L2P
Learning to Prompt for Continual Learning (opens new window)
Code (opens new window) | CVPR 2022
大模型译 ⬇️
0. 摘要
在持续学习的背后主流范式是适应模型参数以适应非平稳数据分布,其中灾难性遗忘是核心挑战。典型方法依赖于复现缓冲区或在测试时已知的任务身份来检索学到的知识和解决遗忘问题,而这项工作提出了一种新的持续学习范式,旨在训练一个更简洁的内存系统,在测试时不访问任务身份。我们的方法学会动态提示(L2P)预训练模型,顺序地在不同任务转换下学习任务。在我们的提议框架中,提示是小的可学习参数,它们被维护在内存空间中。目标是优化提示以指导模型预测,并显式管理任务不变和任务特定知识,同时保持模型的可塑性。我们在不同的挑战性持续学习设置下,对流行的图像分类基准进行了全面实验,L2P 在所有基准上始终优于之前的最佳方法。令人惊讶的是,即使没有复现缓冲区,L2P 也取得了与基于复现的方法相媲美的结果,并且可以直接应用于具有挑战性的任务不可知持续学习。源代码可在 https://github.com/google-research/l2p 上找到。
1. 引言
与在独立同分布(i.i.d.)数据上训练的普通监督学习不同,持续学习解决的是训练一个单一模型在非平稳数据分布上的问题,其中不同的分类任务是顺序呈现的。然而,由于模型只在学习周期的各个阶段访问当前数据,因此它容易过度拟合当前可用的数据,并且由于灾难性遗忘,对之前训练过的数据的性能会恶化 [37]。
在持续学习中的大部分工作遵循的范式是通过持续适应整个或部分模型权重来学习,随着数据分布的转移,重点是保留过去的知识 [9,34]。尽管许多类型的方法取得了良好的结果,但仍有一些关键限制需要解决。首先,根据海马体的情节记忆根据互补学习系统(CLS)理论 [23,36],许多最先进的方法 [3,4,8] 依赖于复现缓冲区来重新训练过去的例子的一部分。然而,它们在缓冲区大小较小时会遭受严重的性能恶化 [4],并且在不允许复现缓冲区的情况下变得无效——例如,在数据隐私至关重要的真实世界场景中 [54]。这表明,简单地缓冲过去数据并重新训练模型可能不是检索过去知识的最佳方法。在没有访问复现缓冲区的情况下,另一部分工作 [19,26,45] 通过假设在测试时已知任务身份来绕过遗忘问题,以便它们能够在共享模型上附加任务独立模块进行推理。然而,知道测试时的任务身份限制了实际使用。
先前工作的局限性在持续学习中提出了关键问题 [13,16]:(1)情节记忆的形式能否超越缓冲过去数据到更智能和简洁的情节记忆系统?(2)如何在不知道其任务身份的情况下自动选择任意样本的相关知识组件?
为了回答第一个问题,我们从最近基于提示的学习(提示)中汲取灵感,这是自然语言处理(NLP)领域中的一种新的迁移学习技术。提示技术设计模型文本输入,用包含额外任务特定信息的模板化或可学习的提示标记,以便预训练的语言模型可以处理参数化输入以执行特定于提示的预测 [25,27,53]。直观地说,基于提示的学习将下游任务的学习重新定义为不是直接适应模型权重,而是设计“指导”模型有条件执行任务的提示。提示编码任务特定知识,并且比普通微调更有效地利用预训练的冻结模型 [25,47]。因此,利用提示来学习知识,并进一步存储学到的知识,在持续学习的背景下是很有希望的。
然而,如何将提示应用于直接解决持续学习中的上述第二个问题尚不清楚:一方面,如果我们在持续学习背景下为不同任务训练不同的提示,在测试时仍然需要任务身份以使用适当的任务特定提示进行预测。另一方面,作为迁移学习技术,提示的目标是使冻结的预训练模型在下游单独实现良好的性能,而不是顺序地。因此,如果我们改为对所有任务维护一个单一共享提示,灾难性遗忘的问题可能仍然存在(见第 5.4 节)。
为此,我们提出了一种新的持续学习方法,称为学习提示以进行持续学习(L2P),它与流行的基于复现的方法正交,并且适用于实际持续学习场景,无需已知的任务身份或边界。图 1 给出了我们的方法与典型持续学习方法的对比概述。L2P 利用预训练模型的代表性特征;然而,与在持续学习过程中调整参数不同,L2P 保持预训练模型不变,而是学习一组动态指导模型解决相应任务的提示。具体来说,提示被结构化为一个称为提示池的键值共享内存空间中,我们设计了一个查询机制,根据实例输入特征动态查找任务相关的提示子集。与监督损失联合优化的提示池确保共享提示编码共享知识以进行知识转移,而不共享的提示编码任务特定知识,有助于保持模型的可塑性。我们的显式设计分离了共享和任务特定知识,从而在优化过程中大大减少了任务特定知识之间的干扰,导致没有复现缓冲区的灾难性遗忘。逐实例查询机制消除了知道任务身份或边界的必要性,使得最具有挑战性但研究不足的任务不可知持续学习成为可能。然后,所选提示被添加到输入嵌入(图 2)之前,这隐式地为预训练模型添加了与任务相关的指令,以便模型回忆起进行相应任务最相关的特征。总之,这项工作做出了以下贡献:
我们提出了 L2P,一种基于提示的持续学习新框架,为持续学习提供了一种通过学习提示池内存空间来应对持续学习挑战的新机制,这些提示池作为参数化的“指令”,用于预训练模型顺序学习任务。该方法适用于处理最具挑战性的任务不可知持续学习。
我们在多个持续学习基准上进行了全面实验,包括类别和领域增量,以及任务不可知设置。提出的 L2P 在所有基准上始终优于之前的最佳方法。令人惊讶的是,即使没有复现缓冲区,L2P 仍然取得了与基于复现的方法相媲美的结果,这在复现缓冲区被禁止的真实世界场景中是理想的。
据我们所知,我们是第一个在持续学习领域引入提示概念的人。我们期望我们的方法为解决持续学习前沿挑战提供了不同的视角。
2. 相关工作
在这里我们建立联系并讨论我们的方法与相关工作之间的差异。
2.1. 持续学习
持续学习通常被定义为在顺序任务的非平稳数据上训练机器学习模型。我们定义一系列任务
2.2. 基于提示的学习和基线
基于提示的学习是 NLP 中的新兴技术。与传统的监督微调不同,这类方法设计了特定于任务的提示函数,以指导预训练模型执行相应任务的条件 [29]。最近的一项技术,提示调整(PT)[25],提出了通过学习提示参数来条件冻结的 T5 类语言模型 [47],以执行下游 NLP 任务,这些提示参数被添加到输入标记以指导模型预测。不失一般性,这里我们使用图像模态的变换器基础序列模型 [10,56] 来介绍 PT 的定义。该定义很容易推广到其他模态和基于序列的模型。给定一个 2D 图像
2.3. 提示迁移学习
提示的核心思想是将一个函数应用于修改输入文本,以便语言模型能够获得关于任务的额外信息。然而,提示函数的设计是具有挑战性的,需要依赖启发式方法。最近的工作,包括提示调整(PT)和前缀调整(Prefix Tuning)[27],试图通过在连续空间中应用可学习的提示来解决这个问题,它们在迁移学习中取得了优异的表现。提示以比竞争对手更小的额外参数捕获特定任务的知识,例如 Adapter[43,58] 和 LoRA[18]。提示的核心思想主要是为迁移学习设计的。注意,直接将提示应用于持续学习是非平凡的。我们提出的新框架揭示了其对持续学习问题的价值。
3. 预备知识
3.1. 持续学习协议
持续学习通常被定义为在顺序任务的非平稳数据上训练机器学习模型。我们定义一系列任务
3.2. 基于提示的学习和基线
基于提示的学习是 NLP 中的新兴技术。与传统的监督微调不同,这类方法设计了特定于任务的提示函数,以指导预训练模型执行相应任务的条件 [29]。最近的一项技术,提示调整(PT)[25],提出了通过学习提示参数来条件冻结的 T5 类语言模型 [47],以执行下游 NLP 任务,这些提示参数被添加到输入标记以指导模型预测。不失一般性,这里我们使用图像模态的变换器基础序列模型 [10,56] 来介绍 PT 的定义。该定义很容易推广到其他模态和基于序列的模型。给定一个 2D 图像
4. 学习提示(L2P)
4.1. 从提示到提示池
引入提示池的动机有三个。首先,测试时不知道任务身份,因此训练任务独立的提示是不可行的。其次,即使可以在测试时知道与任务无关的提示,它阻止了类似任务之间的知识共享 [16]。第三,虽然学习一个单一共享提示用于所有任务的简单方法可以共享知识,但它仍然会导致严重的遗忘问题(见第 5.4 节)。理想情况下,我们希望能够学习一个模型,当任务相似时能够共享知识,同时保持知识的独立性。因此,我们提出使用提示池来存储编码的知识,它可以被灵活地组合作为输入提供给模型。提示池定义为:
其中
其中;表示沿标记长度维度的连接。提示可以自由组合,因此它们可以共同编码知识(例如视觉特征或任务信息),供模型处理。理想情况下,我们希望通过提示组合实现更细粒度的知识共享方案:类似的输入倾向于共享更多的共同提示,反之亦然。
4.2. 实例级提示查询
我们设计了一个基于键值对的查询策略,为不同输入动态选择适当的提示(见图 2)。这种基于键值的记忆查询机制与其他领域的一些方法有一些设计原则相同,例如可微神经计算机 [14] 和 VQ-VAE[41],它们具有外部记忆以维护,并将其用于不同的目的。我们将每个提示作为值与一个可学习的键关联起来:{(k_1, P_1), (k_2, P_2), \cdots, (k_M, P_M)},其中
其中
其中
4.3. L2P 的优化目标
在每个训练步骤中,根据上述查询策略选择
其中
5. 实验
为了评估提出的 L2P,我们紧密跟随之前工作 [32,55,66] 的设置,并进行了全面的实验。特别是,我们主要考虑(1)类别增量设置,其中在推理期间任务身份是未知的;(2)领域增量设置,其中输入领域随时间变化;(3)任务不可知设置,其中没有明确任务边界。我们仔细比较了不同类别的 L2P 与最先进的(SOTA)方法在适当的实验设置下的表现。此外,我们进行了广泛的消融研究,以更深入地理解我们的方法。
5.1. 比较方法
我们比较了 L2P 与几个基线和最先进的(SOTA)持续学习方法。我们的方法基于预训练的 ViT-B/16[11,67],这已成为先进视觉社区的共同资产。我们仔细选择了在相同环境下进行比较的方法,以进行公平比较。许多最近的方法声称在最简单的任务增量设置中实现了 SOTA 性能,其中在测试时知道任务身份 [19,45,57]。我们不包括这些方法,因为它们不适用于更一般的类别增量设置。我们参考了多篇最近的综述论文 [9,34] 和最近的工作 [3,4,46],选择了最被认可和表现最好的方法。为了完整性,我们还包括了简单的顺序训练方法和代表性的基于正则化的方法。此外,我们参考了原始代码库以实现和超参数选择,以确保最佳可能的性能。
基线方法。上界是通常的监督微调方法,对所有任务的 i.i.d.数据进行微调,通常被认为是方法可以达到的上界性能。
FT-seq-frozen 是预训练模型冻结的简单顺序微调方法。FT-seq 代替微调预训练模型权重。EWC[21] 和 LwF[28] 是广泛比较的代表性基于正则化的方法。
SOTA 基于复现的方法。我们选择了 5 个先进的基于复现的方法进行比较,包括 ER[8,17]、GDumb[46]、BiC[61]、DER++[3] 和 Co2L[4]。ER 和 GDumb 在概念上很简单,但它们不仅在自己的工作中,而且在后来的文献 [3,34] 中也取得了非常强的性能。DER++ 和 Co2L 是最新的 SOTA 方法。
SOTA 基于架构的方法。我们选择了两个代表性的基于架构的方法进行比较。SupSup[60] 和 DualNet[44] 都基于 ResNet18。
5.2. 数据集和实验细节
数据集。我们使用 Split CIFAR-100[22] 和 5datasets[12] 进行类别增量设置,CORE50[30] 进行领域增量设置,以及 Gaussian scheduled CIFAR-100[52] 进行任务不可知设置,以评估我们方法的有效性。数据集的详细信息在附录 C 中介绍。
评估指标。对于具有任务边界和每个任务都有相关测试集的设置,我们使用两个指标,平均准确率(越高越好)和遗忘(越低越好),这些在以前的工作中广泛使用 [7,32,34]。对于没有任务边界或只有一个测试集可用的设置,我们遵循常见协议报告最终测试准确率 [30,52]。
训练细节。对于 L2P,我们使用 Adam[20] 训练所有模型,其中
5.3. 主要结果
类别增量学习结果。表 1 总结了这两个类别增量基准的结果。L2P 在不同配置下一致地超越了所有比较的方法,无论是在平均准确率还是遗忘方面。我们观察到,当缓冲区大小相对较大时,L2P 不仅超越了所有其他方法,而且还显著缩小了与 i.i.d.设置下上界性能的差距。当缓冲区大小变小时,L2P 以更大的优势超越了其他方法。最后,当没有缓冲区时,基于复现的方法不再适用,而 L2P 仍然通过击败基于正则化的方法并超越几乎所有小缓冲区的基于复现的方法而保持优越性能。
表 2 显示了 L2P 与基于架构的方法在 Split CIFAR-100 上的比较。我们不是用平均准确率的绝对性能来衡量每种方法的性能,而是用与上界(Diff)的差异来衡量给定特定架构的每种方法的性能。我们观察到,L2P 无论是有(DualNet)还是没有(SupSup)复现缓冲区,都以较大的优势超越了这两种方法。
L2P 在所有竞争方法中的卓越性能表明,我们提出的提示池成功地积累了经验知识,因此它能够总体上提高学习性能,即使没有复现缓冲区,也能减轻灾难性遗忘。
领域增量学习结果。表 3 总结了领域增量设置的结果。L2P 与其他方法相比,保持了最佳性能。有趣的是,所有基于复现的比较方法表现相当接近(除了 GDumb)。基线方法与上界结果之间的性能差距相对较小的观察结果也在 [30] 中报告,因此我们的方法与其他方法之间确实存在显著的性能差距。
任务不可知学习结果。尽管任务不可知设置通常被认为更具挑战性 [52],但这个话题的研究还不足。
我们对任务不可知设置进行了更多的探索性研究。表 4 总结了在具有挑战性的任务不可知学习设置上的结果。我们不与 LwF、BiC 和 Co2L 进行比较,因为它们需要任务边界来保存模型快照并计算蒸馏损失。扩展它们到这个设置的范围超出了我们的能力。我们还使用了 [5] 提出的在线版本的 EWC 来应对任务不可知设置。由于所有比较的方法都基于预训练模型,绝对数字与上界并不太远。可以看出,基于复现的方法有明显的优势。然而,即使缓冲区大小为零,L2P 仍然在所有方法中取得了最佳性能,包括那些有复现缓冲区的方法。我们认为,任务的更平滑过渡隐式地帮助 L2P 将知识巩固到提示中。由于我们有更好的提示,复现缓冲区的好处自然被削弱了。
5.4. 核心设计的有效性
L2P 的提示相关组件的有效性。表 5(第 1 行)去除了提示池设计,并使用单一提示顺序训练。性能显著下降,表明单一提示遭受了严重的灾难性遗忘和任务间知识干扰,而我们的提示池设计很好地编码了任务不变和任务特定知识。表 5(第 2 行)去除了与提示相关的可学习键,并直接使用提示的平均值作为键。结果表明,可学习的键在解耦查询和提示学习过程中发挥了重要作用。表 5(第 3 行)去除了多样化提示选择(仅在 5-datasets 实验中使用)。基本上,去除它允许不同任务的实例自由选择提示。性能下降表明,当任务多样化时,增加这种策略确实减少了不必要的知识共享,从而减轻了不相关任务之间的干扰。为了更好地理解提示选择机制,我们在图 3 中为 Split CIFAR-100 和 5-datasets 绘制了每个任务的最佳参数设置下的提示选择直方图。从 Split CIFAR-100 的图表(左)中,任务在很大程度上共享所有提示,这意味着我们的提示选择机制鼓励在类似任务之间共享更多知识。相比之下,在 5-datasets 的图表(右)中,多样化的任务需要更多的任务特定提示,共享更少。L2P 的超参数的有效性。回想一下,有三个关键超参数,包括提示池的大小
6. 结论
本文提出了一种新方法来解决持续学习中的一些关键挑战,该方法能够在不需要复现和任务身份的情况下实现强大的性能。L2P 将基于提示的学习引入持续学习,并提出了一种新技术,使单一预训练模型能够通过共享提示池适应顺序任务,成功地减轻了灾难性遗忘问题。结果表明,该方法在多个持续学习问题上显著优于以前的 SOTA,包括类别增量和领域增量。我们展示了我们的方法足够通用,能够处理更具挑战性的任务不可知设置,以前的方法无法应对。
A. 潜在的负面社会影响
L2P 是一种强大的持续学习方法,有很大的潜力被应用在各个领域。然而,它也有可能被滥用的方式。我们的方法采用一个预训练的模型作为主干,因此原始模型中的任何偏见和公平性问题 [38] 可能会在持续学习过程中被继承。我们鼓励任何用户彻底检查预训练模型,以减轻任何偏见和公平性问题。此外,该方法可能被部署在安全关键的应用中,例如自动驾驶系统 [15],这可能在对抗性攻击 [33] 方面带来潜在的安全问题。我们建议在未来的工作中测试我们方法的鲁棒性,并设计相应的防御技术来应对潜在的安全问题。
B. 局限性
尽管我们的方法在视觉模型上进行了演示,但它并没有对模态做出任何假设。我们将探索其他模态的工作留作未来的研究。此外,L2P 假设存在预训练的基于序列的模型。虽然它们已经成为先进社区中的常见资产和未来方向,但如何将我们的框架推广到其他视觉架构(例如 ConvNet)可能是一个吸引人的研究方向。实现能够满足现实世界要求的持续学习是一个重要的研究方向,这仍然是一个挑战。例如,任务不可知设置被认为是最具挑战性的设置,非常接近现实世界场景。尽管我们的方法朝着这个目标迈进了一步,但是目前常用的高斯调度 CIFAR-100 是合成的,仍然远离现实。因此,我们认为也需要更复杂的基准来评估任务不可知持续学习方法的能力,并推动这一现实世界挑战的进步。
C. 数据集详情和许可信息
Split CIFAR-100(类别增量)。这个数据集将原始的 CIFAR-100[22] 分成 10 个任务,每个任务有 10 个不相交的类别。由于任务来自同一个原始数据集,它们之间存在一些相似性,一些类别可能来自同一个超类别。尽管 CIFAR-100 是一个简单的图像分类数据集,它对持续学习研究来说仍然相当具有挑战性,特别是在类别增量设置中 [34]。
5-datasets(类别增量)。我们还使用了 [12] 中提出的一个具有挑战性的数据集。这个数据集由五个图像分类数据集组成:CIFAR-10、MNIST[24]、Fashion-MNIST[62]、SVHN[40] 和 notMNIST[2]。尽管每个数据集本身并不难,但它们的顺序训练即使使用 ImageNet 预训练的模型也是相当具有挑战性的,因为模型在任务多样化时容易遗忘 [39]。
CORE50(领域增量)。这是一个特别为持续目标识别设计的广泛使用的数据集 [30]。它收集了 50 个对象,分布在 11 个不同的领域中,其中 8 个领域(120,000 个样本)用于训练,其余的被视为单个测试集(45,000 个样本)。方法按领域顺序训练。
高斯调度 CIFAR-100(任务不可知)。数据的分布在整个学习过程中逐渐变化 [52],一个类别在批次中出现的概率遵循以间隔为中心的高斯分布。批次之间没有明确的任务边界,因此需要方法能够在不使用任何任务特定信息的情况下隐式适应非平稳数据分布,无论是在训练还是推理期间。
- CIFAR-10 和 CIFAR-100[22],Fashion-MNIST[62] 在 MIT 许可下授权。
- MNIST[24] 在创作共用署名 3.0 许可下授权。
- CORE50[30] 在创作共用署名 4.0 国际许可下授权。
- SVHN[40] 和 notMNIST[2] 的许可信息不可用。
D. 算法细节
为了更好地说明我们提出的方法,我们在算法 1 中展示了训练程序的完整图景。注意,对于预测,我们简单地将损失计算替换为标签预测。可选地,当已知任务边界先验时,我们可以用方程 4 替换顶部 N 个键的查找。
