Python/Numpy:在计算语义(图像)分割的精确召回曲线时,我做错了什么?

2024-10-02 08:23:37 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在执行语义(图像)分割(关于材料),并试图为每个材料(类)绘制精确回忆曲线,但到目前为止,它们看起来非常奇怪:

PR curves for my semantic segmentation problem for ~100 thresholds

我从三个等长的形状数组(num\u pixels,num\u classes)开始:pixel_probs包含每个类每个像素的softmax概率,pixel_labels_predpixel_labels_gt分别为每个像素包含一个热编码预测和地面真值标签。然后我循环遍历我的阈值集,在每次迭代时,将pixel_labels_predpixel_labels_gt中的所有对应项都设置为零(这些不会导致类的真阳性、假阳性或假阴性)。然后我计算每个类/列的精确度和召回率。如果某个阈值(由irrelevant_mask表示)的标签没有有效的精度或召回,则这些值被设置为-1,并将在打印期间进行过滤。我最终得到了两个大小相等的数组precision_matrecall_mat(num\u classes,num\u threshold),其中包含要绘制的精度召回对。 我的代码如下:

precision_mat, recall_mat = np.array([]).reshape(num_labels, 0), np.array([]).reshape(num_labels, 0)
thresholds = np.unique(np.round(pixel_probs, decimals=2))

for threshold in sorted(thresholds):
    print("Current threshold: " + str(threshold))
    pixel_labels_pred_mskd = pixel_labels_pred
    pixel_labels_gt_mskd = pixel_labels_gt
    pixel_labels_pred_mskd[pixel_probs < threshold] = 0
    pixel_labels_gt_mskd[pixel_probs < threshold] = 0

    tps = np.sum(np.logical_and(pixel_labels_gt_mskd, pixel_labels_pred_mskd), axis=0)
    fps = np.sum(np.logical_and(np.logical_not(pixel_labels_gt_mskd), pixel_labels_pred_mskd), axis=0)
    fns = np.sum(np.logical_and(pixel_labels_gt_mskd, np.logical_not(pixel_labels_pred_mskd)), axis=0)

    irrelevant_mask = np.logical_or(tps + fps == 0, tps + fns == 0)
    tps[irrelevant_mask], fps[irrelevant_mask], fns[irrelevant_mask] = -1, -1, -1 # In order to avoid zero division
    precisions = tps / (tps + fps)
    recalls = tps / (tps + fns)
    precisions[irrelevant_mask], recalls[irrelevant_mask] = -1, -1 # In order to filter these out during plotting
    precision_mat = np.concatenate([precision_mat, np.expand_dims(precisions, axis=-1)], axis=-1)
    recall_mat = np.concatenate([recall_mat, np.expand_dims(recalls, axis=-1)], axis=-1)

fig = plt.figure()
fig.set_size_inches(12, 5)
index_rgb_mapping = get_index_rgb_mapping()
for label_index in range(len(labels)):
    recalls = recall_mat[label_index]
    precisions = precision_mat[label_index]
    print(list(zip(recalls, precisions)))
    relevant_mask = np.logical_and(recalls >= 0, precisions >= 0)
    recalls = recalls[relevant_mask]
    precisions = precisions[relevant_mask]
    sort_order = np.argsort(recalls)
    recalls = recalls[sort_order]
    precisions = precisions[sort_order]
    print(list(zip(recalls, precisions)))
    plt.plot(recalls, precisions, '-o', markersize=2.5, linewidth=2, label=labels[label_index],
             color=index_rgb_mapping[label_index])
plt.title("Precision-recall curve")
plt.legend(loc='upper left', fontsize=8.5, ncol=1, bbox_to_anchor=(1, 1))
plt.xlabel('recall', fontsize=12)
plt.ylabel('precision', fontsize=12)
plt.savefig(DIR + "test/pr_curves.png")

对于10个阈值中的0.3阈值:中间输出看起来很好:

Current threshold: 0.3
sum(pixel_labels_pred) : [   403  89522 205550  79966   5511      0  19153   6167  13746      0     160     50     32      0  96441    224   3012   1677   9330    811     582]
sum(pixel_labels_gt) :   [    34  73592 162568  86453   5936      0   5295   1722   3481      0       0      0      0      0  73393      0    422      5   3322     21      88]
sum(pixel_labels_pred_mskd) : [   144  78175 185814  77120   4316      0  14919   3630   9719      0      80     50     32      0  93369    176   1274    825   6343    409     313]
sum(pixel_labels_gt_mskd) : [      4  57588 143507  80909   3532    0   4710   1112   2362      0       0      0      0      0  73103      0     99      0   1245      0      32]
tps: [  0  54282 133949  74121   3220   0   4637   1004   2130      0       0      0      0      0  73007      0     91      0   1004      0      32]
fps: [  144 23893 51865  2999  1096     0   10282  2626  7589     0    80    50     32     0 20362   176  1183   825  5339   409   281]
fns: [   4 3306 9558 6788  312    0   73  108  232    0    0    0    0    0    96    0    8    0  241    0    0]
precisions: [ 0.     0.694  0.721  0.961  0.746 -1.     0.311  0.277  0.219 -1.  -1.    -1.    -1.    -1.     0.782 -1.     0.071 -1.     0.158 -1.   0.102]
recalls:    [ 0.     0.943  0.933  0.916  0.912 -1.     0.985  0.903  0.902 -1.  -1.    -1.    -1.    -1.     0.999 -1.     0.919 -1.     0.806 -1.   1.   ]

绘制曲线时,我的输出如下所示:

Plotting label index 2...
list(zip(recalls, precisions)), start of iteration:
[[ 0.807  0.682]
 [ 0.828  0.682]
 [ 0.869  0.687]
 [ 0.933  0.721]
 [ 0.986  0.772]
 [ 1.     0.804]
 [ 1.     0.823]
 [ 1.     0.821]
 [ 1.     0.823]
 [ 1.     1.   ]
 [-1.    -1.   ]]
relevant_mask: [ True  True  True  True  True  True  True  True  True  True False]
list(zip(recalls, precisions)), before plotting:
[[0.807 0.682]
 [0.828 0.682]
 [0.869 0.687]
 [0.933 0.721]
 [0.986 0.772]
 [1.    0.804]
 [1.    0.823]
 [1.    0.821]
 [1.    0.823]
 [1.    1.   ]]

很明显,查全率和查准率都提高了,这在查准率曲线上是一个非常奇怪的现象。 有人知道我做错了什么吗?你知道吗


Tags: gttrueindexlabelsnpmaskpixelmat

热门问题