我正在尝试使用RANSAC回归算法的修改版本,但没有更改sklearn的函数(为了避免使用sklearn 0.24.1的问题和避免函数中的一些警告),我在github中找到了一个算法,并尝试对其进行Cythonization以提高速度,然后才能进行修改。令我惊讶的是,速度提升非常差,是因为代码中充满了numpy调用,还是我做错了什么?以下代码是python版本、cython版本(经过适当修改以避免错误):
##Python version
import numpy as np
def ransac_polyfit(x,y,order, t, n=0.8,k=100,f=0.9):
besterr = np.inf
bestfit = np.array([None])
for kk in range(k):
maybeinliers = np.random.randint(len(x), size=int(n*len(x)))
maybemodel = np.polyfit(x[maybeinliers], y[maybeinliers], order)
alsoinliers = np.abs(np.polyval(maybemodel, x)-y) < t
if sum(alsoinliers) > len(x)*f:
bettermodel = np.polyfit(x[alsoinliers], y[alsoinliers], order)
thiserr = np.sum(np.abs(np.polyval(bettermodel, x[alsoinliers])-y[alsoinliers]))
if thiserr < besterr:
bestfit = bettermodel
besterr = thiserr
return bestfit
##Cython version
cimport cython
cimport numpy as np
import numpy as np
np.import_array()
@cython.boundscheck(False)
@cython.wraparound(False)
cdef ransac_polyfit(np.ndarray[np.npy_int64, ndim=1] x,
np.ndarray[np.npy_int64, ndim=1] y,
np.npy_intp order,
np.npy_float32 t,
np.npy_float32 n=0.8,
np.npy_intp k=100,
np.npy_float32 f=0.9):
cdef np.npy_intp kk, i, ransac_control
cdef np.npy_float64 thiserr, besterr = -1.0
cdef np.ndarray[np.npy_int64, ndim=1] maybeinliers
cdef np.ndarray[np.npy_bool, ndim=1] alsoinliers
cdef np.ndarray[np.npy_float64, ndim=1] bestfit, maybemodel, bettermodel
bestfit = np.zeros(1, dtype = np.float64)
for kk in range(k):
maybeinliers = np.random.randint(len(x), size=int(n*len(x)))
maybemodel = np.polyfit(x[maybeinliers], y[maybeinliers], order)
alsoinliers = np.abs(polyval(maybemodel, x)-y) < t
if sum(alsoinliers) > len(x)*f:
bettermodel = np.polyfit(x[alsoinliers], y[alsoinliers], order)
thiserr = np.sum(np.abs(polyval(bettermodel, x[alsoinliers])-y[alsoinliers]))
if (thiserr < besterr) or (besterr == -1.0):
bestfit = bettermodel
besterr = thiserr
#Since I can't return an empty array, I had to set it as an array with a zero and the ransac_control variable will tell if the function was able or not to find a good model
if (besterr == -1.0):
ransac_control = 0
return ransac_control, bestfit
else:
ransac_control = 1
return ransac_control, bestfit
**PS:无法发送cython代码的HTML页面图像,因为这是我的第一个问题
目前没有回答
相关问题 更多 >
编程相关推荐