张量流模型在大尺度上需要指数时间

2024-09-30 18:29:39 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在使用TensorFlow为诗人检测服装图像中的特征。我训练了4个不同的模特(袖子、形状、长度和裙摆)。现在我将图像url传递给每个模型并存储结果。因为我有大量的数据(100k图像),所以使用spark一次播放4个模型,并通过图像RDD来检测特征。这需要指数级的时间。它从3秒/图像开始,并不断增加执行时间。当脚本已经检测到10k图像时,它的执行时间达到8秒/图像。我是新来的Tensorflow,将非常感谢如果有任何想法,使执行时间线性。你知道吗

def getLabelDresses(file_name):
    resultDict = {}
    t = read_tensor_from_image_file(file_name,
                              input_height=input_height,
                              input_width=input_width,
                              input_mean=input_mean,
                              input_std=input_std)
    input_name = "import/" + input_layer
    output_name = "import/" + output_layer



    with tf.Graph().as_default() as g:

        graph_def = tf.GraphDef()

        graph_def.ParseFromString(model_data_hemline.value)

        tf.import_graph_def(graph_def)

        input_operation_hemline = g.get_operation_by_name(input_name);
        output_operation_hemline = g.get_operation_by_name(output_name);

        with tf.Session() as sess:
            results = sess.run(output_operation_hemline.outputs[0],{input_operation_hemline.outputs[0]: t})

        results = np.squeeze(results)

        top_k = results.argsort()[-1:][::-1]
        labels = load_labels(label_file_hemline)
        resultDict['hemline'] = labels[top_k[0]]

    with tf.Graph().as_default() as g:

        graph_def = tf.GraphDef()

        graph_def.ParseFromString(model_data_shape.value)

        tf.import_graph_def(graph_def)

        input_operation_shape = g.get_operation_by_name(input_name);
        output_operation_shape = g.get_operation_by_name(output_name);

        with tf.Session() as sess:
            results = sess.run(output_operation_shape.outputs[0],{input_operation_shape.outputs[0]: t})

        results = np.squeeze(results)

        top_k = results.argsort()[-1:][::-1]
        labels = load_labels(label_file_shape)
        resultDict['shape'] = labels[top_k[0]]

    with tf.Graph().as_default() as g:

        graph_def = tf.GraphDef()

        graph_def.ParseFromString(model_data_length.value)

        tf.import_graph_def(graph_def)

        input_operation_length = g.get_operation_by_name(input_name);
        output_operation_length = g.get_operation_by_name(output_name);

        with tf.Session() as sess:
            results = sess.run(output_operation_length.outputs[0],{input_operation_length.outputs[0]: t})

        results = np.squeeze(results)

        top_k = results.argsort()[-1:][::-1]
        labels = load_labels(label_file_length)
        resultDict['length'] = labels[top_k[0]]

    with tf.Graph().as_default() as g:

        graph_def = tf.GraphDef()

        graph_def.ParseFromString(model_data_sleeve.value)

        tf.import_graph_def(graph_def)

        input_operation_sleeve = g.get_operation_by_name(input_name);
        output_operation_sleeve = g.get_operation_by_name(output_name);

        with tf.Session() as sess:
            results = sess.run(output_operation_sleeve.outputs[0],{input_operation_sleeve.outputs[0]: t})

        results = np.squeeze(results)

        top_k = results.argsort()[-1:][::-1]
        labels = load_labels(label_file_sleeve)
        resultDict['sleeve'] = labels[top_k[0]]     

    return resultDict;


model_file_hemline = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/hemline/retrained_graph_hemline.pb"
label_file_hemline = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/hemline/retrained_labels_hemline.txt"
model_file_length = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/length/retrained_graph_length.pb"
label_file_length = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/length/retrained_labels_length.txt"
model_file_shape = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/shape/retrained_graph_shape.pb"
label_file_shape = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/shape/retrained_labels_shape.txt"
model_file_sleeve = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/sleeve/retrained_graph_sleeve.pb"
label_file_sleeve = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/sleeve/retrained_labels_sleeve.txt"

with gfile.FastGFile(model_file_hemline, "rb") as f:
    model_data = f.read()
    model_data_hemline = sc.broadcast(model_data)
with gfile.FastGFile(model_file_length, "rb") as f:
    model_data = f.read()
    model_data_length = sc.broadcast(model_data)
with gfile.FastGFile(model_file_shape, "rb") as f:
    model_data = f.read()
    model_data_shape = sc.broadcast(model_data)
with gfile.FastGFile(model_file_sleeve, "rb") as f:
    model_data = f.read()
    model_data_sleeve = sc.broadcast(model_data)

def calculate(row):
    path = "/tmp/"+row.guid
    url = row.modelno
    print(path, url)
    if(url is not None):
        import urllib.request
        urllib.request.urlretrieve(url, path)
        t1=time.time() 
        result = getLabelDresses(path)
        print(time.time()-t1)
        print(result)
        return row
    return row

product2.rdd.map(calculate).collect()

Tags: nameinputdatalabelsmodeltfdefas
1条回答
网友
1楼 · 发布于 2024-09-30 18:29:39

代码中对getLabelDresses的每次调用都会向图形中添加操作。你知道吗

将代码分为设置(模型加载)部分,执行一次,以及为每个图像执行的执行部分。后者应该只包含对Session.run的调用。你知道吗

另一个选择是在使用^{}处理下一个图像之前清除图形。但这不太可取。你知道吗

相关问题 更多 >