我的机器学习任务的度量是weight TPR = 0.4 * TPR1 + 0.3 * TPR2 + 0.3 * TPR3
。通常,它要求模型具有较高的召回率,同时干扰较少的负样本
一些术语:
- TPR(True Positive Rate, Sensitivity) : TPR = TP /(TP + FN)
- FPR(False Positive Rate, 1 - Specificity): FPR = FP /(FP + TN)
- TP、FN、FP、TN stands for True Positive, False Negative, Fasle Positive and True Negative.
- TPR1:TPR at FPR = 0.001
- TPR2:TPR at FPR = 0.005
- TPR3:TPR at FPR = 0.01
由于keras没有这样的度量,我们需要编写自己的定制度量。另一个值得一提的词是,与lightgbm和xgboost不同,keras
中的自定义度量不是直接向前的,因为训练过程是基于张量而不是pandas/numpy数组
在lightgbm/Xgboost中,我有一个wtpr
自定义度量,它可以正常工作:
def tpr_weight_funtion(y_true,y_predict):
d = pd.DataFrame()
d['prob'] = list(y_predict)
d['y'] = list(y_true)
d = d.sort_values(['prob'], ascending=[0])
y = d.y
PosAll = pd.Series(y).value_counts()[1]
NegAll = pd.Series(y).value_counts()[0]
pCumsum = d['y'].cumsum()
nCumsum = np.arange(len(y)) - pCumsum + 1
pCumsumPer = pCumsum / PosAll
nCumsumPer = nCumsum / NegAll
TR1 = pCumsumPer[abs(nCumsumPer-0.001).idxmin()]
TR2 = pCumsumPer[abs(nCumsumPer-0.005).idxmin()]
TR3 = pCumsumPer[abs(nCumsumPer-0.01).idxmin()]
return 0.4 * TR1 + 0.3 * TR2 + 0.3 * TR3
在keras中,我在下面编写了一个自定义度量。它适用于常规张量输入,但在批量梯度下降的模型拟合过程中失败:
import keras.backend as K
def keras_wtpr_metric(y_true, y_predict):
n = y_predict.shape[0]
a = tf.dtypes.cast(y_predict, tf.float32)
b = tf.dtypes.cast(y_true, tf.float32)
a = tf.reshape(a,shape = [-1])
b = tf.reshape(b,shape = [-1])
d = tf.stack([a,b], axis = 0)
d = tf.gather(d, tf.argsort(a,direction='DESCENDING'), axis = 1)
PosAll = tf.math.reduce_sum(b, axis = -1) # the number of positive samples
NegAll = tf.math.reduce_sum(1-b, axis = -1) # the number of negative samples
pCumsum = tf.math.cumsum(d[1]) # TP
pCumsum = tf.dtypes.cast(pCumsum,dtype = tf.float32)
nCumsum = tf.range(0,n,dtype = tf.float32) - pCumsum + 1 # FP
pCumsumPer = pCumsum / PosAll # tpr
nCumsumPer = nCumsum / NegAll # fpr
TR1 = pCumsumPer[tf.math.argmin(abs(nCumsumPer-0.001))]
TR2 = pCumsumPer[tf.math.argmin(abs(nCumsumPer-0.005))]
TR3 = pCumsumPer[tf.math.argmin(abs(nCumsumPer-0.01))]
return tf.reduce_sum(0.4*TR1+0.3*TR2+0.3*TR3)
我的模型是:
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(load_breast_cancer().data, load_breast_cancer().target,test_size = 0.3)
model = keras.models.Sequential([
# I have a tabular data
keras.layers.Dense(256, activation='relu',input_shape = (x_train.shape[1],)),
keras.layers.Dense(64, activation = 'relu'),
keras.layers.Dense(1, activation = 'sigmoid')
])
model.compile(optimizer='adam',loss = 'binary_crossentropy', metrics = [keras_wtpr_metric])
# it seems can not work under batch training, I don't know why
model.fit(x=x_train, y= y_train, batch_size = 2048, epochs = 30,validation_data = [x_test,y_test])
错误消息是
Epoch 1/30
398/398 [==============================] - 1s 2ms/sample
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
<ipython-input-639-da481d44d615> in <module>
5 ])
6 model.compile(optimizer='adam',loss = 'binary_crossentropy', metrics = [keras_wtpr_metric])
----> 7 model.fit(x=x_train, y= y_train, batch_size = 2048, epochs = 30,validation_data = [x_test,y_test]) # it seems can not work under batch training, I don't know why
~/opt/anaconda3/envs/envpython36/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, max_queue_size, workers, use_multiprocessing, **kwargs)
726 max_queue_size=max_queue_size,
727 workers=workers,
--> 728 use_multiprocessing=use_multiprocessing)
729
730 def evaluate(self,
~/opt/anaconda3/envs/envpython36/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training_v2.py in fit(self, model, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, **kwargs)
322 mode=ModeKeys.TRAIN,
323 training_context=training_context,
--> 324 total_epochs=epochs)
325 cbks.make_logs(model, epoch_logs, training_result, ModeKeys.TRAIN)
326
~/opt/anaconda3/envs/envpython36/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training_v2.py in run_one_epoch(model, iterator, execution_function, dataset_size, batch_size, strategy, steps_per_epoch, num_samples, mode, training_context, total_epochs)
121 step=step, mode=mode, size=current_batch_size) as batch_logs:
122 try:
--> 123 batch_outs = execution_function(iterator)
124 except (StopIteration, errors.OutOfRangeError):
125 # TODO(kaftan): File bug about tf function and errors.OutOfRangeError?
~/opt/anaconda3/envs/envpython36/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training_v2_utils.py in execution_function(input_fn)
84 # `numpy` translates Tensors to values in Eager mode.
85 return nest.map_structure(_non_none_constant_value,
---> 86 distributed_function(input_fn))
87
88 return execution_function
~/opt/anaconda3/envs/envpython36/lib/python3.6/site-packages/tensorflow_core/python/eager/def_function.py in __call__(self, *args, **kwds)
455
456 tracing_count = self._get_tracing_count()
--> 457 result = self._call(*args, **kwds)
458 if tracing_count == self._get_tracing_count():
459 self._call_counter.called_without_tracing()
~/opt/anaconda3/envs/envpython36/lib/python3.6/site-packages/tensorflow_core/python/eager/def_function.py in _call(self, *args, **kwds)
518 # Lifting succeeded, so variables are initialized and we can run the
519 # stateless function.
--> 520 return self._stateless_fn(*args, **kwds)
521 else:
522 canon_args, canon_kwds = \
~/opt/anaconda3/envs/envpython36/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py in __call__(self, *args, **kwargs)
1821 """Calls a graph function specialized to the inputs."""
1822 graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
-> 1823 return graph_function._filtered_call(args, kwargs) # pylint: disable=protected-access
1824
1825 @property
~/opt/anaconda3/envs/envpython36/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py in _filtered_call(self, args, kwargs)
1139 if isinstance(t, (ops.Tensor,
1140 resource_variable_ops.BaseResourceVariable))),
-> 1141 self.captured_inputs)
1142
1143 def _call_flat(self, args, captured_inputs, cancellation_manager=None):
~/opt/anaconda3/envs/envpython36/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
1222 if executing_eagerly:
1223 flat_outputs = forward_function.call(
-> 1224 ctx, args, cancellation_manager=cancellation_manager)
1225 else:
1226 gradient_name = self._delayed_rewrite_functions.register()
~/opt/anaconda3/envs/envpython36/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py in call(self, ctx, args, cancellation_manager)
509 inputs=args,
510 attrs=("executor_type", executor_type, "config_proto", config),
--> 511 ctx=ctx)
512 else:
513 outputs = execute.execute_with_cancellation(
~/opt/anaconda3/envs/envpython36/lib/python3.6/site-packages/tensorflow_core/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
65 else:
66 message = e.message
---> 67 six.raise_from(core._status_to_exception(e.code, message), None)
68 except TypeError as e:
69 keras_symbolic_tensors = [
~/opt/anaconda3/envs/envpython36/lib/python3.6/site-packages/six.py in raise_from(value, from_value)
InvalidArgumentError: Incompatible shapes: [0] vs. [398]
[[node metrics/keras_wtpr_metric/sub_1 (defined at /Users/travis/opt/anaconda3/envs/envpython36/lib/python3.6/site-packages/tensorflow_core/python/framework/ops.py:1751) ]] [Op:__inference_distributed_function_681042]
Function call stack:
distributed_function
keras_wtpr_metric
失败了李>
使用
n = tf.shape(y_predict)[0]
intead ofn = y_predict.shape[0]
动态考虑批处理维度在圆括号中传递验证数据:
validation_data = (x_test,y_test)
这里是跑步笔记本:https://colab.research.google.com/drive/1uUb3nAk8CAsLYDJXGraNt1_sYYRYVihX?usp=sharing
相关问题 更多 >
编程相关推荐