PyTorch的概述
PyTorch 实现模型训练的 5 大要素

- 数据:包括数据读取,数据清洗,进行数据划分和数据预处理,比如读取图片如何预处理及数据增强。
- 模型:包括构建模型模块,组织复杂网络,初始化网络参数,定义网络层。
- 损失函数:包括创建损失函数,设置损失函数超参数,根据不同任务选择合适的损失函数。
- 优化器:包括根据梯度使用某种优化器更新参数,管理模型参数,管理多个参数组实现不同学习率,调整学习率。
- 迭代训练:组织上面 4 个模块进行反复训练。包括观察训练效果,绘制 Loss/ Accuracy 曲线,用 TensorBoard 进行可视化分析。
深度学习代码 coding 顺序
比较好的顺序是先写 model,再写 dataset,最后写 train。
model:构成了整个深度学习训练与推断系统骨架,也确定了整个 AI 模型的输入和输出格式。对于视觉任务,模型架构多为卷积神经网络或是最新的 ViT 模型;对于 NLP 任务,模型架构多为 Transformer 以及 Bert;对于时间序列预测,模型架构多为 RNN 或 LSTM。不同的 model 对应了不同的数据输入格式,如 ResNet 一般是输入多通道二维矩阵,而 ViT 则需要输入带有位置信息的图像 patchs。确定了用什么样的 model 后,数据的输入格式也就确定下来。根据确定的输入格式,我们才能构建对应的 dataset。
dataset:构建了整个 AI 模型的输入与输出格式。在写作 dataset 组件时,我们需要考虑数据的存储位置与存储方式,如数据是否是分布式存储的,模型是否要在多机多卡的情况下运行,读写速度是否存在瓶颈,如果机械硬盘带来了读写瓶颈则需要将数据预加载进内存等。在写 dataset 组件时,我们也要反向微调 model 组件。例如,确定了分布式训练的数据读写后,需要用 nn.DataParallel 或者 nn.DistributedDataParallel 等模块包裹 model,使模型能够在多机多卡上运行。此外,dataset 组件的写作也会影响训练策略,这也为构建 train 组件做了铺垫。比如根据显存大小,我们需要确定相应的 BatchSize,而 BatchSize 则直接影响学习率的大小。再比如根据数据的分布情况,我们需要选择不同的采样策略进行 Feature Balance,而这也会体现在训练策略中。
train:构建了模型的训练策略以及评估方法,它是最重要也是最复杂的组件。先构建 model 与 dataset 可以添加限制,减少 train 组件的复杂度。在 train 组件中,我们需要根据训练环境(单机多卡,多机多卡或是联邦学习)确定模型更新的策略,以及确定训练总时长 epochs,优化器的类型,学习率的大小与衰减策略,参数的初始化方法,模型损失函数。此外,为了对抗过拟合,提升泛化性,还需要引入合适的正则化方法,如 Dropout,BatchNorm,L2-Regularization,Data Augmentation 等。有些提升泛化性能的方法可以直接在 train 组件中实现(如添加 L2-Reg,Mixup),有些则需要添加进 model 中(如 Dropout 与 BatchNorm),还有些需要添加进 dataset 中(如 Data Augmentation)。
PyTorch学习资源
- Awesome-pytorch-list (opens new window):目前已获12K Star,包含了NLP,CV,常见库,论文实现以及Pytorch的其他项目。
- PyTorch官方文档 (opens new window):官方发布的文档,十分丰富。
- Pytorch-handbook (opens new window):GitHub上已经收获14.8K,pytorch手中书。
- PyTorch官方社区 (opens new window):PyTorch拥有一个活跃的社区,在这里你可以和开发pytorch的人们进行交流。
- PyTorch官方tutorials (opens new window):官方编写的tutorials,可以结合colab边动手边学习
- 动手学深度学习 (opens new window):动手学深度学习是由李沐老师主讲的一门深度学习入门课,拥有成熟的书籍资源和课程资源,在B站,Youtube均有回放。
- Awesome-PyTorch-Chinese (opens new window):常见的中文优质PyTorch资源
- labml.ai Deep Learning Paper Implementations (opens new window):手把手实现经典网络代码
- YSDA course in Natural Language Processing (opens new window):YSDA course in Natural Language Processing
- huggingface (opens new window):hugging face
- ModelScope (opens new window): 魔搭社区