使用 BERT 进行文本分类 (02/3)

一、说明

        在使用BERT(1)进行文本分类中,我向您展示了一个BERT如何标记文本的示例。在下面的文章中,让我们更深入地研究是否可以使用 BERT 来预测文本是使用 PyTorch 传达积极还是消极的情绪。首先,我们需要准备数据,以便使用 PyTorch 框架进行分析。

二、什么是 PyTorch

        PyTorch 是用于构建深度学习模型的框架,深度学习模型是一种机器学习,通常用于图像识别和语言处理等应用程序。它由Facebook的人工智能研究小组于2016年开发,由于其灵活性,易用性和动态计算图构建而广受欢迎。

        PyTorch 提供了一个基于 Python 的科学计算包,它使用图形处理单元 (GPU) 的强大功能来加速张量运算的计算。它具有简单直观的API,允许开发人员快速构建和训练深度学习模型。PyTorch 还支持自动微分,使用户能够计算任意函数的梯度。

三、准备我们的数据集

        首先,让我们从Github下载我们的数据。这里有一个关于如何从Github下载CSV文件的小提醒。只需继续并单击以下链接:

github.com

        然后,右键单击“原始”,然后左键单击“将链接文件下载为...”。您将看到“垃圾邮件.csv”并下载它。下载后,将其保存到您的首选文件夹中以供以后使用。

        现在,让我们导入数据。我们看到一条错误消息,告诉我们部分数据未采用 UTF-8 编码。

import pandas as pd
df = pd.read_csv("spam.csv")ERROR: 
UnicodeDecodeError: 'utf-8' codec can't decode bytes in position 606-607: invalid continuation byte

我们可以通过了解数据包含的字符编码并在读取数据时调用该编码来修复此错误。

# Use chardet to know the character encoding 
import chardet
with open("spam.csv", 'rb') as rawdata:result = chardet.detect(rawdata.read(100000))
resultOutput: 
{'encoding': 'Windows-1252', 'confidence': 0.7270322499829184, 'language': ''}

似乎我们的数据是在“Windows-1252”中编码的。那让我们再读一遍。它奏效了!

df = pd.read_csv("spam.csv", encoding = 'Windows-1252')
df.head()

        如我们所见,我们实际上并不需要“v1”和“v2”以外的列。此外,如果我们将“v1”和“v2”重命名为“类别”和“消息”,则更容易理解。

df = df.loc[:, ['v1', 'v2']]
df = df.rename(columns={'v1': 'Category', 'v2': 'Message'})
df.head()

        现在,我们应该看看我们的数据集,看看每个类别中有多少条消息。

df['Category'].value_counts()Output: 
ham     4825
spam     747
Name: Category, dtype: int64

四、创建平衡数据集

        事实证明,正常邮件比垃圾邮件多。构建机器学习模型时,如果数据集不平衡,其中一个类中的数据数量明显多于另一个类,则可能会对模型的性能产生各种影响。一些潜在的后果。例如:

-1 有偏差模型:如果数据集不平衡,模型可能会偏向多数类,而对少数类表现不佳。这是因为模型更有可能预测多数类,这将导致少数类的准确性较差。

-2 泛化不良:不平衡的数据集可能导致模型泛化不良。这是因为该模型将在不代表数据真实世界分布的数据集上进行训练,因此它可能无法很好地概括看不见的数据。

-3 评估不准确:如果使用准确性作为指标评估模型,则可能会产生误导性结果。例如,始终预测不平衡数据集中多数类的模型可能具有很高的准确性,但对少数类没有用。

-4 过拟合:由于数据点数量较多,模型可能会过度拟合多数类,从而导致测试数据的性能不佳。

为了解决这些问题,可以使用各种技术来平衡数据集,例如对少数类进行过采样,对多数类进行欠采样,或同时使用两者的组合。在这篇文章中,我将使用欠采样方法。

df_spam = df[df['Category']=='spam']
df_ham = df[df['Category']=='ham']
df_ham_downsampled = df_ham.sample(df_spam.shape[0])
df_balanced = pd.concat([df_ham_downsampled, df_spam])
df_balanced['Category'].value_counts()Output: 
ham     747
spam    747
Name: Category, dtype: int64

五、标记数据

        当数据表示为数字而不是分类为用于训练和测试的模型时,机器学习算法在准确性和其他性能指标方面表现更好。我们需要用数值对分类值进行标签编码。在这里,我们创建了一个新列“标签”,如果邮件是垃圾邮件,我们将其标记为 1,否则为 0。

df_balanced['Label']=df_balanced['Category'].apply(lambda x: 1 if x=='spam' else 0)
df_balanced = df_balanced.reset_index(drop=True)display(df_balanced)

由作者创建

