在the official example of PyTorch中,它给出了一个损失函数,如下所示
def nll(input, target):
return -input[range(target.shape[0]), target].mean()
loss_func = nll
如何理解上述函数中“输入[range(target.shape[0]),target]”的语法?
“输入”有一个火炬大小([64,10]),“目标”有一个火炬大小([64])。为什么在这里使用“范围”功能
Tags:
range函数用作创建从0到64的向量/列表/生成器的快捷方式。所以它本质上是[0,1,2,…64]的缩写
要明确这一点,您可以执行以下操作:
相关问题 更多 >
编程相关推荐