Tensorflow 2:获取张量值

2024-07-02 04:13:31 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在切换到TF2,我刚刚跟踪了tutorial, 列车和梯级功能现在定义为@tf.功能". 在

我如何打印张量y峎pred和loss的值?在

@tf.function
def train_step(images, labels):
  with tf.GradientTape() as tape:
    predictions = model(images)
    loss = loss_object(labels, predictions)

    print("train preds: ", y_pred)
    print("train loss: ", loss)

  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  train_loss(loss)
  train_accuracy(labels, predictions)

Tags: 功能labelsmodeltftrainvariablesimagesprint
1条回答
网友
1楼 · 发布于 2024-07-02 04:13:31

print在Python世界中执行(而不是在graph中),因此当tf.function跟踪函数以构造图形时,它只打印一次张量。如果要在图形中打印,请使用^{}。在

相关问题 更多 >

    热门问题