擅长:python、mysql、java
<p>您可以使用一组嵌套的<a href="https://www.tensorflow.org/api_docs/python/tf/cond" rel="nofollow noreferrer">^{<cd1>}</a>。如果满足条件,它将调用<code>true_fn</code>或<code>false_fn</code>。因为您有两个以上的函数,所以可以将它们嵌套为任意多个函数。例如,我在做一个函数,根据随机变量的值,将输入乘以2、3、4或5</p>
<pre><code>import tensorflow as tf
x = 10
@tf.function
def mult_2():
tf.print(f'i was 2, returning {x} multiplied by 2')
return tf.multiply(x, 2)
@tf.function
def mult_3():
tf.print(f'i was 3, returning {x} multiplied by 3')
return tf.multiply(x, 3)
@tf.function
def mult_4():
tf.print(f'i was 4, returning {x} multiplied by 4')
return tf.multiply(x, 4)
@tf.function
def mult_5():
tf.print(f'i was 5, returning {x} multiplied by 5')
return tf.multiply(x, 5)
i = tf.random.uniform((), 1, 5, dtype=tf.int32)
tf.cond(i == 2, mult_2,
lambda: tf.cond(i == 3, mult_3,
lambda: tf.cond(i == 4, mult_4, mult_5)))
</code></pre>
<pre><code>I was 3, returning 10 multiplied by 3
</code></pre>
<pre><code><tf.Tensor: shape=(), dtype=int32, numpy=30>
</code></pre>
<p>请注意,如果不满足任何条件,<code>mult_5</code>将执行</p>