
2024-10-03 19:25:08 发布

您现在位置:Python中文网/ 问答频道 /正文


How to extract sklearn decision tree rules to pandas boolean conditions?


First of all let's use the scikit documentation on decision tree structure to get information about the tree that was constructed:

n_nodes = clf.tree_.node_count
children_left = clf.tree_.children_left
children_right = clf.tree_.children_right
feature = clf.tree_.feature
threshold = clf.tree_.threshold

We then define two recursive functions. The first one will find the path from the tree's root to create a specific node (all the leaves in our case). The second one will write the specific rules used to create a node using its creation path:

def find_path(node_numb, path, x):
        if node_numb == x:
            return True
        left = False
        right = False
        if (children_left[node_numb] !=-1):
            left = find_path(children_left[node_numb], path, x)
        if (children_right[node_numb] !=-1):
            right = find_path(children_right[node_numb], path, x)
        if left or right :
            return True
        return False

def get_rule(path, column_names):
    mask = ''
    for index, node in enumerate(path):
        #We check if we are not in the leaf
        if index!=len(path)-1:
            # Do we go under or over the threshold ?
            if (children_left[node] == path[index+1]):
                mask += "(df['{}']<= {}) \t ".format(column_names[feature[node]], threshold[node])
                mask += "(df['{}']> {}) \t ".format(column_names[feature[node]], threshold[node])
    # We insert the & at the right places
    mask = mask.replace("\t", "&", mask.count("\t") - 1)
    mask = mask.replace("\t", "")
    return mask

Finally, we use those two functions to first store the path of creation of each leaf. And then to store the rules used to create each leaf :

Leaves leave_id = clf.apply(X_test)

paths ={} for leaf in np.unique(leave_id):
    path_leaf = []
    find_path(0, path_leaf, leaf)
    paths[leaf] = np.unique(np.sort(path_leaf))

rules = {} for key in paths:
    rules[key] = get_rule(paths[key], pima.columns)

With the data you gave the output is:

rules = {3: "(df['insulin']<= 127.5) & (df['bp']<= 26.450000762939453) & (df['bp']<= 9.100000381469727)  ",  
4: "(df['insulin']<= 127.5) & (df['bp']<= 26.450000762939453) & (df['bp']> 9.100000381469`727)",  
6: "(df['insulin']<= 127.5) & (df['bp']> 26.450000762939453) & (df['skin']<= 27.5)  ",  
7: "(df['insulin']<= 127.5) & (df['bp']> 26.450000762939453 & (df['skin']> 27.5)  ",  
10: "(df['insulin']> 127.5) & (df['bp']<= 28.149999618530273) &(df['insulin']<= 145.5)  ",  
11: "(df['insulin']> 127.5) & (df['bp']<= 28.149999618530273) & (df['insulin']> 145.5)  ",  
13: "(df['insulin']> 127.5) & (df['bp']> 28.149999618530273) & (df['insulin']<= 158.5)  ",  
14: "(df['insulin']> 127.5) & (df['bp']> 28.149999618530273) & (df['insulin']> 158.5)  "}

Since the rules are strings, you can't directly call them using df[rules[3]], you have to use the eval function like so df[eval(rules[3])]



Tags: thetopathrightnodetreedfmask
1楼 · 发布于 2024-10-03 19:25:08


node_id = 3

def datatree_path_summarystats(node_id):
    for k, v in paths.items():
        if node_id in v:
            d = k,v

    ruleskey = d[0]
    numberofsteps = sum(map(lambda x : x<node_id, d[1]))

    for k, v in rules.items():
        if k == ruleskey:
            b = k,v

    stringsubset = b[1]

    datasubset = "&".join(stringsubset.split('&')[:numberofsteps])
    return datasubset

datasubset = datatree_path_summarystats(node_id)



相关问题 更多 >