有 Java 编程相关的问题?

你可以在下面搜索框中键入要查询的问题!

Tensorflow Java多GPU推理

我有一台带有多个GPU的服务器,希望在java应用程序的模型推理过程中充分利用它们。 默认情况下,tensorflow会占用所有可用的GPU,但只使用第一个GPU

我可以想出三种方法来解决这个问题:

  1. 在进程级别限制设备可见性,即使用CUDA_VISIBLE_DEVICES环境变量

    这需要我运行几个java应用程序实例,并在它们之间分配流量。这不是个诱人的主意

  2. 在一个应用程序中启动几个会话,并尝试通过ConfigProto为每个会话分配一个设备:

    public class DistributedPredictor {
    
        private Predictor[] nested;
        private int[] counters;
    
        // ...
    
        public DistributedPredictor(String modelPath, int numDevices, int numThreadsPerDevice) {
            nested = new Predictor[numDevices];
            counters = new int[numDevices];
    
            for (int i = 0; i < nested.length; i++) {
                nested[i] = new Predictor(modelPath, i, numDevices, numThreadsPerDevice);
            }
        }
    
        public Prediction predict(Data data) {
            int i = acquirePredictorIndex();
            Prediction result = nested[i].predict(data);
            releasePredictorIndex(i);
            return result;
        }
    
        private synchronized int acquirePredictorIndex() {
            int i = argmin(counters);
            counters[i] += 1;
            return i;
        }
    
        private synchronized void releasePredictorIndex(int i) {
            counters[i] -= 1;
        }
    }
    
    
    public class Predictor {
    
        private Session session;
    
        public Predictor(String modelPath, int deviceIdx, int numDevices, int numThreadsPerDevice) {
    
            GPUOptions gpuOptions = GPUOptions.newBuilder()
                    .setVisibleDeviceList("" + deviceIdx)
                    .setAllowGrowth(true)
                    .build();
    
            ConfigProto config = ConfigProto.newBuilder()
                    .setGpuOptions(gpuOptions)
                    .setInterOpParallelismThreads(numDevices * numThreadsPerDevice)
                    .build();
    
            byte[] graphDef = Files.readAllBytes(Paths.get(modelPath));
            Graph graph = new Graph();
            graph.importGraphDef(graphDef);
    
            this.session = new Session(graph, config.toByteArray());
        }
    
        public Prediction predict(Data data) {
            // ...
        }
    }
    

    这种方法乍一看似乎效果不错。然而,会话偶尔会忽略setVisibleDeviceList选项,所有会话都会针对第一个设备,导致内存不足崩溃

  3. 使用tf.device()规范以python中的多塔方式构建模型。在java方面,在共享会话中为不同的Predictor提供不同的塔

    我觉得自己很笨重,而且习惯上是错误的

更新:按照@ash的建议,还有另一个选项:

  1. 通过修改现有图形的定义(graphDef),为其每个操作分配适当的设备

    为了完成这项工作,可以修改方法2的代码:

    public class Predictor {
    
        private Session session;
    
        public Predictor(String modelPath, int deviceIdx, int numDevices, int numThreadsPerDevice) {
    
            byte[] graphDef = Files.readAllBytes(Paths.get(modelPath));
            graphDef = setGraphDefDevice(graphDef, deviceIdx)
    
            Graph graph = new Graph();
            graph.importGraphDef(graphDef);
    
            ConfigProto config = ConfigProto.newBuilder()
                    .setAllowSoftPlacement(true)
                    .build();
    
            this.session = new Session(graph, config.toByteArray());
        }
    
        private static byte[] setGraphDefDevice(byte[] graphDef, int deviceIdx) throws InvalidProtocolBufferException {
            String deviceString = String.format("/gpu:%d", deviceIdx);
    
            GraphDef.Builder builder = GraphDef.parseFrom(graphDef).toBuilder();
            for (int i = 0; i < builder.getNodeCount(); i++) {
                builder.getNodeBuilder(i).setDevice(deviceString);
            }
            return builder.build().toByteArray();
        }
    
        public Prediction predict(Data data) {
            // ...
        }
    }
    

    就像前面提到的其他方法一样,这个方法并没有让我摆脱在设备之间手动分发数据的束缚。但至少它工作稳定,相对容易实现。总的来说,这看起来是(几乎)正常的技术

