目前正在尝试使用PyTorch实现增强算法。我希望能够收集负责任的输出后,折扣奖励。因此,考虑到actions内存,我创建了一个索引的张量,并尝试使用张量索引选择,但没有成功。有人能帮忙吗?在
rH = np.array(rH) # discounted reward
aH = np.array(aH) # action_holder
sH = np.vstack(np.array(sH)) # states holder
statesTensor = Variable(torch.from_numpy(sH).type(torch.FloatTensor))
out = model.forward(statesTensor)
indexes = GuiltyOnes(out, aH)
flat = out.view(1,-1)
respos = torch.index_select(flat, 1, torch.from_numpy(indexes).type(torch.LongTensor))
我得到以下错误:
^{pr2}$
您的情况可能类似于this one,因此,您应该使用
Variable
来代替:请记住,pytorch错误消息并不总是真正准确的。在这种情况下,这对我来说是相当误导的
相关问题 更多 >
编程相关推荐