V-MoE
Scaling Vision with Sparse Mixture of Experts (opens new window)
Code (opens new window) | NeurIPS 2021
摘要
稀疏门控混合专家网络 (MoESparsely-gated Mixture of Experts networks,MoE)) 这种方法已经在自然语言处理领域中表现出了出色的可扩展性。但是在计算机视觉中,几乎所有性能网络都是 "密集 (Dense) 的",也就是说,每个输入都由所有的参数来处理。
本文就提出了视觉领域经典的稀疏门控混合专家网络 Vision MoE (V-MoE),它是 Vision Transformer 的稀疏版本,V-MoE 是一种可扩展的架构,其性能和最大的密集网络适配。在图像识别任务中,V-MoE 与 SOTA 的网络的性能相匹配,同时在推理时只需要不到一半的 FLOPs。作者还对路由算法进行了扩展,对整个批次中每个输入的子集进行优先级排序,从而对每幅图像实现自适应的计算,使得测试时的计算更加平滑。
背景
提高模型容量的新方法:稀疏门控混合专家模型
在深度学习的实践经验中,增加网络容量和数据集大小通常会提高模型的性能,这种方法在 NLP 领域已经取得了成功,比如各种各样的 Transformer。
但是,训练以及部署这样的模型需要很大的代价,十分昂贵。一方面的原因是这些深度神经网络通常是 "密集 (Dense)" 的,这个含义是:对于任何单个输入,都需要模型的全部参数来处理。因此,扩大模型的规模往往意味着高昂的计算代价。
那么,能不能有一种方法,可以在提高模型尺寸的同时,保持其训练和推理的成本大致恒定?
稀疏门控混合专家网络 (Sparsely-gated Mixture of Experts networks) 这种方法就是为了实现这个目的,其已经在自然语言处理领域中表现出了出色的可扩展性,能够在解锁万亿参数的模型时以更少的资源实现训练和推理。
本文把这种方法应用到了 CV 领域中,引入的模型称之为 V-MoE,是一种 Vision Transformer (ViT) 架构的稀疏变体。V-MoE 用稀疏的 MoE 层取代了 ViT 中密集的 MLP (又称 Expert)。每个输入图片被 "路由" 到不同 Expert 的子集中。但是,MoE 这种技术路线有其独特的劣势:即不可微性 (non-differentiability),使得训练这样的模型本身就具有挑战性。
作者在本文探索了多种设计的范式和思路,并为 V-MoE 的预训练和微调提出了一种行之有效的策略,明显优于其他 Dense 的模型。而且,V-MoE 模型还很灵活,我们可以自由调节已经训练好的模型的稀疏度来自由改变其性能和推理成本的 trade-off。使用了 V-MoE 技术之后,我们可以把一个 ViT 视觉模型扩展到 15B 参数大小,这是迄今为止最大的视觉模型,其性能也和 SOTA 的 Dense 模型相当,同时需要更少的时间进行训练。
V-MoE 的贡献
V-MoE 的贡献可以概括为:
超大规模的视觉模型(Vision models at scale): 稀疏门控混合专家网络是一种用于视觉的分布式稀疏激活 Transformer 模型,包含 24 个 MoE 层,每层包含 32 个 Expert,总参数量 15B。这个模型可以稳定地训练,无缝微调;
性能和推理开销(Performance and inference): V-MoE 在上游、few-shot 和完整微调指标方面大大优于 Dense 的竞争者。此外,在推理时,V-MoE 模型可以调整为:1) 计算量或实际运行时间一半的情况下匹配最大的 Dense 模型性能,或者 2) 成本相当的情况下显著优于最大的 Dense 模型;
基于优先级的 Batch 路由算法 (Batch Prioritized Routing): V-MoE 提出一种新的基于优先级的路由算法,让 V-MoE 丢弃掉最没用的 Patch,从而对每个图像投入较少的计算。效果是 V-MoE 可以节约掉 20% 的计算开销,同时匹配最大的 Dense 模型性能。
V-MoE 的计算方法
一句话概括:不同的 Expert 模型负责不同的输入的部分。
V-MoE 的架构如下图 1 所示。V-MoE 把 ViT 的一部分 Block 里面的 MLP 换成了 Sparse MoE。MLP 的表达式为:
每个 Expert 的函数是:

