Pytorch张量的拆分与拼接
预览
在 PyTorch 中,对张量 (Tensor) 进行拆分通常会用到两个函数:
而对张量 (Tensor) 进行拼接通常会用到另外两个函数:
1.张量的拆分
- torch.split函数
torch.split(tensor, split_size_or_sections, dim = 0)
按块大小拆分张量
tensor 为待拆分张量
dim 指定张量拆分的所在维度,即在第几维对张量进行拆分。dim=0是按照行拆分,dim=1是按照列拆分。如果是三维向量的话,可以按照dim=2在矩阵的方向上划分。
split_size_or_sections 表示在 dim 维度拆分张量时每一块在该维度的尺寸大小 (int),或各块尺寸大小的列表 (list)
指定每一块的尺寸大小后,如果在该维度无法整除,则最后一块会取余数,尺寸较小一些
如:长度为 10 的张量,按单位长度 3 拆分,则前三块长度为 3,最后一块长度为 1
函数返回:所有拆分后的张量所组成的 tuple
函数并不会改变原 tensor
import torch
X = torch.randn(6, 2)
Y=torch.split(X, 2, dim = 0)
#返回一个元组tutle
(tensor([[-0.0039, -0.1259],
[-0.7630, 1.3833]]), tensor([[-0.7960, 0.2523],
[-0.5351, -0.5850]]), tensor([[ 0.3403, -0.2898],
[-0.3122, -0.7490]]))
Y=torch.split(X, 4, dim = 0)
#除不尽的取余数
(tensor([[ 1.4674, 0.7185],
[ 0.4943, 1.4040],
[-1.5243, 0.0566],
[-1.2039, -0.3079]]), tensor([[-2.9470, -1.6064],
[-0.8393, -0.5528]]))
- torch.chunk 函数
torch.chunk(input, chunks, dim = 0)
按块数拆分张量
input 为待拆分张量
dim 指定张量拆分的所在维度,即在第几维对张量进行拆分
chunks 表示在 dim 维度拆分张量时最后所分出的总块数 (int),根据该块数进行平均拆分
指定总块数后,如果在该维度无法整除,则每块长度向上取整,最后一块会取余数,尺寸较小一些,若余数恰好为 0,则会只分出 chunks - 1 块
如:
长度为 6 的张量,按 4 块拆分,则只分出三块,长度为 2 (6 / 4 = 1.5 → 2)
长度为 10 的张量,按 4 块拆分,则前三块长度为 3 (10 / 4 = 2.5 → 3),最后一块长度为 1
函数返回:所有拆分后的张量所组成的 tuple
函数并不会改变原 input
In [1]: X = torch.randn(6, 2)
In [2]: X
Out[2]:
tensor([[-0.3711, 0.7372],
[ 0.2608, -0.1129],
[-0.2785, 0.1560],
[-0.7589, -0.8927],
[ 0.1480, -0.0371],
[-0.8387, 0.6233]])
In [3]: torch.chunk(X, 2, dim = 0)
Out[3]:
(tensor([[-0.3711, 0.7372],
[ 0.2608, -0.1129],
[-0.2785, 0.1560]]),
tensor([[-0.7589, -0.8927],
[ 0.1480, -0.0371],
[-0.8387, 0.6233]]))
In [4]: torch.chunk(X, 3, dim = 0)
Out[4]:
(tensor([[-0.3711, 0.7372],
[ 0.2608, -0.1129]]),
tensor([[-0.2785, 0.1560],
[-0.7589, -0.8927]]),
tensor([[ 0.1480, -0.0371],
[-0.8387, 0.6233]]))
In [5]: torch.chunk(X, 4, dim = 0)
Out[5]:
(tensor([[-0.3711, 0.7372],
[ 0.2608, -0.1129]]),
tensor([[-0.2785, 0.1560],
[-0.7589, -0.8927]]),
tensor([[ 0.1480, -0.0371],
[-0.8387, 0.6233]]))
In [6]: Y = torch.randn(10, 2)
In [6]: Y
Out[6]:
tensor([[-0.9749, 1.3103],
[-0.4138, -0.8369],
[-0.1138, -1.6984],
[ 0.7512, -0.3417],
[-1.4575, -0.4392],
[-0.2035, -0.2962],
[-0.7533, -0.8294],
[ 0.0104, -1.3582],
[-1.5781, 0.8594],
[ 0.0286, 0.7611]])
In [7]: torch.chunk(Y, 4, dim = 0)
Out[7]:
(tensor([[-0.9749, 1.3103],
[-0.4138, -0.8369],
[-0.1138, -1.6984]]),
tensor([[ 0.7512, -0.3417],
[-1.4575, -0.4392],
[-0.2035, -0.2962]]),
tensor([[-0.7533, -0.8294],
[ 0.0104, -1.3582],
[-1.5781, 0.8594]]),
tensor([[0.0286, 0.7611]]))
这个函数还是很好理解的
2.张量的合并
可以用torch.cat和torch.stack方法将多个张量合并,但是torch.cat仅仅是张量的连接,不会增加维度,而torch.stack是堆叠,会增加维度。
- cat方法
torch.cat(tensors, dim = 0, out = None)
在已有维度拼接张量
tensors 为待拼接张量的序列,通常为 tuple
dim 指定张量拼接的所在维度,即在第几维对张量进行拼接,除该拼接维度外,其余维度上待拼接张量的尺寸必须相同
out 表示在拼接张量的输出,也可直接使用函数返回值
函数返回:拼接后所得到的张量
函数并不会改变原 tensors
- stack方法
torch.stack(tensors, dim = 0, out = None)
在新维度拼接张量
tensors 为待拼接张量的序列,通常为 tuple
dim 指定张量拼接的新维度对应已有维度的插入索引,即在原来第几维的位置上插入新维度对张量进行拼接,待拼接张量在所有已有维度上的尺寸必须完全相同
out 表示在拼接张量的输出,也可直接使用函数返回值
函数返回:拼接后所得到的张量
函数并不会改变原 tensors
In [1]: x = torch.randn(2, 3)
In [2]: x
Out[2]:
tensor([[-0.0288, 0.6936, -0.6222],
[ 0.8786, -1.1464, -0.6486]])
In [3]: torch.stack((x, x, x), dim = 0)
Out[3]:
tensor([[[-0.0288, 0.6936, -0.6222],
[ 0.8786, -1.1464, -0.6486]],
[[-0.0288, 0.6936, -0.6222],
[ 0.8786, -1.1464, -0.6486]],
[[-0.0288, 0.6936, -0.6222],
[ 0.8786, -1.1464, -0.6486]]])
In [4]: torch.stack((x, x, x), dim = 0).shape
Out[4]: torch.Size([3, 2, 3])
In [5]: torch.stack((x, x, x), dim = 1)
Out[5]:
tensor([[[-0.0288, 0.6936, -0.6222],
[-0.0288, 0.6936, -0.6222],
[-0.0288, 0.6936, -0.6222]],
[[ 0.8786, -1.1464, -0.6486],
[ 0.8786, -1.1464, -0.6486],
[ 0.8786, -1.1464, -0.6486]]])
In [6]: torch.stack((x, x, x), dim = 1).shape
Out[6]: torch.Size([2, 3, 3])