擅长:python、mysql、java
<p>这就像一个符咒<a href="http://zachmoshe.com/2017/04/03/pickling-keras-models.html" rel="noreferrer">http://zachmoshe.com/2017/04/03/pickling-keras-models.html</a>:</p>
<pre><code>import types
import tempfile
import keras.models
def make_keras_picklable():
def __getstate__(self):
model_str = ""
with tempfile.NamedTemporaryFile(suffix='.hdf5', delete=True) as fd:
keras.models.save_model(self, fd.name, overwrite=True)
model_str = fd.read()
d = { 'model_str': model_str }
return d
def __setstate__(self, state):
with tempfile.NamedTemporaryFile(suffix='.hdf5', delete=True) as fd:
fd.write(state['model_str'])
fd.flush()
model = keras.models.load_model(fd.name)
self.__dict__ = model.__dict__
cls = keras.models.Model
cls.__getstate__ = __getstate__
cls.__setstate__ = __setstate__
make_keras_picklable()
</code></pre>
<p>另外,我有一些问题,由于循环引用引起的<code>model.to_json()</code>问题,这个错误不知怎么被上面的代码所吞噬,从而导致pickle函数永远运行。</p>