使用tf.function的条件分支

2024-09-30 06:18:55 发布

您现在位置:Python中文网/ 问答频道 /正文

当使用带有条件分支的tf.function时,我在使用梯度带计算梯度时遇到问题

在梯度磁带范围内,我试图计算zw.r.tself.LMn的梯度。当我不使用@tf.function注释函数时,这工作得非常好。错误源于子类tf.keras.layers.Layer调用函数:

def call(self, x, training=None, labels=None, cur_switch=None, basis_filters=None):
    if basis_filters is None:
        basis_filters = self.initial_basis_filters

    gap = tf.reduce_mean(x, axis=[1, 2])
    z = tf.einsum('bc,cd->bd', gap, self.LMn)

    # ... code for out

    return out, z

实际误差如下所示:

....
C:\...\ops\cond_v2.py:387 <lambda>
    lambda: _grad_fn(func_graph, grads), [], {},
C:\...\ops\cond_v2.py:363 _grad_fn
    assert len(func_graph.outputs) == len(grads)

AssertionError: 

更具体地说

func_graphs.outputs = [<tf.Tensor 'vgg16/block1a/block1a_conv/cond_2/Identity:0' shape=(128, 32, 32, None) dtype=float32>, <tf.Tensor 'vgg16/block1a/block1a_conv/cond_2/OptionalFromValue:0' shape=() dtype=variant>, <tf.Tensor 'vgg16/block1a/block1a_conv/cond_2/OptionalFromValue_1:0' shape=() dtype=variant>]

grads = (<tf.Tensor 'gradient_tape/vgg16/block1a/block1a_conv/strided_slice_3/StridedSliceGrad_1:0' shape=(128, 32, 32, None) dtype=float32>,)

我可以假设这些输出中的每一个都对应于一些条件分支输出集的情况,这些条件分支输出随后被馈送到tf.einsum函数中。我已经阅读了梯度带文档中的所有边缘情况和注意事项,因为这似乎是问题所在。请注意,我只使用超参数(作为pythonic变量传递到tf.function,如basis_filters)执行条件计算。还有一些条件分支使用这些超参数的(tensorflow ops)函数,这是允许的吗?或者我需要在tf.function之外计算这些值,并将它们作为pythonic变量传入吗

我知道这个问题并不完全清楚,如果需要,我可以提供任何额外的信息。对于如何处理此类问题,提供一些指导将非常有帮助

谢谢


Tags: nonetf分支basisfunction条件filters梯度

热门问题