深度学习二分类评估详细解析与代码实战

深度学习二分类的实战代码:使用 Trainer API 微调模型. https://huggingface.co/learn/nlp-course/zh-CN/chapter3/3

如果你刚接触 自然语言处理,huggingface 是你绕不过去的坎。但是目前它已经被墙了,相信读者的实力,自行解决吧。

设置代理,如果不设置的话,那么huggingface的包无法下载;

import os
os.environ['HTTP_PROXY'] = 'http://127.0.0.1:7890'
os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890'

在探讨二分类问题时,经常会遇到四种基本的分类结果,它们根据样例的真实类别与分类器的预测类别来定义。以下是对这些分类结果的详细解释:

这四个定义均由两个字母组成,它们各自代表了不同的含义。

第一个字母(True/False)用于表示算法预测的正确性,而第二个字母(Positive/Negative)则用于表示算法预测的结果。

  • 第1个字母(True/False):描述的是分类器是否预测正确。True表示分类器判断正确,而False则表示分类器判断错误。
  • 第2个字母(Positive/Negative):表示的是分类器的预测结果。Positive代表分类器预测为正例,而Negative则代表分类器预测为负例。
  1. 真正例(True Positive,TP):当样例的真实类别为正例时,如果分类器也预测其为正例,那么我们就称这个样例为真正例。简而言之,真实情况与预测结果均为正例。
  2. 假正例(False Positive,FP):有时,分类器可能会将真实类别为负例的样例错误地预测为正例。这种情况下,我们称该样例为假正例。它代表了分类器的“过度自信”或“误报”现象。
  3. 假负例(False Negative,FN):与假正例相反,假负例指的是真实类别为正例的样例被分类器错误地预测为负例。这种情况下的“遗漏”或“漏报”是分类器性能评估中需要重点关注的问题。
  4. 真负例(True Negative,TN):当样例的真实类别和预测类别均为负例时,我们称其为真负例。这意味着分类器正确地识别了负例。

数据准备

做深度学习的同学应该都默认装了 torch,跳过 torch的安装

!pip install evaluate

导包

import torch
import random
import evaluate

随机生成二分类的预测数据 pred 和 label;

label = torch.tensor([random.choice([0, 1]) for i in range(20)])
pred = torch.tensor([random.choice([0, 1, label[i]]) for i in range(20)])
sum(label == pred)

下述是随机生成的 label 和 pred

# label
tensor([0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0])# pred
tensor([0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0])

