从git子模块继承的Cython,覆盖scikitlearn方法

2024-10-03 17:24:34 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在使用scikitlearn,并希望覆盖使用cython实现的回归树的treebuilder类的构建方法。为了做到这一点,我想我需要访问cython代码,所以我添加了scikitlearn作为git子模块。你知道吗

因此,我的项目结构如下:

.
|-- setup.py
|-- MyNewTree
|   |-- __init__.py
|   |-- MyNewTree.pyx
|   `-- scikitlearn
|       `-- sklearn
|           `-- tree
|               |-- __init__.py
|               |-- _tree.pxd
|               |-- _tree.pyx
|               |-- setup.py
|               `-- tree.py

在我的设置.py我正在做以下工作:

from setuptools import setup, find_packages
from setuptools.extension import Extension
from Cython.Build import cythonize
import numpy

extensions = [
    Extension(
        "newtree.MyNewTree",
        ["newtree/MyNewTree.pyx"],
        include_dirs=['modulenetwork/scikitlearn/sklearn/tree', numpy.get_include()]
    )
]

setup(
    name = 'MyNewTree',
    version = '0.0.1',
    packages = find_packages(),
    ext_modules = cythonize(extensions)
)

最后MyNewTree.pyx文件你知道吗

# cython: cdivision=True
# cython: boundscheck=False
# cython: wraparound=False

import numpy as np
cimport numpy as np
np.import_array()

from .scikitlearn.sklearn.tree._tree cimport BestFirstTreeBuilder

cdef class TreeBuilder(BestFirstTreeBuilder):
    cpdef build(self):
        print('This is an overridden build method!')

我希望它产生的是一个TreeBuilder类,它有一个不同于原始scikitlearn实现的构建方法,但是其他的都是相同的。你知道吗

要编译,我运行python setup.py build_ext --inplace

但是,我得到以下错误:

Error compiling Cython file:
------------------------------------------------------------
...

import numpy as np
cimport numpy as np
np.import_array()

from .scikitlearn.sklearn.tree._tree cimport BestFirstTreeBuilder
^
------------------------------------------------------------

newtree/MyNewTree.pyx:9:0: 'newtree/scikitlearn/sklearn/tree/_tree.pxd' not found

Error compiling Cython file:
------------------------------------------------------------
...

import numpy as np
cimport numpy as np
np.import_array()

from .scikitlearn.sklearn.tree._tree cimport BestFirstTreeBuilder
^
------------------------------------------------------------

newtree/MyNewTree.pyx:9:0: 'newtree/scikitlearn/sklearn/tree/_tree/BestFirstTreeBuilder.pxd' not found

Error compiling Cython file:
------------------------------------------------------------
...
cimport numpy as np
np.import_array()

from .scikitlearn.sklearn.tree._tree cimport BestFirstTreeBuilder

cdef class TreeBuilder(BestFirstTreeBuilder):
    ^
------------------------------------------------------------

newtree/MyNewTree.pyx:11:5: 'BestFirstTreeBuilder' is not a type name
Traceback (most recent call last):
  File "setup.py", line 18, in <module>
    ext_modules = cythonize(extensions)
  File "/Users/__/miniconda3/lib/python3.6/site-packages/Cython/Build/Dependencies.py", line 1039, in cythonize
    cythonize_one(*args)
  File "/Users/__/miniconda3/lib/python3.6/site-packages/Cython/Build/Dependencies.py", line 1161, in cythonize_one
    raise CompileError(None, pyx_file)
Cython.Compiler.Errors.CompileError: newtree/MyNewTree.pyx

显然,报告为不存在的文件实际上确实存在。我的安装脚本有问题吗?如何正确地将类导入到代码中?你知道吗


Tags: frompyimportnumpytreeasnpsklearn