Geeks_Z の Blog Geeks_Z の Blog
首页
  • 学习笔记

    • 《HTML》
    • 《CSS》
    • 《JavaWeb》
    • 《Vue》
  • 后端文章

    • Linux
    • Maven
    • 汇编语言
    • 软件工程
    • 计算机网络概述
    • Conda
    • Pip
    • Shell
    • SSH
    • Mac快捷键
    • Zotero
  • 学习笔记

    • 《数据结构与算法》
    • 《算法设计与分析》
    • 《Spring》
    • 《SpringMVC》
    • 《SpringBoot》
    • 《SpringCloud》
    • 《Nginx》
  • 深度学习文章
  • 学习笔记

    • 《PyTorch》
    • 《ReinforementLearning》
    • 《MetaLearning》
  • 学习笔记

    • 《高等数学》
    • 《线性代数》
    • 《概率论与数理统计》
  • 增量学习
  • 哈希学习
GitHub (opens new window)

Geeks_Z

AI小学生
首页
  • 学习笔记

    • 《HTML》
    • 《CSS》
    • 《JavaWeb》
    • 《Vue》
  • 后端文章

    • Linux
    • Maven
    • 汇编语言
    • 软件工程
    • 计算机网络概述
    • Conda
    • Pip
    • Shell
    • SSH
    • Mac快捷键
    • Zotero
  • 学习笔记

    • 《数据结构与算法》
    • 《算法设计与分析》
    • 《Spring》
    • 《SpringMVC》
    • 《SpringBoot》
    • 《SpringCloud》
    • 《Nginx》
  • 深度学习文章
  • 学习笔记

    • 《PyTorch》
    • 《ReinforementLearning》
    • 《MetaLearning》
  • 学习笔记

    • 《高等数学》
    • 《线性代数》
    • 《概率论与数理统计》
  • 增量学习
  • 哈希学习
GitHub (opens new window)
  • Python

  • MLTutorials

  • 卷积神经网络

  • 循环神经网络

  • Transformer

  • VisionTransformer

    • ViT综述
    • ViT全文翻译
    • ViT详解
    • ViT模型列表
    • ViT微调时position_embedding插值
      • 1. 问题描述
      • 2.positional embedding如何interpolate
      • 3. 输入的sequence length改变了ViT还能正常前向推断?
    • ViT_four secrets
    • DyT
    • Sparse-Tuning
  • 扩散模型

  • 计算机视觉

  • PTM

  • MoE

  • LoRAMoE

  • LongTailed

  • 多模态

  • 知识蒸馏

  • PEFT

  • 对比学习

  • 小样本学习

  • 迁移学习

  • 零样本学习

  • 集成学习

  • Mamba

  • PyTorch

  • CL

  • CIL

  • 小样本类增量学习FSCIL

  • UCIL

  • 多模态增量学习MMCL

  • LTCIL

  • DIL

  • 论文阅读与写作

  • 分布外检测

  • GPU

  • 深度学习调参指南

  • AINotes
  • VisionTransformer
Geeks_Z
2024-12-21
目录

ViT微调时position_embedding插值

ViT 微调时position embedding如何插值(interpolate)【源码解析】

Author: [好耶]

