我试图实现一个定制的多输入多输出模型,该模型使用了this论文中提出的学习算法。模型本身在没有我用作基线的自定义学习算法的情况下运行良好。我遇到的问题是,代码卡在DebiaModel类的train_step函数的代码行中:
mc_pred = self.main_classifier([xu, xs], training=True)
它没有返回错误。运行一个小时后,我中断了内核,它返回错误消息说:
InvalidArgumentError: Operation 'while' has no attr named '_XlaCompile'.
During handling of the above exception, another exception occurred:
InvalidArgumentError: Operation 'gradients/while_grad/Placeholder_28' has no attr named '_read_only_resource_inputs'.
我不确定问题是什么,我也尝试在tf.GradientTape中使用persistent=True,而不是在一个watch中声明两个GradientTape。但是,发生了完全相同的错误
有人知道这个问题是什么吗?如何解决
我使用的是Tensorflow V2.3.0和Keras V2.4.0
源代码
class model_components:
def mitigation_expert():
inputs = Input(shape=(300,), dtype=tf.int32, name="me_input")
x = Embedding(num_tokens, 300, weights=[embedding_matrix], input_length=max_length, trainable=False, name="me_embedding")(inputs)
x = LSTM(300, return_sequences=False, name="me_lstm")(x)
model = Model(inputs, x)
return model
def control_expert():
inputs = Input(shape=(22,), dtype=tf.int32, name="ce_input")
y = Dense(19, activation='relu', name="ce_hidden")(inputs)
model = Model(inputs, y)
return model
def main_classifier():
# Expert components
me = model_components.mitigation_expert()
ce = model_components.control_expert()
# Main classifier
ensemble = concatenate([me.output, ce.output], name="pred_ensemble")
pred_output = Dense(319, activation="relu", name="pred_hidden")(ensemble)
pred_output = Dense(3, activation="softmax", name="pred_output")(pred_output)
model = Model(inputs=[me.input, ce.input], outputs=pred_output, name="main_classifier")
return model
def adversary_classifier():
# Mitigation Expert component
me = model_components.mitigation_expert()
# Adversary classifier
adv_output = Dense(300, activation='relu', name="adv_hidden")(me.output)
adv_output = Dense(1, activation='sigmoid', name="adv_output")(adv_output)
model = Model(inputs=me.input, outputs=adv_output, name="adversary_classifier")
return model
def tf_normalize(x):
return x / (tf.norm(x) + np.finfo(np.float32).tiny)
class DebiasModel(keras.Model):
def __init__(self, main_classifier, adversary_classifier):
super(DebiasModel, self).__init__()
self.main_classifier = main_classifier
self.adversary_classifier = adversary_classifier
def compile(self, mc_optimizer, adv_optimizer, mc_loss, adv_loss, debias_param):
super(DebiasModel, self).compile()
self.mc_optimizer = mc_optimizer
self.adv_optimizer = adv_optimizer
self.mc_loss = mc_loss
self.adv_loss = adv_loss
self.debias_param = debias_param
def train_step(self, data):
# Unpack data from model.fit()
x, y, sample_weight = data
# Unpack input and output features
xu, xs = x
y_mc = y['pred_output']
z_adv = y['adv_output']
# Unpack sample_weights
mainClass_weights = sample_weight["pred_output"]
protectClass_weights = sample_weight["adv_output"]
# Generate prediction and compute loss for Main_Classifier
with tf.GradientTape() as mc_tape, tf.GradientTape() as me_mc_tape:
mc_pred = self.main_classifier([xu, xs], training=True)
mc_loss = self.mc_loss(y_mc, mc_pred, sample_weight=mainClass_weights)
# Compute and Apply Gradients for CE & Main Classifier
mc_trainable_vars = self.main_classifier.trainable_weights[3:]
mc_grads = mc_tape.gradient(mc_loss, mc_trainable_vars)
self.mc_optimizer.apply_gradients(zip(mc_grads, mc_trainable_vars))
# Generate prediction and compute loss for Adversary_Classifier
with tf.GradientTape() as adv_tape, tf.GradientTape() as me_adv_tape:
adv_pred = self.adversary_classifier(xu)
adv_loss = self.adv_loss(z_adv, adv_pred, sample_weight=protectClass_weights)
# Compute and Apply Gradients for CE & Main Classifier
adv_trainable_vars = self.adversary_classifier.trainable_weights[3:]
adv_grads = adv_tape.gradient(adv_loss, adv_trainable_vars)
self.adv_optimizer.apply_gradients(zip(adv_grads, adv_trainable_vars))
# Compute and Apply Gradients to debias ME
me_adv_debias_trainable_vars = self.adversary_classifier.trainable_weights[:3]
adv_debias_grads = me_adv_tape.gradient(adv_loss, me_adv_debias_trainable_vars)
adv_debias_dict = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer(me_adv_debias_trainable_vars, adv_debias_grads), 0)
me_mc_debias_trainable_vars = self.main_classifier.trainable_weights[:3]
mc_debias_grads = me_mc_tape.gradient(mc_loss, me_mc_debias_trainable_vars)
me_grads = []
for g, v in zip(mc_debias_grads, me_mc_debias_trainable_vars):
unit_adv = tf_normalize(adv_debias_dict.lookup(v))
g -= tf.math.reduce_sum(g * unit_adv) * unit_adv
g -= self.debias_param * adv_debias_dict.lookup(v)
me_grads.append(zip(g, v))
self.mc_optimizer.apply_gradients(me_grads)
return {"pred_loss": mc_loss, "adv_loss": adv_loss}
model = DebiasModel(model_components.main_classifier(),
model_components.adversary_classifier())
model.compile(mc_optimizer=tf.keras.optimizers.Adam(),
adv_optimizer=tf.keras.optimizers.Adam(),
mc_loss=tf.keras.losses.CategoricalCrossentropy(),
adv_loss=tf.keras.losses.BinaryCrossentropy(),
debias_param=1)
epoch = 5
sample_weights = {
"pred_output": mainClass_weight,
"adv_output": protectClass_weight,}
model.fit(x=[xu_train, xs_train],
y={"pred_output": y_train, "adv_output": z_train},
validation_data=([xu_val, xs_val], {"pred_output": y_val, "adv_output": z_val}),
sample_weight=sample_weights, epochs=epoch, batch_size=256, verbose=1)
错误回溯
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in get_attr(self, name)
2485 with c_api_util.tf_buffer() as buf:
-> 2486 pywrap_tf_session.TF_OperationGetAttrValueProto(self._c_op, name, buf)
2487 data = pywrap_tf_session.TF_GetBuffer(buf)
InvalidArgumentError: Operation 'while' has no attr named '_XlaCompile'.
During handling of the above exception, another exception occurred:
ValueError Traceback (most recent call last)
51 frames
ValueError: Operation 'while' has no attr named '_XlaCompile'.
During handling of the above exception, another exception occurred:
InvalidArgumentError Traceback (most recent call last)
InvalidArgumentError: Operation 'gradients/while_grad/Placeholder_28' has no attr named '_read_only_resource_inputs'.
注意:我没有添加完整的回溯,但如果需要,我可以提供它。非常感谢
目前没有回答
相关问题 更多 >
编程相关推荐