此问题类似于(可能是简单的扩展)此处链接的问题:
How to extract sklearn decision tree rules to pandas boolean conditions?
以上环节的解决方案综合如下:
First of all let's use the scikit documentation on decision tree structure to get information about the tree that was constructed:
n_nodes = clf.tree_.node_count children_left = clf.tree_.children_left children_right = clf.tree_.children_right feature = clf.tree_.feature threshold = clf.tree_.threshold
We then define two recursive functions. The first one will find the path from the tree's root to create a specific node (all the leaves in our case). The second one will write the specific rules used to create a node using its creation path:
def find_path(node_numb, path, x): path.append(node_numb) if node_numb == x: return True left = False right = False if (children_left[node_numb] !=-1): left = find_path(children_left[node_numb], path, x) if (children_right[node_numb] !=-1): right = find_path(children_right[node_numb], path, x) if left or right : return True path.remove(node_numb) return False def get_rule(path, column_names): mask = '' for index, node in enumerate(path): #We check if we are not in the leaf if index!=len(path)-1: # Do we go under or over the threshold ? if (children_left[node] == path[index+1]): mask += "(df['{}']<= {}) \t ".format(column_names[feature[node]], threshold[node]) else: mask += "(df['{}']> {}) \t ".format(column_names[feature[node]], threshold[node]) # We insert the & at the right places mask = mask.replace("\t", "&", mask.count("\t") - 1) mask = mask.replace("\t", "") return mask
Finally, we use those two functions to first store the path of creation of each leaf. And then to store the rules used to create each leaf :
Leaves leave_id = clf.apply(X_test) paths ={} for leaf in np.unique(leave_id): path_leaf = [] find_path(0, path_leaf, leaf) paths[leaf] = np.unique(np.sort(path_leaf)) rules = {} for key in paths: rules[key] = get_rule(paths[key], pima.columns)
With the data you gave the output is:
rules = {3: "(df['insulin']<= 127.5) & (df['bp']<= 26.450000762939453) & (df['bp']<= 9.100000381469727) ", 4: "(df['insulin']<= 127.5) & (df['bp']<= 26.450000762939453) & (df['bp']> 9.100000381469`727)", 6: "(df['insulin']<= 127.5) & (df['bp']> 26.450000762939453) & (df['skin']<= 27.5) ", 7: "(df['insulin']<= 127.5) & (df['bp']> 26.450000762939453 & (df['skin']> 27.5) ", 10: "(df['insulin']> 127.5) & (df['bp']<= 28.149999618530273) &(df['insulin']<= 145.5) ", 11: "(df['insulin']> 127.5) & (df['bp']<= 28.149999618530273) & (df['insulin']> 145.5) ", 13: "(df['insulin']> 127.5) & (df['bp']> 28.149999618530273) & (df['insulin']<= 158.5) ", 14: "(df['insulin']> 127.5) & (df['bp']> 28.149999618530273) & (df['insulin']> 158.5) "}
Since the rules are strings, you can't directly call them using df[rules[3]], you have to use the eval function like so df[eval(rules[3])]
上面发布的解决方案在为每个终止节点找到路径方面非常有效。我想知道是否有可能将每个节点(叶子和终止节点)的路径存储为与上面链接中完全相同的输出格式(字典/列表格式)
谢谢
所以我找到了一个解决问题的方法(虽然我不认为这是最好的/最有效的方法),但它也不是我问题的直接答案(我没有存储每个节点的路径-只是创建一个函数来解析存储的信息)。它是上述解决方案的第二部分,允许您为要查找的特定节点提取子集数据
此函数在包含要查找的节点id的路径中运行。然后,它将根据节点数拆分规则,创建逻辑以基于一个特定节点对数据帧进行子集划分
相关问题 更多 >
编程相关推荐