我正在为最后一个项目做一个LSTM。我在这里一直遵循TensorFlow的教程:https://www.tensorflow.org/tutorials/sequences/text_generation大部分内容,特别是如何保存和加载模型。但是,它会出现以下错误:
Traceback (most recent call last): File "D:\xxx\Documents\Class Coding\Artificial Intelligence\Shelley>\Writerbot.py", line 187, in restore_progress()
File "D:\xxx\Documents\Class Coding\Artificial Intelligence\Shelley\Writerbot.py", line 141, in restore_progress
shelley.load_weights(weights)
File "C:\Users\xxx\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\keras\engine\network.py", line 1508, in load_weights
if _is_hdf5_filepath(filepath):
File "C:\Users\xxx\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\keras\engine\network.py", line 1648, in _is_hdf5_filepath
return filepath.endswith('.h5') or filepath.endswith('.keras')
AttributeError: 'NoneType' object has no attribute 'endswith'
下面是我与加载和恢复重量相关的代码,据我所知,因为其余错误都来自keras:
def create_shelley(vocab, embedding, numunits, batch):
"""This is what actually creates a neural network."""
shelley = tf.keras.Sequential([
tf.keras.layers.Embedding(vocab, embedding,
batch_input_shape=[batch, None]),
lstm(numunits,
return_sequences=True,
recurrent_initializer='glorot_uniform',
stateful=True),
tf.keras.layers.Dense(vocab)
])
return shelley
def train():
"""We create weight checkpoints as we train our neural network on files fed into it."""
checkpoints = 'D:\\xxx\\Documents\\Class Coding\\Artificial Intelligence\\Shelley\\trainingcheckpoints'
prefix = os.path.join(checkpoints, "ckpt_{epoch}")
callback=tf.keras.callbacks.ModelCheckpoint(
filepath=prefix,
save_weights_only=True)
print(epochsteps)
history = shelley.fit(botfeed.repeat(), epochs=epochs, steps_per_epoch=epochsteps, callbacks=[callback])
def restore_progress():
"""Load the most recent weight checkpoint."""
trainingcheckpoints = "D:\\Robin Pegau\\Documents\\Class Coding\\Artificial Intelligence\\Shelley\\trainingcheckpoints\\checkpoint"
weights = tf.train.latest_checkpoint(trainingcheckpoints)
shelley = create_shelley(vocab, embed, totalunits, batch = 1)
shelley.load_weights(weights)
shelley.build(tf.TensorShape([1, None]))
restore_progress()
存在一个没有文件类型的“检查点”文件。还有一些文件看起来像“ckpt_[x].index”和“ckpt\[x].data-00000-of-00001”
提前感谢大家的帮助。在
目前没有回答
相关问题 更多 >
编程相关推荐