我有一个稀疏数组/张量,如下所示
import torch
from torch_sparse import SparseTensor
row = torch.tensor([0, 0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3])
col = torch.tensor([1, 2, 3, 0, 2, 0, 1, 4, 5, 0, 2, 5, 2, 4])
value = torch.rand([14])
adj_t = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 9))
我想对n_samples
列索引进行采样,可以替换,也可以不替换。我可以通过首先将adj_t
转换为dense
,然后使用torch.multinomial
或者类似地使用numpy.random.choice
来实现这一点
sample = torch.multinomial(adj_t.to_dense(), num_samples=2, replacement=True)
但是将稀疏数组转换为稠密数组和torch.multinomial
并不是很有效。有稀疏版本的torch.multinomial
。如果没有,如何着手实施
我不确定这是否能像你的一艘班轮那样有效地完成
据我所知,实现您的目标的一种方法是:
np.split(value.numpy(), np.unique(row.numpy(), return_index=True)[1][1:])
col
中的相应值(即第0行的0
是1
,1
是第1行的2
,第2行的2
是4
-所有这些都是根据row
和col
值确定的)您可能不想使用任何内置循环来避免性能下降
相关问题 更多 >
编程相关推荐