Geeks_Z の Blog Geeks_Z の Blog
首页
  • 学习笔记

    • 《HTML》
    • 《CSS》
    • 《JavaWeb》
    • 《Vue》
  • 后端文章

    • Linux
    • Maven
    • 汇编语言
    • 软件工程
    • 计算机网络概述
    • Conda
    • Pip
    • Shell
    • SSH
    • Mac快捷键
    • Zotero
  • 学习笔记

    • 《数据结构与算法》
    • 《算法设计与分析》
    • 《Spring》
    • 《SpringMVC》
    • 《SpringBoot》
    • 《SpringCloud》
    • 《Nginx》
  • 深度学习文章
  • 学习笔记

    • 《PyTorch》
    • 《ReinforementLearning》
    • 《MetaLearning》
  • 学习笔记

    • 《高等数学》
    • 《线性代数》
    • 《概率论与数理统计》
  • 增量学习
  • 哈希学习
GitHub (opens new window)

Geeks_Z

AI小学生
首页
  • 学习笔记

    • 《HTML》
    • 《CSS》
    • 《JavaWeb》
    • 《Vue》
  • 后端文章

    • Linux
    • Maven
    • 汇编语言
    • 软件工程
    • 计算机网络概述
    • Conda
    • Pip
    • Shell
    • SSH
    • Mac快捷键
    • Zotero
  • 学习笔记

    • 《数据结构与算法》
    • 《算法设计与分析》
    • 《Spring》
    • 《SpringMVC》
    • 《SpringBoot》
    • 《SpringCloud》
    • 《Nginx》
  • 深度学习文章
  • 学习笔记

    • 《PyTorch》
    • 《ReinforementLearning》
    • 《MetaLearning》
  • 学习笔记

    • 《高等数学》
    • 《线性代数》
    • 《概率论与数理统计》
  • 增量学习
  • 哈希学习
GitHub (opens new window)
  • 线性代数

  • 概率论与数理统计

  • 矩阵

  • 分布

    • 统计量及其分布
    • 参数估计
    • 假设检验
    • 先验分布与后验分布
    • 分布度量
    • 交叉熵
    • 最优运输概述
    • Wasserstein距离
    • 基于最优传输的分类损失函数
    • 最优传输之生成模型
      • WGAN
      • 扩散模型
    • 最优传输之梯度流
    • 两个多元正态分布的KL散度巴氏距离和W距离
  • 数学笔记
  • 分布
Geeks_Z
2025-01-19
目录

最优传输之生成模型

WGAN

假设 F 的形式为: $$F = \left[ \begin{array}{c} f_1\ f_2 \ \vdots \ f_n \ \hline g_1 \ g_2 \ \vdots \ g_n \ \end{array} \right] = \left[ \begin{array}{c} f(x_1) \ f(x_2) \ \vdots \ f(x_n) \ \hline g(y_1) \ g(y_2) \ \vdots \ g(y_n) \ \end{array} \right]\tag{2.1}$$
为什么这样设?第一个等号说明, F 的形状跟 b 一样,因此每个位置 x1,x2,...,xn,y1,y2,...yn 上对应一个待求解的未知变量。第二个等号说明我们可以用已知的 fi 和 xi 或 gi 和 yi 来得到一个统一的 f 和 g ,这样有利于接下来的推导分析。

bTF 就可以展开为: $$b^T F=\sum_{i,j} p(x_i)f(x_i)+q(y_i)g(y_i)\tag{2.2}$$
因此写成连续形式则有: $$b^T F=\int [p(x)f(x) +q(x)g(x)] dx\tag{2.3}$$

