Python绘制多分类ROC曲线

目录

1 数据集介绍

1.1 数据集简介

1.2 数据预处理

 2随机森林分类

2.1 数据加载

2.2 参数寻优

2.3 模型训练与评估

3 绘制十分类ROC曲线

第一步,计算每个分类的预测结果概率

第二步,画图数据准备

第三步,绘制十分类ROC曲线


1 数据集介绍

1.1 数据集简介

分类数据集为某公司手机上网满意度数据集,数据如图所示,共7020条样本,关于手机满意度分类的特征有网络覆盖与信号强度、手机上网速度、手机上网稳定性等75个特征。

1.2 数据预处理

常规数据处理流程,详细内容见上期随机森林处理流程:

xxx 链接

  • 缺失值处理

  • 异常值处理

  • 数据归一化

  • 分类特征编码

处理完后的数据保存为 手机上网满意度.csv文件(放置文末)

 2随机森林分类

2.1 数据加载

第一步,导入包

import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestClassifier# 随机森林回归
from sklearn.model_selection import train_test_split,GridSearchCV,cross_val_score
from sklearn.metrics import accuracy_score # 引入准确度评分函数
from sklearn.metrics import mean_squared_error
from sklearn import preprocessing
import warnings
warnings.filterwarnings('ignore')
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rc("font", family='Microsoft YaHei')

第二步,加载数据

net_data = pd.read_csv('手机上网满意度.csv')
Y = net_data.iloc[:,3]   
X= net_data.iloc[:, 4:]
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.3, random_state=200)  # 随机数种子
print(net_data.shape)
net_data.describe()

2.2 参数寻优

第一步,对最重要的超参数n_estimators即决策树数量进行调试,通过不同数目的树情况下,在训练集和测试集上的均方根误差来判断

## 分析随着树数目的变化,在测试集和训练集上的预测效果
rfr1 = RandomForestClassifier(random_state=1)
n_estimators = np.arange(50,500,50) # 420,440,2
train_mse = []
test_mse = []
for n in n_estimators:rfr1.set_params(n_estimators = n) # 设置参数rfr1.fit(X_train,Y_train) # 训练模型rfr1_lab = rfr1.predict(X_train)rfr1_pre = rfr1.predict(X_test)train_mse.append(mean_squared_error(Y_train,rfr1_lab))test_mse.append(mean_squared_error(Y_test,rfr1_pre))## 可视化不同数目的树情况下,在训练集和测试集上的均方根误差
plt.figure(figsize=(12,9))
plt.subplot(2,1,1)
plt.plot(n_estimators,train_mse,'r-o',label='trained MSE',color='darkgreen')
plt.xlabel('Number of trees')
plt.ylabel('MSE')
plt.grid()
plt.legend()plt.subplot(2,1,2)
plt.plot(n_estimators,test_mse,'r-o',label='test MSE',color='darkgreen')
index = np.argmin(test_mse)
plt.annotate('MSE:'+str(round(test_mse[index],4)),xy=(n_estimators[index],test_mse[index]),xytext=(n_estimators[index]+2,test_mse[index]+0.000002),arrowprops=dict(facecolor='red',shrink=0.02))
plt.xlabel('Number of trees')
plt.ylabel('MSE')
plt.grid()
plt.legend()
plt.tight_layout()
plt.show()

以及最优参数和最高得分进行分析,如下所示

###调n_estimators参数
ScoreAll = []
for i in range(50,500,50):DT = RandomForestClassifier(n_estimators = i,random_state = 1) #,criterion = 'entropy'score = cross_val_score(DT,X_train,Y_train,cv=6).mean()ScoreAll.append([i,score])
ScoreAll = np.array(ScoreAll)max_score = np.where(ScoreAll==np.max(ScoreAll[:,1]))[0][0] ##这句话看似很长的,其实就是找出最高得分对应的索引
print("最优参数以及最高得分:",ScoreAll[max_score])  
plt.figure(figsize=[20,5])
plt.plot(ScoreAll[:,0],ScoreAll[:,1],'r-o',label='最高得分',color='orange')
plt.xlabel('n_estimators参数')
plt.ylabel('分数')
plt.grid()
plt.legend()
plt.show()

