最优传输之生成模型
WGAN
假设
为什么这样设?第一个等号说明,
因此写成连续形式则有: $$b^T F=\int [p(x)f(x) +q(x)g(x)] dx\tag{2.3}$$
此外对于约束
展开后,实际上说的就是: $$f(x_i)+g(y_j)\leq c(x_i,y_j)\tag{2.5}$$
那么(1.16)式就转化为了: $$\max_{f,g} \left{\int [p(x)f(x) +q(x)g(x)] dx \mid f(x) + g(y) \leq c(x,y)\right}\tag{2.6}$$
这也是(1.10)W距离对应的对偶问题,即求解两个分布的W距离相当于求解(2.4)。
对于特殊情况
这启发我们,如果令
凑巧的是,(2.6)刚好等价于(2.8),下面是证明:
另外,由于
假设(2.6)的最优解为
显然
因为为
其中,第一个等号根据(2.9)的定义;第二个不等号是因为
上述说明,(2.6)的解一定包含在(2.8)中(目标符合,约束也符合),因此有
最终我们把W距离转化为了(2.8)式。对于约束
对于生成模型,我们会要求两个分布的距离尽量靠近,所以如果再最小化
这正是WGAN,因此WGAN的损失函数就是在缩短两个分布的W距离。
扩散模型
根据之前的推导,WGAN主要是在优化W距离,而扩散模型主要是优化KL距离(散度),那么这两个距离之间会有关联吗?2022年的一篇文章《https://arxiv.org/abs/2212.06359 (opens new window)》就介绍了扩散模型的得分匹配损失(实质上也是KL距离的上界#ref_4)是W距离的一个上界,因此在某种程度上,优化得分就等于优化W距离,这样就将扩散模型和WGAN联系到一起了。
这篇文章介绍的最核心的定理如下:
定理1. 假设
从
从
定理1告诉我们,如果我们优化得分匹配损失,那么也相当于优化W距离,所以扩散模型不但在优化两个分布之间的KL距离,还在悄悄优化W距离,也就揭示扩散模型模型和WGAN的联系了。
要完整这个定理需要用到很多最优传输中的引理,作者在论文中也只是简单引用,因此这里就简单介绍一下作者主要的证明思路:
根据#ref_5的定理8.4.7和#ref_6的推论5.25,有: $$-\frac{1}{2}\frac{dW_2^2(p_t(x),q_t(y))}{dt}=\mathbb{E}{\pi_t(x,y)}\left[(x-y)\cdot(\frac{dy}{dt}-\frac{dx}{dt}) \right]\tag{2.18}$$
这个式子是证明的**核心,**其中
(2.14)对应的概率流ODE为: $$\frac{dx}{dt}=f(x,t)-g(t)^2\nabla{x} \log p_t(x)\tag{2.19}$$
(2.15)对应的概率流ODE为: $$\frac{dy}{dt}=f(y,t) - g(t)^2 s_\theta(y,t) + \frac{1}{2}g(t)^2 \nabla_{y} \log q_t(y)\tag{2.20}$$
带入(2.18),有: $$\begin{align} -\frac{1}{2}\frac{dW_2^2(p_t(x),q_t(y))}{dt}&=\mathbb{E}{\pi_t(x,y)}\left[(x-y)\cdot(f(y,t)-f(x,t)) \right]\ &\quad+g(t)^2\mathbb{E}{\pi_t(x,y)}\left[(x-y)\cdot(s_\theta(x,t)-s_\theta(y,t)) \right]\ &\quad+g(t)^2\mathbb{E}{\pi_t(x,y)}\left[(x-y)\cdot(\log \nabla{x}p_t(x)-s_\theta(x,t)) \right]\ &\quad+\frac{g(t)^2}{2}\mathbb{E}{\pi_t(x,y)}\left[(x-y)\cdot(\log \nabla{y}q_t(y)-\log \nabla_{x}p_t(x)) \right]\ \end{align}\tag{2.21}$$
右边第一项和第二项根据(2.17)的约束,可以很容易得到: $$\begin{align} \mathbb{E}{\pi_t(x,y)}\left[(x-y)\cdot(f(y,t)-f(x,t)) \right]&\leq L_f(t)\mathbb{E}{\pi_t(x,y)}[|x-y|^2]\ &=L_f(t)W_2^2(p_t(x),q_t(y)) \end{align}\tag{2.22}$$
和: $$\begin{align} g(t)^2\mathbb{E}{\pi_t(x,y)}\left[(x-y)\cdot(s\theta(x,t)-s_\theta(y,t)) \right]&\leq g(t)^2L_s(t)\mathbb{E}{\pi_t(x,y)}[|x-y|^2]\ &=g(t)^2L_s(t)W_2^2(p_t(x),q_t(y)) \end{align}\tag{2.23}$$
第三项利用积分Cauchy-Schwarz不等式: $$\begin{align} g(t)^2\mathbb{E}{\pi_t(x,y)}\left[(x-y)\cdot(\log \nabla_{x}p_t(x)-s_\theta(x,t)) \right] &\leq g(t)^2 \mathbb{E}{\pi_t(x,y)}[|x-y|^2]^{\frac{1}{2}}\mathbb{E}{\pi_t(x,y)}[|\log \nabla_{x}p_t(x)-s_\theta(x,t)|^2]^{\frac{1}{2}}\ &=g(t)^2W_2(p_t(x),q_t(y))\mathbb{E}{p_t(x)}[|\log \nabla{x}p_t(x)-s_\theta(x,t)|^2]^{\frac{1}{2}} \end{align}\tag{2.24}$$
第四项根据论文附录的引理2,有: $$\mathbb{E}{\pi_t(x,y)}\left[(x-y)\cdot(\log \nabla{y}q_t(y)-\log \nabla_{x}p_t(x)) \right]\leq 0\tag{2.25}$$
综合(2.22)-(2.25),最终有: $$\begin{align} -\frac{1}{2}\frac{dW_2^2(p_t(x),q_t(y))}{dt} & \leq L_f(t)W_2^2(p_t(x),q_t(y)) \ &\quad+ g(t)^2L_s(t)W_2^2(p_t(x),q_t(y))\ &\quad+g(t)^2W_2(p_t(x),q_t(y))b_t^{\frac{1}{2}}\ \end{align}\tag{2.26}$$
简记
因此(2.26)两端可以整理为: $$-\frac{dW_2(p_t(x),q_t(y))}{dt}\leq (L_f(t)+g(t)^2L_s(t))W_2(p_t(x),q_t(y)) +g(t)^2b_t^{\frac{1}{2}} \tag{2.28}$$
这是非齐次线性一阶微分方程,利用常数变易法,设: $$W_2(p_t(x),q_t(y))=C_t\exp\left(\int_{t}^0 L_f(r)+g(r)^2L_s(r)dr\right)=C_t/I(t)\tag{2.29}$$
带入(2.28)有: $$-\frac{dC_t}{dt}\leq \exp\left(\int_0^t L_f(r)+g(r)^2L_s(r)dr\right) g(t)^2b_t^{\frac{1}{2}}=I(t)g(t)^2b_t^{\frac{1}{2}}\tag{2.30}$$
两边同时对
由于
完成了证明。
注:苏老师在博客《https://spaces.ac.cn/archives/9467 (opens new window)》中对于ODE情况给出了自己的证明,和原论文的主要的差别在于(2.18)式的推导。苏老师把期望的
其实这种距离对