【机器学习】--模型评估指标之混淆矩阵,ROC曲线和AUC面积

一、前述

怎么样对训练出来的模型进行评估是有一定指标的,本文就相关指标做一个总结。

二、具体

1、混淆矩阵

混淆矩阵如图:

 第一个参数true,false是指预测的正确性。

 第二个参数true,postitives是指预测的结果。

 相关公式:

 

检测正列的效果:

检测负列的效果:

 

 

 公式解释:

 fp_rate:

tp_rate:

recall:(召回率)

值越大越好

presssion:(准确率)

TP:本来是正例,通过模型预测出来是正列

TP+FP:通过模型预测出来的所有正列数(其中包括本来是负例,但预测出来是正列)

 值越大越好

F1_Score:

 

准确率和召回率是负相关的。如图所示:

通俗解释:

实际上非常简单,精确率是针对我们预测结果而言的,它表示的是预测为正的样本中有多少是真正的正样本。那么预测为正就有两种可能了,一种就是把正类预测为正类(TP),另一种就是把负类预测为正类(FP),也就是

 

召回率是针对我们原来的样本而言的,它表示的是样本中的正例有多少被预测正确了。那也有两种可能,一种是把原来的正类预测成正类(TP),另一种就是把原来的正类预测为负类(FN)。

 

其实就是分母不同,一个分母是预测为正的样本数,另一个是原来样本中所有的正样本数。

2、ROC曲线

过程:对第一个样例,预测对,阈值是0.9,所以曲线向上走,以此类推。

          对第三个样例,预测错,阈值是0.7 ,所以曲线向右走,以此类推。

几种情况:

所以得出结论,曲线在对角线以上,则准确率好。

3、AUC面积

 

M是样本中正例数

N是样本中负例数

其中累加解释是把预测出来的所有概率结果按照分值升序排序,然后取正例所对应的索引号进行累加

通过AUC面积预测出来的可以知道好到底有多好,坏到底有多坏。因为正例的索引比较大,则AUC面积越大。

总结:

 4、交叉验证

为在实际的训练中,训练的结果对于训练集的拟合程度通常还是挺好的(初试条件敏感),但是对于训练集之外的数据的拟合程度通常就不那么令人满意了。因此我们通常并不会把所有的数据集都拿来训练,而是分出一部分来(这一部分不参加训练)对训练集生成的参数进行测试,相对客观的判断这些参数对训练集之外的数据的符合程度。这种思想就称为交叉验证。

 一般3折或者5折交叉验证就足够了。

 三、代码

 