此外对于约束 A⊤F≤C : $$\underbrace{\left[ \begin{array}{cccc|cccc} 1 & 0 & ... & 0 & 1 & 0 & ... & 0\ 1 & 0 & ... & 0 & 0 & 1 & ... & 0\ \vdots & \vdots & \ddots & \vdots & \vdots & \vdots & \ddots & \vdots\ \hline 0 & 1 & ... & 0 & 1 & 0 & ... & 0\ 0 & 1 & ... & 0 & 0 & 1 & ... & 0\ \vdots & \vdots & \ddots & \vdots & \vdots & \vdots & \ddots & \vdots\ \hline \vdots & \vdots & \ddots & \vdots & \vdots & \vdots & \ddots & \vdots\ \hline 0 & 0 & ... & 1 & 1 & 0 & ... & 0\ 0 & 0 & ... & 1 & 0 & 1 & ... & 0\ \vdots & \vdots & \ddots & \vdots & \vdots & \vdots & \ddots & \vdots\ \end{array} \right]}{A^T} \underbrace{\left[ \begin{array}{c} f(x_1) \ f(x_2) \ \vdots \ f(x_n) \ \hline g(y_1) \ g(y_2) \ \vdots \ g(y_n) \ \end{array} \right]}{F} \leq \underbrace{\begin{bmatrix} c(x_1, y_1) \ c(x_1, y_2) \ \vdots \ \hline c(x_2, y_1) \ c(x_2, y_2) \ \vdots \ \hline \vdots \ \hline c(x_n, y_1) \ c(x_n, y_2) \ \vdots\ \end{bmatrix}}_C \tag{2.4}$$
展开后,实际上说的就是: $$f(x_i)+g(y_j)\leq c(x_i,y_j)\tag{2.5}$$

那么(1.16)式就转化为了: $$\max_{f,g} \left{\int [p(x)f(x) +q(x)g(x)] dx \mid f(x) + g(y) \leq c(x,y)\right}\tag{2.6}$$
这也是(1.10)W距离对应的对偶问题,即求解两个分布的W距离相当于求解(2.4)。

对于特殊情况 x=y ,此时 c(x,x)=0 ,则: $$f(x)+g(x)\leq 0 \Rightarrow g(x)\leq -f(x)\tag{2.7}$$
这启发我们,如果令 g(x)=−f(x) ,(2.6)式就能写成: $$\max_{f} \left{\int [p(x)f(x) -q(x)f(x)] dx \mid f(x) - f(y) \leq c(x,y)\right}\tag{2.8}$$
凑巧的是,(2.6)刚好等价于(2.8),下面是证明:

g(x)=−f(x) 是(2.7)中 g(x)≤−f(x) 的特殊情况,因此容易知道(2.8)的解一定在(2.6)中,记为 (2.6)⊇(2.8) 。
另外,由于 p(x)f(x)+q(x)g(x)≤p(x)f(x)−q(x)f(x) ,因此(2.6)的目标一定被(2.8)包含,所以只需证明(2.8)和(2.6)的约束是等价的,那么就有 (2.6)⊆(2.8) ,综上就能推出 (2.6)=(2.8) 了。但对于约束,由于 $ f(x) + g(y) \leq f(x) - f(y)$ ,所以 f(x)−f(y) 不一定会小于 c(x,y) ,因此关于这部分需要换个思路。
假设(2.6)的最优解为 f∗(x) 和 g∗(y) ,(2.8)的目标此时就为 p(x)f∗(x)−q(x)f∗(x) ,那么可以想办法构造出一个新解 f∗∗(x) ,其不仅满足目标,且符合约束 f∗∗(x)−f∗∗(y)≤c(x,y) 。由于 f∗(x)+g∗(y)≤c(x,y) ,则 $f^(x) \leq c(x,y)-g^(y) $ ,令: $$\begin{align} f^{}(x)&=\min_{y} {c(x,y)-g^(y) }\ z(x) &= \operatorname{argmin}_{y} {c(x,y)-g^(y) } \end{align}\tag{2.9}$$
显然 f∗(x)≤f∗∗(x) ,那么(2.6)的目标: $$p(x)f^
(x) +q(x)g^*(x)\leq p(x)f^{
}(x) +q(x)g^(x)\tag{2.10}$$
因为为 f∗(x) 和 g∗(y) 已经是最优解了,所以(2.10)只能取等号, f∗∗(x) 和 g∗(y) 也是一对最优解,那么约束: $$\begin{align} f^{}(x) - f^{}(y) &= [c(x, z(x)) - g^{
}(z(x))] - [c(y, z(y)) - g^{}(z(y))] \ &\leq [c(x, z(y)) - g^{}(z(y))] - [c(y, z(y)) - g^{*}(z(y))] \ &= c(x, z(y)) - c(y, z(y)) \ &\leq c(x, y) \end{align}\tag{2.11}$$
其中,第一个等号根据(2.9)的定义;第二个不等号是因为 z(x) 已是最优解,换成其它肯定会放大;最后一个不等号是由于距离的三角不等式,而W距离中的 c(x,y) 定义为欧式距离,这显然满足。
上述说明,(2.6)的解一定包含在(2.8)中(目标符合,约束也符合),因此有 (2.6)⊆(2.8) ,综合起来就有 (2.6)=(2.8) ,完成了证明。