六、训练、验证和测试数据集:谁是谁

        要记住的一件事是,当我们使用 train_test_split 库来训练模型时,我们实际上是将数据集拆分为 TRAINING 数据集和 VALIDATION 数据集,而不是 TRAINING 数据集和 TESTING 数据集。下面提醒一下这些数据集的含义。

  1. 训练集:用于构建我们的模型。我们将使用训练集来找到具有反向传播规则的“最佳”权重和偏差。在此阶段,我们通常会创建多个算法,以便在交叉验证阶段比较它们的性能。
  2. 交叉验证集:此数据集用于比较基于训练集创建的预测算法的性能。我们选择性能最佳的算法。
  3. 测试集:这是“未来”数据集。现在我们已经选择了我们喜欢的预测算法,但我们还不知道它将如何在完全看不见的真实世界数据上执行。因此,我们将我们选择的预测算法应用于我们的测试集,以查看它将如何执行,以便我们可以了解我们的算法在野外的性能。

        因此,在测试集中,我们没有数据的标签,而是使用我们的模型来预测标签。我们只能将手头的数据集拆分为训练集和验证集,因为我们还没有“未来”数据。

七、拆分为训练数据集和验证数据集

        现在我们了解了这三种类型的数据的真正含义,我们可以使用scikit-learn的train_test_split来拆分数据。

from sklearn.model_selection import train_test_split
X_train, X_val, y_train, y_val = train_test_split(df_balanced['Message'],df_balanced['Label'], stratify=df_balanced['Label'], test_size=.2)X_train.head()Output: 
708                      ;-) ok. I feel like john lennon.
1386    Cashbin.co.uk (Get lots of cash this weekend!)...
1492    REMINDER FROM O2: To get 2.50 pounds free call...
119     Back in brum! Thanks for putting us up and kee...
89                       Sorry, I can't help you on this.
Name: Message, dtype: object

八、总结

        我们已经学会了如何下载和拆分数据。在下一篇文章中,我们将首先对其进行标记,并使用DistilBERT训练分类器。达门·

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/39627.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

3.1 Qt样式选择器

本期内容 3.1 样式选择器 3.1.1 Universal Selector (通用选择器) 3.1.2 Type Selector (类型选择器) 3.1.3 Property Selector (属性选择器) 3.1.4 Class Selector (类选择器) 3.1.5 ID Selector (ID选择器) 3.1.6 Descendant Selector (后裔选择器) 3.1.7 Chil…

前端跨域的原因以及解决方案(vue),一文让你真正理解跨域

跨域这个问题,可以说是前端的必需了解的,但是多少人是知其然不知所以然呢? 下面我们来梳理一下vue解决跨域的思路。 什么情况会跨域? ​ 跨域的本质就是浏览器基于同源策略的一种安全手段。所谓同源就是必须有以下三个相同点:协议相同、域名…

WinCC V7.5 中的C脚本对话框不可见,将编辑窗口移动到可见区域的具体方法

WinCC V7.5 中的C脚本对话框不可见,将编辑窗口移动到可见区域的具体方法 由于 Windows 系统更新或使用不同的显示器,在配置C动作时,有可能会出现C脚本编辑窗口被移动到不可见区域的现象。 由于该窗口无法被关闭,故无法进行进一步…

KafkaStream:Springboot中集成

1、在kafka-demo中创建配置类 配置kafka参数 package com.heima.kafkademo.config;import lombok.Data; import org.apache.kafka.common.serialization.Serdes; import org.apache.kafka.streams.StreamsConfig; import org.springframework.boot.context.properties.Configu…

8月11日上课内容 nginx的多实例和动静分离

多实例部署 在一台服务器上有多个tomcat的服务。 配置多实例之前,看单个实例是否访问正常。 1.安装好 jdk 2.安装 tomcat cd /opt tar zxvf apache-tomcat-9.0.16.tar.gz mkdir /usr/local/tomcat mv apache-tomcat-9.0.16 /usr/local/tomcat/tomcat1 cp -a /u…

Linux系统管理:虚拟机ESXi安装

目录 一、理论 1.VMware Workstation 2.VMware vSphere Client 3.ESXi 二、实验 1.ESXi 7安装 一、理论 1.VMware Workstation 它是一款专业的虚拟机软件,可以在一台物理机上运行多个操作系统,支持Windows、Linux等操作系统,可以模拟…

使用selenium如何实现自动登录

回顾使用requests如何实现自动登录一文中,提到好多网站在我们登录过后,在之后的某段时间内访问该网页时,不会给出请登录的提示,时间到期后就会提示请登录!这样在使用爬虫访问网页时还要登录,打乱我们的节奏…

item_get_sales-获取商品销量详情

一、接口参数说明: item_get_sales-获取商品销量详情,点击更多API调试,请移步注册API账号点击获取测试key和secret 公共参数 请求地址: https://api-gw.onebound.cn/taobao/item_get_sales 名称类型必须描述keyString是调用key&#xff08…

