Geeks_Z の Blog Geeks_Z の Blog
首页
  • 学习笔记

    • 《HTML》
    • 《CSS》
    • 《JavaWeb》
    • 《Vue》
  • 后端文章

    • Linux
    • Maven
    • 汇编语言
    • 软件工程
    • 计算机网络概述
    • Conda
    • Pip
    • Shell
    • SSH
    • Mac快捷键
    • Zotero
  • 学习笔记

    • 《数据结构与算法》
    • 《算法设计与分析》
    • 《Spring》
    • 《SpringMVC》
    • 《SpringBoot》
    • 《SpringCloud》
    • 《Nginx》
  • 深度学习文章
  • 学习笔记

    • 《PyTorch》
    • 《ReinforementLearning》
    • 《MetaLearning》
  • 学习笔记

    • 《高等数学》
    • 《线性代数》
    • 《概率论与数理统计》
  • 增量学习
  • 哈希学习
GitHub (opens new window)

Geeks_Z

AI小学生
首页
  • 学习笔记

    • 《HTML》
    • 《CSS》
    • 《JavaWeb》
    • 《Vue》
  • 后端文章

    • Linux
    • Maven
    • 汇编语言
    • 软件工程
    • 计算机网络概述
    • Conda
    • Pip
    • Shell
    • SSH
    • Mac快捷键
    • Zotero
  • 学习笔记

    • 《数据结构与算法》
    • 《算法设计与分析》
    • 《Spring》
    • 《SpringMVC》
    • 《SpringBoot》
    • 《SpringCloud》
    • 《Nginx》
  • 深度学习文章
  • 学习笔记

    • 《PyTorch》
    • 《ReinforementLearning》
    • 《MetaLearning》
  • 学习笔记

    • 《高等数学》
    • 《线性代数》
    • 《概率论与数理统计》
  • 增量学习
  • 哈希学习
GitHub (opens new window)
  • Python

  • MLTutorials

  • 卷积神经网络

  • 循环神经网络

  • Transformer

  • VisionTransformer

  • 扩散模型

    • DDPM
    • VAE
    • DDIM
      • DDIM原理
      • 实验结果
      • 代码实现
      • 其它:重建和插值
      • 小结
      • 参考
    • StableDiffusion
    • 扩散模型和最优传输
  • 计算机视觉

  • PTM

  • MoE

  • LoRAMoE

  • LongTailed

  • 多模态

  • 知识蒸馏

  • PEFT

  • 对比学习

  • 小样本学习

  • 迁移学习

  • 零样本学习

  • 集成学习

  • Mamba

  • PyTorch

  • CL

  • CIL

  • 小样本类增量学习FSCIL

  • UCIL

  • 多模态增量学习MMCL

  • LTCIL

  • DIL

  • 论文阅读与写作

  • 分布外检测

  • GPU

  • 深度学习调参指南

  • AINotes
  • 扩散模型
Geeks_Z
2024-10-17
目录

DDIM

扩散模型之DDIM

Author: [小小将]