最终我们把W距离转化为了(2.8)式。对于约束 f(x)−f(y)≤c(x,y) ,根据欧氏距离的对称性还有 f(x)−f(y)≥−c(x,y) 写成期望形式: $$\max_{f, |f(x) - f(y)| \leq |x-y|} \left{\mathbb{E}{p(x)}[f(x)] -\mathbb{E}{q(x)}[f(x)] \right}\tag{2.12}$$
对于生成模型,我们会要求两个分布的距离尽量靠近,所以如果再最小化 p(x) 和 q(x) 的W距离,假设 g 为生成器, q(x)=q(g(z)),z∼N(0,1) ,那么最终的形式为: $$\min_{g}\max_{f, |f(x) - f(y)| \leq |x-y|} \left{\mathbb{E}{p(x)}[f(x)] -\mathbb{E}{q(g(z))}[f(g(z))] \right}\tag{2.13}$$
这正是WGAN,因此WGAN的损失函数就是在缩短两个分布的W距离。

扩散模型

根据之前的推导,WGAN主要是在优化W距离,而扩散模型主要是优化KL距离(散度),那么这两个距离之间会有关联吗?2022年的一篇文章《https://arxiv.org/abs/2212.06359 (opens new window)》就介绍了扩散模型的得分匹配损失(实质上也是KL距离的上界#ref_4)是W距离的一个上界,因此在某种程度上,优化得分就等于优化W距离,这样就将扩散模型和WGAN联系到一起了。

这篇文章介绍的最核心的定理如下:

定理1. 假设 pt(x) 服从以下正向SDE演化过程: $$dx=f(x,t)dt+g(t)dw,\qquad t\in[0,T]\tag{2.14}$$
从 t=0 开始,定义 p0(x) 为数据分布。令 $ s_{\theta}(t, x)$ 是由(2.14)经过得分匹配损失训练得到的。假设 qt(x) 服从以下逆向SDE演化过程: $$dx=[f(x,t)-g(t)^2s_\theta(x,t)]dt+g(t)dw,\qquad t\in[0,T]\tag{2.15}$$
从 t=T 开始,定义 qT(x) 为指定先验分布(例如标准高斯噪声)。那么有以下关系: $$W_2(p_0, q_0) \leq \int_{0}^{T} g(t)I(t)\mathbb{E}{p_t} \left[ | \nabla \log p_t(x) - s{\theta}(t, x) |^2 \right]^{\frac{1}{2}} dt + I(T) W_2(p_T, q_T)\tag{2.16}$$
W2(p0,q0) 表示 p0(x) 和 q0(x) 之间的W-2距离; I(t)=exp⁡(∫0t(Lf(r)+Ls(r)g(r)2)dr) 单调递增,其中两个非负函数 Lf(t) 和 Ls(t) 来自论文的前提假设, f(x,t) 满足Lipschitz约束, sθ(x,t) 满足单边Lipschitz约束: $$\begin{align} |f(x,t)-f(y,t)|&\leq L_f(t)|x-y|\ (s_\theta(x,t)-s_\theta(y,t))\cdot (x-y)&\leq L_s(t)|x-y|^2 \end{align}\tag{2.17}$$
定理1告诉我们,如果我们优化得分匹配损失,那么也相当于优化W距离,所以扩散模型不但在优化两个分布之间的KL距离,还在悄悄优化W距离,也就揭示扩散模型模型和WGAN的联系了。

