Geeks_Z の Blog Geeks_Z の Blog
首页
  • 学习笔记

    • 《HTML》
    • 《CSS》
    • 《JavaWeb》
    • 《Vue》
  • 后端文章

    • Linux
    • Maven
    • 汇编语言
    • 软件工程
    • 计算机网络概述
    • Conda
    • Pip
    • Shell
    • SSH
    • Mac快捷键
    • Zotero
  • 学习笔记

    • 《数据结构与算法》
    • 《算法设计与分析》
    • 《Spring》
    • 《SpringMVC》
    • 《SpringBoot》
    • 《SpringCloud》
    • 《Nginx》
  • 深度学习文章
  • 学习笔记

    • 《PyTorch》
    • 《ReinforementLearning》
    • 《MetaLearning》
  • 学习笔记

    • 《高等数学》
    • 《线性代数》
    • 《概率论与数理统计》
  • 增量学习
  • 哈希学习
GitHub (opens new window)

Geeks_Z

AI小学生
首页
  • 学习笔记

    • 《HTML》
    • 《CSS》
    • 《JavaWeb》
    • 《Vue》
  • 后端文章

    • Linux
    • Maven
    • 汇编语言
    • 软件工程
    • 计算机网络概述
    • Conda
    • Pip
    • Shell
    • SSH
    • Mac快捷键
    • Zotero
  • 学习笔记

    • 《数据结构与算法》
    • 《算法设计与分析》
    • 《Spring》
    • 《SpringMVC》
    • 《SpringBoot》
    • 《SpringCloud》
    • 《Nginx》
  • 深度学习文章
  • 学习笔记

    • 《PyTorch》
    • 《ReinforementLearning》
    • 《MetaLearning》
  • 学习笔记

    • 《高等数学》
    • 《线性代数》
    • 《概率论与数理统计》
  • 增量学习
  • 哈希学习
GitHub (opens new window)
  • Python

  • MLTutorials

  • 卷积神经网络

  • 循环神经网络

  • Transformer

  • VisionTransformer

  • 扩散模型

  • 计算机视觉

  • PTM

  • MoE

  • LoRAMoE

  • LongTailed

  • 多模态

  • 知识蒸馏

  • PEFT

  • 对比学习

  • 小样本学习

  • 迁移学习

  • 零样本学习

  • 集成学习

  • Mamba

  • PyTorch

    • PyTorch概述

    • Tensors

      • Tensors
      • 自动求导
      • AI硬件加速设备
      • tensor类型转换
      • tensor维度转换
      • 常见函数
        • torch.topk()
        • torch.unique()
        • torch.nonzero
        • torch.where
          • 基本语法
          • 使用示例
        • scatter()
          • 官方示例
          • scatter_()
          • Reference
      • tensor可视化为图片
    • 数据处理

    • 模型

    • 训练

    • 并行计算

    • 可视化

    • 实战

    • timm

    • Pytorch Lightning

    • 数据增强

    • 面经与bug解决

    • 常用代码片段

    • Reference
  • CL

  • CIL

  • 小样本类增量学习FSCIL

  • UCIL

  • 多模态增量学习MMCL

  • LTCIL

  • DIL

  • 论文阅读与写作

  • 分布外检测

  • GPU

  • 深度学习调参指南

  • AINotes
  • PyTorch
  • Tensors
Geeks_Z
2023-06-29
目录

常见函数

torch.topk()

  1. 作用 取一个 tensor 的 topk 元素,返回值为降序后的前 k 个大小的元素值及索引
  2. 使用方法
  • dim=0 表示按照列求 topn
  • dim=1 表示按照行求 topn
  • 默认情况下,dim=1
  1. 示例

    >>> x = torch.arange(1., 6.)
    >>> x
    tensor([ 1.,  2.,  3.,  4.,  5.])
    >>> torch.topk(x, 3)
    torch.return_types.topk(values=tensor([5., 4., 3.]), indices=tensor([4, 3, 2]))
    
    1
    2
    3
    4
    5

torch.unique()

torch.unique()的功能类似于数学中的集合,就是挑出 tensor 中的独立不重复元素。

这个方法的参数在官方解释文档中有这么几个:torch.unique(input, sorted=True, return_inverse=False, return_counts=False, dim=None)

input: 待处理的 tensor

sorted:是否对返回的无重复张量按照数值进行排列,默认是生序排列的

return_inverse: 是否返回原始 tensor 中的每个元素在这个无重复张量中的索引

return_counts: 统计原始张量中每个独立元素的个数

dim: 值沿着哪个维度进行 unique 的处理,这个我试验后没有搞懂怎样的机理。如果处理的张量都是一维的,那么这个不需要理会。

