基本数据类型Tensor
tensor是pytorch中的特殊数据类型,和numpy的数组类似,不过tensor可以在GPU和其他加速器上运行。
初始化tensor
- 直接从数据初始化
data = [xxx]
x_data = torch.tensor(data)
- 从numpy数据初始化
x_data = torch.from_numpy(np.array())
- 从其他tensor初始化,torch.*_like()
新的tensor保留原tensor的形状,数据可以改变,参数为(shape, datatype)
x_dada = torch.ones_like()
x_dada = torch.rand_like()
- torch.rand、ones、zeros创建
- torch.empty(shape)未初始化的,数据是随机的。
- torch.FloatTensor(shape或现有数据)等大写的初始化方式。
- torch.normal(u,std)根据高斯分布取随机数。
- torch.full([shape], x)全部设置成x的矩阵。
- torch.arange(x,y,step) 不包含右边界,step是步长
- torch.linspace(x,y,n)n是生成等差数列的个数,默认包含右边界。
- torch.logspace(x,y,steps,base)steps是生成个数,默认底数是10,可以设置成2,e。
- torch.randperm()生成随机数列并打散。
tensor维度
- data.dim()查看tensor的维度
- data.size()或data.shape查看tensor的维度,返回torch.Size([x,x])
- data.numel()返回tensor占的内存大小
tensor索引
- 切片索引
一个 : 全部取
start : end : step,不包含end,步长;end缺省默认到最后。 - a.index_select(索引,索引的索引)
- a[…]表示取所有可能的维度,此时右边的索引是最右边的索引。
- mask_select(x,mask),mask是一定条件下的掩码,返回的tensor会被打平。
- torch.take(src,[index])先打平再返回给定索引的tensor
维度变换
- view / reshape,两者参数一样,效果一样,传入想要的不改变矩阵大小的参数即可。维度信息会丢失。
- squeeze、unsqueeze unsqueeze插入位置
squeeze(要挤压维度的位置)
- 维度扩展expand和repeat
expand是广播形式的扩展,不复制数据,所以效率较高,而repeat复制了数据。expend扩展时只能是维度为1的扩展成相应的维度,如果不想扩展某个维度,那个维度的参数写成-1就行。
repeat的参数是每个维度要重复复制的次数,如果不改变某个维度,则应重复1次。
转置
- a.t()只能用于2D的tensor,也就是只能用于矩阵。
- transpose(),参数是要转换的两个维度,这里要注意转置后,数据的存储顺序已经改变了,因此要记录好原数据的顺序,不然返回原维度回出错
可以用a.transpose().contiguous()连续化,然后再改变维度。 - permute()参数是每个位置要放的原来的数据的维度。
broadcasting 自动扩展
如果认为定义最右边是最小维度,最左边是最高维度,则扩展时应满足小维度相同或小维度是1,否侧不能扩展。
合并与分割
- cat 合并操作参数为([a,b],dim),参数分别为合并的tensor的列表,在哪个维度上合并。要合并的维度可以不一样,但其他维度必须一样。
- stack 合并操作,参数和cat相同,但会创建一个新的维度,合并的tensor维度必须一致。
- split 拆分操作,如果拆分的长度都一样,则直接给定拆分的长度和拆分的维度。如果拆分的维度不一样,则给一个拆分长度的列表和要拆分的维度。
- chunk 按数量来拆分,给定参数拆分的数量和拆分的维度即可。
数学运算
- 基本运算:+ - * /
加减乘除可以直接用重载的运算符+ - * / ,其中整除是 //;也可以用函数:add,sub,mul,div - 矩阵相乘:torch.mm(a,b),torch.matmul(a,b), @
torch.mm和@适合于二维的矩阵相乘;更高维矩阵的相乘要用torch.matmul(),其实际是完成矩阵的并行计算,不参与计算前边的batch、channel的数据。 - 幂计算:pow(a,n) 或者a ** n; torch.exp() e的幂;torch.log2(),torch.log10()对数。
- 其他操作:torch.floor()下取整;torch.ceil()上取整;torch.trunc()取整数部分;torch.frac()取小数部分;torch.round()四舍五入。
- clamp操作,对数据进行裁剪。如果传入一个参数表示所有小于该数的值都变成该参数,如果传入一个范围表示将数据裁剪成该范围内的参数,如果大于改参数范围,就变成上届。
统计属性
- norm求范数,a.norm(n,dim),n表示范数类型L0,L1,L2…;dim表示在哪个维度上求范数。
- mean,max,min,sum,prod:分别求矩阵的平均值、最大值、最小值、求和、累乘。此外argmin(),argmax()返回最值的索引,如果不给最值得维度,求最值前会将矩阵打平。如果传入求最值得维度,会返回一个最值索引数列,代表每个的最值。
- dim和keepdim:dim就是维度,keepdim=True就是在返回最值索引时保持原矩阵的维度,如原来矩阵是410,如果求每行的最值就返回一个41的矩阵,维度还是2.
- topk和kthvalue:topk传入的参数是要返回前几的最值和要操作的维度,其中larges默认为True如果要返回最小值,则设为False就行了。 kthvalue返回的是第k小的值和索引。
- 比较操作:包括比较符号以及torch.eq()返回每个位置的元素是否相等,torch.equal()返回整体是否相等(True or False)。
高级操作
- where
torch.where(condition,x,y)返回一个tensor,其中condition是一个和x,y维度相同的条件矩阵,如果条件矩阵的某个位置满足条件就返回x对应位置的元素,否则返回y对应位置的元素。 - gather
torch.gather(input,dim,index,out=None) 返回一个tensor,具体就是根据索引返回input对应索引位置的元素。其实就是一个查表的过程。例如一个分类任务要分类的标签是[dog,cat,whale] 由于pytorch中没有string类型的数据,因此可以根据最终返回的索引tensor得到对应的分类结果。
prob = torch.randn(4,10)
idx = prob.topk(3,dim=1)
idx = idx[1]
label = torch.arange(10) + 100
torch.gather(label.expand(4,10),dim=1,index=idx)
out为: