BUG
使用Cross_entropy损失函数时出现 RuntimeError: multi-target not supported at …
可能存在的问题
1)其标签必须为0~n-1,而且必须为1维的,如果设置标签为[nx1]的,则也会出现以上错误。
2)标签y打印:
tensor([[1, 0], [1, 0], [1, 0], [1, 0], [1, 0], [1, 0], [1, 0], [1, 0]], device='cuda:0')
对于n.CrossEntroyLoss,目标必须是间隔[0,#class]的单个数字,而不是一个热编码的目标向量。您的目标是[1,0],因此PyTorch认为您希望每个输入都有多个不受支持的标签。
替换您的one-hot编码标签:
[1, 0] --> 0
[0, 1] --> 1
输入以下代码可解决上述问题。
y = torch.argmax(y, dim=1)