下面分别对这些不同的参数进行实验讲解分析。

import torch

x = torch.tensor([4,0,1,2,1,2,3])#生成一个tensor,作为实验输入
print(x)

out = torch.unique(x) #所有参数都设置为默认的
print(out)#将处理结果打印出来
#结果如下:
#tensor([0, 1, 2, 3, 4])   #将x中的不重复元素挑了出来,并且默认为生序排列

out = torch.unique(x,sorted=False)#将默认的生序排列改为False
print(out)
#输出结果如下:
#tensor([3, 2, 1, 0, 4])  #将x中的独立元素找了出来,就按照原始顺序输出

out = torch.unique(x,return_inverse=True)#将原始数据中的每个元素在新生成的独立元素张量中的索引输出
print(out)
#输出结果如下:
#(tensor([0, 1, 2, 3, 4]), tensor([4, 0, 1, 2, 1, 2, 3]))  #第一个张量是排序后输出的独立张量,第二个结果对应着原始数据中的每个元素在新的独立无重复张量中的索引,比如x[0]=4,在新的张量中的索引为4, x[1]=0,在新的张量中的索引为0,x[6]=3,在新的张量中的索引为3

out = torch.unique(x,return_counts=True) #返回每个独立元素的个数
print(out)
#输出结果如下
#(tensor([0, 1, 2, 3, 4]), tensor([1, 2, 2, 1, 1]))  #0这个元素在原始数据中的数量为1,1这个元素在原始数据中的数量为2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24

torch.nonzero

torch.nonzero() 是 PyTorch 中的一个函数,用于获取张量中非零元素的索引。这个函数返回一个二维张量,其中每一行都包含输入张量中一个非零元素的索引。

这个函数的语法如下:

torch.nonzero(input, *, out=None)
1

其中,input 是输入张量。

让我们来看一个例子,假设我们有一个形状为 (3, 3) 的张量 x:

import torch

x = torch.tensor([[0, 1, 0], [2, 0, 2], [0, 3, 0]])
print(x)
1
2
3
4

现在,我们可以使用 torch.nonzero() 来获取 x 中非零元素的索引:

indices = torch.nonzero(x)
print(indices)
1
2

输出结果如下:

tensor([[0, 1],
        [1, 0],
        [1, 2],
        [2, 1]])
1
2
3
4

可以看到,indices 是一个形状为 (4, 2) 的张量,其中每一行都是 x 中一个非零元素的索引。

需要注意的是,torch.nonzero() 返回的索引是按照行优先顺序排列的,也就是说,它首先返回第一行的非零元素的索引,然后返回第二行的,依此类推。

在 PyTorch 中,scatter() 和 scatter_() 函数通常用于在特定维度上根据索引更新张量(tensor)的值。这两个函数的主要区别在于它们是否原地(in-place)修改输入张量。

torch.where 是 PyTorch 中的一个条件选择函数,常用于根据给定条件在两个张量之间进行元素级选择。


torch.where

基本语法

torch.where(condition, x, y)
1
  • condition:布尔张量,元素值为 True 的位置选择 x,为 False 的位置选择 y。
  • x:当 condition 为 True 时使用的值或张量。
  • y:当 condition 为 False 时使用的值或张量。

使用示例

选择性替换元素

import torch

a = torch.tensor([1, 2, 3, 4, 5])
b = torch.tensor([10, 20, 30, 40, 50])
condition = a > 3

result = torch.where(condition, a, b)
print(result)  # tensor([10, 20, 30,  4,  5])
1
2
3
4
5
6
7
8
  • a > 3 生成布尔张量 [False, False, False, True, True]
  • 只有 a 中大于 3 的元素被保留,其他地方使用 b 的元素。

应用于多维张量

A = torch.tensor([[1, -2], [3, -4]])
B = torch.tensor([[10, 20], [30, 40]])

condition = A < 0  # 找到负数
result = torch.where(condition, B, A)
print(result)
# tensor([[ 1, 20],
#         [ 3, 40]])
1
2
3
4
5
6
7
8
  • 负数用 B 中对应元素替换,其他保持 A。

仅提供 condition(索引操作)

如果只提供 condition,torch.where 会返回满足条件的索引。

x = torch.tensor([[1, 2, 3], [4, 5, 6]])
indices = torch.where(x > 3)
print(indices)  # (tensor([1, 1, 1]), tensor([0, 1, 2]))
1
2
3
  • indices 代表行索引和列索引,可以用于索引 x[indices],取出满足条件的元素。
values = x[indices]
print(values)  # tensor([4, 5, 6])
1
2

scatter()

scatter() 函数根据提供的索引将源张量的值分散到目标张量中。它不会修改源张量或目标张量本身(即原地操作)。

函数签名:

