三招提升数据不平衡模型的性能(附python代码)

摘要: 本文的主要目标是处理数据不平衡问题。文中描述了用来克服数据不平衡问题的三种技术,分别是集成交叉验证、类别权重以及过大预测 。

       对于深度学习而言,数据集非常重要,但在实际项目中,或多或少会碰见数据不平衡问题。什么是数据不平衡呢?举例来说,现在有一个任务是判断西瓜是否成熟,这是一个二分类问题——西瓜是生的还是熟的,该任务的数据集由两部分数据组成,成熟西瓜与生西瓜,假设生西瓜的样本数量远远大于成熟西瓜样本的数量,针对这样的数据集训练出来的算法“偏向”于识别新样本为生西瓜,存心让你买不到甜的西瓜以解夏天之苦,这就是一个数据不平衡问题。针对数据不平衡问题有相应的处理办法,比如对多数样本进行采样使得其样本数量级与少样本数相近,或者是对少数样本重复使用等。最近恰好在面试中遇到一个数据不平衡问题,这也是面试中经常会出现的问题之一,现向读者分享此次解决问题的心得。


数据集

       训练数据中有三个标签,分别标记为[1、2、3],这意味着该问题是一个多分类问题。训练数据集有17个特征以及38829个独立数据点。而在测试数据中,有16个没有标签的特征和16641个数据点。该训练数据集非常不平衡,大部分数据是1类(95%),而2类和3类分别有3.0%和0.87%的数据,如下图所示。


算法

       经过初步观察,决定采用随机森林(RF)算法,因为它优于支持向量机、Xgboost以及LightGBM算法。在这个项目中选择RF还有几个原因:

  • 1机森林对过拟合具有很强的鲁棒性;
  • 2.参数化仍然非常直观;
  • 3.在这个项目中,有许多成功的用例将随机森林算法用于高度不平衡的数据集;
  • 4.个人有先前的算法实施经验;
           为了找到最佳参数,使用scikit-sklearn实现的GridSearchCV对指定的参数值执行网格搜索,更多细节可以在本人的Github上找到。

为了处理数据不平衡问题,使用了以下三种技术:

A.使用集成交叉验证(CV):

       在这个项目中,使用交叉验证来验证模型的鲁棒性。整个数据集被分成五个子集。在每个交叉验证中,使用其中的四个子集用于训练,剩余的子集用于验证模型,此外模型还对测试数据进行了预测。在交叉验证结束时,会得到五个测试预测概率。最后,对所有类别的概率取平均值。模型的训练表现稳定,每个交叉验证上具有稳定的召回率和f1分数。这项技术也帮助我在Kaggle比赛中取得了很好的成绩(前1%)。以下部分代码片段显示了集成交叉验证的实现:

for j, (train_idx, valid_idx) in enumerate(folds):X_train = X[train_idx]Y_train = y[train_idx]X_valid = X[valid_idx]Y_valid = y[valid_idx]clf.fit(X_train, Y_train)valid_pred = clf.predict(X_valid)recall  = recall_score(Y_valid, valid_pred, average='macro')f1 = f1_score(Y_valid, valid_pred, average='macro')recall_scores[i][j] = recallf1_scores[i][j] = f1train_pred[valid_idx, i] = valid_predtest_pred[:, test_col] = clf.predict(T)test_col += 1## Probabilitiesvalid_proba = clf.predict_proba(X_valid)train_proba[valid_idx, :] = valid_probatest_proba  += clf.predict_proba(T)test_proba /= self.n_splits

B.设置类别权重/重要性:

       代价敏感学习是使随机森林更适合从非常不平衡的数据中学习的方法之一。随机森林有倾向于偏向大多数类别。因此,对少数群体错误分类施加昂贵的惩罚可能是有作用的。由于这种技术可以改善模型性能,所以我给少数群体分配了很高的权重(即更高的错误分类成本)。然后将类别权重合并到随机森林算法中。我根据类别1中数据集的数量与其它数据集的数量之间的比率来确定类别权重。例如,类别1和类别3数据集的数目之间的比率约为110,而类别1和类别2的比例约为26。现在我稍微对数量进行修改以改善模型的性能,以下代码片段显示了不同类权重的实现:

from sklearn.ensemble import RandomForestClassifier
class_weight = dict({1:1.9, 2:35, 3:180})rdf = RandomForestClassifier(bootstrap=True,class_weight=class_weight, criterion='gini',max_depth=8, max_features='auto', max_leaf_nodes=None,min_impurity_decrease=0.0, min_impurity_split=None,min_samples_leaf=4, min_samples_split=10,min_weight_fraction_leaf=0.0, n_estimators=300,oob_score=False,random_state=random_state,verbose=0, warm_start=False)

