处理顺序数据的注意机制,考虑每个时间戳的上下文
keras-self-attention的Python项目详细描述
凯拉斯自我关注
处理顺序数据的注意机制,考虑每个时间戳的上下文。
安装
pip install keras-self-attention
用法
基本
默认情况下,注意层使用附加注意,并在计算相关性时考虑整个上下文。下面的代码创建了一个关注层,它遵循第一节中的公式(attention_activation
是e_{t, t'}
的激活函数):
importkerasfromkeras_self_attentionimportSeqSelfAttentionmodel=keras.models.Sequential()model.add(keras.layers.Embedding(input_dim=10000,output_dim=300,mask_zero=True))model.add(keras.layers.Bidirectional(keras.layers.LSTM(units=128,return_sequences=True)))model.add(SeqSelfAttention(attention_activation='sigmoid'))model.add(keras.layers.Dense(units=5))model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['categorical_accuracy'],)model.summary()
局部注意力
对于一个数据来说,全球环境可能过于宽泛。参数attention_width
控制本地上下文的宽度:
fromkeras_self_attentionimportSeqSelfAttentionSeqSelfAttention(attention_width=15,attention_activation='sigmoid',name='Attention',)
乘法注意
你可以通过设置attention_type
:
fromkeras_self_attentionimportSeqSelfAttentionSeqSelfAttention(attention_width=15,attention_type=SeqSelfAttention.ATTENTION_TYPE_MUL,attention_activation=None,kernel_regularizer=keras.regularizers.l2(1e-6),use_attention_bias=False,name='Attention',)
正则化
要使用正则化器,请将attention_regularizer_weight
设置为正数:
importkerasfromkeras_self_attentionimportSeqSelfAttentioninputs=keras.layers.Input(shape=(None,))embd=keras.layers.Embedding(input_dim=32,output_dim=16,mask_zero=True)(inputs)lstm=keras.layers.Bidirectional(keras.layers.LSTM(units=16,return_sequences=True))(embd)att=SeqSelfAttention(attention_type=SeqSelfAttention.ATTENTION_TYPE_MUL,kernel_regularizer=keras.regularizers.l2(1e-4),bias_regularizer=keras.regularizers.l1(1e-4),attention_regularizer_weight=1e-4,name='Attention')(lstm)dense=keras.layers.Dense(units=5,name='Dense')(att)model=keras.models.Model(inputs=inputs,outputs=[dense])model.compile(optimizer='adam',loss={'Dense':'sparse_categorical_crossentropy'},metrics={'Dense':'categorical_accuracy'},)model.summary(line_length=100)
加载模型
确保将SeqSelfAttention
添加到自定义对象:
importkeraskeras.models.load_model(model_path,custom_objects=SeqSelfAttention.get_custom_objects())
仅历史记录
当只能使用历史数据时,将history_only
设置为True
:
SeqSelfAttention(attention_width=3,history_only=True,name='Attention',)
多头
请参考keras-multi-head。