Tensorflow:如何在QueueRunn中使用“新”数据集API

2024-10-06 13:35:06 发布

您现在位置:Python中文网/ 问答频道 /正文

基本上我有一个要处理的图像列表。 我需要在加载后做一些预处理(数据扩充),然后输入到TF的主图中。 目前我正在使用一个定制的生成器,它获取一个路径列表,生成一对张量(图像),并通过占位符向网络馈送。每个批次的顺序处理时间约为0.5s。在

我刚刚阅读了datasetAPI,可以通过使用.from_generator()函数直接使用它,并且可以直接使用.get_next()作为输入。在

但是QueueRunner如何融入框架中呢?Dataset是否隐式地利用queue+dequeue来维护其generator/get_next管道,或者它要求我以后显式地馈送到FIFOQueue中?如果答案是后一个,那么维护管道以训练+验证多个random_shuffle时间段的最佳实践是什么?(我的意思是,我需要维护多少DS/queueRunner,在哪里设置洗牌和时间?)在


Tags: 数据from图像路径网络列表get管道
1条回答
网友
1楼 · 发布于 2024-10-06 13:35:06

如果使用数据集API,则不必使用QueueRunner来拥有队列/缓冲区。可以使用数据集API创建队列/缓冲区,并对数据进行预处理并并行训练网络。如果有数据集,可以使用prefetch functionshuffle function创建队列/缓冲区。在

有关更多信息,请参阅official tutorial on the Dataset API。在

以下是在CPU上使用预处理的预取缓冲区的示例:

 NUM_THREADS = 8
 BUFFER_SIZE = 100

 data = ...
 labels = ...
 inputs = (data, labels)

 def pre_processing(data_, labels_):
     with tf.device("/cpu:0"):
         # do some pre-processing here
         return data_, labels_

 dataset_source = tf.data.Dataset.from_tensor_slices(inputs)
 dataset = dataset_source.map(pre_processing, num_parallel_calls=NUM_THREADS)

 dataset = dataset.repeat(1)  # repeats for one epoch
 dataset = dataset.prefetch(BUFFER_SIZE)

 iterator = tf.data.Iterator.from_structure(dataset.output_types,
                                            dataset.output_shapes)
 next_element = iterator.get_next()
 init_op = iterator.make_initializer(dataset)

 with tf.Session() as sess:
     sess.run(init_op)
     while True:
         try:
             sess.run(next_element)
         except tf.errors.OutOfRangeError:
             break

相关问题 更多 >