使用时启用XLAtf.contrib.学习.估算

2024-06-01 08:20:15 发布

您现在位置:Python中文网/ 问答频道 /正文

看到了在张量流图上打开XLA的潜在(和高度实验性的)好处,我想我应该尝试一下。在

问题:使用^{}时,如何启用JIT XLA?在

我可以标记JIT XLA的某些操作

with tf.device("/job:localhost/replica:0/task:0/device:XLA_GPU:0"):
    output = tf.add(input1, input2)

然而,文件警告说,这仅仅是为了测试。我想用推荐的方法来做这个

^{pr2}$

但是我想不出从model_fn外部设置tf.Session的方法。在

此伪代码可以更好地澄清问题:

def char_cnn_model(features, target, mode, params, model_dir):
    # do stuff to build the CNN

    ...

    return tf.contrib.learn.ModelFnOps(mode=mode,
        predictions=predictions_dict,
        loss=loss,
        train_op=train_op,
        eval_metric_ops=eval_metric_ops)

def main(unused_argv):
    # load training data, test data etc

    ...

    classifier = learn.Estimator(
        model_fn=char_cnn_model,  # defined above
        model_dir=model_dir,
        params=params,
        config=tf.contrib.learn.RunConfig(save_checkpoints_secs=60,
                                          log_device_placement=True,
                                          tf_random_seed=7))
    classifier.fit(
        X_train,
        y_train,
        steps=5000,
        monitors=[validation_monitor])  # validation_monitor defined in main

    accuracy_score = classifier.evaluate(x=X_test, y=y_test)                                      

^{}似乎是一个很好的候选者,但是它没有为会话提供一些内容(我想这是有道理的,为什么RunConfig会有一个指向现有会话的指针?)在

我可能想得太多了,^{}可能是我所需要的,但是我可以在会话创建后修改它的配置吗?在


Tags: 方法testmodelmodedevicetfdefdir
2条回答

现在可以通过RunConfigsession_config参数来实现。 下面是一个例子:

session_config = tf.ConfigProto()                                               
optimizer_options = session_config.graph_options.optimizer_options               
if xla:                                                     
    optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1           

run_config = tf.estimator.RunConfig(                                        
    session_config=session_config,  # use this session config
    log_device_placement=True                                           
)                                                                           

base_classifier = tf.estimator.Estimator(                                   
    model_fn=model_fn,                                                      
    model_dir=model_dir,                                                    
    config=run_config,                                                      
    params=model_params                                                                
)                          

看来还没有。MonitoredSession is being instantiated directly和选项在本地传递。 您唯一的选择是使用XLA配置的受监视会话子类化并重写train方法。在

相关问题 更多 >