很明显,决策树个数设置在400的时候回归森林预测模型的测试集均方根误差最小,得分最高,效果最显著。因此,我们通过网格搜索进行小范围搜索,构建随机森林预测模型时选取的决策树个数为400。

第二步,在确定决策树数量大概范围后,搜索决策树的最大深度的最高得分,如下所示

# 探索max_depth的最佳参数
ScoreAll = []  
for i in range(4,14,2):  DT = RandomForestClassifier(n_estimators = 400,random_state = 1,max_depth =i ) #,criterion = 'entropy'  score = cross_val_score(DT,X_train,Y_train,cv=6).mean()  ScoreAll.append([i,score])  
ScoreAll = np.array(ScoreAll)  max_score = np.where(ScoreAll==np.max(ScoreAll[:,1]))[0][0] 
print("最优参数以及最高得分:",ScoreAll[max_score])    
plt.figure(figsize=[20,5])  
plt.plot(ScoreAll[:,0],ScoreAll[:,1]) 
plt.xlabel('max_depth最佳参数')
plt.ylabel('分数')
plt.grid()
plt.legend() 
plt.show()  

决策树的深度最终设置为10。

2.3 模型训练与评估

# 随机森林 分类模型  
model = RandomForestClassifier(n_estimators=400,max_depth=10,random_state=1) # min_samples_leaf=11
# 模型训练
model.fit(X_train, Y_train)
# 模型预测
y_pred = model.predict(X_test)
print('训练集模型分数:', model.score(X_train,Y_train))
print('测试集模型分数:', model.score(X_test,Y_test))
print("训练集准确率: %.3f" % accuracy_score(Y_train, model.predict(X_train)))
print("测试集准确率: %.3f" % accuracy_score(Y_test, y_pred))

绘制混淆矩阵:

# 混淆矩阵
from sklearn.metrics import confusion_matrix
import matplotlib.ticker as tickercm = confusion_matrix(Y_test, y_pred,labels=[1,2,3,4,5,6,7,8,9,10]) # ,
print('混淆矩阵:\n', cm)
labels=['1','2','3','4','5','6','7','8','9','10']
from sklearn.metrics import ConfusionMatrixDisplay
cm_display = ConfusionMatrixDisplay(cm,display_labels=labels).plot()

3 绘制十分类ROC曲线

第一步,计算每个分类的预测结果概率

from sklearn.metrics import roc_curve,auc
df = pd.DataFrame()
pre_score = model.predict_proba(X_test)
df['y_test'] = Y_test.to_list()df['pre_score1'] = pre_score[:,0]
df['pre_score2'] = pre_score[:,1]
df['pre_score3'] = pre_score[:,2]
df['pre_score4'] = pre_score[:,3]
df['pre_score5'] = pre_score[:,4]
df['pre_score6'] = pre_score[:,5]
df['pre_score7'] = pre_score[:,6]
df['pre_score8'] = pre_score[:,7]
df['pre_score9'] = pre_score[:,8]
df['pre_score10'] = pre_score[:,9]pre1 = df['pre_score1']
pre1 = np.array(pre1)pre2 = df['pre_score2']
pre2 = np.array(pre2)pre3 = df['pre_score3']
pre3 = np.array(pre3)pre4 = df['pre_score4']
pre4 = np.array(pre4)pre5 = df['pre_score5']
pre5 = np.array(pre5)pre6 = df['pre_score6']
pre6 = np.array(pre6)pre7 = df['pre_score7']
pre7 = np.array(pre7)pre8 = df['pre_score8']
pre8 = np.array(pre8)pre9 = df['pre_score9']
pre9 = np.array(pre9)pre10 = df['pre_score10']
pre10 = np.array(pre10)

