Geeks_Z の Blog Geeks_Z の Blog
首页
  • 学习笔记

    • 《HTML》
    • 《CSS》
    • 《JavaWeb》
    • 《Vue》
  • 后端文章

    • Linux
    • Maven
    • 汇编语言
    • 软件工程
    • 计算机网络概述
    • Conda
    • Pip
    • Shell
    • SSH
    • Mac快捷键
    • Zotero
  • 学习笔记

    • 《数据结构与算法》
    • 《算法设计与分析》
    • 《Spring》
    • 《SpringMVC》
    • 《SpringBoot》
    • 《SpringCloud》
    • 《Nginx》
  • 深度学习文章
  • 学习笔记

    • 《PyTorch》
    • 《ReinforementLearning》
    • 《MetaLearning》
  • 学习笔记

    • 《高等数学》
    • 《线性代数》
    • 《概率论与数理统计》
  • 增量学习
  • 哈希学习
GitHub (opens new window)

Geeks_Z

AI小学生
首页
  • 学习笔记

    • 《HTML》
    • 《CSS》
    • 《JavaWeb》
    • 《Vue》
  • 后端文章

    • Linux
    • Maven
    • 汇编语言
    • 软件工程
    • 计算机网络概述
    • Conda
    • Pip
    • Shell
    • SSH
    • Mac快捷键
    • Zotero
  • 学习笔记

    • 《数据结构与算法》
    • 《算法设计与分析》
    • 《Spring》
    • 《SpringMVC》
    • 《SpringBoot》
    • 《SpringCloud》
    • 《Nginx》
  • 深度学习文章
  • 学习笔记

    • 《PyTorch》
    • 《ReinforementLearning》
    • 《MetaLearning》
  • 学习笔记

    • 《高等数学》
    • 《线性代数》
    • 《概率论与数理统计》
  • 增量学习
  • 哈希学习
GitHub (opens new window)
  • Python

  • MLTutorials

  • 卷积神经网络

  • 循环神经网络

  • Transformer

  • VisionTransformer

  • 扩散模型

  • 计算机视觉

  • PTM

  • MoE

  • LoRAMoE

  • LongTailed

  • 多模态

  • 知识蒸馏

  • PEFT

  • 对比学习

  • 小样本学习

  • 迁移学习

  • 零样本学习

  • 集成学习

  • Mamba

  • PyTorch

    • PyTorch概述

    • Tensors

    • 数据处理

    • 模型

      • 模型构建
      • 模型容器
      • 模型参数
      • 权值初始化
      • 模型保存与加载
        • 模型存储内容
        • torch.save
          • 主要参数
          • 保存整个 Module
          • 只保存模型的参数
        • torch.load
          • 主要参数
          • 加载整个 Module
          • 只加载模型的参数
          • 注意事项
        • 模型的断点续训练
        • 单卡和多卡模型存储的区别
        • 单卡/多卡情况分类讨论
        • 其他参数的保存和读取
        • Reference
      • 模型修改
      • 模型优化
      • nn.Module
      • 模型示例
    • 训练

    • 并行计算

    • 可视化

    • 实战

    • timm

    • Pytorch Lightning

    • 数据增强

    • 面经与bug解决

    • 常用代码片段

    • Reference
  • CL

  • CIL

  • 小样本类增量学习FSCIL

  • UCIL

  • 多模态增量学习MMCL

  • LTCIL

  • DIL

  • 论文阅读与写作

  • 分布外检测

  • GPU

  • 深度学习调参指南

  • AINotes
  • PyTorch
  • 模型
Geeks_Z
2023-01-26
目录

模型保存与加载

模型存储内容

一个PyTorch模型主要包含两个部分:模型结构和权重。其中模型是继承nn.Module的类,权重的数据结构是一个字典(key是层名,value是权重向量)。存储也由此分为两种形式:存储整个模型(包括结构和权重),和只存储模型权重。

from torchvision import models
model = models.resnet152(pretrained=True)
save_dir = './resnet152.pth'

# 保存整个模型
torch.save(model, save_dir)
# 保存模型权重
torch.save(model.state_dict, save_dir)
1
2
3
4
5
6
7
8

