如何使用生成器训练TensorFlow网络以产生输入?

2024-07-08 10:38:21 发布

您现在位置:Python中文网/ 问答频道 /正文

TensorFlowdocs描述了一系列使用TFRecordReader、TextLineReader、QueueRunner等和队列读取数据的方法。

我想做的事情要简单得多:我有一个python生成器函数,它生成一个无限序列的训练数据作为(X,y)元组(都是numpy数组,第一个维度是批处理大小)。我只想训练一个使用这些数据作为输入的网络。

有没有一个使用生成数据的生成器训练TensorFlow网络的简单示例?(按照MNIST或CIFAR示例的思路)


Tags: 数据方法函数网络numpy示例队列序列
1条回答
网友
1楼 · 发布于 2024-07-08 10:38:21

假设您有一个生成数据的函数:

 def generator(data): 
    ...
    yield (X, y)

现在您需要另一个描述模型架构的函数。它可以是处理X的任何函数,并且必须预测y作为输出(例如,神经网络)。

假设您的函数接受X和y作为输入,以某种方式从X计算y的预测,并返回y和预测y之间的损失函数(例如,在回归的情况下,交叉熵或MSE):

 def neural_network(X, y): 
    # computation of prediction for y using X
    ...
    return loss(y, y_pred)

要使模型正常工作,需要为X和y定义占位符,然后运行会话:

 X = tf.placeholder(tf.float32, shape=(batch_size, x_dim))
 y = tf.placeholder(tf.float32, shape=(batch_size, y_dim))

占位符类似于“free variables”,在通过feed_dict运行会话时需要指定它:

 with tf.Session() as sess:
     # variables need to be initialized before any sess.run() calls
     tf.global_variables_initializer().run()

     for X_batch, y_batch in generator(data):
         feed_dict = {X: X_batch, y: y_batch} 
         _, loss_value, ... = sess.run([train_op, loss, ...], feed_dict)
         # train_op here stands for optimization operation you have defined
         # and loss for loss function (return value of neural_network function)

希望你会觉得有用。但是,请记住,这不是完全工作的实现,而是一个伪代码,因为您几乎没有指定任何细节。

相关问题 更多 >

    热门问题