使用TensorFlow中的随机方法执行一个选项的问题?

2024-10-03 17:14:54 发布

您现在位置:Python中文网/ 问答频道 /正文

如何以可训练的方式从多个备选方案中随机选择一个执行流?例如:

import random
from tensorflow import keras

class RandomModel(keras.Model):
    def __init__(self, model_set):
        super(RandomModel, self).__init__()
        self.models = model_set


    def call(self, inputs):
        """Calls one of its models at random"""
        return random.sample(self.models, 1)[0](inputs)


def new_model():
    return keras.Sequential([
        keras.layers.Dense(10, activation='softmax')
    ])

model = RandomModel({new_model(), new_model()})
model.build(input_shape=(32, 784))
model.summary()

虽然这段代码runs,但它似乎不允许渐变反向传播。这是它的输出:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0
_________________________________________________________________

Tags: importselfnewmodelreturninitmodelsdef
1条回答
网友
1楼 · 发布于 2024-10-03 17:14:54

我找到a way来做这个。但是,由于嵌套的tf.cond操作,执行速度很慢:

def random_network_applied_to_inputs(inputs, networks):
    """
    Returns a tf.cond tree that does binary search
    before applying a network to the inputs.
    """
    length = len(networks)

    index = tf.random.uniform(
        shape=[],
        minval=0,
        maxval=length,
        dtype=tf.dtypes.int32
    )

    def branch(lower_bound, upper_bound):
        if lower_bound + 1 == upper_bound:
            return networks[lower_bound](inputs)
        else:
            center = (lower_bound + upper_bound) // 2
            return tf.cond(
                pred=index < center,
                true_fn=lambda: branch(lower_bound, center),
                false_fn=lambda: branch(center, upper_bound)
            )

    return branch(0, length)

相关问题 更多 >