头歌——人工智能(机器学习 --- 决策树2)

文章目录

  • 第5关:基尼系数
    • 代码
  • 第6关:预剪枝与后剪枝
    • 代码
  • 第7关:鸢尾花识别
    • 代码

第5关:基尼系数

基尼系数
在ID3算法中我们使用了信息增益来选择特征,信息增益大的优先选择。在C4.5算法中,采用了信息增益率来选择特征,以减少信息增益容易选择特征值多的特征的问题。但是无论是ID3还是C4.5,都是基于信息论的熵模型的,这里面会涉及大量的对数运算。能不能简化模型同时也不至于完全丢失熵模型的优点呢?当然有!那就是基尼系数!

CART算法使用基尼系数来代替信息增益率,基尼系数代表了模型的不纯度,基尼系数越小,则不纯度越低,特征越好。这和信息增益与信息增益率是相反的(它们都是越大越好)。

在这里插入图片描述
从公式可以看出,相比于信息增益和信息增益率,计算起来更加简单。举个例子,还是使用第二关中提到过的数据集,第一列是编号,第二列是性别,第三列是活跃度,第四列是客户是否流失的标签(0表示未流失,1表示流失)。

在这里插入图片描述
在这里插入图片描述

代码

import numpy as npdef calcGini(feature, label, index):'''计算基尼系数:param feature:测试用例中字典里的feature,类型为ndarray:param label:测试用例中字典里的label,类型为ndarray:param index:测试用例中字典里的index,即feature部分特征列的索引。该索引指的是feature中第几个特征,如index:0表示使用第一个特征来计算信息增益。:return:基尼系数,类型float'''# 计算子集的基尼指数def calcGiniIndex(label_subset):total = len(label_subset)if total == 0:return 0label_counts = np.bincount(label_subset)probabilities = label_counts / totalgini = 1.0 - np.sum(np.square(probabilities))return gini# 将feature和label转为numpy数组f = np.array(feature)l = np.array(label)# 得到指定特征列的值的集合unique_values = np.unique(f[:, index])total_gini = 0total_samples = len(label)# 按照特征的每个唯一值划分数据集for value in unique_values:# 获取该特征值对应的样本索引subset_indices = np.where(f[:, index] == value)[0]# 获取对应的子集标签subset_label = l[subset_indices]# 计算子集的基尼指数subset_gini = calcGiniIndex(subset_label)# 加权计算总的基尼系数weighted_gini = (len(subset_label) / total_samples) * subset_ginitotal_gini += weighted_ginireturn total_gini

第6关:预剪枝与后剪枝

为什么需要剪枝
决策树的生成是递归地去构建决策树,直到不能继续下去为止。这样产生的树往往对训练数据有很高的分类准确率,但对未知的测试数据进行预测就没有那么准确了,也就是所谓的过拟合。

决策树容易过拟合的原因是在构建决策树的过程时会过多地考虑如何提高对训练集中的数据的分类准确率,从而会构建出非常复杂的决策树(树的宽度和深度都比较大)。在之前的实训中已经提到过,模型的复杂度越高,模型就越容易出现过拟合的现象。所以简化决策树的复杂度能够有效地缓解过拟合现象,而简化决策树最常用的方法就是剪枝。剪枝分为预剪枝与后剪枝。

预剪枝
预剪枝的核心思想是在决策树生成过程中,对每个结点在划分前先进行一个评估,若当前结点的划分不能带来决策树泛化性能提升,则停止划分并将当前结点标记为叶结点。

想要评估决策树算法的泛化性能如何,方法很简单。可以将训练数据集中随机取出一部分作为验证数据集,然后在用训练数据集对每个结点进行划分之前用当前状态的决策树计算出在验证数据集上的正确率。正确率越高说明决策树的泛化性能越好,如果在划分结点的时候发现泛化性能有所下降或者没有提升时,说明应该停止划分,并用投票计数的方式将当前结点标记成叶子结点。

