代码如下:
# coding=utf-8 from __future__ import print_function
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
import sklearn
import sklearn.datasets
import sklearn.ensemble
import numpy as np
import lime
import lime.lime_tabular
class Randomforest:
def __init__(self):
self.trained_model = None
self.clf = None
pass
def split_dataset(self, dataset, train_percentage, feature_headers, target_header):
# Split dataset into train and test dataset
self.train_x, self.test_x, self.train_y, self.test_y = train_test_split(dataset[feature_headers], dataset[target_header],
train_size=train_percentage)
return self.train_x, self.test_x, self.train_y, self.test_y
def random_forest_classifier(self, features, target):
self.clf = RandomForestClassifier()
self.clf.fit(features, target)
return self.clf
def predictProba(self, input):
return self.clf.predict_proba(input)
def predict_proba(self, input):
return self.predictProba(input)
rf = Randomforest()
Headers = ["vectorName", "abmessungen_Lange", "starrflugler", "tragflachen", "triebwerke", "rumpf", "leitwerk", "drehflugler",
"drehflugler_Rumpf_Cockpit", "doppeldecker", "tragflachen_Stellung_Gerade","hochDecker","triebwerke_triebwerksart",
"rumpf_Rumpfformen","drehflugler_Rotor","drehflugler_Triebwerk","drehflugler_Rumpf","drehflugler_Heckausleger",
"drehflugler_Triebwerk_Lufteinlass","drehflugler_Triebwerk_Luftauslass", "result"]
dataset = pd.read_csv("filename.csv")
train_x, test_x, train_y, test_y = rf.split_dataset(dataset, 0.7, Headers[1:-1], Headers[-1])
trained_model = rf.random_forest_classifier(train_x, train_y)
predictions = trained_model.predict(test_x)
feature_names = Headers[1:-1]
class_names = ['1', '0']
explainer = lime.lime_tabular.LimeTabularExplainer(train_x, feature_names= feature_names, class_names=class_names,
categorical_features= None, categorical_names=None, discretize_continuous=True,
kernel_width=3)
误差如下:
^{pr2}$我在运行如上面代码所示的leme解释器之后遇到这个错误。我不确定这是熊猫包错误还是石灰包错误。我尝试了很多方法来调试它,比如把一个数据帧而不是一个列表发送到leme explainer函数中等等,但是没有一个有效。如果有人能尽早告诉我一个问题,那就太好了。谢谢。在
问题是因为你正在传递一个pandas数据帧(我想)我也有同样的问题。在
最简单的方法是用
train_x.values
希望这次比赛
干杯
相关问题 更多 >
编程相关推荐