Geeks_Z の Blog Geeks_Z の Blog
首页
  • 学习笔记

    • 《HTML》
    • 《CSS》
    • 《JavaWeb》
    • 《Vue》
  • 后端文章

    • Linux
    • Maven
    • 汇编语言
    • 软件工程
    • 计算机网络概述
    • Conda
    • Pip
    • Shell
    • SSH
    • Mac快捷键
    • Zotero
  • 学习笔记

    • 《数据结构与算法》
    • 《算法设计与分析》
    • 《Spring》
    • 《SpringMVC》
    • 《SpringBoot》
    • 《SpringCloud》
    • 《Nginx》
  • 深度学习文章
  • 学习笔记

    • 《PyTorch》
    • 《ReinforementLearning》
    • 《MetaLearning》
  • 学习笔记

    • 《高等数学》
    • 《线性代数》
    • 《概率论与数理统计》
  • 增量学习
  • 哈希学习
GitHub (opens new window)

Geeks_Z

AI小学生
首页
  • 学习笔记

    • 《HTML》
    • 《CSS》
    • 《JavaWeb》
    • 《Vue》
  • 后端文章

    • Linux
    • Maven
    • 汇编语言
    • 软件工程
    • 计算机网络概述
    • Conda
    • Pip
    • Shell
    • SSH
    • Mac快捷键
    • Zotero
  • 学习笔记

    • 《数据结构与算法》
    • 《算法设计与分析》
    • 《Spring》
    • 《SpringMVC》
    • 《SpringBoot》
    • 《SpringCloud》
    • 《Nginx》
  • 深度学习文章
  • 学习笔记

    • 《PyTorch》
    • 《ReinforementLearning》
    • 《MetaLearning》
  • 学习笔记

    • 《高等数学》
    • 《线性代数》
    • 《概率论与数理统计》
  • 增量学习
  • 哈希学习
GitHub (opens new window)
  • Python

  • MLTutorials

  • 卷积神经网络

  • 循环神经网络

  • Transformer

  • VisionTransformer

  • 扩散模型

  • 计算机视觉

  • PTM

  • MoE

  • LoRAMoE

  • LongTailed

  • 多模态

  • 知识蒸馏

  • PEFT

  • 对比学习

  • 小样本学习

  • 迁移学习

  • 零样本学习

  • 集成学习

  • Mamba

  • PyTorch

    • PyTorch概述

    • Tensors

    • 数据处理

    • 模型

    • 训练

    • 并行计算

    • 可视化

    • 实战

    • timm

    • Pytorch Lightning

    • 数据增强

      • 数据增强
      • 基础数据增强
      • 高级数据增强
        • 高级数据增强方法
          • Mixup 和 CutMix
          • 应用 Mixup 或 CutMix 进行训练
          • TransMix
          • 导入必要的库
          • 数据准备
          • 训练和验证
      • 自监督学习增强
      • imgaug
    • 面经与bug解决

    • 常用代码片段

    • Reference
  • CL

  • CIL

  • 小样本类增量学习FSCIL

  • UCIL

  • 多模态增量学习MMCL

  • LTCIL

  • DIL

  • 论文阅读与写作

  • 分布外检测

  • GPU

  • 深度学习调参指南

  • AINotes
  • PyTorch
  • 数据增强
Geeks_Z
2024-07-15
目录

高级数据增强

高级数据增强方法

  • Cutout:在图像上随机遮挡一个矩形区域。
  • Mixup:将两张图像按照一定比例进行线性混合,同时混合对应的标签。
  • CutMix:将一张图像的矩形区域剪切并粘贴到另一张图像上,同时混合标签。
  • Random Erasing:在图像中随机擦除一个区域。
  • TransMix: TransMix 是一种用于增强 Vision Transformer (ViT) 模型的高级数据增强方法。它结合了 Mixup 和 CutMix 的思想,并应用于 Transformer 的 attention 机制上。

Mixup 和 CutMix

import numpy as np
import torch

def mixup_data(x, y, alpha=1.0):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    index = torch.randperm(batch_size).cuda()

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def cutmix_data(x, y, alpha=1.0):
    '''Returns cutmix inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    index = torch.randperm(batch_size).cuda()

    bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam)
    x[:, :, bbx1:bbx2, bby1:bby2] = x[index, :, bbx1:bbx2, bby1:bby2]

    y_a, y_b = y, y[index]
    return x, y_a, y_b, lam

def rand_bbox(size, lam):
    '''Generate random bounding box'''
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51

应用 Mixup 或 CutMix 进行训练

for i, data in enumerate(trainloader, 0):
    inputs, labels = data
    inputs, labels = inputs.to(device), labels.to(device)

    # 使用Mixup
    inputs, targets_a, targets_b, lam = mixup_data(inputs, labels, alpha=1.0)
    # 或者使用CutMix
    # inputs, targets_a, targets_b, lam = cutmix_data(inputs, labels, alpha=1.0)

    inputs, targets_a, targets_b = map(torch.autograd.Variable, (inputs, targets_a, targets_b))

    optimizer.zero_grad()
    outputs = model(inputs)
    loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam)
    # 如果使用CutMix,使用相同的损失函数计算方法
    loss.backward()
    optimizer.step()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17

TransMix

导入必要的库

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from vit_pytorch import ViT
import numpy as np

# 定义TransMix增强方法
def transmix_data(x, y, alpha=1.0):
    '''Returns transmix inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    index = torch.randperm(batch_size).cuda()

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]

    # Generate random attention map
    attention_map = torch.rand_like(mixed_x)
    mixed_x = lam * x * attention_map + (1 - lam) * x[index, :] * (1 - attention_map)

    return mixed_x, y_a, y_b, lam

def transmix_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30

数据准备

transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False, num_workers=2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14

训练和验证

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

def train(epoch):
    model.train()
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        inputs, targets_a, targets_b, lam = transmix_data(inputs, labels, alpha=1.0)
        inputs, targets_a, targets_b = map(torch.autograd.Variable, (inputs, targets_a, targets_b))

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = transmix_criterion(criterion, outputs, targets_a, targets_b, lam)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 100 == 99:    # 每100个batch打印一次
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 100))
            running_loss = 0.0

def test():
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))

for epoch in range(10):  # 训练10个epoch
    train(epoch)
    test()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42

这段代码展示了如何使用 TransMix 进行数据增强,并将其应用于 Vision Transformer 模型的训练过程中。TransMix 结合了 Mixup 和 CutMix 的思想,并通过随机生成的 attention map 来增强图像数据。这样可以提升模型的泛化能力,减少过拟合,并提高模型在不同数据分布上的表现。你可以根据具体的任务和数据集调整这些增强方法的参数和组合。

上次更新: 2025/06/25, 11:25:50
基础数据增强
自监督学习增强

← 基础数据增强 自监督学习增强→

最近更新
01
帮助信息查看
06-08
02
常用命令
06-08
03
学习资源
06-07
更多文章>
Theme by Vdoing | Copyright © 2022-2025 Geeks_Z | MIT License
京公网安备 11010802040735号 | 京ICP备2022029989号-1
  • 跟随系统
  • 浅色模式
  • 深色模式
  • 阅读模式