图神经网络中的极低梯度值

2024-10-01 15:41:26 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在使用Pytorch几何库实现一个回归问题的图形神经网络。该模型定义为:

import torch
from torch.nn.parameter import Parameter
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.nn import GCNConv

class Model(nn.Module):
    def __init__(self, nin=1, nhid1=128, nout=128, hid_l=64, out_l=1):
        super(Model, self).__init__()
       
        self.gc1 = GCNConv(in_channels= nin, out_channels= nhid1)
        self.gc2 = GCNConv(in_channels= nhid1, out_channels= nout)
        self.lay1 = nn.Linear(nout ,hid_l)
        self.l0 = nn.Linear(hid_l,hid_l)
        self.l1 = nn.Linear(hid_l,hid_l)
        self.lay2 = nn.Linear(hid_l ,out_l)
        self.active = nn.LeakyReLU(0.1)
       
        with torch.no_grad():
            self.gc1.weight = Parameter(nn.init.uniform_(torch.empty(nin,nhid1),a=0.0,b=1.0))
            self.gc1.bias = Parameter(nn.init.uniform_(torch.empty(nhid1),a=0.0,b=1.0))
            self.gc2.weight = Parameter(nn.init.uniform_(torch.empty(nhid1,nout),a=0.0,b=1.0))
            self.gc2.bias = Parameter(nn.init.uniform_(torch.empty(nout),a=0.0,b=1.0))
            self.lay1.weight = Parameter(nn.init.uniform_(torch.empty(hid_l, nout ),a=0.0,b=1.0))
            self.l0.weight = Parameter(nn.init.uniform_(torch.empty(hid_l, hid_l),a=0.0,b=1.0))
            self.l1.weight = Parameter(nn.init.uniform_(torch.empty(hid_l, hid_l),a=0.0,b=1.0))
            self.lay2.weight = Parameter(nn.init.uniform_(torch.empty(out_l,hid_l),a=0.0,b=1.0))
                       

    def forward(self, features, edge_list):
        x = self.active(self.gc1(features, edge_list))
        x = self.active(self.gc2(x, edge_list))
        x = self.active(self.lay1(x))
        x = self.active(self.l0(x))
        x = self.active(self.l1(x))
        x = self.active(self.lay2(x))
       
        return x

其中,features是维度[n X 1]的特征矩阵,edge_list是Pytorch几何中使用的边索引。梯度似乎有一个非常急剧的下降,即使是在40个阶段的训练与批量梯度下降。我正在学习具有大约1K个节点和~3K条边的无标度图。如何获得更好的渐变值

enter image description here


Tags: importselfparameterinituniformnntorchout

热门问题