举个例子,假如上一关中所提到的用来决定是否买西瓜的决策树模型已经出现过拟合的情况,模型如下:
在这里插入图片描述
假设当模型在划分是否便宜这个结点前,模型在验证数据集上的正确率为0.81。但在划分后,模型在验证数据集上的正确率降为0.67。此时就不应该划分是否便宜这个结点。所以预剪枝后的模型如下:

在这里插入图片描述
从上图可以看出,预剪枝能够降低决策树的复杂度。这种预剪枝处理属于贪心思想,但是贪心有一定的缺陷,就是可能当前划分会降低泛化性能,但在其基础上进行的后续划分却有可能导致性能显著提高。所以有可能会导致决策树出现欠拟合的情况。

后剪枝
后剪枝是先从训练集生成一棵完整的决策树,然后自底向上地对非叶结点进行考察,若将该结点对应的子树替换为叶结点能够带来决策树泛化性能提升,则将该子树替换为叶结点。

后剪枝的思路很直接,对于决策树中的每一个非叶子结点的子树,我们尝试着把它替换成一个叶子结点,该叶子结点的类别我们用子树所覆盖训练样本中存在最多的那个类来代替,这样就产生了一个简化决策树,然后比较这两个决策树在测试数据集中的表现,如果简化决策树在验证数据集中的准确率有所提高,那么该子树就可以替换成叶子结点。该算法以bottom-up的方式遍历所有的子树,直至没有任何子树可以替换使得测试数据集的表现得以改进时,算法就可以终止。

从后剪枝的流程可以看出,后剪枝是从全局的角度来看待要不要剪枝,所以造成欠拟合现象的可能性比较小。但由于后剪枝需要先生成完整的决策树,然后再剪枝,所以后剪枝的训练时间开销更高。

代码

