两个张量在同一个设备上,但我得到了一个错误:预期所有张量都在同一个设备上,但至少找到了两个设备,cuda:0和cpu

2024-09-27 21:29:52 发布

您现在位置:Python中文网/ 问答频道 /正文

我有:

def loss_fn(self, pred, truth):        
    truth_flat = torch.reshape(truth, (truth.size(0),-1)).to(truth.device)
    pred_flat = torch.reshape(pred, (pred.size(0),-1)).to(pred.device)
    
    stoi_loss = NegSTOILoss(sample_rate=16000)(pred_flat, truth_flat)
    print('truth', truth.size(), truth_flat.size(), stoi_loss)

    return torch.nn.MSELoss()(pred, truth)

如您所见,我正在确保它位于同一台设备上,但仍然会出现错误:

Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu

有什么想法吗


Tags: tosampleselfsizedevicedeftorchfn

热门问题