一、决策树ID3算法
相比于logistic回归、BP网络、支持向量机等基于超平面的方法,决策树更像一种算法,里面的数学原理并不是很多,较好理解。
决策树就是一个不断地属性选择、属性划分地过程,直到满足某一情况就停止划分。
- 当前样本全部属于同一类别了(信息增益为0);
- 已经是空叶子了(没有样本了);
- 当前叶子节点所有样本所有属性上取值相同,无法划分了(信息增益为0)。
信息增益如何计算?根据信息熵地变化量,信息熵减少最大地属性就是我们要选择地属性。
信息熵定义:
E n t ( D ) = − ∑ k = 1 ∣ y ∣ p k l o g 2 p k Ent(D)=-\sum_{k=1}^{|y|}p_klog_2p_k Ent(D)=−k=1∑∣y∣pklog2pk
信息增益定义:
G a i n ( D , a ) = E n t ( D ) − ∑ v = 1 v ∣ D v ∣ ∣ D ∣ E n t ( D v ) Gain(D,a)=Ent(D)-\sum_{v=1}^v\frac{|D^v|}{|D|}Ent(D^v) Gain(D,a)=Ent(D)−v=1∑v∣D∣∣Dv∣Ent(Dv)
信息增益越大,则意味着属性a来划分所获得的“纯度提升”越大。
ID3就是以信息增益作为属性选择和划分的标准的。有了决策树生长和停止生长的条件,剩下的其实就是一些编程技巧了,我们就可以进行编码了。
除此之外,决策树还有C4.5等其它实现的算法,包括基尼系数、增益率、剪枝、预剪枝等防止过拟合的方法,但决策树最本质、朴素的思想还是在ID3中体现的最好。
具体可以参考这篇博客:机器学习06:决策树学习.
二、基于weka平台实现ID3决策树
package weka.classifiers.myf;import weka.classifiers.Classifier;
import weka.core.*;/*** @author YFMan* @Description 自定义的 ID3 分类器* @Date 2023/5/25 18:07*/
public class myId3 extends Classifier {// 当前节点 的 后续节点private myId3[] m_Successors;// 当前节点的划分属性 (如果为空,说明当前节点是叶子节点;否则,说明当前节点是中间节点)private Attribute m_Attribute;// 当前节点的类别分布 (如果为中间节点,全为 0;为叶子节点,为类别分布)private double[] m_Distribution;// 当前节点的类别 (如果为中间节点,为 0;为叶子节点,为类别分布)// (用于获取类别的索引,对于算法本身没用,但对于可视化 决策树有用)private double m_ClassValue;// 当前节点的类别属性 (如果为中间节点,为 null;为叶子节点,为类别属性)// (用于获取类别的名称,对于算法本身没用,但对于可视化 决策树有用)private Attribute m_ClassAttribute;/** @Author YFMan* @Description 根据训练数据 建立 决策树* @Date 2023/5/25 18:43* @Param [data]* @return void**/public void buildClassifier(Instances data) throws Exception {// 建树makeTree(data);}/** @Author YFMan* @Description 根据训练数据 建立 决策树* @Date 2023/5/25 18:43* @Param [data] 训练数据* @return void**/private void makeTree(Instances data) throws Exception {// 如果是空叶子,拒绝建树 (拒判)if (data.numInstances() == 0) {m_Attribute = null;m_ClassValue = Instance.missingValue();m_Distribution = new double[data.numClasses()];return;}// 计算 所有属性的 信息增益double[] infoGains = new double[data.numAttributes()];// 遍历所有属性for(int i = 0; i < data.numAttributes(); i++) {// 如果是类别属性,跳过if (i == data.classIndex()) {infoGains[i] = 0;} else {// 计算信息增益infoGains[i] = computeInfoGain(data, data.attribute(i));}}// 选择信息增益最大的属性m_Attribute = data.attribute(Utils.maxIndex(infoGains));// 如果信息增益为 0,说明当前节点包含的样例都属于同一类别,直接设置为叶子节点if (Utils.eq(infoGains[m_Attribute.index()], 0)) {// 设置为叶子节点m_Attribute = null;m_Distribution = new double[data.numClasses()];// 遍历所有样例for (int i = 0; i < data.numInstances(); i++) {// 获取当前样例的类别Instance inst = data.instance(i);// 统计类别分布m_Distribution[(int) inst.classValue()]++;}// 归一化Utils.normalize(m_Distribution);// 设置类别m_ClassValue = Utils.maxIndex(m_Distribution);m_ClassAttribute = data.classAttribute();} else { // 否则,递归建树// 划分数据集Instances[] splitData = splitData(data, m_Attribute);// 创建叶子m_Successors = new myId3[m_Attribute.numValues()];// 叶子再去长叶子,递归调用for (int j = 0; j < m_Attribute.numValues(); j++) {m_Successors[j] = new myId3();m_Successors[j].makeTree(splitData[j]);}}}/** @Author YFMan* @Description 根据 instance 进行分类* @Date 2023/5/25 18:33* @Param [instance] 待分类的实例* @return double[] 类别分布**/public double[] distributionForInstance(Instance instance)throws NoSupportForMissingValuesException {// 如果到达叶子节点,返回类别分布if (m_Attribute == null) {// 如果 m_Distribution 全为 0(是空叶子),随机返回一个类别分布if (Utils.eq(Utils.sum(m_Distribution), 0)) {// 在 0~类别数-1 之间随机选择一个类别m_Distribution = new double[m_ClassAttribute.numValues()];m_Distribution[(int) Math.round(Math.random() * m_ClassAttribute.numValues())] = 1.0;}return m_Distribution;} else {// 否则,递归调用return m_Successors[(int) instance.value(m_Attribute)].distributionForInstance(instance);}}/** @Author YFMan* @Description 计算当前数据集 选择某个属性的 信息增益* @Date 2023/5/25 18:29* @Param [data, att] 当前数据集,选择的属性* @return double 信息增益**/private double computeInfoGain(Instances data, Attribute att)throws Exception {// 计算 data 的信息熵double infoGain = computeEntropy(data);// 计算 data 按照 att 属性进行划分的信息熵// 划分数据集Instances[] splitData = splitData(data, att);// 遍历划分后的数据集for (Instances instances : splitData) {// 计算概率double probability = (double) instances.numInstances() / data.numInstances();// 计算信息熵infoGain -= probability * computeEntropy(instances);}// 返回信息增益return infoGain;}/** @Author YFMan* @Description 计算信息熵* @Date 2023/5/25 18:18* @Param [data] 计算的数据集* @return double 信息熵**/private double computeEntropy(Instances data) throws Exception {// 计不同类别的数量double[] classCounts = new double[data.numClasses()];// 遍历数据集for(int i=0;i<data.numInstances();i++){// 获取类别int classIndex = (int) data.instance(i).classValue();// 数量加一classCounts[classIndex]++;}// 计算信息熵double entropy = 0;// 遍历类别for (double classCount : classCounts) {// 注意:这里是大于 0,因为 log2(0) = -Infinity;// 如果是等于 0,那么计算结果就是 NaN,熵就出错了if(classCount > 0){// 计算概率double probability = classCount / data.numInstances();// 计算信息熵entropy -= probability * Utils.log2(probability);}}// 返回信息熵return entropy;}/** @Author YFMan* @Description 根据属性划分数据集* @Date 2023/5/25 18:23* @Param [data, att] 数据集,属性* @return weka.core.Instances[] 划分后的数据集**/private Instances[] splitData(Instances data, Attribute att) {// 定义划分后的数据集Instances[] splitData = new Instances[att.numValues()];// 遍历划分后的数据集for(int i=0;i<splitData.length;i++){// 创建数据集 (这里主要是为了初始化 数据集 header)// Constructor copying all instances and references to the header// information from the given set of instances.splitData[i] = new Instances(data,0);}// 遍历数据集for(int i=0;i<data.numInstances();i++){// 获取实例Instance instance = data.instance(i);// 获取实例的属性值double value = instance.value(att);// 将实例添加到对应的数据集中splitData[(int) value].add(instance);}// 返回划分后的数据集return splitData;}private String toString(int level) {StringBuffer text = new StringBuffer();if (m_Attribute == null) {if (Instance.isMissingValue(m_ClassValue)) {text.append(": null");} else {text.append(": " + m_ClassAttribute.value((int) m_ClassValue));}} else {for (int j = 0; j < m_Attribute.numValues(); j++) {text.append("\n");for (int i = 0; i < level; i++) {text.append("| ");}text.append(m_Attribute.name() + " = " + m_Attribute.value(j));text.append(m_Successors[j].toString(level + 1));}}return text.toString();}public String toString() {if ((m_Distribution == null) && (m_Successors == null)) {return "Id3: No model built yet.";}return "Id3\n\n" + toString(0);}/*** Main method.** @param args the options for the classifier*/public static void main(String[] args) {runClassifier(new myId3(), args);}
}