import numpy as np
from copy import deepcopyclass DecisionTree(object):def __init__(self):# 决策树模型self.tree = {}def calcInfoGain(self, feature, label, index):# 计算信息增益的代码def calcInfoEntropy(feature, label):# 计算信息熵label_set = set(label)result = 0for l in label_set:count = 0for j in range(len(label)):if label[j] == l:count += 1p = count / len(label)result -= p * np.log2(p)return resultdef calcHDA(feature, label, index, value):# 计算条件熵count = 0sub_feature = []sub_label = []for i in range(len(feature)):if feature[i][index] == value:count += 1sub_feature.append(feature[i])sub_label.append(label[i])pHA = count / len(feature)e = calcInfoEntropy(sub_feature, sub_label)return pHA * ebase_e = calcInfoEntropy(feature, label)f = np.array(feature)f_set = set(f[:, index])sum_HDA = 0for value in f_set:sum_HDA += calcHDA(feature, label, index, value)return base_e - sum_HDAdef getBestFeature(self, feature, label):max_infogain = 0best_feature = 0for i in range(len(feature[0])):infogain = self.calcInfoGain(feature, label, i)if infogain > max_infogain:max_infogain = infogainbest_feature = ireturn best_featuredef calc_acc_val(self, the_tree, val_feature, val_label):# 计算验证集准确率result = []def classify(tree, feature):if not isinstance(tree, dict):return treet_index, t_value = list(tree.items())[0]f_value = feature[t_index]if isinstance(t_value, dict):classLabel = classify(tree[t_index][f_value], feature)return classLabelelse:return t_valuefor f in val_feature:result.append(classify(the_tree, f))result = np.array(result)return np.mean(result == val_label)def createTree(self, train_feature, train_label):# 创建决策树if len(set(train_label)) == 1:return train_label[0]if len(train_feature[0]) == 1 or len(np.unique(train_feature, axis=0)) == 1:vote = {}for l in train_label:if l in vote.keys():vote[l] += 1else:vote[l] = 1max_count = 0vote_label = Nonefor k, v in vote.items():if v > max_count:max_count = vvote_label = kreturn vote_labelbest_feature = self.getBestFeature(train_feature, train_label)tree = {best_feature: {}}f = np.array(train_feature)f_set = set(f[:, best_feature])for v in f_set:sub_feature = []sub_label = []for i in range(len(train_feature)):if train_feature[i][best_feature] == v:sub_feature.append(train_feature[i])sub_label.append(train_label[i])tree[best_feature][v] = self.createTree(sub_feature, sub_label)return treedef post_cut(self, val_feature, val_label):# 剪枝相关代码def get_non_leaf_node_count(tree):non_leaf_node_path = []def dfs(tree, path, all_path):for k in tree.keys():if isinstance(tree[k], dict):path.append(k)dfs(tree[k], path, all_path)if len(path) > 0:path.pop()else:all_path.append(path[:])dfs(tree, [], non_leaf_node_path)unique_non_leaf_node = []for path in non_leaf_node_path:isFind = Falsefor p in unique_non_leaf_node:if path == p:isFind = Truebreakif not isFind:unique_non_leaf_node.append(path)return len(unique_non_leaf_node)def get_the_most_deep_path(tree):non_leaf_node_path = []def dfs(tree, path, all_path):for k in tree.keys():if isinstance(tree[k], dict):path.append(k)dfs(tree[k], path, all_path)if len(path) > 0:path.pop()else:all_path.append(path[:])dfs(tree, [], non_leaf_node_path)max_depth = 0result = Nonefor path in non_leaf_node_path:if len(path) > max_depth:max_depth = len(path)result = pathreturn resultdef set_vote_label(tree, path, label):for i in range(len(path)-1):tree = tree[path[i]]tree[path[len(path)-1]] = labelacc_before_cut = self.calc_acc_val(self.tree, val_feature, val_label)for _ in range(get_non_leaf_node_count(self.tree)):path = get_the_most_deep_path(self.tree)tree = deepcopy(self.tree)step = deepcopy(tree)for k in path:step = step[k]vote_label = sorted(step.items(), key=lambda item: item[1], reverse=True)[0][0]set_vote_label(tree, path, vote_label)acc_after_cut = self.calc_acc_val(tree, val_feature, val_label)if acc_after_cut > acc_before_cut:set_vote_label(self.tree, path, vote_label)acc_before_cut = acc_after_cutdef fit(self, train_feature, train_label, val_feature, val_label):# 训练决策树模型self.tree = self.createTree(train_feature, train_label)self.post_cut(val_feature, val_label)def predict(self, feature):# 预测函数result = []def classify(tree, feature):if not isinstance(tree, dict):return treet_index, t_value = list(tree.items())[0]f_value = feature[t_index]if isinstance(t_value, dict):classLabel = classify(tree[t_index][f_value], feature)return classLabelelse:return t_valuefor f in feature:result.append(classify(self.tree, f))return np.array(result)

第7关:鸢尾花识别

掌握如何使用sklearn提供的DecisionTreeClassifier

在这里插入图片描述
数据简介:
鸢尾花数据集是一类多重变量分析的数据集。通过花萼长度,花萼宽度,花瓣长度,花瓣宽度4个属性预测鸢尾花卉属于(Setosa,Versicolour,Virginica)三个种类中的哪一类(其中分别用0,1,2代替)。

数据集中部分数据与标签如下图所示:
在这里插入图片描述
在这里插入图片描述
DecisionTreeClassifier
DecisionTreeClassifier的构造函数中有两个常用的参数可以设置:

criterion:划分节点时用到的指标。有gini(基尼系数),entropy(信息增益)。若不设置,默认为gini
max_depth:决策树的最大深度,如果发现模型已经出现过拟合,可以尝试将该参数调小。若不设置,默认为None

和sklearn中其他分类器一样,DecisionTreeClassifier类中的fit函数用于训练模型,fit函数有两个向量输入:

X:大小为[样本数量,特征数量]的ndarray,存放训练样本;
Y:值为整型,大小为[样本数量]的ndarray,存放训练样本的分类标签。

DecisionTreeClassifier类中的predict函数用于预测,返回预测标签,predict函数有一个向量输入:

X:大小为[样本数量,特征数量]的ndarray,存放预测样本。

DecisionTreeClassifier的使用代码如下:

