<p>我发现了一种在Cython中使用numpyc-API(下面的代码)迭代除了一个轴以外的所有轴的方法。然而,它并不漂亮。是否值得这样做取决于内部函数和数据的大小。在</p>
<p>如果有人知道一个更优雅的方式来做这件事,请告诉我。在</p>
<p>我比较了Eelco的解决方案,它们在大参数下运行的速度相当。对于较小的参数,C-API解决方案更快:</p>
<pre><code>In [5]: y=linspace(-1,1,100);
In [6]: %timeit transf.apply_along(f, x, y, axis=1)
1 loops, best of 3: 5.28 s per loop
In [7]: %timeit transf.transfnd(f, x, y, axis=1)
1 loops, best of 3: 5.16 s per loop
</code></pre>
<p>如您所见,对于这个输入,两个函数的速度大致相同。在</p>
^{pr2}$
<p>但是,对于较小的输入数组,C-API方法更快。在</p>
<h2>代码</h2>
<pre><code>#cython: boundscheck=False
#cython: wraparound=False
#cython: cdivision=True
import numpy as np
cimport numpy as np
np.import_array()
cdef extern from "complex.h":
double complex cexp(double complex z) nogil
cdef void transf1d(double complex[:] f,
double[:] x,
double[:] y,
double complex[:] out,
int Nx,
int Ny) nogil:
cdef int i, j
for i in xrange(Ny):
out[i] = 0
for j in xrange(Nx):
out[i] = out[i] + f[j]*cexp(-1j*x[j]*y[i])
def transfnd(F, x, y, axis=-1, out=None):
# Make sure everything is a numpy array.
F = np.asanyarray(F, dtype=complex)
x = np.asanyarray(x, dtype=float)
y = np.asanyarray(y, dtype=float)
# Calculate absolute axis.
cdef int ax = axis
if ax < 0:
ax = np.ndim(F) + ax
# Calculate lengths of the axes `x`, and `y`.
cdef int Nx = np.size(x), Ny = np.size(y)
# Output array.
if out is None:
shape = list(np.shape(F))
shape[axis] = Ny
out = np.empty(shape, dtype=complex)
else:
out = np.asanyarray(out, dtype=complex)
# Error check.
assert np.shape(F)[axis] == Nx, \
'Array length mismatch between `F`, and `x`!'
assert np.shape(out)[axis] == Ny, \
'Array length mismatch between `out`, and `y`!'
f_shape = list(np.shape(F))
o_shape = list(np.shape(out))
f_shape[axis] = 0
o_shape[axis] = 0
assert f_shape == o_shape, 'Array shape mismatch between `F`, and `out`!'
# Construct iterator over all but one axis.
cdef np.flatiter itf = np.PyArray_IterAllButAxis(F, &ax)
cdef np.flatiter ito = np.PyArray_IterAllButAxis(out, &ax)
cdef int f_stride = F.strides[axis]
cdef int o_stride = out.strides[axis]
# Memoryview to access one slice per iteration.
cdef double complex[:] fdat
cdef double complex[:] odat
cdef double[:] xdat = x
cdef double[:] ydat = y
while np.PyArray_ITER_NOTDONE(itf):
# View the current `x`, and `y` axes.
fdat = <double complex[:Nx]> np.PyArray_ITER_DATA(itf)
fdat.strides[0] = f_stride
odat = <double complex[:Ny]> np.PyArray_ITER_DATA(ito)
odat.strides[0] = o_stride
# Perform the 1D-transformation on one slice.
transf1d(fdat, xdat, ydat, odat, Nx, Ny)
# Go to next step.
np.PyArray_ITER_NEXT(itf)
np.PyArray_ITER_NEXT(ito)
return out
# For comparison
def apply_along(F, x, y, axis=-1):
# Make sure everything is a numpy array.
F = np.asanyarray(F, dtype=complex)
x = np.asanyarray(x, dtype=float)
y = np.asanyarray(y, dtype=float)
# Calculate absolute axis.
cdef int ax = axis
if ax < 0:
ax = np.ndim(F) + ax
# Calculate lengths of the axes `x`, and `y`.
cdef int Nx = np.size(x), Ny = np.size(y)
# Error check.
assert np.shape(F)[axis] == Nx, \
'Array length mismatch between `F`, and `x`!'
def wrapper(f):
out = np.empty(Ny, complex)
transf1d(f, x, y, out, Nx, Ny)
return out
return np.apply_along_axis(wrapper, axis, F)
</code></pre>
<p>使用以下内容生成<code>setup.py</code></p>
<pre><code>from distutils.core import setup
from Cython.Build import cythonize
import numpy as np
setup(
name = 'transf',
ext_modules = cythonize('transf.pyx'),
include_dirs = [np.get_include()],
)
</code></pre>