Link: [https://zhuanlan.zhihu.com/p/592580783]

1. 问题描述

本文适合对Vision Transformer有一定了解(知道内部结构和一些实现细节,最好是精读过ViT这篇论文)的读者阅读,这篇博客不会详细说明ViT的结构和前向推断过程。

符号 含义
b batch size
N patch size
H W 低分辨率图像的高和宽
H' W' 高分辨率图像的高和宽
So 低分辨率图像的sequence length的长度(o是original的意思)
Sn 高分辨率图像的sequence length的长度(n是new的意思)
h hidden dimension,即每个patch经过linear layer后得到的vector的长度,原文是16x16x3=768

最近在读ViT相关的论文(ViT、DeiT、Swin Transformer),感觉看得比较细致,但ViT中有个细节我一直不太理解:就是在用高分辨率(high resolution)图像做微调时,作者在论文里说:保持patch size不变,直接把position embedding向量进行插值处理(interpolate),原文如下:


作者的意思是:当使用高分辨率(high resolution)图像对预训练好的ViT进行微调(fine-tuning)时,保持patch size (N×N) 不变(即每个patch中的像素数量不变),但由于image size( H×W ,且 H′=W′ )变大了,则sequence length sn=H′/N 也相应变大了。而预训练好的position embedding是对原先低分辨率(low resolution)图像的位置编码(即原来的sequence length so=H/N ),自然无法适应现在的新的sequence length sn 。作者对此提出的解决方案是对原先的postion embedding进行2D的插值处理。

这我就很困惑了:position embedding是个1-D的向量,怎么做2D的插值呢?查了好久也没找到满意的解释,最后还是去看了torchvision中ViT的实现才明白怎么回事儿,其实很简单。

2.positional embedding如何interpolate

我们用图来表示想做的事情:



如何把 so 变成 sn 呢?具体做法如下:

假设position_embedding_img的shape为 (b,h,so) ,其中 b 为batch size,设置 b=1 。 h 和 so 的含义见上面的表格。

  • 首先将position_embedding_img的shape由 (b,h,so) reshape成 ()(b,h,so,so)
  • 然后将后两维 (so,so) 使用torch.nn.functinoal.interpolate,插值成: (sn,sn) ,此时position_embedding_img_new的shape为: ()(b,h,sn,sn)
  • 最后再把position_embedding_img_new reshape成 ()(b,h,sn)

经过上述步骤,我们就将position_embedding_img的 (b,h,so) 变成了position_embedding_img_new的 (b,h,sn) 。示意图如下(这里设 b=1,h=1 ):

3. 输入的sequence length改变了ViT还能正常前向推断?


其实到了第二步就已经结束了,但可能有些人(包括我之前)还会有个疑问:之前我们预训练时输入给Transformer Encoder(即上图中红色圈出的部分)的tensor的shape为: (b,so,h) ,而如果使用高分辨率的img进行微调,那输入到Transformer Encoder的shape变成了: (b,sn,h) ,还可以前向推断吗?Transformer Encoder不需要改内部结构吗?

答案是不需要。原因在于微调时hidden dimension h 的值没有变,为什么这么说呢?我们考虑下Transformer Encoder的内部结构,主要是多头自注意力(multi-head self-attention)和MLP。multi-head self-attention其实就是把输入切分成n个头,分别进行self-attention,然后再把结果concat起来,所以我们以单头自注意力、batch size=1为例,self-attention的大致流程为:


可以看出,Transformer Encoder中训练的参数: 、、Wq、Wk、Wv 的形状都为 (h,h) ,并不会随着sequence length由 so 变为 sn 而发生改变。

同理,Transformer Encoder中的MLP的input layer的神经元个数也是 h ,和 sn 无关。

即Transformer Encoder中参数只和hidden embedding的长度 h 有关,和sequence length 、so、sn 无关。

因此,即使我们输入Transformer Encoder的维度由 (b,so,h) 变为 (b,sn,h) ,也不会影响ViT的前向推断过程。

如果想看Torchvision官方中关于interpolate代码的细节实现,我放在下面:

def interpolate_embeddings(
    image_size: int,
    patch_size: int,
    model_state: "OrderedDict[str, torch.Tensor]",
    interpolation_mode: str = "bicubic",
    reset_heads: bool = False,
) -> "OrderedDict[str, torch.Tensor]":
    """This function helps interpolating positional embeddings during checkpoint loading,
    especially when you want to apply a pre-trained model on images with different resolution.

    Args:
        image_size (int): Image size of the new model.
        patch_size (int): Patch size of the new model.
        model_state (OrderedDict[str, torch.Tensor]): State dict of the pre-trained model.
        interpolation_mode (str): The algorithm used for upsampling. Default: bicubic.
        reset_heads (bool): If true, not copying the state of heads. Default: False.

    Returns:
        OrderedDict[str, torch.Tensor]: A state dict which can be loaded into the new model.
    """
    # Shape of pos_embedding is (1, seq_length, hidden_dim)
    pos_embedding = model_state["encoder.pos_embedding"]
    n, seq_length, hidden_dim = pos_embedding.shape
    if n != 1:
        raise ValueError(f"Unexpected position embedding shape: {pos_embedding.shape}")

    new_seq_length = (image_size // patch_size) ** 2 + 1

    # Need to interpolate the weights for the position embedding.
    # We do this by reshaping the positions embeddings to a 2d grid, performing
    # an interpolation in the (h, w) space and then reshaping back to a 1d grid.
    if new_seq_length != seq_length:
        # The class token embedding shouldn't be interpolated so we split it up.
        seq_length -= 1
        new_seq_length -= 1
        pos_embedding_token = pos_embedding[:, :1, :]
        pos_embedding_img = pos_embedding[:, 1:, :]

        # (1, seq_length, hidden_dim) -> (1, hidden_dim, seq_length)
        pos_embedding_img = pos_embedding_img.permute(0, 2, 1)
        seq_length_1d = int(math.sqrt(seq_length))
        if seq_length_1d * seq_length_1d != seq_length:
            raise ValueError(
                f"seq_length is not a perfect square! Instead got seq_length_1d * seq_length_1d = {seq_length_1d * seq_length_1d } and seq_length = {seq_length}"
            )

        # (1, hidden_dim, seq_length) -> (1, hidden_dim, seq_l_1d, seq_l_1d)
        pos_embedding_img = pos_embedding_img.reshape(1, hidden_dim, seq_length_1d, seq_length_1d)
        new_seq_length_1d = image_size // patch_size

        # Perform interpolation.
        # (1, hidden_dim, seq_l_1d, seq_l_1d) -> (1, hidden_dim, new_seq_l_1d, new_seq_l_1d)
        new_pos_embedding_img = nn.functional.interpolate(
            pos_embedding_img,
            size=new_seq_length_1d,
            mode=interpolation_mode,
            align_corners=True,
        )

        # (1, hidden_dim, new_seq_l_1d, new_seq_l_1d) -> (1, hidden_dim, new_seq_length)
        new_pos_embedding_img = new_pos_embedding_img.reshape(1, hidden_dim, new_seq_length)

        # (1, hidden_dim, new_seq_length) -> (1, new_seq_length, hidden_dim)
        new_pos_embedding_img = new_pos_embedding_img.permute(0, 2, 1)
        new_pos_embedding = torch.cat([pos_embedding_token, new_pos_embedding_img], dim=1)

        model_state["encoder.pos_embedding"] = new_pos_embedding

        if reset_heads:
            model_state_copy: "OrderedDict[str, torch.Tensor]" = OrderedDict()
            for k, v in model_state.items():
                if not k.startswith("heads"):
                    model_state_copy[k] = v
            model_state = model_state_copy

    return model_state

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77

参考

https://arxiv.org/pdf/2010.11929.pdf (opens new window)

https://github.com/pytorch/vision/blob/main/torchvision/models/vision_transformer.py (opens new window)

https://blog.csdn.net/qq_44166630/article/details/127429697?spm=1001.2014.3001.5501 (opens new window)

上次更新: 2025/06/25, 11:25:50
ViT模型列表
ViT_four secrets

← ViT模型列表 ViT_four secrets→

最近更新
01
帮助信息查看
06-08
02
常用命令
06-08
03
学习资源
06-07
更多文章>
Theme by Vdoing | Copyright © 2022-2025 Geeks_Z | MIT License
京公网安备 11010802040735号 | 京ICP备2022029989号-1
  • 跟随系统
  • 浅色模式
  • 深色模式
  • 阅读模式