From Information Theory to Variational Inference


Variational Inference $(\text{VI})$ are useful methods for approximating hard-to-compute probabilisty densities. The main idea behind VI is that a target distribution $p$ of some dataset can be estimated by introducing an approximate distribution $q$, and then, iteratively minimizing the Kullback-Leibler divergence $\text{KL}(q||p)$ between $q$ and $p$. Many reinforcement learning algorithms, e.g., variational inference for policy search, aim to optimize the policy by minimizing the KL-divergence between a policy distribution and an improper reward-weighted distribution. This post discussse the following topics that are basic, but important, concepts to understanding VI.

  • Information Theory
    • Information
    • Entropy
    • Kullback-Leibler divergence
  • Statistics
    • Jensen’s inequality
    • Evidence lower bound $(\text{ELBO})$
  • Graphic models
    • Bayesian Networks

and then, talk about

  • Varitional Inference

1. Information Theory


One of the core basic concept in information theory is “Information”. The amount of “information” contains in an event $x$ is defined formally $(\text{or mathematically})$ as

\[I(x) = \log{\frac{1}{p(x)}}\]

where $p(x)$ is the occurrence probability of event $x$. Informally, the more one knows about an event $(\text{high probability})$, the less hidden information he is apt to get about it $(\text{less information})$.

  • For example, the probability of a dice being a particular number, e.g., 3, is 1/6. Thus, the information of a dice rolling is $I(x) = \log_{2}{\frac{1}{1/6}} = \log_{2}{6}=2.58$ bits. On the other hand, the probability of a coin being head or tail is $1/2$, and hence, the information of a coin toss is $I(x) = \log_{2}{\frac{1}{1/2}}=1$ bit.
  • Event Probability Information
    Coin toss 1/2 $I(X)=\log_{2}{\frac{1}{1/2}}=1$ bit
    Dice rolling 1/6 $I(X)=\log_{2}{\frac{1}{1/6}} = 2.58$ bits


Entropy” measures the average “Information” of the source data. Shannon defined the entropy $(H)$ of a set of discrete random variable $X={x_1, x_2, x_3, \cdots, x_n}$ with probability mass function $P(X)$ explicitly as

\[\begin{aligned} H(X) &= E[I(X)] = E[\log{\frac{1}{P(x)}}]\\ &= - \sum_{i}^{n} p(x_i)\log_{b}{p(x_i)} \end{aligned}\]

where $b$ is the base of the logarithm, e.g., $b=2$, $b=10$, or $b=e$.

  • Discrete variables $X=[10, 20, 30, 40]$ with equal probability $p(x_{i})=\frac{1}{4}$
\[\begin{aligned} H(X) &= - \sum_{i}^{n} p(x_i)\log_{2}{p(x_i)}\\ &= - \sum_{0}^{3} \frac{1}{4} \log_{2}{\frac{1}{4}}\\ &= 2~\text{bits} \end{aligned}\]
  • Continuous variables $X \in \mathbb{R}$ with probabilty density function of the exponential distribution $p(x)=\lambda e^{-\lambda x}$
\[\begin{aligned} H(X) &= - \int_{0}^{\infty} p(x) \log_{e}{p(x)} dx\\ &= -\int_{0}^{\infty} \lambda e^{-\lambda x} \log_{e} {\lambda e^{-\lambda x}} dx\\ &= -\int_{0}^{\infty} \lambda e^{-\lambda x} [\log_{e}\lambda - \lambda x] dx\\ &= -\int_{0}^{\infty} \lambda e^{-\lambda x}\log_{e}\lambda dx + \int_{0}^{\infty}\lambda e^{-\lambda x}\lambda x dx\\ &= -\log_{e}\lambda \int_{0}^{\infty} \lambda e^{-\lambda x} dx + \lambda \int_{0}^{\infty} \lambda x e^{-\lambda x} dx \\ &= -\log_{e}\lambda \int_{0}^{\infty} p(x) dx + \lambda \int_{0}^{\infty}x p(x)dx\\ &= -\log_{e}\lambda \cdot 1 + \lambda \cdot \frac{1}{\lambda} ~\rightarrow\text{mean of an exponential random variable is } \frac{1}{\lambda}\\ &= -\log_{e}\lambda + 1 \end{aligned}\]

Kullback-Leibler $(\text{KL})$divergence

