在混淆矩阵中从四面打印标签

2024-09-28 17:04:03 发布

您现在位置:Python中文网/ 问答频道 /正文

我需要绘制混乱矩阵,从各个方面进行注释。我想在右侧打印与底部相同的标签时遇到问题([…类名…,“总样本数”,“精度”“mIoU]”)。此外,似乎顶部刻度与底部刻度不对齐

以下是我尝试过的:

    fig, ax1 = get_new_fig('Conf matrix default', figsize)

    
    ax = sn.heatmap(df_cm, annot=annot, annot_kws={"size": fz}, linewidths=lw, ax=ax1,
                    cbar=cbar, cmap=cmap, linecolor='w', fmt=fmt)

    ax_new = ax.twinx().twiny()

    labels = ['' for _ in range(len(ax.get_xticklabels()))]

    labels[-3] = 'Total samples'
    labels[-2] = 'Accuracy'
    labels[-1] = 'mIoU'

    ticks = [tick for tick in ax.get_xticks()]

    ax_new.set_xticks(ticks)
    ax_new.set_yticks(ticks)
    ax_new.yaxis.set_label_position('right')

    ax_new.set_xticklabels([text.get_text() for text in ax.get_xticklabels()], fontsize=10, rotation=-45)  # top
    ax_new.set_yticklabels(labels, fontsize=10, rotation=-25)  # right

    # set ticklabels
    ax.set_xticklabels(labels, rotation=45, fontsize=10)  # bottom
    ax.set_yticklabels(ax.get_yticklabels(), rotation=25, fontsize=10)  # left

enter image description here

我非常感谢任何帮助,因为我不知道我是否在代码中遗漏了什么

提前谢谢


Tags: textinnewforgetlabelsaxset
1条回答
网友
1楼 · 发布于 2024-09-28 17:04:03

问题在于新ax的限制。这些需要等于原始ax的限制。尤其是原始ax的y轴反转的事实,导致新y轴没有可见的记号标签。不同的限制也会阻止x轴记号的对齐

ax_new.set_xlim(ax.get_xlim())ax_new.set_ylim(ax.get_ylim())应该可以解决这个问题plt.tight_layout()可以帮助在周围的绘图中很好地定位所有标签

右y记号标签的旋转问题似乎有点难。下面的代码通过分离twinxtwiny轴来解决此问题:

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

fig, ax1 = plt.subplots()
ax1.set_title('Conf matrix default')

df_cm = pd.DataFrame(np.random.rand(9, 9), columns=range(1, 10), index=range(1, 10))
ax = sns.heatmap(df_cm, annot=True, annot_kws={"size": 12}, linewidths=2, ax=ax1,
                 cbar=False, linecolor='w', fmt='.2f')

ax_new1 = ax.twinx()
ax_new2 = ax_new1.twiny()

labels = ['' for _ in range(len(ax.get_xticklabels()))]

labels[-3] = 'Total samples'
labels[-2] = 'Accuracy'
labels[-1] = 'mIoU'

ticks = [tick for tick in ax.get_xticks()]

ax_new2.set_xticks(ticks)
ax_new1.set_yticks(ticks)
ax_new1.yaxis.set_label_position('right')

ax_new2.set_xticklabels([text.get_text() for text in ax.get_xticklabels()], fontsize=10, rotation=-45)  # top
ax_new1.set_yticklabels(labels, fontsize=10, rotation=-45)  # right

# set ticklabels
ax.set_xticklabels(labels, rotation=45, fontsize=10)  # bottom
ax.set_yticklabels(ax.get_yticklabels(), rotation=25, fontsize=10)  # left

ax_new2.set_xlim(ax.get_xlim())
ax_new1.set_ylim(ax.get_ylim())

plt.tight_layout()
plt.show()

resulting heatmap

相关问题 更多 >