图 1:V-MoE 的架构:V-MoE 由 L 个 ViT Block 组成。每个 MLP (Expert) 都存储在一个单独的设备上,并处理固定数量的令牌
MoE 的路由函数
式中,
专家的缓冲区容量
专家的缓冲区容量(每个 Expert 处理的 token (即图像 Patch) 数):
如果路由器为给定的 Expert 分配超过
模型细节
V-MoE 一共是来自 5 种尺寸的 ViT 模型:ViT-S(mall), ViT-B(ase), ViT-L(arge) 和 ViTH(uge)
有 3 个主要设计决策会影响模型的成本:
MoE 层的数量: 本文尝试了两种:每一层都使用 MoE,通常命名为 Every xxx,或者只有最后几层使用 MoE,通常命名为 Last xxx。作者发现,尽管使用较少的 MoE 层减少了模型的参数数量,但它对质量的影响通常很小,并且可以显着加快模型,因为会产生更少的通信开销。
Expert 的数量: V-MoE 模型的成本不依赖于 Expert 的总数,而是每个 token 所选择的专家数量
。V-MoE 默认使用 ,而作者发现 Expert 总数 是实验设置中的最佳点。 模型容量 C: 在上游训练期间,默认设置
以给出少量松弛,而不会显着增加成本。
路由算法
原始的 Routing 算法如下图 3 的 Algorithm1 所示,设

图 3:Batch Prioritized Routing 和原始算法的对比
Batch Prioritized Routing(BPR) 算法如上图 3 的 Algorithm2 所示,思路就是优先处理最重要的 token (路由权重高的)。为了优先处理最重要的 token,作者提出对这
Loss function
Importance Loss
重要性损失:激励专家的平衡使用(所有专家具有同等的重要性)
专家
- expert
- a batch of images
is the layer-specific weight matrix for the router
We use the squared coefficient of variation of the importance distribution over experts,
计算每个专家对当前 batch 的重要性,然后计算所有重要性的均值和标准差
Load Loss
负载均衡(load balancing):专家接收大致相等数量的训练样本
where
The probability is defined over
Final Auxiliary Loss
The overall loss is:
We set
V-MoE 训练数据
V-MoE 在私有数据集 JFT-300M 上做预训练,它包含了约 305M 的训练集和 50000 的验证集。类别数 18291 (每张图像平均 1.89 个标签)。
对于 ImageNet 上的 Few-Shot 实验,作者只在每个类别使用 1, 5, 或者 10 shots 来调整上游模型,在验证集上评估得到的模型。
实验结果
V-MoE 的上游任务预训练数据集是 JFT-300M,它是一个多标签数据集,因此作者通过 precision@1 来衡量模型性能。如下图 4(a) 和图 5 所示是不同 V-MoE 和 ViT 变体相对于总训练时间的质量。图 4 的结果是 V-MoE last-n,就是最后的

图 4:(a) V-MoE 上游任务预训练实验结果。(b) V-MoE Few-Shot 实验结果

图 5:V-MoE 实验结果
放大 V-MoE
作为一种放大 ViT 模型的方法,作者自然希望验证下 V-MoE 方法对于超大 ViT 模型的泛化性能。为此,作者增加了模型的大小并使用更大的预训练数据集:JFT-3B 是 JFT-300M 的更大版本,它包含近 3B 图像,并且用 30k 类嘈杂注释。
除此之外,作者再对 Dense 模型放大的方法做了改进之后得到如下的放大 Sparse Mixture of Expert Layers 模型的方法:
- 低精度:使用 bfloat16 而不是 float32 来存储梯度移动平均。
- 学习率衰减:使用 inverse square root schedule 替换 linear schedule。
- 权重衰减:模型中 kernel weights 的权重衰减率为 0.03,bias 不使用,head kernel 的权重衰减为 3.0。
- 模型的 head:不使用 ViT 的 head,其第 1 个 token 被选择和使用,使用基于 Self-Attention 的 head。
作者训练了一个带有 48 个 MoE Block 的 V-MoE 模型,每个 Block 32 Experts,每个 token 被 2 个 Expert 处理,即
作者成功训练了 V-MoE-15B,这是迄今为止最大的视觉模型。它在 5-shot ImageNet 上的准确率为 82.78%,令人印象深刻,在完全微调时的准确率更是达到了惊人的 90.35%。当时 ImageNet 上当前最先进的技术是 Meta Pseudo-Labelling (MPL)。MPL 使用 ImageNet 伪标记在未标记的 JFT-300M 上训练基于 EffecentNet 的模型,达到了 90.2%。