从for循环构建pred\u fn\u对时,Tensorflow case函数不起作用

2024-09-27 07:22:23 发布

您现在位置:Python中文网/ 问答频道 /正文

我有个奇怪的问题

我只是尝试根据一些区间对张量值应用一个映射。 例如,让我们将区间[0,10]中的张量值映射到1,区间[10,20]中的张量值映射到2,区间[20,30]中的张量值映射到3,其他张量值映射到0

如果我通过一个列表来构建这个映射,并“手动”遍历它,它就可以正常工作

def build_mult_mapping(l):
    return lambda x: tf.case({
                              tf.math.logical_and(tf.greater(x, tf.constant(l[0][0][0])), tf.less(x, tf.constant(l[0][0][1]))): lambda: tf.constant(l[0][1]),
                              tf.math.logical_and(tf.greater(x, tf.constant(l[1][0][0])), tf.less(x, tf.constant(l[1][0][1]))): lambda: tf.constant(l[1][1]),
                              tf.math.logical_and(tf.greater(x, tf.constant(l[2][0][0])), tf.less(x, tf.constant(l[2][0][1]))): lambda: tf.constant(l[2][1])
                              },
                              default=lambda: tf.constant(l[3]), exclusive=True)

它很难看,但是如果我应用这个build_mult_mapping版本,它就可以正常工作

def f(x):
    mult_mapping = build_mult_mapping([((0, 10), 1), ((10, 20), 2), ((20, 30), 3), 0])
    return tf.map_fn(mult_mapping, x)

x = tf.constant([5, 15, 25, 35])
sess = tf.Session()
sess.run(f(x))

如果我运行这个块,我得到:array([1, 2, 3, 0], dtype=int32),这是预期的输出

现在如果我尝试用for循环迭代build_mult_mapping参数y,我会得到一个奇怪的行为。这是新的实现

def build_mult_mapping(l):
    return lambda x: tf.case({tf.math.logical_and(tf.greater(x, tf.constant(i[0][0])), tf.less(x, tf.constant(i[0][1]))): lambda: tf.constant(i[1])
                              for i in l[:-1]
                             },
                             default=lambda: tf.constant(l[-1]), exclusive=True)

如果我像以前一样运行同一个块,我会得到以下输出array([3, 3, 3, 0], dtype=int32),而我希望得到和以前一样的输出

有什么想法吗


Tags: andlambdabuildreturntfdefmathmapping

热门问题