使用tensorflow java API有没有一种优雅的方法来完成这样的基本任务?任何想法都将不胜感激


共 (2) 个答案

  1. # 1 楼答案

    简言之:有一个变通方法,每个GPU有一个会话

    详情:

    一般流程是TensorFlow运行时尊重为图中的操作指定的设备。如果没有为一个操作指定设备,那么它会根据一些试探法“放置”它。这些启发式方法目前导致“在GPU上放置操作:0,如果GPU可用,并且操作有GPU内核”(^{},如果您感兴趣的话)

    我认为您所要求的是TensorFlow的合理功能要求——能够将序列化图中的设备视为“虚拟”设备,以便在运行时映射到一组“phyiscal”设备,或者设置“默认设备”。此功能当前不存在。向ConfigProto添加这样一个选项可能需要提交一个功能请求

    我可以提出一个临时解决办法。首先,对你提出的解决方案进行一些评论

    1. 你的第一个想法肯定会奏效,但正如你所指出的,很麻烦

    2. ConfigProto中使用visible_device_list的设置不太有效,因为这实际上是每个进程的设置,在进程中创建第一个会话后会被忽略。这当然没有得到应有的记录(不幸的是,这出现在每会话配置中)。然而,这解释了为什么你在这里的建议不起作用,为什么你仍然看到一个单一的GPU正在使用

    3. 这可能管用

    另一种选择是以不同的图形结束(操作显式地放在不同的GPU上),结果是每个GPU有一个会话。类似的内容可用于编辑图形并为每个操作显式分配设备:

    public static byte[] modifyGraphDef(byte[] graphDef, String device) throws Exception {
      GraphDef.Builder builder = GraphDef.parseFrom(graphDef).toBuilder();
      for (int i = 0; i < builder.getNodeCount(); ++i) {
        builder.getNodeBuilder(i).setDevice(device);
      }
      return builder.build().toByteArray();
    } 
    

    之后,您可以使用以下方法为每个GPU创建GraphSession

    final int NUM_GPUS = 8;
    // setAllowSoftPlacement: Just in case our device modifications were too aggressive
    // (e.g., setting a GPU device on an operation that only has CPU kernels)
    // setLogDevicePlacment: So we can see what happens.
    byte[] config =
        ConfigProto.newBuilder()
            .setLogDevicePlacement(true)
            .setAllowSoftPlacement(true)
            .build()
            .toByteArray();
    Graph graphs[] = new Graph[NUM_GPUS];
    Session sessions[] = new Session[NUM_GPUS];
    for (int i = 0; i < NUM_GPUS; ++i) {
      graphs[i] = new Graph();
      graphs[i].importGraphDef(modifyGraphDef(graphDef, String.format("/gpu:%d", i)));
      sessions[i] = new Session(graphs[i], config);    
    }
    

    然后使用sessions[i]在GPU#i上执行图形

    希望有帮助

  2. # 2 楼答案

    在python中,可以按如下方式执行:

    def get_frozen_graph(graph_file):
        """Read Frozen Graph file from disk."""
        with tf.gfile.GFile(graph_file, "rb") as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
        return graph_def
    
    trt_graph1 = get_frozen_graph('/home/ved/ved_1/frozen_inference_graph.pb')
    
    with tf.device('/gpu:1'):
        [tf_input_l1, tf_scores_l1, tf_boxes_l1, tf_classes_l1, tf_num_detections_l1, tf_masks_l1] = tf.import_graph_def(trt_graph1, 
                        return_elements=['image_tensor:0', 'detection_scores:0', 
                        'detection_boxes:0', 'detection_classes:0','num_detections:0', 'detection_masks:0'])
        
    tf_sess1 = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    
    trt_graph2 = get_frozen_graph('/home/ved/ved_2/frozen_inference_graph.pb')
    
    with tf.device('/gpu:0'):
        [tf_input_l2, tf_scores_l2, tf_boxes_l2, tf_classes_l2, tf_num_detections_l2] = tf.import_graph_def(trt_graph2, 
                        return_elements=['image_tensor:0', 'detection_scores:0', 
                        'detection_boxes:0', 'detection_classes:0','num_detections:0'])
        
    tf_sess2 = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))