tensor维度转换
view() 转换维度
torch.view
是 PyTorch 中用于重新塑形张量(tensor)的函数,它返回一个新的张量,这个张量与原始张量具有相同的数据但不同的形状。它不会改变张量中的数据,只是改变数据的视图方式。
Tensor.view(*shape) → Tensor
- shape (torch.Size or int...): 期望的新张量的形状。可以是一个
torch.Size
对象,也可以是一个整数序列。这个形状描述了新张量的各个维度的大小。
torch.view
的使用有一些重要的注意事项:
元素总数必须保持不变:重新塑形后的张量必须包含与原始张量相同数量的元素。换句话说,原始张量和新张量的元素总数必须相等。
-1 的特殊用法:在
torch.view
的形状参数中,可以使用-1
来自动计算该维度的大小。-1
的位置表示该维度的大小将根据张量中的元素总数和其他维度的大小自动计算。这通常用于当我们知道除了一个维度以外的所有维度大小,但不想手动计算那个维度的大小时。
下面是一些使用 torch.view
的示例:
import torch
# 创建一个形状为 (4, 4) 的张量
x = torch.randn(4, 4)
print(x.size()) # 输出: torch.Size([4, 4])
# 使用 torch.view 将 x 重塑为形状 (16,) 的一维张量
y = x.view(16)
print(y.size()) # 输出: torch.Size([16])
# 使用 -1 自动计算维度大小,将 x 重塑为形状 (2, 8) 的二维张量
z = x.view(2, -1)
print(z.size()) # 输出: torch.Size([2, 8])
# 也可以重塑为更高维度的张量,只要元素总数保持不变
m = x.view(2, 2, 4)
print(m.size()) # 输出: torch.Size([2, 2, 4])
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
在上面的例子中,我们展示了如何使用 torch.view
来改变张量的形状。注意,在使用 torch.view
时,需要确保新形状与原始形状的元素总数相匹配,否则 PyTorch 会抛出一个错误。
最后,需要提醒的是,虽然 torch.view
可以改变张量的形状,但它并不会改变张量在内存中的布局。也就是说,它不会移动张量中的数据,只是改变了解释这些数据的方式。因此,torch.view
的操作是高效的,并且通常用于在神经网络的不同层之间调整数据的形状。
reshape() 转换维度
permute() 坐标系变换
torch.permute
是 PyTorch 中的一个函数,用于重新排列张量(tensor)的维度。
基本用法
torch.permute(input, dims) → Tensor
input
:输入的张量。dims
:一个包含整数的元组,表示新的维度顺序。
示例
假设我们有一个形状为 (3, 4, 5)
的张量,我们可以使用 torch.permute
来改变其维度的顺序。
import torch
# 创建一个形状为 (3, 4, 5) 的张量
x = torch.randn(3, 4, 5)
# 使用 permute 交换第一和第二个维度
y = torch.permute(x, (1, 0, 2))
print(y.shape) # 输出:torch.Size([4, 3, 5])
2
3
4
5
6
7
8
9
在上面的例子中,我们创建了一个形状为 (3, 4, 5)
的张量 x
,然后使用 torch.permute
将第一和第二个维度交换,得到形状为 (4, 3, 5)
的张量 y
。
注意点
- 维度数必须匹配:
dims
中的整数数量必须与输入张量的维度数相同。 - 索引的唯一性:
dims
中的每个整数都必须是唯一的,不能有重复。 - 原地操作:
torch.permute
默认是返回一个新的张量,而不是在原地修改输入张量。如果你想在原地修改,你可以使用.permute_()
方法(注意末尾的下划线)。
与其他函数的关系
torch.permute
与 torch.transpose
类似,但 torch.transpose
只用于交换两个维度,而 torch.permute
可以用于任意维度的重新排列。
例如,使用 torch.transpose
交换第一和第二个维度:
y = x.transpose(0, 1)
这与使用 torch.permute
的效果相同:
y = torch.permute(x, (1, 0, 2))
但是,如果你需要更复杂的维度重排,torch.permute
会更加灵活。
squeeze()/unsqueeze() 降维/升维
torch.squeeze()
是 PyTorch 中的一个函数,用于从张量(tensor)中移除所有大小为 1 的维度。
参数
input
(Tensor): 输入张量。dim
(int, optional): 要移除的维度的索引。如果提供,则仅移除该特定维度。
返回值
返回一个与输入张量相同数据的新张量,但已移除了大小为 1 的维度。
示例
示例 1:移除所有大小为 1 的维度
import torch
# 创建一个形状为 (1, 3, 1, 4) 的张量
x = torch.randn(1, 3, 1, 4)
print(x.shape) # 输出: torch.Size([1, 3, 1, 4])
# 使用 torch.squeeze() 移除所有大小为 1 的维度
y = torch.squeeze(x)
print(y.shape) # 输出: torch.Size([3, 4])
2
3
4
5
6
7
8
9
示例 2:仅移除指定维度
import torch
# 创建一个形状为 (1, 3, 1, 4) 的张量
x = torch.randn(1, 3, 1, 4)
print(x.shape) # 输出: torch.Size([1, 3, 1, 4])
# 使用 torch.squeeze() 仅移除第 0 个维度(即形状中的第一个维度)
y = torch.squeeze(x, dim=0)
print(y.shape) # 输出: torch.Size([3, 1, 4])
2
3
4
5
6
7
8
9
注意事项
torch.squeeze()
不会改变原始张量,而是返回一个新的张量。- 如果指定的维度
dim
不是大小为 1 的维度,那么torch.squeeze()
将不会移除该维度,并且返回的张量将与输入张量具有相同的形状。 torch.squeeze()
常用于处理从某些操作(如torch.unsqueeze()
或某些神经网络层)中产生的额外大小为 1 的维度,这些维度可能会使张量的形状变得复杂。
torch.unsqueeze()
是 PyTorch 中的一个函数,用于在指定的维度上增加一个维度大小为 1 的维度。这通常用于改变张量(tensor)的形状(shape),以便进行某些操作或匹配其他张量的形状。
参数
torch.unsqueeze(input, dim)
input
(Tensor): 输入张量。dim
(int): 在哪个维度上增加维度。
示例
import torch
# 创建一个形状为 [3, 4] 的张量
x = torch.tensor([[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]])
print(x.shape) # 输出: torch.Size([3, 4])
# 在维度 0 上增加一个维度,形状变为 [1, 3, 4]
x_unsqueeze0 = torch.unsqueeze(x, 0)
print(x_unsqueeze0.shape) # 输出: torch.Size([1, 3, 4])
# 在维度 1 上增加一个维度,形状变为 [3, 1, 4]
x_unsqueeze1 = torch.unsqueeze(x, 1)
print(x_unsqueeze1.shape) # 输出: torch.Size([3, 1, 4])
# 在维度 2 上增加一个维度,形状变为 [3, 4, 1]
x_unsqueeze2 = torch.unsqueeze(x, 2)
print(x_unsqueeze2.shape) # 输出: torch.Size([3, 4, 1])
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
注意事项
- 如果指定的维度
dim
已经存在(即其大小不为 1),则该函数将引发错误。 torch.unsqueeze()
会返回一个新的张量,而不是修改原始张量。- 如果你想要删除一个维度大小为 1 的维度,可以使用
torch.squeeze()
函数。
expand() 扩张张量
torch.expand()
是 PyTorch 中的一个函数,用于扩展张量(tensor)的维度。这个函数不会改变张量的数据,只是改变张量的形状(shape)。expand()
返回的新的张量与原始张量共享相同的内存空间,也就是说它们指向同一块内存区域。
函数的基本语法如下:
torch.expand(input, size)
input
:输入的张量。size
:一个整数元组,表示扩展后的张量的大小。
import torch
# 创建一个形状为 (3,1) 的张量
x = torch.tensor([[1], [2], [3]])
# print(x.shape) # torch.Size([3, 1])
# 使用 expand 扩展 x 的维度到 (3, 4)
y = x.expand( (3, 4))
print(y.shape) # 输出: torch.Size([3, 4])
print(y)
2
3
4
5
6
7
8
9
10
输出:
torch.Size([3, 4])
tensor([[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3]])
2
3
4
需要注意的是,expand()
只能在单维度上进行扩展,并且这个维度的大小必须为 1。
尽管 expand()
提供了扩展张量维度的功能,但它并不总是最佳选择。在某些情况下,使用 view()
或 reshape()
函数可能更为合适,因为它们可以更改张量的形状而不受上述限制。但是,view()
和 reshape()
要求新的形状与原始张量的元素总数相匹配,而 expand()
则没有这个要求。
narraw() 缩小张量
resize_() 重设尺寸
repeat()
Tensor
对象有一个 repeat()
方法,它用于沿着指定的维度重复张量。 repeat()
方法并不会增加张量的维度,而是沿着指定的维度复制张量的元素。这与 torch.tensor.expand()
有所不同,expand()
会增加一个或多个维度的大小,但不会复制数据(即“广播”)。
Tensor.repeat()
的基本语法如下:
torch.Tensor.repeat(*sizes)
*sizes
:一个整数序列,指定在每个维度上重复的次数。
示例:
import torch
x = torch.tensor([1, 2, 3])
y = x.repeat(2) # 在每个元素上重复2次
print(y) # tensor([1, 1, 2, 2, 3, 3])
x = torch.tensor([[1, 2], [3, 4]])
y = x.repeat(2, 1) # 在第一个维度上重复2次,第二个维度上重复1次(即不重复)
print(y)
# tensor([[1, 2],
# [1, 2],
# [3, 4],
# [3, 4]])
z = x.repeat(1, 2) # 在第一个维度上重复1次(即不重复),第二个维度上重复2次
print(z)
# tensor([[1, 2, 1, 2],
# [3, 4, 3, 4]])
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
请注意,repeat()
方法会实际复制数据,这可能会消耗大量内存。因此,在处理大型张量时,要特别小心。如果可能的话,考虑使用 expand()
或其他方法来避免不必要的内存消耗。
unfold() 重复张量
cat()
torch.cat()
是 PyTorch 中的一个函数,用于沿指定维度连接张量(tensors)。
参数
tensors
(sequence of Tensors): 要连接的张量序列。dim
(int, optional): 要连接的维度。默认值为 0。out
(Tensor, optional): 输出张量。
返回值
返回一个新的张量,它是输入张量在指定维度上的连接。
示例
示例 1:沿第一个维度连接张量
import torch
# 创建两个形状为 (3, 4) 的张量
tensor1 = torch.randn(3, 4)
tensor2 = torch.randn(3, 4)
# 沿第一个维度(即行)连接这两个张量
result = torch.cat((tensor1, tensor2), dim=0)
# result 的形状现在是 (6, 4),因为它包含了两个原始张量的所有行
print(result.shape) # 输出: torch.Size([6, 4])
2
3
4
5
6
7
8
9
10
11
示例 2:沿第二个维度连接张量
import torch
# 创建两个形状为 (2, 3) 的张量
tensor1 = torch.randn(2, 3)
tensor2 = torch.randn(2, 3)
# 沿第二个维度(即列)连接这两个张量
result = torch.cat((tensor1, tensor2), dim=1)
# result 的形状现在是 (2, 6),因为它包含了两个原始张量的所有列
print(result.shape) # 输出: torch.Size([2, 6])
2
3
4
5
6
7
8
9
10
11
注意事项
- 所有要连接的张量必须在除了连接维度以外的所有维度上具有相同的大小。
dim
参数指定了连接应该发生的维度。例如,dim=0
表示按行连接(增加行数),而dim=1
表示按列连接(增加列数)。torch.cat()
不会改变原始张量,而是返回一个新的连接后的张量。
stack() 拼接张量
在 PyTorch 中,stack()
函数用于沿着一个新维度连接张量(tensors)。
1. 函数签名
torch.stack(tensors, dim=0, out=None) → Tensor
tensors
(sequence of Tensors): 需要连接的张量序列。dim
(int, optional): 插入新维度的索引。默认为 0。out
(Tensor, optional): 输出张量。
2. 用法示例
假设我们有三个形状为(2, 3)
的张量,我们希望将它们沿着新的维度(例如第 0 维)堆叠起来:
import torch
# 创建三个形状为(2, 3)的张量
x1 = torch.tensor([[1, 2, 3], [4, 5, 6]])
x2 = torch.tensor([[7, 8, 9], [10, 11, 12]])
x3 = torch.tensor([[13, 14, 15], [16, 17, 18]])
# 使用stack沿着第0维堆叠
stacked = torch.stack((x1, x2, x3), dim=0)
print(stacked.shape) # 输出: torch.Size([3, 2, 3])
2
3
4
5
6
7
8
9
10
11
在这个例子中,stacked
张量的形状是(3, 2, 3)
,因为我们沿着第 0 维(新的维度)堆叠了三个形状为(2, 3)
的张量。
3. 与cat()
的区别
stack()
和cat()
都可以用来连接张量,但它们的工作方式略有不同。cat()
函数沿着现有维度连接张量,而stack()
则创建一个新的维度来连接张量。
例如,如果你有三个形状为(2, 3)
的张量,并使用cat()
沿着第 0 维连接它们,你将得到一个形状为(6, 3)
的张量。但是,如果你使用stack()
,你将得到一个形状为(3, 2, 3)
的张量。
4. 注意事项
- 所有要堆叠的张量必须具有相同的形状(除了堆叠的维度)。
- 堆叠操作会增加张量的维度数。