Spring的三种异常处理方式

1.SpringMVC 异常的处理流程 异常分为编译时异常和运行时异常,编译时异常我们 try-cache 进行捕获,捕获后自行处理,而运行时异常是不 可预期的,就需要规范编码来避免,在SpringMVC 中,不管是编译异常还是运行…

java:JDBC

文章目录 什么是JDBCJDBC使用步骤详解各个对象DriverManagerConnectionStatementResultSetPreparedStatement JDBC控制事务操作步骤示例 什么是JDBC 我们知道,数据库有很多种,比如 mysql,Oracle,DB2等等,如果每一种数…

C# WPF 中 外部图标引入iconfont,无法正常显示问题 【小白记录】

wpf iconfont 外部图标引入&#xff0c;无法正常显示问题。 1. 检查资源路径和引入格式是否正确2. 检查资源是否包含在程序集中 1. 检查资源路径和引入格式是否正确 正确的格式&#xff0c;注意字体文件 “xxxx.ttf” 应写为 “#xxxx” <TextBlock Text"&#xe7ae;…

经典人体模型SMPL介绍(一)

SMPL是马普所提出的经典人体模型&#xff0c;目前已成为姿态估计、人体重建等领域必不可少的基础先验。SMPL基于蒙皮和BlendShape实现&#xff0c;从数千个三维人体扫描结果得来&#xff0c;后通过PCA统计学习得来。 论文&#xff1a;SMPL: A Skinned Multi-Person Linear Mode…

2023连锁收银系统该如何选?值得推荐的5款连锁收银系统

现在不管是连锁店还是零售店&#xff0c;只要是开店做生意赚钱的&#xff0c;都少不了要和钱打交道&#xff0c;尤其是对连锁店来说&#xff0c;收银工作更是重中之重。 连锁店涉及的门店较多&#xff0c;必须要有一套足够优秀的连锁收银系统&#xff0c;才能做好每个门店的收银…

使用 `tailwindcss-patch@2` 来提取你的类名吧

使用 tailwindcss-patch2 来提取你的类名吧 使用 tailwindcss-patch2 来提取你的类名吧 安装使用方式 命令行 Cli 开始提取吧 Nodejs API 的方式来使用 配置 初始化 What’s next? tailwindcss-patch 是一个 tailwindcss 生态的扩展项目。也是 tailwindcss-mangle 项目重要…

2023年上半年网络工程师上午真题及答案解析

1.固态硬盘的存储介质是( )。 A.光盘 B.闪存 C.软盘 D.磁盘 2.虚拟存储技术把( )有机地结合起来使用&#xff0c;从而得到一个更大容量的“内存”。 A.内存与外存 B.Cache与内存 C.寄存器与Cache D.Cache与外存 3.下列接口协议中&…

找不到msvcp140.dll无法继续执行代码怎么解决?分享三个解决方法

当你在运行某个程序或游戏时遇到msvcp140.dll缺失的错误提示&#xff0c;你可能会感到困惑和烦恼。在修复msvcp140.dll的过程中&#xff0c;我遇到了一些挑战&#xff0c;但最终成功解决了这个问题。以下是我总结的三个解决方法&#xff0c;希望能帮助你解决这个问题。 找不到m…

Mongodb (四十一)

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 目录 前言 一、概述 1.1 相关概念 1.2 特性 二、应用场景 三、安装 四、目录结构 五、默认数据库 六、 数据库操作 6.1 库操作 6.2 文档操作 七、MongoDB数据库备份 7.1 备…

小游戏扫雷实现教学(详解)

目录 【前言】 一、模块化程序设计&#xff08;多文件编程&#xff09;介绍 1.概述 2.传统编程的方式 3.模块化程序设计的方法 二、扫雷代码设计思路 三、扫雷代码设计 1.创建菜单函数 2.实现9x9扫雷 3.初始化棋盘 4.打印棋盘 5.随机布置雷的位置 6.排查雷的信息 7.回…

TEC2083BS-PD码转换器(解决博世矩阵控制PELCO派尔高球机的问题)

TEC2083BS-PD码转换器 使用说明 1.设备概述 控制码转换器在安防工程中起着非常重要的角色&#xff0c;随着高速球型摄像机在安防工程中大范围的使用&#xff0c;而高速球厂家都因为某些原因很少使用博世、飞利浦的协议。为此&#xff0c;工程商经常会遇到博世协议和PELCO协议之…

RabbitMQ工作流程详解

1 生产者发送消息的流程 (1)生产者连接RabbitMQ&#xff0c;建立TCP连接(Connection)&#xff0c;开启信道(Channel) (2)生产者声明一个Exchange (交换器)&#xff0c;并设置相关属性&#xff0c;比如交换器类型、是否持久化等 (3)生产者声明一个队列井设置相关属性&#xf…