目录

  • 1、数据类型
  • 2、维度变换
  • view/reshape
  • Squeese/unsqueeze
  • Expand/repeat
  • permute
  • 3、Broadcast
  • 什么时候用broadcast
  • 4、拼接和拆分
  • cat
  • stack
  • split
  • chunk
  • 5、数学运算
  • 基本运算(四则)
  • 矩阵相乘 matmul
  • power
  • 近似值
  • clamp
  • 6、统计属性
  • norm 范数
  • mean,sum,min,max,prod
  • dim,keepdim
  • Top-k
  • 比较
  • 7、进阶操作
  • where
  • gather


1、数据类型

在python中的各种数据类型都用Tensor进行概括:

tensor 如何索引 tensor常用操作_标量


对于string类型,pytorch中要计算string类型的数据,需要先将其也转化为可以处理的Tensor类型

  • one-hot 编码
    [0,1,0,0],[1,0,0,0]…
  • Embedding
    word2vec,glove

pytorch中的数据类型如下所列:

tensor 如何索引 tensor常用操作_tensor 如何索引_02


在程序中可以用.type()和isinstance()检验类型

tensor 如何索引 tensor常用操作_矩阵相乘_03


注意部署在CPU和GPU上是不一样的

tensor 如何索引 tensor常用操作_数据类型_04


注意pytorch里标量是0维的,生成方法如下:

tensor 如何索引 tensor常用操作_标量_05


注意pytorch中标量的shape和size都是空数组,长度为0

tensor 如何索引 tensor常用操作_标量_06


注意和dim为1的张量作区分:

tensor 如何索引 tensor常用操作_标量_07


还可以从numpy中转化得到Tensor

tensor 如何索引 tensor常用操作_标量_08


多维的情况(Dim=3)

tensor 如何索引 tensor常用操作_tensor 如何索引_09

2、维度变换

tensor 如何索引 tensor常用操作_标量_10

view/reshape

view是之前版本的api,与reshape完全一致

举例,现在有一个四维的Tensor。在mnist数据集中,它可以代表4张图片(batch size)灰度信息,尺寸是28*28

a=torch.rand(4,1,28,28)

#view函数 要满足prod相等,注意要表示正确的实际意义意义
#此操作的意义是对每一张图片,直接784个数字作为一维,忽略了二维位置信息,适用于全连接层
a.view(4,28*28)
#此操作看成四个二维数组
a.view(4*1,28,28)

Squeese/unsqueeze

unsqueeze 可以添加更高维度。参数为非负数的话在之前插入维度,负数的话在索引之后插入

tensor 如何索引 tensor常用操作_tensor 如何索引_11

tensor 如何索引 tensor常用操作_标量_12


举例:

tensor 如何索引 tensor常用操作_数据类型_13

当维度不同的Tensor相加的时候,要先用unsqueeze进行维度展开,然后将各个维度大小进行变换后相加:

tensor 如何索引 tensor常用操作_tensor 如何索引_14


squeeze:删减维度,无参数则挤压掉所有可以挤压的维度(dim size=1的维度)给出索引则挤压掉指定维度。如果输入了不能挤压的维度,不会报错,但是Tensor不变

tensor 如何索引 tensor常用操作_数据类型_15

Expand/repeat

比如现在有维度为[32]和[4,32,14,14]的两个Tensor,可以先用unsqueeze将维度扩展为[1,32,1,1]用expand就可以进行维度大小扩展(重复值,但是不重新分配内存)

tensor 如何索引 tensor常用操作_矩阵相乘_16


如果参数是-1 则维度不变 。如果输入一个除了-1的负数,维度会变成这个负数(一般不这样用)

repeat传入的参数为每一维要拷贝的次数

tensor 如何索引 tensor常用操作_数据类型_17

permute

可以交换不同的维度,参数是原来的维度索引

tensor 如何索引 tensor常用操作_tensor 如何索引_18

tensor 如何索引 tensor常用操作_数据类型_19

3、Broadcast

Broadcast总是在“大维度”上进行自动扩张,可以认为左边的维度是大维度。

实际问题:有一个Feature map:[4,32,14,14],分别代表batch size、通道数、长、宽

需要加上一个偏置[32,1,1]

tensor 如何索引 tensor常用操作_数据类型_20


相当于unsqueeze+expand

tensor 如何索引 tensor常用操作_矩阵相乘_21


实际中,可能要在一个高维Tensor上加上一个标量,就用到了broadcasting

tensor 如何索引 tensor常用操作_标量_22


另外用broadcast可以节省内存

什么时候用broadcast

低维度要么是1(可以自动扩展相加),要么和被加Tensor的低维度匹配

比如A [4,32,8] B为标量,可以用broadcast机制相加。先将低维扩展成维度=8,再扩展出两个高维。如果B为Tensor,维数大小与A的低维大小相同,也可以自动扩展高维之后相加,比如B为[1,8]

情形1

tensor 如何索引 tensor常用操作_标量_23


情形2

在每张图片的每个通道都叠加一个二维Tensor

tensor 如何索引 tensor常用操作_数据类型_24


不可用情形

高维只给了两张的信息,操作无法完成。可以用B[n]举出某一张的Tensor然后相加

tensor 如何索引 tensor常用操作_标量_25

4、拼接和拆分

tensor 如何索引 tensor常用操作_矩阵相乘_26

cat

假设有两份成绩单,一份是1-4班的成绩单,一份是5-9班的成绩单,成绩单Tensor有三个维度,分别代表班级、学生和课程。

tensor 如何索引 tensor常用操作_矩阵相乘_27


现在要将此两个Tensor拼接在一起,就可以用cat,传入两个Tensor以及合并的维度

tensor 如何索引 tensor常用操作_tensor 如何索引_28


举二维Tensor的例子,在dim=0上拼接,即按照行来拼接:[4,4] [4,4] [4,4] 得到[12,4]

tensor 如何索引 tensor常用操作_标量_29


在dim=1上拼接,[4,4] [4,3]就得到[4,7]

tensor 如何索引 tensor常用操作_数据类型_30


注意在cat的时候除了拼接的维度,其他维度的size要一样

stack

在拼接的时候创建新维度

对比:

tensor 如何索引 tensor常用操作_tensor 如何索引_31


比如有两张表,都是328,如果用cat会合成一张648的大表,但是有时候我们想要分开存放,便于调用,这样就要用stack,创建了一个新的维度(班级),便于调用管理。

tensor 如何索引 tensor常用操作_数据类型_32


注意用stack的话,除了生成的新维度,其他维度都要相同。

split

参数格式1:切分后每个单元的长度(如果是一个数代表每个的长度都是这么多;如果是一个列表则分别代表每个拆分后的维度大小)+维度索引

拆分之前:[2,32,8]

tensor 如何索引 tensor常用操作_标量_33

tensor 如何索引 tensor常用操作_标量_34

chunk

按照数量来拆分

tensor 如何索引 tensor常用操作_数据类型_35

5、数学运算

tensor 如何索引 tensor常用操作_矩阵相乘_36

基本运算(四则)

下面的add用到了广播机制

tensor 如何索引 tensor常用操作_标量_37


sub,mul,div 操作同理

tensor 如何索引 tensor常用操作_tensor 如何索引_38

矩阵相乘 matmul

tensor 如何索引 tensor常用操作_数据类型_39


一般就使用matmul

tensor 如何索引 tensor常用操作_数据类型_40


实例:降低某一维度的长度:

乘以一个[784,512]矩阵

可以将[4,784]->[512,784]

这里反着写是因为pytorch约定chanel-out chanel-in的顺序,后面进行矩阵相乘的时候用.t()转置一下

tensor 如何索引 tensor常用操作_tensor 如何索引_41


二维以上的矩阵相乘,只对后面两维作相乘运算

tensor 如何索引 tensor常用操作_矩阵相乘_42


tensor 如何索引 tensor常用操作_标量_43


如果之前的维数不一样,由于broadcast机制可以自动扩展相乘

tensor 如何索引 tensor常用操作_数据类型_44


broadcast在0维扩展,如果无法用broadcast扩展,则会报错

tensor 如何索引 tensor常用操作_数据类型_45

power

接收矩阵和每个元素的pow

tensor 如何索引 tensor常用操作_tensor 如何索引_46


其他操作同理 rsqrt是平方根的倒数

tensor 如何索引 tensor常用操作_标量_47


exp和log:

tensor 如何索引 tensor常用操作_tensor 如何索引_48

近似值

取下、取上、四舍五入、取整、取小数

tensor 如何索引 tensor常用操作_数据类型_49

clamp

如果只有一个参数,限定最小值

如果两个参数,限定最小值和最大值

tensor 如何索引 tensor常用操作_标量_50

6、统计属性

norm 范数

第一范数是和 第二范数是平方和开根号

tensor 如何索引 tensor常用操作_标量_51

mean,sum,min,max,prod

注意argmax和argmin返回的是索引。且是变成vector后的索引

tensor 如何索引 tensor常用操作_数据类型_52


在用argmax的时候可以输入维度索引:

tensor 如何索引 tensor常用操作_矩阵相乘_53

dim,keepdim

统计信息会消除dimension,用keepdim可以避免消除dimension

tensor 如何索引 tensor常用操作_标量_54

Top-k

取最大的前k个,同样可以用dim指定维度

将largest设置为false(默认为true)可以求前k小的

kthvalue返回第k小的值和索引

tensor 如何索引 tensor常用操作_数据类型_55

比较

可以对Tensor的每个元素进行比较,返回与原Tensor维度相同的0-1Tensor作为结果

注意eq和equal的区别,前者逐个比较Tensor元素,后者比较Tensor整体

tensor 如何索引 tensor常用操作_tensor 如何索引_56

7、进阶操作

tensor 如何索引 tensor常用操作_tensor 如何索引_57

where

参数:条件+原数据A+原数据B

条件是一个和A和B维度相同的Tensor,0代表来自ATensor该位置的元素,1代表来自BTensor该位置的元素

tensor 如何索引 tensor常用操作_标量_58


例子:使一个Tensor中大于0.5的数字取0,小于等于0.5的数字取1:

tensor 如何索引 tensor常用操作_tensor 如何索引_59


为什么要用where而不用for循环一个个比较,是因为后者完全使用cpu,用where是用的并行运算,用的GPU,速度会更快。

gather

tensor 如何索引 tensor常用操作_tensor 如何索引_60


根据索引表,从一个表中采集不同的元素

tensor 如何索引 tensor常用操作_tensor 如何索引_61


用gather进行查表操作用法如下

tensor 如何索引 tensor常用操作_数据类型_62


例子:

tensor 如何索引 tensor常用操作_tensor 如何索引_63