带有非传感器参数的Tensorflow自定义渐变

2024-09-30 14:19:12 发布

您现在位置:Python中文网/ 问答频道 /正文

玩tensorflow的@tf.custom_梯度decorator我总是遇到麻烦,当函数的一个参数(可以争论这是否是一个op)是不可变的,我不能(或不想)定义它的梯度贡献。 因为似乎没有一种方法可以告诉tensorflow关于它的输入属性的任何信息,所以我选择了下面的解决方案,我认为这可能对其他人有用。下面是定义泊松基损失的示例。只有“fwd”是张量,所有其他参数不是:

def Loss_Poisson(fwd,meas,Bg=0.05, checkPos=False):
    meanmeas=np.mean(meas)
    if checkPos:
        fwd=((tf.sign(fwd) + 1)/2)*fwd
    @tf.custom_gradient
    def BarePoisson(myfwd):
        def grad(dy):
            mygrad=dy*(1.0 - meas/(myfwd+Bg))/meas.size  # the size accounts for the mean operation (rather than sum)
            return mygrad
        totalError = tf.reduce_mean((myfwd+Bg-meas) - meas * tf.log((myfwd+Bg)/(meas+Bg)))  
        return totalError,grad    
    return BarePoisson(fwd)/meanmeas

诀窍是使用一个内部函数,它只有一个(张量)参数,并偷偷使用外部函数的其他不可变参数。这条路@tf.custom_梯度装潢师对这些视而不见。在一次测试中,这很有效。 也许一个更清晰的解决方案会断言,这些(mean,Bg和checkPos)实际上是张量,因为这可能会导致错误的梯度。在

有没有更好的办法来对付这种黑客行为?在


Tags: 函数参数return定义tftensorflowdefcustom