将Python scikit learn模型导出到pmm

2024-05-19 17:03:52 发布

您现在位置:Python中文网/ 问答频道 /正文

我想将python scikit学习模型导出到PMML中。

哪个python包最适合?

我读到关于Augustus的文章,但是我找不到使用scikit学习模型的任何示例。


Tags: 模型示例文章scikitpmmlaugustus
3条回答

你可以试试Nyoka。出口SKL型号和一些。

Nyoka是一个python库,支持Scikit-learnXGBoostLightGBMKerasStatsmodels

Besides about 500 Python classes which each cover a PMML tag and all constructor parameters/attributes as defined in the standard, Nyoka also provides an increasing number of convenience classes and functions that make the Data Scientist’s life easier for example by reading or writing any PMML file in one line of code from within your favorite Python environment.

可以使用以下命令从PyPi安装它:

pip install nyoka

示例代码

例1

import pandas as pd
from sklearn import datasets
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler, Imputer
from sklearn_pandas import DataFrameMapper
from sklearn.ensemble import RandomForestClassifier

iris = datasets.load_iris()
irisd = pd.DataFrame(iris.data, columns=iris.feature_names)
irisd['Species'] = iris.target

features = irisd.columns.drop('Species')
target = 'Species'

pipeline_obj = Pipeline([
    ("mapping", DataFrameMapper([
    (['sepal length (cm)', 'sepal width (cm)'], StandardScaler()) , 
    (['petal length (cm)', 'petal width (cm)'], Imputer())
    ])),
    ("rfc", RandomForestClassifier(n_estimators = 100))
])

pipeline_obj.fit(irisd[features], irisd[target])

from nyoka import skl_to_pmml

skl_to_pmml(pipeline_obj, features, target, "rf_pmml.pmml")

例2

from keras import applications
from keras.layers import Flatten, Dense
from keras.models import Model

model = applications.MobileNet(weights='imagenet', include_top=False,input_shape = (224, 224,3))

activType='sigmoid'
x = model.output
x = Flatten()(x)
x = Dense(1024, activation="relu")(x)
predictions = Dense(2, activation=activType)(x)
model_final = Model(inputs =model.input, outputs = predictions,name='predictions')

from nyoka import KerasToPmml
cnn_pmml = KerasToPmml(model_final,dataSet='image',predictedClasses=['cats','dogs'])

cnn_pmml.export(open('2classMBNet.pmml', "w"), 0)

更多的例子可以在Nyoka's Github Page中找到。

^{}

a thin wrapper around the JPMML-SkLearn command-line application. For a list of supported Scikit-Learn Estimator and Transformer types, please refer to the documentation of the JPMML-SkLearn project.

正如@user1808924所指出的,它支持Python 2.7或3.4+。它还需要Java1.7+

