Tensorflow: 在第二轴上使用索引列表切片3D张量

2024-10-01 13:37:51 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个形状为[batch_size, sentence_length, word_dim]的占位符张量和一个shape=[batch_size, num_indices]的索引列表。索引位于第二个轴上,是句子中单词的索引。Batch_size & sentence_length只有在运行时才知道。在

如何提取形状为[batch_size, len(indices), word_dim]的张量?在

我在读关于tensorflow.gather的文章,但似乎只收集第一个轴上的切片。我说的对吗?在

编辑:我设法让它和常量一起工作

def tile_repeat(n, repTime):
    '''
    create something like 111..122..2333..33 ..... n..nn 
    one particular number appears repTime consecutively.
    This is for flattening the indices.
    '''
    print n, repTime
    idx = tf.range(n)
    idx = tf.reshape(idx, [-1, 1])    # Convert to a n x 1 matrix.
    idx = tf.tile(idx, [1, int(repTime)])  # Create multiple columns, each column has one number repeats repTime 
    y = tf.reshape(idx, [-1])
    return y

def gather_along_second_axis(x, idx):
    ''' 
    x has shape: [batch_size, sentence_length, word_dim]
    idx has shape: [batch_size, num_indices]
    Basically, in each batch, get words from sentence having index specified in idx
    However, since tensorflow does not fully support indexing,
    gather only work for the first axis. We have to reshape the input data, gather then reshape again
    '''
    reshapedIdx = tf.reshape(idx, [-1]) # [batch_size*num_indices]
    idx_flattened = tile_repeat(tf.shape(x)[0], tf.shape(x)[1]) * tf.shape(x)[1] + reshapedIdx
    y = tf.gather(tf.reshape(x, [-1,int(tf.shape(x)[2])]),  # flatten input
                idx_flattened)
    y = tf.reshape(y, tf.shape(x))
    return y

x = tf.constant([
            [[1,2,3],[3,5,6]],
            [[7,8,9],[10,11,12]],
            [[13,14,15],[16,17,18]]
    ])
idx=tf.constant([[0,1],[1,0],[1,1]])

y = gather_along_second_axis(x, idx)
with tf.Session(''):
    print y.eval()
    print tf.Tensor.get_shape(y)

输出为:

^{pr2}$

形状:(3, 2, 3)

但是,当输入为占位符时,它不起作用返回错误:

idx = tf.tile(idx, [1, int(repTime)])  
TypeError: int() argument must be a string or a number, not 'Tensor'

Python2.7,tensorflow 0.12

提前谢谢你。在


Tags: sizetfbatchlengthsentencewordint形状
2条回答

@和伍的回答很有帮助。代码与示例xidx一起工作,后者是{},但是当sentence_length != len(indices)时,它会给出一个错误。在

我稍微修改了代码,现在它可以在sentence_length >= len(indices)时工作。在

我在python3.x上测试了新的x和{}

def tile_repeat(n, repTime):
    '''
    create something like 111..122..2333..33 ..... n..nn 
    one particular number appears repTime consecutively.
    This is for flattening the indices.
    '''
    idx = tf.range(n)
    idx = tf.reshape(idx, [-1, 1])    # Convert to a n x 1 matrix.
    idx = tf.tile(idx, [1, repTime])  # Create multiple columns, each column has one number repeats repTime 
    y = tf.reshape(idx, [-1])
    return y


def gather_along_second_axis(x, idx):
    ''' 
    x has shape: [batch_size, sentence_length, word_dim]
    idx has shape: [batch_size, num_indices]
    Basically, in each batch, get words from sentence having index specified in idx
    However, since tensorflow does not fully support indexing,
    gather only work for the first axis. We have to reshape the input data, gather then reshape again
    '''
    reshapedIdx = tf.reshape(idx, [-1]) # [batch_size*num_indices]
    idx_flattened = tile_repeat(tf.shape(x)[0], tf.shape(idx)[1]) * tf.shape(x)[1] + reshapedIdx
    y = tf.gather(tf.reshape(x, [-1,tf.shape(x)[2]]),  # flatten input
                idx_flattened)
    y = tf.reshape(y, [tf.shape(x)[0],tf.shape(idx)[1],tf.shape(x)[2]])
    return y

x = tf.constant([
            [[1,2,3],[1,2,3],[3,5,6],[3,5,6]],
            [[7,8,9],[7,8,9],[10,11,12],[10,11,12]],
            [[13,14,15],[13,14,15],[16,17,18],[16,17,18]]
    ])
idx=tf.constant([[0,1],[1,2],[0,3]])

y = gather_along_second_axis(x, idx)
with tf.Session(''):
    print(y.eval())

感谢@AllenLavoie的评论,我最终可以想出解决方案:

def tile_repeat(n, repTime):
    '''
    create something like 111..122..2333..33 ..... n..nn 
    one particular number appears repTime consecutively.
    This is for flattening the indices.
    '''
    print n, repTime
    idx = tf.range(n)
    idx = tf.reshape(idx, [-1, 1])    # Convert to a n x 1 matrix.
    idx = tf.tile(idx, [1, repTime])  # Create multiple columns, each column has one number repeats repTime 
    y = tf.reshape(idx, [-1])
    return y

def gather_along_second_axis(x, idx):
    ''' 
    x has shape: [batch_size, sentence_length, word_dim]
    idx has shape: [batch_size, num_indices]
    Basically, in each batch, get words from sentence having index specified in idx
    However, since tensorflow does not fully support indexing,
    gather only work for the first axis. We have to reshape the input data, gather then reshape again
    '''
    reshapedIdx = tf.reshape(idx, [-1]) # [batch_size*num_indices]
    idx_flattened = tile_repeat(tf.shape(x)[0], tf.shape(x)[1]) * tf.shape(x)[1] + reshapedIdx
    y = tf.gather(tf.reshape(x, [-1,tf.shape(x)[2]]),  # flatten input
                idx_flattened)
    y = tf.reshape(y, tf.shape(x))
    return y

x = tf.constant([
            [[1,2,3],[3,5,6]],
            [[7,8,9],[10,11,12]],
            [[13,14,15],[16,17,18]]
    ])
idx=tf.constant([[0,1],[1,0],[1,1]])

y = gather_along_second_axis(x, idx)
with tf.Session(''):
    print y.eval()
    print tf.Tensor.get_shape(y)

相关问题 更多 >