我正在尝试在TPU上培训变压器编码器(从这里-https://www.tensorflow.org/tutorials/text/transformer):
def test():
train_step_signature = [
tf.TensorSpec(shape=(None, None), dtype=tf.int64),
tf.TensorSpec(shape=(None, None), dtype=tf.int64),
]
@tf.function(input_signature=train_step_signature)
def train_step(inp, tar):
with tf.GradientTape() as tape:
predictions = transf(inp,
True,
None,
None,
None)
loss = loss_function(tar, predictions) # <- error is here
# I use SparseCategoricalCrossentropy()
vocabsize=1000
transf = Transformer(num_layers, d_model, num_heads, dff,
vocabsize, vocabsize,
pe_input=vocabsize,
pe_target=vocabsize,
rate=dropout_rate)
for iter in range(1,75000):
print(iter)
inp=np.random.randint(vocabsize, size=(5,11))
tar=np.random.randint(vocabsize, size=(5,11))
train_step(inp,tar)
它在CPU上工作。但在TPU上迭代约100次后,调用loss_函数时出现错误(如上所示):
InvalidArgumentError:
Function invoked by the following node is not compilable: {{node __inference_train_step_4179}} = __inference_train_step_4179[_XlaMustCompile=true, config_proto="\n\007\n\003GPU\020\000\n\007\n\003CPU\020\0012\002J\0008\001", executor_type=""](dummy_input, dummy_input, dummy_input, dummy_input...
Uncompilable nodes: sparse_categorical_crossentropy/SparseSoftmaxCrossEntropyWithLogits/assert_equal_1/Assert/Const:
unsupported op: Const op with type DT_STRING is not supported by XLA.
Stacktrace: Node: __inference_train_step_4179, function: Node: sparse_categorical_crossentropy/SparseSoftmaxCrossEntropyWithLogits/assert_equal_1/Assert/Const, function: __inference_train_step_4179 ...
我理解,错误是由XLA不支持的损失函数中的断言引起的。 我能在这里做什么
目前没有回答
相关问题 更多 >
编程相关推荐