第二步,画图数据准备

y_list = df['y_test'].to_list()
pre_list=[pre1,pre2,pre3,pre4,pre5,pre6,pre7,pre8,pre9,pre10]lable_names=['1','2','3','4','5','6','7','8','9','10']
colors1 = ["r","b","g",'gold','pink','y','c','m','orange','chocolate']
colors2 = "skyblue"# "mistyrose","skyblue","palegreen"
my_list = []
linestyles =["-", "--", ":","-", "--", ":","-", "--", ":","-"]

第三步,绘制十分类ROC曲线

plt.figure(figsize=(12,5),facecolor='w')
for i in range(10):roc_auc = 0#添加文本信息if i==0:fpr, tpr, threshold = roc_curve(y_list,pre_list[i],pos_label=1)# 计算AUC的值roc_auc = auc(fpr, tpr)plt.text(0.3, 0.01, "class "+lable_names[i]+' :ROC curve (area = %0.2f)' % roc_auc)elif i==1:fpr, tpr, threshold = roc_curve(y_list,pre_list[i],pos_label=2)# 计算AUC的值roc_auc = auc(fpr, tpr)plt.text(0.3, 0.11, "class "+lable_names[i]+' :ROC curve (area = %0.2f)' % roc_auc)elif i==2:fpr, tpr, threshold = roc_curve(y_list,pre_list[i],pos_label=3)# 计算AUC的值roc_auc = auc(fpr, tpr)plt.text(0.3, 0.21, "class "+lable_names[i]+' :ROC curve (area = %0.2f)' % roc_auc)elif i==3:fpr, tpr, threshold = roc_curve(y_list,pre_list[i],pos_label=4)# 计算AUC的值roc_auc = auc(fpr, tpr)plt.text(0.3, 0.31, "class "+lable_names[i]+' :ROC curve (area = %0.2f)' % roc_auc)elif i==4:fpr, tpr, threshold = roc_curve(y_list,pre_list[i],pos_label=5)# 计算AUC的值roc_auc = auc(fpr, tpr)plt.text(0.3, 0.41, "class "+lable_names[i]+' :ROC curve (area = %0.2f)' % roc_auc)elif i==5:fpr, tpr, threshold = roc_curve(y_list,pre_list[i],pos_label=6)# 计算AUC的值roc_auc = auc(fpr, tpr)plt.text(0.6, 0.01, "class "+lable_names[i]+' :ROC curve (area = %0.2f)' % roc_auc)elif i==6:fpr, tpr, threshold = roc_curve(y_list,pre_list[i],pos_label=7)# 计算AUC的值roc_auc = auc(fpr, tpr)plt.text(0.6, 0.11, "class "+lable_names[i]+' :ROC curve (area = %0.2f)' % roc_auc)elif i==7:fpr, tpr, threshold = roc_curve(y_list,pre_list[i],pos_label=8)# 计算AUC的值roc_auc = auc(fpr, tpr)plt.text(0.6, 0.21, "class "+lable_names[i]+' :ROC curve (area = %0.2f)' % roc_auc)elif i==8:fpr, tpr, threshold = roc_curve(y_list,pre_list[i],pos_label=9)# 计算AUC的值roc_auc = auc(fpr, tpr)plt.text(0.6, 0.31, "class "+lable_names[i]+' :ROC curve (area = %0.2f)' % roc_auc)elif i==9:fpr, tpr, threshold = roc_curve(y_list,pre_list[i],pos_label=10)# 计算AUC的值roc_auc = auc(fpr, tpr)plt.text(0.6, 0.41, "class "+lable_names[i]+' :ROC curve (area = %0.2f)' % roc_auc)my_list.append(roc_auc)# 添加ROC曲线的轮廓plt.plot(fpr, tpr, color = colors1[i],linestyle = linestyles[i],linewidth = 3,label = "class:"+lable_names[i])  #  lw = 1,#绘制面积图plt.stackplot(fpr, tpr, colors=colors2, alpha = 0.5,edgecolor = colors1[i]) #  alpha = 0.5,# 添加对角线
plt.plot([0, 1], [0, 1], color = 'black', linestyle = '--',linewidth = 3)
plt.xlabel('1-Specificity')
plt.ylabel('Sensitivity')
plt.grid()
plt.legend()
plt.title("手机上网稳定性ROC曲线和AUC数值")
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/210986.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【数据结构】——排序篇(上)