C.过大预测标签而不是过小预测(Over-Predict a Label than Under-Predict):

       这项技术是可选的,通过实践发现,这种方法对提高少数类别的表现非常有效。简而言之,如果将模型错误分类为类别3,则该技术能最大限度地惩罚该模型,对于类别2和类别1惩罚力度稍差一些。 为了实施该方法,我改变了每个类别的概率阈值,将类别3、类别2和类别1的概率设置为递增顺序(即,P3= 0.25,P2= 0.35,P1= 0.50),以便模型被迫过度预测类别。该算法的详细实现可以在Github上找到。

最终结果

       以下结果表明,上述三种技术如何帮助改善模型性能:
1.使用集成交叉验证的结果:



2.使用集成交叉验证+类别权重的结果:



3.使用集成交叉验证+类别权重+过大预测标签的结果:


结论

       由于在实施过大预测技术方面的经验很少,因此最初的时候处理起来非常棘手。但是,研究该问题有助于提升我解决问题的能力。对于每个任务而言,起初可能确实是陌生的,这个时候不要害怕,一次次尝试就好。由于时间的限制(48小时),无法将精力分散于模型的微调以及特征工程,存在改进的地方还有很多,比如删除不必要的功能并添加一些额外功能。此外,也尝试过LightGBM和XgBoost算法,但在实践过程中发现,随机森林的效果优于这两个算法。在后面的研究中,可以进一步尝试一些其他算法,比如神经网络、稀疏编码等。

原文链接

本文为云栖社区原创内容,未经允许不得转载。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/521674.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

为什么说「中台」程序员将来会最值钱?

戳蓝字“CSDN云计算”关注我们哦!今年在国内互联网公司中真的是很流行中台这个概念,不,是非常流行,是相当流行。作为程序员真的非常有必要了解一下。国内中台概念的由来国内中台的这个概念最早是由阿里巴巴提出来的。据说故事是这…

varclus变量聚类对变量有啥要求_互助问答第208期:递归的双变量probit模型问题...

递归的双变量probit模型的stata命令是什么?比如二元被解释变量为y1,内生的二元变量为t1,x1和x2是其他外生协变量,iv1为内生二元解释变量的工具变量,那么,递归双变量probit模型是否可以写为:bipr…

Vue + Spring Boot 项目实战(六):使用 Element 辅助前端开发

文章目录一、安装并引入 Element1.安装 Element2.引入 Element二、优化登录页面1.使用 Form 组件2.添加样式3.设置背景4.完整代码之前我们实现了登录功能,但不得不说登录页面实在是太简陋了。在这个看脸的社会,如果代码写的烂,界面也做得不好…

不止 5G 和鸿蒙,华为最新大招,扔出 AI 计算核弹

戳蓝字“CSDN云计算”关注我们哦!华为发布全世界最快AI产品,集成1024颗业内最强芯片,训练ResNet-50只需59.8秒。近日,华为全联接大会开幕,推出又一重量级AI产品Atlas900。此前接受外媒采访时,任正非就已经预…

阿里90后工程师利用ARM硬件特性开启安卓8终端“上帝模式”

摘要: 本文以安卓8终端为载体,介绍阿里安全潘多拉实验室成员研究并提出的内核空间镜像攻击利用技巧。文/图 阿里安全潘多拉实验室 团控编者按:团控,阿里安全潘多拉实验室研究人员,该实验室主要聚焦于移动安全领域&…

神龙X-Dragon,这技术“范儿”如何?

戳蓝字“CSDN云计算”关注我们哦!在CSDN总部会议室,阿晶首次见到了阿里云智能研究员、弹性计算技术负责人张献涛——这位不仅仅在阿里云智能内部,在业内也是响当当的虚拟化技术大牛。现在回想起来,当时聊了没两句,阿晶…

python 如何判断一个函数执行完成_三步搞定 Python 中的文件操作

当程序运行时,变量是保存数据的好方法,但变量、序列以及对象中存储的数据是暂时的,程序结束后就会丢失,如果希望程序结束后数据仍然保持,就需要将数据保存到文件中。Python 提供了内置的文件对象,以及对文件…

一位资深程序员大牛给予Java初学者的学习路线建议

摘要: java学习这一部分其实也算是今天的重点,这一部分用来回答很多群里的朋友所问过的问题,那就是我你是如何学习Java的,能不能给点建议?今天我是打算来点干货,因此咱们就不说一些学习方法和技巧了&#x…

