如何使用Pythorch张量索引选择()?

2024-09-29 21:43:17 发布

您现在位置:Python中文网/ 问答频道 /正文

目前正在尝试使用PyTorch实现增强算法。我希望能够收集负责任的输出后,折扣奖励。因此,考虑到actions内存,我创建了一个索引的张量,并尝试使用张量索引选择,但没有成功。有人能帮忙吗?在

    rH = np.array(rH) # discounted reward
    aH = np.array(aH) # action_holder
    sH = np.vstack(np.array(sH)) # states holder

    statesTensor = Variable(torch.from_numpy(sH).type(torch.FloatTensor))
    out = model.forward(statesTensor)

    indexes = GuiltyOnes(out, aH)
    flat = out.view(1,-1)

    respos = torch.index_select(flat, 1, torch.from_numpy(indexes).type(torch.LongTensor))

我得到以下错误:

^{pr2}$

Tags: fromnumpytypeshnptorchpytorchout
1条回答
网友
1楼 · 发布于 2024-09-29 21:43:17

您的情况可能类似于this one,因此,您应该使用Variable来代替:

i = Variable(torch.from_numpy(indexes).long())
respos = torch.index_select(flat, 1, i)

请记住,pytorch错误消息并不总是真正准确的。在这种情况下,这对我来说是相当误导的

相关问题 更多 >

    热门问题