The Kullback-Leibler divergence $(\text{also named relative entropy})$ was first introduced by Solomon Kullback and Richard Leibler in 1951 as the directed divergence between two distributions. In statistics, the KL-divergence is commonly used to measure how one probability distribution different from a second, reference probability distribution.

For discrete probability distributions, $P$ and $Q$ are defined on a same probability space, the KL-divergence between $P$ and $Q$ is defined as $(\text{see below})$. The KL-divergence is also interpreted as the mean of the logarithm difference between two distributions, where the expectation is taken using the probability $P$.

\[D_{KL}(P || Q) = \sum_{x\in\mathcal{X}} {\color{red}p(x)} \log{\frac{ {\color{red}p(x)}}{q(x)}}\]

For continuous probability distributions, $P$ and $Q$, the KL-divergence is defined as an integral $(\text{see below})$.

\[D_{KL}(P || Q) = \int_{x\in\mathcal{X}} {\color{red}p(x)} \log{\frac{ {\color{red}p(x)}}{q(x)}} dx\]
  • Example: Calculate the KL-divergence in discrete domain between $p$ and $q$, and the KL-divergence between $p$ and $t$? \(\begin{aligned} D_{KL}(p||q) &= 0.25 * \log{\frac{0.25}{0.25}} + 0.25 * \log{\frac{0.25}{0.25}} + 0.25 * \log{\frac{0.25}{0.26}} + 0.25 * \log{\frac{0.25}{0.24}} \\ &= 0.000058\\ D_{KL}(p||t) &=0.25 * \log{\frac{0.25}{0.6}} + 0.25 * \log{\frac{0.25}{0.2}} + 0.25 * \log{\frac{0.25}{0.15}} + 0.25 * \log{\frac{0.25}{0.05}} \\ &= 0.53\\ D_{KL}(p||q) &<< D_{KL}(p||t)\\ \end{aligned}\)
  • Example: Calculate the KL-divergence between two normal distributions, $p \sim \mathcal{N}(\mu_1, \sigma_1^2)$ and $q \sim \mathcal{N}(\mu_2, \sigma_2^2)$. \(\begin{aligned} D_{KL}(P || Q) &= \int_{x\in\mathcal{X}} p(x) \log{\frac{p(x)}{q(x)}} dx\\ &= \int p(x) \log{\left( \frac{1}{\sqrt{2\pi} \sigma_1^2} e^{\frac{-(x-\mu_1)^2}{2\sigma_1^2}} / \frac{1}{\sqrt{2\pi} \sigma_2^2} e^{\frac{-(x-\mu_2)^2}{2\sigma_2^2}} \right)}dx\\ &= \int p(x) \left( \log \frac{\sigma_2^2}{\sigma_1^2} + \log e^{\frac{-(x-\mu_1)^2}{2\sigma_1^2}} - \log e^{\frac{-(x-\mu_2)^2}{2\sigma_2^2}} \right) dx \\ &= \int p(x) \left( \log \frac{\sigma_2^2}{\sigma_1^2} + \frac{-(x-\mu_1)^2}{2\sigma_1^2} - \frac{-(x-\mu_2)^2}{2\sigma_2^2} \right) dx \\ &=\log \frac{\sigma_2^2}{\sigma_1^2} \int p(x)dx - \frac{1}{2\sigma_1^2} \int p(x) (x-\mu_1)^2 dx + \frac{1}{2\sigma_2^2} \int p(x) (x-\mu_2)^2 dx\\ &= \log \frac{\sigma_2^2}{\sigma_1^2} - \frac{1}{2\sigma_1^2} (\sigma_1^2) + \frac{1}{2\sigma_2^2}(\sigma_1^2 + \mu_1^2 - 2\mu_1\mu_2 + \mu_2^2)\\ &= \log \frac{\sigma_2^2}{\sigma_1^2} + \frac{\sigma_1^2 + (\mu_1 - \mu_2)^2}{2\sigma_2^2}- \frac{1}{2}\\ &= \log \frac{2}{1} + \frac{1 + (0-2)^2}{2\times2} - \frac{1}{2}\\ &= 1.05 \end{aligned}\) \(\begin{aligned} \text{where}~\int p(x) (x-\mu)^2 dx &= \int p(x) \left( x^2 - 2x\mu + \mu^2 \right)dx\\ &= \int p(x) x^2 dx - 2\mu \int p(x) x dx + \mu^2 \int p(x) dx \\ &= E[x^2] - 2\mu \mu + \mu^2\\ &= \sigma^2 + \mu^2 - 2\mu \mu + \mu^2\\ \end{aligned}\)

