在TPU上训练时如何使用tensorFlow中的交叉熵损失?

2024-10-02 02:30:16 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试在TPU上培训变压器编码器(从这里-https://www.tensorflow.org/tutorials/text/transformer):

def test():
    train_step_signature = [
        tf.TensorSpec(shape=(None, None), dtype=tf.int64),
        tf.TensorSpec(shape=(None, None), dtype=tf.int64),
    ]
    @tf.function(input_signature=train_step_signature)
    def train_step(inp, tar):
      with tf.GradientTape() as tape:
        predictions = transf(inp,  
                                     True, 
                                     None, 
                                     None, 
                                     None)
        loss = loss_function(tar, predictions) # <- error is here
                                               # I use SparseCategoricalCrossentropy()

    vocabsize=1000
    transf = Transformer(num_layers, d_model, num_heads, dff,
                              vocabsize, vocabsize, 
                              pe_input=vocabsize, 
                              pe_target=vocabsize,
                              rate=dropout_rate)
    for iter in range(1,75000):
      print(iter)
      inp=np.random.randint(vocabsize, size=(5,11))
      tar=np.random.randint(vocabsize, size=(5,11))
      train_step(inp,tar)

它在CPU上工作。但在TPU上迭代约100次后,调用loss_函数时出现错误(如上所示):

InvalidArgumentError:

Function invoked by the following node is not compilable: {{node __inference_train_step_4179}} = __inference_train_step_4179[_XlaMustCompile=true, config_proto="\n\007\n\003GPU\020\000\n\007\n\003CPU\020\0012\002J\0008\001", executor_type=""](dummy_input, dummy_input, dummy_input, dummy_input...

Uncompilable nodes: sparse_categorical_crossentropy/SparseSoftmaxCrossEntropyWithLogits/assert_equal_1/Assert/Const:

unsupported op: Const op with type DT_STRING is not supported by XLA.

Stacktrace: Node: __inference_train_step_4179, function: Node: sparse_categorical_crossentropy/SparseSoftmaxCrossEntropyWithLogits/assert_equal_1/Assert/Const, function: __inference_train_step_4179 ...

我理解,错误是由XLA不支持的损失函数中的断言引起的。 我能在这里做什么


Tags: noneinputistfstepfunctiontraintar

热门问题