前言

 在看FCOS算法源码时,发现获取正样本点用到了scatter这个函数,故记录下。

1、官方文档解释

  先贴出链接:​

Tensor.scatter_(dim, index, src, reduce=None) → Tensor

 接收三个参数: dim, index和src。该函数作用就是在dim维度上,根据index提供的索引,从src中提取对应元素来赋值给Tensor。 以下是官方给的一个三维张量例子。

Pytorch的scatter函数详解_深度学习

 需要注意两个点:index和src的dim维度数必须一样! 以官方3-D tensor为例,即self、src和index的维度均为3;若是2D-tensor则self、src和index的维度均为2。因为需要用index的元素作为索引,故index中元素的大小必须<self.size(d) 且 src.size(d)。

2、举个例子

Pytorch的scatter函数详解_python_02

附上代码:

src = torch.Tensor([[0,1,2,3,4],[5,6,7,8,9]])
self = torch.zeros((3,5))
index = torch.tensor([[0,1,2,0,0],[2,0,0,1,2]])
self.scatter_(dim=0,index=index,src=src)
print(self)

输出:

Pytorch的scatter函数详解_标量_03

总结

  在实际编程中,src往往是标量,即是个常数。根据定义,等式右边的src[i][j] 恒等于标量。即此时scatter函数作用就是根据index将self中对应位置变成常数即可。