给tensor增加一维
b = a.unsqueeze(0)
import torch
a = torch.randn(3, 200, 200)
b = a.unsqueeze(0)
print(a.shape)
print(b.shape)
删除tensor一维
squeeze只能删除维度为1的某一维。若某个维度不为1,可以用切片取出该维度的一个数据,再用squeeze删除。
b = a.squeeze(0)
import torch
a = torch.randn(1, 3, 200, 200)
b = a.squeeze(0)
print(a.shape)
print(b.shape)