当结果稀疏度已知时的稀疏矩阵乘法(在python中)

2024-09-28 21:52:04 发布

您现在位置:Python中文网/ 问答频道 /正文

假设我们想要计算给定稀疏矩阵A,B的C=A*B,但是对C的一个非常小的子集合感兴趣,用一个索引对列表来表示:
行=[i1,i2,i3。。。]
cols=[j1,j2,j3。。。]
A和B都非常大(比如50Kx50K),但是非常稀疏(1%的条目是非零的)。在

我们如何计算乘法的这个子集?在

下面是一个运行速度非常慢的天真实现:

def naive(A, B, rows, cols):
    N = len(rows)
    vals = []
    for n in xrange(N):
        v = A.getrow(rows[n]) * B.getcol(cols[n])
        vals.append(v[0, 0])

    R = sps.coo_matrix((np.array(vals), (np.array(rows), np.array(cols))), shape=(A.shape[0], B.shape[1]), dtype=np.float64)
    return R

即使对于小矩阵,这也很糟糕:

^{pr2}$

在我的机器上,naive()在1分钟后完成,大部分的工作都花在构建行/列(在getrow()和getcol()中)。
当然,将这个(非常小)示例转换为稠密矩阵,计算大约需要100ms:

A0 = np.array(A.todense())
B0 = np.array(B.todense())
X0 = np.array(X.todense())
A0.dot(B0) * X0

对如何有效地计算这种矩阵乘法有什么想法?在


Tags: np矩阵b0a0arraymatrixrowscols
1条回答
网友
1楼 · 发布于 2024-09-28 21:52:04

稀疏矩阵的格式在这里很重要。您总是需要一个行形式a和一个来自B的列。因此,将A存储为csr,将{}存储为{},以消除{}/getcol的开销。不幸的是,这只是故事的一小部分。在

最佳解决方案在很大程度上取决于稀疏矩阵的结构(大量的稀疏列/行等),但是您可以尝试基于字典和集合的方法。对于每一行的矩阵A,保留以下内容:

  • 该行上所有非零列索引的集合
  • 一种以非零索引为键,以相应的非零值为值的字典

对于矩阵B,每个列都保留类似的dict和set。在

为了计算乘法结果中的元素(M,N),A的M行与B的N列相乘。乘法:

  • 求非零集的集合交集
  • 计算非零元素的乘法之和(即上面的交集)

在大多数情况下,这应该非常快,因为在稀疏矩阵中,集合交集通常非常小。在

一些代码:

class rowarray():
    def __init__(self, arr):
        self.rows = []
        for row in arr:
            nonzeros = np.nonzero(row)[0]
            nzvalues = { i: row[i] for i in nonzeros }
            self.rows.append((set(nonzeros), nzvalues))

    def __getitem__(self, key):
        return self.rows[key]

    def __len__(self):
        return len(self.rows)


class colarray(rowarray):
    def __init__(self, arr):
        rowarray.__init__(self, arr.T)


def maybe_less_naive(A, B, rows, cols):
    N = len(rows)
    vals = []
    for n in xrange(N):
        nz1,v1 = A[rows[n]]
        nz2,v2 = B[cols[n]]
        # list of common non-zeros
        nz = nz1.intersection(nz2)
        # sum of non-zeros
        vals.append(sum([ v1[i]*v2[i] for i in nz]))

    R = sps.coo_matrix((np.array(vals), (np.array(rows), np.array(cols))), shape=(len(A), len(B)), dtype=np.float64)
    return R

D = 1000

Ap = np.random.randn(D, D)
Ap[np.abs(Ap) > 0.1] = 0
A = rowarray(Ap)
Bp = np.random.randn(D, D)
Bp[np.abs(Bp) > 0.1] = 0
B = colarray(Bp)

X = np.random.randn(D, D)
X[np.abs(X) > 0.1] = 0
X[X != 0] = 1
X = sps.csr_matrix(X)
rows, cols = X.nonzero()
maybe_less_naive(A, B, rows, cols)

这是一个有点效率,乘法大约需要2秒的测试(80000个元素)。结果似乎基本相同。在


对表演的一些评论。在

对每个输出元素执行两个操作:

  • 设置交点
  • 乘法

集合交集的复杂度应该是O(min(m,n)),其中m和n是每个操作数中非零的数目。这是矩阵大小不变的,只有每行/列的非零的平均数才是重要的。在

乘法(和dict查找)的数量取决于在上面的交集中找到的非零的数量。在

如果两个矩阵都有概率(密度)p的随机分布非零,且行/列长度为n,则:

  • 集合交集:O(np)
  • 字典查找,乘法:O(np^2)

这表明,对于真正稀疏的矩阵,求交集是关键点。这也可以通过分析来验证;大部分时间都花在计算交叉点上。在

当这反映到现实世界中时,我们似乎要花20美元左右来获得一行/列80个非零。这并不是盲目的快,而且代码当然可以更快。Cython可能是一种解决方案,但这可能是Python不是最佳解决方案的问题之一。对于排序整数的简单线性匹配(合并排序类型算法)在用C编写时应该至少快一个数量级

需要注意的一点是,该算法可以同时针对多个元素并行执行。不需要为单个线程解决问题,因为只要一个线程处理一个输出点,计算是独立的。在

相关问题 更多 >