有条件地中断tensorflow.while\u循环

2024-09-30 18:35:25 发布

您现在位置:Python中文网/ 问答频道 /正文

我正试图修改来自nshepperd的fork(https://github.com/nshepperd/gpt-2)的GPT-2示例生成代码

具体来说,下面的代码是sample.py文件的一部分:

with tf.name_scope('sample_sequence'):
    # Don't feed the last context token -- leave that to the loop below
    # TODO: Would be slightly faster if we called step on the entire context,
    # rather than leaving the last token transformer calculation to the while loop.
    context_output = step(hparams, context[:, :-1])

    def body(past, prev, output):
        next_outputs = step(hparams, prev[:, tf.newaxis], past=past)
        logits = next_outputs['logits'][:, -1, :]  / tf.to_float(temperature)
        if penalize > 0.0:
            logits = penalize_used(logits, output, penalize=penalize)
        if top_p > 0.0:
            logits = top_p_logits(logits, p=top_p, epsilon=epsilon)
        else:
            logits = top_k_logits(logits, k=top_k, epsilon=epsilon)
        samples = tf.multinomial(logits, num_samples=1, output_dtype=tf.int32)
        return [
            tf.concat([past, next_outputs['presents']], axis=-2),
            tf.squeeze(samples, axis=[1]),
            tf.concat([output, samples], axis=1),
        ]

    def cond(*args):
        return True

    _, _, tokens = tf.while_loop(
        cond=cond, body=body,
        maximum_iterations=length,
        loop_vars=[
            context_output['presents'],
            context[:, -1],
            context,
        ],
        shape_invariants=[
            tf.TensorShape(model.past_shape(hparams=hparams, batch_size=batch_size)),
            tf.TensorShape([batch_size]),
            tf.TensorShape([batch_size, None]),
        ],
        back_prop=False,
    )

    return tokens

基本上,我要做的是,一旦它生成了一个具有特定值的令牌,就让它停止,例如!EndText!。然而,由于我对tensorflow非常陌生,我对如何做到这一点非常不确定,特别是关于这一点的官方文档很少。如果我理解正确,我需要修改cond函数(我理解为在body函数的所有输出上循环),以便在enc.decode(output)=“!EndText!”时它会中断,但是我或多或少完全不知道从哪里开始


Tags: theloopoutputtftopbatchcontextbody
1条回答
网友
1楼 · 发布于 2024-09-30 18:35:25

我找到了答案(通过比我更熟悉TF的人的大力帮助)

唯一需要修改的是删除cond下的“returntrue”,而是插入以下return语句

return tf.math.logical_and(tf.not_equal(output[0][-1], tf.cast(X, tf.int32)), tf.not_equal(output[0][-2], tf.cast(Y, tf.int32)))

其中,X和Y表示(在本例中)两个应表示停止的令牌,如果您想要更少或更多的令牌(即,如果表示停止的令牌更少/更多不同),只需在上面的逻辑和语句中添加/删除术语即可

相关问题 更多 >