import torch
input_tensor = torch.tensor([[1,2],[3,4],[5,6]])
gather_input = torch.tensor([[0,0],[1,0],[1,1]])
output_tensor = torch.gather(input_tensor, 1, gather_input)
print(output_tensor)

tensor(
[[1, 1],
[4, 3],
[6, 6]]
)