尝试从Python代码中删除forloop,对矩阵执行查找表操作

2024-09-28 05:16:33 发布

您现在位置:Python中文网/ 问答频道 /正文

我觉得这个问题和我之前问过的问题很相似,但我想不通。如何将这两行代码转换成没有for循环的一行?你知道吗

for i in xrange(X.shape[0]):
  dW[:,y[i]] -= X[i]

在英语中,矩阵X中的每一行都应该从由向量y给出的矩阵dW中的相应列中减去

我应该提到dW是DXC,X是NXD,所以X的转置与W的形状不同,否则我可以对X的行重新排序,直接进行转置。但是,dW中的列可能有多个需要减去的对应行。你知道吗

我觉得我对python中的索引应该如何工作没有一个很好的理解,这使得删除不必要的for循环很困难,甚至不知道什么for循环可以删除。你知道吗


Tags: 代码infor排序矩阵向量形状shape
2条回答

矢量化的简单方法是:

dW[:,y] -= X.T

除了,虽然不是很明显或者没有很好的记录,但是这会给y中的重复索引带来问题。对于这些情况,有ufunc.at方法(numpy中的元素操作被实现为“ufunc”或“通用函数”)。引用the docs

ufunc.at(a, indices, b=None)

Performs unbuffered in place operation on operand ‘a’ for elements specified by ‘indices’. For addition ufunc, this method is equivalent to a[indices] += b, except that results are accumulated for elements that are indexed more than once. For example, a[[0,0]] += 1 will only increment the first element once because of buffering, whereas add.at(a, [0,0], 1) will increment the first element twice.

所以在你的情况下:

np.subtract.at(dW.T, y, X)

不幸的是,就矢量化技术而言,ufunc.at的效率相对较低,因此与循环相比的加速可能没有那么令人印象深刻。你知道吗

方法#1这是一种单线性向量化方法,使用^{}^{}-

dWout -= (np.arange(dW.shape[1])[:,None] == y).dot(X).T

解释:举一个小例子来了解发生了什么-

输入:

In [259]: X
Out[259]: 
array([[ 0.80195208,  0.40566743,  0.62585574,  0.53571781],
       [ 0.56643339,  0.4635662 ,  0.4290103 ,  0.14457036],
       [ 0.31823491,  0.12329964,  0.41682841,  0.09544716]])

In [260]: y
Out[260]: array([1, 2, 2])

首先,我们创建分布在dW第二轴长度上的y索引的2D掩码。你知道吗

dW4 x 5形数组。所以,面具应该是:

In [261]: mask = (np.arange(dW.shape[1])[:,None] == y)

In [262]: mask
Out[262]: 
array([[False, False, False],
       [ True, False, False],
       [False,  True,  True],
       [False, False, False],
       [False, False, False]], dtype=bool)

这里使用^{}来创建2D掩码。你知道吗

接下来,我们使用矩阵乘法对y中的相同索引求和-

In [264]: mask.dot(X)
Out[264]: 
array([[ 0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.80195208,  0.40566743,  0.62585574,  0.53571781],
       [ 0.8846683 ,  0.58686584,  0.84583872,  0.24001752],
       [ 0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ]])

因此,对应于在第二列和第三列具有True值的掩码的第三行,我们将用该矩阵乘法从X求第二行和第三行的和。这将作为乘法输出的第三行。你知道吗

因为在原始循环代码中,我们正在跨列更新dW,所以我们需要转置乘法结果,然后更新。你知道吗


方法#2这里有另一种矢量化方法,尽管不是使用^{}的一行-

sidx = y.argsort()
unq,shift_idx = np.unique(y[sidx],return_index=True)
dWout[:,unq] -= np.add.reduceat(X[sidx],shift_idx,axis=0).T

相关问题 更多 >

    热门问题