对不起,我不知道该怎么解决。 我使用两个网络来构造两个嵌入,我有一个二进制目标来指示embeddingA和embeddingB是否匹配(1或-1)。 数据集如下所示:
embA0 embB0 1.0
embA1 embB1 -1.0
embA2 embB2 1.0
...
我希望使用余弦相似性来得到分类结果。 但我在选择损失函数时感到困惑,生成嵌入的两个网络是分开训练的,现在我可以想到以下两个选项:
计划1:
构造第三个网络,使用EmbeddedingA和EmbeddedingB作为nn.cosinesimilarity()的输入来计算最终结果(应为[-1,1]中的概率),然后选择一个两类损失函数
(对不起,我不知道该选择哪个损失函数。)
class cos_Similarity(nn.Module):
def __init__(self):
super(cos_Similarity,self).__init__()
cos=nn.CosineSimilarity(dim=2)
embA=generator_A()
embB=generator_B()
def forward(self,a,b):
output_a=embA(a)
output_b=embB(b)
return cos(output_a,output_b)
loss_func=nn.CrossEntropyLoss()
y=cos_Similarity(a,b)
loss=loss_func(y,target)
acc=np.int64(y>0)
计划2: 这两个嵌入作为输出,然后使用nn.cosinembeddingloss()作为损失函数,当我计算精度时,我使用nn.Cosinesimilarity()输出结果(概率为[-1,1])
output_a=embA(a)
output_b=embB(b)
cos=nn.CosineSimilarity(dim=2)
loss_function = torch.nn.CosineEmbeddingLoss()
loss=loss_function(output_a,output_b,target)
acc=cos(output_a,output_b)
我真的需要帮助。我如何做出选择?为什么?或者我只能通过实验结果为自己做出选择。 多谢各位
######################################
def train_func(train_loss_list):
train_data=load_data('train')
trainloader = DataLoader(train_data, batch_size=BATCH_SIZE)
cos_smi=nn.CosineSimilarity(dim=2)
train_loss = 0
for step,(a,b,target) in enumerate(trainloader):
try:
optimizer.zero_grad()
output_a = model_A(a) #generate embA
output_b = model_B(b) #generate embB
acc=cos_smi(output_a,output_b)
loss = loss_fn(output_a,output_b, target.unsqueeze(dim=1))
train_loss += loss.item()
loss.backward()
optimizer.step()
train_loss_list.append(loss)
if step%10==0:
print('train:',step,'step','loss:',loss,'acc',acc)
except Exception as e:
print('train:',step,'step')
print(repr(e))
return train_loss_list,train_loss/len(trainloader)
您可以使用三重丢失功能进行训练。您的输入是一组嵌入(比如1000行)。假设每一个都以200维编码。还有相似性标签。例如,第1行可能与1000行中的20行相似,而dis与其余980行相似。然后,您可以通过每次进行1+ve和1-ve匹配,对第1行使用三重态丢失函数。你可以对火车上的所有1000行这样做。这样,嵌入现在可以更好地进行微调。这是训练阶段
现在,为了进行推断,您可以找出余弦相似性来确定哪些行彼此接近,哪些不接近(k最近,其中k=1)。我想这就是你的模型的目标
我们在这里假设嵌入是“可转移的”,因为它来自诸如BERT(文本)或imagenet(图像)之类的东西,这些嵌入可以通过在顶部添加一层进行微调
作为对注释线程的响应
目标或管道似乎是:
我能想到的是以下几点。如果我误解了什么,请纠正我。免责声明是,我几乎是根据我的直觉编写的,不知道任何细节,所以如果你尝试运行,它可能会充满错误。让我们仍然尝试获得高层次的理解
型号
培训/评估
我省略了一些细节(例如,超参数值、损失函数和优化器等)。这整个过程和你想要的类似吗
相关问题 更多 >
编程相关推荐