python中稀疏数组/张量的有效多项式采样

2024-06-26 01:48:10 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个稀疏数组/张量,如下所示

import torch
from torch_sparse import SparseTensor


row = torch.tensor([0, 0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3])
col = torch.tensor([1, 2, 3, 0, 2, 0, 1, 4, 5, 0, 2, 5, 2, 4])
value = torch.rand([14])
adj_t = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 9))

我想对n_samples列索引进行采样,可以替换,也可以不替换。我可以通过首先将adj_t转换为dense,然后使用torch.multinomial或者类似地使用numpy.random.choice来实现这一点

sample = torch.multinomial(adj_t.to_dense(), num_samples=2, replacement=True)

但是将稀疏数组转换为稠密数组和torch.multinomial 并不是很有效。有稀疏版本的torch.multinomial。如果没有,如何着手实施


Tags: fromimportvaluecoltorch数组denserow
1条回答
网友
1楼 · 发布于 2024-06-26 01:48:10

我不确定这是否能像你的一艘班轮那样有效地完成

据我所知,实现您的目标的一种方法是:

  1. 按其在sparese张量中出现的行对值进行分组,例如使用this solution:np.split(value.numpy(), np.unique(row.numpy(), return_index=True)[1][1:])
  2. 使用numpy.random.multinominal为每行创建随机选择的索引列表
  3. 将索引映射到col中的相应值(即第0行的011是第1行的2,第2行的24-所有这些都是根据rowcol值确定的)

您可能不想使用任何内置循环来避免性能下降

相关问题 更多 >