为什么在我的算法中,LSTM单元的损耗图显示了许多上下波动?

2024-09-30 22:22:38 发布

您现在位置:Python中文网/ 问答频道 /正文

Hier I attached Image of my Loss Diagram. I am very new to stack overflow, if possible, please try to answer here.

我的数据在这里代表一个大约有3600个节点的图。我想利用图中最后一个节点的已知数据来预测下一个节点

在这里,我使用了一个LSTM单元层作为我的6个死亡数据。[x,y,u,v,p,type]。最后我只想得到3D[u,v,p]输出。 实际上它做得很好

但当我看到同一个图在30000次迭代后的损耗。 它看起来像,因为我附上了一张照片上面

我不知道,为什么我的损失没有减少到我们预期的0.025以下。为什么有时会出现很高的波动而不是曲线呢

import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch
plt.switch_backend('agg')

class Sequence(nn.Module):
    def __init__(self):
        super(Sequence, self).__init__()
        self.lstmcell = nn.LSTMCell(5, 256)
        self.linear = nn.Linear(256, 3)

    def forward(self, x, path):

        outputs = []
        h_t = torch.zeros(1, 256)
        c_t = torch.zeros(1, 256)         

        for i in range(x.shape[0]):
            if  trainData[path[i],5] == 0:
                neighbours = graphDict[str(path[i])]
                y = graphFlags[neighbours]
                idx = np.where(y !=0)[0]
                z = torch.Tensor(x[neighbours[idx],:])
                u = torch.mean(z,0)
            else:              
                u = (x[path[i],:]).clone()
            input_t = u.view(1,5)
            h_t, c_t = self.lstmcell(input_t, (h_t, c_t))          
            output = torch.tanh(self.linear(h_t))
            outputs += [output]
            graphFlags[path[i]] = 1
            x[path[i], 0:3] = output     
        outputs = torch.cat((outputs), 0)
        return outputs

for t in range(30000):
    graphFlags = np.copy(trainData[:,5])
    trueData_ = torch.Tensor(trueData[:, 2:5])

    binary_repr = np.unpackbits(np.copy(trainData)[:,5].astype(np.uint8).reshape(-1,1), axis=1)[:,6:]   
    x = np.concatenate((np.copy(trainData)[:,2:5], binary_repr), axis=1)
    x = torch.Tensor(x)

    optimizer.zero_grad()
    out = seq(x, path)

    loss = criterion(out, trueData_)
    counter+=1
    loss.backward()
    optimizer.step()
    lossHisto.append(loss.item())

    if counter == 100:

        saveName = 'out'
        SaveName = 'lossHisto'
        torch.save(out, '/mnt/fs2/home/bhavesh/Ergebnisse/test1/'+saveName+str(t))
        BILD = torch.load('/mnt/fs2/home/bhavesh/Ergebnisse/test1/'+saveName+str(t))
        plt.figure(figsize=(20,10))
        plt.scatter(trueData[:,0],trueData[:,1],c=torch.Tensor(BILD.data[:,2]),marker='.')
        plt.colorbar()   
        plt.savefig("/mnt/fs2/home/bhavesh/Ergebnisse/test1/out%d.png"%t)   
        plt.clf()
        torch.save(lossHisto, '/mnt/fs2/home/bhavesh/Ergebnisse/test1/'+SaveName+str(t))
        LOSS = torch.load('/mnt/fs2/home/bhavesh/Ergebnisse/test1/'+SaveName+str(t))
        plt.plot(LOSS)
        plt.grid()
        plt.savefig("/mnt/fs2/home/bhavesh/Ergebnisse/test/lossHisto%d.png"%t)
        counter = 0

    print(loss.item()) 

Tags: pathselfhomenppltnntorchoutputs