看到了在张量流图上打开XLA的潜在(和高度实验性的)好处,我想我应该尝试一下。在
我可以标记JIT XLA的某些操作
with tf.device("/job:localhost/replica:0/task:0/device:XLA_GPU:0"):
output = tf.add(input1, input2)
然而,文件警告说,这仅仅是为了测试。我想用推荐的方法来做这个
^{pr2}$但是我想不出从model_fn
外部设置tf.Session
的方法。在
此伪代码可以更好地澄清问题:
def char_cnn_model(features, target, mode, params, model_dir):
# do stuff to build the CNN
...
return tf.contrib.learn.ModelFnOps(mode=mode,
predictions=predictions_dict,
loss=loss,
train_op=train_op,
eval_metric_ops=eval_metric_ops)
def main(unused_argv):
# load training data, test data etc
...
classifier = learn.Estimator(
model_fn=char_cnn_model, # defined above
model_dir=model_dir,
params=params,
config=tf.contrib.learn.RunConfig(save_checkpoints_secs=60,
log_device_placement=True,
tf_random_seed=7))
classifier.fit(
X_train,
y_train,
steps=5000,
monitors=[validation_monitor]) # validation_monitor defined in main
accuracy_score = classifier.evaluate(x=X_test, y=y_test)
^{
我可能想得太多了,^{
现在可以通过RunConfig的
session_config
参数来实现。 下面是一个例子:看来还没有。MonitoredSession is being instantiated directly和选项在本地传递。 您唯一的选择是使用XLA配置的受监视会话子类化并重写train方法。在
相关问题 更多 >
编程相关推荐