Vue + Spring Boot 项目实战(七):前端路由与登录拦截器

文章目录前言一、前端路由二、使用 History 模式三、后端登录拦截器3.1. LoginController3.2. LoginInterceptor3.3. WebConfigurer3.4. 效果检验四、Vuex 与前端登录拦截器4.1. 引入 Vuex4.2. 修改路由配置4.3. 使用钩子函数判断是否拦截4.4. 修改 Login.vue4.5. 效果检验前言…

高手如何实践HBase?不容错过的滴滴内部技巧

摘要: HBase和Phoenix的优势大家众所周知,想要落地实践却问题一堆?replication的随机发送、Connection的管理是否让你头痛不已?本次分享中,滴滴以典型的应用场景带大家深入探究HBase和Phoenix,并分享内核改…

JS 打印 data数据_数据表格 Data Table - 复杂内容的15个设计点

表格是桌面应用中常见的内容型组件,它包含大量的信息和丰富的交互形式,表格具有极高的空间利用率,结构化的展示保证了数据可读性。高效、清晰且易用是进行表格设计的原则性要求。本文将从表格的内容组织到交互作一次汇总,作为数据…

神龙X-Dragon,这技术“范儿”如何?| 问底中国IT技术演进

在CSDN总部会议室,阿晶首次见到了阿里云智能研究员、弹性计算技术负责人张献涛——这位不仅仅在阿里云智能内部,在业内也是响当当的虚拟化技术大牛。现在回想起来,当时聊了没两句,阿晶就问了这样一个问题,“阿里云这款…

干货 | 蚂蚁金服是如何实现经典服务化架构往 Service Mesh 方向的演进的?

摘要: 小蚂蚁说: 蚂蚁金服在服务化上面已经经过多年的沉淀,支撑了每年双十一的高峰峰值。Service Mesh 作为微服务的一个新方向,在最近两年成为领域的一个大热点,但是如何从经典服务化架构往 Service Mesh 的方向上演进…

Vue + Spring Boot 项目实战(八):导航栏与图书页面设计

文章目录前言一、导航栏的实现1.路由配置2.使用 NavMenu 组件二、图书管理页面2.1. LibraryIndex.vue2.SideMenu.vue3.Books.vue前言 之前讲过使用 Element 辅助前端页面的开发,但是只用到了比较少的内容,这一篇我们来做一下系统的核心页面——图书管理…

pmsm simulink foc 仿真_仿真软件教程

很多朋友都建议我做个视频的整理,方便没看过之前内容的朋友方便查找,我觉得这个确实很有必要。下面内容是关于仿真软件方面:仿真环境:Simlpis 8.0类型简介VMC和CMC的LLC控制器仿真对比 第一节图文电压模式和电流模式LLC控制器的简…

日志采集中的关键技术分析

摘要: 从日志投递的方式来看,日志采集又可以分为推模式和拉模式,本文主要分析的是推模式的日志采集。概述日志从最初面向人类演变到现在的面向机器发生了巨大的变化。最初的日志主要的消费者是软件工程师,他们通过读取日志来排查问…

限时早鸟票 | 2019 中国大数据技术大会(BDTC)超豪华盛宴抢先看!

2019 年12月5-7 日,由中国计算机学会主办,CCF 大数据专家委员会承办,CSDN、中科天玑数据科技股份有限公司协办的 2019 中国大数据技术大会,将于北京长城饭店隆重举行。届时,超过百位技术专家及行业领袖将齐聚于此&…

机器学习和数据科学领域必读的10本免费书籍

摘要: 暑期来了,别出去溜达了,看书学习一波~在这个暑假,有兴趣的可以阅读一下这些免费的有关机器学习和数据科学的书籍,他们能给你打开一扇看清机器学习和数据科学的窗。如果在阅读完这一文章后想知晓更多免…

microsoft账号登陆一直在加载_英雄联盟手游下载,附带拳头账号注册教程

欢迎关注【花卷来了】公众号。如果喜欢本期节目请点赞、再看、分享给朋友吧~软件资源请回复文章底部今日关键词获取/排版:萌萌哒花卷/来源:采集自网络今日主题:最新英雄联盟手游下载,附带拳头账号注册教程英雄联盟手游今天正式公测…

机器学习者都应该知道的五种损失函数!

摘要: 还不知道这五种损失函数?你怎么在机器学习这个圈子里面混?在机器学习中,所有的机器学习算法都或多或少的依赖于对目标函数最大化或者最小化的过程,我们常常把最小化的函数称为损失函数,它主要用于衡量…