WKL
0. 摘要
自Hinton等人的开创性工作以来,基于Kullback-Leibler散度(KL-Div)的知识蒸馏已成为主流,其变种在近期取得了令人瞩目的性能。然而,KL-Div仅比较教师和学生模型之间对应类别的概率,缺乏跨类别比较的机制。此外,KL-Div在应用于中间层时存在问题,因为它无法处理非重叠分布,并且不了解底层流形的几何结构。为了解决这些问题,我们提出了基于Wasserstein距离(WD)的知识蒸馏方法。具体来说,我们提出了一种基于离散WD的logit蒸馏方法WKD-L,该方法通过跨类别比较概率,从而能够显式利用类别之间的丰富关联。此外,我们引入了一种基于连续WD的特征蒸馏方法WKD-F,该方法使用参数化方法对特征分布进行建模,并采用连续WD从中间层传递知识。在图像分类和目标检测上的综合评估表明:(1)在logit蒸馏方面,WKD-L优于非常强大的KL-Div变种;(2)在特征蒸馏方面,WKD-F优于KL-Div的对应方法和最先进的竞争对手。源代码可在https://peihuali.org/WKD (opens new window)获取。
1. 引言
知识蒸馏(KD)旨在将高性能、大容量的教师模型的知识转移到轻量级的学生模型中。近年来,它在深度学习中引起了越来越多的关注,并在视觉识别和目标检测等领域得到了广泛应用[1]。在Hinton等人的开创性工作中,他们引入了Kullback-Leibler散度(KL-Div)来进行知识蒸馏,其中学生模型的类别概率预测被约束为与教师模型相似[2]。自那时起,KL-Div在logit蒸馏中占据了主导地位,并且其变种[3;4;5]在近期取得了令人瞩目的性能。此外,这些logit蒸馏方法与许多从中间层传递知识的最先进方法[6;7;8]是互补的。
尽管取得了巨大成功,KL-Div存在两个缺点,阻碍了教师模型知识的完全传递。首先,KL-Div仅比较教师和学生模型之间对应类别的概率,缺乏跨类别比较的机制。然而,现实世界中的类别在视觉上存在不同程度的相似性,例如,哺乳动物如狗和狼在视觉上更为相似,而与汽车和自行车等人工制品则截然不同。深度神经网络(DNNs)能够区分数千个类别[9],并且对这些复杂关系有很好的理解,如图1a所示。不幸的是,由于其类别到类别的性质,经典的KD[2]及其变种[3;4;5]无法显式利用这种丰富的跨类别知识。
其次,KL-Div在从中间层蒸馏知识时存在问题。图像的高维特征通常尺寸较小,因此在特征空间中分布非常稀疏[10, 第2章]。这不仅使得KL-Div所需的非参数密度估计(例如直方图)由于维度灾难而不可行,还导致了KL-Div无法处理的非重叠离散分布[11]。人们可能会转向参数化、连续的方法(例如高斯分布)来建模特征分布。然而,KL-Div及其变种在测量连续分布之间的差异时能力有限,因为它不是一种度量[12],并且不了解底层流形的几何结构[13]。
Wasserstein距离(WD)[14],也称为Earth Mover's Distance(EMD)或最优传输,有潜力解决KL-Div的局限性。两个概率分布之间的WD通常定义为将一个分布转换为另一个分布的最小成本。一些工作已经探索了使用WD从中间层进行知识传递[15;16]。具体来说,他们基于离散WD测量教师和学生模型之间一批图像的差异,这种方式关注的是跨实例的软比较,未能利用跨类别的关系。此外,他们主要追求非参数化的分布建模方法,性能落后于最先进的基于KL-Div的方法。
为了解决这些问题,我们提出了一种基于Wasserstein距离的知识蒸馏方法,称为WKD。该方法适用于logits(WKD-L)以及中间层(WKD-F),如图1b所示。在WKD-L中,我们使用离散WD来最小化教师和学生模型预测概率之间的差异,从而实现跨类别比较,有效利用类别之间的相互关系,与经典KL-Div的类别到类别比较形成了鲜明对比。我们提出使用Centered Kernel Alignment(CKA)[17;18]来量化类别之间的相互关系,它测量任意两个类别之间特征的相似性。
对于WKD-F,我们将WD引入中间层以从特征中提取知识。与logits不同,中间层不涉及类别概率。因此,我们让学生模型直接匹配教师模型的特征分布。由于DNN特征的维度较高,非参数化方法(例如直方图)由于维度灾难而不可行[10, 第2章],我们选择参数化方法进行分布建模。具体来说,我们使用最广泛使用的连续分布之一(即高斯分布),它在给定从特征估计的一阶和二阶矩时具有最大熵[19, 第1章]。高斯分布之间的WD可以以闭式形式计算,并且是底层流形上的黎曼度量[20]。
我们在以下方面总结了我们的贡献:
- 我们提出了一种基于离散WD的logit蒸馏方法(WKD-L)。它可以通过跨类别比较教师和学生模型预测的概率,利用类别之间的丰富关系,克服了类别到类别KL散度的缺点。
- 我们将连续WD引入中间层进行特征蒸馏(WKD-F)。它可以有效利用高斯分布的黎曼空间几何结构,优于不了解几何结构的KL散度。
- 在图像分类和目标检测任务中,WKD-L的表现优于非常强大的基于KL-Div的logit蒸馏方法,而WKD-F在特征蒸馏方面优于KL-Div的对应方法和竞争对手。它们的结合进一步提高了性能。
2. 用于知识传递的Wasserstein距离
给定一个预训练的高性能教师模型T,我们的任务是训练一个轻量级的学生模型S,使其能够从教师模型中蒸馏知识。因此,学生的监督来自具有交叉熵损失的真实标签和来自教师模型的蒸馏损失,这将在接下来的两节中描述。
2.1 离散 Wasserstein 距离用于 Logit 蒸馏
类别之间的相互关系(IRs)。如图1a和图4所示,现实世界中的类别在特征空间中表现出复杂的拓扑关系。例如,哺乳动物物种彼此更接近,而远离人工制品或食物。此外,同一类别的特征聚集在一起形成分布,而相邻类别的特征重叠且无法完全分离。因此,我们提出基于CKA[18]来量化类别之间的相互关系,CKA是一种归一化的Hilbert-Schmidt独立准则(HSIC),它通过将两组特征映射到再生核希尔伯特空间(RKHS)来建模它们的统计关系[21]。
给定类别
其中
损失函数。给定一个输入图像(实例),我们令
我们分别用
KL 散度(2)仅比较教师和学生模型之间对应类别的预测概率,本质上缺乏进行跨类别比较的机制,如图 2 所示。尽管在梯度反向传播过程中,由于 softmax 函数的存在,一个类别的概率会影响其他类别的概率,但这种隐式影响是微弱的,最重要的是,它无法显式利用丰富的成对相互关系知识,如公式(1)所述。
与 KL 散度不同,WD 进行跨类别比较,因此自然利用了类别之间的相互关系,如图 1b(左)所示。我们将离散 WD 公式化为一个熵正则化的线性规划 [23]:
约束条件为
其中
最近的工作 [3] 揭示了目标概率(即目标类别的概率)和非目标概率在训练中扮演不同的角色:前者关注训练样本的难度,而后者包含显著的“暗知识”。研究表明,这种分离有助于平衡它们的角色,并大大提高了经典 KD 的性能 [3;4]。受此启发,我们也考虑了类似的分离策略。令
其中
2.2 连续 Wasserstein 距离用于特征蒸馏
由于 DNN 中间层输出的特征通常是高维且尺寸较小,因此非参数化方法(例如直方图和核密度估计)不可行。因此,我们使用最广泛使用的参数化方法之一(即高斯分布)进行分布建模。
2.2.1 特征分布建模
给定一个输入图像,我们考虑 DNN 模型某个中间层输出的特征图,其空间高度、宽度和通道数分别为
我们通过一个具有均值向量
其中
2.2.2 损失函数
令高斯分布
其中
其中
KL 散度 [30] 和对称 KL 散度(即 Jeffreys 散度)[31] 在高斯分布的情况下都有闭式表达式 [32],可以用于知识传递。然而,它们不是度量 [12],并且不了解高斯分布空间的几何结构 [13],而该空间是一个黎曼空间。相反,Wasserstein 距离是黎曼度量,能够测量内在距离 [20]。值得注意的是,G²DeNet[33] 提出了一种基于李群的高斯分布之间的度量,可以用于定义蒸馏损失。除了高斯分布外,还可以使用拉普拉斯分布和指数分布对特征分布进行建模。最后,尽管直方图或核密度估计是不可行的,但仍然可以用概率质量函数 ( PMF ) 来建模特征分布,并相应地使用离散 WD 来定义蒸馏损失。
3. 相关工作
我们总结了与我们的方法相关的 KD 方法,并在表 1 中展示了它们的联系和区别。
3.1 基于 KL 散度的知识蒸馏
Zhao 等人 [3] 揭示了经典 KD 损失 [2] 是一种耦合形式,限制了其性能,从而提出了一种解耦形式(DKD),该形式由目标类别的二元 logit 损失和所有非目标类别的多类 logit 损失组成。Yang 等人 [4] 提出了一种归一化 KD(NKD)方法,该方法将经典 KD 损失分解为目标损失(类似于广泛使用的交叉熵损失)和归一化非目标预测损失。WTTM[5] 引入了 Rényi 熵正则化,无需对学生模型进行温度缩放。尽管性能竞争激烈,但它们无法显式利用类别之间的关系。相比之下,我们基于 Wasserstein 距离的方法能够进行跨类别比较,从而利用丰富的类别相互关系。
3.2 基于 Wasserstein 距离的知识蒸馏
现有的基于 WD 的 KD 方法 [15;16] 主要关注特征蒸馏的跨实例匹配,如图 3(左)所示。Chen 等人 [15] 提出了 Wasserstein 对比表示蒸馏(WCoRD)框架,该框架包括全局和局部对比损失。前者最小化教师和学生模型分布之间的互信息(通过 WD 的对偶形式);后者仅对倒数第二层的特征进行匹配,最小化一批图像的特征之间的 Wasserstein 距离。Lohit 等人 [16] 独立提出了一种类似的跨实例匹配方法,称为 EMD+IPOT,该方法从所有中间层传递知识,并通过不精确的近似最优传输算法 [34] 计算离散 WD。我们的工作与它们的区别在于:(1)它们未能利用类别之间的相互关系,而我们的 WKD-L 可以充分利用这些关系;(2)它们关注基于离散 WD 的跨实例匹配,而我们的 WKD-F 涉及基于连续 WD 的逐实例匹配。
3.3 其他基于统计建模的方法
NST[35] 是最早将特征蒸馏形式化为分布匹配问题的方法之一,其中学生模型模仿教师模型的分布,基于最大均值差异(MMD)。他们表明,在 MMD 的候选核中,二阶多项式核表现最佳,并且基于激活的注意力转移(AT)[36] 是 NST 的特例。Yang 等人 [37] 提出了一种新颖的损失函数,该方法将学生模型学习的统计信息通过自适应实例归一化传递回教师模型。Liu 等人 [6] 提出了通道间相关性(ICKD-C)来建模特征的多样性和同源性,以实现更好的知识传递。NST 和 ICKD-C 都可以视为基于 Frobenius 范数的分布建模,分别沿空间维度和通道维度的二阶矩,如图 3(右)所示。然而,它们未能利用二阶矩矩阵的几何结构,这些矩阵是对称正定(SPD)的,并形成一个黎曼空间 [38;39]。Ahn 等人 [40] 引入了基于互信息的变分信息蒸馏(VID)。VID 假设特征分布是高斯的,并且如果进一步假设高斯分布具有单位方差,则其损失简化为均方损失(即 FitNet[24])。
4. 实验
我们在 ImageNet[41] 和 CIFAR-100[42] 上评估了 WKD 的图像分类性能。此外,我们还在自知识蒸馏(Self-KD)中评估了 WKD 的有效性。进一步地,我们将 WKD 扩展到目标检测任务,并在 MS-COCO[43] 上进行了实验。我们使用 PyTorch 框架 [44] 进行模型的训练和测试,实验设备为 Intel Core i9-13900K CPU 和 GeForce RTX 4090 GPU。
4.1 实验设置
4.1.1 图像分类
ImageNet:包含 1,000 个类别,训练集有 1.28M 张图像,验证集有 50K 张图像,测试集有 100K 张图像。我们按照 [25] 的设置,使用 SGD 优化器训练模型 100 个 epoch,批量大小为 256,动量为 0.9,权重衰减为 1e-4。初始学习率为 0.1,在第 30、60 和 90 个 epoch 时分别除以 10。我们使用随机裁剪和随机水平翻转进行数据增强。对于 WKD-L,我们使用 POT 库 [45] 求解离散 Wasserstein 距离,参数
,迭代次数为 9。对于 WKD-F,投影器采用瓶颈结构,即一个 1×1 卷积(Conv)和一个 3×3 卷积,均具有 256 个过滤器,最后是一个 1×1 卷积,带有 BN 和 ReLU,以匹配教师模型特征图的尺寸。 CIFAR-100:包含 60K 张 32×32 像素的图像,来自 100 个类别,训练集有 50K 张图像,测试集有 10K 张图像。我们按照 OFA[46] 的设置,在卷积神经网络(CNN)和视觉 Transformer 架构中进行实验。所有模型训练 300 个 epoch,批量大小为 512,采用余弦退火调度。对于基于 CNN 的学生模型,我们使用 SGD 优化器,初始学习率为 2.5e-2,权重衰减为 2e-3。对于基于 Transformer 的学生模型,我们使用 AdamW 优化器,初始学习率为 2.5e-4,权重衰减为 2e-3。
4.1.2 目标检测
- MS-COCO:是一个常用的目标检测基准,包含 80 个类别。按照常规做法,我们使用 COCO 2017 的标准划分,训练集有 118K 张图像,验证集有 5K 张图像。我们采用 Faster-RCNN[47] 框架,并在 Detectron2 平台 [49] 上使用特征金字塔网络(FPN)[48]。与之前的艺术 [50;29;51] 一样,我们使用官方训练并发布的检测模型作为教师模型,而学生模型的骨干网络则使用在 ImageNet 上预训练的权重进行初始化。学生网络训练 180K 次迭代,批量大小为 8;初始学习率为 0.01,在 120K 和 160K 次迭代时分别衰减 0.1 倍。
4.2 WKD 关键组件的分析
我们在 ImageNet 上分析了 WKD-L 和 WKD-F 的关键组件。我们采用 ResNet34 作为教师模型,ResNet18 作为学生模型(即设置 (a)),它们的 Top-1 准确率分别为 73.31% 和 69.75%。详见第 C.1 节对超参数(例如温度和权重)的分析。
4.2.1 WKD-L 的消融实验
- WD 与 KL-Div 的比较:我们比较了 WD 和 KL-Div 在有(w/)和没有(w/o)目标概率分离的情况下的性能,如表 2a 所示。在没有分离的情况下,WD(w/o)比 KL-Div(w/o)提高了 1.0%;在有分离的情况下,WD(w/)显著优于基于 DKD 和 NKD 的 KL-Div(w/)。上述比较清楚地表明:(1)WD 在这两种情况下均优于 KL-Div;(2)分离策略对 WD 也至关重要。因此,我们在全文中使用目标概率分离的 WD。