要完整这个定理需要用到很多最优传输中的引理,作者在论文中也只是简单引用,因此这里就简单介绍一下作者主要的证明思路:

根据#ref_5的定理8.4.7和#ref_6的推论5.25,有: $$-\frac{1}{2}\frac{dW_2^2(p_t(x),q_t(y))}{dt}=\mathbb{E}{\pi_t(x,y)}\left[(x-y)\cdot(\frac{dy}{dt}-\frac{dx}{dt}) \right]\tag{2.18}$$
这个式子是证明的**核心,**其中 πt(x,y) 表示 pt(x) 到 qt(y) 的最优传输策略, dx/dt 和 dy/dt 分别是 pt(x) 和 qt(y) 对应路径 x 和 y 对 t 的全微分,即为概率流ODE。
(2.14)对应的概率流ODE为: $$\frac{dx}{dt}=f(x,t)-g(t)^2\nabla
{x} \log p_t(x)\tag{2.19}$$

(2.15)对应的概率流ODE为: $$\frac{dy}{dt}=f(y,t) - g(t)^2 s_\theta(y,t) + \frac{1}{2}g(t)^2 \nabla_{y} \log q_t(y)\tag{2.20}$$
带入(2.18),有: $$\begin{align} -\frac{1}{2}\frac{dW_2^2(p_t(x),q_t(y))}{dt}&=\mathbb{E}{\pi_t(x,y)}\left[(x-y)\cdot(f(y,t)-f(x,t)) \right]\ &\quad+g(t)^2\mathbb{E}{\pi_t(x,y)}\left[(x-y)\cdot(s_\theta(x,t)-s_\theta(y,t)) \right]\ &\quad+g(t)^2\mathbb{E}{\pi_t(x,y)}\left[(x-y)\cdot(\log \nabla{x}p_t(x)-s_\theta(x,t)) \right]\ &\quad+\frac{g(t)^2}{2}\mathbb{E}{\pi_t(x,y)}\left[(x-y)\cdot(\log \nabla{y}q_t(y)-\log \nabla_{x}p_t(x)) \right]\ \end{align}\tag{2.21}$$
右边第一项和第二项根据(2.17)的约束,可以很容易得到: $$\begin{align} \mathbb{E}{\pi_t(x,y)}\left[(x-y)\cdot(f(y,t)-f(x,t)) \right]&\leq L_f(t)\mathbb{E}{\pi_t(x,y)}[|x-y|^2]\ &=L_f(t)W_2^2(p_t(x),q_t(y)) \end{align}\tag{2.22}$$
和: $$\begin{align} g(t)^2\mathbb{E}{\pi_t(x,y)}\left[(x-y)\cdot(s\theta(x,t)-s_\theta(y,t)) \right]&\leq g(t)^2L_s(t)\mathbb{E}{\pi_t(x,y)}[|x-y|^2]\ &=g(t)^2L_s(t)W_2^2(p_t(x),q_t(y)) \end{align}\tag{2.23}$$
第三项利用积分Cauchy-Schwarz不等式: $$\begin{align} g(t)^2\mathbb{E}
{\pi_t(x,y)}\left[(x-y)\cdot(\log \nabla_{x}p_t(x)-s_\theta(x,t)) \right] &\leq g(t)^2 \mathbb{E}{\pi_t(x,y)}[|x-y|^2]^{\frac{1}{2}}\mathbb{E}{\pi_t(x,y)}[|\log \nabla_{x}p_t(x)-s_\theta(x,t)|^2]^{\frac{1}{2}}\ &=g(t)^2W_2(p_t(x),q_t(y))\mathbb{E}{p_t(x)}[|\log \nabla{x}p_t(x)-s_\theta(x,t)|^2]^{\frac{1}{2}} \end{align}\tag{2.24}$$
第四项根据论文附录的引理2,有: $$\mathbb{E}{\pi_t(x,y)}\left[(x-y)\cdot(\log \nabla{y}q_t(y)-\log \nabla_{x}p_t(x)) \right]\leq 0\tag{2.25}$$
综合(2.22)-(2.25),最终有: $$\begin{align} -\frac{1}{2}\frac{dW_2^2(p_t(x),q_t(y))}{dt} & \leq L_f(t)W_2^2(p_t(x),q_t(y)) \ &\quad+ g(t)^2L_s(t)W_2^2(p_t(x),q_t(y))\ &\quad+g(t)^2W_2(p_t(x),q_t(y))b_t^{\frac{1}{2}}\ \end{align}\tag{2.26}$$
简记 bt=Ept(x)[∥log⁡∇xpt(x)−sθ(x,t)∥2] ,由于: $$-\frac{1}{2}\frac{dW_2^2(p_t(x),q_t(y))}{dt}=-W_2(p_t(x),q_t(y))\frac{dW_2(p_t(x),q_t(y))}{dt}\tag{2.27}$$
因此(2.26)两端可以整理为: $$-\frac{dW_2(p_t(x),q_t(y))}{dt}\leq (L_f(t)+g(t)^2L_s(t))W_2(p_t(x),q_t(y)) +g(t)^2b_t^{\frac{1}{2}} \tag{2.28}$$
这是非齐次线性一阶微分方程,利用常数变易法,设: $$W_2(p_t(x),q_t(y))=C_t\exp\left(\int_{t}^0 L_f(r)+g(r)^2L_s(r)dr\right)=C_t/I(t)\tag{2.29}$$
带入(2.28)有: $$-\frac{dC_t}{dt}\leq \exp\left(\int_0^t L_f(r)+g(r)^2L_s(r)dr\right) g(t)^2b_t^{\frac{1}{2}}=I(t)g(t)^2b_t^{\frac{1}{2}}\tag{2.30}$$
两边同时对 t 从0积到 T ,则: $$C_0\leq \int_0^T I(t)g(t)^2b_t^{\frac{1}{2}}dt+C_T\tag{2.31}$$
由于 W2(pT(x),qT(y))=CT/I(T) ,则 CT=I(T)W2(pT(x),qT(y)) ,最终: $$W_2(p_0(x),q_0(y))=C_0\leq \int_0^T I(t)g(t)^2b_t^{\frac{1}{2}}dt+I(T)W_2(p_T(x),q_T(y))\tag{2.32}$$
完成了证明。