Some properties of the KL-divergence are

  • It is Non-negative: $D_{KL}(P || Q) \geq 0 $,
  • It is asymmetric. $D_{KL}(P || Q) \neq D_{KL}(Q||P)$,
  • It is invariant under parameter transformation. $\text{(this property turns out very useful in machine learning or reinforcement learning, e.g., natural gradient)}$.
2. Statistics

Jensen's inequality

Jensen’s inequality generalizes the statement that a secant line of a convex function lies above the graph $(\text{Wikipedia})$. Let $f(x)$ to be a real continuous function $(\text{convex or concave})$, thus the Jensen’s inequality is

\[\begin{aligned} f(tx_1 + (1-t)x_2) &\leq tf(x_1) + (1-t)f(x_2) \quad \quad \text{where}~f(x)~\text{is convex}\\ f(tx_1 + (1-t)x_2) &\geq tf(x_1) + (1-t)f(x_2) \quad \quad \text{where}~f(x)~\text{is concave}\\ \end{aligned}\]

In the domain of probability theory, if the $p_1, p_2, \cdots, p_n$ are positive number that sum to 1, and $f(x)$ is a convex function, then

\[\begin{aligned} &f\left(\sum_i^n p_i x_i\right) \leq \sum_i^n p_i f(x_i)\\ \Rightarrow~&f\left(E[x]\right) \leq E[f(x)] \end{aligned}\]

On the other hand, if $f(x)$ is a concave function, then

\[\begin{aligned} &f\left(\sum_i^n p_i x_i\right) \geq \sum_i^n p_i f(x_i)\\ \Rightarrow~&f\left(E[x]\right) \geq E[f(x)] \end{aligned}\]

Evidence Lower Bound $(\text{ELBO})$

Now, let’s start from the log probability of a random variable $X$. Here, $f(x)=\log{(x)}$ is a concave function. Thus, we can have

\[\begin{aligned} \log{p(X)} &= \log \int_z p(X, Z) dz \\ &= \log \int_z p(X, Z) \frac{q(Z)}{q(Z)} dz \\ &= \log\left( E_z\left[\frac{p(X, Z)}{q(Z)}\right] \right) \geq E_z \left[\log \frac{p(X, Z)}{q(Z)} \right]\\ &= E_z [\log p(X, Z)] - E_z [\log(q(Z))] \end{aligned}\]

We denote $L=E_z [\log p(X, Z)] + H(Z)$ as the Evidence Lower Bound $(\text{ELBO})$, where $H(Z)=- E_z [\log(q(Z))]$ is the Shannon entropy. The $q(Z)$ in the equation is a distribution used to approximate the true posterior distribution $p(Z|X)$ in VI. Maximizing the ELBO gives as tight a bound on the log probability. Or, if we want to maximize the marginal probability, we can instead maximize its ELBO $L$.

3. Graphical Models

Probability theory plays a crucial role in modern machine leanring. Graphical models provide a simple and elegant way to represent the structure of a probabilitic model, show insights into the properties of the mdoel, especially conditional independence properties. Generally, a graph consists of nodes and links, where each node represent a random variable and the links express probabilistic relationships between these variables. The graph then captures the way in which the joint distribution over all the random variables can be decomposed into a product of factors each depending only on a subset of the variables.

Bayesian Networks

The graph is a directed graphical model which is typically used to describe probability distribution in Bayesian inference. The graphical model represents the joint probability distribution over three variables $A$, $B$ and $C$. We can therefore write the joint distribution in the form

\[\begin{aligned} p(A, B, C) &= p(C|A, B)p(A, B)\\ &= p(C|A, B)p(B|A)p(A) \end{aligned}\]
4. Variational Inference

Finally, we are now ready to introduce the Variational Inference.

Problem Setup

Assume that $X$ are observations $(\text{data})$ and $Z$ are hidden variables, the hidden variables might include the “parameters”. The relationship of these two variabels can be represented using a grapical model

The goal of variational inference is to infer hidden variables from observations, that is we want the posterior distribution

\[p(Z|X) = \frac{p(X|Z)p(Z)}{p(X)} = \frac{p(X, Z)}{\int_Z p(X, Z) dz}\]

where the joint probability $p(X, Z)$ is generally easy to compute, however, the marginal probability $p(X)=\int_Z p(X,Z) dz$ is intractable in most cases.


A YouTube tutorial on variational inference.

Another YouTube tutorial on variational inference.