from sklearn.tree import DecisionTreeClassifier
clf = tree.DecisionTreeClassifier()
clf.fit(X_train, Y_train)
result = clf.predict(X_test)

代码

#********* Begin *********#
import pandas as pd
from sklearn.tree import DecisionTreeClassifiertrain_df = pd.read_csv('./step7/train_data.csv').as_matrix()
train_label = pd.read_csv('./step7/train_label.csv').as_matrix()
test_df = pd.read_csv('./step7/test_data.csv').as_matrix()dt = DecisionTreeClassifier()
dt.fit(train_df, train_label)
result = dt.predict(test_df)result = pd.DataFrame({'target':result})
result.to_csv('./step7/predict.csv', index=False)
#********* End *********#

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/57673.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

银河麒麟相关

最近安装了银河麒麟server版本,整理下遇到的一些小问题 1、vmware安装Kylin-Server-V10-SP3-General-Release-2303-X86_64虚拟机完成后,桌面窗口很小,安装vmwaretools后解决,下载地址http://softwareupdate.vmware.com/cds/vmw-de…

leetcode-71-简化路径

题解: 1、以"/"作为分隔符对字符串进行分割得到数组names; 2、初始化一个栈stack(python中的栈使用列表实现); 3、遍历数组names;如果当前元素为".."且栈不为空,则将弹出栈顶元素&a…

网络安全行业10大副业汇总,总有一个适合你

网络安全行业这10大副业汇总 总有一个适合你 引言 在当今的网络安全行业中,除了全职工作外,许多师傅还通过副业来增加收入、不断拓展自身技能,并积累更多实际操作经验,为职业发展增添了无限可能。 本文提供了10种适合各种类型…

Android13关于获取外部存储文件的相关问题及解决方案记录

