如何将数据读入Tensorflow?

2024-05-20 02:30:31 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图从CSV文件读取数据到tensorflow

https://www.tensorflow.org/versions/r0.7/how_tos/reading_data/index.html#filenames-shuffling-and-epoch-limits

公文中的示例代码如下:

col1, col2, col3, col4, col5 = tf.decode_csv(value, record_defaults=record_defaults)

要读取文件,我需要事先知道文件中有多少列和行,如果有1000列,我需要定义1000个变量,比如col1, col2, col3, col4, col5,..., col1000 ,,这看起来不是读取数据的有效方法。

我的问题

  1. 将CSV文件读入Tensorflow的最佳方法是什么?

  2. 有没有办法在Tensorflow中读取数据库(比如mongoDB)?


Tags: 文件csv方法httpstensorflowwww读取数据record
3条回答

当然,您可以实现从mongo直接读取批量随机排序训练数据,以馈送给tensorflow。下面是我的方式:

        for step in range(self.steps):


            pageNum=1;
            while(True):
                trainArray,trainLabelsArray = loadBatchTrainDataFromMongo(****)
                if len(trainArray)==0:
                    logging.info("train datas consume up!")
                    break;
                logging.info("started to train")
                sess.run([model.train_op],
                         feed_dict={self.input: trainArray,
                                    self.output: np.asarray(trainLabelsArray),
                                    self.keep_prob: params['dropout_rate']})

                pageNum=pageNum+1;

同时还需要对mongodb中的训练数据进行预处理,例如:为mongodb中的每个训练数据分配一个随机排序值。。。

  1. 你绝对不需要定义col1,col2,到col1000。。。

    一般来说,你可能会这样做:

    
    columns = tf.decode_csv(value, record_defaults=record_defaults)
    features = tf.pack(columns)
    do_whatever_you_want_to_play_with_features(features)
    
  2. 我不知道从MongoDB直接读取数据的现成方法。也许您可以编写一个简短的脚本,以Tensorflow支持的格式转换MongoDB中的数据,我建议使用二进制格式TFRecord,它比csv记录的读取速度快得多。This是一篇关于这个主题的好博客文章。或者您可以选择自己实现自定义的数据读取器,请参见这里的the official doc

def func()
    return 1,2,3,4

b = func() 

print b #(1, 2, 3, 4)

print [num for num in b] # [1, 2, 3, 4]

Hi它与tensorflow无关它的简单python不需要定义1000个变量。decode_csv返回一个元组。

不知道如何处理数据库,我认为您可以使用python,只需将数组形式的数据输入到tensorflow。

希望这对你有帮助

相关问题 更多 >