对于PyTorch而言,pt, pth和pkl三种数据格式均支持模型权重和整个模型的存储,因此使用上没有差别。

torch.save

torch.save 是 PyTorch 中用于保存模型、张量、字典或其他 Python 对象到磁盘的函数。

torch.save(obj, f, pickle_module, pickle_protocol=2, _use_new_zipfile_serialization=False)
1

主要参数

  1. obj:要保存的对象。这可以是一个模型 (nn.Module 的实例)、张量 (torch.Tensor)、字典或任何其他 Python 对象。

  2. f:保存对象的文件路径或文件对象(类似一个打开的文件句柄)。可以是一个字符串,表示文件路径,也可以是一个已打开的文件对象(如通过 open 函数打开的文件)。

  3. pickle_module:用于序列化的模块。默认为 Python 的 pickle 模块。但在某些情况下,如使用 PyTorch 的特定环境(如某些受限的或旧版本的 Python 环境),你可能需要使用 torch.jit.pickle 作为替代。

  4. pickle_protocol:使用的 pickle 协议版本。默认为 2,但在某些情况下,你可能需要设置为 4 以获得更好的兼容性或性能。

  5. _use_new_zipfile_serialization:一个布尔值,决定是否使用新的 zipfile 序列化格式。默认是 False。这个参数主要是为了解决在某些环境中的兼容性问题。

其中模型保存还有两种方式:

保存整个 Module

这种方法比较耗时,保存的文件大

torch.save(net, path)
1

只保存模型的参数

推荐这种方法,运行比较快,保存的文件比较小

state_sict = net.state_dict()
torch.save(state_sict, path)
1
2
  1. 可以使用model.eval()将 dropout 和 batch normalization 层设置成 evaluation 模式。
  2. load_state_dict()函数需要一个 dict 类型的输入,而不是保存模型的 PATH。所以这样 model.load_state_dict(PATH)是错误的,而应该model.load_state_dict(torch.load(PATH))。
  3. 如果你想保存验证集上表现最好的模型,那么这样best_model_state=model.state_dict()是错误的。因为这属于浅复制,也就是说此时这个 best_model_state 会随着后续的训练过程而不断被更新,最后保存的其实是个 overfit 的模型。所以正确的做法应该是best_model_state=deepcopy(model.state_dict())。

下面是保存 LeNet 的例子。在网络初始化中,把权值都设置为 2020,然后保存模型。

import torch
import numpy as np
import torch.nn as nn
from common_tools import set_seed

