Geeks_Z の Blog Geeks_Z の Blog
首页
  • 学习笔记

    • 《HTML》
    • 《CSS》
    • 《JavaWeb》
    • 《Vue》
  • 后端文章

    • Linux
    • Maven
    • 汇编语言
    • 软件工程
    • 计算机网络概述
    • Conda
    • Pip
    • Shell
    • SSH
    • Mac快捷键
    • Zotero
  • 学习笔记

    • 《数据结构与算法》
    • 《算法设计与分析》
    • 《Spring》
    • 《SpringMVC》
    • 《SpringBoot》
    • 《SpringCloud》
    • 《Nginx》
  • 深度学习文章
  • 学习笔记

    • 《PyTorch》
    • 《ReinforementLearning》
    • 《MetaLearning》
  • 学习笔记

    • 《高等数学》
    • 《线性代数》
    • 《概率论与数理统计》
  • 增量学习
  • 哈希学习
GitHub (opens new window)

Geeks_Z

AI小学生
首页
  • 学习笔记

    • 《HTML》
    • 《CSS》
    • 《JavaWeb》
    • 《Vue》
  • 后端文章

    • Linux
    • Maven
    • 汇编语言
    • 软件工程
    • 计算机网络概述
    • Conda
    • Pip
    • Shell
    • SSH
    • Mac快捷键
    • Zotero
  • 学习笔记

    • 《数据结构与算法》
    • 《算法设计与分析》
    • 《Spring》
    • 《SpringMVC》
    • 《SpringBoot》
    • 《SpringCloud》
    • 《Nginx》
  • 深度学习文章
  • 学习笔记

    • 《PyTorch》
    • 《ReinforementLearning》
    • 《MetaLearning》
  • 学习笔记

    • 《高等数学》
    • 《线性代数》
    • 《概率论与数理统计》
  • 增量学习
  • 哈希学习
GitHub (opens new window)
  • Python

  • MLTutorials

  • 卷积神经网络

  • 循环神经网络

  • Transformer

  • VisionTransformer

  • 扩散模型

  • 计算机视觉

  • PTM

  • MoE

  • LoRAMoE

  • LongTailed

  • 多模态

  • 知识蒸馏

  • PEFT

  • 对比学习

  • 小样本学习

  • 迁移学习

  • 零样本学习

  • 集成学习

  • Mamba

  • PyTorch

    • PyTorch概述

    • Tensors

    • 数据处理

    • 模型

      • 模型构建
      • 模型容器
      • 模型参数
      • 权值初始化
      • 模型保存与加载
      • 模型修改
      • 模型优化
      • nn.Module
        • state_dict
        • loadstatedict
          • 参数说明
          • 注意事项
        • model.train()
          • 使用示例
        • model.eval()
          • 主要作用
          • 使用方法
          • 注意事项
        • model.eval()和torch.no_grad()的区别
        • Reference
      • 模型示例
    • 训练

    • 并行计算

    • 可视化

    • 实战

    • timm

    • Pytorch Lightning

    • 数据增强

    • 面经与bug解决

    • 常用代码片段

    • Reference
  • CL

  • CIL

  • 小样本类增量学习FSCIL

  • UCIL

  • 多模态增量学习MMCL

  • LTCIL

  • DIL

  • 论文阅读与写作

  • 分布外检测

  • GPU

  • 深度学习调参指南

  • AINotes
  • PyTorch
  • 模型
Geeks_Z
2022-10-30
目录

nn.Module

state_dict

在PyTorch中,state_dict是一个字典对象,用于存储模型或优化器的参数。这个字典将每一层或优化器的参数映射到对应的张量。state_dict的主要作用在于方便模型的保存和加载,以便在训练过程中恢复模型的状态或在其他任务中重用模型。

  • 对于模型(如torch.nn.Module的实例),state_dict包含模型的可学习参数(如权重和偏置)。只有包含可学习参数的层(如卷积层、线性层等)和已注册的缓冲区(如Batch Normalization层的运行均值和方差)才会在state_dict中有对应的条目。这些参数是在模型训练过程中被优化器更新的。

  • 对于优化器(如torch.optim的实例),state_dict包含优化器的状态信息以及使用的超参数(如学习率、动量等)。这些状态信息用于在训练过程中更新模型的参数。

通过调用torch.save(model.state_dict(), PATH)可以将模型的state_dict保存到磁盘上,其中PATH是保存的路径。同样地,通过model.load_state_dict(torch.load(PATH))可以加载之前保存的state_dict到模型中。

#encoding:utf-8
 
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import numpy as mp
import matplotlib.pyplot as plt
import torch.nn.functional as F
 
