Geeks_Z の Blog Geeks_Z の Blog
首页
  • 学习笔记

    • 《HTML》
    • 《CSS》
    • 《JavaWeb》
    • 《Vue》
  • 后端文章

    • Linux
    • Maven
    • 汇编语言
    • 软件工程
    • 计算机网络概述
    • Conda
    • Pip
    • Shell
    • SSH
    • Mac快捷键
    • Zotero
  • 学习笔记

    • 《数据结构与算法》
    • 《算法设计与分析》
    • 《Spring》
    • 《SpringMVC》
    • 《SpringBoot》
    • 《SpringCloud》
    • 《Nginx》
  • 深度学习文章
  • 学习笔记

    • 《PyTorch》
    • 《ReinforementLearning》
    • 《MetaLearning》
  • 学习笔记

    • 《高等数学》
    • 《线性代数》
    • 《概率论与数理统计》
  • 增量学习
  • 哈希学习
GitHub (opens new window)

Geeks_Z

AI小学生
首页
  • 学习笔记

    • 《HTML》
    • 《CSS》
    • 《JavaWeb》
    • 《Vue》
  • 后端文章

    • Linux
    • Maven
    • 汇编语言
    • 软件工程
    • 计算机网络概述
    • Conda
    • Pip
    • Shell
    • SSH
    • Mac快捷键
    • Zotero
  • 学习笔记

    • 《数据结构与算法》
    • 《算法设计与分析》
    • 《Spring》
    • 《SpringMVC》
    • 《SpringBoot》
    • 《SpringCloud》
    • 《Nginx》
  • 深度学习文章
  • 学习笔记

    • 《PyTorch》
    • 《ReinforementLearning》
    • 《MetaLearning》
  • 学习笔记

    • 《高等数学》
    • 《线性代数》
    • 《概率论与数理统计》
  • 增量学习
  • 哈希学习
GitHub (opens new window)
  • Python

  • MLTutorials

  • 卷积神经网络

  • 循环神经网络

  • Transformer

  • VisionTransformer

  • 扩散模型

  • 计算机视觉

  • PTM

  • MoE

  • LoRAMoE

  • LongTailed

  • 多模态

  • 知识蒸馏

  • PEFT

  • 对比学习

  • 小样本学习

  • 迁移学习

  • 零样本学习

  • 集成学习

  • Mamba

  • PyTorch

    • PyTorch概述

    • Tensors

    • 数据处理

    • 模型

    • 训练

    • 并行计算

    • 可视化

    • 实战

    • timm

    • Pytorch Lightning

      • 概述
      • 记录训练loss
      • Pytorch_Lightning
    • 数据增强

    • 面经与bug解决

    • 常用代码片段

    • Reference
  • CL

  • CIL

  • 小样本类增量学习FSCIL

  • UCIL

  • 多模态增量学习MMCL

  • LTCIL

  • DIL

  • 论文阅读与写作

  • 分布外检测

  • GPU

  • 深度学习调参指南

  • AINotes
  • PyTorch
  • Pytorch Lightning
Geeks_Z
2025-03-18

概述

使用 PyTorch Lightning 进行模型训练可以简化深度学习项目的开发流程,提高代码的可读性和可维护性。以下是使用 PyTorch Lightning 完成模型训练的主要步骤:

  1. 安装 PyTorch Lightning

    首先,确保已安装 PyTorch Lightning。可以使用以下命令通过 pip 安装:

    pip install pytorch-lightning
    
    1
  2. 定义 LightningModule

    创建一个继承自 pl.LightningModule 的类,用于定义模型结构、前向传播、损失计算和优化器配置等。

    import pytorch_lightning as pl
    import torch
    from torch import nn
    import torch.nn.functional as F
    
    class LitModel(pl.LightningModule):
        def __init__(self):
            super(LitModel, self).__init__()
            self.layer = nn.Linear(28 * 28, 10)
    
        def forward(self, x):
            return torch.relu(self.layer(x.view(x.size(0), -1)))
    
        def training_step(self, batch, batch_idx):
            x, y = batch
            logits = self(x)
            loss = F.cross_entropy(logits, y)
            return loss
    
        def configure_optimizers(self):
            return torch.optim.Adam(self.parameters(), lr=0.001)
    
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
  3. 准备数据

    使用 LightningDataModule 或自定义数据加载器来处理数据的加载和预处理。

    from torch.utils.data import DataLoader, random_split
    from torchvision.datasets import MNIST
    from torchvision import transforms
    
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
    dataset = MNIST(root='data', train=True, download=True, transform=transform)
    train_set, val_set = random_split(dataset, [55000, 5000])
    
    train_loader = DataLoader(train_set, batch_size=64, num_workers=4)
    val_loader = DataLoader(val_set, batch_size=64, num_workers=4)
    
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
  4. 初始化 Trainer 并开始训练

    使用 pl.Trainer 初始化训练器,设置训练参数,并调用 fit 方法开始训练。

    from pytorch_lightning import Trainer
    
    model = LitModel()
    trainer = Trainer(max_epochs=10, gpus=1)
    trainer.fit(model, train_loader, val_loader)
    
    1
    2
    3
    4
    5
  5. 模型验证和测试

    在训练完成后,可以使用验证集或测试集评估模型性能。

    test_set = MNIST(root='data', train=False, download=True, transform=transform)
    test_loader = DataLoader(test_set, batch_size=64, num_workers=4)
    trainer.test(model, test_loader)
    
    1
    2
    3

通过以上步骤,您可以使用 PyTorch Lightning 高效地完成模型的训练和评估过程。

上次更新: 2025/06/25, 11:25:50
create_model解读
记录训练loss

← create_model解读 记录训练loss→

最近更新
01
帮助信息查看
06-08
02
常用命令
06-08
03
学习资源
06-07
更多文章>
Theme by Vdoing | Copyright © 2022-2025 Geeks_Z | MIT License
京公网安备 11010802040735号 | 京ICP备2022029989号-1
  • 跟随系统
  • 浅色模式
  • 深色模式
  • 阅读模式