#!/usr/bin/python
# -*- coding: UTF-8 -*-
# 文件名: mnist_k_cross_validate.pyfrom sklearn.datasets import fetch_mldata
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import StratifiedKFold
from sklearn.base import clone
from sklearn.model_selection import cross_val_score
from sklearn.base import BaseEstimator #评估指标
from sklearn.model_selection import cross_val_predict
from sklearn.metrics import confusion_matrix
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import f1_score
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import roc_curve
from sklearn.metrics import roc_auc_score
from sklearn.ensemble import RandomForestClassifier# Alternative method to load MNIST, if mldata.org is down
from scipy.io import loadmat #利用Matlib加载本地数据
mnist_raw = loadmat("mnist-original.mat")
mnist = {"data": mnist_raw["data"].T,"target": mnist_raw["label"][0],"COL_NAMES": ["label", "data"],"DESCR": "mldata.org dataset: mnist_k_cross_validate-original",
}
print("Success!")
# mnist_k_cross_validate = fetch_mldata('MNIST_original', data_home='test_data_home')
print(mnist)X, y = mnist['data'], mnist['target'] # X 是70000行 784个特征 y是70000行 784个像素点
print(X.shape, y.shape)
#
some_digit = X[36000]
print(some_digit)
some_digit_image = some_digit.reshape(28, 28)#调整矩阵 28*28=784 784个像素点调整成28*28的矩阵 图片是一个28*28像素的图片 每一个像素点是一个rgb的值
print(some_digit_image)
#
plt.imshow(some_digit_image, cmap=matplotlib.cm.binary,interpolation='nearest')
plt.axis('off')
plt.show()
#
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[:60000]#6/7作为训练,1/7作为测试
shuffle_index = np.random.permutation(60000)#返回一组随机的数据 shuffle 打乱60000中每行的值 即每个编号的值不是原先的对应的值
X_train, y_train = X_train[shuffle_index], y_train[shuffle_index] # Shuffle之后的取值
# #
y_train_5 = (y_train == 5)# 是5就标记为True,不是5就标记为false
y_test_5 = (y_test == 5)
print(y_test_5)
#这里可以直接写成LogGression
sgd_clf = SGDClassifier(loss='log', random_state=42)# log 代表逻辑回归 random_state或者random_seed 随机种子 写死以后生成的随机数就是一样的
sgd_clf.fit(X_train, y_train_5)#构建模型
print(sgd_clf.predict([some_digit]))# 测试模型 最终为5
# #
### K折交叉验证
##总共会运行3次
skfolds = StratifiedKFold(n_splits=3, random_state=42)# 交叉验证 3折 跑三次 在训练集中的开始1/3 中测试,中间1/3 ,最后1/3做验证
for train_index, test_index in skfolds.split(X_train, y_train_5):#可以把sgd_clf = SGDClassifier(loss='log', random_state=42)这一行放入进来,传不同的超参数 这里就不用克隆了clone_clf = clone(sgd_clf)# clone一个上一个一样的模型 让它不变了 每次初始随机参数w0,w1,w2都一样,所以设定随机种子是一样X_train_folds = X_train[train_index]#对应的是训练集中训练的X 没有阴影的y_train_folds = y_train_5[train_index]# 对应的是训练集中的训练y 没有阴影的X_test_folds = X_train[test_index]#对应的是训练集中的测试的X 阴影部分的y_test_folds = y_train_5[test_index]#对应的是训练集中的测试的Y 阴影部分的
clone_clf.fit(X_train_folds, y_train_folds)#构建模型y_pred = clone_clf.predict(X_test_folds)#验证print(y_pred)n_correct = sum(y_pred == y_test_folds)# 如若预测对了加和 因为true=1 false=0print(n_correct / len(y_pred))#得到预测对的精度 #用判断正确的数/总共预测的 得到一个精度
# #PS:这里可以把上面的模型生成直接放在交叉验证里面传一些超参数比如阿尔法,看最后的准确率则知道什么超参数最好。#这是Sk_learn里面的实现的函数cv是几折,score评估什么指标这里是准确率,结果类似上面一大推代码
print(cross_val_score(sgd_clf, X_train, y_train_5, cv=3, scoring='accuracy')) #这是Sk_learn里面的实现的函数cv是几折,score评估什么指标这里是准确率class Never5Classifier(BaseEstimator):#给定一个分类器,永远不会分成5这个类别 因为正负列样本不均匀,所以得出的结果是90%,所以只拿精度是不准确的。def fit(self, X, y=None):passdef predict(self, X):return np.zeros((len(X), 1), dtype=bool)never_5_clf = Never5Classifier()
print(cross_val_score(never_5_clf, X_train, y_train_5, cv=3, scoring='accuracy'))#给每一个结果一个结果
# #
# #
##混淆矩阵 可以准确地知道哪一个类别判断的不准
y_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3)#给每一个结果预测一个概率
print(confusion_matrix(y_train_5, y_train_pred))
# #
y_train_perfect_prediction = y_train_5
print(confusion_matrix(y_train_5, y_train_5))
#准确率,召回率,F1Score
print(precision_score(y_train_5, y_train_pred))
print(recall_score(y_train_5, y_train_pred))
print(sum(y_train_pred))
print(f1_score(y_train_5, y_train_pred))sgd_clf.fit(X_train, y_train_5)
y_scores = sgd_clf.decision_function([some_digit])
print(y_scores)threshold = 0 # Z的大小 wT*x的结果
y_some_digit_pred = (y_scores > threshold)
print(y_some_digit_pred)threshold = 200000
y_some_digit_pred = (y_scores > threshold)
print(y_some_digit_pred)y_scores = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3, method='decision_function')
print(y_scores)#直接得出Score

precisions, recalls, thresholds = precision_recall_curve(y_train_5, y_scores)
print(precisions, recalls, thresholds)def plot_precision_recall_vs_threshold(precisions, recalls, thresholds):plt.plot(thresholds, precisions[:-1], 'b--', label='Precision')plt.plot(thresholds, recalls[:-1], 'r--', label='Recall')plt.xlabel("Threshold")plt.legend(loc='upper left')plt.ylim([0, 1])# plot_precision_recall_vs_threshold(precisions, recalls, thresholds)
# plt.savefig('./temp_precision_recall')

y_train_pred_90 = (y_scores > 70000)
print(precision_score(y_train_5, y_train_pred_90))
print(recall_score(y_train_5, y_train_pred_90))fpr, tpr, thresholds = roc_curve(y_train_5, y_scores)def plot_roc_curve(fpr, tpr, label=None):plt.plot(fpr, tpr, linewidth=2, label=label)plt.plot([0, 1], [0, 1], 'k--')plt.axis([0, 1, 0, 1])plt.xlabel('False Positive Rate')plt.ylabel('True positive Rate')plot_roc_curve(fpr, tpr)
plt.show()
# plt.savefig('img_roc_sgd')print(roc_auc_score(y_train_5, y_scores))forest_clf = RandomForestClassifier(random_state=42)
y_probas_forest = cross_val_predict(forest_clf, X_train, y_train_5, cv=3, method='predict_proba')
y_scores_forest = y_probas_forest[:, 1]fpr_forest, tpr_forest, thresholds_forest = roc_curve(y_train_5, y_scores_forest)
plt.plot(fpr, tpr, 'b:', label='SGD')
plt.plot(fpr_forest, tpr_forest, label='Random Forest')
plt.legend(loc='lower right')
plt.show()
# plt.savefig('./img_roc_forest')print(roc_auc_score(y_train_5, y_scores_forest))#
#

 

 

 

 

acc 看中整体
auc看中正例

 

转载于:https://www.cnblogs.com/LHWorldBlog/p/8656439.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/414390.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

05-Flutter移动电商实战-dio基础_引入和简单的Get请求

05-Flutter移动电商实战-dio基础_引入和简单的Get请求 这篇开始我们学习Dart第三方Http请求库dio,这是国人开源的一个项目,也是国内用的最广泛的Dart Http请求库。 1、dio介绍和引入 dio是一个强大的Dart Http请求库,支持Restful API、 Fo…

Jmeter——for循环控制器和if逻辑控制器

有时我们不仅仅需要用例按照简单的顺序跑,需要内嵌循环,或者条件分支,让某些用例在满足一定条件时才执行。 1、for循环控制器 此处记录两种应用的场景,一种是直接定义好要循环的变量,循环次数是固定的,写死…

06-Flutter移动电商实战-dio基础_Get_Post请求和动态组件协作

06-Flutter移动电商实战-dio基础_Get_Post请求和动态组件协作 上篇文章中,我们只看到了 dio 的使用方式,但并未跟应用关联起来,所以这一篇将 dio 网络请求与应用界面结合起来,当然这也是为以后的实战作基础准备,基础打…

Cannot merge new index 66395 into a non-jumbo instruction!,uses or overrides a deprecated API.

老项目运行没问题。一打包就报错 解决方法——添加dexOptions android {compileSdkVersion 27dexOptions{jumboMode true}

08-Flutter移动电商实战-dio基础_伪造请求头获取数据

08-Flutter移动电商实战-dio基础_伪造请求头获取数据 在很多时候,后端为了安全都会有一些请求头的限制,只有请求头对了,才能正确返回数据。这虽然限制了一些人恶意请求数据,但是对于我们聪明的程序员来说,就是形同虚设…

工作269:uni--客流分析优化

<template><view><view><view id"main"><!-- 第一步 设置盒子大小 --></view></view><view><view v-for"(item,index) in 10"><view style"display: flex;justify-content: space-between;…

09-Flutter移动电商实战-移动商城数据请求实战

09-Flutter移动电商实战-移动商城数据请求实战 1、URL接口管理文件建立 第一步需要在建立一个URL的管理文件&#xff0c;因为课程的接口会一直进行变化&#xff0c;所以单独拿出来会非常方便变化接口。当然工作中的URL管理也是需要这样配置的&#xff0c;以为我们会不断的切换…

Android 串口开发,发送串口命令,读卡,反扫码,USB通讯,实现demo。——持续更新

应用场景&#xff1a;APP发送串口命令到打印机&#xff0c;打印相应数据小票 // 串口 implementation com.github.licheedev.Android-SerialPort-API:serialport:1.0.1 1、获取全部串口地址devicePath private String[] mDevices; public void getcuankou(){SerialPortFinder…

10-Flutter移动电商实战-使用FlutterSwiper制作轮播效果

10-Flutter移动电商实战-使用FlutterSwiper制作轮播效果 1、引入flutter_swiper插件 flutter最强大的siwiper, 多种布局方式&#xff0c;无限轮播&#xff0c;Android和IOS双端适配. 好牛X得介绍&#xff0c;一般敢用“最”的一般都是神级大神&#xff0c;看到这个介绍后我也是…

工作271:打开弹出框调用当前页面接口

<template><div><button-dialogopen"open"ref"dialog"width"1000px"title"内容关联"ok-button-text"确认关联":destroy-on-close"true"ok"confirmAssociation"><!----><cus…

11-Flutter移动电商实战-首页_屏幕适配方案和制作

11-Flutter移动电商实战-首页_屏幕适配方案和制作 1、flutter_ScreenUtil插件简介 flutter_ScreenUtil屏幕适配方案&#xff0c;让你的UI在不同尺寸的屏幕上都能显示合理的布局。 插件会让你先设置一个UI稿的尺寸&#xff0c;他会根据这个尺寸&#xff0c;根据不同屏幕进行缩…

12-Flutter移动电商实战-首页导航区域编写

12-Flutter移动电商实战-首页导航区域编写 1、导航单元素的编写 从外部看&#xff0c;导航是一个GridView部件&#xff0c;但是每一个导航又是一个上下关系的Column。小伙伴们都知道Flutter有多层嵌套的问题&#xff0c;如果我们都写在一个组件里&#xff0c;那势必造成嵌套严…

Day3_操作记录

python基础&#xff1a; 回顾 1. 条件判断 if &#xff1a; x else&#xff1a; xx 循环 while for for i in range(5): 2. 数据类型&#xff1a; int 类型 float 小数类…

Android 全局悬浮按钮,悬浮按钮点击事件

实现效果&#xff1a; 实现方法&#xff1a; 在自定义baseActivity里面添加viwe即可。在子activity里刷新悬浮View即可 public abstract class BaseActivity extends BaseCommonActivity {LinearLayout saoli,ewmli;ImageView imageView;private QrCodeDialog mMQrCodeDialog;p…

13-Flutter移动电商实战-ADBanner组件的编写

1、AdBanner组件的编写 我们还是把这部分单独出来&#xff0c;需要说明的是&#xff0c;这个Class你也是可以完全独立成一个dart文件的。代码如下&#xff1a; 广告图片class AdBanner extends StatelessWidget { final String advertesPicture; AdBanner({Key key, this.adv…

android远程调试工具,android投屏工具

远程调试工具 说明&#xff1a;远程对方电脑需下载安装两个软件&#xff1a;Android studio 和向日葵 android投屏工具 说明&#xff1a;涉及硬件开发时&#xff0c;有的硬件没有安卓屏&#xff0c;可以使用 两个工具exe文件下载连接&#xff1a; https://download.csdn.net/…