class LeNet2(nn.Module):
    def __init__(self, classes):
        super(LeNet2, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 6, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(6, 16, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.classifier = nn.Sequential(
            nn.Linear(16*5*5, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size()[0], -1)
        x = self.classifier(x)
        return x

    def initialize(self):
        for p in self.parameters():
            p.data.fill_(2020)

net = LeNet2(classes=2019)

# "训练"
print("训练前: ", net.features[0].weight[0, ...])
net.initialize()
print("训练后: ", net.features[0].weight[0, ...])

path_model = "./model.pkl"
path_state_dict = "./model_state_dict.pkl"

# 保存整个模型
torch.save(net, path_model)

# 保存模型参数
net_state_dict = net.state_dict()
torch.save(net_state_dict, path_state_dict)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50

运行完之后,文件夹中生成了``model.pkl和model_state_dict.pkl`,分别保存了整个网络和网络的参数

torch.load

torch.load 是 PyTorch 中用于从磁盘加载之前通过 torch.save 保存的对象(如模型、张量或字典)的函数。

torch.load(f, map_location=None, pickle_module, **pickle_load_args)
1

主要参数

  1. f:要加载的文件的路径或文件对象。这通常是一个字符串,表示文件的路径,或者是一个已经打开的文件对象。

  2. map_location:一个可选参数,用于指定如何重新映射存储位置。这个参数在加载模型到不同的设备(如CPU或GPU)时非常有用。它可以是一个设备对象(如 torch.device('cpu'))或包含设备标签的字符串。如果提供了这个参数,torch.load 会将加载的数据映射到指定的设备上。默认情况下,数据会加载到它原来保存的设备上。

  3. pickle_module:用于反序列化的模块。默认为 Python 的 pickle 模块,但在某些情况下,你可能需要使用 torch.jit.pickle。

  4. pickle_load_args:可选参数,允许传入额外的参数到 pickle_module.load() 方法中。

加载整个 Module

如果保存的时候,保存的是整个模型,那么加载时就加载整个模型。这种方法不需要事先创建一个模型对象,也不用知道模型的结构,代码如下:

path_model = "./model.pkl"
net_load = torch.load(path_model)

print(net_load)
1
2
3
4

输出如下:

LeNet2(
  (features): Sequential(
    (0): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
    (0): Linear(in_features=400, out_features=120, bias=True)
    (1): ReLU()
    (2): Linear(in_features=120, out_features=84, bias=True)
    (3): ReLU()
    (4): Linear(in_features=84, out_features=2019, bias=True)
  )
)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17

只加载模型的参数

如果保存的时候,保存的是模型的参数,那么加载时就参数。这种方法需要事先创建一个模型对象,再使用模型的load_state_dict()方法把参数加载到模型中,代码如下:

path_state_dict = "./model_state_dict.pkl"
state_dict_load = torch.load(path_state_dict)
net_new = LeNet2(classes=2019)

print("加载前: ", net_new.features[0].weight[0, ...])
net_new.load_state_dict(state_dict_load)
print("加载后: ", net_new.features[0].weight[0, ...])
1
2
3
4
5
6
7

输出如下:

加载前:  tensor([[[ 0.0775,  0.0374,  0.0163,  0.0196, -0.0884],
         [ 0.0293, -0.1051, -0.0362,  0.1122, -0.0616],
         [ 0.0083,  0.0274,  0.0158,  0.0301,  0.0937],
         [-0.0459, -0.1062,  0.0510, -0.0058,  0.1046],
         [-0.0672, -0.0204,  0.0134,  0.0594,  0.0421]],

        [[ 0.0058, -0.0435, -0.0550,  0.0591, -0.1067],
         [ 0.0929,  0.0202, -0.0027,  0.0264,  0.0409],
         [ 0.0038, -0.0219, -0.0522, -0.0065,  0.0717],
         [-0.0300, -0.0819, -0.0238, -0.0132, -0.0364],
         [ 0.0258, -0.0238, -0.0680, -0.0172,  0.0902]],

        [[-0.1087,  0.0948, -0.0848,  0.1148, -0.0212],
         [-0.0634,  0.0479,  0.0064, -0.0287,  0.0732],
         [-0.1080,  0.0522, -0.0891, -0.1137,  0.0838],
         [ 0.0740,  0.0965,  0.0893, -0.1075,  0.0277],
         [-0.0060, -0.0713,  0.0996,  0.0865, -0.0181]]],
       grad_fn=<SelectBackward0>)
加载后:  tensor([[[2020., 2020., 2020., 2020., 2020.],
         [2020., 2020., 2020., 2020., 2020.],
         [2020., 2020., 2020., 2020., 2020.],
         [2020., 2020., 2020., 2020., 2020.],
         [2020., 2020., 2020., 2020., 2020.]],

        [[2020., 2020., 2020., 2020., 2020.],
         [2020., 2020., 2020., 2020., 2020.],
         [2020., 2020., 2020., 2020., 2020.],
         [2020., 2020., 2020., 2020., 2020.],
         [2020., 2020., 2020., 2020., 2020.]],

        [[2020., 2020., 2020., 2020., 2020.],
         [2020., 2020., 2020., 2020., 2020.],
         [2020., 2020., 2020., 2020., 2020.],
         [2020., 2020., 2020., 2020., 2020.],
         [2020., 2020., 2020., 2020., 2020.]]], grad_fn=<SelectBackward0>)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35

注意事项

  • 如果你的模型是在GPU上训练的,并且你现在在CPU上加载它,或者反之,你可能需要在加载权重之前将模型移动到相应的设备上。例如,如果模型应该在GPU上运行,你可以使用model.to('cuda')来移动模型。
  • 当加载模型时,你可能还想加载优化器的状态(如果你打算继续训练模型的话)。这可以通过类似的方式完成:optimizer_state_dict = torch.load('optimizer_state.pth') 和 optimizer.load_state_dict(optimizer_state_dict)。
  • 如果你的模型是从旧版本的PyTorch保存的,而你现在正在使用新版本的PyTorch,通常不会有问题,因为PyTorch努力保持向后兼容性。但是,如果版本差异非常大,或者使用了特定的特性,则可能会出现问题。在这种情况下,尝试在保存模型时使用与加载模型时相同的PyTorch版本。
  • 模型在内存中是以对象的逻辑结构保存的,但是在硬盘中是以二进制流的方式保存的。
  • 序列化是指将内存中的数据以二进制序列的方式保存到硬盘中。PyTorch 的模型保存就是序列化。
  • 反序列化是指将硬盘中的二进制序列加载到内存中,得到模型的对象。PyTorch 的模型加载就是反序列化。

模型的断点续训练

在训练过程中,可能由于某种意外原因如断点等导致训练终止,这时需要重新开始训练。断点续练是在训练过程中每隔一定次数的 epoch 就保存模型的参数和优化器的参数,这样如果意外终止训练了,下次就可以重新加载最新的模型参数和优化器的参数,在这个基础上继续训练。

下面的代码中,每隔 5 个 epoch 就保存一次,保存的是一个 dict,包括模型参数、优化器的参数、epoch。然后在 epoch 大于 5 时,就break模拟训练意外终止。关键代码如下:

if (epoch+1) % checkpoint_interval == 0:

        checkpoint = {"model_state_dict": net.state_dict(),
                      "optimizer_state_dict": optimizer.state_dict(),
                      "epoch": epoch}
        path_checkpoint = "./checkpoint_{}_epoch.pkl".format(epoch)
        torch.save(checkpoint, path_checkpoint)
1
2
3
4
5
6
7

在 epoch 大于 5 时,就break模拟训练意外终止

if epoch > 5:
        print("训练意外中断...")
        break
1
2
3

断点续训练的恢复代码如下:

path_checkpoint = "./checkpoint_4_epoch.pkl"
checkpoint = torch.load(path_checkpoint)

net.load_state_dict(checkpoint['model_state_dict'])

optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

start_epoch = checkpoint['epoch']

scheduler.last_epoch = start_epoch
1
2
3
4
5
6
7
8
9
10

需要注意的是,还要设置scheduler.last_epoch参数为保存的 epoch。模型训练的起始 epoch 也要修改为保存的 epoch。

单卡和多卡模型存储的区别

PyTorch中将模型和数据放到GPU上有两种方式——.cuda()和.to(device),后续内容针对前一种方式进行讨论。如果要使用多卡训练的话,需要对模型使用torch.nn.DataParallel。示例如下:

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0' # 如果是多卡改成类似0,1,2
model = model.cuda()  # 单卡
model = torch.nn.DataParallel(model).cuda()  # 多卡
1
2
3
4

之后我们把model对应的layer名称打印出来看一下,可以观察到差别在于多卡并行的模型每层的名称前多了一个“module”。

  • 单卡模型的层名:
  • 多卡模型的层名:

这种模型表示的不同可能会导致模型保存和加载过程中需要处理一些矛盾点,下面对各种可能的情况做分类讨论。

单卡/多卡情况分类讨论

由于训练和测试所使用的硬件条件不同,在模型的保存和加载过程中可能因为单GPU和多GPU环境的不同带来模型不匹配等问题。这里对PyTorch框架下单卡/多卡下模型的保存和加载问题进行排列组合,样例模型是torchvision中预训练模型resnet152。

  • 单卡保存+单卡加载

在使用os.envision命令指定使用的GPU后,即可进行模型保存和读取操作。注意这里即便保存和读取时使用的GPU不同也无妨。

import os
import torch
from torchvision import models

os.environ['CUDA_VISIBLE_DEVICES'] = '0'   #这里替换成希望使用的GPU编号
model = models.resnet152(pretrained=True)
model.cuda()

save_dir = 'resnet152.pt'   #保存路径

# 保存+读取整个模型
torch.save(model, save_dir)
loaded_model = torch.load(save_dir)
loaded_model.cuda()

# 保存+读取模型权重
torch.save(model.state_dict(), save_dir)
loaded_model = models.resnet152()   #注意这里需要对模型结构有定义
loaded_model.load_state_dict(torch.load(save_dir))
loaded_model.cuda()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
  • 单卡保存+多卡加载

这种情况的处理比较简单,读取单卡保存的模型后,使用nn.DataParallel函数进行分布式训练设置即可:

import os
import torch
from torchvision import models

os.environ['CUDA_VISIBLE_DEVICES'] = '0'   #这里替换成希望使用的GPU编号
model = models.resnet152(pretrained=True)
model.cuda()

# 保存+读取整个模型
torch.save(model, save_dir)

os.environ['CUDA_VISIBLE_DEVICES'] = '1,2'   #这里替换成希望使用的GPU编号
loaded_model = torch.load(save_dir)
loaded_model = nn.DataParallel(loaded_model).cuda()

# 保存+读取模型权重
torch.save(model.state_dict(), save_dir)

os.environ['CUDA_VISIBLE_DEVICES'] = '1,2'   #这里替换成希望使用的GPU编号
loaded_model = models.resnet152()   #注意这里需要对模型结构有定义
loaded_model.load_state_dict(torch.load(save_dir))
loaded_model = nn.DataParallel(loaded_model).cuda()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
  • 多卡保存+单卡加载

这种情况下的核心问题是:如何去掉权重字典键名中的"module",以保证模型的统一性。

对于加载整个模型,直接提取模型的module属性即可:

import os
import torch
from torchvision import models

os.environ['CUDA_VISIBLE_DEVICES'] = '1,2'   #这里替换成希望使用的GPU编号

model = models.resnet152(pretrained=True)
model = nn.DataParallel(model).cuda()

# 保存+读取整个模型
torch.save(model, save_dir)

os.environ['CUDA_VISIBLE_DEVICES'] = '0'   #这里替换成希望使用的GPU编号
loaded_model = torch.load(save_dir).module
1
2
3
4
5
6
7
8
9
10
11
12
13
14

对于加载模型权重,有以下几种思路: 保存模型时保存模型的module属性对应的权重

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2'   #这里替换成希望使用的GPU编号
import torch
from torchvision import models

save_dir = 'resnet152.pth'   #保存路径
model = models.resnet152(pretrained=True)
model = nn.DataParallel(model).cuda()

# 保存权重
torch.save(model.module.state_dict(), save_dir)
1
2
3
4
5
6
7
8
9
10
11

这样保存下来的模型参数就和单卡保存的模型参数一样了,可以直接加载。也是比较推荐的一种方法。 去除字典里的module麻烦,往model里添加module简单

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2'   #这里替换成希望使用的GPU编号
import torch
from torchvision import models

model = models.resnet152(pretrained=True)
model = nn.DataParallel(model).cuda()

# 保存+读取模型权重
torch.save(model.state_dict(), save_dir)

os.environ['CUDA_VISIBLE_DEVICES'] = '0'   #这里替换成希望使用的GPU编号
loaded_model = models.resnet152()   #注意这里需要对模型结构有定义
loaded_model.load_state_dict(torch.load(save_dir))
loaded_model = nn.DataParallel(loaded_model).cuda()
loaded_model.state_dict = loaded_dict
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16

这样即便是单卡,也可以开始训练了(相当于分布到单卡上)

遍历字典去除module

from collections import OrderedDict
os.environ['CUDA_VISIBLE_DEVICES'] = '0'   #这里替换成希望使用的GPU编号

loaded_dict = torch.load(save_dir)

new_state_dict = OrderedDict()
for k, v in loaded_dict.items():
    name = k[7:] # module字段在最前面,从第7个字符开始就可以去掉module
    new_state_dict[name] = v #新字典的key值对应的value一一对应

loaded_model = models.resnet152()   #注意这里需要对模型结构有定义
loaded_model.state_dict = new_state_dict
loaded_model = loaded_model.cuda()
1
2
3
4
5
6
7
8
9
10
11
12
13

使用replace操作去除module

loaded_model = models.resnet152()    
loaded_dict = torch.load(save_dir)
loaded_model.load_state_dict({k.replace('module.', ''): v for k, v in loaded_dict.items()})
1
2
3
  • 多卡保存+多卡加载

由于是模型保存和加载都使用的是多卡,因此不存在模型层名前缀不同的问题。但多卡状态下存在一个device(使用的GPU)匹配的问题,即保存整个模型时会同时保存所使用的GPU id等信息,读取时若这些信息和当前使用的GPU信息不符则可能会报错或者程序不按预定状态运行。具体表现为以下两点:

读取整个模型再使用nn.DataParallel进行分布式训练设置

这种情况很可能会造成保存的整个模型中GPU id和读取环境下设置的GPU id不符,训练时数据所在device和模型所在device不一致而报错。

读取整个模型而不使用nn.DataParallel进行分布式训练设置

这种情况可能不会报错,测试中发现程序会自动使用设备的前n个GPU进行训练(n是保存的模型使用的GPU个数)。此时如果指定的GPU个数少于n,则会报错。在这种情况下,只有保存模型时环境的device id和读取模型时环境的device id一致,程序才会按照预期在指定的GPU上进行分布式训练。

相比之下,读取模型权重,之后再使用nn.DataParallel进行分布式训练设置则没有问题。因此多卡模式下建议使用权重的方式存储和读取模型:

import os
import torch
from torchvision import models

os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2'   #这里替换成希望使用的GPU编号

model = models.resnet152(pretrained=True)
model = nn.DataParallel(model).cuda()

# 保存+读取模型权重,强烈建议!!
torch.save(model.state_dict(), save_dir)
loaded_model = models.resnet152()   #注意这里需要对模型结构有定义
loaded_model.load_state_dict(torch.load(save_dir)))
loaded_model = nn.DataParallel(loaded_model).cuda()
1
2
3
4
5
6
7
8
9
10
11
12
13
14

如果只有保存的整个模型,也可以采用提取权重的方式构建新的模型:

# 读取整个模型
loaded_whole_model = torch.load(save_dir)
loaded_model = models.resnet152()   #注意这里需要对模型结构有定义
loaded_model.state_dict = loaded_whole_model.state_dict
loaded_model = nn.DataParallel(loaded_model).cuda()
1
2
3
4
5

另外,上面所有对于loaded_model修改权重字典的形式都是通过赋值来实现的,在PyTorch中还可以通过"load_state_dict"函数来实现。因此在上面的所有示例中,我们使用了两种实现方式。

loaded_model.load_state_dict(loaded_dict)
1

其他参数的保存和读取

在深度学习项目里,有时候我们不仅仅需要保存模型的权重,还需要保存一些其他的参数,比如训练的epoch数、训练的loss,优化器的参数,动态调整学习策略的参数等等。这些参数可以通过字典的形式保存在一个文件里,然后在读取模型时一起读取。这里我们以下方代码为例:

torch.save({
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'lr_scheduler': lr_scheduler.state_dict(),
        'epoch': epoch,
        'args': args,
    }, checkpoint_path)
1
2
3
4
5
6
7

这些参数的读取方式也是类似的:

checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
epoch = checkpoint['epoch']
args = checkpoint['args']
1
2
3
4
5
6

Reference

  1. pytorch 中pkl和pth的区别? (opens new window)
  2. What is the difference between .pt, .pth and .pwf extentions in PyTorch? (opens new window)
#PyTorch
上次更新: 2025/06/25, 11:25:50
权值初始化
模型修改

← 权值初始化 模型修改→

最近更新
01
帮助信息查看
06-08
02
常用命令
06-08
03
学习资源
06-07
更多文章>
Theme by Vdoing | Copyright © 2022-2025 Geeks_Z | MIT License
京公网安备 11010802040735号 | 京ICP备2022029989号-1
  • 跟随系统
  • 浅色模式
  • 深色模式
  • 阅读模式