找到加载重量非类型

2024-09-29 07:33:07 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在为最后一个项目做一个LSTM。我在这里一直遵循TensorFlow的教程:https://www.tensorflow.org/tutorials/sequences/text_generation大部分内容,特别是如何保存和加载模型。但是,它会出现以下错误:

Traceback (most recent call last): File "D:\xxx\Documents\Class Coding\Artificial Intelligence\Shelley>\Writerbot.py", line 187, in restore_progress()

File "D:\xxx\Documents\Class Coding\Artificial Intelligence\Shelley\Writerbot.py", line 141, in restore_progress

shelley.load_weights(weights)

File "C:\Users\xxx\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\keras\engine\network.py", line 1508, in load_weights

if _is_hdf5_filepath(filepath):

File "C:\Users\xxx\AppData\Roaming\Python\Python36\site-packages\tensorflow\python\keras\engine\network.py", line 1648, in _is_hdf5_filepath

return filepath.endswith('.h5') or filepath.endswith('.keras')

AttributeError: 'NoneType' object has no attribute 'endswith'

下面是我与加载和恢复重量相关的代码,据我所知,因为其余错误都来自keras:

def create_shelley(vocab, embedding, numunits, batch):
    """This is what actually creates a neural network."""
    shelley = tf.keras.Sequential([
    tf.keras.layers.Embedding(vocab, embedding,
                              batch_input_shape=[batch, None]),
    lstm(numunits,
        return_sequences=True,
        recurrent_initializer='glorot_uniform',
        stateful=True),
    tf.keras.layers.Dense(vocab)
    ])
    return shelley

def train():
    """We create weight checkpoints as we train our neural network on files     fed into it."""
    checkpoints = 'D:\\xxx\\Documents\\Class Coding\\Artificial  Intelligence\\Shelley\\trainingcheckpoints'
    prefix = os.path.join(checkpoints, "ckpt_{epoch}")
    callback=tf.keras.callbacks.ModelCheckpoint(
        filepath=prefix,
        save_weights_only=True)

    print(epochsteps)
    history = shelley.fit(botfeed.repeat(), epochs=epochs,    steps_per_epoch=epochsteps, callbacks=[callback])

def restore_progress():
    """Load the most recent weight checkpoint."""
    trainingcheckpoints = "D:\\Robin Pegau\\Documents\\Class Coding\\Artificial Intelligence\\Shelley\\trainingcheckpoints\\checkpoint"
    weights = tf.train.latest_checkpoint(trainingcheckpoints)
    shelley = create_shelley(vocab, embed, totalunits, batch = 1)
    shelley.load_weights(weights)
    shelley.build(tf.TensorShape([1, None]))

restore_progress()

存在一个没有文件类型的“检查点”文件。还有一些文件看起来像“ckpt_[x].index”和“ckpt\[x].data-00000-of-00001”

提前感谢大家的帮助。在


Tags: pytflinedocumentsclasskerasfilexxx