水平堆叠层的包装层

keras-multi-head的Python项目详细描述


路缘石多头

TravisCoverageVersionDownloadsLicense

用于水平堆叠层的包装层。

安装

pip install keras-multi-head

用法

重复层

如果仅提供一个层,则将复制该层。参数layer_num控制最终将复制多少层。

importkerasfromkeras_multi_headimportMultiHeadmodel=keras.models.Sequential()model.add(keras.layers.Embedding(input_dim=100,output_dim=20,name='Embedding'))model.add(MultiHead(keras.layers.LSTM(units=32),layer_num=5,name='Multi-LSTMs'))model.add(keras.layers.Flatten(name='Flatten'))model.add(keras.layers.Dense(units=4,activation='softmax',name='Dense'))model.build()model.summary()

使用多层

第一个参数也可以是具有不同配置的层的列表,但是,它们必须具有相同的输出形状。

importkerasfromkeras_multi_headimportMultiHeadmodel=keras.models.Sequential()model.add(keras.layers.Embedding(input_dim=100,output_dim=20,name='Embedding'))model.add(MultiHead([keras.layers.Conv1D(filters=32,kernel_size=3,padding='same'),keras.layers.Conv1D(filters=32,kernel_size=5,padding='same'),keras.layers.Conv1D(filters=32,kernel_size=7,padding='same'),],name='Multi-CNNs'))model.build()model.summary()

线性变换

当给定hidden_dim时,输入数据将映射到每个层的相同形状的不同值。

正则化

当您希望从平行层中提取不同的特征时,将使用正则化。可以自定义层中权重的索引,间隔表示权重的部分和正则化因子。

例如,双向lstm层默认有6个权重,前3个属于前向层。前向层中的第二个权重(递归核)控制递归连接的门的计算。计算单元状态的核心是递归核的x 2到x 3个单元。我们可以对内核使用正则化:

importkerasfromkeras_multi_headimportMultiHeadmodel=keras.models.Sequential()model.add(keras.layers.Embedding(input_dim=5,output_dim=3,name='Embed'))model.add(MultiHead(layer=keras.layers.Bidirectional(keras.layers.LSTM(units=16),name='LSTM'),layer_num=5,reg_index=[1,4],reg_slice=(slice(None,None),slice(32,48)),reg_factor=0.1,name='Multi-Head-Attention',))model.add(keras.layers.Flatten(name='Flatten'))model.add(keras.layers.Dense(units=2,activation='softmax',name='Dense'))model.build()
  • reg_index:指数layer.get_weights(),单个整数或整数列表。
  • reg_sliceslices或slices的元组或以前选择的列表。如果在reg_index中提供了多个索引,并且reg_slice不是列表,则假定reg_slice等于所有索引。如果将此参数保留为None,则将使用整个数组。
  • reg_factor:正则化因子,浮点数或浮点数列表。

多头注意力

提供了一个更具体的多头层(因为普通层更难使用)。该层使用缩放的点积注意层作为其子层,只需要head_num

importkerasfromkeras_multi_headimportMultiHeadAttentioninput_layer=keras.layers.Input(shape=(2,3),name='Input',)att_layer=MultiHeadAttention(head_num=3,name='Multi-Head',)(input_layer)model=keras.models.Model(inputs=input_layer,outputs=att_layer)model.compile(optimizer='adam',loss='mse',metrics={},)model.summary()

当输入只有一层时,输入张量和输出张量的形状是相同的。当给定列表时,输入层将被视为查询、键和值:

importkerasfromkeras_multi_headimportMultiHeadAttentioninput_query=keras.layers.Input(shape=(2,3),name='Input-Q',)input_key=keras.layers.Input(shape=(4,5),name='Input-K',)input_value=keras.layers.Input(shape=(4,6),name='Input-V',)att_layer=MultiHeadAttention(head_num=3,name='Multi-Head',)([input_query,input_key,input_value])model=keras.models.Model(inputs=[input_query,input_key,input_value],outputs=att_layer)model.compile(optimizer='adam',loss='mse',metrics={},)model.summary()

欢迎加入QQ群-->: 979659372 Python中文网_新手群

推荐PyPI第三方库


热门话题
arduino JAVA串行数据接收非法字符   logging Create log method在java中记录当前执行方法的参数及其值   使用maven使用参数构建java项目   java在尝试失败后关闭   从excel工作表内容生成java xml   java Apache HttpClient:setConnectTimeout()vs.setConnectionTimeOlive()vs.setSocketTimeout()   Java的性能。forEach(列表::添加)vs。collect(Collectors.toList())   java生成给定数字的金字塔?   netbeans 8.2,windows 10上的java调试器错误   java glassfish部署错误   用java C++加密CryptoAPI SIMPLEBLOB   java Tomee jpa设置   需要使用Java/Selenium或任何语言从Googe TAG manager提取数据层信息以实现自动化   azure如何在Java中为Iterable的for循环内创建计数器并获取计数器变量的值   java JDialog未显示最小化/关闭按钮   java断言true,来自两个方法的变量