文章目录
- Tensor维度变换
- 1. view / reshape
- 1.1 view 函数
- 1.2 reshape 函数
- 2. squeeze / unsqueeze
- 2.1 unsqueeze 函数
- 案例
- 2.2 squeeze 函数
- 3. expand / repeat
- 3.1 expand 函数
- 3.2 repeat 函数
- 4. 矩阵转置
- 4.1 t 函数
- 4.2 transpose 函数
- 案例:数据污染
- 4.3 permute 函数
- 5. Broadcasting
Tensor维度变换
1. view / reshape
- 在 Pytorch 0.3 时,使用的默认 API 是 view
- 在 Pytorch 0.4 时,为了与numpy一致,增加了 reshape 方法
- 保证其元素个数不变的前提下,任意改变其维度
- 若改变了元素个数,就会报错
1.1 view 函数
a = torch.rand(2, 1, 2, 2) # 共有 2 * 1 * 2 * 2 = 8 个元素
print(a.shape) # torch.Size([2, 1, 2, 2])
print(a.numel()) # 8
print(a)
'''
tensor([
[[[0.6904, 0.6917],[0.1554, 0.4077]]],
[[[0.7704, 0.3776],[0.5143, 0.8417]]]
])
'''
b = a.view(2, 2 * 2) # 表示把 1 2 3三个维度合并成 1 个维度,第 0 维度不动
print(b.shape) # torch.Size([2, 4])
print(b.numel()) # 8
print(b)
'''
tensor([[0.6904, 0.6917, 0.1554, 0.4077],
[0.7704, 0.3776, 0.5143, 0.8417]])
'''
c = b.view(8)
print(c.shape) # torch.Size([8])
print(c.numel()) # 8
print(c)
'''
tensor([0.6904, 0.6917, 0.1554, 0.4077, 0.7704, 0.3776, 0.5143, 0.8417])
'''
d = c.view(8, 1, 1, 1, 1, 1, 1, 1, 1, 1)
print(d.shape) # torch.Size([8, 1, 1, 1, 1, 1, 1, 1, 1, 1])
print(d)
'''
tensor([[[[[[[[[[0.6904]]]]]]]]],
[[[[[[[[[0.6917]]]]]]]]],
[[[[[[[[[0.1554]]]]]]]]],
[[[[[[[[[0.4077]]]]]]]]],
[[[[[[[[[0.7704]]]]]]]]],
[[[[[[[[[0.3776]]]]]]]]],
[[[[[[[[[0.5143]]]]]]]]],
[[[[[[[[[0.8417]]]]]]]]]])
'''
print(a.dim()) # 4维度
print(b.dim()) # 2维度
print(c.dim()) # 1维度
print(d.dim()) # 10w维度
1.2 reshape 函数
- 与 view 函数用法一致
a = torch.ones(2, 1, 2)
print(a.shape) # torch.Size([2, 1, 2, 2])
print(a.dim()) # 3
print(a.numel()) # 4
print(a)
'''
tensor([
[[1., 1.]],
[[1., 1.]]
])
'''
b = a.reshape(4)
print(b.shape) # torch.Size([2, 4])
print(b.dim()) # 1
print(b.numel()) # 4
print(b)
'''
tensor([1., 1., 1., 1.])
'''
2. squeeze / unsqueeze
2.1 unsqueeze 函数
def unsqueeze(dim) -> Tensor
- unsqueeze:展开,增加一个维度,但是不改变数据个数
- dim : 参数范围 dim ∈ [ -(原dimMax + 1), (原dimMax + 1) ),能取左边界,无法取右边界
- 功能:在 dim 维度处插入一个维度,原 dim 维度及后续维度后移
参数dim正数范围: [ 0, 原 dimMax ],下面测试代码为 dim∈ [0, 4] ,如果在第 5 维度新增维度会报错
import torch
a = torch.ones(4, 3, 28, 28)
print(a.dim()) # 4
print(a.shape) # torch.Size([4, 3, 28, 28])
b = a.unsqueeze(0) # 在 0 维度处插入一个维度
print(b.dim()) # 5
print(b.shape) # torch.Size([1, 4, 3, 28, 28])
b = a.unsqueeze(1) # 在 1 维度处插入一个维度
print(b.dim()) # 5
print(b.shape) # torch.Size([4, 1, 3, 28, 28])
b = a.unsqueeze(2) # 在 2 维度处插入一个维度
print(b.dim()) # 5
print(b.shape) # torch.Size([4, 3, 1, 28, 28])
b = a.unsqueeze(3) # 在 3 维度处插入一个维度
print(b.dim()) # 5
print(b.shape) # torch.Size([4, 3, 28, 1, 28])
b = a.unsqueeze(4) # 在 4 维度处插入一个维度
print(b.dim()) # 5
print(b.shape) # torch.Size([4, 3, 28, 28, 1])
b = a.unsqueeze(5) # 在 5 维度处插入一个维度
print(b.dim())
print(b.shape)
# IndexError: Dimension out of range (expected to be in range of [-5, 4], but got 5)
参数dim负数范围: [ - ( 原 dimMax + 1 ), 0 ],下面测试代码为 dim∈ [ -5, 0 ]
import torch
a = torch.ones(4, 3, 28, 28)
print(a.dim()) # 4
print(a.shape) # torch.Size([4, 3, 28, 28])
b = a.unsqueeze(0) # 在 0 维度前面插入一个新维度
print(b.dim()) # 5
print(b.shape) # torch.Size([1, 4, 3, 28, 28])
b = a.unsqueeze(-1) # 在 4 维度处插入一个维度
print(b.dim()) # 5
print(b.shape) # torch.Size([4, 3, 28, 28, 1])
b = a.unsqueeze(-2) # 在 3 维度处插入一个维度
print(b.dim()) # 5
print(b.shape) # torch.Size([4, 3, 28, 1, 28])
b = a.unsqueeze(-3) # 在 2 维度处插入一个维度
print(b.dim()) # 5
print(b.shape) # torch.Size([4, 3, 1, 28, 28])
b = a.unsqueeze(-4) # 在 1 维度处插入一个维度
print(b.dim()) # 5
print(b.shape) # torch.Size([4, 1, 3, 28, 28])
b = a.unsqueeze(-5) # 在 0 维度处插入一个维度
print(b.dim()) # 5
print(b.shape) # torch.Size([1, 4, 3, 28, 28])
案例
- 将 b Tensor 的维度变成和 f Tensor 的维度一样
b = torch.rand(32)
f = torch.rand(4, 32, 14, 14)
b = b.unsqueeze(1).unsqueeze(2).unsqueeze(0)
print(b.shape) # torch.Size([1, 32, 1, 1])
2.2 squeeze 函数
def squeeze(dim=None) -> Tensor
- 参数不填:将所有维度中值为1的维度删除
- 填入正dim:将dim维度删除,若dim超过原维度最大值,则报错
- 填入负dim:将倒数dim维度删除,若该dim中的值不为 1 则不删除,只有 dim 中的值为 1 才删除该维度
参数dim正数范围:[1, 原 dimMax]
a = torch.ones(1, 1, 2, 2)
print(a.dim()) # 4
print(a.shape) # torch.Size([1, 1, 2, 2])
b = a.squeeze()
print(b.dim()) # 2
print(b.shape) # torch.Size([2, 2])
b = a.squeeze(0)
print(b.dim()) # 3
print(b.shape) # torch.Size([1, 2, 2])
b = a.squeeze(1)
print(b.dim()) # 3
print(b.shape) # torch.Size([1, 2, 2])
b = a.squeeze(2)
print(b.dim()) # 2
print(b.shape) # torch.Size([2, 2])
b = a.squeeze(3)
print(b.dim()) # 4
print(b.shape) # torch.Size([1, 1, 2, 2])
b = a.squeeze(4)
print(b.dim())
print(b.shape)
# IndexError: Dimension out of range (expected to be in range of [-4, 3], but got 4)
参数dim负数范围:[- 原 dimMax, -1]
a = torch.ones(1, 1, 2, 2)
print(a.dim()) # 4
print(a.shape) # torch.Size([4, 3, 28, 28])
b = a.squeeze()
print(b.dim()) # 2
print(b.shape) # torch.Size([2, 2])
b = a.squeeze(0)
print(b.dim()) # 3
print(b.shape) # torch.Size([1, 2, 2])
b = a.squeeze(-1)
print(b.dim()) # 4
print(b.shape) # torch.Size([1, 1, 2, 2])
b = a.squeeze(-2)
print(b.dim()) # 4
print(b.shape) # torch.Size([1, 1, 2, 2])
b = a.squeeze(-3)
print(b.dim()) # 3
print(b.shape) # torch.Size([1, 2, 2])
b = a.squeeze(-4)
print(b.dim()) # 3
print(b.shape) # torch.Size([1, 2, 2])
3. expand / repeat
3.1 expand 函数
def expand(*sizes) -> Tensor
- 按照传入的参数,将原 dim 中的值个数按照参数的值进行修改,并将其中的值按原有的值进行初始化
- 原维度中的值必须 小于等于 改变后的维度中的值
- 将不想改变的维度中填入 -1 后,该维度中的值不会改变
a = torch.rand(2, 1, 2, 2)
b = torch.rand(2, 1, 1, 1)
c = b.expand(a.shape)
print(c.shape) # torch.Size([2, 1, 2, 2])
print(a)
print(b)
print(c)
'''
tensor([[[[0.4356, 0.0878],
[0.3151, 0.1201]]],
[[[0.3978, 0.2646],
[0.0629, 0.9532]]]])
tensor([[[[0.9870]]],
[[[0.1181]]]])
tensor([[[[0.9870, 0.9870],
[0.9870, 0.9870]]],
[[[0.1181, 0.1181],
[0.1181, 0.1181]]]])
'''
3.2 repeat 函数
def repeat(*sizes) -> Tensor
- 将原维度中的值进行复制sizes次
b = torch.rand(2, 1, 1, 1)
c = b.repeat(1, 2, 1, 2)
print(c.shape) # torch.Size([2, 2, 1, 2])
print(c)
'''
tensor([[[[0.2657, 0.2657]],
[[0.2657, 0.2657]]],
[[[0.8223, 0.8223]],
[[0.8223, 0.8223]]]])
'''
代码解释:
c = b.repeat(1, 2, 1, 2)
将 b 的 0 维度中的值复制 1 次,所以复制之后 0 维度中有 2 * 1 个值
将 b 的 1 维度中的值复制 2 次,所以复制之后 1 维度中有 1 * 2 个值
将 b 的 2 维度中的值复制 1 次,所以复制之后 2 维度中有 1 * 1 个值
将 b 的 3 维度中的值复制 2 次,所以复制之后 3 维度中有 1 * 2 个值
4. 矩阵转置
4.1 t 函数
def t() -> Tensor
- 将矩阵转置
- 只适用于 2D 的矩阵,其他多维度都不适用
a = torch.randn(2, 2)
print(a)
b = a.t()
print(b)
'''
tensor([[-0.6856, 0.7479],
[ 0.7589, 0.6101]])
tensor([[-0.6856, 0.7589],
[ 0.7479, 0.6101]])
'''
4.2 transpose 函数
def transpose(dim0, dim1) -> Tensor
- 功能:将 dim0 维度与 dim1 维度进行交换
a = torch.randn(4, 3, 28, 28)
print(a.shape) # torch.Size([4, 3, 28, 28])
a1 = a.transpose(0, 3)
print(a1.shape) # torch.Size([28, 3, 28, 4])
案例:数据污染
-
注意
:view 函数运行之后需要记住原维度及size才能恢复原状,要不然可能会产生数据污染,如下代码所示
a = torch.randn(4, 3, 32, 32)
a1 = a.transpose(1, 3).contiguous().view(4, 3 * 32 * 32).view(4, 3, 32, 32)
a2 = a.transpose(1, 3).contiguous().view(4, 3 * 32 * 32).view(4, 32, 32, 3).transpose(1, 3)
# 比较是否完全一样
print(torch.all(torch.eq(a, a1))) # tensor(False)
print(torch.all(torch.eq(a, a2))) # tensor(True)
4.3 permute 函数
def permute(*dims) -> Tensor
- 对维度顺序进行指定
a = torch.rand(4, 3, 28, 28)
b = a.permute(0, 2, 3, 1)
print(b.shape) # torch.Size([4, 28, 28, 3])
解释:
将原来的 0 维度放置在 0 维度处
将原来的 2 维度放置在 1 维度处
将原来的 3 维度放置在 2 维度处
将原来的 1 维度放置在 3 维度处
5. Broadcasting
特点:
- 自动扩展维度并改变维度中的值:自动调用 expand 函数
- 没有数据拷贝:不需要拷贝数据
例如:Tensor A [4, 32, 14, 14] ; Tensor B [32, 1, 1]
将 B 转化为 A :先扩展维度,再扩展维度中的值
[32, 1, 1] -> [1, 32, 1, 1] -> [4, 32, 14, 14]