Geeks_Z の Blog Geeks_Z の Blog
首页
  • 学习笔记

    • 《HTML》
    • 《CSS》
    • 《JavaWeb》
    • 《Vue》
  • 后端文章

    • Linux
    • Maven
    • 汇编语言
    • 软件工程
    • 计算机网络概述
    • Conda
    • Pip
    • Shell
    • SSH
    • Mac快捷键
    • Zotero
  • 学习笔记

    • 《数据结构与算法》
    • 《算法设计与分析》
    • 《Spring》
    • 《SpringMVC》
    • 《SpringBoot》
    • 《SpringCloud》
    • 《Nginx》
  • 深度学习文章
  • 学习笔记

    • 《PyTorch》
    • 《ReinforementLearning》
    • 《MetaLearning》
  • 学习笔记

    • 《高等数学》
    • 《线性代数》
    • 《概率论与数理统计》
  • 增量学习
  • 哈希学习
GitHub (opens new window)

Geeks_Z

AI小学生
首页
  • 学习笔记

    • 《HTML》
    • 《CSS》
    • 《JavaWeb》
    • 《Vue》
  • 后端文章

    • Linux
    • Maven
    • 汇编语言
    • 软件工程
    • 计算机网络概述
    • Conda
    • Pip
    • Shell
    • SSH
    • Mac快捷键
    • Zotero
  • 学习笔记

    • 《数据结构与算法》
    • 《算法设计与分析》
    • 《Spring》
    • 《SpringMVC》
    • 《SpringBoot》
    • 《SpringCloud》
    • 《Nginx》
  • 深度学习文章
  • 学习笔记

    • 《PyTorch》
    • 《ReinforementLearning》
    • 《MetaLearning》
  • 学习笔记

    • 《高等数学》
    • 《线性代数》
    • 《概率论与数理统计》
  • 增量学习
  • 哈希学习
GitHub (opens new window)
  • Python

  • MLTutorials

  • 卷积神经网络

  • 循环神经网络

  • Transformer

  • VisionTransformer

  • 扩散模型

  • 计算机视觉

  • PTM

  • MoE

  • LoRAMoE

  • LongTailed

  • 多模态

  • 知识蒸馏

  • PEFT

  • 对比学习

  • 小样本学习

  • 迁移学习

  • 零样本学习

  • 集成学习

  • Mamba

  • PyTorch

    • PyTorch概述

    • Tensors

    • 数据处理

    • 模型

    • 训练

    • 并行计算

    • 可视化

      • 可视化网络结构
      • CNN卷积层可视化
      • TensorBoard
      • 使用wandb可视化训练过程
        • 使用 wandb 可视化训练过程
        • wandb 的安装
        • wandb 的使用
        • demo 演示
      • wandb相关参数解释
    • 实战

    • timm

    • Pytorch Lightning

    • 数据增强

    • 面经与bug解决

    • 常用代码片段

    • Reference
  • CL

  • CIL

  • 小样本类增量学习FSCIL

  • UCIL

  • 多模态增量学习MMCL

  • LTCIL

  • DIL

  • 论文阅读与写作

  • 分布外检测

  • GPU

  • 深度学习调参指南

  • AINotes
  • PyTorch
  • 可视化
Geeks_Z
2024-04-09
目录

使用wandb可视化训练过程

使用 wandb 可视化训练过程

Tensorboard 对数据的保存仅限于本地,也很难分析超参数不同对实验的影响。wandb 的出现很好的解决了这些问题。wandb 是 Weights & Biases 的缩写,它能够自动记录模型训练过程中的超参数和输出指标,然后可视化和比较结果,并快速与其他人共享结果。目前它能够和 Jupyter、TensorFlow、Pytorch、Keras、Scikit、fast.ai、LightGBM、XGBoost 一起结合使用。

wandb 的安装

wandb 的安装非常简单,我们只需要使用 pip 安装即可。

pip install wandb
1

安装完成后,我们需要在官网 (opens new window)注册一个账号并复制下自己的 API keys,然后在本地使用下面的命令登录。

wandb login
1

这时,我们会看到下面的界面,只需要粘贴你的 API keys 即可。

wandb 的使用

wandb 的使用也非常简单,只需要在代码中添加几行代码即可。

import wandb
wandb.init(project='my-project', entity='my-name')
1
2

这里的 project 和 entity 是你在 wandb 上创建的项目名称和用户名,如果你还没有创建项目,可以参考官方文档 (opens new window)。

demo 演示

下面我们使用一个 CIFAR10 的图像分类 demo 来演示 wandb 的使用。


