如何理解“输入[范围(target.shape[0]),目标]?

2024-06-01 07:17:06 发布

您现在位置:Python中文网/ 问答频道 /正文

the official example of PyTorch中,它给出了一个损失函数,如下所示

def nll(input, target):
    return -input[range(target.shape[0]), target].mean()

loss_func = nll

如何理解上述函数中“输入[range(target.shape[0]),target]”的语法? “输入”有一个火炬大小([64,10]),“目标”有一个火炬大小([64])。为什么在这里使用“范围”功能


Tags: ofthe函数targetinputreturnexampledef
1条回答
网友
1楼 · 发布于 2024-06-01 07:17:06

range函数用作创建从0到64的向量/列表/生成器的快捷方式。所以它本质上是[0,1,2,…64]的缩写

要明确这一点,您可以执行以下操作:

def nll(input, target):
    minputlist = list(range(target.shape[0]))
    print(minputlist )
    return -input[minputlist, target].mean()

相关问题 更多 >