我正在尝试按其索引筛选tensorflow.dataset
:
dataset = tf.data.Dataset.from_tensor_slices((sequences_matrix, label_data.astype(np.int8)))
dataset = dataset.cache()
dataset = dataset.enumerate()
@tf.function
def filter_function(i, data):
return i in train_index # train_index is a list of integers
train_dataset = dataset.filter(filter_function)
但我得到的错误如下:
Traceback (most recent call last):
File "/home/marzi/workspace/nlp_classification/src/main.py", line 355, in <module>
if __name__ == '__main__': main()
File "/home/marzi/workspace/nlp_classification/src/main.py", line 320, in main
deep_learning_algo(THE_DATA, HYPER_DICT)
File "/home/marzi/workspace/nlp_classification/src/main.py", line 226, in deep_learning_algo
tokenizer_name=tokenizer_name
File "/home/marzi/workspace/nlp_classification/src/train.py", line 118, in fit_normal
train_dataset = dataset.filter(filter_function)
File "/home/marzi/anaconda3/envs/nlp_classification/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1862, in filter
return FilterDataset(self, predicate)
File "/home/marzi/anaconda3/envs/nlp_classification/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 4264, in __init__
use_legacy_function=use_legacy_function)
File "/home/marzi/anaconda3/envs/nlp_classification/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 3371, in __init__
self._function = wrapper_fn.get_concrete_function()
File "/home/marzi/anaconda3/envs/nlp_classification/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 2939, in get_concrete_function
*args, **kwargs)
File "/home/marzi/anaconda3/envs/nlp_classification/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 2906, in _get_concrete_function_garbage_collected
graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
File "/home/marzi/anaconda3/envs/nlp_classification/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 3213, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "/home/marzi/anaconda3/envs/nlp_classification/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 3075, in _create_graph_function
capture_by_value=self._capture_by_value),
File "/home/marzi/anaconda3/envs/nlp_classification/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py", line 986, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/home/marzi/anaconda3/envs/nlp_classification/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 3364, in wrapper_fn
ret = _wrapper_helper(*args)
File "/home/marzi/anaconda3/envs/nlp_classification/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 3299, in _wrapper_helper
ret = autograph.tf_convert(func, ag_ctx)(*nested_args)
File "/home/marzi/anaconda3/envs/nlp_classification/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 780, in __call__
result = self._call(*args, **kwds)
File "/home/marzi/anaconda3/envs/nlp_classification/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 823, in _call
self._initialize(args, kwds, add_initializers_to=initializers)
File "/home/marzi/anaconda3/envs/nlp_classification/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 697, in _initialize
*args, **kwds))
File "/home/marzi/anaconda3/envs/nlp_classification/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 2855, in _get_concrete_function_internal_garbage_collected
graph_function, _, _ = self._maybe_define_function(args, kwargs)
File "/home/marzi/anaconda3/envs/nlp_classification/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 3213, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "/home/marzi/anaconda3/envs/nlp_classification/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 3075, in _create_graph_function
capture_by_value=self._capture_by_value),
File "/home/marzi/anaconda3/envs/nlp_classification/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py", line 986, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/home/marzi/anaconda3/envs/nlp_classification/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 600, in wrapped_fn
return weak_wrapped_fn().__wrapped__(*args, **kwds)
File "/home/marzi/anaconda3/envs/nlp_classification/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py", line 973, in wrapper
raise e.ag_error_metadata.to_exception(e)
tensorflow.python.framework.errors_impl.OperatorNotAllowedInGraphError: in user code:
/home/marzi/workspace/nlp_classification/src/train.py:116 filter_function *
return i in train_index
/home/marzi/anaconda3/envs/nlp_classification/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:877 __bool__
self._disallow_bool_casting()
/home/marzi/anaconda3/envs/nlp_classification/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:487 _disallow_bool_casting
"using a `tf.Tensor` as a Python `bool`")
/home/marzi/anaconda3/envs/nlp_classification/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:474 _disallow_when_autograph_enabled
" indicate you are trying to use an unsupported feature.".format(task))
OperatorNotAllowedInGraphError: using a `tf.Tensor` as a Python `bool` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.
但是如果我将filter函数中的条件从i in train_index
更改为i > 10
,它就可以正常工作。我不明白这两个条件之间有什么区别,这两个条件使其中一个产生错误,而另一个没有
使用
@tf.function
将把操作转换为图形模式,并在图形模式下列出理解is not supported。您可以改为使用tf.map_fn
或tf.py_function
:例如:
更多阅读:Better performance with tf.function
相关问题 更多 >
编程相关推荐