概述
使用 PyTorch Lightning 进行模型训练可以简化深度学习项目的开发流程,提高代码的可读性和可维护性。以下是使用 PyTorch Lightning 完成模型训练的主要步骤:
安装 PyTorch Lightning
首先,确保已安装 PyTorch Lightning。可以使用以下命令通过 pip 安装:
pip install pytorch-lightning
1定义 LightningModule
创建一个继承自
pl.LightningModule
的类,用于定义模型结构、前向传播、损失计算和优化器配置等。import pytorch_lightning as pl import torch from torch import nn import torch.nn.functional as F class LitModel(pl.LightningModule): def __init__(self): super(LitModel, self).__init__() self.layer = nn.Linear(28 * 28, 10) def forward(self, x): return torch.relu(self.layer(x.view(x.size(0), -1))) def training_step(self, batch, batch_idx): x, y = batch logits = self(x) loss = F.cross_entropy(logits, y) return loss def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=0.001)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21准备数据
使用
LightningDataModule
或自定义数据加载器来处理数据的加载和预处理。from torch.utils.data import DataLoader, random_split from torchvision.datasets import MNIST from torchvision import transforms transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) dataset = MNIST(root='data', train=True, download=True, transform=transform) train_set, val_set = random_split(dataset, [55000, 5000]) train_loader = DataLoader(train_set, batch_size=64, num_workers=4) val_loader = DataLoader(val_set, batch_size=64, num_workers=4)
1
2
3
4
5
6
7
8
9
10初始化 Trainer 并开始训练
使用
pl.Trainer
初始化训练器,设置训练参数,并调用fit
方法开始训练。from pytorch_lightning import Trainer model = LitModel() trainer = Trainer(max_epochs=10, gpus=1) trainer.fit(model, train_loader, val_loader)
1
2
3
4
5模型验证和测试
在训练完成后,可以使用验证集或测试集评估模型性能。
test_set = MNIST(root='data', train=False, download=True, transform=transform) test_loader = DataLoader(test_set, batch_size=64, num_workers=4) trainer.test(model, test_loader)
1
2
3
通过以上步骤,您可以使用 PyTorch Lightning 高效地完成模型的训练和评估过程。
上次更新: 2025/04/02, 12:03:38