ValueError:使用Xgboos时,mat的类型必须为numpy.ndarray

2024-06-01 09:58:39 发布

您现在位置:Python中文网/ 问答频道 /正文

我用Xgboost创建了一个机器学习模型,效果很好。我想使用Treelite库(https://treelite.readthedocs.io/en/latest/tutorials/first.html)在C代码中转换这个模型

我已经按照上面链接中提供的文档编写了代码

当我在treelite运行时模块中提供Xgboost dfeatures时,它给出了这个错误

Traceback (most recent call last):
  File "treelite_model.py", line 176, in <module>
    predict(data)
  File "treelite_model.py", line 157, in predict
    batch = treelite.runtime.Batch.from_npy2d(dfeatures)
  File "/Users/karim/Documents/vnev/p36_avidhrt/treelite/python/treelite/runtime/../../../runtime/native/python/treelite_runtime/predictor.py", line 117, in from_npy2d
    raise ValueError('mat must be of type numpy.ndarray')
ValueError: mat must be of type numpy.ndarray

下面是代码片段

def predict():
    data = genfromtxt('AFIB.csv', delimiter=',')
    features_noise = np.zeros((5, ))

    snr, rr_num, var, fr, fr2 = find_noise_features(data)
    features_noise[0] = snr
    features_noise[1] = rr_num
    features_noise[2] = var
    features_noise[3] = fr
    features_noise[4] = fr2
    features = extract_basic_features(data, 30000)
    features = np.hstack((features, features_noise.reshape(1, -1)))

    bst = xgb.Booster({'nthread': 4})
    bst.load_model("xgb_model.bin")
    dfeatures = xgb.DMatrix(features)
    prediction = bst.predict(dfeatures,ntree_limit=420)
    prediction = prediction.astype('int8')
    result = data_preprocess.encoder.inverse_transform(prediction)
    # print(prediction)
    # print(result)
    model = treelite.Model.from_xgboost(bst)
    toolchain = 'gcc'
    # model.export_lib(toolchain=toolchain, libpath='./afibmodel.dylib', verbose=True)
    model.export_lib(toolchain=toolchain, libpath='./afibmodel.dylib',params={'parallel_comp': 32}, verbose=True)
    predictor = treelite.runtime.Predictor('./afibmodel.dylib', verbose=True)

    # batch = treelite.runtime.Batch.from_csr(dfeatures)
    batch = treelite.runtime.Batch.from_npy2d(dfeatures)

    out_pred = predictor.predict(batch)
    print(out_pred)

Tags: 代码fromdatamodelbatchpredictfileruntime