我正在执行语义(图像)分割(关于材料),并试图为每个材料(类)绘制精确回忆曲线,但到目前为止,它们看起来非常奇怪:
我从三个等长的形状数组(num\u pixels,num\u classes)开始:pixel_probs
包含每个类每个像素的softmax概率,pixel_labels_pred
和pixel_labels_gt
分别为每个像素包含一个热编码预测和地面真值标签。然后我循环遍历我的阈值集,在每次迭代时,将pixel_labels_pred
和pixel_labels_gt
中的所有对应项都设置为零(这些不会导致类的真阳性、假阳性或假阴性)。然后我计算每个类/列的精确度和召回率。如果某个阈值(由irrelevant_mask
表示)的标签没有有效的精度或召回,则这些值被设置为-1
,并将在打印期间进行过滤。我最终得到了两个大小相等的数组precision_mat
和recall_mat
(num\u classes,num\u threshold),其中包含要绘制的精度召回对。
我的代码如下:
precision_mat, recall_mat = np.array([]).reshape(num_labels, 0), np.array([]).reshape(num_labels, 0)
thresholds = np.unique(np.round(pixel_probs, decimals=2))
for threshold in sorted(thresholds):
print("Current threshold: " + str(threshold))
pixel_labels_pred_mskd = pixel_labels_pred
pixel_labels_gt_mskd = pixel_labels_gt
pixel_labels_pred_mskd[pixel_probs < threshold] = 0
pixel_labels_gt_mskd[pixel_probs < threshold] = 0
tps = np.sum(np.logical_and(pixel_labels_gt_mskd, pixel_labels_pred_mskd), axis=0)
fps = np.sum(np.logical_and(np.logical_not(pixel_labels_gt_mskd), pixel_labels_pred_mskd), axis=0)
fns = np.sum(np.logical_and(pixel_labels_gt_mskd, np.logical_not(pixel_labels_pred_mskd)), axis=0)
irrelevant_mask = np.logical_or(tps + fps == 0, tps + fns == 0)
tps[irrelevant_mask], fps[irrelevant_mask], fns[irrelevant_mask] = -1, -1, -1 # In order to avoid zero division
precisions = tps / (tps + fps)
recalls = tps / (tps + fns)
precisions[irrelevant_mask], recalls[irrelevant_mask] = -1, -1 # In order to filter these out during plotting
precision_mat = np.concatenate([precision_mat, np.expand_dims(precisions, axis=-1)], axis=-1)
recall_mat = np.concatenate([recall_mat, np.expand_dims(recalls, axis=-1)], axis=-1)
fig = plt.figure()
fig.set_size_inches(12, 5)
index_rgb_mapping = get_index_rgb_mapping()
for label_index in range(len(labels)):
recalls = recall_mat[label_index]
precisions = precision_mat[label_index]
print(list(zip(recalls, precisions)))
relevant_mask = np.logical_and(recalls >= 0, precisions >= 0)
recalls = recalls[relevant_mask]
precisions = precisions[relevant_mask]
sort_order = np.argsort(recalls)
recalls = recalls[sort_order]
precisions = precisions[sort_order]
print(list(zip(recalls, precisions)))
plt.plot(recalls, precisions, '-o', markersize=2.5, linewidth=2, label=labels[label_index],
color=index_rgb_mapping[label_index])
plt.title("Precision-recall curve")
plt.legend(loc='upper left', fontsize=8.5, ncol=1, bbox_to_anchor=(1, 1))
plt.xlabel('recall', fontsize=12)
plt.ylabel('precision', fontsize=12)
plt.savefig(DIR + "test/pr_curves.png")
对于10个阈值中的0.3
阈值:中间输出看起来很好:
Current threshold: 0.3
sum(pixel_labels_pred) : [ 403 89522 205550 79966 5511 0 19153 6167 13746 0 160 50 32 0 96441 224 3012 1677 9330 811 582]
sum(pixel_labels_gt) : [ 34 73592 162568 86453 5936 0 5295 1722 3481 0 0 0 0 0 73393 0 422 5 3322 21 88]
sum(pixel_labels_pred_mskd) : [ 144 78175 185814 77120 4316 0 14919 3630 9719 0 80 50 32 0 93369 176 1274 825 6343 409 313]
sum(pixel_labels_gt_mskd) : [ 4 57588 143507 80909 3532 0 4710 1112 2362 0 0 0 0 0 73103 0 99 0 1245 0 32]
tps: [ 0 54282 133949 74121 3220 0 4637 1004 2130 0 0 0 0 0 73007 0 91 0 1004 0 32]
fps: [ 144 23893 51865 2999 1096 0 10282 2626 7589 0 80 50 32 0 20362 176 1183 825 5339 409 281]
fns: [ 4 3306 9558 6788 312 0 73 108 232 0 0 0 0 0 96 0 8 0 241 0 0]
precisions: [ 0. 0.694 0.721 0.961 0.746 -1. 0.311 0.277 0.219 -1. -1. -1. -1. -1. 0.782 -1. 0.071 -1. 0.158 -1. 0.102]
recalls: [ 0. 0.943 0.933 0.916 0.912 -1. 0.985 0.903 0.902 -1. -1. -1. -1. -1. 0.999 -1. 0.919 -1. 0.806 -1. 1. ]
绘制曲线时,我的输出如下所示:
Plotting label index 2...
list(zip(recalls, precisions)), start of iteration:
[[ 0.807 0.682]
[ 0.828 0.682]
[ 0.869 0.687]
[ 0.933 0.721]
[ 0.986 0.772]
[ 1. 0.804]
[ 1. 0.823]
[ 1. 0.821]
[ 1. 0.823]
[ 1. 1. ]
[-1. -1. ]]
relevant_mask: [ True True True True True True True True True True False]
list(zip(recalls, precisions)), before plotting:
[[0.807 0.682]
[0.828 0.682]
[0.869 0.687]
[0.933 0.721]
[0.986 0.772]
[1. 0.804]
[1. 0.823]
[1. 0.821]
[1. 0.823]
[1. 1. ]]
很明显,查全率和查准率都提高了,这在查准率曲线上是一个非常奇怪的现象。 有人知道我做错了什么吗?你知道吗
目前没有回答
相关问题 更多 >
编程相关推荐