我正试着在几批上逐步褪去一层keras。因此,我编写了一个自定义层“DecayingSkip”。另外,我把它剩余地加到另一层。我试图实现一个淡出跳过连接。 但是,代码似乎无法正常工作。该模型编译和训练,但层激活并没有像预期的那样淡出。我做错什么了?你知道吗
class DecayingSkip(Layer):
def __init__(self, fade_out_at_batch, **kwargs):
self.fade_out_at_batch = K.variable(fade_out_at_batch)
self.btch_cnt = K.variable(0)
super(decayingSkip, self).__init__(**kwargs)
def call(self, x):
self.btch_cnt = self.btch_cnt + 1.0
return K.switch(
self.btch_cnt >= self.fade_out_at_batch,
x * 0,
x * (1.0 - ((1.0 / self.fade_out_at_batch) * self.btch_cnt))
)
def add_fade_out(fadeOutLayer, layer, fade_out_at_batch):
cnn_match = Conv2D(filters=int(layer.shape[-1]), kernel_size=1, activation=bounded_relu)(fadeOutLayer)
fadeOutLayer = DecayingSkip(fade_out_at_batch=fade_out_at_batch, name=name + '_fade_out')(cnn_match)
return Add()([fadeOutLayer, layer])
此外,在另一次尝试中,我尝试使用一个tensorflow变量,我在会话中更改了该变量,如:
def add_fade_out(fadeOutLayer, layer):
fadeOutLayer = Conv2D(filters=int(layer.shape[-1]), kernel_size=1, activation='relu')(fadeOutLayer)
alph = K.variable(1.0, name='alpha')
fadeOutLayer = Lambda(lambda x: x * alph)(fadeOutLayer)
return Add()([fadeOutLayer, layer])
sess = K.get_session()
lw = sess.graph.get_tensor_by_name("alpha:0")
sess.run(K.tf.assign(lw, new_value))
这也不管用。为什么?你知道吗
我想我找到了解决办法。我将层的调用函数更改为:
相关问题 更多 >
编程相关推荐