我不太确定我在这里做错了什么,或者这是否是PyTorch中的一个bug。我试图预测一些类——在本例中是5类——但其中一个,0类,支配着所有其他类。这本质上是后台类,我们对它不太感兴趣。所以我想用交叉熵函数中的权重来强调其他4类
我正在阅读https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html?highlight=crossentropyloss的文档
。。。第一个参数张量的形状是(minibatch,C,d1),d2,...,dK)
C显然是类的数量,在我的例子中是5,这将是0到4
现在根据文档“权重(张量,可选)–每个类的手动重缩放权重。如果给定,必须是大小为C的张量”-因此为5
以下是我目前掌握的代码:
def loss_func(result, target) -> torch.Tensor:
print(result.shape, target.shape)
class_weights = torch.tensor([0.1, 1.0, 1.0, 1,0, 1.0], dtype=torch.float16, device=result.device)
criterion = nn.CrossEntropyLoss(weight=class_weights)
dense = target.to_dense().long().to(result.device)
loss = criterion(result, dense)
return loss
我得到的结果如下:
torch.Size([2, 5, 26, 150, 320]) torch.Size([2, 26, 150, 320])
Traceback (most recent call last):
File "/home/oni/Projects/PhD/sea_elegance_multi/train_unet.py", line 213, in <module>
train(args, model, train_data, test_data, optimiser, writer)
File "/home/oni/Projects/PhD/sea_elegance_multi/train_unet.py", line 118, in train
loss = loss_func(result, target_mask)
File "/home/oni/Projects/PhD/sea_elegance_multi/train_unet.py", line 61, in loss_func
loss = criterion(result, dense)
File "/home/oni/.conda/envs/seaelegance/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/home/oni/.conda/envs/seaelegance/lib/python3.9/site-packages/torch/nn/modules/loss.py", line 1120, in forward
return F.cross_entropy(input, target, weight=self.weight,
File "/home/oni/.conda/envs/seaelegance/lib/python3.9/site-packages/torch/nn/functional.py", line 2824, in cross_entropy
return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
RuntimeError: weight tensor should be defined either for all or no classes at /pytorch/aten/src/THCUNN/generic/SpatialClassNLLCriterion.cu:27
现在,如果我将权重张量调整为[0.1,1.0,1.0,1.0],效果很好,或者更确切地说,我得到了NaN,必须添加“忽略指数=0”
我使用的批处理大小为2,3D图像大小为320x150x26像素。因此[2,5,26,150,320]对我来说似乎是正确的。我想知道我是否遗漏了什么,或者是否存在错误?我使用的是float16-out,这有时会导致NaN,但我认为“应该为所有类或任何类定义权重张量”看起来像个虫子
其他人有过这个吗?也许我只需要升级我的pytorch设置?我正在运行1.9.0
以下是一个简单可行的示例:
import torch
from torch import nn
def loss_func_works(result, target) -> torch.Tensor:
print(result.shape, target.shape)
class_weights = torch.tensor([1.0, 1.0, 1,0, 1.0], dtype=torch.float32, device=result.device)
criterion = nn.CrossEntropyLoss(weight=class_weights)
dense = target.long().to(result.device)
loss = criterion(result, dense)
return loss
def loss_func_fails(result, target) -> torch.Tensor:
print(result.shape, target.shape)
class_weights = torch.tensor([0.1, 1.0, 1.0, 1,0, 1.0], dtype=torch.float32, device=result.device)
criterion = nn.CrossEntropyLoss(weight=class_weights)
dense = target.long().to(result.device)
loss = criterion(result, dense)
return loss
if __name__ == "__main__":
result = torch.tensor((), dtype=torch.float32)
result = result.new_ones((2, 5, 26, 150, 320))
target = torch.tensor((), dtype=torch.float32)
target = target.new_ones((2, 26, 150, 320))
loss_func_works(result, target)
loss_func_fails(result, target)
目前没有回答
相关问题 更多 >
编程相关推荐