我有一个形状为[batch_size, sentence_length, word_dim]
的占位符张量和一个shape=[batch_size, num_indices]
的索引列表。索引位于第二个轴上,是句子中单词的索引。Batch_size & sentence_length
只有在运行时才知道。在
如何提取形状为[batch_size, len(indices), word_dim]
的张量?在
我在读关于tensorflow.gather
的文章,但似乎只收集第一个轴上的切片。我说的对吗?在
编辑:我设法让它和常量一起工作
def tile_repeat(n, repTime):
'''
create something like 111..122..2333..33 ..... n..nn
one particular number appears repTime consecutively.
This is for flattening the indices.
'''
print n, repTime
idx = tf.range(n)
idx = tf.reshape(idx, [-1, 1]) # Convert to a n x 1 matrix.
idx = tf.tile(idx, [1, int(repTime)]) # Create multiple columns, each column has one number repeats repTime
y = tf.reshape(idx, [-1])
return y
def gather_along_second_axis(x, idx):
'''
x has shape: [batch_size, sentence_length, word_dim]
idx has shape: [batch_size, num_indices]
Basically, in each batch, get words from sentence having index specified in idx
However, since tensorflow does not fully support indexing,
gather only work for the first axis. We have to reshape the input data, gather then reshape again
'''
reshapedIdx = tf.reshape(idx, [-1]) # [batch_size*num_indices]
idx_flattened = tile_repeat(tf.shape(x)[0], tf.shape(x)[1]) * tf.shape(x)[1] + reshapedIdx
y = tf.gather(tf.reshape(x, [-1,int(tf.shape(x)[2])]), # flatten input
idx_flattened)
y = tf.reshape(y, tf.shape(x))
return y
x = tf.constant([
[[1,2,3],[3,5,6]],
[[7,8,9],[10,11,12]],
[[13,14,15],[16,17,18]]
])
idx=tf.constant([[0,1],[1,0],[1,1]])
y = gather_along_second_axis(x, idx)
with tf.Session(''):
print y.eval()
print tf.Tensor.get_shape(y)
输出为:
^{pr2}$形状:(3, 2, 3)
但是,当输入为占位符时,它不起作用返回错误:
idx = tf.tile(idx, [1, int(repTime)])
TypeError: int() argument must be a string or a number, not 'Tensor'
Python2.7,tensorflow 0.12
提前谢谢你。在
@和伍的回答很有帮助。代码与示例},但是当
x
和idx
一起工作,后者是{sentence_length != len(indices)
时,它会给出一个错误。在我稍微修改了代码,现在它可以在
sentence_length >= len(indices)
时工作。在我在python3.x上测试了新的}
x
和{感谢@AllenLavoie的评论,我最终可以想出解决方案:
相关问题 更多 >
编程相关推荐