前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >机器学习之决策树熵&信息增量求解算法实现

机器学习之决策树熵&信息增量求解算法实现

作者头像
Gxjun
发布2018-03-27 10:40:19
1.2K0
发布2018-03-27 10:40:19
举报
文章被收录于专栏:ml

此文不对理论做相关阐述,仅涉及代码实现:

1.熵计算公式:

             P为正例,Q为反例

     Entropy(S)   = -PLog2(P) - QLog2(Q);

2.信息增量计算:

    Gain(S,Sv) = Entropy(S) - (|Sv|/|S|)ΣEntropy(Sv);

举例:

转化数据输入:

代码语言:javascript
复制
 5  14
 Outlook       Sunny  Sunny  Overcast  Rain  Rain    Rain    Overcast  Sunny  Sunny    Rain    Sunny   Overcast   Overcast    Rain
 Temperature   Hot    Hot    Hot       Mild  Cool    Cool        Cool   Mild  Cool     Mild    Mild    Mild       Hot         Mild
 Humidity      High   High   High      High  Normal  Normal  Normal     High  Normal   Normal  Normal  High       Normal      High
 Wind          Weak   Strong Weak      Weak  Weak    Strong  Strong    Weak   Weak     Weak    Strong  Strong     Weak        Strong
 PlayTennis    No     No     Yes       Yes   Yes     No      Yes       No     Yes      Yes     Yes     Yes        Yes         No
 Outlook Temperature Humidity Wind PlayTennis
