340 likes | 350 Views
Algorithm learns mixture of product distributions over {0,1}n efficiently. Technical details and prior work discussed. Generalizations and results presented. Learning distributions summary explained. Intuition and algorithmic ideas discussed for k=2.
E N D
Learning Mixtures of Product Distributions Ryan O’Donnell IAS Jon Feldman Columbia University Rocco Servedio Columbia University
Learning Distributions • There is a an unknown distribution P over Rn, or maybe just over {0,1}n. • An algorithm gets access to random samples from P. • In time polynomial in n/ε it should output a hypothesis distribution Q which (w.h.p.) is ε-close to P. [Technical details later.]
Learning Distributions R 0 Hopeless in general!
Learning Classes of Distributions Learning Distributions • Since this is hopeless in general one assumes that P comes from class of distributions C. • We speak of whether C is polynomial-time learnable or not; this means that there is one algorithm that learns every Pin C. • Some easily learnable classes: • C = {Gaussians over Rn} • C = {Product distributions over {0,1}n}
Learning product distributions over {0,1}n E.g. n = 3. Samples… 0 1 0 0 1 1 0 1 1 1 1 1 0 1 0 0 1 1 0 1 0 0 1 0 1 1 1 0 0 0 Hypothesis: [.2.9.5]
Mixtures of product distributions Fix k ≥ 2 and let π1 + π2 + … πk = 1. The π-mixture of distributions P 1, …, Pk is: • Draw i according to mixture weights πi. • Draw from Pi. In the case of product distributions over {0,1}n: π1 [ μ1μ1μ1 … μ1 ] π2 [ μ2μ2μ2 … μ2 ] … πk [μkμkμk … μk ] 2 1 3 n n 1 3 2 n 3 2 1
Learning mixture example E.g. n = 4. Samples… 1 1 0 0 0 0 0 1 0 1 0 1 0 1 1 0 0 0 0 1 1 1 1 0 0 1 0 1 0 0 1 1 1 1 1 0 1 0 1 0 True distribution: 60% [ .8 .8.6.2 ] 40% [.2 .4.3.8 ]
Prior work • [KMRRSS94]: learned in time poly(n/ε, 2k) in the special case that there is a number p < ½ such that every μi is either p or 1−p. • [FM99]: learned mixtures of 2 product distributions over {0,1}n in polynomial time (with a few minor technical deficiencies). • [CGG98]: learned a generalization of 2 product distributions over {0,1}n, no deficiencies. The latter two leave mixtures of 3+ as an open problem: there is a qualitative difference between 2 & 3. [FM99] also leaves open learning mixes of Gaussians, other Rn distributions. j
Our results • A poly(n/ε) time algorithm learning a mixture of k product distributions over {0,1}n for any constant k. • Evidence that getting a poly(n/ε) algorithm for k = ω(1) [even in the case where μ’s are in {0, ½, 1}] will be very hard (if possible). • Generalizations: • Let C 1, …, Cn be “nice” classes of distributions over R (…definable in terms of O(1) moments…) Algorithm learns mixture of O(1) distributions in C 1× · · · ×C n. • Only pairwise independence of coords is used…
Technical definitions When is a hypothesis distribution Q “ε-close” to the target distribution P ? • L1 distance? ∫ |P(x) – Q(x)|. • KL divergence: KL(P || Q) = ∫P (x) log[P (x)/Q(x)]. Getting a KL-close hypothesis is more stringent: fact: L1≤ O(KL½). We learn under KL divergence, which leads to some technical advantages (and some technical difficulties).
Learning distributions summary • Learning a class of distributions C. • Let P be any distribution in the class. • Given ε andδ> 0. • Get samples and do poly(n/ε, log(1/δ)) much work. • With probability at least 1−δ output a hypothesis Q which satisfies KL(P || Q) < ε.
Some intuition for k = 2 Idea: Find two coordinates j and j' to “key off.” • Suppose you notice that the bits in coords j and j' are very frequently different. • Then probably most of the …0…1… examples come from one mixture and most of the …1…0… examples come from the other mixture – • Use this separation to estimate all other means.
More details for the intuition Suppose you somehow “know” the following three things: • The mixture weights are 60% / 40%. • There are j and j' such that means satisfy pj pj' qj qj' • The values pj, pj', qj, qj' themselves. > ε.
More details for the intuition Main algorithmic idea: For each coord m, estimate (to within ε2) the correlation between j & m and j' & m. corr(j, m) = (.6 pj) pm + (.4 qj) qm corr(j', m) = (.6 pj') pm + (.4 qj') qm Solve this system of equations for pm, qm. Done! Since the determinant is > ε, any error in correlation estimation error does not blow up too much.
Two questions 1. This assumes that there is some 2×2 submatrix which is far from singular. In general, no reason to believe this is the case. • But if not, then one set of means is very nearly a multiple of the other set; problem becomes very easy. 2. How did we knowπ1, π2? How did we know which j and j' were good? How did we know the 4 means pj, pj', qj, qj'?
Guessing Just guess. I.e., “try” “all” possibilities. • Guess if the 2 × n matrix is essentially rank 1 or not. • Guess π1, π2 to within ε2. (Time: 1/ε4.) • Guess correct j, j'. (Time: n2.) • Guess pj, pj', qj, qj' to within ε2. (Time: 1/ε8.) Solve the system of equations in every case. Time: poly(n/ε).
Checking guesses • After this we get a whole bunch of candidate hypotheses. • When we get lucky and make all the right guesses, the resulting candidate hypothesis will be a good one – say, will be ε-close in KL to the truth. Can we pick the (or, a) candidate hypothesis which is KL-close to the truth? I.e., can we guess and check? Yes – use a Maximum Likelihood test…
Checking with ML Suppose Q is a candidate hypothesis for P. Estimate its “log likelihood”: log Πx є SQ(x) = Σx є S log Q(x) ≈ |S| E[log Q (x)] = |S| ∫P (x) log Q (x) = |S| [∫P log P – KL(P || Q ) ].
Checking with ML cont’d • By Chernoff bounds, if we take enough samples, all candidate hypotheses Q will have their “estimated log-likelihoods” close to their expectations. • Any KL-close Q will look very good in the ML test. • Anything which looks good in the ML test is KL-close. • Thus assuming there is an ε-close candidate hypothesis among guesses, we find an O(ε)-close candidate hypothesis. • I.e., we can guess and check.
Overview of the algorithm We now give the precise algorithm for learning a mixture of k product distributions, along with intuition for why it works. Intuitively: • Estimate all the pairwise correlations of bits. • Guess a number of parameters of the mixture distn. • Use guesses, correlation estimates to solve for remaining parameters. • Show that whenever guesses are close, the resulting parameter estimations give a close-in-KL candidate hypothesis. • Check candidates with ML algorithm, pick best one.
The algorithm 1. Estimate all pairwise correlations corr(j, j') to within (ε/n)k. (Time: (n/ε)k.) Note: corr(j, j') = Σi = 1..kπiμiμi = μj , μj', where μj = ( (πi)½μi)i = 1..k 2. Guess all πi to within (ε/n)k. (Time: (n/ε)k2.) Now it suffices to estimate all vectors μj, j = 1… n. j j' ~ ~ ~ j ~
Mixtures of product distributions Fix k ≥ 2 and let π1 + π2 + … πk = 1. The π-mixture of distributions P 1, …, Pk is: • Draw i according to mixture weights πi. • Draw from Pi. In the case of product distributions over {0,1}n: π1 [ μ1μ1μ1 … μ1 ] π2 [ μ2μ2μ2 … μ2 ] … πk [μkμkμk … μk ] 2 1 3 n n 1 3 2 n 3 2 1
Guessing matrices from most of their Gram matrices ~ j Let A be the k × n matrix of μi’s. A = After estimating all correlations, we know all dot products of distinct columns of A to high accuracy. Goal: determine all entries of A, making only O(1) guesses. ~ ~ ~ μ1 μ2 μn
Two remarks • This is the final problem, where all the main action and technical challenge lies. Note that all we ever do with the samples is estimate pairwise correlations. • If we knew the dot products of the columns of A with themselves, we’d have the whole matrix ATA. That would be great; we could just factor it and recover A exactly. Unfortunately, there doesn’t seem to be any way to get at these quantities – Σi = 1..kπi (μi)2. j
Keying off a nonsingular submatrix Idea: find a nonsingular k × k matrix to “key off.” As before, the “usual” case is that A has full rank. • Then there is a k × k nonsingular submatrix AJ. • Guess this matrix (time: nk) and all its entries to within (ε/n)k (time: (n/ε)k3 – final running time). • Now use this submatrix and correlation estimates to find all other entries of A: for all m, AJT Am = corr(m, j) (j є J)
Non-full rank case But what if A is not full rank? (Or in actual analysis, if A is extremely close to being rank deficient.) A genuine problem. Then A has some perpendicular space of dimension 0 < d ≤ k, spanned by some orthonormal vectors u1, …, ud. • Guess d and the vectors u1, …, ud. Now adjoin these columns to A getting a full rank matrix. A' = A u1 u2 … ud
Non-full rank case – cont’d Now A' has full rank and we can do the full rank case! Why do we still know all pairwise dot products of A'’s columns? • Dot product of u’s with A columns are 0! • Dot product of u’s with each other is 1. (Don’t need this.) 4. Guess a k × k submatrix of A' and all its entries. Use these to solve for all other entries.
The actual analysis • The actual analysis of this algorithm is quite delicate. • There’s some linear algebra & numerical analysis ideas. • The main issue is: The degree to which A is “essentially” of rank k – d is similar to the degree to which all guessed vectors u really do have dot product 0 with A’s original columns. • The key is to find a large multiplicative “gap” between A’s singular values, and treat its location as the essential rank of A. • This is where the necessary accuracy (ε/n)k comes in.
Can we learn a mixture of ω(1)? Claim: Let T be a decision tree on {0,1}n with k leaves. Then the uniform distribution over the inputs which make T output 1 is a mixture of at most k product distributions. Indeed, all product distributions have means 0, ½, or 1. x1 0 1 x2 x3 2/3: [0, 0, ½, ½, ½, …] 1/3: [1, 1, 0, ½, ½, …] 0 1 0 1 x2 1 0 0 0 1 0 1
Learning DTs under uniform Cor: If one can learn a mixture of k product distributions over {0,1}n (even 0/½/1 ones) in poly(n) time, one can PAC-learn k-leaf decision trees under uniform in poly(n) time. PAC-learning ω(1)-size DTs under uniform is an extremely notorious problem: • easier than learning ω(1)-term DNF under uniform, a 20-year-old problem; • essentially equivalent to learning ω(1)-juntas under uniform; worth $1000 from A. Blum to solve
Generalizations We gave an algorithm that guessed the means of an unknown mixture of k product distributions. What assumptions did we really need? • pairwise independence of coords • means fell in a bounded range [-poly(n), poly(n)] • 1-d distributions (and pairwise products of same) are “samplable” – can find true correlations by estimation • the means defined the 1-d distributions The last of these is rarely true. But…
Higher moments • Suppose we ran the algorithm and got N guesses for the means of all the distributions. • Now run the algorithm again, but whenever you get the point x1, …, xn, treat it as x12, …, xn2. • You will get N guesses for the second moments! • Cross product the two lists, get N2 guesses for the mean, second moment pairs. • Guess and check, as always.
Generalizations Let C 1, …, Cn be families of distributions on R which have the following “niceness properties”: • means bounded in [-poly(n), poly(n)] • sharp tail bounds / samplability • defined by O(1) moments, closeness in moments closeness in KL • … more technical concerns… Should be able to learn O(1)-mixtures from C 1× · · · ×C n in same time. Definitely can learn mixtures of axis-aligned Gaussians, mixtures of distributions on O(1)-sized sets.
Open questions • Quantify some nice properties of families of distributions over R which this algorithm can learn. • Simplify algorithm: • Simpler analysis? • Faster? nk2 ? nk ? nlog k ??? • Specific fast results for k = 2, 3. • Solve other distribution-learning problems.