用pytorch函数计算二元熵损失

2024-10-16 17:27:25 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个关于计算二进制交叉熵的问题。我知道在pytorch中的结果是:

import torch
import torch.nn as nn
import torch.nn.functional as F
def lossfunc():
    return F.binary_cross_entropy

criterion = lossFunc()
input = torch.randn((3, 2), requires_grad=True)
target = torch.rand((3, 2), requires_grad=False)
loss = criterion(torch.sigmoid(input),target)

但是如何以这种方式完成lossfunc(),因为我不知道如何将参数传递给函数:

#the function that add sigmoid to input and calculate the binary cross entropy loss
def lossfunc():
   return

criterion = lossFunc()
input = torch.randn((3, 2), requires_grad=True)
target = torch.rand((3, 2), requires_grad=False)
loss = criterion(input,target)

Tags: importtargetinputreturndefasnntorch
1条回答
网友
1楼 · 发布于 2024-10-16 17:27:25

我觉得你把nnapi和函数Fapi搞混了。在函数api中,loss函数F.binary_cross_entropy可以直接用作函数。你知道吗

nnapi中,需要创建loss类的对象,例如criterion = nn.BCELoss()

因此,您可以简单地执行以下操作:

def lossFunc(input, target):
   return F.binary_cross_entropy(torch.sigmoid(input),target)

input = torch.randn((3, 2), requires_grad=True)
target = torch.rand((3, 2), requires_grad=False)
loss = lossFunc(input,target)

另外,PyTorch提供了^{}^{},它们结合了sigmoid和二进制交叉熵。你知道吗

相关问题 更多 >