我有:
def loss_fn(self, pred, truth):
truth_flat = torch.reshape(truth, (truth.size(0),-1)).to(truth.device)
pred_flat = torch.reshape(pred, (pred.size(0),-1)).to(pred.device)
stoi_loss = NegSTOILoss(sample_rate=16000)(pred_flat, truth_flat)
print('truth', truth.size(), truth_flat.size(), stoi_loss)
return torch.nn.MSELoss()(pred, truth)
如您所见,我正在确保它位于同一台设备上,但仍然会出现错误:
Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu
有什么想法吗
您正在分配给两个不同的设备:
truth.device
和pred.device
相关问题 更多 >
编程相关推荐