列车估计员()和.predict()对于小数据集来说太慢

2024-05-20 13:16:30 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试实现一个DQN,它在同一个模型上多次调用Estimator.train(),然后调用Estimator.predict(),每个模型都有少量的示例。但是每次调用至少需要几百毫秒到一秒钟以上,这与1-20这样的小数字的示例数无关。你知道吗

我认为这些延迟是由重建图形和保存每次调用的检查点造成的。有没有办法在内存中保持相同的图形和参数,以便快速训练预测迭代或以其他方式加速迭代?你知道吗


Tags: 内存模型图形示例参数dqn方式train
1条回答
网友
1楼 · 发布于 2024-05-20 13:16:30

转换为tf.keras.Model而不是Estimator,使用tf.keras.Model.fit()而不是Estimator.train()fit()没有train()所具有的固定延迟。Keras predict()也没有。你知道吗

相关问题 更多 >