Batch Size
- batch size不能太大,也不能太小;太小会浪费计算资源,太大则会浪费内存;一般设置为16的倍数。对于推荐来说32-64-128-512测试效果再高一般也不会正向了,再低训练太慢了。
- Learning rate和batch size是两个重要的参数,而且二者也是相互影响的,在反向传播时直接影响梯度。一般情况下,先调batch size,再调learning rate。
选择BatchSize
总结: Batch Size决定训练速度,并且不应该被直接用于调整验证集性能。通常来说,可用硬件支持的最大Batch Size是较为理想的数值。
- Batch Size是决定训练时间和计算资源消耗的关键因素。
- 增加Batch Size通常会减少训练时间。这非常有益,因为它:
- 能使固定时间间隔内超参数调整更彻底,最终训练出更好的模型。
- 减少开发周期的延迟,能更多地测试新想法。
- 资源消耗和Batch Size之间并没有明确的关系,增加Batch Size让资源消耗增加、减少或是保持不变都有可能。
- Batch Size不应该被当作验证集性能的可调超参数。
- 只要调整好所有超参数(尤其是学习率和正则化超参数)并且训练步数足够,理论上任意的Batch Size都能获得相同的最终性能(参见 Shallue et al. 2018 (opens new window))
确定可行的Batch Size并估计训练吞吐量
- 对于给定的模型和优化器,可用硬件通常能能够支持一系列Batch Size。限制因素通常是加速器(GPU/TPU 等)的内存。
- 不幸的是,如果不运行或者编译完整的训练程序,就很难计算出适合内存的Batch Size。
- 最简单的解决方案通常是以不同的批次大小(例如,用 2 的幂来尝试)运行少量的训练实验,直到其中一个实验超过可用内存。
- 对于每个Batch Size,我们应该训练足够长的时间以准确估计训练吞吐量
训练吞吐量 =每秒处理的样本数量
或者,我们可以估计每步时间 :
每步时间 =(Batch Size)/(训练吞吐量)
- 当加速器内存未饱和时,如果Batch Size加倍,训练吞吐量也应该加倍(或至少接近加倍)。等效地,随着Batch Size的增加,每步的时间应该是恒定的(或至少接近恒定的)。
- 如果与上述情况不符,那么训练工作流可能存在瓶颈,例如 I/O 或计算节点间的同步。有必要在开始下一步前对此进行诊断和矫正。
- 如果训练吞吐量到某个Batch Size之后就不再增加,那么我们只考虑使用该Batch Size(即使硬件支持更大的Batch Size)
- 使用更大Batch Size的所有好处都假定训练吞吐量增加。如果没有,请修复瓶颈或使用较小的Batch Size。
- 使用梯度积累技术可以支持的更大的Batch Size。但其不提供任何训练吞吐量优势,故在应用工作中通常应避免使用它。
- 每次更改模型或优化器时,可能都需要重复这些步骤(例如,不同的模型架构可能允许更大的Batch Size)。
选择合适的Batch Size以最小化训练时间
训练时间 =(每步时间)x(总步数)
- 对于所有可行的Batch Size,我们通常可以认为每步的时间近似恒定(实际上,增加Batch Size通常会产生一些开销)。
- Batch Size越大,达到某一性能目标所需的步数通常会减少(前提是在更改Batch Size时重新调整所有相关超参数;Shallue et al. 2018 (opens new window))。
- 例如,将Batch Size翻倍可能会使训练步数减半。这称为完美缩放。
- 完美缩放适用于Batch Size在临界值之前,超过该临界总步数的减少效果将会下降。
- 最终,增加Batch Size不会再使训练步数减少(永远不会增加)。
- 因此,最小化训练时间的Batch Size通常是最大的Batch Size,也同时减少了所需的训练步数。
- Batch Size取决于数据集、模型和优化器,除了通过实验为每个新问题找到它之外,如何计算它是一个悬而未决的问题。
- 比较Batch Size时,请注意效果(epoch)预算(运行所有实验,固定训练样本的数量达到设定的效果所花的时间)和步数预算(运行设定步数的试验)之间的区别。
- 将Batch Size与效果预算进行比较只会涉及到完美缩放的范围,即使更大的Batch Size仍可能通过减少所需的训练步数来提供有意义的加速。
- 通常,可用硬件支持的最大Batch Size将小于临界Batch Size。因此,一个好的经验法则(不运行任何实验)是使用尽可能大的Batch Size。
- 如果最终增加了训练时间,那么使用更大的Batch Size就没有意义了。
选择合适的Batch Size以最小化资源消耗
- 有两种类型的资源成本与增加 batch size有关。
- 前期成本,例如购买新硬件或重写训练工作流以实现多GPU/多TPU训练。
- 使用成本,例如,根据团队的资源预算计费,从云供应商处计费,电力/维护成本。
如果增加 batch size有很大的前期成本,那么直到项目成熟且容易权衡成本效益前,推迟其的增加可能更好。实施多机并行训练程序可能会引入错误 (opens new window)和一些棘手的细节 (opens new window),所以无论如何,一开始最好是用一个比较简单的工作流。(另一方面,当需要进行大量的调优实验时,训练时间的大幅加速可能会在过程的早期非常有利)。
我们把总的使用成本(可能包括多种不同类型的成本)称为 "资源消耗 "。我们可以将资源消耗分解为以下几个部分。
资源消耗 = (每步的资源消耗) x (总步数)
增加 batch size通常可以减少总步骤数。资源消耗是增加还是减少,将取决于每步的消耗如何变化。
- 增加 batch size可能会减少资源消耗。例如,如果大batch size的每一步都可以在与小batch size相同的硬件上运行(每一步只增加少量时间),那么每一步资源消耗的增加可能被步骤数的减少所抵消。
- 增加 batch size可能不会改变资源消耗。例如,如果将batch size增加一倍,所需的步骤数减少一半,所使用的GPU数量增加一倍,总消耗量(以GPU小时计)将不会改变。
- 增加 batch size可能会增加资源消耗。例如,如果增加batch size需要升级硬件,那么每步消耗的增加可能超过训练所需步数的减少。
更改Batch Size需要重新调整大多数超参数
- 大多数超参数的最佳值对Batch Size敏感。因此,更改Batch Size通常需要重新开始调整过程。
- 与Batch Size交互最强烈的超参数是优化器超参数(学习率、动量等)和正则化超参数,所以有必要对于针对每个Batch Size单独调整它们。
- 在项目开始时选择Batch Size时请记住,如果您以后需要切换到不同的Batch Size,则为新的Batch Size重新调整所有内容可能会很困难、耗时且成本高昂。
Batch Norm会对Batch Size的选择造成什么影响?
Batch norm 很复杂,一般来说,应该使用与计算梯度不同的 batch size 来计算统计数据(像Ghost Batch Norm (opens new window)采用固定值的batch size)。
BatchNorm的实现细节
总结:目前Batch Norm通常可以用Layer Norm代替,但在不能替换的情况下,在更改批大小或主机数量时会有一些棘手的细节。
- Batch norm 使用当前批次的均值和方差对激活值进行归一化,但在多设备设置中,除非明确同步处理,否则这些统计数据在每个设备上都是不同的。
- 据说(主要在ImageNet上)仅使用约 64 个样本计算这些归一化统计数据在实际应用中效果更好(请参阅 Ghost Batch Norm (opens new window))。
- 将总批大小与用于计算批归一化统计数据的样本数量分离对于批次大小的比较特别有用。
- Ghost batch norm 实现并不总能正确处理每台设备的批次大小 > 虚拟批次大小的情况。在这种情况下,我们实际上需要在每个设备上对批次进行二次抽样,以获得适当数量的批归一化统计样本。
- 在测试模式中使用的指数移动平均(EMA)仅仅是训练统计数据的线性组合,因此这些 EMA 只需要在将它们保存在检查点之前进行同步。然而一些常见的批归一化实现不同步这些EMA,并只保存第一个设备的EMA。
为什么不应该调整Batch Size来直接提高验证集性能?
- 在不更改训练工作流其他细节的情况下, 修改batch size 通常会影响验证集的性能。
- 但是,如果针对每个batch size单独调优,则两个batch size之间的验证集性能差异通常会消失。
- 受batch size影响最强烈的那些超参数,即优化器超参数(例如:学习率、动量)和正则化超参数,这些东西对于每个batch size进行单独调优的时候是最重要的。
- 由于样本方差的原因,较小的batch size会在训练算法中引入更多的不确定性,并且这些不确定性可能存在着正则化效果。因此,较大的batch size可能更容易过度拟合。并且,这可能需要更强的正则化和/或额外的正则化技术。
- 此外, 当修改batch size的大小时,训练步骤的数量可能也需要进行调整。
- 一旦考虑了所有这些因素带来的影响,目前还没有任何能够令人信服的证据表明batch size会影响最大可实现的验证性能(具体请阅读 Shallue et al. 2018 (opens new window))。
上次更新: 2025/04/02, 12:03:38