#define model
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass,self).__init__()
        self.conv1=nn.Conv2d(3,6,5)
        self.pool=nn.MaxPool2d(2,2)
        self.conv2=nn.Conv2d(6,16,5)
        self.fc1=nn.Linear(16*5*5,120)
        self.fc2=nn.Linear(120,84)
        self.fc3=nn.Linear(84,10)
 
    def forward(self,x):
        x=self.pool(F.relu(self.conv1(x)))
        x=self.pool(F.relu(self.conv2(x)))
        x=x.view(-1,16*5*5)
        x=F.relu(self.fc1(x))
        x=F.relu(self.fc2(x))
        x=self.fc3(x)
        return x
 
def main():
    # Initialize model
    model = TheModelClass()
 
    #Initialize optimizer
    optimizer=optim.SGD(model.parameters(),lr=0.001,momentum=0.9)
 
    #print model's state_dict
    print('Model.state_dict:')
    for param_tensor in model.state_dict():
        #打印 key value字典
        print(param_tensor,'\t',model.state_dict()[param_tensor].size())
 
    #print optimizer's state_dict
    print('Optimizer,s state_dict:')
    for var_name in optimizer.state_dict():
        print(var_name,'\t',optimizer.state_dict()[var_name])
 
 
 
if __name__=='__main__':
    main()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
