Geeks_Z の Blog Geeks_Z の Blog
首页
  • 学习笔记

    • 《HTML》
    • 《CSS》
    • 《JavaWeb》
    • 《Vue》
  • 后端文章

    • Linux
    • Maven
    • 汇编语言
    • 软件工程
    • 计算机网络概述
    • Conda
    • Pip
    • Shell
    • SSH
    • Mac快捷键
    • Zotero
  • 学习笔记

    • 《数据结构与算法》
    • 《算法设计与分析》
    • 《Spring》
    • 《SpringMVC》
    • 《SpringBoot》
    • 《SpringCloud》
    • 《Nginx》
  • 深度学习文章
  • 学习笔记

    • 《PyTorch》
    • 《ReinforementLearning》
    • 《MetaLearning》
  • 学习笔记

    • 《高等数学》
    • 《线性代数》
    • 《概率论与数理统计》
  • 增量学习
  • 哈希学习
GitHub (opens new window)

Geeks_Z

AI小学生
首页
  • 学习笔记

    • 《HTML》
    • 《CSS》
    • 《JavaWeb》
    • 《Vue》
  • 后端文章

    • Linux
    • Maven
    • 汇编语言
    • 软件工程
    • 计算机网络概述
    • Conda
    • Pip
    • Shell
    • SSH
    • Mac快捷键
    • Zotero
  • 学习笔记

    • 《数据结构与算法》
    • 《算法设计与分析》
    • 《Spring》
    • 《SpringMVC》
    • 《SpringBoot》
    • 《SpringCloud》
    • 《Nginx》
  • 深度学习文章
  • 学习笔记

    • 《PyTorch》
    • 《ReinforementLearning》
    • 《MetaLearning》
  • 学习笔记

    • 《高等数学》
    • 《线性代数》
    • 《概率论与数理统计》
  • 增量学习
  • 哈希学习
GitHub (opens new window)
  • Python

  • MLTutorials

  • 卷积神经网络

  • 循环神经网络

  • Transformer

  • VisionTransformer

  • 扩散模型

  • 计算机视觉

  • PTM

  • MoE

  • LoRAMoE

  • LongTailed

  • 多模态

  • 知识蒸馏

  • PEFT

  • 对比学习

  • 小样本学习

  • 迁移学习

  • 零样本学习

  • 集成学习

  • Mamba

  • PyTorch

    • PyTorch概述

    • Tensors

    • 数据处理

    • 模型

    • 训练

      • 基本配置
        • 常见的包
        • 超参数设置
        • 基本配置
          • 导入包和版本查询
          • 可复现性
          • 显卡设置
      • 损失函数
      • 优化器
      • PyTorch计算图
    • 并行计算

    • 可视化

    • 实战

    • timm

    • Pytorch Lightning

    • 数据增强

    • 面经与bug解决

    • 常用代码片段

    • Reference
  • CL

  • CIL

  • 小样本类增量学习FSCIL

  • UCIL

  • 多模态增量学习MMCL

  • LTCIL

  • DIL

  • 论文阅读与写作

  • 分布外检测

  • GPU

  • 深度学习调参指南

  • AINotes
  • PyTorch
  • 训练
Geeks_Z
2024-03-18
目录

基本配置

常见的包

import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optimizer
1
2
3
4
5
6

超参数设置

  • batch size
  • 初始学习率(初始)
  • 训练次数(max_epochs)
  • GPU 配置
batch_size = 16
# 批次的大小
lr = 1e-4
# 优化器的学习率
max_epochs = 100
1
2
3
4
5

除了直接将超参数设置在训练的代码里,我们也可以使用 yaml、json,dict 等文件来存储超参数,这样可以方便后续的调试和修改,这种方式也是常见的深度学习库(mmdetection,Paddledetection,detectron2)和一些 AI Lab 里面比较常见的一种参数设置方式。

基本配置

我们的数据和模型如果没有经过显式指明设备,默认会存储在 CPU 上,为了加速模型的训练,我们需要显式调用 GPU,一般情况下 GPU 的设置有两种常见的方式:

导入包和版本查询

import torch
import torch.nn as nn
import torchvision
print(torch.__version__)
print(torch.version.cuda)
print(torch.backends.cudnn.version())
print(torch.cuda.get_device_name(0))
1
2
3
4
5
6
7

可复现性

在硬件设备(CPU、GPU)不同时,完全的可复现性无法保证,即使随机种子相同。但是,在同一个设备上,应该保证可复现性。具体做法是,在程序开始的时候固定 torch 的随机种子,同时也把 numpy 的随机种子固定。

np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
1
2
3
4
5
6

显卡设置

如果只需要一张显卡

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
1
2

如果需要指定多张显卡,比如 0,1 号显卡

PyTorch会在第一次导入时缓存CUDA状态 如果设置CUDA_VISIBLE_DEVICES在导入torch之后,设置会失效

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'  # 必须在导入torch前设置!
import torch
1
2
3

也可以在命令行运行代码时设置显卡

CUDA_VISIBLE_DEVICES=0,1 python train.py
1

清除显存

torch.cuda.empty_cache()
1

也可以使用在命令行重置 GPU 的指令

nvidia-smi --gpu-reset -i [gpu_id]
1
#PyTorch
上次更新: 2025/06/25, 11:25:50
模型示例
损失函数

← 模型示例 损失函数→

最近更新
01
帮助信息查看
06-08
02
常用命令
06-08
03
学习资源
06-07
更多文章>
Theme by Vdoing | Copyright © 2022-2025 Geeks_Z | MIT License
京公网安备 11010802040735号 | 京ICP备2022029989号-1
  • 跟随系统
  • 浅色模式
  • 深色模式
  • 阅读模式