前言:前面我们已经学过了许许多多的排序方法,如冒泡排序,选择排序,堆排序等等,那么我们就来将排序的方法总结一下。 我们的排序方法包括以下几种,而快速排序和归并排序我们后面进行详细的讲解。 直接插入…

Qt实现二维码生成和识别

一、简介 QZxing开源库: 生成和识别条码和二维码 下载地址:https://gitcode.com/mirrors/ftylitak/qzxing/tree/master 二、编译与使用 1.下载并解压,解压之后如图所示 2.编译 打开src目录下的QZXing.pro,选择合适的编译器进行编译 最后生…

MIT6S081-Lab2总结

大家好,我叫徐锦桐,个人博客地址为www.xujintong.com,github地址为https://github.com/xjintong。平时记录一下学习计算机过程中获取的知识,还有日常折腾的经验,欢迎大家访问。 Lab2就是了解一下xv6的系统调用流程&…

解决服务端渲染程序SSR运行时报错: ReferenceError: document is not defined

现象: 原因: 该错误表明在服务端渲染 (SSR) 过程中,有一些代码尝试在没有浏览器环境的情况下执行与浏览器相关的操作。这在服务端渲染期间是一个常见的问题,因为在服务端渲染期间是没有浏览器 API。 解决办法: 1. 修…

【2023传智杯-新增场次】第六届传智杯程序设计挑战赛AB组-DEF题复盘解题分析详解【JavaPythonC++解题笔记】

本文仅为【2023传智杯-第二场】第六届传智杯程序设计挑战赛-题目解题分析详解的解题个人笔记,个人解题分析记录。 本文包含:第六届传智杯程序设计挑战赛题目、解题思路分析、解题代码、解题代码详解 文章目录 一.前言二.赛题题目D题题目-E题题目-F题题目-二.赛题题解D题题解-…

深入理解Sentinel系列-1.初识Sentinel

👏作者简介:大家好,我是爱吃芝士的土豆倪,24届校招生Java选手,很高兴认识大家📕系列专栏:Spring源码、JUC源码、Kafka原理、分布式技术原理🔥如果感觉博主的文章还不错的话&#xff…

如何搭建自己的直播电商系统?

当下,传统的图文电商模式已经走向没落,视频电商备受追捧。抖音、快手、小红书、京东、淘宝、拼多多都在发力直播电商业务,尤其是以抖音为首的直播电商备受用户欢迎,它具有实时直播和强互动的特点,是传统电商所不具备的…

最长子串问题(LCS)--动态规划解法

题目描述: 如果Z既是X的子串,又是Y的子串,则称Z为X和Y的公共子串。 如果给定X、Y,求出最长Z及其长度。 注意:这里求的不是子序列,两者的意思并不相同。子串要求连续,子序列并不需要。 如果想…

simulinkveristandlabview联合仿真环境搭建

目录 开篇废话 软件版本 明确需求 软件安装 matlab2020a veristand2020 R4 VS2017 VS2010 软件安装验证 软件资源分享 开篇废话 推免之后接到的第一个让人难绷的活,网上开源的软件资料和成功的案例很少,查来查去就那么几篇,而且版本…

SpringData

1.为什么要学习SpringData? 是因为对数据存储的框架太多了,全部都要学习成本比较高,SpringData对这些数据存储层做了一个统一,学习成本大大降低。

