我如何将所有的PyTorch链接到pybind11.so?

2024-09-30 18:15:57 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个PybDun11 C++项目,它使用PyrPrac C++ API:

#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <math.h>
#include <torch/torch.h>

...

void f()
{
...
   torch::Tensor dynamic_parameters = torch::full({1}, /*value=*/0.5, torch::dtype(torch::kFloat64).requires_grad(true));
   torch::optim::SGD optimizer({dynamic_parameters}, /*lr=*/0.01);
...
}

PYBIND11_MODULE(reson8, m)
{
    m.def("my_function", &my_function, "");
}

我使用distutils将其编译为。因此可以在Python中导入:

from distutils.core import setup, Extension

def configuration(parent_package='', top_path=None):
      import numpy
      from numpy.distutils.misc_util import Configuration
      from numpy.distutils.misc_util import get_info

      #Necessary for the half-float d-type.
      info = get_info('npymath')

      config = Configuration('',
                             parent_package,
                             top_path)
      config.add_extension('reson8',
                           ['reson8.cpp'],
                           extra_info=info,
                           include_dirs=["/home/ian/anaconda3/lib/python3.7/site-packages/pybind11/include",
                                          "/home/ian/anaconda3/lib/python3.8/site-packages/pybind11/include",
                                          "/home/ian/dev/hedgey/Engine/lib/libtorch/include",
                                          "/home/ian/dev/hedgey/Engine/lib/libtorch/include/torch/csrc/api/include"])

      return config


if __name__ == "__main__":
      from numpy.distutils.core import setup
      setup(configuration=configuration)

它编译时没有错误,但在python中运行“import reson8”时,我遇到以下错误:

importerror: undefined symbol: _ZTVN5torch5optim9OptimizerE

我不确定Pytork是否还没有链接到我的so中(虽然.so是10mb,如果不包括Pytork,则相当大,但可能所有pybind11.so文件都很大)

我如何解决这个问题


Tags: fromimportinfonumpyconfighomeincludelib
1条回答
网友
1楼 · 发布于 2024-09-30 18:15:57

我最终发现我需要使用Anaconda版本的torchlib,而不是我自己的,以及Torch的CPP扩展。这是我的工作设置.py:

from distutils.core import setup, Extension
from torch.utils.cpp_extension import BuildExtension, CppExtension

def configuration(parent_package='', top_path=None):
      import numpy
      from numpy.distutils.misc_util import Configuration
      from numpy.distutils.misc_util import get_info

      #Necessary for the half-float d-type.
      info = get_info('npymath')

      config = Configuration('',
                             parent_package,
                             top_path)

      config.ext_modules.append(CppExtension(
                name='reson8',
                sources=['reson8.cpp'],
                extra_info=info,
                extra_compile_args=['-g', '-D_GLIBCXX_USE_CXX11_ABI=0'],
                extra_ldflags=['-ltorch_python'],
                include_dirs=["/home/ian/anaconda3/lib/python3.7/site-packages/pybind11/include",
                                          "/home/ian/anaconda3/lib/python3.8/site-packages/pybind11/include",
                                          "/home/ian/anaconda3/lib"
                                          ]
                                ))

      return config


if __name__ == "__main__":
      from numpy.distutils.core import setup
      setup(configuration=configuration)

相关问题 更多 >