为什么numpy的矩阵乘法比Python的ctypes快?

2024-05-11 05:46:05 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图找出最快的矩阵乘法方法,并尝试了3种不同的方法:

  • 纯python实现:这里没有什么奇怪的。
  • 使用numpy.dot(a, b)实现Numpy
  • 使用Python中的ctypes模块与C接口。

这是转换为共享库的C代码:

#include <stdio.h>
#include <stdlib.h>

void matmult(float* a, float* b, float* c, int n) {
    int i = 0;
    int j = 0;
    int k = 0;

    /*float* c = malloc(nay * sizeof(float));*/

    for (i = 0; i < n; i++) {
        for (j = 0; j < n; j++) {
            int sub = 0;
            for (k = 0; k < n; k++) {
                sub = sub + a[i * n + k] * b[k * n + j];
            }
            c[i * n + j] = sub;
        }
    }
    return ;
}

以及调用它的Python代码:

def C_mat_mult(a, b):
    libmatmult = ctypes.CDLL("./matmult.so")

    dima = len(a) * len(a)
    dimb = len(b) * len(b)

    array_a = ctypes.c_float * dima
    array_b = ctypes.c_float * dimb
    array_c = ctypes.c_float * dima

    suma = array_a()
    sumb = array_b()
    sumc = array_c()

    inda = 0
    for i in range(0, len(a)):
        for j in range(0, len(a[i])):
            suma[inda] = a[i][j]
            inda = inda + 1
        indb = 0
    for i in range(0, len(b)):
        for j in range(0, len(b[i])):
            sumb[indb] = b[i][j]
            indb = indb + 1

    libmatmult.matmult(ctypes.byref(suma), ctypes.byref(sumb), ctypes.byref(sumc), 2);

    res = numpy.zeros([len(a), len(a)])
    indc = 0
    for i in range(0, len(sumc)):
        res[indc][i % len(a)] = sumc[i]
        if i % len(a) == len(a) - 1:
            indc = indc + 1

    return res

我敢打赌使用C的版本会更快。。。我会输的!下面是我的基准,它似乎表明我要么做得不对,要么numpy太快了:

benchmark

我想知道为什么numpy版本比ctypes版本快,我甚至没有讨论纯Python实现,因为它有点明显。


Tags: innumpyforlenrangefloatctypesarray
3条回答

NumPy使用高度优化的、经过仔细调整的BLAS方法进行矩阵乘法(另请参见:ATLAS)。本例中的特定函数是GEMM(用于泛型矩阵乘法)。您可以通过搜索dgemm.f(在Netlib中)来查找原始文件。

顺便说一下,优化不仅仅是编译器优化。上面,菲利普提到了铜匠——温诺格拉德。如果我没记错的话,这是ATLAS中大多数矩阵乘法情况下使用的算法(尽管一位评论者指出这可能是Strassen的算法)。

换言之,您的matmult算法是微不足道的实现。有更快的方法做同样的事情。

我对Numpy不太熟悉,但源代码在Github上。dot产品的一部分是在https://github.com/numpy/numpy/blob/master/numpy/core/src/multiarray/arraytypes.c.src中实现的,我假设它被转换为每个数据类型的特定C实现。例如:

/**begin repeat
 *
 * #name = BYTE, UBYTE, SHORT, USHORT, INT, UINT,
 * LONG, ULONG, LONGLONG, ULONGLONG,
 * FLOAT, DOUBLE, LONGDOUBLE,
 * DATETIME, TIMEDELTA#
 * #type = npy_byte, npy_ubyte, npy_short, npy_ushort, npy_int, npy_uint,
 * npy_long, npy_ulong, npy_longlong, npy_ulonglong,
 * npy_float, npy_double, npy_longdouble,
 * npy_datetime, npy_timedelta#
 * #out = npy_long, npy_ulong, npy_long, npy_ulong, npy_long, npy_ulong,
 * npy_long, npy_ulong, npy_longlong, npy_ulonglong,
 * npy_float, npy_double, npy_longdouble,
 * npy_datetime, npy_timedelta#
 */
static void
@name@_dot(char *ip1, npy_intp is1, char *ip2, npy_intp is2, char *op, npy_intp n,
           void *NPY_UNUSED(ignore))
{
    @out@ tmp = (@out@)0;
    npy_intp i;

    for (i = 0; i < n; i++, ip1 += is1, ip2 += is2) {
        tmp += (@out@)(*((@type@ *)ip1)) *
               (@out@)(*((@type@ *)ip2));
    }
    *((@type@ *)op) = (@type@) tmp;
}
/**end repeat**/

这似乎是计算一维点积,即矢量上的点积。在Github浏览的几分钟里,我找不到矩阵的源代码,但它可能对结果矩阵中的每个元素使用一个FLOAT_dot调用。这意味着这个函数中的循环对应于最内部的循环。

它们之间的一个区别是,“步幅”(输入中连续元素之间的差异)在调用函数之前显式计算一次。在您的例子中没有跨距,每次都计算每个输入的偏移量,例如a[i * n + k]。我本来希望一个好的编译器能够优化这个步骤,使之与Numpy的步骤类似,但是也许它不能证明这个步骤是一个常量(或者它没有被优化)。

Numpy也可能在调用这个函数的高级代码中使用缓存效果来做一些聪明的事情。一个常见的技巧是考虑每一行是连续的,还是每一列——并尝试首先遍历每一个连续的部分。对于每一个点积,一个输入矩阵必须由行遍历,另一个由列遍历(除非它们碰巧以不同的主顺序存储)。但它至少可以对结果元素做到这一点。

Numpy还包含从不同的基本实现中选择某些操作(包括“dot”)实现的代码。例如,它可以使用BLAS库。从上面的讨论来看,似乎使用了CBLAS。这是从Fortran翻译成C语言的。我认为在您的测试中使用的实现应该是在这里找到的:http://www.netlib.org/clapack/cblas/sdot.c

请注意,此程序是由一台计算机编写的,供另一台计算机读取。但您可以在底部看到,它使用展开循环一次处理5个元素:

for (i = mp1; i <= *n; i += 5) {
stemp = stemp + SX(i) * SY(i) + SX(i + 1) * SY(i + 1) + SX(i + 2) * 
    SY(i + 2) + SX(i + 3) * SY(i + 3) + SX(i + 4) * SY(i + 4);
}

这个展开因子很可能是在分析了几个之后选择的。但它的一个理论优势是在每个分支点之间进行更多的算术运算,编译器和CPU可以选择如何优化调度它们,以获得尽可能多的指令流水线。

用来实现某种功能的语言本身就是一种糟糕的性能度量。通常,使用更合适的算法是决定因素。

在您的例子中,您使用的是学校里教的矩阵乘法的朴素方法,即O(n^3)。但是,对于某些类型的矩阵,例如平方矩阵、备用矩阵等,您可以做得更好。

看一下Coppersmith–Winograd algorithm(O(n^2.3737)中的平方矩阵乘法)作为快速矩阵乘法的良好起点。另请参阅“引用”一节,其中列出了一些指向更快方法的指针。


要获得令人惊讶的性能提升,请尝试编写一个快速的strlen(),并将其与glibc实现进行比较。如果你不能战胜它,请阅读glibc的strlen()源代码,它有相当好的注释。

相关问题 更多 >