标量类型应为Long,但在PyTorch中找到Float,使用nn.CrossEntropyLoss()命令

2024-07-02 04:52:36 发布

您现在位置:Python中文网/ 问答频道 /正文

我想做一个标签火炬张量。我选择了两种不同的方法,第一种方法在使用^{计算损失时出错。我想知道为什么会发生这种情况,尽管张量结果是一样的

第一种方法:

labels = torch.hstack((torch.zeros(100),torch.ones(100),1+torch.ones(100)))

第二种方法:

labels_np = np.vstack((np.zeros((100,1)),np.ones((100,1)),1+np.ones((100,1))))
labels = torch.squeeze(torch.tensor(labels_np).long())

错误

expected scalar type Long but found Float in Pytoch

Tags: 方法labelsnponeszeros情况torch标签