使用 random.choice([0, 1, label[i]] 是为了提高 pred 的 准确率; 因为 label[i] 是真实的 label;

下述的是计算TP、TN、FP、FN的值:

Tips:

pred : 与第2个字母(Positive/Negative)保持一致,

label: 根据第一个字母是否预测正确,再判断填什么

TP = sum((label == 1) & (pred == 1))
TN = sum((label == 0) & (pred == 0))
FP = sum((label == 0) & (pred == 1))
FN = sum((label == 1) & (pred == 0))
标签Value
TP6
TN8
FP2
FN4

准确率 Accuracy

准确率(Accuracy): 分母通常指的是所有样本的数量,即包括真正例(True Positives, TP)、假正例(False Positives, FP)、假负例(False Negatives, FN)和真负例(True Negatives, TN)的总和。而分子中的第一个字母为“T”(True),意味着我们计算的是算法预测正确的样本数量,即TP和TN的总和。

然而,准确率作为一个评价指标存在一个显著的缺陷,那就是它对数据样本的均衡性非常敏感。当数据集中的正负样本数量存在严重不均衡时,准确率往往不能准确地反映模型的性能优劣。

例如,假设有一个测试集,其中包含90%正样本和仅10%负样本。若模型将所有样本都预测为正样本,那么它的准确率将轻松达到90%。从准确率这一指标来看,模型似乎表现得非常好。但实际上,这个模型对于负样本的预测能力几乎为零。

因此,在处理样本不均衡的问题时,需要采用其他更合适的评价指标,如精确度(Precision)、召回率(Recall)、F1分数(F1 Score)等,来更全面地评估模型的性能。这些指标能够更准确地反映模型在各类样本上的预测能力,从而帮助我们做出更准确的决策。

精准率的公式如下:
A c c u r a c y = T P + T N T P + T N + F P + F N = T P + T N 所有样本数 Accuracy = \frac{TP + TN}{TP + TN + FP +FN} = \frac{TP + TN}{所有样本数} Accuracy=TP+TN+FP+FNTP+TN=所有样本数TP+TN

accuracy = evaluate.load("accuracy")
accuracy.compute(predictions=pred, references=label)

Output:

{'accuracy': 0.7}

下述三种方法都可以用来计算 accuracy:

print((TP + TN) / (TP + TN + FP +FN),(TP + TN) / len(label),sum((label == pred)) / 20
)

Output:

tensor(0.7000) tensor(0.7000) tensor(0.7000)

使用公式计算出来的与通过evaluate库,算出来的结果一致,都是 0.7。

precision 精准率

P r e c i s i o n = T P T P + F P Precision = \frac{TP}{TP + FP} Precision=TP+FPTP

precision = evaluate.load("precision")
precision.compute(predictions=pred, references=label)

Output:

{'precision': 0.75}
TP / (TP + FP)

recall 召回率

R e c a l l = T P T P + F N Recall = \frac{TP}{TP + FN} Recall=TP+FNTP

recall = evaluate.load("recall")
recall.compute(predictions=pred, references=label)

Output:

{'recall': 0.6}
TP / (TP + FN)

F1

f1 = evaluate.load("f1")
f1.compute(predictions=pred, references=label)

Output:

{'f1': 0.6666666666666666}

F 1 = 2 × P r e c i s i o n × R e c a l l P r e c i s i o n + R e c a l l F1 = \frac{2 \times {Precision} \times {Recall}}{{Precision} + {Recall}} F1=Precision+Recall2×Precision×Recall

2 * 0.7500 * 0.6000 / (0.7500 + 0.6000)

Output:

0.6666666666666665

希望这篇文章,通过代码实战,能够帮到你加深印象与理解!概念说再多遍,不如代码实现一遍。

参考资料

  • 如何在python代码中使用代理下载Hungging face模型. https://www.jianshu.com/p/209528bed023
  • [机器学习] 二分类模型评估指标—精确率Precision、召回率Recall、ROC|AUC. https://blog.csdn.net/zwqjoy/article/details/78793162
  • 使用 Trainer API 微调模型. https://huggingface.co/learn/nlp-course/zh-CN/chapter3/3
  • Huggingface Evaluate 文档. https://huggingface.co/docs/evaluate/index

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/36047.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

渗透第二次作业

cs与msf权限传递,以及mimikatz抓取win2012明文密码 1、准备三台虚拟机: 一台安装有cs的kali,网络模式为nat, 一台Win2012,有两张网卡,一张为NAT模式,一张为仅主机模式,分别对应内外网, 一台…

QT的TCP服务端与多客户端通信

目的 TCP通信可以说是最基础的东西了,也是面试经常问的问题,记得10年前,面试浪潮时,就是问的TCP连接的过程。 时间长了不用,感觉一些东西模糊了,基础的东西还是需要清晰的,而且,现在是QT的TCP,用法也有一些自己的特点。 这里主要说的就是服务端与多客户端的通信,这也…

G8 - ACGAN

🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 目录 模型结构 模型结构 之前几期打卡中,已经介绍过GAN CGAN SGAN,而ACGAN属于上述几种GAN的缝合怪,其模型的结构图如下&a…

Python 中的抽象语法树

Abstract Syntax Trees in Python 注:机翻,未校对。 Requirement: All examples are compatible with at least Python v3.6, except for using ast.dump() with the attribute indent which has been added in Python v3.9. 要求:所有示例至…

如何检测和处理Android应用程序中的内存泄漏问题。

在Android开发中,内存泄漏是一个不容忽视的问题。它不仅会影响应用程序的性能,还可能导致应用崩溃,给用户带来不良体验。因此,作为开发者,我们必须了解如何检测和处理内存泄漏。下面,我将从技术难点、面试官…

结题阶段(2024年6月)

课题研究大事记 序号 时间 内容安排 负责人 备注 1 2022.4 课题审定会议 全体成员 2 2022.4 开题报告撰写 郭书艳 3 2022.4 课题申报 郭书艳 4 2022.5 课题立项报告会 郭书艳、陈晓忠 5 2022.6 课题推进安排会 俞峰 6 2022.7 当下课堂模式…

第二十课,认识列表与定义列表

一,列表的作用 思考一个问题:如果我想要在程序中,记录5名学生的信息,如姓名。 如何做呢? 这就是列表的作用,能帮助我们更加高效的存储各种数据 思考:如果一个班级100位学生,每个人…

利用SHAP算法解释BERT模型的输出

1 何为SHAP? 传统的 feature importance 只告诉哪个特征重要,但并不清楚该特征如何影响预测结果。SHAP 算法的最大优势是能反应每一个样本中特征的影响力,且可表现出影响的正负性。SHAP算法的主要思想为:控制变量法,如果某个特征…

VMware完美安装Ubuntu20.04

一、官网下载Ubuntu20.04 下载地址为:https://releases.ubuntu.com/https://releases.ubuntu.com/ 下载完后镜像为ubuntu-20.04.4-desktop-amd64.iso 二、Ubuntu安装 2.1、打开VMware player,并创建新虚拟机。 2.2、点击浏览按钮选择需要安装的镜像 2…

Linux系统上部署Whisper。

Whisper是一个开源的自动语音识别(ASR)模型,最初由OpenAI发布。要在本地Linux系统上部署Whisper,你可以按照以下步骤进行: 1. 创建虚拟环境 为了避免依赖冲突,建议在虚拟环境中进行部署。创建并激活一个新…

问题 N: 二叉树的创建和文本显示

问题 N: 二叉树的创建和文本显示 题目描述 编一个程序,读入先序遍历字符串,根据此字符串建立一棵二叉树(以指针方式存储)。 例如如下的先序遍历字符串: A ST C # # D 10 # G # # F # # # 各结点数据(长度不…

数据结构实训:表达式求值器(非常详细)

表达式求值器 问题描述: 设计一个表达式求值器,能够解析和计算由数字、运算符和括号组成的算术表达式。要求实现基本的四则运算,如加、减、乘、除,并处理运算符的优先级和括号。 设计要点: 1. 使用栈作为数据结构来处…

ElementUI组件

目录 1、安装ElementUI 2、在main.js文件中加入 3、使用组件 终端运行: Element,一套为开发者、设计师和产品经理准备的基于Vue2.0的桌面端组件库. 1、安装ElementUI 控制台输入 npm i element-ui -S 2、在main.js文件中加入 import ElementUI from…

老司机开发技巧,如何扩展三方包功能

前言 最近碰上有个业务,查询的sql如下: sql 复制代码 select * from table where (sku_id,batch_no) in ((#{skuId},#{batchNo}),...); 本来也没什么,很简单常见的一种sql。 问题是我们使用的是mybatis-plus,然后写的时候有没…

【智能制造-5】数采和电机

既然可以采集PLC的数据,为什么要采集电机的数据? 采集PLC(可编程逻辑控制器)的数据和采集电机的数据是两个不同的概念和目的。 PLC是用于控制和监控工业自动化过程的设备,它可以接收传感器的输入信号并根据预设的逻辑…

多线程软件不响应处理

多线程的问题,基本上由于写法不规范造成的问题,从而影响软件正常运行,或时不时出现软件不响应,但是其它CPU,内存保存不变的情况. 出现这样的情况,多半是软件运行时死锁或多个线程相互等待,从而引起的软件未响应的情况发生. 解决办法: 1.while,do while循环增加延时时间Sleep…

重庆交通大学24计算机考研数据速览,专硕第二年招生,复试线321分!

重庆交通大学(Chongqing Jiaotong University,CQJTU),是由重庆市人民政府和中华人民共和国交通运输部共建的一所交通特色、以工为主的多科性大学,入选“中西部高校基础能力建设工程”、“卓越工程师教育培养计划”、国…

企业级堡垒机JumpServer

文章目录 JumpServer是什么生产应用场景 Docker安装JumpServer1.Docker安装2.MySQL服务安装3.Redis服务安装4.key生成5.JumpServer安装6.登录验证 系统设置邮箱服务器用户和用户组创建系统审计员资产管理用户创建资产节点资产授权查看用户的资产监控仪表盘 命令过滤器创建命令过…

Model3C芯片方案--86彩屏中控面板Modbus协议说明

一、概述 Model3C芯片是一款基于RISC-V的高性能、国产自主、工业级高清显示与智能控制MCU,配备强大的2D图形加速处理器、PNG/JPEG解码引擎,并支持工业宽温。基于Model3C芯片的86彩屏中控面板,通过集成Modbus协议,实现了与多种控制…

前端存储都有哪些

cookie 、sessionStorage、localStorange、http缓存 、indexDB cookie 由服务器设置,在客户端存储,然后每次发起同源请求时,发送给服务器端。cookie最多能存储4K数据,它的生存时间由expires属性指定,并且cookie只能被…