在python脚本中,我创建了一个类,其中包含keras
模型,如下所示:
from keras.layers import Input, Activation, Dense
from keras.models import Model
class Klass:
def __init__(self, input_dims, output_dims, hidden_dims, optimizer, a, b):
self.input_dims = input_dims
self.output_dims = output_dims
self.hidden_dims = hidden_dims
self.optimizer = optimizer
self.a = a
self.b = b
self.__build_nn()
def __build_nn(self):
inputs = Input(shape=(self.input_dims,))
net = inputs
for h_dim in self.hidden_dims:
net = Dense(h_dim, kernel_initializer='he_uniform')(net)
net = Activation("relu")(net)
outputs = Dense(self.output_dims)(net)
outputs = Activation("linear")(outputs)
self.nn1 = Model(inputs=inputs, outputs=outputs)
self.nn2 = Model(inputs=inputs, outputs=outputs)
self.nn1.compile(optimizer=self.optimizer, loss='mean_squared_error')
self.nn2.compile(optimizer=self.optimizer, loss='mean_squared_error')
在创建了一个Klass
实例之后,我想对其进行深度复制:
但是,这会抛出一个TypeError: can't pickle _thread.RLock objects
。我非常确定这个错误与class对象中的keras
模型有关,因为我能够在没有keras
模型的情况下获得类似类的深层副本。在
不幸的是,我在互联网上找不到解决这个问题的方法,因为大多数关于深度复制keras
模型的问题都是试图克隆keras
模型,就像here。在
那么,如何获得包含keras
模型的类的深层副本呢?在
编辑
这三个问题(1,2,3)在不同的情况下提到了类似的错误。然而,那里提供的解决方案并不适用于我的情况。在
编辑2
正如注释中所建议的,我在类中添加了一个copy
方法。这是一个可行的解决方案吗?在
class Klass:
def __init__(self, input_dims, output_dims, hidden_dims, optimizer, a, b):
self.input_dims = input_dims
self.output_dims = output_dims
self.hidden_dims = hidden_dims
self.optimizer = optimizer
self.a = a
self.b = b
self.__build_nn()
# [...]
def copy(self):
new = Klass(self.input_dims, self.output_dims, self.hidden_dims,
self.optimizer, self.a, self.b)
new.nn1.set_weights(self.nn1.get_weights())
new.nn2.set_weights(self.nn2.get_weights())
return new
在注释中解决:为
Klass
添加了一个copy
方法,它将旧的Klass
实例的权重复制到新创建的实例。在相关问题 更多 >
编程相关推荐