我的焦点丢失实现中是否存在一些错误?

2024-09-30 00:28:47 发布

您现在位置:Python中文网/ 问答频道 /正文

def focalloss(input, target):
  """
  compute focal loss for multi-classes classification
  :param input: logits not processed by softmax in shape [batch, channel(classes)]
  :param target: in shape [batch], long data_type
  :return: focal loss
  """

  alpha = 0.5
  alpha_factor = torch.ones(target.shape).cuda() * alpha
  alpha_factor = torch.where(torch.eq(target, 1.), alpha_factor, 1. - alpha_factor)

  gamma = 2
  input_softmax = F.softmax(input, dim=1)
  index = target.unsqueeze(dim=1)
  pred_weights = torch.gather(input_softmax, dim=1, index=index).squeeze()
  focalweights = torch.ones_like(pred_weights) - pred_weights

  focalweights = alpha_factor * torch.pow(focalweights, gamma)
  focalweights = focalweights.detach()


  cel = F.cross_entropy(input, target, reduction='none')
  assert len(target.shape) == 1

  fl  = (focalweights*cel).sum()/target.numel()

return fl

When I use this implement to train my model. The loss down in former several epochs, but will rise later... I don't know why this happens, hope someone can help me.


Tags: inalphatargetinputindextorchshapedim

热门问题