安装方式:(需要git

pip install git+https://github.com/jpmml/sklearn2pmml.git

如何将分类器树导出到PMML的示例。首先生长该树:

# example tree & viz from http://scikit-learn.org/stable/modules/tree.html
from sklearn import datasets, tree
iris = datasets.load_iris()
clf = tree.DecisionTreeClassifier() 
clf = clf.fit(iris.data, iris.target)

SkLearn2PMML转换分为两部分:估计器(ourclf)和映射器(用于离散化或PCA等预处理步骤)。我们的映射程序非常基本,因为我们不做任何转换。

from sklearn_pandas import DataFrameMapper
default_mapper = DataFrameMapper([(i, None) for i in iris.feature_names + ['Species']])

from sklearn2pmml import sklearn2pmml
sklearn2pmml(estimator=clf, 
             mapper=default_mapper, 
             pmml="D:/workspace/IrisClassificationTree.pmml")

有可能(虽然没有文档记录)通过mapper=None,但您将看到预测值名称丢失(返回x1而不是sepal length等)。

让我们看看.pmml文件:

<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
<PMML xmlns="http://www.dmg.org/PMML-4_3" version="4.3">
    <Header>
        <Application name="JPMML-SkLearn" version="1.1.1"/>
        <Timestamp>2016-09-26T19:21:43Z</Timestamp>
    </Header>
    <DataDictionary>
        <DataField name="sepal length (cm)" optype="continuous" dataType="float"/>
        <DataField name="sepal width (cm)" optype="continuous" dataType="float"/>
        <DataField name="petal length (cm)" optype="continuous" dataType="float"/>
        <DataField name="petal width (cm)" optype="continuous" dataType="float"/>
        <DataField name="Species" optype="categorical" dataType="string">
            <Value value="setosa"/>
            <Value value="versicolor"/>
            <Value value="virginica"/>
        </DataField>
    </DataDictionary>
    <TreeModel functionName="classification" splitCharacteristic="binarySplit">
        <MiningSchema>
            <MiningField name="Species" usageType="target"/>
            <MiningField name="sepal length (cm)"/>
            <MiningField name="sepal width (cm)"/>
            <MiningField name="petal length (cm)"/>
            <MiningField name="petal width (cm)"/>
        </MiningSchema>
        <Output>
            <OutputField name="probability_setosa" dataType="double" feature="probability" value="setosa"/>
            <OutputField name="probability_versicolor" dataType="double" feature="probability" value="versicolor"/>
            <OutputField name="probability_virginica" dataType="double" feature="probability" value="virginica"/>
        </Output>
        <Node id="1">
            <True/>
            <Node id="2" score="setosa" recordCount="50.0">
                <SimplePredicate field="petal width (cm)" operator="lessOrEqual" value="0.8"/>
                <ScoreDistribution value="setosa" recordCount="50.0"/>
                <ScoreDistribution value="versicolor" recordCount="0.0"/>
                <ScoreDistribution value="virginica" recordCount="0.0"/>
            </Node>
            <Node id="3">
                <SimplePredicate field="petal width (cm)" operator="greaterThan" value="0.8"/>
                <Node id="4">
                    <SimplePredicate field="petal width (cm)" operator="lessOrEqual" value="1.75"/>
                    <Node id="5">
                        <SimplePredicate field="petal length (cm)" operator="lessOrEqual" value="4.95"/>
                        <Node id="6" score="versicolor" recordCount="47.0">
                            <SimplePredicate field="petal width (cm)" operator="lessOrEqual" value="1.6500001"/>
                            <ScoreDistribution value="setosa" recordCount="0.0"/>
                            <ScoreDistribution value="versicolor" recordCount="47.0"/>
                            <ScoreDistribution value="virginica" recordCount="0.0"/>
                        </Node>
                        <Node id="7" score="virginica" recordCount="1.0">
                            <SimplePredicate field="petal width (cm)" operator="greaterThan" value="1.6500001"/>
                            <ScoreDistribution value="setosa" recordCount="0.0"/>
                            <ScoreDistribution value="versicolor" recordCount="0.0"/>
                            <ScoreDistribution value="virginica" recordCount="1.0"/>
                        </Node>
                    </Node>
                    <Node id="8">
                        <SimplePredicate field="petal length (cm)" operator="greaterThan" value="4.95"/>
                        <Node id="9" score="virginica" recordCount="3.0">
                            <SimplePredicate field="petal width (cm)" operator="lessOrEqual" value="1.55"/>
                            <ScoreDistribution value="setosa" recordCount="0.0"/>
                            <ScoreDistribution value="versicolor" recordCount="0.0"/>
                            <ScoreDistribution value="virginica" recordCount="3.0"/>
                        </Node>
                        <Node id="10">
                            <SimplePredicate field="petal width (cm)" operator="greaterThan" value="1.55"/>
                            <Node id="11" score="versicolor" recordCount="2.0">
                                <SimplePredicate field="sepal length (cm)" operator="lessOrEqual" value="6.95"/>
                                <ScoreDistribution value="setosa" recordCount="0.0"/>
                                <ScoreDistribution value="versicolor" recordCount="2.0"/>
                                <ScoreDistribution value="virginica" recordCount="0.0"/>
                            </Node>
                            <Node id="12" score="virginica" recordCount="1.0">
                                <SimplePredicate field="sepal length (cm)" operator="greaterThan" value="6.95"/>
                                <ScoreDistribution value="setosa" recordCount="0.0"/>
                                <ScoreDistribution value="versicolor" recordCount="0.0"/>
                                <ScoreDistribution value="virginica" recordCount="1.0"/>
                            </Node>
                        </Node>
                    </Node>
                </Node>
                <Node id="13">
                    <SimplePredicate field="petal width (cm)" operator="greaterThan" value="1.75"/>
                    <Node id="14">
                        <SimplePredicate field="petal length (cm)" operator="lessOrEqual" value="4.8500004"/>
                        <Node id="15" score="virginica" recordCount="2.0">
                            <SimplePredicate field="sepal width (cm)" operator="lessOrEqual" value="3.1"/>
                            <ScoreDistribution value="setosa" recordCount="0.0"/>
                            <ScoreDistribution value="versicolor" recordCount="0.0"/>
                            <ScoreDistribution value="virginica" recordCount="2.0"/>
                        </Node>
                        <Node id="16" score="versicolor" recordCount="1.0">
                            <SimplePredicate field="sepal width (cm)" operator="greaterThan" value="3.1"/>
                            <ScoreDistribution value="setosa" recordCount="0.0"/>
                            <ScoreDistribution value="versicolor" recordCount="1.0"/>
                            <ScoreDistribution value="virginica" recordCount="0.0"/>
                        </Node>
                    </Node>
                    <Node id="17" score="virginica" recordCount="43.0">
                        <SimplePredicate field="petal length (cm)" operator="greaterThan" value="4.8500004"/>
                        <ScoreDistribution value="setosa" recordCount="0.0"/>
                        <ScoreDistribution value="versicolor" recordCount="0.0"/>
                        <ScoreDistribution value="virginica" recordCount="43.0"/>
                    </Node>
                </Node>
            </Node>
        </Node>
    </TreeModel>
</PMML>

第一个分割(节点1)的花瓣宽度为0.8。节点2(花瓣宽度<;=0.8)捕获所有刚毛,没有其他。

您可以将pmml输出与graphviz输出进行比较:

from sklearn.externals.six import StringIO
import pydotplus # this might be pydot for python 2.7
dot_data = StringIO() 
tree.export_graphviz(clf, 
                     out_file=dot_data,  
                     feature_names=iris.feature_names,  
                     class_names=iris.target_names,  
                     filled=True, rounded=True,  
                     special_characters=True) 
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
graph.write_pdf("D:/workspace/iris.pdf") 
# for in-line display, you can also do:
# from IPython.display import Image  
# Image(graph.create_png())  

enter image description here

相关问题 更多 >