- 类别相互关系的建模方法:表 2b 比较了两种类别相互关系的建模方法,即 CKA 和余弦相似度。对于前者,我们评估了不同的核;对于后者,我们评估了使用分类器权重或类别质心作为原型的性能。我们注意到,所有基于 WD 的方法均显著优于 KL-Div 基线。总体而言,基于 CKA 的 IR 性能优于基于余弦相似度的 IR,表明其能够更好地表示类别之间的相似性。对于基于 CKA 的 IR,RBF 核优于多项式核,而线性核表现最佳,因此我们在全文中使用线性核。
4.2.2 WKD-F 的消融实验

完整协方差矩阵与对角协方差矩阵的比较:如表 3a 所示,对于完整协方差矩阵的高斯分布(Full),WD 优于 G²DeNet[33],表明前者更适合特征蒸馏。当使用 WD 时,对角协方差矩阵的高斯分布(Diag)比完整协方差矩阵的高斯分布(Full)具有更高的准确率。我们推测原因在于高维特征空间中完整协方差矩阵的估计不够稳健 [52];相比之下,对于对角协方差矩阵的高斯分布(Diag),我们只需要估计单维数据的 1D 方差。此外,对角协方差矩阵的高斯分布(Diag)比完整协方差矩阵的高斯分布(Full)更高效。因此,我们在全文中使用对角协方差矩阵的高斯分布(Diag)。
分布建模方法的选择:在表 3a 中,我们比较了不同的参数化方法进行知识蒸馏,包括高斯分布、拉普拉斯分布、指数分布,以及单独的一阶矩和二阶矩。我们注意到,基于高斯分布(Diag)的 KL 散度和对称 KL 散度表现相似,但均低于 WD。原因可能是 KL 相关的散度不是内在距离,无法利用高斯分布流形的几何结构。对于统计矩,我们注意到通道矩优于空间矩。对于通道表示,一阶矩优于二阶矩,表明均值在特征分布中起更重要的作用。最后,基于 PMF 的非参数化方法表现不如基于高斯分布的参数化方法。
逐实例匹配与跨实例匹配的比较:我们的 WKD-F 是一种基于连续 WD 的逐实例匹配方法,而 WCoRD 和 EMD+IPOT 关注基于离散 WD 的跨实例匹配。如表 3b 所示,WCoRD 的准确率显著高于 EMD+IPOT,这可能归因于其额外的基于互信息的全局对比损失;WKD-F 显著优于这两种方法,表明我们的策略具有优势。值得注意的是,WKD-F 的运行速度显著快于依赖于优化算法求解离散 WD 的两种方法。
蒸馏位置和网格方案的影响:我们在表 3c 中评估了进行分布匹配的位置和不同网格方案的影响。从第 3 行和第 4 行可以看出,Conv_5x 阶段的性能显著优于 Conv_4x 阶段,表明高层特征更适合知识传递。比较第 4 行和第 5 行,我们发现 2×2 网格并未优于 1×1 网格。最后,结合 Conv_4x 和 Conv_5x 的特征并未带来进一步的性能提升。因此,我们在 ImageNet 分类任务中使用 Conv_5x 的特征和 1×1 网格。
4.3 ImageNet 上的图像分类实验
我们在 ImageNet 上进行了图像分类实验,并比较了两种设置下的性能。设置(a)涉及同构架构,其中教师和学生网络分别为 ResNet34 和 ResNet18[9];设置(b)涉及异构架构,其中教师为 ResNet50,学生为 MobileNetV1[57]。详细的超参数设置见第 C.2 节。

4.3.1 Logit 蒸馏的比较
在 logit 蒸馏方面,我们将 WKD-L 与 KD[2]、DKD[3]、NKD[4]、CTKD[54] 和 WTTM[5] 进行了比较。我们的 WKD-L 在两种设置下均优于经典的 KD 及其所有变种。特别是在设置(a)中,WKD-L 的 Top-1 准确率为 72.49%,比 KD 提高了 1.46%;在设置(b)中,WKD-L 的 Top-1 准确率为 73.17%,比 KD 提高了 2.67%。值得注意的是,WKD-L 显著优于 WTTM,后者是一种非常强大的 KD 变种,引入了样本自适应的加权方法。这表明,基于 Wasserstein 距离的跨类别比较优于基于 KL 散度的类别到类别比较。
4.3.2 特征蒸馏的比较
在特征蒸馏方面,我们将 WKD-F 与 FitNet[24]、CRD[25]、ReviewKD[29] 和 CAT[55] 进行了比较。我们的 WKD-F 在设置(a)中显著优于 ReviewKD,Top-1 准确率提高了 0.89%;在设置(b)中,WKD-F 的 Top-1 准确率为 73.12%,比 ReviewKD 提高了 0.41%。这表明,在特征蒸馏中,匹配高斯分布比直接匹配特征更为有效。
4.3.3 Logit 与特征蒸馏的结合
我们将 WKD-L 和 WKD-F 结合使用,进一步提升了性能。在设置(a)中,WKD-L+WKD-F 的 Top-1 准确率为 72.76%,比单独的 WKD-L 和 WKD-F 分别提高了 0.27% 和 0.26%;在设置(b)中,WKD-L+WKD-F 的 Top-1 准确率为 73.69%,比单独的 WKD-L 和 WKD-F 分别提高了 0.52% 和 0.57%。这表明,logit 蒸馏和特征蒸馏的结合能够进一步优化知识传递的效果。
4.3.4 与其他最先进方法的比较
我们还将 WKD-L+WKD-F 与其他最先进的方法进行了比较,包括 CRD+KD[25]、DPK[7]、FCFD[8] 和 KD-Zero[56]。在设置(a)中,WKD-L+WKD-F 的 Top-1 准确率为 72.76%,优于 CRD+KD(71.38%)和 FCFD(72.25%);在设置(b)中,WKD-L+WKD-F 的 Top-1 准确率为 73.69%,优于 DPK(72.25%)和 KD-Zero(72.17%)。这表明,WKD-L+WKD-F 在 logit 蒸馏和特征蒸馏的结合方面表现优异。
4.3.5 训练时延的比较
在设置(a)中,我们比较了不同方法的训练时延,结果如表 5 所示。对于 logit 蒸馏,WKD-L 的时延比基于 KL 散度的方法(例如 KD 和 NKD)高约 1.3 倍,这是由于优化离散 Wasserstein 距离的过程较为复杂。对于特征蒸馏,WKD-F 的时延与基于 KL 散度的方法相当,但比 ReviewKD 快约 1.6 倍,比 EMD+IPOT 快约 1.2 倍。这是因为 WKD-F 仅涉及均值向量和方差向量,计算成本较低。最后,WKD-L+WKD-F 的结合方法具有较高的时延,但其性能优于 ICKD-C,并且比最先进的 FCFD 更为高效。
4.4 CIFAR-100 上的图像分类实验
我们在 CIFAR-100 上评估了 WKD 的性能,实验设置包括教师模型为 CNN、学生模型为 Transformer,以及教师模型为 Transformer、学生模型为 CNN 的情况。我们使用的 CNN 模型包括 ResNet(RN)[9]、MobileNetV2(MNV2)[58] 和 ConvNeXt[59],而视觉 Transformer 模型包括 ViT[60]、DeiT[61] 和 Swin Transformer[62]。详细的超参数设置见第 C.5 节。
4.4.1 Logit 蒸馏的比较
在 logit 蒸馏方面,我们将 WKD-L 与 KD[2]、DKD[3]、DIST[63] 和 OFA[46] 进行了比较。如表 6 所示,无论是在从 Transformer 到 CNN 还是从 CNN 到 Transformer 的知识传递中,WKD-L 均表现优异。例如,在 Swin-T→RN18 设置中,WKD-L 的 Top-1 准确率为 81.42%,比 OFA 提高了 0.88%;在 ConvNeXt-T→DeiT-T 设置中,WKD-L 的 Top-1 准确率为 76.11%,比 OFA 提高了 0.35%。这表明,WKD-L 在跨架构知识传递中具有显著优势。

