CODA-Prompt
0. 摘要
计算机视觉模型在学习不断变化的训练数据中的新概念时,会遭受一种称为灾难性遗忘的现象。这种持续学习问题的典型解决方案需要广泛复习以前见过的数据,这增加了内存成本,并且可能违反数据隐私。最近,大规模预训练视觉变换器模型的出现使得提示方法成为数据复述的替代方案。这些方法依赖于关键 - 查询机制来生成提示,并且被发现在免复习的持续学习设置中高度抵抗灾难性遗忘。然而,这些方法的关键机制并没有与任务序列一起端到端训练。我们的实验表明,这导致它们的可塑性降低,因此牺牲了新任务的准确性,并且无法从扩展的参数容量中受益。我们提出学习一组提示组件,这些组件用输入条件权重组装,产生输入条件提示,从而产生一种新颖的基于注意力的端到端关键 - 查询方案。我们的实验表明,我们在建立的基准测试上超过了当前的最佳方法 DualPrompt,平均最终准确率提高了 4.5%。我们还在包含类别增量和域增量任务转换的持续学习基准测试上超越了最先进的水平,准确率提高了 4.4%,这对应于许多实际设置。我们的代码可在 https://github.com/GT-RIPL/CODA-Prompt (opens new window) 找到。
1. 引言
要使计算机视觉模型在现实世界中成功,它必须克服一个脆弱的假设,即它在部署后将遇到的 concepts 将与其在训练期间学习的 concepts 相匹配。真实世界确实是动态的,包含不断出现的对象和类别。
2. 背景和相关工作
2.1. 持续学习
持续学习方法可以分为几类,这些类别在不同的问题设置和约束下都是有用的。一组方法在遇到新任务时扩展模型的架构;这些方法对于模型随任务增长而实际的应用非常有效。另一种方法是在训练新任务时用过去任务的知识来规范模型。这可以通过在权重空间(即,惩罚模型参数的变化)或预测空间(即,惩罚模型预测的变化)中规范模型来完成。在预测空间中规范知识是使用知识蒸馏来完成的,并且已经发现它比基于模型规范的方法在任务标签未给出时对持续学习表现更好。存储数据的复述或从生成模型中取样的复述在存储训练数据或训练/保存生成模型是可能的时候非常有效。不幸的是,对于许多机器学习应用来说,长期存储训练数据将违反数据隐私,并且还会承担大的内存成本。关于生成模型,这个训练过程在持续学习设置中与分类模型相比更加计算和内存密集,此外,可能还会违反数据合法性问题,因为使用生成模型增加了记忆潜在敏感数据的机会。这激发了我们对免复述方法的工作,以减轻灾难性遗忘。免复述持续学习:最近的提议使用深度模型反演来产生复述图像。虽然这些方法与生成建模方法相比表现良好,并且仅从少量存储的图像中复述,但模型反演是一个与高计算成本相关的慢过程,在持续学习设置中,此外这些方法的表现比基于复述的方法差很多。
2.2. 视觉变换器中的持续学习
最近的工作已经证明变换器可以很好地推广到看不见的领域。例如,一个研究变化视觉变换器中注意力头的数量,并得出结论,与基于 CNN 的等价物相比,ViTs 在对抗遗忘方面提供了更多的鲁棒性。另一个显示,从零开始训练的 ViT 的普通对应物更倾向于遗忘。最后,DyTox 学习一个统一的模型,采用参数隔离方法,并动态扩展最后一层处理的标记,以减轻遗忘。对于每个任务,他们使用基于任务注意力的解码器块学习每个头部的新任务特定标记。然而,上述工作要么依赖于示例来对抗遗忘,要么需要从头开始训练一个新的变换器模型,这是一个昂贵的操作,因此与我们使用预训练的 ViTs 进行无示例 CL 的目标不同。持续学习的提示:对于持续学习的提示方法通过学习少量可插入的模型指令(提示)而不是直接修改编码器参数来对抗灾难性遗忘。我们设置中的当前最先进方法 DualPrompt 和 L2P,创建一个提示池,用于插入模型,将输入数据与提示匹配而无需任务 ID,采用局部聚类优化。我们的方法建立在这些方法的基础上,稍后在本文中讨论。我们注意到,最近的 S-Prompts 方法与这些方法类似,但是为域增量学习(即,在协变量分布任务转换下学习相同的类别集)设计的,这与我们论文中的类别增量持续学习设置不同(即,在新任务中学习新的对象类别)。
3. 预备知识
3.1. 持续学习
在我们的持续学习设置中,模型将顺序显示 N 个任务,对应于语义对象类别的不重叠子集。每个类别只出现在一个任务中,目标是随着新对象类别的引入而增量学习对它们的分类,同时保持对以前学习过的类别的性能。为了描述我们的模型,我们将预训练的视觉编码器表示为
3.2. 使用 Prefix-Tuning 的提示
在我们的工作中,我们没有改变以前最先进方法 DualPrompt 的提示技术基础。我们专注于提示的选择和形成(包括使所有提示组件端到端优化),并且结果提示用于视觉变换器编码器与 DualPrompt 相同的方式。这使我们能够公平地比较我们新的贡献。正如在 DualPrompt 中所做的,我们将提示传递给预训练的 ViT 变换器 [13, 63] 的多个多头自注意力(MSA)层,并使用 prefix-tuning 而不是 prompt-tuning,它将提示添加到 MSA 层的键和值之前,而不是添加到输入标记之前。提示发生的层设置为超参数,并且提示参数在层之间是唯一的。我们定义一个提示参数为
其中
结果,我们现在只训练少量参数(提示),而保持其余的变换器编码器未修改。现在剩下的关键问题是如何在连续的方式中选择和更新提示?下一节将讨论以前的工作如何选择和更新提示。
3.3. L2P 和 DualPrompt
L2P 和 DualPrompt 使用基于图像的提示查询从池中选择提示。具体来说,这些方法使用基于键 - 值对的查询策略,从候选提示池中动态选择特定于实例的提示。每个提示
4. 方法
4.1. 提示形成
虽然提示已被证明在减轻灾难性遗忘方面表现非常好,但现有的最先进方法 DualPrompt 缺乏在给定任务中扩展学习能力的能力。具体来说,DualPrompt 为每个新任务学习一个提示——不管任务是简单(例如,学习少量新对象类别)还是复杂(例如,学习大量新对象类别)。有人可能会尝试增加提示的长度,但我们在实验(第 5.3 节)中表明,增加提示长度的回报已经饱和。直观上,我们希望一种学习能力与任务数据的基本复杂性相关的方法是,而不是单一提示。此外,我们希望一种端到端可微分的方法,我们推测这增加了用更高准确性学习新任务的能力。
我们通过引入一个新的轴:一组提示组件来扩展我们的学习能力。而不是从池中选择和选择提示,我们学习一组提示组件,通过加权求和形成分解的提示,该提示传递给相应的 MSA 层。这使我们能够将我们的提示能力扩展到任意深度,并捕获任何任务的丰富底层复杂性,同时保持固定的提示长度。此外,新任务中的提示将固有地重用过去任务的先前获得的知识,而不是从零开始初始化新任务提示。具体来说,我们用加权求和替换可学习的提示参数
其中
4.2. 提示组件加权
而不是提示池,我们有一组提示组件,并希望给定查询
其中
其中
4.3. 扩展与正交性
减轻灾难性遗忘的关键是避免覆盖在先前任务中获得的知识。当我们访问新任务时,我们冻结当前组件并扩展集合,只更新新组件。这在图 2 的底部可视化,其中现有的参数被锁定,只有扩展的参数被优化。具体来说,在任务
其中
4.4. 完整优化
给定任务分类损失
其中
5. 实验
我们在类别增量持续学习设置中使用几个图像数据集对我们的方法进行基准测试。我们实现了不存储训练数据进行复述的基线:Learning without Forgetting (LwF)、Learning to Prompt (L2P) 和 DualPrompt。此外,我们报告了上限性能(即,离线训练)和仅使用新任务训练数据训练的神经网络的性能(我们称之为 FT),并包括了一个改进版本的 FT,它使用与 L2P/DualPrompt 相同的分类器实现(称为 FT++)。我们还与 Experience Replay 进行比较,以提供我们结果的额外上下文。我们使用 PyTorch 实现我们的方法和所有基线,使用在 ImageNet1K 上预训练的 ViT-B/16 主干。我们的贡献包括流行提示基线 L2P 和 DualPrompt 的忠实 PyTorch 实现,这些基线在 JAX 中发布。我们对竞争方法的实现实际上在大多数基准测试中提高了 DualPrompt 的性能,以及由于改进的提示类型,L2P 的性能大幅提升(我们称之为 L2P++)。DualPrompt 在层 1-2 中使用长度 5 的提示(称为通用提示),在层 3-5 中使用长度 20 的提示(称为任务专家提示)。我们为 CODA-P 在与 DualPrompt 相同的层(层 1-5)中插入提示,并使用长度为 8 和 100 个提示组件,这是通过在验证数据上进行超参数扫描选择的,以在性能和参数效率之间获得最佳权衡。因为我们的方法比 DualPrompt 引入了更多的可学习参数,所以我们包括了我们方法的一个变体 CODA-P-S,它使用与 DualPrompt 相同数量的参数进行额外比较。我们表明,即使在这种情况下,我们的方法也优于其他方法,同时仍然保持如果需要的话可以扩展性能的能力。我们在测试数据集上报告结果,但强调所有超参数和设计决策(包括基线)是使用验证数据(训练数据的 20%)制作的。与 DualPrompt 不同,我们的基准测试对任务类别顺序进行了几种不同的洗牌,并报告这些运行的平均值和标准差。我们使用一致的种子(每次试验不同)这样做,以便结果可以直接比较。这是至关重要的,因为由于任务难度和出现顺序的差异,不同洗牌顺序的结果可能会有所不同。我们在附录 A 和 B 中包括了额外的实现细节和结果。评估指标:我们使用(1)平均最终准确率
5.1. CODA-P 在现有基准测试中设定了 SOTA
我们首先在几个建立的基准测试上评估我们的方法和最新技术。表 1 提供了由 200 个对象类别组成的 ImageNet-R 的结果,包括卡通、涂鸦和原始 ImageNet 数据集中的困难示例。这个基准测试很有吸引力,因为训练数据的分布与预训练数据(ImageNet)有显著的距离,因此提供了一个公平和具有挑战性的问题设置。除了原始的 10 任务基准测试,我们还提供了较小数量的大任务(5 任务)和较大数量的小任务(20 任务)的结果。我们首先注意到,我们报告的 DualPrompt 的性能比原始的 DualPrompt 论文高出几个百分点。我们在不同的框架(PyTorch)中从头开始重新实现了该方法,怀疑差异与我们对不同类别顺序洗牌的平均有关。我们注意到,我们的 L2P 实现(L2P++)比原来报告的性能要好得多,因为我们使用了与 DualPrompt 相同的提示形式(为了公平比较)。此外,我们的“Deep L2P++”,它与 DualPrompt 一样在 5 个层中提示(而不是仅在层 1 中),实际上与 DualPrompt 表现相似。重要的是,我们看到我们的方法在所有三个任务长度的平均准确率上有强劲的增长,平均准确率比 DualPrompt 提高了 4.5%。我们提醒读者,CODA-P 是我们提出的方法,具有调整的提示长度和组件集大小,而 CODA-P-S 被缩小(见附录 A 以获得确切细节),以与 DualPrompt 的可学习参数数量完全匹配。我们注意到我们的方法通常有略高的平均水平遗忘与 DualPrompt 相比。鉴于平均准确率是衡量实际性能的关键指标,涵盖了方法的可塑性和遗忘,我们不担心遗忘的轻微上升。事实上,我们认为这是非常合理和反映了我们方法的优势:我们的方法有更大的能力学习新任务,因此我们可以牺牲略高的遗忘。关键的是,我们看到随着任务序列长度的增长,我们的方法与 DualPrompt 的遗忘指标开始收敛到类似的值。我们提供了经验重放的结果,以提供我们结果的额外上下文。虽然对于短任务序列,具有相当大的核心集大小的重放与我们方法之间的差距很小,但我们看到我们的方法在更长的任务序列中强烈优于重放。表 2a 和 2b 提供了在额外的 10 任务 CIFAR-100 和 5 任务 DomainNet 基准测试上的结果。虽然这两个数据集在挑战性和与预训练分布的距离方面都不如 ImageNet-R 有影响力,但这些表格为我们方法的性能提供了额外的上下文。这些表格展示了与 ImageNet-R 基准测试相似的故事,分别提高了 +3.2% 和 +2.5%。我们确实注意到,LwF 在这些基准测试上略微优于提示。
5.2. CODA-P 为新的双转换基准测试设定了 SOTA
我们还在具有挑战性的双转换基准测试上进行了评估,使用了 ImageNet-R 数据集。我们的动机是展示对两种不同类型的持续分布转换的鲁棒性:语义和协变量。我们通过从 ImageNet-R 数据集中随机选择图像类型来包括在每个任务的训练数据中(同时保持评估数据不变)来实现这一点。读者被引用到附录 C 以获取更多详细信息。这个基准测试的结果在表 3 中提供。与我们的方法是 ImageNet-R 的 5 任务基准测试中 SOTA 提高了 4.5% 平均准确率相比,我们看到这个 5 任务基准测试在类似的遗忘范围内提高了 4.4%。我们的结果表明 CODA-P 更好地泛化到现实世界类型的转换。
5.3. 消融研究和额外分析
我们通过消融研究和额外分析更仔细地查看我们的方法。在表 4 中,我们展示了移除我们方法的几个关键组件的影响:注意力键、过去任务组件的冻结和我们的正交性正则化。我们表明,移除注意力略微降低了遗忘,降低了平均准确率(更重要的指标)。这是有道理的,因为移除注意力键使我们的查询处理更接近 L2P/DualPrompt 方法,这些方法拥有低遗忘但缺乏足够的学习能力。我们看到在移除冻结和正交性时出现了更大的下降。这表明这些方面对我们的方法至关重要。我们的直觉是,没有这些,我们的提示形成类似于浅 MLP 模块,如果不加以规范,应该会受到高遗忘的影响。我们还查看了我们的能力,沿着我们新引入的提示组件维度扩展模型容量,使用 5 任务 ImageNet-R 基准测试(与验证数据)。我们在图 3 中展示了,与在 DualPrompt 中学习的提示数量相等的提示组件数量时,我们实现了更高的性能。重要的是,当我们增加组件数量时,我们能够利用扩展的参数实现显著更高的性能,发现我们的方法比 DualPrompt 更接近上限性能而不是平均准确率。因为 L2P 的提示池也可以扩展到任意大小,我们也在这个分析中包括了结果。然而,我们表明 L2P 方法在池大小等于任务序列长度的两倍时达到峰值,然后性能下降。这反映了我们的提示组件从规模中受益,而现有的提示池实际上遭受了损失。最后,我们使用相同的设置展示了平均准确率与提示长度的关系图 4。这个实验的目的是强调准确率随着提示长度的饱和,从而激发了在我们引入的组件维度扩展的需求。附录 B 中有额外的分析和超参数扫描。
6. 结论
我们介绍了 COntinual decomposed attention-based prompting(CODA-Prompt)用于免复述的持续学习。我们的方法组装了可学习的提示组件,然后插入到预训练的 ViT 编码器中进行图像分类。重要的是,CODA-Prompt 是端到端优化的(与涉及两个单独优化的先前 SOTA 方法不同)。此外,CODA-Prompt 可以将提示能力扩展到任意大小。我们在建立的基准测试和包含语义和协变量转换的双分布转换基准测试(突出了我们方法的潜在现实世界影响和普遍性)上设定了新的 SOTA。
致谢
这项材料是基于在国家科学基金会的支持下进行的工作,授权号为 2239292。