230 likes | 338 Views
A Novel Local Patch Framework for Fixing Supervised Learning Models. Yilei Wang 1 , Bingzheng Wei 2 , Jun Yan 2 , Yang Hu 2 , Zhi-Hong Deng 1 , Zheng Chen 2. 1 Peking University 2 Microsoft Research Asia. Outline. Motivation & Background Problem Definition & Algorithm Overview
E N D
A Novel Local Patch Framework for Fixing Supervised Learning Models Yilei Wang1, Bingzheng Wei2, Jun Yan2, Yang Hu2, Zhi-Hong Deng1, Zheng Chen2 1Peking University 2Microsoft Research Asia
Outline • Motivation & Background • Problem Definition & Algorithm Overview • Algorithm Details • Experiments - Classification • Experiments - Search Ranking • Conclusion
Motivation & Background • Supervised Learning: • Machine Learning task of inferring a function from labeled training data • Prediction Error: • No matter how strong a learning model is, it will suffer from prediction errors. • Noise in training data, dynamically changing data distribution, weakness of learner • Feedback from User: • Good signal for learning models to find the limitation and then improve accordingly
Learning to Fix Errors from Failure Cases • Automatically fix model prediction errors from failure cases in feedback data. • Input: • A well trained supervised model (we name it as Mother Model) • A collection of failure cases in feedback dataset. • Output: • Learning to automatically fix the model bugs from failure cases • Previous Works • Model Retraining • Model Aggregation • Incremental Learning
Local Patching: from Global to Local • Learning models are generally optimized globally • Introducing new prediction errors when fixing the old ones • Our key idea: learning to fix the model locally using patches New Error New Error
Problem Definition • Our proposed Local Patch Framework(LPF) aims to learn a new model • : the original mother model • : Patch model • : Gaussian distribution defined by a centroid and a range
Algorithm Overview • Failure Case Collection • Learning Patch Regions/Failure Case Clustering • Clustering Failure Cases into N groups through subspace learning, compute the centroid and range for every group, then define our patches • Learning Patch Model • Learn a patch model using only the data samples that sufficiently close to the patch centroid
Learning Patch Region – Key Challenge • Failure cases may distribute diffusely • Small N = large patch range → many success cases will be patched • Big N = small patch range → high computational complexity • How to make trade-offs ?
Solution: Clustered Metric Learning • Our solution to diffusion: Metric Learning • Learn a distance metric, i.e. subspace, for failure cases, such that the similar failure cases will aggregate, and keep distant from the success cases. (Red circle = failure cases; blue circle = success cases) Key idea of the patch model learning • (Left): The cases in original data space. • (Middle): The cases mapped to the learned subspace. • (Right): Repair the failure cases using a single patch.
Metric Learning • Conditional distribution over • Ideal distribution • Learn to satisfy • Which is equivalent to maximize
Clustered Metric Learning • Algorithm: • 1. Initialize each failure case with a random group • 2. Repeat the following steps: • a) For the given clusters, proceeds metric learning step • b) Update the centroids of the groups, and re-assign the failure cases to its closest centroid. • Local Patch Region: • For each cluster i, we define a corresponding patch with as its centroid , and as its variance • Gaussian weight:
Learning Patch Model • Objective: • Where are the parameters, are the labels • Update parameter: • For /, we have • Notice: dependent on the specific patch model
Experiments - Classification • Dataset • Randomly select 3 UCI subset • Spambase, Waveform, Optical Digit Recognition • Convert to binary classification dataset • ~5000 instances in each dataset • Split to: 60% - training, 20% - feedback, 20% - test • Baseline Algorithm • SVM • Logistic Regression • SVM - retrained with training + feedback data • Logistic Regression - retrained with training + feedback data • SVM – Incremental Learning • Logistic Regression - Incremental Learning
Classification Accuracy • Classification accuracy on feedback dataset • Classification accuracy on test dataset
Parameter Tuning • Number of Patches • Data sensitive, in our experiment the best N is 2
Experiments – Search Ranking • Dataset • Data from a commonly used commercial search engine • ~14, 126 <q, d> pairs • With 5 grades label • Metrics • NDCG@K {1,3,5} • Baseline Algorithm • GBDT • GBDT + IL
Conclusion • We proposed • The local model fixing problem • A novel patch framework fox fixing the failure cases in feedback dataset in local view • The experiment results demonstrate the effectiveness of our proposed Local Patch Framework