文章目录

  • 需要注意的地方
  • 返回的都是 list of view
  • 数据的连续性
  • torch.chunk
  • 函数原型
  • 例程
  • torch.tensor_split
  • 函数原型
  • 例程
  • torch.split
  • 函数原型
  • 例程
  • torch.unbind
  • 函数原型
  • 例程
  • dsplit、vsplit、hsplit
  • 函数原型
  • 例程


Function

Description

Detail

chunk

Attempts to split a tensor into the specified number of chunks.

按指定数量分割张量

tensor_split

Splits a tensor into multiple sub-tensors, all of which are views of input, along dimension dim according to the indices or number of sections specified by indices_or_sections.

按指定引索分割张量

split

Splits the tensor into chunks.

分割张量

unbind

Removes a tensor dimension.

对张量进行解耦操作

dsplit

Splits input, a tensor with three or more dimensions, into multiple tensors depthwise according to indices_or_sections.

按深度方向分割张量

hsplit

Splits input, a tensor with one or more dimensions, into multiple tensors horizontally according to indices_or_sections.

按水平方向分割张量

vsplit

Splits input, a tensor with two or more dimensions, into multiple tensors vertically according to indices_or_sections.

按垂直方向分割张量

需要注意的地方

返回的都是 list of view

在本章中提到的函数,返回的结果都是原张量的view,换句话说如果修改了结果的值,会导致原本的张量被修改。所以如果要想子张量和原张量数据隔离,可以使用

tensor.clone().detach()

torch.clone(tensor)

的方式,创健一个备份。

数据的连续性

另外就是关于数据切分的方式,是按照顺序的形式进行拆分,比如说有一个一维的数据,

[[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12], …]

现在想把它们拆分成3组,那么这个时候的顺序就是这样的了:

第一次拆分时,先按照某个维度区分出3组(例如维度dim=1)

[1, 2], [3, 4], [5, 6]

然后发现还有数据,于是:

[[1, 2], [7, 8], …],
[[3, 4], [9, 10], …],
[[5, 6], [11, 12], …]

直到全部数据都切分完毕,这就是最终的数据形式。

torch.chunk

如果说数据可以通过cat粘合,那么chunk就可以把tensor按维度方向进行分割。我们来看看这个函数原型:

函数原型

torch.chunk(input, chunks, dim=0) → List of Tensors

例程

为了更好说明这个函数是怎么使用的,不如直接看看它的执行结果如何。

>>> tensor = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9]).reshape(3, 3)
>>> tensor
tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])
        
>>> list_of_tensors = torch.chunk(tensor, 3, dim=0)
>>> list_of_tensors
(tensor([[1, 2, 3]]), tensor([[4, 5, 6]]), tensor([[7, 8, 9]]))

>>> list_of_tensors = torch.chunk(tensor, 3, dim=1)
>>> list_of_tensors
(tensor([[1],
        [4],
        [7]]), tensor([[2],
        [5],
        [8]]), tensor([[3],
        [6],
        [9]]))

从这个例子来看,chunk命令,实际上是按照维度进行切分,但不是我们通常理解的按x轴、y轴或z轴这种定义。这在所有Numpy为基础的框架,都是一样的道理。为了更好理解这个问题,我们再看一个例子就好了。

COLUMNS = 5
ROWS = 10
DEPTH = 8

a = torch.arange(0, COLUMNS * ROWS * DEPTH).reshape(ROWS, COLUMNS, DEPTH)

ts = torch.chunk(a, 10, dim=0)
print(ts[0].shape)  # torch.Size([1, 5, 8])

ts = torch.chunk(a, 5, dim=0)
print(ts[0].shape)  # torch.Size([2, 5, 8])

ts = torch.chunk(a, 2, dim=2)
print(ts[0].shape)  # torch.Size([10, 5, 4])

chunks与所对应的dim应该能整除,如果不能整除,那么就返回一个自认为合适的划分,保证这组除最后一个,其他块的维度都是一样。为了更好说明这个情况,看下面这个例子:

COLUMNS = 5
ROWS = 10
DEPTH = 8

a = torch.arange(0, COLUMNS * ROWS * DEPTH).reshape(ROWS, COLUMNS, DEPTH)

ts = torch.chunk(a, 4, dim=1) # 5 / 4 无法整除,所以会返回函数认为最合适的划分
print(len(ts))  # 只返回了3个划分

for t in ts:
    print(t.shape) 
    # torch.Size([10, 2, 8])
	# torch.Size([10, 2, 8])
	# torch.Size([10, 1, 8])

torch.tensor_split

函数原型

torch.tensor_split(input, indices_or_sections, dim=0) → List of Tensors

它的作用和 chunk 很相似,我们来看看具体的代码吧。

例程

COLUMNS = 5
ROWS = 10
DEPTH = 8

a = torch.arange(0, COLUMNS * ROWS * DEPTH).reshape(ROWS, COLUMNS, DEPTH)

ts = torch.tensor_split(a, 4, dim=2)
print(len(ts)) # 8 / 2 = 4

for t in ts:
    print(t.shape)
    # torch.Size([10, 5, 2])
	# torch.Size([10, 5, 2])
	# torch.Size([10, 5, 2])
	# torch.Size([10, 5, 2])

和chunk最大的区别,在于如果某维度无法整除时,它会忠实的按照给定的维度进行划分,余数部分会被平分加入到列表的前几位。

COLUMNS = 5
ROWS = 10
DEPTH = 8

a = torch.arange(0, COLUMNS * ROWS * DEPTH).reshape(ROWS, COLUMNS, DEPTH)

ts = torch.tensor_split(a, 4, dim=1) # 5 / 4 无法整除,所以会返回函数认为最合适的划分
print(len(ts))  # 只返回了4个划分

for t in ts:
    print(t.shape) 
	# torch.Size([10, 2, 8])
	# torch.Size([10, 1, 8])
	# torch.Size([10, 1, 8])
	# torch.Size([10, 1, 8])

torch.split

函数原型

torch.split(tensor, split_size_or_sections, dim=0)

功能总体和上面提到的都很相似,只不过最大的区别在于它不是划分有多少块,而是指定每个view包含多少条数据,为了更好说明,我们依然来直接看看代码好了。

例程

COLUMNS = 5
ROWS = 10
DEPTH = 8

a = torch.arange(0, COLUMNS * ROWS * DEPTH).reshape(ROWS, COLUMNS, DEPTH)

ts = torch.split(a, 5, dim=0)
print(len(ts)) # 10条数据,按每个块包含5条,一共划分成了2块

for t in ts:
    print(t.shape)
	# torch.Size([5, 5, 8])
	# torch.Size([5, 5, 8])

那么如果维度大小不能整除时怎么办?

a = torch.arange(0, COLUMNS * ROWS * DEPTH).reshape(ROWS, COLUMNS, DEPTH)

ts = torch.split(a, 6, dim=0)
print(len(ts))  # 10条数据,按每个块包含6条,一共划分成了2块

for t in ts:
    print(t.shape)
	# torch.Size([6, 5, 8])
	# torch.Size([4, 5, 8])

可以看到,它和整除很像,会优先保证前面的块有足够的数据,最后的块往往不足这个数。

torch.unbind

函数原型

torch.unbind(input, dim=0) → seq

它的作用有一点像 torch.split 但是又有所不同的是,torch.split 会把所有的数据拆分成等分的全部压成一维的。而这个函数,并不会做那么具体,而是你指定某个维度,它直接就会把张量拆成一组张量。

例程

>>> torch.unbind(torch.tensor([[1, 2, 3],
>>>                            [4, 5, 6],
>>>                            [7, 8, 9]]))
(tensor([1, 2, 3]), tensor([4, 5, 6]), tensor([7, 8, 9]))

dsplit、vsplit、hsplit

这三个函数功能很相似,不过主要是针对三维数据进行划分的,其各自的函数原型如下:

函数原型

torch.dsplit(input, indices_or_sections) → List of Tensors
torch.vsplit(input, indices_or_sections) → List of Tensors
torch.hsplit(input, indices_or_sections) → List of Tensors

现在我们来看看例子吧:

例程

COLUMNS = 12
ROWS = 10
DEPTH = 8

a = torch.arange(0, COLUMNS * ROWS * DEPTH).reshape(ROWS, COLUMNS, DEPTH)

# 垂直方向划分,vsplit
vs = torch.vsplit(a, 2)
for s in vs:
    print(s.shape)
    # torch.Size([5, 12, 8])
	# torch.Size([5, 12, 8])

# 水平方向划分,hsplit
hs = torch.hsplit(a, 2)
for s in hs:
    print(s.shape)
   	# torch.Size([10, 6, 8])
	# torch.Size([10, 6, 8])

# 深度方向划分,dsplit
ds = torch.dsplit(a, 2)
for s in ds:
    print(s.shape)
    # torch.Size([10, 12, 4])
	# torch.Size([10, 12, 4])