具有可训练初始状态的TensorFlow Keras RNN
trainable-initial-state-rnn的Python项目详细描述
处理TensorFlow-Keras递归神经网络的初始状态 如[1]中所建议的,在培训期间作为一个或多个参数学习的层。在
默认情况下,普通RNN使用全零初始状态。为什么不让神经 网络学习一个更智能的初始状态?在
trainable-initial-state-rnn包提供了一个类 TrainableInitialStateRNN,它可以包装任何tf.kerasRNN(或 双向RNN)和管理新的初始状态变量 重量RNS。在
典型用法如下。在
importtensorflowastffromtrainable_initial_state_rnnimportTrainableInitialStateRNNbase_rnn=tf.keras.layers.LSTM(256)rnn=TrainableInitialStateRNN(base_rnn)# Treats initial state as a variable!model=tf.keras.Model(...)# Use rnn like any other tf.keras layer in your modelmodel.compile(...)history=model.fit(...)
文件可在 Read the Docs。在
安装
trainable_initial_state_rnn包可以使用 pip实用程序直接从包的 GitHub page:
^{pr2}$或者,从 Python Package Index (PyPI):
pip install trainable-initial-state-rnn
Note.安装用于开发的项目(例如,对
源代码),从GitHub克隆项目存储库并运行make dev
:
git clone https://github.com/artemmavrin/trainable-initial-state-rnn.git cd trainable-initial-state-rnn # Optional but recommended: create and activate a new Python virtual environment make dev
这将额外安装所需的要求 运行测试、检查代码覆盖率和生成文档。在
参考文献
[1] | Felix A. Gers, Nicol N. Schraudolph, Jürgen Schmidhuber. Learning Precise Timing with LSTM Recurrent Networks. Journal of Machine Learning Research 3 (2002) 115-143. (Link) |
- 项目
标签: