擅长:python、mysql、java
<p>您似乎对<code>tf.where</code>用法感到困惑。从<a href="https://www.tensorflow.org/api_docs/python/tf/where" rel="nofollow noreferrer">documentation</a>可以看出,tf.where应该采用三个参数,否则它将简单地返回<code>None</code>,如这里所述</p>
<pre><code>tf.where(
condition, x=None, y=None, name=None
)
</code></pre>
<p>这就是为什么你的损失无助于学习任何东西,因为它总是会返回<code>None</code>不管怎样</p>
<p>对于你的问题,如果你想检查这两个条件,然后暗示损失,这是你应该怎么做</p>
<p>假设<code>y_true!=0</code>和<code>y_pred!=0</code>分别给出损失<code>some_loss1</code>和<code>some_loss2</code>,那么总损失可以通过嵌套<code>tf.where</code>计算为</p>
<pre><code>some_loss1=tf.constant(1000.0) #say
some_loss12=tf.constant(1000.0) #say
loss = tf.where(y_pred<0.1,tf.where(y_true<0.1,tf.constant(0.0),some_loss1),some_loss2)
</code></pre>
<p>这将惩罚双方</p>
<p>此外,如果要将此损失添加到MSE损失中,请创建不同的变量名称,因为它将已获得的MSE值添加到此掩码损失中</p>