示例tensorflow.contrib.learn三、出口战略

2024-09-29 23:24:23 发布

您现在位置:Python中文网/ 问答频道 /正文

有人能提供Tensorflow完整工作代码的例子吗

tf.contrib.learn.ExportStrategy

文档中缺少示例。在Github或Stackoverflow上,我也找不到这个看似晦涩的Tensorflow操作的例子。在

文档:https://www.tensorflow.org/api_docs/python/tf/contrib/learn/ExportStrategy


Tags: 代码文档httpsorggithub示例tftensorflow
1条回答
网友
1楼 · 发布于 2024-09-29 23:24:23

Google CloudML在这里有一个很好的工作示例: https://github.com/GoogleCloudPlatform/cloudml-samples/tree/master/census/customestimator/trainer

您需要使用它们的完整代码才能使示例正常工作,但下面是如何使用ExportStrategy的要点:

import tensorflow as tf
from tensorflow.contrib.learn.python.learn import learn_runner
from tensorflow.contrib.learn.python.learn.utils import (
    saved_model_export_utils)
from tensorflow.contrib.training.python.training import hparam

def csv_serving_input_fn():
    """Build the serving inputs."""
    csv_row = tf.placeholder(
        shape=[None],
        dtype=tf.string
    )
    features = parse_csv(csv_row)
    # Ignore label column
    features.pop(LABEL_COLUMN)
    return tf.estimator.export.ServingInputReceiver(
        features, {'csv_row': csv_row})


def example_serving_input_fn():
    """Build the serving inputs."""
    example_bytestring = tf.placeholder(
        shape=[None],
        dtype=tf.string,
    )
    features = tf.parse_example(
        example_bytestring,
        tf.feature_column.make_parse_example_spec(INPUT_COLUMNS)
    )
    return tf.estimator.export.ServingInputReceiver(
        features, {'example_proto': example_bytestring})


def json_serving_input_fn():
  """Build the serving inputs."""
  inputs = {}
  for feat in INPUT_COLUMNS:
    inputs[feat.name] = tf.placeholder(shape=[None], dtype=feat.dtype)
  return tf.estimator.export.ServingInputReceiver(inputs, inputs)


SERVING_FUNCTIONS = {
    'JSON': json_serving_input_fn,
    'EXAMPLE': example_serving_input_fn,
    'CSV': csv_serving_input_fn
}

# Run the training job
# learn_runner pulls configuration information from environment
# variables using tf.learn.RunConfig and uses this configuration
# to conditionally execute Experiment, or param server code
learn_runner.run(
  generate_experiment_fn(
      min_eval_frequency=args.min_eval_frequency,
      eval_delay_secs=args.eval_delay_secs,
      train_steps=args.train_steps,
      eval_steps=args.eval_steps,
      export_strategies=[saved_model_export_utils.make_export_strategy(
          SERVING_FUNCTIONS[args.export_format],
          exports_to_keep=1
      )]
  ),
  run_config=tf.contrib.learn.RunConfig(model_dir=args.job_dir),
  hparams=hparam.HParams(**args.__dict__)
)

相关问题 更多 >

    热门问题