SQL命令---修改字段的数据类型

介绍 使用sql语句修改字段的数据类型。 命令 alter table 表明 modify 字段名 数据类型;例子 有一张a表,表里有一个id字段,长度为11。使用命令将长度修改为12 下面使用命令进行修改: alter table a modify id int(12) NOT NULL;下面使修…

stm32使用多串口不输出无反应的问题(usart1、usart2)

在使用stm32c8t6单片机时,由于需要使用两个串口usart1 、usart2。usart1用作程序烧录、调试作用,串口2用于与其它模块进行通信。 使用串口1时,正常工作,使用串口2时,无反应。查阅了相关资料串口2在PA2\PA3 引脚上。RX…

[仅供学习,禁止用于违法]编写一个程序来手动设置Windows的全局代理开或关,实现对所有网络请求拦截和数据包捕获(抓包或VPN的应用)

文章目录 介绍一、实现原理二、通过注册表设置代理2.1 开启代理2.2 关闭代理2.3 添加代理地址2.4 删除代理设置信息 三、代码实战3.1 程序控制代理操作控制3.1.1 开启全局代理3.1.2 添加代理地址3.1.3 关闭代理开关3.1.4 删除代理信息 3.2 拦截所有请求 介绍 有一天突发奇想&am…

在git使用SSH密钥进行github身份认证学习笔记

1.生成ssh密钥对 官网文档:Https://docs.github.com/zh/authentication(本节内容对应的官方文档,不清晰的地方可参考此内容) 首先,启动我们的git bush(在桌面右键,点击 Git Bush Here &#xf…

iOS_制作 cocopods库

文章目录 1.创建项目2.配置项目3.发布 1.创建项目 在 github 上创建仓库&#xff0c;克隆到本地&#xff1a; git clone https://github.com/mxh-mo/MOOXXX.git在项目目录下执行&#xff1a; pod lib create <库名称>进行一些配置的选择&#xff1a; # 希望在那个平台…

随机分词与tokenizer(BPE->BBPE->Wordpiece->Unigram->sentencepiece->bytepiece)

0 tokenizer综述 根据不同的切分粒度可以把tokenizer分为: 基于词的切分&#xff0c;基于字的切分和基于subword的切分。 基于subword的切分是目前的主流切分方式。subword的切分包括: BPE(/BBPE), WordPiece 和 Unigram三种分词模型。其中WordPiece可以认为是一种特殊的BPE。完…

求Sn=m+mm+mmm+...+mm..mmm(有n个m)的值

题目&#xff1a;求 的值 一、做这个题我们其实可以直接一个for求解&#xff1a; a,aa,aaa...我们很容易知道它们后一项与前一项的关系就是&#xff1b; public static void Sum(int m,int n){long sum 0L;long curAn 0;for (int i 0; i < n; i){curAn m 10* curAn;/…

Qexo博客后台管理部署

Qexo博客后台管理部署 个人主页 个人博客 参考文档 https://www.oplog.cn/qexo/本地部署 采用本地Docker部署管理本地Hexo 下载代码包 若无法下载使用科学工具下载到本地在上传到服务器 wget https://github.com/Qexo/Qexo/archive/refs/tags/3.0.1.zip# 解压 unzip Qexo…

联合体和枚举

联合体&#xff1a; 联合体是什么&#xff1f; 联合体也是一种自定义类型&#xff0c;这种类型定义的变量也包含一系列类型&#xff0c;特征是这些类型公用一块内存空间(所以叫联合体也叫公用体)可以理解为结构体公用一块内存。 //联合-联合体-共用体 //联合也是一种特殊的自…

TOMCAT9安装

1、官网下载 2、解压到任意盘符&#xff0c;注意路径不要有中文 3、环境变量 path 下 配置 %CATALINA_HOME%\bin 4、找到tomcat9/bin&#xff0c; 点击 start.bat启动 tomcat