import torch

batch_size = 2
hidden_dim = 5
x = torch.zeros(batch_size, hidden_dim).scatter_(dim=-1,
index=torch.LongTensor([[2],[1]]),
value=1)

print(x)

x = torch.zeros(batch_size, hidden_dim).scatter_(dim=-1,
index=torch.LongTensor([[2],[1]]),
value=2)

print(x)

x = torch.zeros(batch_size, hidden_dim).scatter_(dim=-1,
index=torch.LongTensor([[2],[3]]),
value=2)

print(x)

print结果:

tensor([[0., 0., 1., 0., 0.],
[0., 1., 0., 0., 0.]])

tensor([[0., 0., 2., 0., 0.],
[0., 2., 0., 0., 0.]])

tensor([[0., 0., 2., 0., 0.],
[0., 0., 0., 2., 0.]])