如何在Keras中实现Gumbel Softmax

2024-10-01 11:22:21 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图建立一个seq2seq模型来生成句子并对它们进行分类。为了实现后者,我想将解码器的softmax输出输入到一个预先训练好的分类器。然而,该分类器使用Keras嵌入层,因此将原始softmax传递到分类器中不是一个选项。我想我可以使用gumbel softmax来获得一个热编码,然后使用我在这里找到的onehotmembelding层(https://github.com/keras-team/keras/issues/2505)来解决这个问题。在

ericjang为gumbel softmax提供了这个TensorFlow代码,我想知道如何将其转换为Keras层。特别是,我对hard属性感兴趣,它确保前向通道上的向量是严格分类的,但是在反向通道上,梯度是gumbel softmax的输出。我不知道如何在Keras建造这个。有人能帮忙吗?在

谢谢。在


Tags: https模型github编码分类器选项分类解码器