可变长度序列的连续固定长度批

2024-07-08 10:55:27 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图最大限度地提高GPU占用率在培训期间。我有可变长度的序列,我想密集包装成固定长度的批次。基本上,我希望短序列后面跟着另一个序列,我希望长序列被拆分,以便它们在下一批中继续。示例:

// Say batch size is 2 and desired sequence length is 4
s1 = [a, b, c, d, e, f]
s2 = [x, y, z]
s3 = [l, m, n, o]

// Resulting batches:
b1 = [[a, b, c, d]
      [x, y, z, l]]
b2 = [[e, f, _, _]
      [m, n, o, _]]

在Tensorflow中有没有一种简单的方法可以做到这一点?我的序列来自tf.TextLineReader

file_queue = tf.train.string_input_producer('./example_text')
reader = tf.TextLineReader()
key, sentence = reader.read(file_queue)
// convert string to int32 vector
sequence_tensor = to_sequence(sentence)

// what I wish I had:
batch = tf.fixed_length_batch_from_variable_length_sequences(
    sequence_tensor, batch_size, fixed_length)

提前谢谢你的建议。你知道吗


Tags: tosizestringqueueistfbatch序列
1条回答
网友
1楼 · 发布于 2024-07-08 10:55:27

好吧,我有一个工作的例子,这几乎是我所希望的。下面的代码以我希望的方式生成批处理,但它需要使用占位符在TF会话中传入和传出数据。我希望能够完全从TF图中构建这些批处理。你知道吗

希望我是愚蠢的,有一些明显的解决办法,有人可以指出。也请原谅这个案子。你知道吗

import tensorflow as tf

def buildBatch(seqLength, batchSize):

    def lineToSequence(line):
        line = tf.expand_dims(line, axis=0)
        line = tf.sparse_tensor_to_dense(tf.string_split(line), '_')
        line = tf.concat([line, [['<GO>']]], 1)
        return line

    data = tf.contrib.data.TextLineDataset(['./exampleFile.txt'])
    data = data.map(lambda line: lineToSequence(line))
    iterator = data.make_initializable_iterator()

    # Grab lines from the file until the the sequence length is met and shave off any extra
    def getFixedLengthSequence(start):
        c = lambda s: tf.shape(s)[1] < seqLength # while sequence is is too short
        b = lambda s: tf.concat([s, iterator.get_next()], 1) # concatenate the next line
        sentences = tf.while_loop(c, b, [start], back_prop=False, parallel_iterations=1,
            shape_invariants=[tf.TensorShape([1, None])])

        clippedToLength = tf.expand_dims(sentences[0, :seqLength], axis=0)
        leftover = tf.expand_dims(sentences[0, seqLength:], axis=0)
        return clippedToLength, leftover

    # Placeholders pass in the start of each sequence (which are saved from the last batch)
    startOfThisBatch = [tf.placeholder(tf.string, shape=[1,None]) for i in range(batchSize)]
    # Capture what is leftover from each sequence so it can be passed in to start the next batch
    startOfNextBatch = [tf.TensorArray(tf.string, size=1) for i in range(batchSize)]

    # Build the batch
    thisBatch = []
    for i, seqStart in enumerate(startOfThisBatch):
        seq, leftover = getFixedLengthSequence(seqStart)
        thisBatch.append(seq)
        startOfNextBatch[i] = startOfNextBatch[i].write(0, leftover)
    thisBatch = tf.concat(thisBatch, axis=0)
    startOfNextBatch = [b.read(0) for b in startOfNextBatch]

    return thisBatch, startOfThisBatch, startOfNextBatch, iterator.initializer


def printBatch():
    sequenceLength = 10
    batchSize = 3

    batch, startOfThisBatch, startOfNextBatch, iteratorInit = buildBatch(sequenceLength, batchSize)
    # The very first batch starts with <GO> tokens
    batchStarts = [[['<GO>']]]*batchSize

    sv = tf.train.Supervisor()
    with sv.managed_session() as sess:
        sess.run(iteratorInit)
        for b in range(4):
            # Populate feed dict with the beginning of each sequence in the batch
            feed = {}
            for i in range(batchSize):
                feed[startOfThisBatch[i]] = batchStarts[i]

            # Call TF to get this batch and the starting sequences of the next batch
            out, batchStarts = sess.run([batch, startOfNextBatch], feed_dict=feed)

            print 'Batch', b, ':'
            for seq in out:
                print " ".join(seq)
            print

printBatch()

结果:

Batch 0 :  
<GO> A spokesman said the company has been affected by  
<GO> Having a little flexibility on that issue would go  
<GO> Long before the advent of e-commerce , Wal-Mart 's 

Batch 1 :  
the credit crunch in the United States . <GO> Abu  
a long way to putting together a final package .  
founder Sam Walton set out his vision for a successful  

Batch 2 :  
Dhabi is going ahead to build solar city and no  
<GO> Her back was torn open , her liver was  
retail operation : " We let folks know we 're  

Batch 3 :  
pollution city . <GO> Now it has 175 staging centers  
ruptured , one of her lungs had collapsed and the  
interested in them and that they 're vital to us   

请注意,每个句子在下一批中继续。使用的示例文本文件来自1-billion word benchmark dataset,每行包含一个句子。你知道吗

相关问题 更多 >

    热门问题