我有一个函数,它返回一个大的线性方程组的剩余范数的平方
In [1]: import numpy as np
In [2]: A = np.random.rand(3600000, 200)
In [3]: b = np.random.rand(3600000)
In [4]: def f(x):
...: global A
...: global b
...: return np.linalg.norm(A.dot(x) - b)**2
现在我有了一个算法,在这个算法中,函数必须被求值几次。然而,由于方程系统的大小,在某个x
处的每个函数调用都需要很多时间
In [5]: import time
In [6]: def f(x):
...: global A
...: global b
...: start = time.time()
...: res = np.linalg.norm(A.dot(x) - b)**2
...: end = time.time()
...: return res, end - start
In [7]: test = np.random.rand(200)
In [8]: f(test)
Out[8]: (8820030785.528395, 7.467242956161499)
我的问题是:
Are there any possibilities for reducing the time of such a function call?
我曾想过用一个更高效的表达式替换np.linalg.norm(A.dot(x) - b)**2
,但我不知道这会是什么样子
技术信息上面的代码是在带有
Memory:
Memory Slots:
ECC: Disabled
Upgradeable Memory: No
BANK 0/DIMM0:
Size: 4 GB
Type: LPDDR3
Speed: 2133 MHz
Status: OK (...)
BANK 1/DIMM0:
Size: 4 GB
Type: LPDDR3
Speed: 2133 MHz
Status: OK (...)
np.show_config()
的结果是
blas_mkl_info:
libraries = ['blas', 'cblas', 'lapack', 'pthread', 'blas', 'cblas', 'lapack']
library_dirs = ['/Users/me/miniconda3/envs/magpy/lib']
define_macros = [('SCIPY_MKL_H', None), ('HAVE_CBLAS', None)]
include_dirs = ['/Users/me/miniconda3/envs/magpy/include']
blas_opt_info:
libraries = ['blas', 'cblas', 'lapack', 'pthread', 'blas', 'cblas', 'lapack', 'blas', 'cblas', 'lapack']
library_dirs = ['/Users/me/miniconda3/envs/magpy/lib']
define_macros = [('SCIPY_MKL_H', None), ('HAVE_CBLAS', None)]
include_dirs = ['/Users/me/miniconda3/envs/magpy/include']
lapack_mkl_info:
libraries = ['blas', 'cblas', 'lapack', 'pthread', 'blas', 'cblas', 'lapack']
library_dirs = ['/Users/me/miniconda3/envs/magpy/lib']
define_macros = [('SCIPY_MKL_H', None), ('HAVE_CBLAS', None)]
include_dirs = ['/Users/me/miniconda3/envs/magpy/include']
lapack_opt_info:
libraries = ['blas', 'cblas', 'lapack', 'pthread', 'blas', 'cblas', 'lapack', 'blas', 'cblas', 'lapack']
library_dirs = ['/Users/me/miniconda3/envs/magpy/lib']
define_macros = [('SCIPY_MKL_H', None), ('HAVE_CBLAS', None)]
include_dirs = ['/Users/me/miniconda3/envs/magpy/include']
在您的情况下
np.linalg.norm
只是因此,您最好做以下工作:
跳过不必要的sqrt/square。但与最初的
A@x
相比,这可能是个小问题在一台相当普通的Linux4Gb计算机上,您的测试用例给了我(在创建
A
时)虽然你显然有足够的记忆力,但你可能正在突破这一界限。在另一个例子中,我们已经看到,由于内存管理问题,使用非常大的数组的
dot/@
会减慢速度。通常,人们通过进行某种“块”处理来提高速度。如果您正在使用3d“批处理”进行matmul
,那么这很容易。你的普通案件就不那么明显了将
A
大小减少10:时间上没有太大不同:
事实上,正是
A.dot(x)
主导了时间安排;其余的可以忽略不计将
A
的大小加倍,大约使时间加倍(175-180范围)我不是图书馆专家,但我相信
MKL
是一个更快的选择,我没有(但你似乎有)性能问题似乎来自BLAS的默认实现速度缓慢
在您的计算机上使用的默认BLAS实现显然是英特尔MKL,它通常非常快,但在您的计算机上却出人意料地慢。 事实上,根据提供的硬件信息,执行时间应为170-200毫秒,而不是7.5秒
您可以尝试切换到另一个BLAS实现,如OpenBLAS、Apple Accelerate或BLIS。您可以在here和here中找到有关如何执行此操作的信息
如果切换到另一个BLAS实现无法解决问题,则使用以下回退NUBA实现:
这段代码不如使用基于快速BLAS实现的numpy函数好,但它应该仍然相对较快(请注意,第一次调用
f
会有点慢,因为包含了编译时间)请注意,对数组使用类型
np.float32
可以将执行速度提高2倍,尽管结果也应该不太准确相关问题 更多 >
编程相关推荐