我的数据在这里代表一个大约有3600个节点的图。我想利用图中最后一个节点的已知数据来预测下一个节点
在这里,我使用了一个LSTM单元层作为我的6个死亡数据。[x,y,u,v,p,type]。最后我只想得到3D[u,v,p]输出。 实际上它做得很好
但当我看到同一个图在30000次迭代后的损耗。 它看起来像,因为我附上了一张照片上面
我不知道,为什么我的损失没有减少到我们预期的0.025以下。为什么有时会出现很高的波动而不是曲线呢
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch
plt.switch_backend('agg')
class Sequence(nn.Module):
def __init__(self):
super(Sequence, self).__init__()
self.lstmcell = nn.LSTMCell(5, 256)
self.linear = nn.Linear(256, 3)
def forward(self, x, path):
outputs = []
h_t = torch.zeros(1, 256)
c_t = torch.zeros(1, 256)
for i in range(x.shape[0]):
if trainData[path[i],5] == 0:
neighbours = graphDict[str(path[i])]
y = graphFlags[neighbours]
idx = np.where(y !=0)[0]
z = torch.Tensor(x[neighbours[idx],:])
u = torch.mean(z,0)
else:
u = (x[path[i],:]).clone()
input_t = u.view(1,5)
h_t, c_t = self.lstmcell(input_t, (h_t, c_t))
output = torch.tanh(self.linear(h_t))
outputs += [output]
graphFlags[path[i]] = 1
x[path[i], 0:3] = output
outputs = torch.cat((outputs), 0)
return outputs
for t in range(30000):
graphFlags = np.copy(trainData[:,5])
trueData_ = torch.Tensor(trueData[:, 2:5])
binary_repr = np.unpackbits(np.copy(trainData)[:,5].astype(np.uint8).reshape(-1,1), axis=1)[:,6:]
x = np.concatenate((np.copy(trainData)[:,2:5], binary_repr), axis=1)
x = torch.Tensor(x)
optimizer.zero_grad()
out = seq(x, path)
loss = criterion(out, trueData_)
counter+=1
loss.backward()
optimizer.step()
lossHisto.append(loss.item())
if counter == 100:
saveName = 'out'
SaveName = 'lossHisto'
torch.save(out, '/mnt/fs2/home/bhavesh/Ergebnisse/test1/'+saveName+str(t))
BILD = torch.load('/mnt/fs2/home/bhavesh/Ergebnisse/test1/'+saveName+str(t))
plt.figure(figsize=(20,10))
plt.scatter(trueData[:,0],trueData[:,1],c=torch.Tensor(BILD.data[:,2]),marker='.')
plt.colorbar()
plt.savefig("/mnt/fs2/home/bhavesh/Ergebnisse/test1/out%d.png"%t)
plt.clf()
torch.save(lossHisto, '/mnt/fs2/home/bhavesh/Ergebnisse/test1/'+SaveName+str(t))
LOSS = torch.load('/mnt/fs2/home/bhavesh/Ergebnisse/test1/'+SaveName+str(t))
plt.plot(LOSS)
plt.grid()
plt.savefig("/mnt/fs2/home/bhavesh/Ergebnisse/test/lossHisto%d.png"%t)
counter = 0
print(loss.item())
目前没有回答
相关问题 更多 >
编程相关推荐