tf.gather_nd 函数对应的pytorch函数
- 1. 简单介绍
- 2. 步入正题
- 2.1 tensorflow tf.gather_nd()
- 2.2 pytorch框架手动实现gather_nd()函数
- 3. 重点之处==pytorch实现== tf.gather_nd()函数
- 4.上文中第二节中 tuple_tensor()函数
- 总结
1. 简单介绍
从一开始学习的是【python】中的pytorch框架,但是最近用到的一个深度学习的网络模块,并没有找到pytorch框架实现的代码,只能从tensorflow框架的代码中转化而来。
注意的是: 在tensorflow框架中的函数tf.gather_nd()函数,在pytorch框架中并没有与之相对应的,通过各种途径学习,现在找到一个适合自己的函数去代替。可能有的地方不适用,其他情况大家可以参考下面的链接对函数进行改动。
参考学习链接:https://discuss.pytorch.org/t/how-to-do-the-tf-gather-nd-in-pytorch/6445/26
2. 步入正题
2.1 tensorflow tf.gather_nd()
tf.gather_nd()函数介绍:
tf.gather_nd(
params,
indices,
name=None
)
- 作用:将params索引为indices指定形状的切片数组中(indices代表索引后的数组形状)
- indices将切片定义为params的前N个维度,其中N = indices.shape [-1]
- 通常要求indices.shape[-1] <= params.rank(可以用np.ndim(params)查看)
- 如果等号成立是在索引具体元素
- 如果等号不成立是在沿params的indices.shape[-1]轴进行切片
- 返回维度: indices.shape[:-1] + params.shape[indices.shape[-1]:]
- 前面的indices.shape[:-1]代表索引后的指定形状
想要更加详细的介绍,大家可以参考博客:
在这里不做详细介绍了,直接上代码:
# tensorflow
x = np.random.rand(1, 5, 5)
out = np.random.rand(1, 5, 5)
map_lambda1 = tf.exp(tf.divide(tf.subtract(2.0, out), tf.add(1.0, out)))
band_index = tf.reduce_all([x <= 0.5, x >= -0.5], axis=0)
band = tf.where(band_index)
print('band ',[band])
print('\nmap_lambda1 ',map_lambda1)
lambda1 = tf.gather_nd(map_lambda1, [band])
print('\nlambda1 ',lambda1)
输出结果对比如下:
2.2 pytorch框架手动实现gather_nd()函数
代码:
# pytorch
x = np.random.rand(1, 5, 5)
out = np.random.rand(1, 5, 5)
torch_out = torch.from_numpy(out)
torch_map1 = torch.exp(torch.divide(torch.subtract(torch.full(torch_out.size(), 2.0), torch_out), torch.add(1.0, torch_out)))
torch_bandidx = torch.all((torch.tensor(x) <= 0.5) & (torch.tensor(x) >= -0.5), dim=0)
torch_band = torch.where(torch_bandidx)
torch_band = tuple_tensor(torch_band)
torch_band = torch_band[np.newaxis,:,:]
print("torch_band ",torch_band)
print("torch_map1 ", torch_map1)
torch_lambda1 = gather_nd(torch_map1, torch_band)
torch_lambda2 = gather_nd(torch_map2, torch_band)
print(torch_lambda1)
输出结果:
3. 重点之处pytorch实现 tf.gather_nd()函数
本文文章的重点之处,写到最后。
def gather_nd(params, indices):
'''
4D example
params: tensor shaped [n_1, n_2, n_3, n_4] --> 4 dimensional
indices: tensor shaped [m_1, m_2, m_3, m_4, 4] --> multidimensional list of 4D indices
returns: tensor shaped [m_1, m_2, m_3, m_4]
ND_example
params: tensor shaped [n_1, ..., n_p] --> d-dimensional tensor
indices: tensor shaped [m_1, ..., m_i, d] --> multidimensional list of d-dimensional indices
returns: tensor shaped [m_1, ..., m_1]
'''
out_shape = indices.shape[:-1]
indices = indices.unsqueeze(0).transpose(0, -1) # roll last axis to fring
ndim = indices.shape[0]
indices = indices.long()
idx = torch.zeros_like(indices[0], device=indices.device).long()
m = 1
for i in range(ndim)[::-1]:
idx += indices[i] * m
m *= params.size(i)
out = torch.take(params, idx)
return out.view(out_shape)
4.上文中第二节中 tuple_tensor()函数
如果想要知道为何要添加这样一个函数,请大家参考我的另外一一篇博客,这个里面讲到了tf.where()函数实现结果和torch.where()函数输出在结果上面的不同,进而需要这样一个tuple_tensor()函数进行一下处理。
def tuple_tensor(cant):
zlist = []
if len(cant) == 2:
for i, j in zip(cant[0], cant[1]):
data = torch.tensor([i, j])
zlist.append(data)
if len(cant) == 3:
for i, j, k in zip(cant[0], cant[1], cant[2]):
data = torch.tensor([i, j, k])
zlist.append(data)
out = torch.stack(zlist)
return out
总结
tensorflow一步步转化为pytorch路程比较艰难,仅以自己走过的弯路给大家一些借鉴。