我的问题是:在预处理期间,我希望使用tf.data.Dataset
和tf.function
API将从一组函数中随机选择的函数应用于数据集示例
具体来说,我的数据是3D体积,我希望应用一组24个预定义旋转函数的旋转。我想在tf.function
中编写这段代码,这样就限制了numpy
和列表索引等包的使用
例如,我想这样做:
import tensorflow as tf
@tf.function
def func1(tensor):
# Apply some rotation here
...
@tf.function
def func2(tensor):
...
...
@tf.function
def func24(tensor):
...
@tf.function
def apply(tensor):
list_of_funcs = [func1, func2, ..., func24]
# Randomly sample from 0-23
a = tf.random.uniform([1], minval=0, maxval=23, dtype=tf.int32)
return list_of_funcs[a](tensor)
但是,我不能将list_of_funcs
索引为TypeError: list indices must be integers or slices, not Tensor
。此外,我无法将这些函数(AFAIK)收集到tf.Tensor
中并使用tf.gather
因此,我的问题是:如何在tf.function
中合理、整洁地从这些函数中采样
您可以使用一组嵌套的^{} 。如果满足条件,它将调用
true_fn
或false_fn
。因为您有两个以上的函数,所以可以将它们嵌套为任意多个函数。例如,我在做一个函数,根据随机变量的值,将输入乘以2、3、4或5请注意,如果不满足任何条件,
mult_5
将执行也许可以尝试使用tf.py_function,它:
例如(在Google Colab上测试):
相关问题 更多 >
编程相关推荐