torch.scatter(input, dim, index, src, *, out=None)
1
  • input (Tensor): 目标张量。
  • dim (int): 沿其分散的维度。
  • index (LongTensor): 索引张量,其形状必须与 src 的形状在 dim 维度之外的其他所有维度上都匹配。
  • src (Tensor): 源张量,其形状必须与 input 在 dim 维度之外的其他所有维度上都匹配。
  • out (Tensor, optional): 输出张量。

官方示例

三维示例

y = y.scatter(dim,index,src)

#则结果为:
y[ index[i][j][k]  ] [j][k] = src[i][j][k] # if dim == 0
y[i] [ index[i][j][k] ] [k] = src[i][j][k] # if dim == 1
y[i][j] [ index[i][j][k] ]  = src[i][j][k] # if dim == 2
1
2
3
4
5
6

二维示例

y = y.scatter(dim,index,src)

#则:
y [ index[i][j] ] [j] = src[i][j] #if dim==0
y[i] [ index[i][j] ]  = src[i][j] #if dim==1
1
2
3
4
5
import torch

x = torch.randn(2,4)
print(x)
y = torch.zeros(3,4)
y = y.scatter_(0,torch.LongTensor([[2,1,2,2],[0,2,1,1]]),x)
print(y)


#结果为:
tensor([[-0.9669, -0.4518,  1.7987,  0.1546],
        [-0.1122, -0.7998,  0.6075,  1.0192]])
tensor([[-0.1122,  0.0000,  0.0000,  0.0000],
        [ 0.0000, -0.4518,  0.6075,  1.0192],
        [-0.9669, -0.7998,  1.7987,  0.1546]])


'''
scatter后:
y[ index[0][0] ] [0] = src[0][0] -> y[2][0]=-0.9669
y[ index[1][3] ] [3] = src[1][3] -> y[1][3]=1.10192
'''

#如果src为标量,则代表着将对应位置的数值改为src这个标量
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24

那么这个函数有什么作用呢?其实可以利用这个功能将 pytorch 中 mini batch 中的返回的 label(特指[ 1,0,4,9 ],即 size 为[4]这样的 label)转为 one-hot 类型的 label,举例子如下:

import torch

mini_batch = 4
out_planes = 6
out_put = torch.rand(mini_batch, out_planes)
softmax = torch.nn.Softmax(dim=1)
out_put = softmax(out_put)

print(out_put)
label = torch.tensor([1,3,3,5])
one_hot_label = torch.zeros(mini_batch, out_planes).scatter_(1,label.unsqueeze(1),1)
print(one_hot_label)
1
2
3
4
5
6
7
8
9
10
11
12

上述的这个例子假设是一个分类问题,我设置 out_planes=6,是假设总共有 6 类,mini_batch 是我们送入的网络的每个 mini_batch 的样本数量,这里我们不设置网络,直接假设网络的输出为一个随机的张量 ,通常我们要对这个输出进行 softmax 归一化,此时就代表着其属于每个类别的概率了。说到这里都不是重点,就是为了方便理解如何使用 scatter,将 size 为[mini_batch]的张量,转为 size 为[mini_batch, out_palnes]的张量,并且这个生成的张量的每个行向量都是 one-hot 类型的了。通过看下面的输出结果就完全能够理解了。

tensor([[0.1202, 0.2120, 0.1252, 0.1127, 0.2314, 0.1985],
        [0.1707, 0.1227, 0.2282, 0.0918, 0.1845, 0.2021],
        [0.1629, 0.1936, 0.1277, 0.1204, 0.1845, 0.2109],
        [0.1226, 0.1524, 0.2315, 0.2027, 0.1907, 0.1001]])
tensor([1, 3, 3, 5])
tensor([[0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 1.]])
1
2
3
4
5
6
7
8
9

scatter_()

scatter_() 函数与 scatter() 类似,但它会原地修改目标张量(即它会修改 input 张量本身)。

函数签名:

Tensor.scatter_(dim, index, src)
1
  • dim (int): 沿其分散的维度。
  • index (LongTensor): 索引张量。
  • src (Tensor): 源张量。

Reference

  • pytorch 中的 scatter_()函数使用和详解 (opens new window)
#PyTorch
上次更新: 2025/06/29, 11:12:32
tensor维度转换
tensor可视化为图片

← tensor维度转换 tensor可视化为图片→

最近更新
01
帮助信息查看
06-08
02
常用命令
06-08
03
学习资源
06-07
更多文章>
Theme by Vdoing | Copyright © 2022-2025 Geeks_Z | MIT License
京公网安备 11010802040735号 | 京ICP备2022029989号-1
  • 跟随系统
  • 浅色模式
  • 深色模式
  • 阅读模式