我可以从决策树中经过训练的树中提取底层决策规则(或‘决策路径’)作为文本列表吗?
类似于:
if A>0.4 then if B<0.2 then if C>0.8 then class='X'
谢谢你的帮助。
发布于 2016-09-29 21:49:10
我认为这个答案比这里的其他答案更正确:
from sklearn.tree import _tree
def tree_to_code(tree, feature_names):
tree_ = tree.tree_
feature_name = [
feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
for i in tree_.feature
]
print "def tree({}):".format(", ".join(feature_names))
def recurse(node, depth):
indent = " " * depth
if tree_.feature[node] != _tree.TREE_UNDEFINED:
name = feature_name[node]
threshold = tree_.threshold[node]
print "{}if {} <= {}:".format(indent, name, threshold)
recurse(tree_.children_left[node], depth + 1)
print "{}else: # if {} > {}".format(indent, name, threshold)
recurse(tree_.children_right[node], depth + 1)
else:
print "{}return {}".format(indent, tree_.value[node])
recurse(0, 1)
这将打印出一个有效的Python函数。下面是一个树的示例输出,它试图返回其输入,一个介于0和10之间的数字。
def tree(f0):
if f0 <= 6.0:
if f0 <= 1.5:
return [[ 0.]]
else: # if f0 > 1.5
if f0 <= 4.5:
if f0 <= 3.5:
return [[ 3.]]
else: # if f0 > 3.5
return [[ 4.]]
else: # if f0 > 4.5
return [[ 5.]]
else: # if f0 > 6.0
if f0 <= 8.5:
if f0 <= 7.5:
return [[ 7.]]
else: # if f0 > 7.5
return [[ 8.]]
else: # if f0 > 8.5
return [[ 9.]]
以下是我在其他答案中看到的一些绊脚石:
使用tree_.threshold == -2
判断一个节点是否是叶子并不是一个好主意。如果它是一个阈值为-2的实际决策节点呢?相反,您应该查看tree.feature
或者tree.children_*
..。
这条线features = [feature_names[i] for i in tree_.feature]
在我的sklearn版本中崩溃,因为tree.tree_.feature
是-2 (专用于叶节点)。
在递归函数中不需要有多个if语句,只需要一个就可以了。
发布于 2014-03-08 05:31:16
我创建了自己的函数来从sklearn创建的决策树中提取规则:
import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeClassifier
# dummy data:
df = pd.DataFrame({'col1':[0,1,2,3],'col2':[3,4,5,6],'dv':[0,1,0,1]})
# create decision tree
dt = DecisionTreeClassifier(max_depth=5, min_samples_leaf=1)
dt.fit(df.ix[:,:2], df.dv)
此函数首先从节点(在子数组中由-1标识)开始,然后递归查找父节点。我称之为节点的“谱系”。在此过程中,我获取了需要创建if/then/else SAS逻辑的值:
def get_lineage(tree, feature_names):
left = tree.tree_.children_left
right = tree.tree_.children_right
threshold = tree.tree_.threshold
features = [feature_names[i] for i in tree.tree_.feature]
# get ids of child nodes
idx = np.argwhere(left == -1)[:,0]
def recurse(left, right, child, lineage=None):
if lineage is None:
lineage = [child]
if child in left:
parent = np.where(left == child)[0].item()
split = 'l'
else:
parent = np.where(right == child)[0].item()
split = 'r'
lineage.append((parent, split, threshold[parent], features[parent]))
if parent == 0:
lineage.reverse()
return lineage
else:
return recurse(left, right, parent, lineage)
for child in idx:
for node in recurse(left, right, child):
print node
下面的元组集包含创建SAS if/then/else语句所需的所有内容。我不喜欢用do
块,这就是我创建描述节点整个路径的逻辑的原因。元组后面的单个整数是路径中终端节点的ID。前面的所有元组组合在一起创建该节点。
In [1]: get_lineage(dt, df.columns)
(0, 'l', 0.5, 'col1')
1
(0, 'r', 0.5, 'col1')
(2, 'l', 4.5, 'col2')
3
(0, 'r', 0.5, 'col1')
(2, 'r', 4.5, 'col2')
(4, 'l', 2.5, 'col1')
5
(0, 'r', 0.5, 'col1')
(2, 'r', 4.5, 'col2')
(4, 'r', 2.5, 'col1')
6
发布于 2015-05-07 22:58:43
我修改了提交的代码Zelazny7要打印一些伪代码,请执行以下操作:
def get_code(tree, feature_names):
left = tree.tree_.children_left
right = tree.tree_.children_right
threshold = tree.tree_.threshold
features = [feature_names[i] for i in tree.tree_.feature]
value = tree.tree_.value
def recurse(left, right, threshold, features, node):
if (threshold[node] != -2):
print "if ( " + features[node] + " <= " + str(threshold[node]) + " ) {"
if left[node] != -1:
recurse (left, right, threshold, features,left[node])
print "} else {"
if right[node] != -1:
recurse (left, right, threshold, features,right[node])
print "}"
else:
print "return " + str(value[node])
recurse(left, right, threshold, features, 0)
如果你调用get_code(dt, df.columns)
在同一示例中,您将获得:
if ( col1 <= 0.5 ) {
return [[ 1. 0.]]
} else {
if ( col2 <= 4.5 ) {
return [[ 0. 1.]]
} else {
if ( col1 <= 2.5 ) {
return [[ 1. 0.]]
} else {
return [[ 0. 1.]]
}
}
}
https://stackoverflow.com/questions/20224526
复制相似问题