Expectation Maximization Algorithm
Explain the derivation of the EM algorithm. I really hope to simplify it, but I don’t want this algorithm to hold me up for too long. I will see the commonalities in the future; this is my confidence.
Birds in NUAA campus
Assume that the wing span lengths of all the birds in the NUAA campus follow a normal distribution ( $N(\mu, \sigma^2)$ ).But we know that the real situation is not like that.Because Clark frequently see three different types of birds in the NUAA campus,he believes that the wing spans of these three different species of birds follow normal distributions with different parameters.Birds of type A follow a normal distribution ( $N(\mu_1, \sigma_1^2) $).Birds of type B follow a normal distribution ( $N(\mu_2, \sigma_2^2) $).Birds of type C follow a normal distribution ( $N(\mu_3, \sigma_3^2) $).
Unfortunately, Clark is not a renowned ornithologist, and during his observations on campus, he only measured the wingspans of the birds without knowing their species. Fortunately, Clark has a good understanding of statistics, and after collecting his samples, he knows how to evaluate the wingspan distribution of the birds in the NUAA campus.
Expectation Maximization Algorithm
Clark clarified the problem and gave it serious thought. For the sample he drew (a bird from NUAA), he first needed to determine which category the bird belonged to, and then estimate the parameters of the normal distribution of wing spans for categories A, B, and C. These two tasks were interdependent: if we know what category the bird belongs to, we can use maximum likelihood to estimate the distribution of wing spans; conversely, knowing the parameters of the wing span distributions makes it easier to identify which category the bird belongs to (for example, if the mean wing span is largest for category C, then a larger wing span may indicate that the bird is of category C).
Clark’s mathematical intuition was remarkable. He immediately realized that by initially fixing the distribution of bird categories (0.3 for A, 0.6 for B, and 0.1 for C), he could estimate the wing span parameters, and then use these wing span parameters to refine the category distribution. He had a hunch that this iterative process would eventually converge to a stable solution. Clark thought of the EM (Expectation-Maximization) algorithm.
The approach of the EM (Expectation-Maximization) algorithm to solve this problem is to use a heuristic iterative method. Since we cannot directly determine the parameters of the model distribution, we can first hypothesize the hidden parameters (the E-step of the EM algorithm), and then maximize the log-likelihood based on the observed data and the guessed hidden parameters to solve for our model parameters (the M-step of the EM algorithm). Because our hidden parameters were initially guessed, the model parameters obtained at this point are generally not yet the desired result. We continue to hypothesize the hidden parameters based on the current model parameters (the E-step of the EM algorithm) and then continue to maximize the log-likelihood to solve for our model parameters (the M-step of the EM algorithm). This process is repeated until the model distribution parameters change very little, indicating convergence of the algorithm and finding suitable model parameters.
One of the most intuitive ways to understand the idea of the EM algorithm is through the K-Means clustering algorithm. In K-Means clustering, the centroids of each cluster are the hidden data. We assume K initial centroids, which corresponds to the E-step of the EM algorithm; then we compute the nearest centroid for each sample and assign the sample to the closest centroid, which corresponds to the M-step of the EM algorithm. This E-step and M-step are repeated until the centroids no longer change, thus completing the K-Means clustering.
The Derivation of the EM Algorithm
For $m$ independent samples $x=(x^{(1)},x^{(2)}…x^{(m)})$, with corresponding hidden data $z=(z^{(1)},z^{(2)}…z^{(m)})$, the pair $(x,z)$ constitutes the complete data. The parameter of the sample model is $P(x^{(i)}|\theta)$, so the observation probability is $P(x^{(i)}|\theta)$, and the likelihood function for the complete data $(x^{(i)},z^{(i)})$ is $P(x^{(i)},z^{(i)}|\theta)$.
If there were no hidden variables $z$, we would only need to find an appropriate $\theta$ to maximize the likelihood function.
$\theta=argmax_{\theta}L(\theta)=argmax_{\theta}\sum_{i=1}^{m}logP(x^{(i)}|\theta)$
By introducing the hidden variables,our goal becomes to find appropriate $\theta$ and $z$ to maximize the log-likelihood function.
$\theta,z=argmax_{\theta,z}L(\theta,z)=argmax_{\theta,z}\sum_{i=1}^{m}log\sum_{z^{(i)}}P(x^{(i)},z^{i}|\theta)$
Naturally, we would think to take the partial derivatives with respect to the unknown $\theta$ and $z$ separately.
Theoretically it is feasible, however, if we were to take the partial derivatives with respect to the unknown $ \theta $ and $z$ respectively, since $P(x^{(i)}|\theta) $ is the marginal probability of $P(x^{(i)},z^{(i)}|\theta) $, the form after differentiation would become very complex, making it difficult to solve. So, could we consider extracting the plus sign out of the $log$? Let’s transform this expression as follows:
$\sum_{i=1}^{m}log\sum_{z^{(i)}}P(x^{(i)},z^{i}|\theta)=\sum_{i=1}^{m}log\sum_{z^{(i)}}Q_i(z^{(i)})\frac{P(x^{(i)},z^{i}|\theta)}{Q_i(z^{i})}\geq\sum_{i=1}^{m}\sum_{z^{(i)}}Q_i(z^{(i)})log\frac{P(x^{(i)},z^{i}|\theta)}{Q_i(z^{i})}$
The Jensen’s Inequality property is applied here.$f(E_{z\sim Q}[\frac{P(x^{(i)},z^{i}|\theta)}{Q_i(z^{i}}]) \geq E_{z\sim Q}[f(\frac{P(x^{(i)},z^{i}|\theta)}{Q_i(z^{i}})]$.
We have constructed a lower bound for the likelihood function, and next, we need to find an appropriate ( $Q$ ) to optimize this lower bound ($M$-step). According to Jensen’s inequality, the condition for equality is that the random variable must be constant; hence we have:$\frac{P(x^{(i)},z^{i}|\theta)}{Q_i(z^{(i)})}=c$
Given that$\sum_{z}Q_i(z^{(i)})=1$,it follows that$\sum_{z}P(x^{(i)},z^{i}|\theta)=c\sum_{z}Q_i(z^{(i)})=c$.
$Q_i(z^{(i)})=\frac{P(x^{(i)},z^{i}|\theta)}{c}=\frac{P(x^{(i)},z^{i}|\theta)}{\sum_{z}P(x^{(i)},z^{i}|\theta)}=\frac{P(x^{(i)},z^{i}|\theta)}{P(x^{(i)}|\theta)}=P(z^{(i)}|x^{(i)},\theta)$
From the above, it can be seen that $Q(z)$ is the distribution of the latent variables given the observed samples and the model parameters. We have a lower bound on the log-likelihood that includes hidden data, and if we can maximize this lower bound, we are effectively trying to maximize our log-likelihood. That is, we need to maximize the following expression.
$\sum_{i=1}^{m}\sum_{z^{(i)}}Q_i(z^{(i)})log\frac{P(x^{(i)},z^{i}|\theta)}{Q_i(z^{i})}$
Up to this point, we have derived the problem of choosing the distribution $ Q(z) $) given fixed parameters $\theta$, thereby establishing a lower bound for $log L(\theta) $. This is the E-step. For the subsequent M-step, with $ Q(z) $held fixed, we adjust $\theta$ to maximize the lower bound of $log L(\theta)$. By removing the constant part $Q(z)$ from the aforementioned equation, the lower bound of the log-likelihood that we need to maximize is:
$\sum_{i=1}^{m}\sum_{z^{(i)}}Q_i(z^{(i)})logP(x^{(i)},z^{i}|\theta)$
Thinking about the convergence of the EM (Expectation-Maximization) algorithm.
To prove that the EM (Expectation-Maximization) algorithm converges, we need to show that the value of our log-likelihood function (which has an upper bound) increases monotonically during the iterations. That is:
$\sum_{i=1}^{m}logP(x^{(i)}|\theta^{new})\geq \sum_{i=1}^{m}logP(x^{(i)}|\theta^{old})$
Note:
- For non-jointly convex functions, there is no guarantee of converging to the maximum point.
- The choice of initial values is very important.