基于张量流检查点的推理

2024-06-14 19:14:03 发布

您现在位置:Python中文网/ 问答频道 /正文

我将字符(x_train)输入到this link的示例13中定义的RNN模型。下面是代码对应的模型定义、输入预处理和训练

def char_rnn_model(features, target):
    """Character level recurrent neural network model to predict classes."""
    target = tf.one_hot(target, 15, 1, 0)
    #byte_list = tf.one_hot(features, 256, 1, 0)
    byte_list = tf.cast(tf.one_hot(features, 256, 1, 0), dtype=tf.float32)
    byte_list = tf.unstack(byte_list, axis=1)

    cell = tf.contrib.rnn.GRUCell(HIDDEN_SIZE)
    _, encoding = tf.contrib.rnn.static_rnn(cell, byte_list, dtype=tf.float32)


    logits = tf.contrib.layers.fully_connected(encoding, 15, activation_fn=None)
    #loss = tf.contrib.losses.softmax_cross_entropy(logits, target)
    loss = tf.contrib.losses.softmax_cross_entropy(logits=logits, onehot_labels=target)

    train_op = tf.contrib.layers.optimize_loss(
      loss,
      tf.contrib.framework.get_global_step(),
      optimizer='Adam',
      learning_rate=0.001)

    return ({
      'class': tf.argmax(logits, 1),
      'prob': tf.nn.softmax(logits)
    }, loss, train_op)

# pre-process 
char_processor = learn.preprocessing.ByteProcessor(MAX_DOCUMENT_LENGTH)
x_train = np.array(list(char_processor.fit_transform(x_train)))
x_test = np.array(list(char_processor.transform(x_test)))

# train
model_dir = "model"
classifier = learn.Estimator(model_fn=char_rnn_model,model_dir=model_dir)
count=0
n_epoch = 20
while count<n_epoch:
        print("\nEPOCH " + str(count))
        classifier.fit(x_train, y_train, steps=1000,batch_size=10)
        y_predicted = [
              p['class'] for p in classifier.predict(
              x_test, as_iterable=True,batch_size=10)
        ]
        score = metrics.accuracy_score(y_test, y_predicted)
        print('Accuracy: {0:f}'.format(score))
        count+=1

print(metrics.classification_report(y_test, predicted))

经过培训后,目录model_dir中填充了名为:

  • 型号ckpt-???????索引
  • 型号.ckpt-???????.meta
  • 型号:ckpt-??????.数据-00000-of-00001

保存模型的权重和图表。我想用它们来推断

我设法用以下代码加载它们:

new_saver = tf.train.import_meta_graph(meta_file)
sess = tf.Session()
new_saver.restore(sess, tf.train.latest_checkpoint(model_dir))

其中meta_file是model.ckpt-???????.meta文件之一的路径

我想把训练好的模型应用到一个新的字符序列上。所以我打了:

new_input = ["Some Sequence of character"]
new_input_processed = np.array(list(char_processor.transform(new_input)))
output = sess.run(new_input_processed)

但我有以下错误:

   ---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-13-982f2b9b18b3> in <module>()
----> 1 output = sess.run(new_input_processed)

/home/user/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in run(self, fetches, feed_dict, options, run_metadata)
    898     try:
    899       result = self._run(None, fetches, feed_dict, options_ptr,
--> 900                          run_metadata_ptr)
    901       if run_metadata:
    902         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

/home/user/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _run(self, handle, fetches, feed_dict, options, run_metadata)
   1118     # Create a fetch handler to take care of the structure of fetches.
   1119     fetch_handler = _FetchHandler(
-> 1120         self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
   1121 
   1122     # Run request and get response.

/home/user/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in __init__(self, graph, fetches, feeds, feed_handles)
    425     """
    426     with graph.as_default():
--> 427       self._fetch_mapper = _FetchMapper.for_fetch(fetches)
    428     self._fetches = []
    429     self._targets = []

/home/user/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in for_fetch(fetch)
    251         if isinstance(fetch, tensor_type):
    252           fetches, contraction_fn = fetch_fn(fetch)
--> 253           return _ElementFetchMapper(fetches, contraction_fn)
    254     # Did not find anything.
    255     raise TypeError('Fetch argument %r has invalid type %r' % (fetch,

/home/user/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in __init__(self, fetches, contraction_fn)
    284         raise TypeError('Fetch argument %r has invalid type %r, '
    285                         'must be a string or Tensor. (%s)' %
--> 286                         (fetch, type(fetch), str(e)))
    287       except ValueError as e:
    288         raise ValueError('Fetch argument %r cannot be interpreted as a '

TypeError: Fetch argument array([[ 83, 111, 109, 101,  32,  83, 101, 113, 117, 101, 110,  99, 101,
         32, 111, 102,  32,  99, 104,  97, 114,  97,  99, 116, 101, 114,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0]], dtype=uint8) has invalid type <type 'numpy.ndarray'>, must be a string or Tensor. (Can not convert a ndarray into a Tensor or Operation.)

我使用的是TensorFlow1.8.0和Python2.7.14

===================编辑 ===================

它可能是应该使用的函数export_savedmodelhttps://www.tensorflow.org/versions/r1.15/api_docs/python/tf/contrib/learn/Estimator?authuser=0&hl=ro),但我不理解它的所有参数


Tags: runinselfnewmodeltffeedtrain