我试图在PyTorch中实现Wasserstein Loss函数,为此我引用了Scipy实现。因为在forward()方法中使用PyTorch函数意味着不必编写backward()函数,所以我在代码中已经这样做了(Scipy版本中只涉及了等效的Numpy函数)。以下是我所拥有的:
class WassersteinLoss(nn.Module):
def __init__(self):
super(WassersteinLoss, self).__init__()
def forward(self,u,v):
result = torch.empty((len(u)))
for i in range(len(u)):
u_values,v_values = u[i],v[i]
u_sorter,v_sorter = torch.argsort(u_values),torch.argsort(v_values)
all_values = torch.cat((u_values,v_values))
all_values,idx = torch.sort(all_values)
# Compute the differences between pairs of successive values of u and v.
deltas = torch.sub(all_values[1:],all_values[:-1])
# Get the respective positions of the values of u and v among the values of
# both distributions.
u_cdf_indices = torch.searchsorted(u_values[u_sorter],all_values[:-1],right=True)
v_cdf_indices = torch.searchsorted(v_values[v_sorter],all_values[:-1],right=True)
# Calculate the CDFs of u and v
u_cdf = torch.div(u_cdf_indices,len(u_values))
v_cdf = torch.div(v_cdf_indices,len(v_values))
# Compute the value of the integral based on the CDFs.
result[i] = torch.sum(torch.multiply(torch.abs(u_cdf-v_cdf),deltas))
return result.mean()
在上述函数中,u和v是形状向量(NxM),其中N是批次中的样本数。因为我的for循环基本上覆盖了所有的样本,所以每个样本的计算都是独立的,因为批处理中的样本相互依赖。我相信我会看到一个显着的加速,如果我能摆脱这个循环。到目前为止,我一直尝试沿dim=1轴执行所有计算,但这不起作用
以下是测试用例的代码:
from scipy.stats import wasserstein_distance
import torch
import torch.nn as nn
print(wasserstein_distance([0, 1, 3], [5, 6, 8]))
#Output is 5
criterion = WassersteinLoss()
print(criterion(torch.tensor([[0, 1, 3]]), torch.tensor([[5, 6, 8]])))
#Output is tensor(5.)
任何关于如何修改forward()函数以消除for循环的输入都将不胜感激
目前没有回答
相关问题 更多 >
编程相关推荐