在Python中编写和注册自定义Tensorflow操作

2024-09-29 22:05:02 发布

您现在位置:Python中文网/ 问答频道 /正文

我想用Python编写一个定制的Tensorflow操作,并在Protobuf注册表中注册,以执行解释的here等操作。原始BuffF注册是关键,因为我不会直接从Python使用这个OP,但是如果它像C++ OP那样注册并加载到Python运行时环境,那么我可以在我的环境中运行它。在

我希望代码看起来像

import tensorflow as tf
from google.protobuf import json_format
from tensorflow.python.ops.data_flow_ops import QueueBase, _as_type_list, _as_shape_list, _as_name_list

""" Missing the Python equivalent of,                                                                                                                                                                        

  class HDF5QueueOp : public ResourceOpKernel<QueueInterface> {                                                                                                                                              
  public:                                                                                                                                                                                                    
      // Implementation                                                                                                                                                                                      
  };                                                                                                                                                                                                         

  REGISTER_OP("HDF5Queue")                                                                                                                                                                                   
  .Output("handle: resource")                                                                                                                                                                                
  .Attr("filename: string")                                                                                                                                                                                  
  .Attr("datasets: list(string)")                                                                                                                                                                            
  .Attr("overwrite: bool = false")                                                                                                                                                                           
  .Attr("component_types: list(type) >= 0 = []")                                                                                                                                                             
  .Attr("shapes: list(shape) >= 0 = []")                                                                                                                                                                     
  .Attr("shared_name: string = ''")                                                                                                                                                                          
  .Attr("container: string = ''")                                                                                                                                                                            
  .Attr("capacity: int = -1")                                                                                                                                                                                
  .SetIsStateful()                                                                                                                                                                                           
  .SetShapeFn(TwoElementOutput);                                                                                                                                                                             

"""

class HDF5Queue(QueueBase):
  def __init__(self, stream_id, stream_columns, dtypes=None, capacity=100,
               shapes=None, names=None, name="hdf5_queue"):
    if not dtypes:
      dtypes = [tf.int64, tf.float32]

    if not shapes:
      shapes = [[1], [1]]

    dtypes = _as_type_list(dtypes)
    shapes = _as_shape_list(shapes, dtypes)
    names = _as_name_list(names, dtypes)
    queue_ref = _op_def_lib.apply_op("HDF5Queue", stream_id=stream_id,
                                     stream_columns=stream_columns, capacity=capacity,
                                     component_types=dtypes, shapes=shapes,
                                     name=name, container=None, shared_name=None)
    super(HDF5Queue, self).__init__(dtypes, shapes,
                                    names, queue_ref)

以上是TF的标准配置。例如,在FIFOQueue中可以看到。Python WrapperProtobuf RegistrationC++ Implementation。在编译过程中生成了一个Python包装器,我不喜欢,但是您可以通过运行grep -A 10 -B 10 -n FIFO $(find /usr/local -name "*gen_data_flow*.py") /dev/null查看它的使用位置

下面将以JSON格式转储TF图的Protobuf消息。我希望这能用一个块来为HDF5GeQueE操作转储,就像我写C++操作一样。在

^{pr2}$

Tags: nameimportnonestreamstringnamesaslist
1条回答
网友
1楼 · 发布于 2024-09-29 22:05:02

这可以用py_func完成。这里有一个例子。在

import tensorflow as tf
from google.protobuf import json_format
import sys, json, base64, numpy
from tensorflow.python.ops.script_ops import _py_funcs as py_func_registry
from tensorflow.core.protobuf.meta_graph_pb2 import MetaGraphDef

graph = tf.Graph()
graph2 = tf.Graph()

def f(x):
    return x

def g(x):
    return 2*x

with graph.as_default():
    x = tf.placeholder(tf.float32, shape=(3,), name='x')
    y = tf.py_func(f, [x], tf.float32, name='y')

    # py_func_registry._funcs.clear() # Optional line to clear the Python function registry
    msg = json.loads(json_format.MessageToJson(tf.train.export_meta_graph()))

# Change the function being used by py_func
msg['graphDef']['node'][1]['attr']['token']['s'] = base64.b64encode(py_func_registry.insert(g))

with graph2.as_default():    
    # Load graph
    meta_graph_def = MetaGraphDef()
    json_format.Parse(json.dumps(msg), meta_graph_def)
    tf.train.import_meta_graph(meta_graph_def)

    sess = tf.Session(graph=graph2)
    print sess.run('y:0', feed_dict={'x:0':numpy.array([1, 2, 3])})
    print g(numpy.array([1, 2, 3]))

相关问题 更多 >

    热门问题