import random  # to set the python random seed
import numpy  # to set the numpy random seed
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.models import resnet18
import warnings
warnings.filterwarnings('ignore')
1
2
3
4
5
6
7
8
9
10
11
12

使用 wandb 的第一步是初始化 wandb,这里我们使用 wandb.init()函数来初始化 wandb,其中 project 是你在 wandb 上创建的项目名称,name 是你的实验名称。

# 初始化wandb
import wandb
wandb.init(project="thorough-pytorch",
           name="wandb_demo",)
1
2
3
4

使用 wandb 的第二步是设置超参数,这里我们使用 wandb.config 来设置超参数,这样我们就可以在 wandb 的界面上看到超参数的变化。wandb.config 的使用方法和字典类似,我们可以使用 config.key 的方式来设置超参数。

# 超参数设置
config = wandb.config  # config的初始化
config.batch_size = 64
config.test_batch_size = 10
config.epochs = 5
config.lr = 0.01
config.momentum = 0.1
config.use_cuda = True
config.seed = 2043
config.log_interval = 10

# 设置随机数
def set_seed(seed):
    random.seed(config.seed)
    torch.manual_seed(config.seed)
    numpy.random.seed(config.seed)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17

第三步是构建训练和测试的 pipeline,这里我们使用 pytorch 的 CIFAR10 数据集和 resnet18 来构建训练和测试的 pipeline。

def train(model, device, train_loader, optimizer):
    model.train()

    for batch_id, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        criterion = nn.CrossEntropyLoss()
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

# wandb.log用来记录一些日志(accuracy,loss and epoch), 便于随时查看网路的性能
def test(model, device, test_loader, classes):
    model.eval()
    test_loss = 0
    correct = 0
    example_images = []

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            criterion = nn.CrossEntropyLoss()
            test_loss += criterion(output, target).item()
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
            example_images.append(wandb.Image(
                data[0], caption="Pred:{} Truth:{}".format(classes[pred[0].item()], classes[target[0]])))

   # 使用wandb.log 记录你想记录的指标
    wandb.log({
        "Examples": example_images,
        "Test Accuracy": 100. * correct / len(test_loader.dataset),
        "Test Loss": test_loss
    })

wandb.watch_called = False


def main():
    use_cuda = config.use_cuda and torch.cuda.is_available()
    device = torch.device("cuda:0" if use_cuda else "cpu")
    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

    # 设置随机数
    set_seed(config.seed)
    torch.backends.cudnn.deterministic = True

    # 数据预处理
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # 加载数据
    train_loader = DataLoader(datasets.CIFAR10(
        root='dataset',
        train=True,
        download=True,
        transform=transform
    ), batch_size=config.batch_size, shuffle=True, **kwargs)

    test_loader = DataLoader(datasets.CIFAR10(
        root='dataset',
        train=False,
        download=True,
        transform=transform
    ), batch_size=config.batch_size, shuffle=False, **kwargs)

    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    model = resnet18(pretrained=True).to(device)
    optimizer = optim.SGD(model.parameters(), lr=config.lr, momentum=config.momentum)

    wandb.watch(model, log="all")
    for epoch in range(1, config.epochs + 1):
        train(model, device, train_loader, optimizer)
        test(model, device, test_loader, classes)

    # 本地和云端模型保存
    torch.save(model.state_dict(), 'model.pth')
    wandb.save('model.pth')


if __name__ == '__main__':
    main()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88

当我们运行完上面的代码后,我们就可以在 wandb 的界面上看到我们的训练结果了和系统的性能指标。同时,我们还可以在 setting 里面设置训练完给我们发送邮件,这样我们就可以在训练完之后及时的查看训练结果了。

我们可以发现,使用 wandb 可以很方便的记录我们的训练结果,除此之外,wandb 还为我们提供了很多的功能,比如:模型的超参数搜索,模型的版本控制,模型的部署等等。这些功能都可以帮助我们更好的管理我们的模型,更好的进行模型的迭代和优化。

#PyTorch
上次更新: 2025/06/25, 11:25:50
TensorBoard
wandb相关参数解释

← TensorBoard wandb相关参数解释→

最近更新
01
帮助信息查看
06-08
02
常用命令
06-08
03
学习资源
06-07
更多文章>
Theme by Vdoing | Copyright © 2022-2025 Geeks_Z | MIT License
京公网安备 11010802040735号 | 京ICP备2022029989号-1
  • 跟随系统
  • 浅色模式
  • 深色模式
  • 阅读模式