我有以下代码:
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
import keras
from random import choice
import sys
devicet = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(devicet)
if devicet == 'cpu':
print ('Using CPU')
else:
print ('Using GPU')
cuda0 = torch.device('cuda:0')
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.step1 = nn.Linear(5, 25)
self.step2 = nn.Linear(25, 50)
self.step3 = nn.Linear(50, 100)
self.step4 = nn.Linear(100, 100)
self.step5 = nn.Linear(100, 10)
self.step6 = nn.Linear(10, 1)
def forward(self, x):
x = F.relu(x)
x = self.step1(x)
x = F.relu(x)
x = self.step2(x)
x = F.relu(x)
x = self.step3(x)
x = F.relu(x)
x = self.step4(x)
x = F.relu(x)
x = self.step5(x)
x = F.relu(x)
x = self.step6(x)
x = F.relu(x)
return (x)
net = Net()
x = torch.rand(10,5)
num = choice(range(10))
zero_tensor = torch.zeros(num, 1)
one_tensor = torch.ones(10-num, 1)
y = torch.cat((zero_tensor,one_tensor),0)
x.to(devicet)
y.to(devicet)
learning_rate = 1e-3
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
loss_fn = torch.nn.BCELoss()
acc_list = []
for i in tqdm(range(1000),desc='Training'):
y_pred = net(x)
loss = loss_fn(y_pred, y)
loss.backward()
optimizer.step()
acc_list.append(abs(net(x).detach().numpy()[0]-y.detach().numpy()[0]))
with torch.no_grad():
for param in net.parameters():
param -= learning_rate * param.grad
optimizer.zero_grad()
print ('\nFinished training in {} epochs.'.format(len(acc_list)))
plt.plot(range(len(acc_list)),acc_list)
plt.show()
for i in range(10):
print (str(net(x).detach().numpy()[i][0])+', '+str(y.detach().numpy()[i][0]))
当我运行它时,它始终只打印出以下内容:
为什么它不做任何训练?如果我使用MSE损耗,它会工作(实际上,它有时只对MSE损耗有效,有时它会做与图像中相同的事情),只有当我使用BCE时,它才会完全停止工作
最后一层活化
您只输出正值,这些正值应该介于
0
和1
之间。对于初学者,这些值具体如下:将}
torch.sigmoid
与BCELoss
一起使用甚至更好,只需输出x
并使用直接使用logits的^{训练
您正在使用
Adam
优化器并在此处手动执行SGD:从本质上讲,您要应用两次优化步骤,这可能太多,可能会破坏权重
optimizer.step()
已经做到了这一点,两者都不需要强>准确度
本部分:
我假设您想要计算准确度,它将是这样的(同样不要通过
net(x)
两次通过网络推送数据,您已经有了y_pred
!):相关问题 更多 >
编程相关推荐