在中手动更改学习速率火车站

2024-10-01 15:45:18 发布

您现在位置:Python中文网/ 问答频道 /正文

在 问题是,仅仅改变tf.train.AdamOptimizer中的learning_rate参数是否会导致行为的任何改变: 假设代码如下所示:

myLearnRate = 0.001
...
output = tf.someDataFlowGraph
trainLoss = tf.losses.someLoss(output)
trainStep = tf.train.AdamOptimizer(learning_rate=myLearnRate).minimize(trainLoss)
with tf.Session() as session:
    #first trainstep
    session.run(trainStep, feed_dict = {input:someData, target:someTarget})
    myLearnRate = myLearnRate * 0.1
    #second trainstep
    session.run(trainStep, feed_dict = {input:someData, target:someTarget})

减少的myLearnRate现在会应用到第二个trainStep中吗?这就是,节点trainStep的创建是否只评估一次:

^{pr2}$

或者是每隔session.run(train_step)计算一次?我怎么能在Tensorflow中检查我的AdamOptimizer,它是否改变了学习速度。在

免责声明1:我知道手动更改LearnRate是不好的做法。 免责声明2:我知道有一个类似的问题,但是通过输入一个张量作为learnRate来解决,这个问题在每个trainStephere)中更新。它使我倾向于假设它只使用张量作为AdamOptimizer中的learning_rate的输入,但我对此既不确定,也无法理解其背后的推理。在


Tags: runinputoutputratesessiontffeedtrain
2条回答

是的,优化器只创建一次:

tf.train.AdamOptimizer(learning_rate=myLearnRate)

它会记住传递的学习速率(事实上,如果你传递一个浮点数,它会为它创建一个张量),并且你未来对myLearnRate的更改不会影响它。在

是的,您可以创建一个占位符并将其传递给session.run(),如果您真的愿意的话。但是,正如你所说,这是相当罕见的,可能意味着你用错误的方式解决了你的原产地问题。在

在 简单的回答是不,你的新学习率不适用。TF在您第一次运行图时构建它,在Python端更改某些内容不会转化为在运行时改变图形。但是,您可以很容易地将新的学习率输入图表:

# Use a placeholder in the graph for your user-defined learning rate instead
learning_rate = tf.placeholder(tf.float32)
# ...
trainStep = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(trainLoss)
applied_rate = 0.001  # we will update this every training step
with tf.Session() as session:
    #first trainstep, feeding our applied rate to the graph
    session.run(trainStep, feed_dict = {input:someData,
                                        target:someTarget,
                                        learning_rate: applied_rate})
    applied_rate *= 0.1  # update the rate we feed to the graph
    #second trainstep
    session.run(trainStep, feed_dict = {input:someData,
                                        target:someTarget,
                                        learning_rate: applied_rate})

相关问题 更多 >

    热门问题