代码语言:javascript
复制
 1 package com.qunar.data.tree;
 2 
 3 /**
 4  * *********************************************************
 5  * <p/>
 6  * Author:     XiJun.Gong
 7  * Date:       2016-09-02 15:28
 8  * Version:    default 1.0.0
 9  * Class description:
10  * <p>统计该类型出现的次数</p>
11  * <p/>
12  * *********************************************************
13  */
14 public class CountMap<T> {
15 
16     private T key;     //类型
17     private int value;   //出现的次数
18 
19     public CountMap() {
20         this(null, 0);
21     }
22 
23     public CountMap(T key, int value) {
24         this.key = key;
25         this.value = value;
26     }
27 
28     public T getKey() {
29         return key;
30     }
31 
32     public void setKey(T key) {
33         this.key = key;
34     }
35 
36     public int getValue() {
37         return value;
38     }
39 
40     public void setValue(int value) {
41         this.value = value;
42     }
43 }
代码语言:javascript
复制
  1 package com.qunar.data.tree;
  2 
  3 import com.google.common.collect.ArrayListMultimap;
  4 import com.google.common.collect.Maps;
  5 import com.google.common.collect.Multimap;
  6 import com.google.common.collect.Sets;
  7 
  8 import java.util.*;
  9 
 10 /**
 11  * *********************************************************
 12  * <p/>
 13  * Author:     XiJun.Gong
 14  * Date:       2016-09-02 14:24
 15  * Version:    default 1.0.0
 16  * Class description:
 17  * <p>决策树</p>
 18  * <p/>
 19  * *********************************************************
 20  */
 21 
 22 public class DecisionTree<T, K> {
 23 
 24     private static String positiveExampleType = "Yes";
 25     private static String counterExampleType = "No";
 26 
 27 
 28     public double pLog2(final double p) {
 29         if (0 == p) return 0;
 30         return p * (Math.log(p) / Math.log(2));
 31     }
 32 
 33     /**
 34      * 熵计算
 35      *
 36      * @param positiveExample 正例个数
 37      * @param counterExample  反例个数
 38      * @return 熵值
 39      */
 40     public double entropy(final double positiveExample, final double counterExample) {
 41 
 42         double total = positiveExample + counterExample;
 43         double positiveP = positiveExample / total;
 44         double counterP = counterExample / total;
 45         return -1d * (pLog2(positiveP) + pLog2(counterP));
 46     }
 47 
 48     /**
 49      * @param features 特征列表
 50      * @param results  对应结果
 51      * @return 将信息整合成新的格式
 52      */
 53     public Multimap<T, CountMap<K>> merge(final List<T> features, final List<T> results) {
 54         //数据转化
 55         Multimap<T, CountMap<K>> InfoMap = ArrayListMultimap.create();
 56         Iterator result = results.iterator();
 57         for (T feature : features) {
 58             K res = (K) result.next();
 59             boolean tag = false;
 60             Collection<CountMap<K>> countMaps = InfoMap.get(feature);
 61             for (CountMap countMap : countMaps) {
 62                 if (countMap.getKey().equals(res)) {
 63                     /*修改值*/
 64                     int num = countMap.getValue() + 1;
 65                     InfoMap.remove(feature, countMap);
 66                     InfoMap.put(feature, new CountMap<K>(res, num));
 67                     tag = true;
 68                     break;
 69                 }
 70             }
 71             if (!tag)
 72                 InfoMap.put(feature, new CountMap<K>(res, 1));
 73         }
 74 
 75         return InfoMap;
 76     }
 77 
 78     /**
 79      * 信息增益
 80      *
 81      * @param infoMap   因素(Outlook,Temperature,Humidity,Wind)对应的结果
 82      * @param dataTable 输入的数据表
 83      * @param type      因素中的类型(Outlook{Sunny,Overcast,Rain})
 84      * @param entropyS  总的熵值
 85      * @param totalSize 总的样本数
 86      * @return 信息增益
 87      */
 88     public double gain(Multimap<T, CountMap<K>> infoMap,
 89                        Map<K, List<T>> dataTable,
 90                        final String type,
 91                        double entropyS,
 92                        final int totalSize) {
 93         //去重
 94         Set<T> subTypes = Sets.newHashSet();
 95         subTypes.addAll(dataTable.get(type));
 96         /*计算*/
 97         for (T subType : subTypes) {
 98             Collection<CountMap<K>> countMaps = infoMap.get(subType);
 99             double subSize = 0;
100             double positiveExample = 0;
101             double counterExample = 0;
102             for (CountMap<K> countMap : countMaps) {
103                 subSize += countMap.getValue();
104                 if (positiveExampleType.equals(countMap.getKey()))
105                     positiveExample = countMap.getValue();
106                 else
107                     counterExample = countMap.getValue();
108             }
109             entropyS -= (subSize / totalSize) * entropy(positiveExample, counterExample);
110         }
111         return entropyS;
112     }
113 
114     /**
115      * 计算
116      *
117      * @param dataTable  数据表
118      * @param types      因素列表{Outlook,Temperature,Humidity,Wind}
119      * @param resultType 结果(PlayTennis)
120      * @return 返回信息增益集合
121      */
122     public Map<String, Double> calculate(Map<K, List<T>> dataTable, List<K> types, K resultType) {
123 
124         Map<String, Double> answer = Maps.newHashMap();
125         List<T> results = dataTable.get(resultType);
126         int totalSize = results.size();
127         int positiveExample = 0;
128         int counterExample = 0;
129         double entropyS = 0d;
130         for (T ExampleType : results) {
131             if (positiveExampleType.equals(ExampleType)) {
132                 ++positiveExample;
133                 continue;
134             }
135             ++counterExample;
136         }
137         /*计算总的熵*/
138         entropyS = entropy(positiveExample, counterExample);
139 
140         Multimap<T, CountMap<K>> infoMap;
141         for (K type : types) {
142             infoMap = merge(dataTable.get(type), results);
143             double _gain = gain(infoMap, dataTable, (String) type, entropyS, totalSize);
144             answer.put((String) type, _gain);
145         }
146         return answer;
147     }
148 
149 }   1package com.qunar.data.tree;
代码语言:javascript
复制
 2 
 3 import com.google.common.collect.Lists;
 4 import com.google.common.collect.Maps;
 5 
 6 import java.util.*;
 7 
 8 /**
 9  * *********************************************************
10  * <p/>
11  * Author:     XiJun.Gong
12  * Date:       2016-09-02 16:43
13  * Version:    default 1.0.0
14  * Class description:
15  * <p/>
16  * *********************************************************
17  */
18 public class Main {
19 
20     public static void main(String args[]) {
21 
22         Scanner scanner = new Scanner(System.in);
23         while (scanner.hasNext()) {
24             DecisionTree<String, String> dt = new DecisionTree();
25             Map<String, List<String>> dataTable = Maps.newHashMap();
26             /*Map<String, List<String>> dataTable = Maps.newHashMap();*/
27             List<String> types = Lists.newArrayList();
28             String resultType;
29             int factorSize = scanner.nextInt();
30             int demoSize = scanner.nextInt();
31             String type;
32 
33             for (int i = 0; i < factorSize; i++) {
34                 List<String> demos = Lists.newArrayList();
35                 type = scanner.next();
36                 for (int j = 0; j < demoSize; j++) {
37                     demos.add(scanner.next());
38                 }
39                 dataTable.put(type, demos);
40             }
41             for (int i = 1; i < factorSize; i++) {
42                 types.add(scanner.next());
43             }
44             resultType = scanner.next();
45             Map<String, Double> ans = dt.calculate(dataTable, types, resultType);
46             List<Map.Entry<String, Double>> list = new ArrayList<Map.Entry<String, Double>>(ans.entrySet());
47             Collections.sort(list, new Comparator<Map.Entry<String, Double>>() {
48 
49 
50                 @Override
51                 public int compare(Map.Entry<String, Double> o1, Map.Entry<String, Double> o2) {
52                     return (o2.getValue() > o1.getValue() ? 1 : -1);
53                 }
54             });
55 
56             for (Map.Entry<String, Double> iterator : list) {
57                 System.out.println(iterator.getKey() + "= " + iterator.getValue());
58             }
59         }
60     }
61 
62 }
63 /**
64  *使用举例:*
65  5  14
66  Outlook       Sunny  Sunny  Overcast  Rain  Rain    Rain    Overcast  Sunny  Sunny    Rain    Sunny   Overcast   Overcast    Rain
67  Temperature   Hot    Hot    Hot       Mild  Cool    Cool        Cool   Mild  Cool     Mild    Mild    Mild       Hot         Mild
68  Humidity      High   High   High      High  Normal  Normal  Normal     High  Normal   Normal  Normal  High       Normal      High
69  Wind          Weak   Strong Weak      Weak  Weak    Strong  Strong    Weak   Weak     Weak    Strong  Strong     Weak        Strong
70  PlayTennis    No     No     Yes       Yes   Yes     No      Yes       No     Yes      Yes     Yes     Yes        Yes         No
71  Outlook Temperature Humidity Wind PlayTennis
72  */

结果:

代码语言:javascript
复制
Outlook= 0.2467498197744391
Humidity= 0.15183550136234136
Wind= 0.04812703040826927
Temperature= 0.029222565658954647
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2016-09-02 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档