Model.state_dict:
conv1.weight 	 torch.Size([6, 3, 5, 5])
conv1.bias 	 torch.Size([6])
conv2.weight 	 torch.Size([16, 6, 5, 5])
conv2.bias 	 torch.Size([16])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([84, 120])
fc2.bias 	 torch.Size([84])
fc3.weight 	 torch.Size([10, 84])
fc3.bias 	 torch.Size([10])
Optimizer`s state_dict:
state 	 {}
param_groups 	 [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'maximize': False, 'foreach': None, 'differentiable': False, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15

load_state_dict

load_state_dict 是 PyTorch 中 torch.nn.Module 类的一个方法,用于加载模型的状态字典(state dictionary)。状态字典是一个包含模型所有参数的字典,通常通过调用 state_dict() 方法获得。

参数说明

load_state_dict 的主要参数是一个字典,该字典包含了要加载的参数。通常,这个字典是通过 torch.load() 从一个文件(通常是 .pth 或 .pt 文件)中加载的。

state_dict = torch.load('path_to_model.pth')
model.load_state_dict(state_dict)
1
2

注意事项

  1. 模型结构匹配:在调用 load_state_dict 之前,确保你已经定义了与保存状态字典时完全相同的模型结构。如果模型结构不匹配,你将无法加载状态字典,因为 PyTorch 无法将参数映射到正确的位置。
  2. 设备:加载的状态字典中的参数默认在 CPU 上。如果你想在 GPU 上使用这些参数,你需要先将模型移动到 GPU 上,然后再加载状态字典。
model = model.to('cuda')
model.load_state_dict(torch.load('path_to_model.pth'))
1
2

或者,你也可以在加载状态字典后移动模型:

model.load_state_dict(torch.load('path_to_model.pth', map_location=torch.device('cuda')))
model = model.to('cuda')
1
2
  1. 优化器状态:除了模型参数外,你可能还想加载优化器的状态。这可以通过类似的方式完成,但请注意,优化器的状态字典应该单独加载。
optimizer_state_dict = torch.load('path_to_optimizer.pth')
optimizer.load_state_dict(optimizer_state_dict)
1
2

model.train()

model.train()和model.eval()的区别主要在于Batch Normalization和Dropout两层。

当调用 model.train() 时,模型将处于训练模式,这通常意味着:

  1. Dropout:如果模型中包含 Dropout 层,那么在训练模式下,Dropout 层会在前向传播时随机地将一部分神经元的输出设置为零。这有助于防止模型过拟合。在评估模式下,Dropout 层不会进行任何操作,所有神经元的输出都会被保留。

  2. BatchNorm:Batch Normalization(BatchNorm)层在训练和评估模式下的行为也有所不同。在训练模式下,BatchNorm 会计算每个批次的均值和方差,并使用这些统计量来标准化输入。同时,它还会更新其内部运行均值和方差的估计值。在评估模式下,BatchNorm 会使用这些运行均值和方差来进行标准化,而不是每个批次的统计量。

  3. 其他层:有些自定义的层或模块也可能在训练和评估模式下有不同的行为。这取决于这些层或模块的实现。

使用示例

在训练循环的开始,你通常会调用 model.train() 来确保模型处于正确的模式:

model = MyModel()  # 假设 MyModel 是你的模型类
model.train()  # 将模型设置为训练模式

optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

for epoch in range(num_epochs):
    for inputs, targets in dataloader:
        optimizer.zero_grad()  # 清零梯度
        outputs = model(inputs)  # 前向传播
        loss = criterion(outputs, targets)  # 计算损失
        loss.backward()  # 反向传播
        optimizer.step()  # 更新权重
1
2
3
4
5
6
7
8
9
10
11
12

在评估或测试模型时,你应该使用 model.eval() 来确保模型不会应用 Dropout 或使用批次的统计量进行 BatchNorm:

model.eval()  # 将模型设置为评估模式

with torch.no_grad():  # 不计算梯度,节省内存和计算资源
    for inputs, targets in test_dataloader:
        outputs = model(inputs)  # 前向传播
        # ... 计算性能指标等 ...
1
2
3
4
5
6

请注意,在评估模式下,使用 with torch.no_grad(): 块是一个好习惯,因为它可以防止计算不必要的梯度,从而节省计算资源和内存。

model.eval()

model.eval() 是 PyTorch 框架中一个非常关键的方法,用于将模型设置为评估模式(evaluation mode)。在模型训练完成后,我们通常会对模型进行评估或测试以检查其性能。这时候,调用 model.eval() 是非常必要的,因为它会影响模型中的某些层(如 Dropout 和 Batch Normalization)的行为。

主要作用

  1. Dropout:在训练模式下,Dropout 层会随机丢弃一部分神经元的输出,这有助于防止模型过拟合。但在评估模式下,model.eval() 会关闭 Dropout 层的功能,确保所有神经元都参与前向传播,从而得到更稳定、更准确的输出。

  2. Batch Normalization:BatchNorm 层在训练和评估模式下的行为也不同。在训练模式下,BatchNorm 会使用当前批次的均值和方差来标准化输入。而在评估模式下,model.eval() 会指示 BatchNorm 使用训练过程中积累的运行均值(running mean)和运行方差(running variance)来进行标准化。这样做的好处是,模型在评估时对每个输入的标准化方式是一致的,不受批次大小的影响。

使用方法

在 PyTorch 中,使用 model.eval() 很简单。你只需在模型评估或测试之前调用它即可:

model = MyModel()  # 假设 MyModel 是你的模型类
model.load_state_dict(torch.load('path_to_model.pth'))  # 加载预训练模型参数

model.eval()  # 将模型设置为评估模式

with torch.no_grad():  # 不计算梯度,节省内存和计算资源
    for inputs, targets in test_dataloader:
        outputs = model(inputs)  # 进行前向传播
        # ... 计算性能指标等 ...
1
2
3
4
5
6
7
8
9

在上面的代码中,我们首先加载了预训练的模型参数,然后调用 model.eval() 将模型设置为评估模式。注意,我们还使用了 with torch.no_grad(): 块来确保在评估过程中不计算梯度,这有助于节省内存和计算资源。

注意事项

  1. 确保在正确的位置调用:确保在评估或测试开始前调用 model.eval(),并在训练开始前调用 model.train()。不要在训练循环内部多次调用 model.eval(),除非你有特定的需求。

  2. 梯度计算:调用 model.eval() 后,模型中的所有可学习参数的 requires_grad 属性将被设置为 False,这意味着在评估模式下不会计算梯度。这有助于加速推理过程。

  3. BatchNorm 和 Dropout 的固定:如前所述,model.eval() 会固定 BatchNorm 层和关闭 Dropout 层,确保在评估时模型的行为是一致的。

model.eval()和torch.no_grad()的区别

在PyTorch中进行validation/test时,会使用model.eval()切换到测试模式,在该模式下:

  1. 主要用于通知dropout层和BN层在training和validation/test模式间切换:

    • 在train模式下,dropout网络层会按照设定的参数p,设置保留激活单元的概率(保留概率=p)。BN层会继续计算数据的mean和var等参数并更新。
    • 在eval模式下,dropout层会让所有的激活单元都通过,而BN层会停止计算和更新mean和var,直接使用在训练阶段已经学出的mean和var值。
  2. eval模式不会影响各层的gradient计算行为,即gradient计算和存储与training模式一样,只是不进行反向传播(back probagation)。

而with torch.no_grad()则主要是用于停止autograd模块的工作,以起到加速和节省显存的作用。它的作用是将该with语句包裹起来的部分停止梯度的更新,从而节省了GPU算力和显存,但是并不会影响dropout和BN层的行为。

如果不在意显存大小和计算时间的话,仅仅使用model.eval()已足够得到正确的validation/test的结果;而with torch.no_grad()则是更进一步加速和节省gpu空间(因为不用计算和存储梯度),从而可以更快计算,也可以跑更大的batch来测试。

Reference

  • PyTorch:model.train()和model.eval()用法和区别,以及model.eval()和torch.no_grad()的区别 (opens new window)
#PyTorch
上次更新: 2025/06/25, 11:25:50
模型优化
模型示例

← 模型优化 模型示例→

最近更新
01
帮助信息查看
06-08
02
常用命令
06-08
03
学习资源
06-07
更多文章>
Theme by Vdoing | Copyright © 2022-2025 Geeks_Z | MIT License
京公网安备 11010802040735号 | 京ICP备2022029989号-1
  • 跟随系统
  • 浅色模式
  • 深色模式
  • 阅读模式