随机森林和人工林的最优行为提取

oae的Python项目详细描述


OAE公司

This package implements this paper in which the author tries to address the problem of interpretability and actionability of tree-based models. The author of the paper presents a novel framework to post-process any tree-based classifier to extract an optimal actionable plan that can change a given input to a desired class with a minimum cost. Currently this package only supports scikit-learn's implementation of Random Forest.

安装

pip install oae

如何使用

importnumpyasnpimportpandasaspdfromoae.coreimport*fromoae.treeimport*fromoae.optimizerimport*fromsklearn.ensembleimportRandomForestClassifierfromsklearn.model_selectionimporttrain_test_splitasttsfromsklearn.metricsimportaccuracy_score,roc_auc_scoreSEED=41np.random.seed(SEED)
^{pr2}$
data.target.value_counts(normalize=True)
2    0.655222
4    0.344778
Name: target, dtype: float64

benign表示为2和{}分别转换为0和{}。在

# convert benignalbls,lbl_map=pd.factorize(data['target'])

让我们看看特性的数据类型

data.dtypes
code_number                     int64
clump_thickness                 int64
cell_size_uniformity            int64
cell_shape_uniformity           int64
marginal_adhesion               int64
single_epithelial_cell_size     int64
bare_nuclei                    object
bland_chromatin                 int64
normal_nucleoli                 int64
mitoses                         int64
target                          int64
dtype: object
data.bare_nuclei.unique()
array(['1', '10', '2', '4', '3', '9', '7', '?', '5', '8', '6'],
      dtype=object)

让我们将这个?替换为-1,并像其他人一样将其转换成int64

data=data.assign(bare_nuclei=data.bare_nuclei.str.replace('?','-1').astype(np.int))data=data.assign(target=lbls);data.head()
.dataframe tbody tr th {
    vertical-align: top;
}

.dataframe thead th {
    text-align: right;
}
data.iloc[:,1:-1].nunique()
clump_thickness                10
cell_size_uniformity           10
cell_shape_uniformity          10
marginal_adhesion              10
single_epithelial_cell_size    10
bare_nuclei                    11
bland_chromatin                10
normal_nucleoli                10
mitoses                         9
dtype: int64

所有感兴趣的特征(不包括代码编号和目标)都是范畴变量。让我们创建一个保持集并训练一个随机森林分类器。在

SEED=41np.random.seed(SEED)features=data.columns[1:-1]Xtr,Xte,ytr,yte=tts(data.loc[:,features],data.target,test_size=.2,random_state=SEED)
clf=RandomForestClassifier(n_estimators=10,n_jobs=-1,random_state=SEED)clf.fit(Xtr,ytr)print(f'train accuracy: {accuracy_score(ytr,clf.predict(Xtr))}')print(f'holdout accuracy: {accuracy_score(yte,clf.predict(Xte))}')
train accuracy: 0.998211091234347
holdout accuracy: 0.9714285714285714

让我们从holdout set中选择一个实例并查看地面。我们意识到分类器将其标记为malignant,我们想知道哪些特征可以改变,以便分类器将其标记为benign。在

instanceidx=4print(yte.iloc[instanceidx],' ',clf.predict_proba(Xte.iloc[instanceidx:instanceidx+1]))
1   [[0. 1.]]

现在,我们将尝试通过将此问题作为整数线性规划问题来提取最优行动问题。在

atm=ATMSKLEARN(clf,data.loc[:,features].values)instance=Instance(Xte.iloc[instanceidx],['categorical']*len(features))

我们将使用下面的成本函数,以便我们的OAE问题最小化更改特征的数量,即汉明距离。

但是我们不需要把自己局限于这个特定的成本函数,你可以设计你的成本函数并将其传递给求解器。在

在这个例子中,我们的输入具有基本标签1,我们希望找出如何以最小代价调整特征,以便分类器将其分类为标签0,其中z是目标阈值。在

$F(x)=\frac{1}{w{t}\sum{k=1}^{m}h}{t,k}\phi{t,k}\geq z$,其中$h{t}\R$

$F(x)$表示随机森林分类器的概率估计。在

opt=Optimizer(cost_matrix,combine,z=0.45,class_=0)v_i_j_sol,phi_t_k_sol=opt.solve(atm,instance)

该包将有助于建议对该特性进行的更改,以将其从malignant分类为benign。在

^{pr21}$
['no change, current value: 5',
 'no change, current value: 3',
 'no change, current value: 5',
 'no change, current value: 1',
 'no change, current value: 8',
 'current value: 10, proposed change: [-1, 1]',
 'current value: 5, proposed change: [3, 4]',
 'no change, current value: 3',
 'no change, current value: 1']

我们提取了一个行动计划,它说我们需要将当前值为1010更改为-1,将6th feature更改为3,然后我们的分类器将其分类为标签0。让我们来看看。在

X_transformed=atm.transform(v_i_j_sol,instance);X_transformed
.dataframe tbody tr th {
    vertical-align: top;
}

.dataframe thead th {
    text-align: right;
}
clf.predict_proba(X_transformed)
array([[0.6, 0.4]])

实际上,我们可以看到分类器将其标记为0,而且概率也大于z=0.45,因此它也满足了这一关注点。在

应用

  • 一个例子可以是在目标市场营销中,我们可以使用每个客户产生的行动计划来更好地了解我们可以利用哪些杠杆来获得预期的结果。在

欢迎加入QQ群-->: 979659372 Python中文网_新手群

推荐PyPI第三方库


热门话题
JFrame中的Java多线程   java Servlet异常映射   java无法从输出流读取   swing Java带来的小程序GUI问题   java什么原因导致错误“'void'类型此处不允许”以及如何修复它?   Java选择器select(长)与selectNow的区别   java自定义arraylist<mygames>获得不同   java Icepdf注释让页面消失   java反向整数数组   java I在生成同步“无法解析配置的所有依赖项”时遇到此错误:app:debugRuntimeClasspath   多个虚拟机上的java线程访问单个DB实例上的表,有时会导致性能低下和异常   swing更改Java中的默认按钮,使其看起来“更好”   java慢速MQ主题订阅。并行化不能提高性能   java运行Boggle Solver需要一个多小时。我的代码怎么了?   数据库中的java循环与应用程序中的java循环   正则表达式匹配${123…456}并在Java中提取2个数字?   java如何制作我们软件的试用版   Java内存参数计算   从另一个类调用方法时出现java问题