在Pythorch中手动更新权重时,渐变为零

2024-10-01 22:38:08 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图实现一个简单的神经网络,用AUTOGRAD手动更新MNIST的权重,类似于给定的AUTOGRAD示例here。这是我的代码:

import os
import sys

import torch
import torchvision
class datasets:
    """Helper for extracting datasets."""

    def __init__(self, root='data/', batch_size=25):
        if not os.path.exists(root):
            os.mkdir(root)
        self.root = root
        self.batch_size = batch_size

    def get_mnist_loaders(self):
        train_data = torchvision.datasets.MNIST(
                root=self.root, train=True, download=True)
        test_data = torchvision.datasets.MNIST(
                root=self.root, train=False, download=True)


        train_loader = torch.utils.data.DataLoader(
                dataset=train_data, batch_size=self.batch_size, shuffle=True)
        test_loader = torch.utils.data.DataLoader(
                dataset=test_data, batch_size=self.batch_size, shuffle=False)

        return train_loader, test_loader

    def create_batches(self, data, labels, batch_size):
        return [(data[i:i+batch_size], labels[i:i+batch_size])
            for i in range(0, len(data), max(1, batch_size))]

def train1():
    dtype = torch.float
    n_inputs = 28*28
    n_hidden1 = 300
    n_hidden2 = 100
    n_outputs = 10
    batch_size = 200
    n_epochs = 25
    learning_rate = 0.01
    test_step = 100 
    device = torch.device("cpu")

    datasets = Datasets(batch_size=batch_size)
    train_loader, test_loader = datasets.get_mnist_loaders()

    def feed_forward(X):
        x_shape = list(X.size())
        X = X.view(x_shape[0], x_shape[1]*x_shape[2])
        hidden1 = torch.mm(X, w1)
        hidden1 += b1
        hidden1 = hidden1.clamp(min=0)
        hidden2 = torch.mm(hidden1, w2) + b2
        hidden2 = hidden2.clamp(min=0)
        logits = torch.mm(hidden2, w3) + b3
        softmax = pytorch_softmax(logits)
        return softmax

    def accuracy(y_pred, y):

        if list(y_pred.size()) != list(y.size()):
            raise ValueError('Inputs have different shapes.')

        total_correct = 0
        total = 0
        for i, (y1, y2) in enumerate(zip(y_pred, y)):
            if y1 == y2:
                total_correct += 1
            total += 1

        return total_correct / total

    w1 = torch.randn(n_inputs, n_hidden1, device=device, dtype=dtype, requires_grad=True)
    b1 = torch.nn.Parameter(torch.zeros(n_hidden1), requires_grad=True)

    w2 = torch.randn(n_hidden1, n_hidden2, requires_grad=True)
    b2 = torch.nn.Parameter(torch.zeros(n_hidden2), requires_grad=True)

    w3 = torch.randn(n_hidden2, n_outputs, dtype=dtype, requires_grad=True)
    b3 = torch.nn.Parameter(torch.zeros(n_outputs), requires_grad=True)

    pytorch_softmax = torch.nn.Softmax(0)
    pytorch_cross_entropy = torch.nn.CrossEntropyLoss(reduction='elementwise_mean')

    step = 0
    for epoch in range(n_epochs):
        batches = datasets.create_batches(train_loader.dataset.train_data,
                                          train_loader.dataset.train_labels,
                                          batch_size)
        for x, y in batches:
            step += 1

            softmax = feed_forward(x.float())
            vals, y_pred = torch.max(softmax, 1)
            accuracy_ = accuracy(y_pred, y)
            cross_entropy = pytorch_cross_entropy(softmax, y)

            print(epoch, step, cross_entropy.item(), accuracy_)

            cross_entropy.backward()

            with torch.no_grad():
                w1 -= learning_rate * w1.grad
                w2 -= learning_rate * w2.grad
                w3 -= learning_rate * w3.grad

                b1 -= learning_rate * b1.grad
                b2 -= learning_rate * b2.grad
                b3 -= learning_rate * b3.grad

                w1.grad.zero_()
                w2.grad.zero_()
                w3.grad.zero_()

                b1.grad.zero_()
                b2.grad.zero_()
                b3.grad.zero_()

if __name__ == '__main__':
    train1()

然而,网络似乎没有训练。当我打印部分渐变(例如w1.grad.data[:10, :10])时,它们由零组成。我尝试过使用weight.dataweight.grad.data来更新权重,并尝试删除w.grad.zero_()部分(即使在示例中也是如此),但这没有帮助。这里有什么问题?在


Tags: selftruedatasizeratebatchtrainroot
2条回答

这里有三个问题。在

首先,你使用softmax的轴是错误的。它应该在最后一个轴上。在

pytorch_softmax = torch.nn.Softmax(-1)

其次,你的logits由非常大的数字组成。由此得到的导数是一个非常小的数字,因此你看到的是0。在

^{pr2}$

您可以做的一些事情包括规范化数据、添加BatchNorm、钳制等。我可以看到您的数据X是一个值从0到255的张量。在

第三,你不应该用^{}来包装你的张量,因为它们只与nn.Module类一起使用。在

使用张量填充网络时,默认情况下不计算渐变。为了使其工作,您可以将FloatTensor包装成torch.autograd.Variable,或者设置张量的属性requires_gradHere is an example.

相关问题 更多 >

    热门问题