我们已经尝试过使用tf.nn.embedding_lookup
并且它有效。但它需要密集的输入数据,而现在我们需要tf.nn.embedding_lookup_sparse
用于稀疏的输入。
我已经写了下面的代码,但是有一些错误。
import tensorflow as tf
import numpy as np
example1 = tf.SparseTensor(indices=[[4], [7]], values=[1, 1], shape=[10])
example2 = tf.SparseTensor(indices=[[3], [6], [9]], values=[1, 1, 1], shape=[10])
vocabulary_size = 10
embedding_size = 1
var = np.array([0.0, 1.0, 4.0, 9.0, 16.0, 25.0, 36.0, 49.0, 64.0, 81.0])
#embeddings = tf.Variable(tf.ones([vocabulary_size, embedding_size]))
embeddings = tf.Variable(var)
embed = tf.nn.embedding_lookup_sparse(embeddings, example2, None)
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
print(sess.run(embed))
错误日志如下所示。
现在我不知道如何正确地修复和使用这个方法。如有任何意见,我们将不胜感激。
在深入到safe_embedding_lookup_sparse
的单元测试之后,我更困惑的是,如果给出稀疏权重,为什么会得到这个结果,特别是为什么我们得到了embedding_weights[0][3]
这样的结果,而3
在上面的代码中没有出现。
tf.nn.embedding_lookup_sparse()
使用Segmentation组合嵌入,这要求SparseTensor的索引从0开始并增加1。这就是为什么你会犯这个错误。与布尔值不同,稀疏张量只需要保存要从嵌入中检索的每一行的索引。这是你修改过的代码:
此外,还可以使用} 组合器之一组合单词嵌入:
tf.SparseTensor()
中的索引,使用允许的^{例如:
相关问题 更多 >
编程相关推荐