我创建了一个线性的ReLu网络,它应该与我的数据过度匹配。我使用bcewithlogisticsloss作为损失函数。我用它来分类三维点。因为数据很小,所以我不想分批处理。而且效果很好。然而,现在我已经实现了批处理,它似乎预测值不是我所期望的(即0或1),相反,它给我的数字像-25.4562我没有改变任何其他从网络只有批处理。你知道吗
我尝试了二进制丢失函数BSELoss,但是它似乎是pytorch版本中的一个bug,所以我不能使用它。你可以看看我的代码如下:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# We load the training data
Samples, Ocupancy = common.load_samples()
for i in range(0,Ocupancy.shape[0]):
if Ocupancy[i] > 1 or Ocupancy[i] < 0:
print("upsie")
max = np.amax(Samples)
min = np.amin(Samples)
x_test = torch.from_numpy(Samples.astype(np.float32)).to(device)
y_test = torch.from_numpy(Ocupancy.astype(np.float32)).to(device)
train_data = CustomDataset(x_test, y_test)
train_loader = DataLoader(dataset=train_data, batch_size= 22500, shuffle=False) # Batches_size equal to the number of points in each slice
phi = common.MLP(3, 1).to(device)
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(phi.parameters(), lr = 0.01)
epoch = 10
fit_start_time = time.time()
for epoch in range(epoch):
for x_batch, y_batch in train_loader:
#optimizer.zero_grad()
x_train = x_batch.to(device)
y_train = y_batch.to(device)
y_pred = phi(x_batch)
print(y_pred)
# Compute Loss
loss = criterion(y_pred.squeeze(), y_batch.squeeze())
print('Epoch {}: train loss: {}'.format(epoch, loss.item())) # Backward pass
loss.backward()
optimizer.step()
fit_end_time = time.time()
print("Total time = %f" % (fit_end_time - fit_start_time))
min = -2
max = 2
resolution = 0.05
X,Y,Z = np.mgrid[min:max:resolution,min:max:resolution,min:max:resolution] # sample way more
xyz = torch.from_numpy(np.vstack([X.ravel(), Y.ravel(),Z.ravel()]).transpose().astype(np.float32)).to(device)
eval = LabelData(xyz)
eval_loader = DataLoader(dataset=eval, batch_size= 22500, shuffle=False) # Make bigger batches
# feed the network bit by bit?
i = 0
for x_batch in eval_loader:
phi.eval()
labels = phi(x_batch).to(device)
print(labels)
visualization_iso(X,Y,Z,labels)
我希望预测值是0或1,或者至少是一个概率,但是它给了我很多我不明白的数字。比如:19.5953 请看一下我的代码,如果你发现任何大错误请告诉我。我真的很困惑,因为在我扩展我使用的数据之前它工作得很好。你知道吗
敬礼
我可能错了,但我试着根据你的问题代码来回答。你知道吗
您使用的是
BCEwithlogitsloss
,这意味着模型应该输出logits
。logits
是使用sigmoid激活之前的输出。回想一下,sigmoid激活用于将输出转换为概率(基本上介于0和1之间)。Logits可以是任何实数。你知道吗基于此,我认为应该通过sigmoid激活传递模型的输出,即
F.sigmoid(phi(x_batch))
。或者您也可以只检查模型的输出是大于0还是小于0。如果大于0,则标签应为1。你知道吗相关问题 更多 >
编程相关推荐