创建我自己的损失,而不中断梯度磁带记录的梯度链

2024-06-26 00:30:20 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在学习Tensorflow 2.10和Python 3.7.7

我试图使用教程“Tensorflow - Custom training: walkthrough”来使用我自己的损失函数

这是我的损失函数的第一个版本,它可以工作:

    def loss(model, x, y):
      output = model(x)
      return tf.norm(y - output)

我换了另一个,但它不起作用:

def my_loss(model, x, y):
  output = model(x)

  # Only valid values for output var are 0.0 and 1.0.
  output_np = np.array(output)
  output_np[output_np >= 0.5] = 1.0
  output_np[output_np < 0.5] = 0.0

  # Counts how many 1.0 are on y var.
  unique, counts = np.unique(y, return_counts=True)
  dict_wmh = dict(zip(unique, counts))  
  wmh_count = 0
  if 1.0 in dict_wmh:
    wmh_count = dict_wmh[1.0]

  # Add y and output to get another array.
  c = y + output_np
  unique, counts = np.unique(c, return_counts=True)
  dict_net = dict(zip(unique, counts))

  # Counts how many 2.0 are on this new array.
  net_count = 0
  if 2.0 in dict_net:
    net_count = dict_net[2.0]

  # Return the different between the number of ones in the label and the network output.
  return wmh_count - net_count
  #return tf.convert_to_tensor(wmh_count - net_count, dtype=tf.float32)

我得到一个错误:

Cannot convert value 0 to a TensorFlow DType.

如果我用这个更改报税表:

return tf.convert_to_tensor(wmh_count - net_count, dtype=tf.float32)

我得到一个错误:

No gradients provided for any variable: ['conv1_1/kernel:0', 'conv1_1/bias:0', ...

关于此功能:

def grad(model, inputs, targets):
    with tf.GradientTape() as tape:
        tape.watch(model.trainable_variables)
        #loss_value = loss(model, inputs, targets)
        loss_value = my_loss(model, inputs, targets)

    return loss_value, tape.gradient(loss_value, model.trainable_variables)

因为我问过这个SO question,现在我知道我的新损失函数“interrupts the gradient chain registered by the gradient tape

此新损失函数(my_loss)执行以下操作:

  1. 将模型的输出转换为值为0.0或1.0的数组
  2. 创建一个新数组,将此output添加到y(值也为0.0或1.0)
  3. 统计此新阵列中有多少个2.0
  4. 统计y数组中有多少个1.0
  5. 返回点3和点4之间的差值

如何将其转换为张量而不是“interrupts the gradient chain registered by the gradient tape

也许有一个本机函数来完成我想做的事情

顺便说一下,我已经把它转换成了numpy数组,因为我不能用张量来做这个


Tags: the函数outputnetmodelreturnvaluetf