我如何获得每一个历元而不是每批的损失?

2024-05-17 04:35:08 发布

您现在位置:Python中文网/ 问答频道 /正文

在我的理解中,epoch是对整个数据集的任意重复运行,而这些数据又被分部分处理,即所谓的批处理。在每一个train_on_batch计算完一个损失后,权重将被更新,下一批将得到更好的结果。这些损失是我到NNs的质量和学习状态的指标。在

在几种来源中,损耗是按历元计算(并打印)的。因此,我不确定我是否做得对。在

现在我的GAN是这样的:

for epoch:
  for batch:

    fakes = generator.predict_on_batch(batch)

    dlc = discriminator.train_on_batch(batch, ..)
    dlf = discriminator.train_on_batch(fakes, ..)
    dis_loss_total = 0.5 *  np.add(dlc, dlf)

    g_loss = gan.train_on_batch(batch,..)

    # save losses to array to work with later

这些损失是针对每批产品的。我怎样才能得到它们?顺便说一句:我需要一个时代的损失吗?为了什么?在


Tags: to数据foronbatchtrain权重损失
1条回答
网友
1楼 · 发布于 2024-05-17 04:35:08

没有直接的方法来计算一个时代的损失。实际上,一个时期的损失通常被定义为该时期批次损失的平均值。因此,您可以在一个epoch中累积损失值,并在最后除以epoch中的批数:

epoch_loss = []
for epoch in range(n_epochs):
    acc_loss = 0.
    for batch in range(n_batches):
        # do the training 
        loss = model.train_on_batch(...)
        acc_loss += loss
    epoch_loss.append(acc_loss / n_batches)

至于另一个问题,epoch loss的一个用法可能是用它作为停止训练的一个指标(然而,验证损失通常用于此目的,而不是训练损失)。在

相关问题 更多 >