在自定义损失函数Tensorflow 2.4中迭代实现非平衡多标签分类

2024-06-28 19:25:07 发布

您现在位置:Python中文网/ 问答频道 /正文

我的数据集中有大约160000张图像,每个图像都有多个标记/标签。第一个标记位于大约80%的图像上,最后一个标记位于大约0.5%的数据集上。我一直在尝试制作一个自定义的损耗函数来称量标签,这样网络就不会走捷径来获得98%的准确度。我总共有534个不同的标签,我发现很少有关于在这种不平衡的情况下为多标签分类创建损失函数的信息。要返回每个标签的损失值,我使用以下代码:

import numpy as np
positive_weights = {}
negative_weights = {}
for c in sudosamples.columns:
    positive_weights[c] = sudosamples.shape[0]/(2*np.count_nonzero(sudosamples[c]==1))
    negative_weights[c] = sudosamples.shape[0]/(2*np.count_nonzero(sudosamples[c]==0))

sudosamples是一个数据框,它保存通过列组织的所有图像的所有标签。每个标签都是一个热编码标签。此代码工作正常。问题出现在损失函数定义中,如下所示:

import tensorflow.keras.backend as K
def loss_fn(y_true,y_pred):
    y_true = tf.cast(y_true,tf.float32)
    loss = 0
    for c,i in zip(sudosamples.columns,range(len(sudosamples.columns))):
        loss -= positive_weights[c]*y_true[i]*K.log(y_pred[i])+negative_weights[c]*(1-y_true[i])*K.log(1-y_pred[i])
    return loss

这段代码不起作用,但是,如果我删除for循环并给出常量值,而不是从for循环派生的变量,它就可以正常工作。我不想为所有534个标签编写这行代码。如果我这样做了,我甚至不能完全肯定它会起作用。目前,上面的损失函数运行一个错误,如下所示:

  File "C:\Users\cws72\tensorflow\Projects\bot\Multi-Label-Training-V2.py", line 127, in <module>
    history = model.fit(Train_gen, epochs=10,steps_per_epoch=num_samples//128,validation_data=Val_gen,validation_steps=num_val_samples//128,callbacks=[cp_callback])

  File "C:\Users\cws72\anaconda3\envs\Tensorflow-GPU-WorkDir\lib\site-packages\tensorflow\python\keras\engine\training.py", line 1100, in fit
    tmp_logs = self.train_function(iterator)

  File "C:\Users\cws72\anaconda3\envs\Tensorflow-GPU-WorkDir\lib\site-packages\tensorflow\python\eager\def_function.py", line 828, in __call__
    result = self._call(*args, **kwds)

  File "C:\Users\cws72\anaconda3\envs\Tensorflow-GPU-WorkDir\lib\site-packages\tensorflow\python\eager\def_function.py", line 888, in _call
    return self._stateless_fn(*args, **kwds)

  File "C:\Users\cws72\anaconda3\envs\Tensorflow-GPU-WorkDir\lib\site-packages\tensorflow\python\eager\function.py", line 2942, in __call__
    return graph_function._call_flat(

  File "C:\Users\cws72\anaconda3\envs\Tensorflow-GPU-WorkDir\lib\site-packages\tensorflow\python\eager\function.py", line 1918, in _call_flat
    return self._build_call_outputs(self._inference_function.call(

  File "C:\Users\cws72\anaconda3\envs\Tensorflow-GPU-WorkDir\lib\site-packages\tensorflow\python\eager\function.py", line 555, in call
    outputs = execute.execute(

  File "C:\Users\cws72\anaconda3\envs\Tensorflow-GPU-WorkDir\lib\site-packages\tensorflow\python\eager\execute.py", line 59, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,

InvalidArgumentError: 2 root error(s) found.
  (0) Invalid argument:  slice index 484 of dimension 0 out of bounds.
     [[node loss_fn/strided_slice_1936 (defined at C:\Users\cws72\tensorflow\Projects\bot\Multi-Label-Training-V2.py:104) ]]
     [[assert_greater_equal/Assert/AssertGuard/branch_executed/_9/_55]]
  (1) Invalid argument:  slice index 484 of dimension 0 out of bounds.
     [[node loss_fn/strided_slice_1936 (defined at C:\Users\cws72\tensorflow\Projects\bot\Multi-Label-Training-V2.py:104) ]]
0 successful operations.
0 derived errors ignored. [Op:__inference_train_function_527426]

Errors may have originated from an input operation.
Input Source operations connected to node loss_fn/strided_slice_1936:
 loss_fn/Cast (defined at C:\Users\cws72\tensorflow\Projects\bot\Multi-Label-Training-V2.py:100)

Input Source operations connected to node loss_fn/strided_slice_1936:
 loss_fn/Cast (defined at C:\Users\cws72\tensorflow\Projects\bot\Multi-Label-Training-V2.py:100)

Function call stack:
train_function -> train_function

从技术上讲,这并不是全部错误,因为tensorflow发出了其中一半的错误:

2021-03-20 17:03:54.071766: W tensorflow/core/framework/op_kernel.cc:1763] OP_REQUIRES failed at strided_slice_op.cc:108 : Invalid argument: slice index 248 of dimension 0 out of bounds.

它们之间唯一的区别是切片索引。其中有几个是子进程错误,但是我知道这是我的GPU的问题,它不应该影响我运行代码的能力。如果我在没有tf.cast的情况下运行定义,我会得到一个错误,它期望的是一个整数值而不是一个浮点值(是的,我说的没错)。我正在为我的数据集使用生成器


Tags: inpygputensorflowlineslicefunction标签