Android的学习路上... 测试设备:vivo X90s安卓版本: Android13开发环境:AndroidStudio FlamingoSDK:33 最近我在Android13的环境下尝试写一个文件选择器,以便日后的开发使用。但是我们知道,从Android13 (A…

django restful API

文章目录 项目地址一、django环境安装以及初识restful1.1 安装python 3.10的虚拟环境1.2 创建django工程文件1.3 创建一个book app1.4 序列化(Django JsonResponse)1.4.1创建一个Models1.4.2 创建django的超级用户admin1.4.3 添加serializers.py生成序列化器1.5 FBV创建视图1…

用docker Desktop 下载使用thingsboard/tb-gateway

1、因为正常的docker pull thingsboard/tb-gateway 国内不行了,所以需要其它工具来下载 2、在win下用powershell管理员下运行 docker search thingsboard/tb-gateway 可以访问到了 docker pull thingsboard/tb-gateway就可以下载了 3、docker Desktop就可以看到…

ComfyUI初体验

ComfyUI 我就不过多介绍了,安装和基础使用可以看下面大佬的视频,感觉自己靠图文描述的效果不一定好,大家看视频比较方便。 ComfyUI全球爆红,AI绘画进入“工作流时代”?做最好懂的Comfy UI入门教程:Stable D…

分布式ID生成策略

文章目录 分布式ID必要性1.UUID2.基于DB的自增主键方案3.数据库多主模式4.号段模式5.Redis6.Zookeeper7.ETCD8.雪花算法9.百度(Uidgenerator)10.美团(Leaf)11.滴滴(TinyID) 分布式ID必要性 业务量小于500W的时候单独一个mysql即可提供服务,再大点的时候就进行读写分…

nuScenes数据集使用的相机的外参和内参

因为需要用不同数据集测试对比效果,而一般的模型代码里实现的检测结果可视化都是使用open3d的Visualizer在点云上画的3d框,展示出来的可视化效果很差,可能是偷懒,没有实现将检测结果投影到各相机的图像上,所以检测效果…

[ARM-2D 专题]4. 快速搭建ARM2D的PC仿真开发环境及避坑手法

有几种情况你需要使用pc仿真开发环境: 手上没有合适的硬件条件只想快速的了解一下ARM-2D开发过程中,加速开发过程,避免频繁的下载代码 无论如何,pc仿真开发环境,你都值得拥有。 第一步,先下载源代码&#…

基于ElementPlus的Form组件封装

前言 我们在项目开发过程中遇到最多就是表单页面的开发,那么使用频率比较高的就是Form组件,无论是vue亦或者是react,我们在项目中使用到UI库都会有Form组件。多数情况下都是用到了Form组件,我们先根据UI库或者其他类似的页面直接…

h5页面与小程序页面互相跳转

小程序跳转h5页面 一个home页 /pages/home/home 一个含有点击事件的元素&#xff1a;<button type"primary" bind:tap"toWebView">点击跳转h5页面</button>toWebView(){ wx.navigateTo({ url: /pages/webview/webview }) } 一个webView页 /pa…

物联网行业应用实训室建设方案

一、建设背景 随着物联网技术的迅猛发展和广泛应用&#xff0c;物联网产业已跃升为新时代的经济增长引擎&#xff0c;对于产业升级和社会信息化水平的提升具有举足轻重的地位。因此&#xff0c;为了满足这一领域的迫切需求&#xff0c;培养具备物联网技术应用能力的优秀人才成…

自动发现-实现运维管理自动化

nVisual-Discovery是一款自动化工具软件&#xff0c;通过多种自动发现技术&#xff0c;协助运维管理人员快速建立可视化的网络文档&#xff0c;提升网络管理的效率与准确性。 01 IP扫描发现 当我们新接手一个网络运维项目&#xff0c;通常缺乏精准的网络文档数据&#xff0c;…

4.2-6 使用Hadoop WebUI

文章目录 1. 查看HDFS集群状态1.1 端口号说明1.2 用主机名访问1.3 主节点状态1.4 用IP地址访问1.5 查看数据节点 2. 操作HDFS文件系统2.1 查看HDFS文件系统2.2 在HDFS上创建目录2.3 上传文件到HDFS2.4 删除HDFS文件和目录 3. 查看YARN集群状态4. 实战总结 1. 查看HDFS集群状态 …

Docker部署MySQL主从复制

1. 主从复制概念及优势 1.1 概念 MySQL主从复制是一种数据库复制技术&#xff0c;它允许将一个数据库服务器&#xff08;主服务器&#xff09;上的数据更改复制到一个或多个数据库服务器&#xff08;从服务器&#xff09;。这种技术在数据库管理和维护中扮演着重要的角色&…

CSS 网格布局

网格布局是一个二维布局系统&#xff0c;允许开发者以行和列的形式创建灵活的网络&#xff0c;并将内容放置在网络的单元格中。有些元素可能只占据网络的一个单元&#xff0c;另一些元素则可能占据多行或多列。 网格的大小既可以精确定义&#xff0c;也可以根据自身内容自动计…

C/C++(六)多态

本文将介绍C的另一个基于继承的重要且复杂的机制&#xff0c;多态。 一、多态的概念 多态&#xff0c;就是多种形态&#xff0c;通俗来说就是不同的对象去完成某个行为&#xff0c;会产生不同的状态。 多态严格意义上分为静态多态与动态多态&#xff0c;我们平常说的多态一般…

实战应用WPS WebOffice开放平台服务

概述 根据公司的业务需要&#xff0c;主要功能是在线编辑文档&#xff0c;前端的小伙伴进行的技术调研&#xff0c;接入的是WPS WebOffice&#xff0c;这里只阐述技术介入的步骤、流程和遇到的坑进行的一些总结。 实践 WPS WebOffice 开放平台进行认证 在开始之前&#xff…

【NOIP提高组】加分二叉树

【NOIP提高组】加分二叉树 &#x1f490;The Begin&#x1f490;点点关注&#xff0c;收藏不迷路&#x1f490; 设一个n个节点的二叉树tree的中序遍历为&#xff08;l,2,3,…,n&#xff09;&#xff0c;其中数字1,2,3,…,n为节点编号。每个节点都有一个分数&#xff08;均为正整…