注:苏老师在博客《https://spaces.ac.cn/archives/9467 (opens new window)》中对于ODE情况给出了自己的证明,和原论文的主要的差别在于(2.18)式的推导。苏老师把期望的 t 时刻最优传输方案 πt(x,y) 改为了由 pT(z) 通过 dx/dt 和 dy/dt ODE映射得到的 γt(x(z),y(z)) ,这样的好处是可以把W-2距离转化成关于中间变量 z 的期望,与时间无关,那么就可以很容易的对 t 求导得到类(2.18)式,后续的操作是一样的。但在SDE情况下不考虑最优传输方案会出现误差,无法使用类似的思路进行推导。

其实这种距离对 t 求导数的操作在#ref_7#ref_8#ref_4中都有用过,当时是利用KL散度对 t 的导数来推出它的一个上界为得分匹配,现在看来跟本文的结果异曲同工了,只是W距离的推导需要最优传输的背景,相对更加复杂。

上次更新: 2025/06/25, 11:25:50
基于最优传输的分类损失函数
最优传输之梯度流

← 基于最优传输的分类损失函数 最优传输之梯度流→

最近更新
01
帮助信息查看
06-08
02
常用命令
06-08
03
学习资源
06-07
更多文章>
Theme by Vdoing | Copyright © 2022-2025 Geeks_Z | MIT License
京公网安备 11010802040735号 | 京ICP备2022029989号-1
  • 跟随系统
  • 浅色模式
  • 深色模式
  • 阅读模式