为什么在tensorflow 2中使用tf.GradientTape的培训与使用fit API的培训有不同的行为?

2024-10-02 12:23:35 发布

您现在位置:Python中文网/ 问答频道 /正文

我不熟悉使用tensorflow 2

我熟悉在tensorflow 1中使用keras。我通常使用fit方法API来训练模型。但最近在tensorflow 2中,他们引入了渴望执行。因此,我在CiFAR-10数据集上实现并比较了一个简单的图像分类器,分别在fittf.GradientTape上进行了20个时期的训练

经过几次运行,结果如下

  • 使用fitAPI训练的模型
    • 训练数据集,损失约为0.61-0.65,准确率为76%-80%
    • 验证数据集,损失约为0.8,准确率为72%-75%
  • tf.GradientTape训练的模型
    • 训练数据集,损失约为0.15-0.2,准确率为91%-94%
    • 验证数据集,损失约为1.8-2,准确率为64%-67%

我不知道为什么模型表现出不同的行为。我想我可能执行了一些错误的操作。我觉得奇怪的是,在tf.GradientTape中,模型开始更快地过度拟合训练数据集

下面是一些片段

  1. 使用fitAPI
model = SimpleClassifier(10)
model.compile(
    optimizer=Adam(),
    loss=tf.keras.losses.CategoricalCrossentropy(),
    metrics=[tf.keras.metrics.CategoricalAccuracy()]
)
model.fit(X[:split_idx, :, :, :], y[:split_idx, :], batch_size=256, epochs=20, validation_data=(X[split_idx:, :, :, :], y[split_idx:, :]))
  1. 使用tf.GradientTape
with tf.GradientTape() as tape:
    y_pred = model(tf.stop_gradient(train_X))
    loss = loss_fn(train_y, y_pred)
    gradients = tape.gradient(loss, model.trainable_weights)
model.optimizer.apply_gradients(zip(gradients, model.trainable_weights))

完整的代码可以在here in Colab中找到

参考资料


Tags: 数据模型modeltftensorflowfitkerassplit
1条回答
网友
1楼 · 发布于 2024-10-02 12:23:35

tf.GradientTape代码中几乎没有可以修复的东西:
1) trainable_variables不是trainable_weights。您希望对所有可训练变量应用梯度,而不仅仅是模型权重

# gradients = tape.gradient(loss, model.trainable_weights)
gradients = tape.gradient(loss, model.trainable_variables)

# and

# model.optimizer.apply_gradients(zip(gradients, model.trainable_weights))
model.optimizer.apply_gradients(zip(gradients, model.trainable_variables))

2)从输入张量中删除tf.stop_gradient

with tf.GradientTape() as tape:
#    y_pred = model(tf.stop_gradient(train_X))
    y_pred = model(train_X, training=True)

注意,我还添加了培训参数。它还应该包括在模型定义中,以包括依赖于phase(如BatchNormalization和Dropout)的层:

    def call(self, X, training=None):
        X = self.cnn_1(X)
        X = self.bn_1(X, training=training)
        X = self.cnn_2(X)
        X = self.max_pool_2d(X)
        X = self.dropout_1(X)

        X = self.cnn_3(X)
        X = self.bn_2(X, training=training)
        X = self.cnn_4(X)
        X = self.bn_3(X, training=training)
        X = self.cnn_5(X)
        X = self.max_pool_2d(X)
        X = self.dropout_2(X)

        X = self.flatten(X)
        X = self.dense_1(X)
        X = self.dropout_3(X, training=training)
        X = self.dense_2(X)
        return self.out(X)

通过这几项改变,我获得了略好的分数,与keras.fit结果更具可比性:

[19/20] loss: 0.64020, acc: 0.76965, val_loss: 0.71291, val_acc: 0.75318: 100%|██████████| 137/137 [00:12<00:00, 11.25it/s]
[20/20] loss: 0.62999, acc: 0.77649, val_loss: 0.77925, val_acc: 0.73219: 100%|██████████| 137/137 [00:12<00:00, 11.30it/s]

答案是: 区别可能在于Keras.fit在幕后做了这些事情

最后,为了清晰和再现性,我使用了部分培训/评估代码:

for bIdx, (train_X, train_y) in enumerate(train_batch):
            if bIdx < epoch_max_iter:
                with tf.GradientTape() as tape:
                    y_pred = model(train_X, training=True)
                    loss = loss_fn(train_y, y_pred)
                    total_loss += (np.sum(loss.numpy()) * train_X.shape[0])
                    total_num += train_X.shape[0]
                    # gradients = tape.gradient(loss, model.trainable_weights)
                    gradients = tape.gradient(loss, model.trainable_variables)
                total_acc += (metrics(train_y, y_pred) * train_X.shape[0])

                running_loss = (total_loss/total_num)
                running_acc = (total_acc/total_num)
                # model.optimizer.apply_gradients(zip(gradients, model.trainable_weights))
                model.optimizer.apply_gradients(zip(gradients, model.trainable_variables))

                pbar.set_description("[{}/{}] loss: {:.5f}, acc: {:.5f}".format(e, epochs, running_loss, running_acc))
                pbar.refresh()
                pbar.update()

评价之一:

# Eval loop
        # Calculate something wrong here
        val_total_loss = 0
        val_total_acc = 0
        total_val_num = 0
        for bIdx, (val_X, val_y) in enumerate(val_batch):
            if bIdx >= max_val_iterations:
                break
            y_pred = model(val_X, training=False)

相关问题 更多 >

    热门问题