310 likes | 325 Views
580.704 Mathematical Foundations of BME Reza Shadmehr The loss function, the normal equation, cross validation, LMS algorithm, Steepest descent algorithm. Review of the linear classification problem
E N D
580.704 Mathematical Foundations of BME Reza Shadmehr The loss function, the normal equation, cross validation, LMS algorithm, Steepest descent algorithm
Review of the linear classification problem • Hypothesis class: we assume that what we are about to approximate is a function that belongs to some space of functions F. We don’t know the true function f, but we hypothesize that it too belongs to F : Hypothesis:
Estimation: we are given a training set of examples and labels and using some adaptation algorithm, we find Trial number Pixel number • Whenever our estimate was wrong, change the weight for each “expert”: whenever • Evaluation: we measure how well our estimate generalizes to novel examples.
Loss function • The loss function provides a cost for being wrong. The objective of adaptation is to minimize the loss function. In the case of image labeling problem, we might have: • We might want to minimize the loss over the training set: • This is a function of the parameters w and we can minimize it directly. Empirical loss function Find w that minimizes:
Issues about minimization of the loss function • Why should be minimize the loss over the training set when we are actually interested in minimizing the loss over the test set? • We assume that each training and test example-label pair (x,y) is drawn independently and at random from the same but unknown population of examples and labels. • We represent this population as a joint probability distribution p(x,y) so that each training/test example is a sample from this distribution: The training loss based on a few sampled examples and labels serves as a proxy for the test performance measured over the whole population. Empirical loss (training set only) Expected loss over all the data (training and test together)
Regression • The goal is to make quantitative (real valued) predictions on the basis of a (vector of) features or attributes. • Example: years to onset of Huntington’s disease in genetically at risk individuals. Current age CAG repeats Mother HD? Father HD? HD onset 43 37 55 1 0 51 33 37 0 1 49 39 43 1 1 We need to – specify the hypothesized class of functions (e.g., linear) – select how to measure prediction loss (the loss function) – solve the resulting minimization problem
Regression: Hypothesized class and the loss function Univariate regression: Multivariate regression: Parameters we need to find: Squared error Loss function: Empirical loss: Mean squared error
Regression: estimation of parameters • We have to minimize the empirical loss function: • This function is quadratic in terms of w. It has a minimum at some point in the w space. We find the w that minimizes the loss function by finding the conditions where the derivative of the loss function is zero.
Chain rule reminder Optimality conditions: Finding conditions that minimize the empirical loss function
Expected behavior of the model errors (residuals) Error in the prediction (model residual) • The prediction error should be mean zero and not contain any linear trends, i.e., be uncorrelated with any linear function of the inputs. But there may exist some non-linear function of inputs that can account for the residuals. y x x
Loss function: matrix notation “L2” norm
Optimality condition: minimize mean squared error the “normal” equation
Review of regression Univariate regression: Multivariate regression: Parameters we need to find: Loss function: Empirical loss:
Regression with polynomials • univariate regression with m-th order polynomials:
Regression with polynomials: fit improves with increased order
Over-fitting • We want to fit the training set, but as model complexity increases, we run the risk of over-fitting. Train set Leave out When the model order is over-fitting, leaving a single data point out of the training set can drastically change the fit.
Cross validation • We want to fit the training set, but we want to also generalize correctly. To measure generalization, we leave out a data point (named the test point), fit the data, and then measure error on the test point. The average error over all possible test points is the cross validation error. Weights estimated from a training set that does not include the i-th data point
Cross validation • Cross validation error will often increase when the model structure is over-fitting the data. Mean-squared error (training set) Cross-validation error Model order Model order (actual data was generated with a 2nd order polynomial process) Cross-validation error Model order
Batch vs. online learning algorithms • In “batch” learning, we don’t have to make any predictions until we see all of the data. At that point, we make a model to fit all the data. • In “online” learning, data points are given to us one at a time. We use each example pair to update our model. We are given an x and with our current model we predict a y The teacher tells us our error We modify our model
Online learning: the LMS algorithm • Assume we have the model: When we project w onto x, we get a scalar p: What we want is to change w so that when we project onto x we get: Anywhere along the dash line is the solution we’re looking for.
The LMS algorithm Unit vector along x w changes along a vector parallel to the input xin that trial with a magnitude proportional to the prediction error in that trial. With this step size, we change w to completely account for the error in that trial. “step size”
LMS algorithm attempts to minimize a squared error loss function by approximating the gradient of the loss function Steepest descent algorithm Average error over all data points LMS: local error as a rough estimate of average error
Iterating over two data points Iterating over three data points Equilibrium point Convergence of LMS-algorithm With 3 data points, solution will not move to a single point and stay put. It converges to a small region of the parameter space but will bounce around, as long as h > 0.
Convergence of LMS-algorithm It is difficult to prove “convergence” of LMS because the weights keep bouncing around. But we can prove convergence for the steepest decent algorithm and then use the fact that LMS is a stochastic approximation to it. a geometric series
Convergence of a geometric series of matrices See homework for this:
Convergence of steepest descent algorithm See homework for proof of this.
We have shown convergence of the steepest decent algorithm to the solution of the normal equations. The LMS is a stochastic approximation to steepest decent, thus it “converges” as well, but will jump around stochastically, as long as the learning rate is greater than zero. Convergence can be reached when the learning rate is systematically made smaller on each step. We will call changes of the learning rate “adaptive learning” and will see a principled approach to this problem when we consider Bayesian approaches to learning.
Summary: Linear Regression Univariate regression: Multivariate regression: Parameters we need to find: Loss function: Empirical loss:
Summary: Iterative learning • Increased model complexity reduces error over the training data but can increase the leave-one-out cross validation error. We want a model that fits the trained data and generalizes correctly. • LMS algorithm: • w changes along a vector parallel to the input x in that trial with a magnitude proportional to the error in that trial. • Steepest descent algorithm: