tf.gather_nd 函数对应的pytorch函数

  • 1. 简单介绍
  • 2. 步入正题
  • 2.1 tensorflow tf.gather_nd()
  • 2.2 pytorch框架手动实现gather_nd()函数
  • 3. 重点之处==pytorch实现== tf.gather_nd()函数
  • 4.上文中第二节中 tuple_tensor()函数
  • 总结


1. 简单介绍

从一开始学习的是【python】中的pytorch框架,但是最近用到的一个深度学习的网络模块,并没有找到pytorch框架实现的代码,只能从tensorflow框架的代码中转化而来。
注意的是: 在tensorflow框架中的函数tf.gather_nd()函数,在pytorch框架中并没有与之相对应的,通过各种途径学习,现在找到一个适合自己的函数去代替。可能有的地方不适用,其他情况大家可以参考下面的链接对函数进行改动。

参考学习链接https://discuss.pytorch.org/t/how-to-do-the-tf-gather-nd-in-pytorch/6445/26

2. 步入正题

2.1 tensorflow tf.gather_nd()

tf.gather_nd()函数介绍:

tf.gather_nd(
params,
indices,
name=None
)

  • 作用:将params索引为indices指定形状的切片数组中(indices代表索引后的数组形状)
  • indices将切片定义为params的前N个维度,其中N = indices.shape [-1]
  • 通常要求indices.shape[-1] <= params.rank(可以用np.ndim(params)查看)
  • 如果等号成立是在索引具体元素
  • 如果等号不成立是在沿params的indices.shape[-1]轴进行切片
  • 返回维度: indices.shape[:-1] + params.shape[indices.shape[-1]:]
  • 前面的indices.shape[:-1]代表索引后的指定形状

想要更加详细的介绍,大家可以参考博客:

  1. https://www.w3cschool.cn/tensorflow_python/tensorflow_python-ctv72eru.html
  2. 知乎链接

在这里不做详细介绍了,直接上代码:

# tensorflow
x = np.random.rand(1, 5, 5)
out = np.random.rand(1, 5, 5)
map_lambda1 = tf.exp(tf.divide(tf.subtract(2.0, out), tf.add(1.0, out)))

band_index = tf.reduce_all([x <= 0.5, x >= -0.5], axis=0)
band = tf.where(band_index)
print('band ',[band])
print('\nmap_lambda1 ',map_lambda1)
lambda1 = tf.gather_nd(map_lambda1, [band])
print('\nlambda1 ',lambda1)

输出结果对比如下:

用pytorch实现harr特征 pytorch gather_nd_用pytorch实现harr特征

2.2 pytorch框架手动实现gather_nd()函数

代码:

# pytorch
x = np.random.rand(1, 5, 5)
out = np.random.rand(1, 5, 5)
torch_out = torch.from_numpy(out)
torch_map1 = torch.exp(torch.divide(torch.subtract(torch.full(torch_out.size(), 2.0), torch_out), torch.add(1.0, torch_out)))

torch_bandidx = torch.all((torch.tensor(x) <= 0.5) & (torch.tensor(x) >= -0.5), dim=0)
torch_band = torch.where(torch_bandidx)
torch_band = tuple_tensor(torch_band)
torch_band = torch_band[np.newaxis,:,:]
print("torch_band  ",torch_band)
print("torch_map1  ", torch_map1)
torch_lambda1 = gather_nd(torch_map1, torch_band)
torch_lambda2 = gather_nd(torch_map2, torch_band)
print(torch_lambda1)

输出结果:

用pytorch实现harr特征 pytorch gather_nd_深度学习_02

3. 重点之处pytorch实现 tf.gather_nd()函数

本文文章的重点之处,写到最后。

def gather_nd(params, indices):
    '''
    4D example
    params: tensor shaped [n_1, n_2, n_3, n_4] --> 4 dimensional
    indices: tensor shaped [m_1, m_2, m_3, m_4, 4] --> multidimensional list of 4D indices

    returns: tensor shaped [m_1, m_2, m_3, m_4]

    ND_example
    params: tensor shaped [n_1, ..., n_p] --> d-dimensional tensor
    indices: tensor shaped [m_1, ..., m_i, d] --> multidimensional list of d-dimensional indices

    returns: tensor shaped [m_1, ..., m_1]
    '''

    out_shape = indices.shape[:-1]
    indices = indices.unsqueeze(0).transpose(0, -1)  # roll last axis to fring
    ndim = indices.shape[0]
    indices = indices.long()
    idx = torch.zeros_like(indices[0], device=indices.device).long()
    m = 1

    for i in range(ndim)[::-1]:
        idx += indices[i] * m
        m *= params.size(i)

    out = torch.take(params, idx)

    return out.view(out_shape)

4.上文中第二节中 tuple_tensor()函数

如果想要知道为何要添加这样一个函数,请大家参考我的另外一一篇博客,这个里面讲到了tf.where()函数实现结果和torch.where()函数输出在结果上面的不同,进而需要这样一个tuple_tensor()函数进行一下处理。

def tuple_tensor(cant):
    zlist = []
    if len(cant) == 2:
        for i, j in zip(cant[0], cant[1]):
            data = torch.tensor([i, j])
            zlist.append(data)

    if len(cant) == 3:
        for i, j, k in zip(cant[0], cant[1], cant[2]):
            data = torch.tensor([i, j, k])
            zlist.append(data)

    out = torch.stack(zlist)
    return out

总结

tensorflow一步步转化为pytorch路程比较艰难,仅以自己走过的弯路给大家一些借鉴。