4.4.2 特征蒸馏的比较
在特征蒸馏方面,我们将 WKD-F 与 FitNet[24]、CC[64]、RKD[65] 和 CRD[25] 进行了比较。WKD-F 在所有设置中均表现最佳,尤其是在从 Transformer 到 CNN 的知识传递中,WKD-F 显著优于之前的竞争对手。例如,在 Swin-T→RN18 设置中,WKD-F 的 Top-1 准确率为 81.57%,比 CRD 提高了 3.94%;在 ViT-S→MNV2 设置中,WKD-F 的 Top-1 准确率为 79.11%,比 CRD 提高了 0.97%。我们认为,WKD-F 的优势在于其分布建模和匹配策略,即使用高斯分布和 Wasserstein 距离。由于 CNN 和 Transformer 生成的特征差异较大 [46],WKD-F 的特征分布匹配策略比 FitNet 和 CRD 的原始特征比较更为有效。
4.4.3 Logit 与特征蒸馏的结合
我们将 WKD-L 和 WKD-F 结合使用,进一步提升了性能。例如,在 ConvNeXt-T→DeiT-T 设置中,WKD-L+WKD-F 的 Top-1 准确率为 76.11%,比单独的 WKD-L 和 WKD-F 分别提高了 0.35% 和 2.84%;在 ConvNeXt-T→Swin-P 设置中,WKD-L+WKD-F 的 Top-1 准确率为 78.94%,比单独的 WKD-L 和 WKD-F 分别提高了 0.62% 和 4.14%。这表明,logit 蒸馏和特征蒸馏的结合能够进一步优化跨架构知识传递的效果。
4.5 ImageNet 上的自知识蒸馏实验
我们在自知识蒸馏(Self-KD)框架中实现了 WKD,具体采用 Born-Again Network(BAN)[66] 的方法。首先,我们使用真实标签训练一个初始模型
4.5.1 实验结果
我们在 ImageNet 上使用 ResNet18 进行了实验,超参数与设置(a)一致。如表 7 所示,BAN 取得了具有竞争力的准确率,与最先进的结果相当。我们的方法取得了最佳结果,Top-1 准确率为 71.35%,比 BAN 提高了 0.85%,比第二好的 USKD[4] 提高了 0.6%。这表明,WKD 在自知识蒸馏中具有良好的泛化能力。
4.6 MS-COCO 上的目标检测实验
我们将 WKD 扩展到目标检测任务中,采用 Faster-RCNN[47] 框架。对于 WKD-L,我们使用检测头中的分类分支进行 logit 蒸馏。对于 WKD-F,我们从 RoIAlign 层输出的特征中传递知识,并选择 4×4 的空间网格来计算高斯分布。详细的实现细节、关键组件的消融实验和额外实验见附录 E 节。
4.6.1 实验结果
我们在两种设置下与现有方法进行了比较,如表 8 所示。在 RN101→RN18 设置中,教师为 ResNet101,学生为 ResNet18;在 RN50→MNV2 设置中,教师为 ResNet50,学生为 MobileNetV2[58]。
- Logit 蒸馏:我们的 WKD-L 显著优于经典的 KD[2],并且在 RN50→MNV2 设置中略优于 DKD[3]。
- 特征蒸馏:我们将 WKD-F 与 FitNet、FGFI[50]、ICD[51] 和 ReviewKD[29] 进行了比较。WKD-F 在两种设置中均显著优于 ReviewKD,这是之前最好的特征蒸馏方法。
- Logit 与特征蒸馏的结合:通过结合 WKD-L 和 WKD-F,我们取得了比 DKD+ReviewKD[3] 更好的性能。当额外使用边界框回归进行知识传递时,我们的 WKD-L+WKD-F 进一步提升了性能,超越了之前最先进的 FCFD[8]。
5. 结论
Wasserstein 距离(WD)在生成模型 [11] 等多个领域已显示出明显优于 KL 散度的优势。然而,在知识蒸馏领域,KL 散度仍然占据主导地位,且尚不清楚 WD 是否能够超越 KL 散度。我们认为,早期基于 WD 的知识蒸馏尝试未能充分发挥这一度量的潜力。因此,我们提出了一种基于 WD 的知识蒸馏新方法,能够从 logits 和特征中传递知识。大量的实验表明,离散 WD 在 logit 蒸馏中是非常有前途的替代方案,而连续 WD 在从中间层传递知识方面能够取得令人瞩目的性能。
尽管如此,我们的方法仍存在一些局限性。具体来说,WKD-L 的计算成本高于基于 KL 散度的 logit 蒸馏方法,而 WKD-F 假设特征遵循高斯分布。关于这些局限性和未来研究的详细讨论见附录 F 节。最后,我们希望我们的工作能够揭示 WD 的潜力,并激发更多关于这一度量在知识蒸馏中的应用研究。
附录
A. WKD 的实现细节
A.1 WKD-L 中的类别相互关系建模
类别相互关系的可视化:我们从 ImageNet 训练集中随机选择 100 个类别,每个类别随机选择 50 张图像。然后,我们将这些图像输入到预训练的 ResNet50 模型中,并从倒数第二层提取特征,使用 t-SNE 将其投影到 2D 空间。不同类别以不同颜色显示,如图 4a 所示。为了直观理解,我们根据特征在 2D 嵌入中的最近位置显示对应的图像,如图 4b 所示。可以看出,类别在特征空间中表现出复杂的拓扑关系(距离),例如哺乳动物彼此更接近,而远离人工制品或食物。这些关系编码了丰富的信息,对知识蒸馏非常有益。
基于 CKA 的 IR 量化:我们使用 CKA 来建模类别之间的相互关系,因为它能够有效表征深度表示的相似性 [17]。CKA 是归一化的 HSIC[18],通过将特征映射到 RKHS 来测量随机变量(特征)之间的统计依赖性。对于类别
基于余弦相似度的 IR 量化:除了 CKA,还可以使用两个类别原型之间的余弦相似度来量化 IR。类别原型可以自然地计算为该类别训练样本的特征质心,即
A.2 WKD-F 中的分布建模
我们使用高斯分布对特征分布进行建模。给定输入图像,我们将 DNN 某层输出的特征图重塑为矩阵
我们使用高斯分布
B. WKD 的计算复杂度
WKD-L 的复杂度为
C. 图像分类的额外实验
C.1 WKD 的更多消融实验
我们在 ImageNet 上对 WKD 的关键组件进行了消融实验,具体包括 WKD-L 的超参数(如温度和权重)和 WKD-F 的超参数(如均值 - 协方差比率和权重)。实验结果表明,WKD-L 和 WKD-F 在不同超参数设置下均表现稳定,且优于基线方法。
C.2 ImageNet 上的超参数总结
在设置(a)中,WKD-L 的超参数包括温度
D. 可视化
D.1 教师 - 学生差异的可视化
我们使用 Grad-CAM[79] 计算不同模型的类激活图(CAMs),如图 8 所示。可以看出,WKD-L 和 WKD-F 的 CAMs 与教师模型更为相似,且能够更准确地定位对象的重要区域。这表明 WKD-L 和 WKD-F 能够学习到具有更好表示能力的特征。
E. 目标检测的额外实验
E.1 COCO 上的实现细节
对于 WKD-L,我们使用离散 WD 匹配教师和学生模型分类分支的预测概率。对于 WKD-F,我们从 RoIAlign 层输出的特征中传递知识,并选择 4×4 的空间网格来计算高斯分布。
F. 局限性与未来研究
WKD-L 的计算成本高于基于 KL 散度的方法,但未来可以通过更快的 WD 算法 [45] 来优化。此外,WKD-F 假设特征遵循高斯分布,未来可以探索更鲁棒和高效的分布建模方法。
2.2.3 其他分布建模方法
除了高斯分布,我们还可以使用其他分布进行特征建模。例如:
- 拉普拉斯分布:拉普拉斯分布假设特征的分量是独立的,其概率密度函数为:
其中
- 指数分布:指数分布假设特征的分量是独立的,其概率密度函数为:
其中
然而,实验表明,高斯分布在特征建模中表现最好,尤其是在高维特征空间中。
2.2.4 非参数化方法
尽管非参数化方法(例如直方图和核密度估计)在高维空间中不可行,但我们仍然可以使用概率质量函数(PMF)进行分布建模。具体来说,给定一组特征
其中
2.2.5 空间金字塔池化
为了增强特征表示能力,我们可以使用空间金字塔池化策略 [28;29;6]。具体来说,我们将特征图划分为
2.2.6 与其他方法的比较
3.3.1 NST 与 ICKD-C 的比较
NST[35] 和 ICKD-C[6] 都采用了基于二阶矩的分布建模方法,但它们关注的维度不同:
- NST:沿空间维度计算二阶矩,即对每个空间位置的特征进行统计建模。
- ICKD-C:沿通道维度计算二阶矩,即对每个通道的特征进行统计建模。
尽管这两种方法在某些任务中表现良好,但它们未能充分利用二阶矩矩阵的几何结构,这些矩阵是对称正定的(SPD),并形成一个黎曼空间 [38;39]。相比之下,我们的 WKD-F 通过 Wasserstein 距离直接捕捉高斯分布之间的几何关系,从而在特征蒸馏中表现更优。
3.3.2 VID 与 WKD-F 的比较
VID[40] 通过变分信息蒸馏(VID)来传递知识,假设特征分布是高斯的,并通过互信息来度量特征之间的相似性。如果进一步假设高斯分布具有单位方差,VID 的损失函数简化为均方损失(即 FitNet[24])。尽管 VID 在某些任务中表现良好,但它仍然依赖于高斯分布的假设,并且无法像 WKD-F 那样直接利用 Wasserstein 距离的几何优势。
3.4 与其他最新方法的比较
在表 1 中,我们总结了与我们的方法相关的 KD 方法,并展示了它们的联系和区别。我们的 WKD-L 和 WKD-F 在 logit 蒸馏和特征蒸馏方面均表现优异,尤其是在利用类别之间的相互关系和几何结构方面。