在PyTorch中沿轴上的所有索引应用函数

2024-10-06 12:25:28 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图在PyTorch中实现Wasserstein Loss函数,为此我引用了Scipy实现。因为在forward()方法中使用PyTorch函数意味着不必编写backward()函数,所以我在代码中已经这样做了(Scipy版本中只涉及了等效的Numpy函数)。以下是我所拥有的:

class WassersteinLoss(nn.Module): 
  def __init__(self):
    super(WassersteinLoss, self).__init__()
  def forward(self,u,v):
    result = torch.empty((len(u)))
    for i in range(len(u)):
      u_values,v_values = u[i],v[i]
      u_sorter,v_sorter = torch.argsort(u_values),torch.argsort(v_values)
      all_values = torch.cat((u_values,v_values))
      all_values,idx = torch.sort(all_values)

      # Compute the differences between pairs of successive values of u and v.
      deltas = torch.sub(all_values[1:],all_values[:-1])

      # Get the respective positions of the values of u and v among the values of
      # both distributions.
      u_cdf_indices = torch.searchsorted(u_values[u_sorter],all_values[:-1],right=True)
      v_cdf_indices = torch.searchsorted(v_values[v_sorter],all_values[:-1],right=True)

      # Calculate the CDFs of u and v
      u_cdf = torch.div(u_cdf_indices,len(u_values))
      v_cdf = torch.div(v_cdf_indices,len(v_values))

      # Compute the value of the integral based on the CDFs.
      result[i] = torch.sum(torch.multiply(torch.abs(u_cdf-v_cdf),deltas))
    return result.mean()

在上述函数中,u和v是形状向量(NxM),其中N是批次中的样本数。因为我的for循环基本上覆盖了所有的样本,所以每个样本的计算都是独立的,因为批处理中的样本相互依赖。我相信我会看到一个显着的加速,如果我能摆脱这个循环。到目前为止,我一直尝试沿dim=1轴执行所有计算,但这不起作用

以下是测试用例的代码:

from scipy.stats import wasserstein_distance
import torch
import torch.nn as nn

print(wasserstein_distance([0, 1, 3], [5, 6, 8]))
#Output is 5

criterion = WassersteinLoss()
print(criterion(torch.tensor([[0, 1, 3]]), torch.tensor([[5, 6, 8]])))
#Output is tensor(5.)

任何关于如何修改forward()函数以消除for循环的输入都将不胜感激


Tags: ofthe函数selflennntorchresult