对序列模型中的一批图像使用tf.image.ssim_multiscale

2024-09-28 18:56:38 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图使用SSIM作为我的Keras序列模型的损失值。输出值是一组6个图像-(6、32、28、3)。我想为model.fit()函数实现一个自定义的loss函数,但无法理解如何实现它。我附上我以前试过的-

def ssim_loss(y_true, y_pred):
    return 1-tf.reduce_mean(tf.image.ssim_multiscale(y_true, y_pred, 2.0))
def ssim_loss(y_true, y_pred):
  loss = 0
  for i in range(6):
    loss += tf.reduce_mean(tf.image.ssim_multiscale(y_true[i], y_pred[i], 2.0))
  return loss

y_真实形状与我检查过的y_pred相同。这两种解决方案都不起作用。我不知道如何访问y_pred的每个元素。我的输入形状是-(619256,3)。请帮忙。短暂性脑缺血发作


Tags: 函数imagetruereducereturntfdefmean