Pyrotch交叉熵损失数权重与类数不匹配?

2024-05-03 12:10:13 发布

您现在位置:Python中文网/ 问答频道 /正文

我不太确定我在这里做错了什么,或者这是否是PyTorch中的一个bug。我试图预测一些类——在本例中是5类——但其中一个,0类,支配着所有其他类。这本质上是后台类,我们对它不太感兴趣。所以我想用交叉熵函数中的权重来强调其他4类

我正在阅读https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html?highlight=crossentropyloss的文档

。。。第一个参数张量的形状是(minibatch,C,d1)​,d2​,...,dK​)

C显然是类的数量,在我的例子中是5,这将是0到4

现在根据文档“权重(张量,可选)–每个类的手动重缩放权重。如果给定,必须是大小为C的张量”-因此为5

以下是我目前掌握的代码:

def loss_func(result, target) -> torch.Tensor:
    print(result.shape, target.shape)
    class_weights = torch.tensor([0.1, 1.0, 1.0, 1,0, 1.0], dtype=torch.float16, device=result.device)
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    dense = target.to_dense().long().to(result.device)
    loss = criterion(result, dense)
    return loss

我得到的结果如下:

torch.Size([2, 5, 26, 150, 320]) torch.Size([2, 26, 150, 320])
Traceback (most recent call last):
  File "/home/oni/Projects/PhD/sea_elegance_multi/train_unet.py", line 213, in <module>
    train(args, model, train_data, test_data, optimiser, writer)
  File "/home/oni/Projects/PhD/sea_elegance_multi/train_unet.py", line 118, in train
    loss = loss_func(result, target_mask)
  File "/home/oni/Projects/PhD/sea_elegance_multi/train_unet.py", line 61, in loss_func
    loss = criterion(result, dense)
  File "/home/oni/.conda/envs/seaelegance/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/oni/.conda/envs/seaelegance/lib/python3.9/site-packages/torch/nn/modules/loss.py", line 1120, in forward
    return F.cross_entropy(input, target, weight=self.weight,
  File "/home/oni/.conda/envs/seaelegance/lib/python3.9/site-packages/torch/nn/functional.py", line 2824, in cross_entropy
    return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
RuntimeError: weight tensor should be defined either for all or no classes at /pytorch/aten/src/THCUNN/generic/SpatialClassNLLCriterion.cu:27

现在,如果我将权重张量调整为[0.1,1.0,1.0,1.0],效果很好,或者更确切地说,我得到了NaN,必须添加“忽略指数=0”

我使用的批处理大小为2,3D图像大小为320x150x26像素。因此[2,5,26,150,320]对我来说似乎是正确的。我想知道我是否遗漏了什么,或者是否存在错误?我使用的是float16-out,这有时会导致NaN,但我认为“应该为所有类或任何类定义权重张量”看起来像个虫子

其他人有过这个吗?也许我只需要升级我的pytorch设置?我正在运行1.9.0

以下是一个简单可行的示例:

import torch
from torch import nn

def loss_func_works(result, target) -> torch.Tensor:
    print(result.shape, target.shape)
    class_weights = torch.tensor([1.0, 1.0, 1,0, 1.0], dtype=torch.float32, device=result.device)
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    dense = target.long().to(result.device)
    loss = criterion(result, dense)
    return loss


def loss_func_fails(result, target) -> torch.Tensor:
    print(result.shape, target.shape)
    class_weights = torch.tensor([0.1, 1.0, 1.0, 1,0, 1.0], dtype=torch.float32, device=result.device)
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    dense = target.long().to(result.device)
    loss = criterion(result, dense)
    return loss


if __name__ == "__main__":
    result = torch.tensor((), dtype=torch.float32)
    result = result.new_ones((2, 5, 26, 150, 320))
    target = torch.tensor((), dtype=torch.float32)
    target = target.new_ones((2, 26, 150, 320))

    loss_func_works(result, target)
    loss_func_fails(result, target)