gumbel 与 softmax
对于变量 \(x_1,x_2,..x_k\),以及标准 gumbel 分布上的噪声 \(\epsilon_1,\epsilon_2,..\epsilon_k\),以下操作:
\[argmax(x_1+\epsilon_1, x_2+\epsilon_2,..x_k+\epsilon_k)\]
等同于在 \(softmax(x_1,x_2,..x_k)\) 这个多元离散分布上采样。
证明:
在 \(softmax(x_1,x_2,..x_k)\) 采样得到 \(i\) 的概率是:
\[ \frac{e^{x_i}}{\sum_{j=1}^{k}e^{x_j}} \tag{*} \]
另一方面 \(argmax(x_1+\epsilon_1, x_2+\epsilon_2,..x_k+\epsilon_k)=i\) 的意思是说
\(x_1+\epsilon_1<x_i+\epsilon_i \\
x_2+\epsilon_2<x_i+\epsilon_i \\
...\\
x_k+\epsilon_k<x_i+\epsilon_i \\\)
等价于
\(\epsilon_1<x_i+\epsilon_i-x_1 \\
\epsilon_2<x_i+\epsilon_i-x_2 \\
...\\
\epsilon_k<x_i+\epsilon_i-x_k \\\)
给定 \(\epsilon_i\) ,其概率为:
\(\prod_{j\neq{i}}e^{-e^{-(x_i+\epsilon_i-x_j)}} \\
=e^{-\sum_{j\neq{i}}e^{-(x_i+\epsilon_i-x_j)}}\)
遍历 \(\epsilon_i\) 的所有取值,
\(p\{argmax(x_1+\epsilon_1, x_2+\epsilon_2,..x_k+\epsilon_k)=i\} \\
=\int^{+\infty}_{-\infty} e^{-\epsilon_i-e^{-\epsilon_i}}\cdot e^{-\sum_{j\neq{i}}e^{-(x_i+\epsilon_i-x_j)}} d\epsilon_i\)
指数部分化简为:
\[-\epsilon_i-e^{-\epsilon_i}\cdot {(1+\sum_{j\neq{i}}\frac{e^{x_j}}{e^{x_i}})} \\
=-\epsilon_i-\frac {e^{-\epsilon_i}} {\frac{e^{x_i}}{\sum_{j=1}^{k} e^{x_j}}}\]
令 \({\frac{e^{x_i}}{\sum_{j=1}^{k} e^{x_j}}}=a\) ,上面的积分可表示为:
\(\int^{+\infty}_{-\infty} e^{-\epsilon_i-\frac{e^{-\epsilon_i}}{a}}d\epsilon_i\)
通过 wolfram 网站可得该积分为 \(a \cdot e^{\frac{1}{a}\cdot -e^{-x}}|^{+\infty}_{-\infty} = a\) ,从而也就说明 \(p\{argmax(x_1+\epsilon_1, x_2+\epsilon_2,..x_k+\epsilon_k)=i\}={\frac{e^{x_i}}{\sum_{j=1}^{k} e^{x_j}}}\)
结合 (*) 式,命题得证。
作用:
当需要将多元离散分布上的样本作为神经网络的输入时,一种方法是将该采样以 onehot 编码传入神经网络。通常这一操作可用带温度 \(softmax\) 来模拟:
\(softmax(\frac{x_1+\epsilon_1}{T}, \frac{x_2+\epsilon_2}{T},..\frac{x_k+\epsilon_k}{T})\)
其中 \(T\) 越小,近似程度越高。
如果 \(x_1,x_2..x_k\) 是一个神经网络的输出,那么通过上面的加 \(gumbel\) 噪声以及带温度的 \(softmax\) 两个操作,我们用可导连续计算模拟了不可导的采样动作,从而沟通了两个神经网络,使得反向梯度传播成为可能。类比于 VAE 中对高斯分布采样的 reparameter trick,这里是对多元离散分布的 reparameter trick 。
reference:
https://en.wikipedia.org/wiki/Gumbel_distribution
https://lips.cs.princeton.edu/the-gumbel-max-trick-for-discrete-distributions/
https://arxiv.org/pdf/1611.01144.pdf