优点:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特征数据。既能用于分类,也能用于回归
缺点:可能会产生过度匹配问题
【二十个问题的游戏】
游戏的规则很简单:参与游戏的一方在脑海里想某个事物,其他参与者向他提问题,只允许提20个问题,问题的答案也只能用对或错回答。问问题的人通过推断分解,逐步缩小待猜测事物的范围。决策树的工作原理与20个问题类似,用户输人一系列数据 ,然后给出游戏的答案。
我们经常使用决策树处理分类问题。近来的调查表明决策树也是最经常使用的数据挖掘算法。它之所以如此流行,一个很重要的原因就是使用者基本上不用了解机器学习算法,也不用深究它是如何工作的。
如果以前没有接触过决策树,完全不用担心,它的概念非常简单。即使不知道它也可以通过简单的图形了解其工作原理。
决策树分类的思想类似于找对象。现想象一个女孩的母亲要给这个女孩介绍男朋友,于是有了下面的对话:
女儿:多大年纪了?
母亲:26。
女儿:长的帅不帅?
母亲:挺帅的。
女儿:收入高不?
母亲:不算很高,中等情况。
女儿:是公务员不?
母亲:是,在税务局上班呢。
女儿:那好,我去见见。
这个女孩的决策过程就是典型的分类树决策。相当于通过年龄、长相、收入和是否公务员对将男人分为两个类别:见和不见。假设这个女孩对男人的要求是:30岁以下、长相中等以上并且是高收入者或中等以上收入的公务员,那么这个可以用下图表示女孩的决策逻辑:
上图完整表达了这个女孩决定是否见一个约会对象的策略,其中绿色节点表示判断条件,橙色节点表示决策结果,箭头表示在一个判断条件在不同情况下的决策路径,图中红色箭头表示了上面例子中女孩的决策过程。
这幅图基本可以算是一颗决策树,说它“基本可以算”是因为图中的判定条件没有量化,如收入高中低等等,还不能算是严格意义上的决策树,如果将所有条件量化,则就变成真正的决策树了。
有了上面直观的认识,我们可以正式定义决策树了:
决策树(decision tree)是一个树结构(可以是二叉树或非二叉树)。其每个非叶节点表示一个特征属性上的测试,每个分支代表这个特征属性在某个值域上的输出,而每个叶节点存放一个类别。使用决策树进行决策的过程就是从根节点开始,测试待分类项中相应的特征属性,并按照其值选择输出分支,直到到达叶子节点,将叶子节点存放的类别作为决策结果。
可以看到,决策树的决策过程非常直观,容易被人理解。目前决策树已经成功运用于医学、制造产业、天文学、分支生物学以及商业等诸多领域。
之前介绍的K-近邻算法可以完成很多分类任务,但是它最大的缺点就是无法给出数据的内在含义,决策树的主要优势就在于数据形式非常容易理解。
决策树算法能够读取数据集合,构建类似于上面的决策树。决策树很多任务都是为了数据中所蕴含的知识信息,因此决策树可以使用不熟悉的数据集合,并从中提取出一系列规则,机器学习算法最终将使用这些机器从数据集中创造的规则。专家系统中经常使用决策树,而且决策树给出结果往往可以匹敌在当前领域具有几十年工作经验的人类专家。
知道了决策树的定义以及其应用方法,下面介绍决策树的构造算法。
不同于逻辑斯蒂回归和贝叶斯算法,决策树的构造过程不依赖领域知识,它使用属性选择度量来选择将元组最好地划分成不同的类的属性。所谓决策树的构造就是进行属性选择度量确定各个特征属性之间的拓扑结构。
构造决策树的关键步骤是分裂属性。所谓分裂属性就是在某个节点处按照某一特征属性的不同划分构造不同的分支,其目标是让各个分裂子集尽可能地“纯”。尽可能“纯”就是尽量让一个分裂子集中待分类项属于同一类别。分裂属性分为三种不同的情况:
1、属性是离散值且不要求生成二叉决策树。此时用属性的每一个划分作为一个分支。
2、属性是离散值且要求生成二叉决策树。此时使用属性划分的一个子集进行测试,按照“属于此子集”和“不属于此子集”分成两个分支。
3、属性是连续值。此时确定一个值作为分裂点split_point,按照>split_point和<=split_point生成两个分支。
构造决策树的关键性内容是进行属性选择度量,属性选择度量是一种选择分裂准则,它决定了拓扑结构及分裂点split_point的选择。
属性选择度量算法有很多,一般使用自顶向下递归分治法,并采用不回溯的贪心策略。这里介绍常用的ID3算法。
划分数据集的大原则是:将无序的数据变得更加有序。
我们可以使用多种方法划分数据集,但是每种方法都有各自的优缺点。组织杂乱无章数据的一种方法就是使用信息论度量信息,信息论是量化处理信息的分支科学。我们可以在划分数据之前使用信息论量化度量信息的内容。
在划分数据集之前之后信息发生的变化称为信息增益,知道如何计算信息增益,我们就可以计算每个特征值划分数据集获得的信息增益,获得信息增益最高的特征就是最好的选择。
在可以评测哪种数据划分方式是最好的数据划分之前,我们必须学习如何计算信息增益。集合信息的度量方式称为香农熵或者简称为熵,这个名字来源于信息论之父克劳德•香农。
entropy
熵定义为信息的期望值,在明晰这个概念之前,我们必须知道信息的定义。如果待分类的事务可能划分在多个分类之中,则符号x的信息定义为:
其中p(x)是选择该分类的概率
为了计算熵,我们需要计算所有类别所有可能值包含的信息期望值,通过下面的公式得到:
其中n是分类的数目。
在决策树当中,设D为用类别对训练元组进行的划分,则D的熵(entropy)表示为:
其中pi表示第i个类别在整个训练元组中出现的概率,可以用属于此类别元素的数量除以训练元组元素总数量作为估计。熵的实际意义表示是D中元组的类标号所需要的平均信息量。
现在我们假设将训练元组D按属性A进行划分,则A对D划分的期望信息为:
而信息增益即为两者的差值:
ID3算法就是在每次需要分裂时,计算每个属性的增益率,然后选择增益率最大的属性进行分裂。下面我们继续用SNS社区中不真实账号检测的例子说明如何使用ID3算法构造决策树。为了简单起见,我们假设训练集合包含10个元素:
其中s、m和l分别表示小、中和大。
设L、F和H表示日志密度、好友密度、是否使用真实头像,下面计算各属性的信息增益。
因此日志密度的信息增益是0.276。
用同样方法得到F和H的信息增益分别为0.553和0.033。
因为F具有最大的信息增益,所以第一次分裂选择F为分裂属性,分裂后的结果如下图表示:
在上图的基础上,再递归使用这个方法计算子节点的分裂属性,最终就可以得到整个决策树。
计算上图的信息熵,确定下一个分类的特征
import numpy as np
import math as math
#求解账号是否真实的熵
#no 0.3 yes 0.7
info_D = -(0.3*math.log2(0.3)+0.7*math.log2(0.7))
info_D
0.8812908992306927
#绘制二叉树:日志密度,好友密度,是否真实头像
# s 0.3 m 0.4 l 0.3
#s ----> 2/3 no 1/3 yes
# m-----> 1/4 no 3/4 yes
# l----> 0 no 1 yes
info_L_D = -(0.3*(2/3*math.log2(2/3)+1/3*math.log2(1/3))
+ 0.4*(1/4*math.log2(1/4)+3/4*math.log2(3/4))
+ 0.3*(1*math.log2(1)))
info_L_D
0.6
#信息增益
#按照日志密度进行划分
info_D - info_L_D
0.2812908992306927
#好友密度进行划分,求解信息熵
# s 0.4 m 0.4 l 0.2
# s 3/4 no 1/4 yes
# m 1 yes
# l yes 1
info_F_D = -(0.4*(3/4*math.log2(3/4)+1/4*math.log2(1/4))
+ 0.4*(1*math.log2(1))
+ 0.2*(1*math.log2(1)))
info_F_D
0.32451124978365314
#求解以好友密度进行划分的信息增益
info_D - info_F_D
# 跟刚才的按照日志密度进行划分相比,按照好友密度进行划分,信息增益大,所以选择按照好友密度进行划分
0.5567796494470396
【注意】 参数max_depth越大,越容易过拟合
from sklearn.tree import DecisionTreeClassifier
import sklearn.datasets as datasets
iris = datasets.load_iris()
x_data = iris.data
y_target = iris.target
from sklearn.model_selection import train_test_split
X_train,x_test,y_train,y_test = train_test_split(x_data,y_target,test_size = 0.1)
使用决策树算法
#max_depth 不进行声明,那么,有多少个属性,树的深度就是多少
#如果属性太多,此时就需要限定树的深度
#max_depth 最大的深度,属性200,max_depth 根据信息增益,进行选择,最重要的100个属性,然后进行数据的分类
tree = DecisionTreeClassifier(max_depth=5)
# tree.fit(X_train,y_train)
tree.fit(x_data,y_target).score(x_data,y_target)
1.0
tree.score(x_test,y_test)
0.73333333333333328
使用KNN算法
from sklearn.neighbors import KNeighborsClassifier
knn = KNeighborsClassifier(5)
knn.fit(x_data,y_target).score(x_data,y_target)
0.96666666666666667
使用逻辑斯蒂回归算法
# 决策树 :信息论 信息熵将划分展示
# 逻辑斯底:概率论
from sklearn.linear_model import LogisticRegression
lrg = LogisticRegression()
lrg.fit(x_data,y_target).score(x_data,y_target)
0.95999999999999996
使用RandomState生成固定随机数
创建-100到100之间的角度
生成正弦值和余弦值
添加噪声
from sklearn.tree import DecisionTreeRegressor
#-100 ~ 100 无序
# np.sort() 排序
# 原始数据是二维的,所以,axis指定排序的维度
#sort 聚合操作
x_data = np.sort(200*np.random.rand(100,1) - 100,axis = 0)
x_data
array([[-98.25897143], [-97.61353655], [-93.7845092 ], [-92.638643 ], [-88.32920567], [-87.74078938], [-85.36860635], [-85.24688731], [-83.34651117], [-80.38829665], [-80.1209589 ], [-76.71435762], [-76.40968096], [-72.09652444], [-70.85309375], [-70.01671507], [-68.23069059], [-68.17500301], [-67.26146831], [-66.14433432], [-64.77602684], [-63.63303187], [-61.12805278], [-60.57529221], [-57.99092137], [-51.99071746], [-50.19521255], [-44.9918581 ], [-44.71462095], [-39.84789847], [-37.44971539], [-36.8924444 ], [-36.78921826], [-34.68277363], [-28.20571676], [-25.55771831], [-24.02080064], [-21.71298869], [-20.88656494], [-19.59548861], [-17.61823774], [-15.95916267], [-13.57216989], [-13.38717073], [-12.84505049], [-11.22206503], [-10.81969008], [ -9.30435017], [ -9.21530701], [ -8.5172761 ], [ -5.52538632], [ -4.34231591], [ -0.69508746], [ 9.8174906 ], [ 10.77651454], [ 12.5120284 ], [ 17.91758471], [ 19.00066293], [ 20.59787053], [ 21.78355563], [ 26.58743695], [ 33.43388355], [ 34.36501088], [ 34.8101705 ], [ 36.77783926], [ 36.80159782], [ 36.98169858], [ 39.02306412], [ 39.96763632], [ 40.69821854], [ 41.55507059], [ 47.71671998], [ 51.27621405], [ 55.58728132], [ 57.27564337], [ 57.96236177], [ 60.7774663 ], [ 61.67580216], [ 63.05125616], [ 67.07730481], [ 68.4992789 ], [ 69.44889837], [ 71.78064256], [ 71.92220066], [ 72.4169162 ], [ 74.15564907], [ 75.85083296], [ 81.39358055], [ 81.91868744], [ 82.01406723], [ 85.05218396], [ 87.549674 ], [ 88.95254518], [ 89.55029014], [ 90.0691102 ], [ 91.40841773], [ 93.06252197], [ 93.66440172], [ 94.27864841], [ 99.13917816]])
#使用随机创造的数据,生成圆上的点
dot_x = np.pi*np.sin(x_data)
dot_y = np.pi*np.cos(x_data)
y_target = np.c_[dot_x.ravel(),dot_y.ravel()]
y_target
array([[ 2.40039799, -2.02674466], [ 0.69834979, -3.06299069], [ 1.40390294, 2.81045565], [ 3.1392839 , -0.12042003], [-1.12024869, 2.93507194], [ 0.69724898, 3.06324146], [ 1.63028268, -2.68547627], [ 1.29215378, -2.86355426], [-3.12763298, -0.2958316 ], [ 3.02124737, 0.86120193], [ 3.14142451, 0.03250308], [-3.04027126, 0.79142596], [-2.66283282, 1.66701102], [-0.50084351, -3.10141261], [-3.09775105, -0.52301321], [-2.46416492, 1.94871641], [ 2.43002753, 1.99112296], [ 2.53708412, 1.852784 ], [ 3.01682494, -0.87656812], [ 0.53425316, -3.09583235], [-2.9251391 , -1.14593439], [-2.25621917, 2.18611057], [ 3.11384612, -0.41661341], [ 2.43138888, -1.98946035], [-3.11567374, 0.40271768], [-3.10420152, -0.48325701], [ 0.2205778 , 3.13383947], [-2.65966684, 1.67205762], [-2.10046718, 2.33615968], [-2.63127963, -1.71638337], [ 0.77540516, 3.0443967 ], [ 2.26818391, 2.17369412], [ 2.48009394, 1.92840308], [ 0.39247032, -3.11698114], [-0.21539794, -3.13419976], [-1.29527834, 2.86214228], [ 2.81662601, 1.39148207], [-0.86263966, -3.02083717], [-2.8063128 , -1.41216602], [-2.13206423, 2.30735921], [ 2.96229747, 1.04613485], [ 0.78089279, -3.04299373], [-2.65335831, 1.68205056], [-2.29867648, 2.14142262], [-0.86421034, 3.0203882 ], [ 3.06135773, 0.70547379], [ 3.09312481, -0.54971203], [-0.37742122, -3.11883915], [-0.65327044, -3.07292078], [-2.47547465, -1.93432925], [ 2.15929329, 2.28189765], [ 2.92890972, -1.13626239], [-2.01204121, 2.41273591], [-1.20227481, -2.90243685], [-3.06651553, -0.68270558], [-0.17063709, 3.13695511], [-2.52206066, 1.87318297], [ 0.47291216, 3.10579434], [ 3.09222241, -0.55476568], [ 0.64749834, -3.07414221], [ 3.12044311, 0.36392195], [ 2.83270647, -1.35844708], [ 0.60105412, -3.08355936], [-0.78531012, -3.04185674], [-2.50187621, 1.90005784], [-2.45603173, 1.958957 ], [-2.06540159, 2.36721791], [ 3.04636592, 0.76763211], [ 2.40741432, -2.01840548], [ 0.44611971, -3.10975588], [-2.0581787 , -2.37350055], [-1.7552408 , -2.60551993], [ 2.6616224 , 1.66894296], [-2.57606179, 1.79819634], [ 2.08794297, 2.34735991], [ 3.10290328, 0.49152378], [-2.78135005, -1.46071774], [-2.87524509, 1.26592656], [ 0.68375841, 3.06628095], [-2.80528889, -1.41419894], [-1.81451694, 2.5645921 ], [ 1.02947578, 2.96812803], [ 1.43953199, -2.79237391], [ 1.0311686 , -2.96754035], [-0.50139733, -3.10132312], [-2.97388529, 1.01272438], [ 1.37386107, 2.82526285], [-0.89180601, 3.01235563], [ 0.73845712, 3.0535693 ], [ 1.02590807, 2.96936307], [-0.71371108, -3.05944781], [-1.2664294 , 2.87502365], [ 2.62291537, 1.72913834], [ 3.14124395, -0.04680645], [ 2.70466227, -1.59825105], [-0.93509697, -2.99919957], [-2.91098715, 1.18142211], [-1.73053462, 2.62199434], [ 0.09696181, 3.14009599], [-3.09139166, 0.55937645]])
#扰乱数据
#添加噪声
y_target[::5] += np.random.randn(20,2)*0.2
import matplotlib.pyplot as plt
%matplotlib inline
plt.scatter(y_target[:,0],y_target[:,1])
<matplotlib.collections.PathCollection at 0x7fa170c83f60>
dot_x.ravel()
array([ 2.40039799, 0.69834979, 1.40390294, 3.1392839 , -1.12024869, 0.69724898, 1.63028268, 1.29215378, -3.12763298, 3.02124737, 3.14142451, -3.04027126, -2.66283282, -0.50084351, -3.09775105, -2.46416492, 2.43002753, 2.53708412, 3.01682494, 0.53425316, -2.9251391 , -2.25621917, 3.11384612, 2.43138888, -3.11567374, -3.10420152, 0.2205778 , -2.65966684, -2.10046718, -2.63127963, 0.77540516, 2.26818391, 2.48009394, 0.39247032, -0.21539794, -1.29527834, 2.81662601, -0.86263966, -2.8063128 , -2.13206423, 2.96229747, 0.78089279, -2.65335831, -2.29867648, -0.86421034, 3.06135773, 3.09312481, -0.37742122, -0.65327044, -2.47547465, 2.15929329, 2.92890972, -2.01204121, -1.20227481, -3.06651553, -0.17063709, -2.52206066, 0.47291216, 3.09222241, 0.64749834, 3.12044311, 2.83270647, 0.60105412, -0.78531012, -2.50187621, -2.45603173, -2.06540159, 3.04636592, 2.40741432, 0.44611971, -2.0581787 , -1.7552408 , 2.6616224 , -2.57606179, 2.08794297, 3.10290328, -2.78135005, -2.87524509, 0.68375841, -2.80528889, -1.81451694, 1.02947578, 1.43953199, 1.0311686 , -0.50139733, -2.97388529, 1.37386107, -0.89180601, 0.73845712, 1.02590807, -0.71371108, -1.2664294 , 2.62291537, 3.14124395, 2.70466227, -0.93509697, -2.91098715, -1.73053462, 0.09696181, -3.09139166])
创建不同深度的决策树
进行数据训练
tree_2 = DecisionTreeRegressor(max_depth=2)
tree_5 = DecisionTreeRegressor(max_depth=5)
tree_20 = DecisionTreeRegressor(max_depth=20)
tree_2.fit(x_data,y_target)
tree_5.fit(x_data,y_target)
tree_20.fit(x_data,y_target)
DecisionTreeRegressor(criterion='mse', max_depth=20, max_features=None, max_leaf_nodes=None, min_impurity_split=1e-07, min_samples_leaf=1, min_samples_split=2, min_weight_fraction_leaf=0.0, presort=False, random_state=None, splitter='best')
创建-100到100的预测数据,间隔为0.01
对数据进行预测
x_test.shape
(20000,)
#创造 预测数据
x_test = np.arange(-100,100,0.01).reshape((20000,1))
y2_ = tree_2.predict(x_test)
plt.scatter(y2_[:,0],y2_[:,1])
<matplotlib.collections.PathCollection at 0x7fa170c3e630>
y5_ = tree_5.predict(x_test)
plt.scatter(y5_[:,0],y5_[:,1])
<matplotlib.collections.PathCollection at 0x7fa170b643c8>
y20_ = tree_20.predict(x_test)
plt.scatter(y20_[:,0],y20_[:,1])
<matplotlib.collections.PathCollection at 0x7fa170c5ccf8>
显示图片
分析lenses.txt文件
import sklearn.datasets as datasets
import pandas as pd
len = pd.read_csv('../data/lenses.txt')
len
young myope no reduced no lenses | |
---|---|
0 | young\tmyope\tno\tnormal\tsoft |
1 | young\tmyope\tyes\treduced\tno lenses |
2 | young\tmyope\tyes\tnormal\thard |
3 | young\thyper\tno\treduced\tno lenses |
4 | young\thyper\tno\tnormal\tsoft |
5 | young\thyper\tyes\treduced\tno lenses |
6 | young\thyper\tyes\tnormal\thard |
7 | pre\tmyope\tno\treduced\tno lenses |
8 | pre\tmyope\tno\tnormal\tsoft |
9 | pre\tmyope\tyes\treduced\tno lenses |
10 | pre\tmyope\tyes\tnormal\thard |
11 | pre\thyper\tno\treduced\tno lenses |
12 | pre\thyper\tno\tnormal\tsoft |
13 | pre\thyper\tyes\treduced\tno lenses |
14 | pre\thyper\tyes\tnormal\tno lenses |
15 | presbyopic\tmyope\tno\treduced\tno lenses |
16 | presbyopic\tmyope\tno\tnormal\tno lenses |
17 | presbyopic\tmyope\tyes\treduced\tno lenses |
18 | presbyopic\tmyope\tyes\tnormal\thard |
19 | presbyopic\thyper\tno\treduced\tno lenses |
20 | presbyopic\thyper\tno\tnormal\tsoft |
21 | presbyopic\thyper\tyes\treduced\tno lenses |
22 | presbyopic\thyper\tyes\tnormal\tno lenses |