Geeks_Z の Blog Geeks_Z の Blog
首页
  • 学习笔记

    • 《HTML》
    • 《CSS》
    • 《JavaWeb》
    • 《Vue》
  • 后端文章

    • Linux
    • Maven
    • 汇编语言
    • 软件工程
    • 计算机网络概述
    • Conda
    • Pip
    • Shell
    • SSH
    • Mac快捷键
    • Zotero
  • 学习笔记

    • 《数据结构与算法》
    • 《算法设计与分析》
    • 《Spring》
    • 《SpringMVC》
    • 《SpringBoot》
    • 《SpringCloud》
    • 《Nginx》
  • 深度学习文章
  • 学习笔记

    • 《PyTorch》
    • 《ReinforementLearning》
    • 《MetaLearning》
  • 学习笔记

    • 《高等数学》
    • 《线性代数》
    • 《概率论与数理统计》
  • 增量学习
  • 哈希学习
GitHub (opens new window)

Geeks_Z

AI小学生
首页
  • 学习笔记

    • 《HTML》
    • 《CSS》
    • 《JavaWeb》
    • 《Vue》
  • 后端文章

    • Linux
    • Maven
    • 汇编语言
    • 软件工程
    • 计算机网络概述
    • Conda
    • Pip
    • Shell
    • SSH
    • Mac快捷键
    • Zotero
  • 学习笔记

    • 《数据结构与算法》
    • 《算法设计与分析》
    • 《Spring》
    • 《SpringMVC》
    • 《SpringBoot》
    • 《SpringCloud》
    • 《Nginx》
  • 深度学习文章
  • 学习笔记

    • 《PyTorch》
    • 《ReinforementLearning》
    • 《MetaLearning》
  • 学习笔记

    • 《高等数学》
    • 《线性代数》
    • 《概率论与数理统计》
  • 增量学习
  • 哈希学习
GitHub (opens new window)
  • 线性代数

  • 概率论与数理统计

  • 矩阵

  • 分布

    • 统计量及其分布
    • 参数估计
    • 假设检验
    • 先验分布与后验分布
    • 分布度量
    • 交叉熵
    • 最优运输概述
    • Wasserstein距离
    • 基于最优传输的分类损失函数
      • 概率散度
      • 最优传输
      • 成本函数
      • 实验效果
      • 个人思考
      • 文章小结
    • 最优传输之生成模型
    • 最优传输之梯度流
    • 两个多元正态分布的KL散度巴氏距离和W距离
  • 数学笔记
  • 分布
Geeks_Z
2025-01-18
目录

基于最优传输的分类损失函数

EMO:基于最优传输思想设计的分类损失函数 (opens new window)

众所周知,分类任务的标准损失是交叉熵(Cross Entropy,等价于最大似然MLE,即Maximum Likelihood Estimation),它有着简单高效的特点,但在某些场景下也暴露出一些问题,如偏离评价指标、过度自信等,相应的改进工作也有很多,此前我们也介绍过一些,比如《再谈类别不平衡问题:调节权重与魔改Loss的对比联系》 (opens new window)、《如何训练你的准确率?》 (opens new window)、《缓解交叉熵过度自信的一个简明方案》 (opens new window)等。由于LLM的训练也可以理解为逐token的分类任务,默认损失也是交叉熵,因此这些改进工作在LLM流行的今天依然有一定的价值。

在这篇文章中,我们介绍一篇名为《EMO: Earth Mover Distance Optimization for Auto-Regressive Language Modeling》 (opens new window)的工作,它基于最优传输思想提出了新的改进损失函数EMO,声称能大幅提高LLM的微调效果。其中细节如何?让我们一探究竟。

概率散度

假设pi 模型预测的第i 类别的概率,i=1,2,⋯,n,t 是目标类别,那么交叉熵损失为

(1)L=−log⁡pt

如果将标签t one hot形式的分布τ 示出来(即τt=1,τi=0|i≠t,i∈[1,n]),那么它可以重写成

(2)L=−∑iτilog⁡pi

这个形式同时适用于非one hot的标签τ(即软标签),它等价于优化τ,p KL散度:

(3)KL(τ‖p)=∑iτilog⁡τipi=∑iτilog⁡τi−∑iτilog⁡pi

当τ 定时,最右端第一项就是一个常数,所以它跟交叉熵目标是等价的。

这个结果表明,我们在做MLE,或者说以交叉熵为损失时,实则就是在最小化目标分布和预测分布的KL散度。由于KL散度的一般推广是f散度(参考《f-GAN简介:GAN模型的生产车间》 (opens new window)),所以很自然想到换用其他f散度或许有改良作用。事实上,确实有不少工作是按照这个思路进行的,比如《缓解交叉熵过度自信的一个简明方案》 (opens new window)介绍的方法,其论文的出发点是“Total Variation距离”,也是f散度的一种。

最优传输

不过,每种f散度或多或少有些问题,要说概率分布之间的理想度量,当属基于最优传输思想的“推土机距离(Earth Mover's Distance,EMD)”,不了解的读者可以参考一下笔者之前写的《从Wasserstein距离、对偶理论到WGAN》 (opens new window)。

简单来说,推土机距离定义为两个分布之间的最优传输成本:

(4)C[p,τ]=infγ∈Π[p,τ]∑i,jγi,jci,j

这里的γ∈Π[p,τ] 的是γ 任意以p,τ 边缘分布的联合分布,ci,j 实现给定的成本函数,代表“从i 运到j 成本”,inf 下确界,意思就是说将最低的运输成本作为p,τ 间的差异度量。正如基于f散度的Vanilla GAN换成基于最优传输的Wasserstein GAN能够更好的收敛性质,我们期望如果将分类的损失函数换成两个分布的W距离,也能收敛到更好的结果。

当τ one hot分布时,目标分布就是一个点t,那么就无所谓最不最优了,传输方案就只有一个,即把p 所有东西都搬到同一个点t,所以此时就有

(5)C[p,τ]=∑ipici,t

如果τ 一般的软标签分布,那么C[p,τ] 计算是一个线性规划问题,求解起来比较复杂,由于piτj 定义的分布也属于Π[p,τ],那么我们有

(6)C[p,τ]=infγ∈Π[p,τ]∑i,jγi,jci,j≤∑i,jpiτjci,j

这是一个容易计算的上界,也可以作为优化目标,式(5) 对应τj=δj,t,其中δ “克罗内克δ函数 (opens new window)”。

成本函数

现在回到原论文所关心的场景——LLM的微调,包括二次预训练和微调到下游任务等。正如本文开头所述,LLM的训练可以理解为逐token的分类任务(类别即所有token),每个标签是one hot的,所以适用于式(5)。

式(5) 差成本函数ci,t 没定下来。如果简单地认为只要i≠t,那么成本都是1,即ci,t=1−δi,t,那么

(7)C[p,τ]=∑ipici,t=∑i(pi−piδi,t)=1−pt

这其实就是在最大化准确率的光滑近似(参考《函数光滑化杂谈:不可导函数的可导逼近》 (opens new window))。但直觉上,所有i≠t 给予同样程度的惩罚似乎过于简单了,理想情况下应该根据相似度来给每个不同的i 计不同的成本,即相似度越大,传输成本越低,那么我们可以将传输成本设计为

ci,t=1−cos⁡(ei,et)=1−⟨ei‖ei‖,et‖et‖⟩

这里的ei,et 事先获取到Token Embedding,原论文是将预训练模型的LM Head作为Token Embedding的,并且根据最优传输的定义成本函数是要实现给定的,因此计算相似度的Token Embedding要在训练过程中固定不变。

有了成本函数后,我们就可以计算

C[p,τ]=∑ipici,t=∑i(pi−pi⟨ei‖ei‖,et‖et‖⟩)=1−⟨∑ipiei‖ei‖,et‖et‖⟩

这就是EMO(Earth Mover Distance Optimization)最终的训练损失。由于embedding_size通常远小于vocab_size,所以先算∑ipiei‖ei‖ 明显降低计算量。

实验效果

由于笔者对LLM的研究还处于预训练阶段,还未涉及到微调,所以暂时没有自己的实验结果,只能先跟大家一起看看原论文的实验。不得不说,原论文的实验结果还是比较惊艳的。

首先,是小模型上的继续预训练实验,相比交叉熵(MLE)的提升最多的有10个点,并且是全面SOTA:

值得一提的是,这里的评价指标是MAUVE,越大越好,它提出自《MAUVE: Measuring the Gap Between Neural Text and Human Text using Divergence Frontiers》 (opens new window),是跟人工评价最相关的自动评测指标之一。此外,对比方法的TaiLr我们曾在《缓解交叉熵过度自信的一个简明方案》 (opens new window)简单介绍过。

可能有读者想EMO更好是不是单纯因为评价指标选得好?并不是,让人意外的是,EMO训练的模型,甚至PPL都更好(PPL跟MLE更相关):

然后是将LLAMA-7B/13B微调到下游任务做Few Shot的效果,同样很出色:

最后对比了不同模型规模和数据规模的效果,显示出EMO在不同模型和数据规模上都有不错的表现:

个人思考

总的来说,原论文的“成绩单”还是非常漂亮的,值得一试。唯一的疑虑可能是原论文的实验数据量其实都不算大,不清楚进一步增大数据量后是否会缩小EMO和MLE的差距。

就笔者看来,EMO之所以能取得更好的结果,是因为它通过Embedding算相似度,来为“近义词”分配了更合理的损失,从而使得模型的学习更加合理。因为虽然形式上LLM也是分类任务,但它并不是一个简单的对与错问题,并不是说下一个预测的token跟标签token不一致,句子就不合理了,因此引入语义上的相似度来设计损失对LLM的训练是有帮助的。可以进一步猜测的是,vocab_size越大、token颗粒度越大的情况下,EMO的效果应该越好,因为vocab_size大了“近义词”就可能越多。

当然,引入语义相似度也导致了EMO不适用于从零训练,因为它需要一个训练好的LM Head作为Token Embedding。当然,一个可能的解决方案是考虑用其他方式,比如经典的Word2Vec来事先训练好Token Embedding,但这可能会有一个风险,即经典方式训练的Token Embedding是否会降低LLM能力的天花板(毕竟存在不一致性)。

此外,即便Token Embedding没问题,从零预训练时单纯用EMO可能还存在收敛过慢的问题,这是因为根据笔者在《如何训练你的准确率?》 (opens new window)的末尾提出的损失函数视角: 首先寻找评测指标的一个光滑近似,最好能表达成每个样本的期望形式,然后将错误方向的误差逐渐拉到无穷大(保证模型能更关注错误样本),但同时在正确方向保证与原始形式是一阶近似。

也就是说,为了保证(从零训练的)收敛速度,错误方向的损失最好能拉到无穷大,而EMO显然不满足这一点,因此将EMO用于从零训练的时候,大概率是EMO与MLE的某个加权组合,才能平衡收敛速度和最终效果。

文章小结

本文介绍了交叉熵损失的一个新的“替代品”——基于最优传输思想的EMO,与以往的小提升不同,EMO在LLM的微调实验中取得了较为明显的提升。

上次更新: 2025/06/25, 11:25:50
Wasserstein距离
最优传输之生成模型

← Wasserstein距离 最优传输之生成模型→

最近更新
01
帮助信息查看
06-08
02
常用命令
06-08
03
学习资源
06-07
更多文章>
Theme by Vdoing | Copyright © 2022-2025 Geeks_Z | MIT License
京公网安备 11010802040735号 | 京ICP备2022029989号-1
  • 跟随系统
  • 浅色模式
  • 深色模式
  • 阅读模式