TensorFlow矢量化地图与argmax没有转换器,如何添加?

2024-10-01 05:00:08 发布

您现在位置:Python中文网/ 问答频道 /正文

我开发了一个tf函数(tf version 2.1):

@tf.function
def make_xcorr(T, S, sigma_S, outshape):
    ...
    xcorr = ... #combination of T and S
    sigma_xcorr = ... #combination of T, S and sigma_S


    def bootstrap(iter):
        _xcorr = tf.random.normal(outshape, xcorr, sigma_xcorr)
        _maxidxs = tf.math.argmax(_xcorr, axis=1) #shape (batch)
        return _maxidxs

    maxidxs = tf.vectorized_map(bootstrap, np.arange(100))
    return maxidxs

forloop中调用此函数:

def main():
    ...
    for i in range():
        ...
        out = make_xcorr(...)
        ...

其思想是在tf function内执行100次boostrap操作。我也尝试过使用tf.map_fn而不是tf.vectorized_map,它实际上是有效的,但我想尝试一下tf.vectorized_map的使用是否会加速,但我得到以下错误:

WARNING:tensorflow:Note that RandomStandardNormal inside pfor op may not give same output as inside a sequential loop. 

Traceback (most recent call last):                                                                                        File "<stdin>", line 1, in <module>                                                                                     File "/home/gangora/scripts/make_cross_correlation_testgpu.py", line 1018, in main_make_cross_correlation_testgpu         xcorr, maxidxs, sigma_xcorr, _maxidxs = make_xcorr_err1(flux_template_vec, data_partial, sigma_partial, (end-start, vec_z.shape[0]))                                                                                                          File "/home/gangora/anaconda3/envs/angora_env/lib/python3.6/site-packages/tensorflow_core/python/eager/def_function.py", line 568, in __call__                                                                                                    result = self._call(*args, **kwds)                                                                                    File "/home/gangora/anaconda3/envs/angora_env/lib/python3.6/site-packages/tensorflow_core/python/eager/def_function.py", line 615, in _call                                                                                                       self._initialize(args, kwds, add_initializers_to=initializers)
  File "/home/gangora/anaconda3/envs/angora_env/lib/python3.6/site-packages/tensorflow_core/python/eager/def_function.py", line 497, in _initialize
    *args, **kwds))
  File "/home/gangora/anaconda3/envs/angora_env/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py", line 2389, in _get_concrete_function_internal_garbage_collected
    graph_function, _, _ = self._maybe_define_function(args, kwargs)
  File "/home/gangora/anaconda3/envs/angora_env/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py", line 2703, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/home/gangora/anaconda3/envs/angora_env/lib/python3.6/site-packages/tensorflow_core/python/eager/function.py", line 2593, in _create_graph_function
    capture_by_value=self._capture_by_value),
  File "/home/gangora/anaconda3/envs/angora_env/lib/python3.6/site-packages/tensorflow_core/python/framework/func_graph.py", line 978, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/gangora/anaconda3/envs/angora_env/lib/python3.6/site-packages/tensorflow_core/python/eager/def_function.py", line 439, in wrapped_fn
    return weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/home/gangora/anaconda3/envs/angora_env/lib/python3.6/site-packages/tensorflow_core/python/framework/func_graph.py", line 968, in wrapper
    raise e.ag_error_metadata.to_exception(e)
ValueError: in converted code:

    /home/gangora/scripts/make_cross_correlation_testgpu.py:764 make_xcorr_err1  *
        _maxidxs = tf.vectorized_map(bootstrap, np.arange(100))#, parallel_iterations=100, back_prop=False, swap_memory=True, infer_shape=False )
    /home/gangora/anaconda3/envs/angora_env/lib/python3.6/site-packages/tensorflow_core/python/ops/parallel_for/control_flow_ops.py:394 vectorized_map
        return pfor(loop_fn, batch_size)
    /home/gangora/anaconda3/envs/angora_env/lib/python3.6/site-packages/tensorflow_core/python/ops/parallel_for/control_flow_ops.py:189 pfor
        return f()
    /home/gangora/anaconda3/envs/angora_env/lib/python3.6/site-packages/tensorflow_core/python/ops/parallel_for/control_flow_ops.py:183 f
        return _pfor_impl(loop_fn, iters, parallel_iterations=parallel_iterations)
    /home/gangora/anaconda3/envs/angora_env/lib/python3.6/site-packages/tensorflow_core/python/ops/parallel_for/control_flow_ops.py:256 _pfor_impl
        outputs.append(converter.convert(loop_fn_output))
    /home/gangora/anaconda3/envs/angora_env/lib/python3.6/site-packages/tensorflow_core/python/ops/parallel_for/pfor.py:1280 convert
        output = self._convert_helper(y)
    /home/gangora/anaconda3/envs/angora_env/lib/python3.6/site-packages/tensorflow_core/python/ops/parallel_for/pfor.py:1460 _convert_helper
        (y_op.type, y_op, converted_inputs))

    ValueError: No converter defined for ArgMax
    name: "loop_body/ArgMax"
    op: "ArgMax"
    input: "loop_body/random_normal"
    input: "loop_body/ArgMax/dimension"
    attr {
      key: "T"
      value {
        type: DT_FLOAT
      }
    }
    attr {
      key: "Tidx"
      value {
        type: DT_INT32
      }
    }
    attr {
      key: "output_type"
      value {
        type: DT_INT64
      }
    }

    inputs: [WrappedTensor(t=<tf.Tensor 'loop_body/random_normal/pfor/Add_1:0' shape=(100, 30000, 3000) dtype=float32>, is_stacked=True, is_sparse_stacked=False), WrappedTensor(t=<tf.Tensor 'loop_body/ArgMax/dimension:0' shape=() dtype=int32>, is_stacked=False, is_sparse_stacked=False)].
    Either add a converter or set --op_conversion_fallback_to_while_loop=True, which may run slower

回溯似乎建议向argamx函数添加一个转换器,但我真的不知道如何实现这样的转换器。如果代码没有变慢,我也可以尝试设置--op_conversion_fallback_to_while_loop=True,但是我不知道在哪里设置这个字段

总之,我的问题是: 如何向像argmax这样的tensorflow操作添加转换器?

谢谢


Tags: inpycoreenvhomelibpackagestf