长列表作为集合标签输入的问题(未对齐)

2024-06-01 06:59:50 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在关注NMT(https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb)教程,并将其应用于我自己的用例。不幸的是,当我尝试绘制注意权重时,如果输入太长(例如14而不是7),我会遇到x轴对齐问题

在此代码块中,绘图工作正常:

import numpy as np
from matplotlib import pyplot as plt

def plot_attention():
    attention = np.array([[7.78877574e-10, 4.04739769e-10, 6.65854022e-05, 1.63362725e-04,
    2.85054208e-04, 8.50252633e-04, 4.58042100e-02],
   [9.23501700e-02, 5.69618285e-01, 1.80586591e-01, 9.78111699e-02,
    2.71992851e-02, 9.59911197e-03, 2.54837354e-03]])



    sentence = ['<start>', 'hace', 'mucho', 'frio', 'aqui', '.', '<end>'] 
    predicted_sentence = ['it', 's']



    fig = plt.figure(figsize=(10,10))
    ax = fig.add_subplot(1, 1, 1)
    ax.matshow(attention, cmap='viridis')

    fontdict = {'fontsize': 14}

    ax.set_xticklabels([''] + sentence, fontdict=fontdict, rotation=90)
    ax.set_yticklabels([''] + predicted_sentence, fontdict=fontdict)

    plt.show()

plot_attention()

但由于列表“句子”中的元素较多,似乎不一致:

def plot_attention():
    attention = np.array([[7.78877574e-10, 4.04739769e-10, 6.65854022e-05, 1.63362725e-04,
    2.85054208e-04, 8.50252633e-04, 4.58042100e-02, 7.78877574e-10, 4.04739769e-10, 6.65854022e-05, 1.63362725e-04,
    2.85054208e-04, 8.50252633e-04, 4.58042100e-02],
   [9.23501700e-02, 5.69618285e-01, 1.80586591e-01, 9.78111699e-02,
    2.71992851e-02, 9.59911197e-03, 2.54837354e-03, 7.78877574e-10, 4.04739769e-10, 6.65854022e-05, 1.63362725e-04,
    2.85054208e-04, 8.50252633e-04, 4.58042100e-02]])




    sentence = ['<start>', 'hace', 'mucho', 'frio', 'aqui', '.', '<end>', '<start>', 'hace', 'mucho', 'frio', 'aqui', '.', '<end>'] 
    predicted_sentence = ['it', 's']


    fig = plt.figure(figsize=(20,10))
    ax = fig.add_subplot(1, 1, 1)
    ax.matshow(attention, cmap='viridis')

    fontdict = {'fontsize': 14}

    ax.set_xticklabels([''] + sentence, fontdict=fontdict, rotation=90)
    ax.set_yticklabels([''] + predicted_sentence, fontdict=fontdict)

    plt.show()

plot_attention()

I expect the x-axis to be perfectly aligned and that all elements of the x-axis显示(不是像现在那样每秒钟显示一次)


Tags: plottensorflownpfigpltaxstartsentence
1条回答
网友
1楼 · 发布于 2024-06-01 06:59:50

问题在于,您只设置记号标签,而没有指定记号的位置。无论何时修改勾号标签,都应始终首先设置勾号位置。因此,在代码中执行以下操作

ax.set_xticks(range(len(sentence)))
ax.set_yticks(range(len(predicted_sentence)))
ax.set_xticklabels(sentence, fontdict=fontdict, rotation=90)
ax.set_yticklabels(predicted_sentence, fontdict=fontdict)

enter image description here

相关问题 更多 >