Link: [https://zhuanlan.zhihu.com/p/565698027]

“What I cannot create, I do not understand.” -- Richard Feynman

上一篇文章https://zhuanlan.zhihu.com/p/563661713 (opens new window)介绍了经典扩散模型DDPM的原理和实现,对于扩散模型来说,一个最大的缺点是需要设置较长的扩散步数才能得到好的效果,这导致了生成样本的速度较慢,比如扩散步数为1000的话,那么生成一个样本就要模型推理1000次。这篇文章我们将介绍另外一种扩散模型DDIM(https://arxiv.org/abs/2010.02502 (opens new window)),DDIM和DDPM有相同的训练目标,但是它不再限制扩散过程必须是一个马尔卡夫链,这使得DDIM可以采用更小的采样步数来加速生成过程,DDIM的另外是一个特点是从一个随机噪音生成样本的过程是一个确定的过程(中间没有加入随机噪音)。

DDIM原理

在介绍DDIM之前,先来回顾一下DDPM。在DDPM中,扩散过程(前向过程)定义为一个马尔卡夫链:

q(x1:T|x0)=∏t=1Tq(xt|xt−1)q(xt|xt−1)=N(xt;αtαt−1xt−1,(1−αtαt−1)I)

注意,在DDIM的论文中,αt 实是DDPM论文中的α¯t,那么DDPM论文中的前向过程βt 为:

βt=(1−αtαt−1)

扩散过程的一个重要特性是可以直接用x0 对任意的xt 行采样:

q(xt|x0)=N(xt;αtx0,(1−αt)I)

而DDPM的反向过程也定义为一个马尔卡夫链:

pθ(x0:T)=p(xT)∏t=1Tpθ(xt−1|xt)pθ(xt−1|xt)=N(xt−1;μθ(xt,t),Σθ(xt,t))

这里用神经网络pθ(xt−1|xt) 拟合真实的分布q(xt−1|xt)。DDPM的前向过程和反向过程如下所示:


我们近一步发现后验分布q(xt−1|xt,x0) 一个可获取的高斯分布:

q(xt−1|xt,x0)=N(xt−1;μ~(xt,x0),β~tI)

其中这个高斯分布的方差是定值,而均值是一个依赖x0 xt 组合函数:

μ~t(xt,x0)=αt(1−αt−1)αt−1(1−αt)xt+αt−1βt1−αtx0

然后我们基于变分法得到如下的优化目标:

L=Eq(x1:T|x0)[log⁡q(x1:T|x0)pθ(x0:T)]=DKL(q(xT|x0)∥pθ(xT))⏟LT+∑t=2TEq(xt|x0)[DKL(q(xt−1|xt,x0)∥pθ(xt−1|xt))]⏟Lt−1−Eq(x1|x0)log⁡pθ(x0|x1)⏟L0

根据两个高斯分布的KL公式,我们近一步得到:

Lt−1=Eq(xt|x0)[12σt2∥μ~t(xt,x0)−μθ(xt,t)∥2]

根据扩散过程的特性,我们通过重参数化可以近一步简化上述目标:

Lt−1=Ex0,ϵ∼N(0,I)[βt22σt2αt(1−α¯t)∥ϵ−ϵθ(α¯tx0+1−α¯tϵ,t)∥2]

如果去掉系数,那么就能得到更简化的优化目标:

Lt−1simple=Ex0,ϵ∼N(0,I)[∥ϵ−ϵθ(α¯tx0+1−α¯tϵ,t)∥2]

仔细分析DDPM的优化目标会发现,DDPM其实仅仅依赖边缘分布q(xt|x0),而并不是直接作用在联合分布q(x1:T|x0)。这带来的一个启示是:DDPM这个隐变量模型可以有很多推理分布来选择,只要推理分布满足边缘分布条件(扩散过程的特性)即可,而且这些推理过程并不一定要是马尔卡夫链。但值得注意的一个点是,我们要得到DDPM的优化目标,还需要知道分布q(xt−1|xt,x0),之前我们在根据贝叶斯公式推导这个分布时是知道分布q(xt|xt−1) ,而且依赖了前向过程的马尔卡夫链特性。如果要解除对前向过程的依赖,那么我们就需要直接定义这个分布q(xt−1|xt,x0)。 基于上述分析,DDIM论文中将推理分布定义为:

qσ(x1:T|x0)=qσ(xT|x0)∏t=2Tqσ(xt−1|xt,x0)

这里要同时满足qσ(xT|x0)=N(αTx0,(1−αT)I) 及对于所有的t≥2 :

qσ(xt−1|xt,x0)=N(xt−1;αt−1x0+1−αt−1−σt2xt−αtx01−αt,σt2I)

这里的方差σt2 一个实数,不同的设置就是不一样的分布,所以qσ(x1:T|x0) 实是一系列的推理分布。可以看到这里分布qσ(xt−1|xt,x0) 均值也定义为一个依赖x0 xt 组合函数,之所以定义为这样的形式,是因为根据qσ(xT|x0),我们可以通过数学归纳法证明,对于所有的t 满足:

qσ(xt|x0)=N(xt;αtx0,(1−αt)I)

这部分的证明见DDIM论文的附录部分,另外博客https://kexue.fm/archives/9181 (opens new window)也从待定系数法来证明了分布qσ(xt−1|xt,x0) 构造的形式。 可以看到这里定义的推理分布qσ(x1:T|x0) 没有直接定义前向过程,但这里满足了我们前面要讨论的两个条件:边缘分布qσ(xt|x0)=N(xt;αtx0,(1−αt)I),同时已知后验分布qσ(xt−1|xt,x0)。同样地,我们可以按照和DDPM的一样的方式去推导优化目标,最终也会得到同样的Lsimple(虽然VLB的系数不同,论文3.2部分也证明了这个结论)。 论文也给出了一个前向过程是非马尔可夫链的示例,如下图所示,这里前向过程是qσ(xt|xt−1,x0),由于生成xt 仅依赖xt−1,而且依赖x0,所以是一个非马尔可夫链:


注意,这里只是一个前向过程的示例,而实际上我们上述定义的推理分布并不需要前向过程就可以得到和DDPM一样的优化目标。与DDPM一样,这里也是用神经网络ϵθ 预测噪音,那么根据qσ(xt−1|xt,x0) 形式,在生成阶段,我们可以用如下公式来从xt 成xt−1:

xt−1=αt−1(xt−1−αtϵθ(xt,t)αt⏟predictedx0)+1−αt−1−σt2⋅ϵθ(xt,t)⏟direction pointing to xt+σtϵt⏟random noise

这里将生成过程分成三个部分:一是由预测的x0 产生的,二是由指向xt 部分,三是随机噪音(这里ϵt 与xt 关的噪音)。论文将σt2 一步定义为:

σt2=η⋅β~t=η⋅(1−αt−1)/(1−αt)(1−αt/αt−1)

这里考虑两种情况,一是η=1,此时σt2=β~t,此时生成过程就和DDPM一样了。另外一种情况是η=0,这个时候生成过程就没有随机噪音了,是一个确定性的过程,论文将这种情况下的模型称为DDIM(denoising diffusion implicit model),一旦最初的随机噪音xT 定了,那么DDIM的样本生成就变成了确定的过程。

上面我们终于得到了DDIM模型,那么我们现在来看如何来加速生成过程。虽然DDIM和DDPM的训练过程一样,但是我们前面已经说了,DDIM并没有明确前向过程,这意味着我们可以定义一个更短的步数的前向过程。具体地,这里我们从原始的序列[1,...,T] 样一个长度为S 子序列[τ1,...,τS],我们将xτ1,...,xτS 前向过程定义为一个马尔卡夫链,并且它们满足:q(xτi|x0)=N(xt;ατix0,(1−ατi)I)。下图展示了一个具体的示例:


那么生成过程也可以用这个子序列的反向马尔卡夫链来替代,由于S 以设置比原来的步数L 小,那么就可以加速生成过程。这里的生成过程变成:

xτi−1=ατi−1(xτi−1−ατiϵθ(xτi,τi)ατi)+1−ατi−1−στi2⋅ϵθ(xτi,τi)+στiϵ

其实上述的加速,我们是将前向过程按如下方式进行了分解:

qσ,τ(x1:T|x0)=qσ,τ(xT|x0)∏i=1Sqσ(xτi−1|xτi,x0)∏t∈τ¯qσ,τ(xt|x0)

其中τ¯={1,...,T}∖τ。这包含了两个图:其中一个就是由{xτi}i=1S 成的马尔可夫链,另外一个是剩余的变量{xt}t∈τ¯ 成的星状图。同时生成过程,我们也只用马尔可夫链的那部分来生成:

pθ(x0:T)=p(xT)∏i=1Spθ(xτi−1|xτi)⏟use to produce sample×∏t∈τ¯pθ(x0|xt)⏟only for VLB

论文共设计了两种方法来采样子序列,分别是:

  • Linear:采用线性的序列τi=⌊ci⌋;
  • Quadratic:采样二次方的序列τi=⌊ci2⌋;

这里的c 一个定值,它的设定使得τ−1 接近T。论文中只对CIFAR10数据集采用Quadratic序列,其它数据集均采用Linear序列。

实验结果

下表为不同的η 以及不同采样步数下的对比结果,可以看到DDIM(η=0)在较短的步数下就能得到比较好的效果,媲美DDPM(η=1)的生成效果。如果S 置为50,那么相比原来的生成过程就可以加速20倍。

代码实现

DDIM和DDPM的训练过程一样,所以可以直接在DDPM的基础上加一个新的生成方法(这里主要参考了https://github.com/ermongroup/ddim (opens new window)以及https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddim.py (opens new window)),具体代码如下所示:

class GaussianDiffusion:
    def __init__(self, timesteps=1000, beta_schedule='linear'):
     pass

    # ...
        
 # use ddim to sample
    @torch.no_grad()
    def ddim_sample(
        self,
        model,
        image_size,
        batch_size=8,
        channels=3,
        ddim_timesteps=50,
        ddim_discr_method="uniform",
        ddim_eta=0.0,
        clip_denoised=True):
        # make ddim timestep sequence
        if ddim_discr_method == 'uniform':
            c = self.timesteps // ddim_timesteps
            ddim_timestep_seq = np.asarray(list(range(0, self.timesteps, c)))
        elif ddim_discr_method == 'quad':
            ddim_timestep_seq = (
                (np.linspace(0, np.sqrt(self.timesteps * .8), ddim_timesteps)) ** 2
            ).astype(int)
        else:
            raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
        # add one to get the final alpha values right (the ones from first scale to data during sampling)
        ddim_timestep_seq = ddim_timestep_seq + 1
        # previous sequence
        ddim_timestep_prev_seq = np.append(np.array([0]), ddim_timestep_seq[:-1])
        
        device = next(model.parameters()).device
        # start from pure noise (for each example in the batch)
        sample_img = torch.randn((batch_size, channels, image_size, image_size), device=device)
        for i in tqdm(reversed(range(0, ddim_timesteps)), desc='sampling loop time step', total=ddim_timesteps):
            t = torch.full((batch_size,), ddim_timestep_seq[i], device=device, dtype=torch.long)
            prev_t = torch.full((batch_size,), ddim_timestep_prev_seq[i], device=device, dtype=torch.long)
            
            # 1. get current and previous alpha_cumprod
            alpha_cumprod_t = self._extract(self.alphas_cumprod, t, sample_img.shape)
            alpha_cumprod_t_prev = self._extract(self.alphas_cumprod, prev_t, sample_img.shape)
    
            # 2. predict noise using model
            pred_noise = model(sample_img, t)
            
            # 3. get the predicted x_0
            pred_x0 = (sample_img - torch.sqrt((1. - alpha_cumprod_t)) * pred_noise) / torch.sqrt(alpha_cumprod_t)
            if clip_denoised:
                pred_x0 = torch.clamp(pred_x0, min=-1., max=1.)
            
            # 4. compute variance: "sigma_t(η)" -> see formula (16)
            # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
            sigmas_t = ddim_eta * torch.sqrt(
                (1 - alpha_cumprod_t_prev) / (1 - alpha_cumprod_t) * (1 - alpha_cumprod_t / alpha_cumprod_t_prev))
            
            # 5. compute "direction pointing to x_t" of formula (12)
            pred_dir_xt = torch.sqrt(1 - alpha_cumprod_t_prev - sigmas_t**2) * pred_noise
            
            # 6. compute x_{t-1} of formula (12)
            x_prev = torch.sqrt(alpha_cumprod_t_prev) * pred_x0 + pred_dir_xt + sigmas_t * torch.randn_like(sample_img)

            sample_img = x_prev
            
        return sample_img.cpu().numpy()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67

这里以MNIST数据集为例,训练的扩散步数为500,直接采用DDPM(即推理500次)生成的样本如下所示:


同样的模型,我们采用DDIM来加速生成过程,这里DDIM的采样步数为50,其生成的样本质量和500步的DDPM相当:


完整的代码示例见https://github.com/xiaohu2015/nngen (opens new window)。

其它:重建和插值

在DDIM论文中,还额外讨论了两个小点的内容**:重建和插值**。所谓重建是指的首先用原始图像求逆得到对应的噪音然后再进行生成的过程;而插值是指的对两个随机噪音进行插值从而得到融合两种噪音的图像。 首先是重建,对于DDIM,其η=0,这个时候从xt 成xt−1 更新公式就变为:

xt−1=αt−1(xt−1−αtϵθ(xt,t)αt)+1−αt−1⋅ϵθ(xt,t)

我们进一步对上述公式进行变换可得到:

xt−1αt−1=xtαt+(1−αt−1αt−1−1−αtαt)ϵθ(xt,t)

当T 够大时,以上公式其实可以看成用欧拉法来求解一个常微分方程(ODE,ordinary differential equation):

xt−Δtαt−Δt=xtαt+(1−αt−Δtαt−Δt−1−αtαt)ϵθ(xt,t)

这里令σ=1−α/α,x¯=x/α,它们都是关于t 函数,这样对应的ODE就是:

dx¯(t)=ϵθ(x¯(t)σ2+1,t)dσ(t)

看成ODE后,我们可以利用如下公式对生成过程进行逆操作:

xt+1αt+1=xtαt+(1−αt+1αt+1−1−αtαt)ϵθ(xt,t)

这意味着,我们可以由一个原始图像x0 到对应的随机噪音xT,然后我们再用xT 行生成就可以重建原始图像x0(具体的代码实现见https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py#L524-L560 (opens new window))。论文也通过在CIFAR10测试集上的实验来证明当步数足够时,这种方式可以得到较低的重建误差:


第二个插值,对于DDIM,两个不同的随机噪音会产生不同的图像,但是如果我们对这两个随机噪音进行插值生成新的xT,那么将生成融合的图像。这里采用的插值方法是球面线性插值( spherical linear interpolation):

xT(α)=sin⁡((1−α)θ)sin⁡(θ)xT(0)+sin⁡(αθ)sin⁡(θ)xT(1)θ=arccos⁡((xT(0))TxT(1)∥xT(0)∥∥xT(1)∥)

这里的参数α∈[0,1] 制插值系数,具体的代码实现见https://github.com/ermongroup/ddim/blob/main/runners/diffusion.py#L296-L334 (opens new window)。下图展示了一些具体的插值效果:


DDIM的重建和插值也在文本转图像模型DALLE-2中使用,不过这里插值的是扩散模型的条件CLIP image embedding,详情见论文**https://arxiv.org/abs/2204.06125 (opens new window)**。



小结

如果从直观上看,DDIM的加速方式非常简单,直接采样一个子序列,其实论文https://arxiv.org/abs/2102.09672 (opens new window)也采用了类似的方式来加速。另外DDIM和其它扩散模型的一个较大的区别是其生成过程是确定性的。

参考

  • https://arxiv.org/abs/2010.02502 (opens new window)
  • https://github.com/ermongroup/ddim (opens new window)
  • https://github.com/openai/improved-diffusion (opens new window)
  • https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddim.py (opens new window)
  • https://github.com/CompVis/latent-diffusion/blob/main/ldm/models/diffusion/ddim.py (opens new window)
  • https://kexue.fm/archives/9181 (opens new window)
上次更新: 2025/06/25, 11:25:50
VAE
StableDiffusion

← VAE StableDiffusion→

最近更新
01
帮助信息查看
06-08
02
常用命令
06-08
03
学习资源
06-07
更多文章>
Theme by Vdoing | Copyright © 2022-2025 Geeks_Z | MIT License
京公网安备 11010802040735号 | 京ICP备2022029989号-1
  • 跟随系统
  • 浅色模式
  • 深色模式
  • 阅读模式