nn.Module
state_dict
在PyTorch中,
state_dict
是一个字典对象,用于存储模型或优化器的参数。这个字典将每一层或优化器的参数映射到对应的张量。state_dict
的主要作用在于方便模型的保存和加载,以便在训练过程中恢复模型的状态或在其他任务中重用模型。
对于模型(如
torch.nn.Module
的实例),state_dict
包含模型的可学习参数(如权重和偏置)。只有包含可学习参数的层(如卷积层、线性层等)和已注册的缓冲区(如Batch Normalization层的运行均值和方差)才会在state_dict
中有对应的条目。这些参数是在模型训练过程中被优化器更新的。对于优化器(如
torch.optim
的实例),state_dict
包含优化器的状态信息以及使用的超参数(如学习率、动量等)。这些状态信息用于在训练过程中更新模型的参数。
通过调用torch.save(model.state_dict(), PATH)
可以将模型的state_dict
保存到磁盘上,其中PATH
是保存的路径。同样地,通过model.load_state_dict(torch.load(PATH))
可以加载之前保存的state_dict
到模型中。
#encoding:utf-8
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import numpy as mp
import matplotlib.pyplot as plt
import torch.nn.functional as F
#define model
class TheModelClass(nn.Module):
def __init__(self):
super(TheModelClass,self).__init__()
self.conv1=nn.Conv2d(3,6,5)
self.pool=nn.MaxPool2d(2,2)
self.conv2=nn.Conv2d(6,16,5)
self.fc1=nn.Linear(16*5*5,120)
self.fc2=nn.Linear(120,84)
self.fc3=nn.Linear(84,10)
def forward(self,x):
x=self.pool(F.relu(self.conv1(x)))
x=self.pool(F.relu(self.conv2(x)))
x=x.view(-1,16*5*5)
x=F.relu(self.fc1(x))
x=F.relu(self.fc2(x))
x=self.fc3(x)
return x
def main():
# Initialize model
model = TheModelClass()
#Initialize optimizer
optimizer=optim.SGD(model.parameters(),lr=0.001,momentum=0.9)
#print model's state_dict
print('Model.state_dict:')
for param_tensor in model.state_dict():
#打印 key value字典
print(param_tensor,'\t',model.state_dict()[param_tensor].size())
#print optimizer's state_dict
print('Optimizer,s state_dict:')
for var_name in optimizer.state_dict():
print(var_name,'\t',optimizer.state_dict()[var_name])
if __name__=='__main__':
main()
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
Model.state_dict:
conv1.weight torch.Size([6, 3, 5, 5])
conv1.bias torch.Size([6])
conv2.weight torch.Size([16, 6, 5, 5])
conv2.bias torch.Size([16])
fc1.weight torch.Size([120, 400])
fc1.bias torch.Size([120])
fc2.weight torch.Size([84, 120])
fc2.bias torch.Size([84])
fc3.weight torch.Size([10, 84])
fc3.bias torch.Size([10])
Optimizer`s state_dict:
state {}
param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'maximize': False, 'foreach': None, 'differentiable': False, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}]
2
3
4
5
6
7
8
9
10
11
12
13
14
15
load_state_dict
load_state_dict
是 PyTorch 中torch.nn.Module
类的一个方法,用于加载模型的状态字典(state dictionary)。状态字典是一个包含模型所有参数的字典,通常通过调用state_dict()
方法获得。
参数说明
load_state_dict
的主要参数是一个字典,该字典包含了要加载的参数。通常,这个字典是通过 torch.load()
从一个文件(通常是 .pth
或 .pt
文件)中加载的。
state_dict = torch.load('path_to_model.pth')
model.load_state_dict(state_dict)
2
注意事项
- 模型结构匹配:在调用
load_state_dict
之前,确保你已经定义了与保存状态字典时完全相同的模型结构。如果模型结构不匹配,你将无法加载状态字典,因为 PyTorch 无法将参数映射到正确的位置。 - 设备:加载的状态字典中的参数默认在 CPU 上。如果你想在 GPU 上使用这些参数,你需要先将模型移动到 GPU 上,然后再加载状态字典。
model = model.to('cuda')
model.load_state_dict(torch.load('path_to_model.pth'))
2
或者,你也可以在加载状态字典后移动模型:
model.load_state_dict(torch.load('path_to_model.pth', map_location=torch.device('cuda')))
model = model.to('cuda')
2
- 优化器状态:除了模型参数外,你可能还想加载优化器的状态。这可以通过类似的方式完成,但请注意,优化器的状态字典应该单独加载。
optimizer_state_dict = torch.load('path_to_optimizer.pth')
optimizer.load_state_dict(optimizer_state_dict)
2
model.train()
model.train()和model.eval()的区别主要在于Batch Normalization和Dropout两层。

当调用 model.train()
时,模型将处于训练模式,这通常意味着:
Dropout:如果模型中包含 Dropout 层,那么在训练模式下,Dropout 层会在前向传播时随机地将一部分神经元的输出设置为零。这有助于防止模型过拟合。在评估模式下,Dropout 层不会进行任何操作,所有神经元的输出都会被保留。
BatchNorm:Batch Normalization(BatchNorm)层在训练和评估模式下的行为也有所不同。在训练模式下,BatchNorm 会计算每个批次的均值和方差,并使用这些统计量来标准化输入。同时,它还会更新其内部运行均值和方差的估计值。在评估模式下,BatchNorm 会使用这些运行均值和方差来进行标准化,而不是每个批次的统计量。
其他层:有些自定义的层或模块也可能在训练和评估模式下有不同的行为。这取决于这些层或模块的实现。
使用示例
在训练循环的开始,你通常会调用 model.train()
来确保模型处于正确的模式:
model = MyModel() # 假设 MyModel 是你的模型类
model.train() # 将模型设置为训练模式
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
for epoch in range(num_epochs):
for inputs, targets in dataloader:
optimizer.zero_grad() # 清零梯度
outputs = model(inputs) # 前向传播
loss = criterion(outputs, targets) # 计算损失
loss.backward() # 反向传播
optimizer.step() # 更新权重
2
3
4
5
6
7
8
9
10
11
12
在评估或测试模型时,你应该使用 model.eval()
来确保模型不会应用 Dropout 或使用批次的统计量进行 BatchNorm:
model.eval() # 将模型设置为评估模式
with torch.no_grad(): # 不计算梯度,节省内存和计算资源
for inputs, targets in test_dataloader:
outputs = model(inputs) # 前向传播
# ... 计算性能指标等 ...
2
3
4
5
6
请注意,在评估模式下,使用 with torch.no_grad():
块是一个好习惯,因为它可以防止计算不必要的梯度,从而节省计算资源和内存。
model.eval()

model.eval()
是 PyTorch 框架中一个非常关键的方法,用于将模型设置为评估模式(evaluation mode)。在模型训练完成后,我们通常会对模型进行评估或测试以检查其性能。这时候,调用 model.eval()
是非常必要的,因为它会影响模型中的某些层(如 Dropout 和 Batch Normalization)的行为。
主要作用
Dropout:在训练模式下,Dropout 层会随机丢弃一部分神经元的输出,这有助于防止模型过拟合。但在评估模式下,
model.eval()
会关闭 Dropout 层的功能,确保所有神经元都参与前向传播,从而得到更稳定、更准确的输出。Batch Normalization:BatchNorm 层在训练和评估模式下的行为也不同。在训练模式下,BatchNorm 会使用当前批次的均值和方差来标准化输入。而在评估模式下,
model.eval()
会指示 BatchNorm 使用训练过程中积累的运行均值(running mean)和运行方差(running variance)来进行标准化。这样做的好处是,模型在评估时对每个输入的标准化方式是一致的,不受批次大小的影响。
使用方法
在 PyTorch 中,使用 model.eval()
很简单。你只需在模型评估或测试之前调用它即可:
model = MyModel() # 假设 MyModel 是你的模型类
model.load_state_dict(torch.load('path_to_model.pth')) # 加载预训练模型参数
model.eval() # 将模型设置为评估模式
with torch.no_grad(): # 不计算梯度,节省内存和计算资源
for inputs, targets in test_dataloader:
outputs = model(inputs) # 进行前向传播
# ... 计算性能指标等 ...
2
3
4
5
6
7
8
9
在上面的代码中,我们首先加载了预训练的模型参数,然后调用 model.eval()
将模型设置为评估模式。注意,我们还使用了 with torch.no_grad():
块来确保在评估过程中不计算梯度,这有助于节省内存和计算资源。
注意事项
确保在正确的位置调用:确保在评估或测试开始前调用
model.eval()
,并在训练开始前调用model.train()
。不要在训练循环内部多次调用model.eval()
,除非你有特定的需求。梯度计算:调用
model.eval()
后,模型中的所有可学习参数的requires_grad
属性将被设置为False
,这意味着在评估模式下不会计算梯度。这有助于加速推理过程。BatchNorm 和 Dropout 的固定:如前所述,
model.eval()
会固定 BatchNorm 层和关闭 Dropout 层,确保在评估时模型的行为是一致的。
model.eval()和torch.no_grad()的区别
在PyTorch中进行validation/test时,会使用model.eval()
切换到测试模式,在该模式下:
主要用于通知dropout层和BN层在training和validation/test模式间切换:
- 在train模式下,dropout网络层会按照设定的参数p,设置保留激活单元的概率(保留概率=p)。BN层会继续计算数据的mean和var等参数并更新。
- 在eval模式下,dropout层会让所有的激活单元都通过,而BN层会停止计算和更新mean和var,直接使用在训练阶段已经学出的mean和var值。
eval模式不会影响各层的gradient计算行为,即gradient计算和存储与training模式一样,只是不进行反向传播(back probagation)。
而with torch.no_grad()
则主要是用于停止autograd模块的工作,以起到加速和节省显存的作用。它的作用是将该with语句包裹起来的部分停止梯度的更新,从而节省了GPU算力和显存,但是并不会影响dropout和BN层的行为。
如果不在意显存大小和计算时间的话,仅仅使用model.eval()
已足够得到正确的validation/test的结果;而with torch.no_grad()
则是更进一步加速和节省gpu空间(因为不用计算和存储梯度),从而可以更快计算,也可以跑更大的batch来测试。