目录
- 1、数据类型
- 2、维度变换
- view/reshape
- Squeese/unsqueeze
- Expand/repeat
- permute
- 3、Broadcast
- 什么时候用broadcast
- 4、拼接和拆分
- cat
- stack
- split
- chunk
- 5、数学运算
- 基本运算(四则)
- 矩阵相乘 matmul
- power
- 近似值
- clamp
- 6、统计属性
- norm 范数
- mean,sum,min,max,prod
- dim,keepdim
- Top-k
- 比较
- 7、进阶操作
- where
- gather
1、数据类型
在python中的各种数据类型都用Tensor进行概括:
对于string类型,pytorch中要计算string类型的数据,需要先将其也转化为可以处理的Tensor类型
- one-hot 编码
[0,1,0,0],[1,0,0,0]… - Embedding
word2vec,glove
pytorch中的数据类型如下所列:
在程序中可以用.type()和isinstance()检验类型
注意部署在CPU和GPU上是不一样的
注意pytorch里标量是0维的,生成方法如下:
注意pytorch中标量的shape和size都是空数组,长度为0
注意和dim为1的张量作区分:
还可以从numpy中转化得到Tensor
多维的情况(Dim=3)
2、维度变换
view/reshape
view是之前版本的api,与reshape完全一致
举例,现在有一个四维的Tensor。在mnist数据集中,它可以代表4张图片(batch size)灰度信息,尺寸是28*28
a=torch.rand(4,1,28,28)
#view函数 要满足prod相等,注意要表示正确的实际意义意义
#此操作的意义是对每一张图片,直接784个数字作为一维,忽略了二维位置信息,适用于全连接层
a.view(4,28*28)
#此操作看成四个二维数组
a.view(4*1,28,28)
Squeese/unsqueeze
unsqueeze 可以添加更高维度。参数为非负数的话在之前插入维度,负数的话在索引之后插入
举例:
当维度不同的Tensor相加的时候,要先用unsqueeze进行维度展开,然后将各个维度大小进行变换后相加:
squeeze:删减维度,无参数则挤压掉所有可以挤压的维度(dim size=1的维度)给出索引则挤压掉指定维度。如果输入了不能挤压的维度,不会报错,但是Tensor不变
Expand/repeat
比如现在有维度为[32]和[4,32,14,14]的两个Tensor,可以先用unsqueeze将维度扩展为[1,32,1,1]用expand就可以进行维度大小扩展(重复值,但是不重新分配内存)
如果参数是-1 则维度不变 。如果输入一个除了-1的负数,维度会变成这个负数(一般不这样用)
repeat传入的参数为每一维要拷贝的次数
permute
可以交换不同的维度,参数是原来的维度索引
3、Broadcast
Broadcast总是在“大维度”上进行自动扩张,可以认为左边的维度是大维度。
实际问题:有一个Feature map:[4,32,14,14],分别代表batch size、通道数、长、宽
需要加上一个偏置[32,1,1]
相当于unsqueeze+expand
实际中,可能要在一个高维Tensor上加上一个标量,就用到了broadcasting
另外用broadcast可以节省内存
什么时候用broadcast
低维度要么是1(可以自动扩展相加),要么和被加Tensor的低维度匹配
比如A [4,32,8] B为标量,可以用broadcast机制相加。先将低维扩展成维度=8,再扩展出两个高维。如果B为Tensor,维数大小与A的低维大小相同,也可以自动扩展高维之后相加,比如B为[1,8]
情形1
情形2
在每张图片的每个通道都叠加一个二维Tensor
不可用情形
高维只给了两张的信息,操作无法完成。可以用B[n]举出某一张的Tensor然后相加
4、拼接和拆分
cat
假设有两份成绩单,一份是1-4班的成绩单,一份是5-9班的成绩单,成绩单Tensor有三个维度,分别代表班级、学生和课程。
现在要将此两个Tensor拼接在一起,就可以用cat,传入两个Tensor以及合并的维度
举二维Tensor的例子,在dim=0上拼接,即按照行来拼接:[4,4] [4,4] [4,4] 得到[12,4]
在dim=1上拼接,[4,4] [4,3]就得到[4,7]
注意在cat的时候除了拼接的维度,其他维度的size要一样
stack
在拼接的时候创建新维度
对比:
比如有两张表,都是328,如果用cat会合成一张648的大表,但是有时候我们想要分开存放,便于调用,这样就要用stack,创建了一个新的维度(班级),便于调用管理。
注意用stack的话,除了生成的新维度,其他维度都要相同。
split
参数格式1:切分后每个单元的长度(如果是一个数代表每个的长度都是这么多;如果是一个列表则分别代表每个拆分后的维度大小)+维度索引
拆分之前:[2,32,8]
chunk
按照数量来拆分
5、数学运算
基本运算(四则)
下面的add用到了广播机制
sub,mul,div 操作同理
矩阵相乘 matmul
一般就使用matmul
实例:降低某一维度的长度:
乘以一个[784,512]矩阵
可以将[4,784]->[512,784]
这里反着写是因为pytorch约定chanel-out chanel-in的顺序,后面进行矩阵相乘的时候用.t()转置一下
二维以上的矩阵相乘,只对后面两维作相乘运算
如果之前的维数不一样,由于broadcast机制可以自动扩展相乘
broadcast在0维扩展,如果无法用broadcast扩展,则会报错
power
接收矩阵和每个元素的pow
其他操作同理 rsqrt是平方根的倒数
exp和log:
近似值
取下、取上、四舍五入、取整、取小数
clamp
如果只有一个参数,限定最小值
如果两个参数,限定最小值和最大值
6、统计属性
norm 范数
第一范数是和 第二范数是平方和开根号
mean,sum,min,max,prod
注意argmax和argmin返回的是索引。且是变成vector后的索引
在用argmax的时候可以输入维度索引:
dim,keepdim
统计信息会消除dimension,用keepdim可以避免消除dimension
Top-k
取最大的前k个,同样可以用dim指定维度
将largest设置为false(默认为true)可以求前k小的
kthvalue返回第k小的值和索引
比较
可以对Tensor的每个元素进行比较,返回与原Tensor维度相同的0-1Tensor作为结果
注意eq和equal的区别,前者逐个比较Tensor元素,后者比较Tensor整体
7、进阶操作
where
参数:条件+原数据A+原数据B
条件是一个和A和B维度相同的Tensor,0代表来自ATensor该位置的元素,1代表来自BTensor该位置的元素
例子:使一个Tensor中大于0.5的数字取0,小于等于0.5的数字取1:
为什么要用where而不用for循环一个个比较,是因为后者完全使用cpu,用where是用的并行运算,用的GPU,速度会更快。
gather
根据索引表,从一个表中采集不同的元素
用gather进行查表操作用法如下
例子: