Tensorflow:PartialTensorShape:合并期间不兼容的列组:2对1

2024-09-30 18:19:35 发布

您现在位置:Python中文网/ 问答频道 /正文

我在后端将keras与tf-2.2一起使用,它显示了这个错误

Traceback (most recent call last):
  File "run.py", line 97, in <module>
    task_entry_function()
  File "/data-crystina/src/capreolus-unpublished/capreolus/task/rerank.py", line 47, in train
    return self.rerank_run(best_search_run, self.get_results_path())
  File "/data-crystina/src/capreolus-unpublished/capreolus/task/rerank.py", line 85, in rerank_run
    self.benchmark.relevance_level,
  File "/data-crystina/src/capreolus-unpublished/capreolus/trainer/__init__.py", line 578, in train
    use_multiprocessing=True,
  File "/data-crystina/anaconda3/envs/maxp/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 66, in _method_wrapper
    return method(self, *args, **kwargs)
  File "/data-crystina/anaconda3/envs/maxp/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 855, in fit
    callbacks.on_train_batch_end(step, logs)
  File "/data-crystina/anaconda3/envs/maxp/lib/python3.7/site-packages/tensorflow/python/keras/callbacks.py", line 389, in on_train_batch_end
    logs = self._process_logs(logs)
  File "/data-crystina/anaconda3/envs/maxp/lib/python3.7/site-packages/tensorflow/python/keras/callbacks.py", line 265, in _process_logs
    return tf_utils.to_numpy_or_python_type(logs)                                                                                                                                                             File "/data-crystina/anaconda3/envs/maxp/lib/python3.7/site-packages/tensorflow/python/keras/utils/tf_utils.py", line 523, in to_numpy_or_python_type
    return nest.map_structure(_to_single_numpy_or_python_type, tensors)
  File "/data-crystina/anaconda3/envs/maxp/lib/python3.7/site-packages/tensorflow/python/util/nest.py", line 617, in map_structure
    structure[0], [func(*x) for x in entries],
  File "/data-crystina/anaconda3/envs/maxp/lib/python3.7/site-packages/tensorflow/python/util/nest.py", line 617, in <listcomp>
    structure[0], [func(*x) for x in entries],
  File "/data-crystina/anaconda3/envs/maxp/lib/python3.7/site-packages/tensorflow/python/keras/utils/tf_utils.py", line 519, in _to_single_numpy_or_python_type
    x = t.numpy()
  File "/data-crystina/anaconda3/envs/maxp/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 961, in numpy
    maybe_arr = self._numpy()  # pylint: disable=protected-access
  File "/data-crystina/anaconda3/envs/maxp/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 929, in _numpy
    six.raise_from(core._status_to_exception(e.code, e.message), None)
  File "<string>", line 3, in raise_from
tensorflow.python.framework.errors_impl.InvalidArgumentError: {{function_node __inference_train_function_100056}} PartialTensorShape: Incompatible ranks during merge: 2 vs. 1
         [[{{node map_6/TensorArrayV2Stack/TensorListStack}}]]
         [[MultiDeviceIteratorGetNextFromShard]]
         [[RemoteCall]]
         [[IteratorGetNextAsOptional]]
2020-07-03 07:19:03.088112: W tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.cc:76] Unable to destroy remote tensor handles. If you are running a tf.function, it usually indicates som
e op in the graph gets an error: {{function_node __inference_train_function_100056}} PartialTensorShape: Incompatible ranks during merge: 2 vs. 1
         [[{{node map_6/TensorArrayV2Stack/TensorListStack}}]]
         [[MultiDeviceIteratorGetNextFromShard]]
         [[RemoteCall]]
         [[IteratorGetNextAsOptional]]

很抱歉,没有找到一个小片段来重现这一点。但是我进入了..python3.7/site-packages/tensorflow/python/keras/callbacks.py,在函数中:

  def on_train_batch_end(self, batch, logs=None):
    """Calls the `on_train_batch_end` methods of its callbacks.

    Arguments:
        batch: integer, index of batch within the current epoch.
        logs: dict. Metric results for this batch.
    """
    if self._should_call_train_batch_hooks:
      # print("<<<<", logs.keys())
      # print(">>>", type(list(logs.values())[0]))
      logs = self._process_logs(logs)
      self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs)

我打印出logs,发现它是一个只包含一个键loss的字典,其值的类型是class 'tensorflow.python.framework.ops.EagerTensor'>。但是,由于相同的错误,logs["loss"]无法打印到目录,并且与logs["loss"].shape相同。我在网上找不到类似的案例,不知道是否有人见过这个案例


Tags: inpyselfdatapackagestensorflowbatchline