如何使用KerAsselAttention包可视化attention LSTM?

2024-09-27 00:22:16 发布

您现在位置:Python中文网/ 问答频道 /正文

我使用(keras-self-attention)在KERAS中实现attention LSTM。在训练了模型之后,我如何将注意力部分形象化?这是一个时间序列预测案例。在

from keras.models import Sequential
from keras_self_attention import SeqWeightedAttention
from keras.layers import LSTM, Dense, Flatten

model = Sequential()
model.add(LSTM(activation = 'tanh' ,units = 200, return_sequences = True, 
               input_shape = (TrainD[0].shape[1], TrainD[0].shape[2])))
model.add(SeqSelfAttention())
model.add(Flatten())    
model.add(Dense(1, activation = 'relu'))

model.compile(optimizer = 'adam', loss = 'mse')

Tags: fromimportselfaddmodelactivationkerasdense
1条回答
网友
1楼 · 发布于 2024-09-27 00:22:16

一种方法是获取给定输入的SeqSelfAttention的输出,并对它们进行组织,以便显示每个通道的预测结果(见下文)。要了解更高级的内容,请查看iNNvestigate library(包括使用示例)。在


说明show_features_1D获取layer_name(可以是一个子串)层输出并显示每个通道的预测(标记),时间步长沿x轴,输出值沿y轴。
  • input_data=单批形状(1, input_shape)
  • prefetched_outputs=已经获得层输出;覆盖input_data
  • max_timesteps=要显示的最大时间步数
  • max_col_subplots=沿水平方向的子批次的最大值
  • equate_axes=强制所有x轴和y轴相等(建议用于公平比较)
  • show_y_zero=是否将y=0显示为红线
  • channel_axis=层特征维度(例如units对于LSTM,这是最后一个)
  • scale_width, scale_height=缩放显示的图像宽度和高度
  • dpi=图像质量(每英寸点数)

视觉效果(如下)说明

  • 第一种方法有助于查看提取特征的形状,而不管其大小,例如提供有关频率内容的信息
  • 第二种方法有助于查看特征关系,例如相对大小、偏差和频率。下面的结果与上面的图像形成了鲜明的对比,因为运行print(outs_1)会发现所有的量值都很小,变化不大,所以包括y=0点和等分轴会产生一条直线状的视觉效果,这可以解释为自我注意是偏向性的。在
  • 第三种方法对于将太多的特性可视化是很有用的;用batch_shape而不是input_shape定义模型会删除打印形状中的所有?,我们可以看到第一个输出的形状是(10, 60, 240),第二个输出的形状是(10, 240, 240)。换句话说,第一个输出返回LSTM channel attention,第二个输出返回“timesteps attention”。下面的热图结果可以解释为显示注意“冷却”w.r.t.时间步数。在

SeqWeightedAttention更容易可视化,但没有太多可可视化的;您需要去掉上面的Flatten才能使其工作。然后注意力的输出形状变成(10, 60)(10, 240)-对于这两个形状,可以使用一个简单的直方图plt.hist(只需确保排除批处理维度,即feed (60,)或{})。在


from keras.layers import Input, Dense, LSTM, Flatten, concatenate
from keras.models import Model
from keras.optimizers import Adam
from keras_self_attention = SeqSelfAttention
import numpy as np 

ipt   = Input(shape=(240,4))
x     = LSTM(60, activation='tanh', return_sequences=True)(ipt)
x     = SeqSelfAttention(return_attention=True)(x)
x     = concatenate(x)
x     = Flatten()(x)
out   = Dense(1, activation='sigmoid')(x)
model = Model(ipt,out)
model.compile(Adam(lr=1e-2), loss='binary_crossentropy')

X = np.random.rand(10,240,4) # dummy data
Y = np.random.randint(0,2,(10,1)) # dummy labels
model.train_on_batch(X, Y)

outs = get_layer_outputs(model, 'seq', X[0:1], 1)
outs_1 = outs[0]
outs_2 = outs[1]

show_features_1D(model,'lstm',X[0:1],max_timesteps=100,equate_axes=False,show_y_zero=False)
show_features_1D(model,'lstm',X[0:1],max_timesteps=100,equate_axes=True, show_y_zero=True)
show_features_2D(outs_2[0])  # [0] for 2D since 'outs_2' is 3D


^{pr2}$
def show_features_2D(data, cmap='bwr', norm=None,
                     scale_width=1, scale_height=1):
    if norm is not None:
        vmin, vmax = norm
    else:
        vmin, vmax = None, None  # scale automatically per min-max of 'data'

    plt.imshow(data, cmap=cmap, vmin=vmin, vmax=vmax)
    plt.xlabel('Timesteps', weight='bold', fontsize=14)
    plt.ylabel('Attention features', weight='bold', fontsize=14)
    plt.colorbar(fraction=0.046, pad=0.04)  # works for any size plot

    plt.gcf().set_size_inches(8*scale_width, 8*scale_height)
    plt.show()

def get_layer_outputs(model, layer_name, input_data, learning_phase=1):
    outputs   = [layer.output for layer in model.layers if layer_name in layer.name]
    layers_fn = K.function([model.input, K.learning_phase()], outputs)
    return layers_fn([input_data, learning_phase])

SeqWeightedAttention示例每个请求:

ipt   = Input(batch_shape=(10,240,4))
x     = LSTM(60, activation='tanh', return_sequences=True)(ipt)
x     = SeqWeightedAttention(return_attention=True)(x)
x     = concatenate(x)
out   = Dense(1, activation='sigmoid')(x)
model = Model(ipt,out)
model.compile(Adam(lr=1e-2), loss='binary_crossentropy')

X = np.random.rand(10,240,4) # dummy data
Y = np.random.randint(0,2,(10,1)) # dummy labels
model.train_on_batch(X, Y)

outs = get_layer_outputs(model, 'seq', X, 1)
outs_1 = outs[0][0] # additional index since using batch_shape
outs_2 = outs[1][0]

plt.hist(outs_1, bins=500); plt.show()
plt.hist(outs_2, bins=500); plt.show()

相关问题 更多 >

    热门问题