使用 BERT 进行文本分类 (02/3)

一、说明

        在使用BERT(1)进行文本分类中,我向您展示了一个BERT如何标记文本的示例。在下面的文章中,让我们更深入地研究是否可以使用 BERT 来预测文本是使用 PyTorch 传达积极还是消极的情绪。首先,我们需要准备数据,以便使用 PyTorch 框架进行分析。

二、什么是 PyTorch

        PyTorch 是用于构建深度学习模型的框架,深度学习模型是一种机器学习,通常用于图像识别和语言处理等应用程序。它由Facebook的人工智能研究小组于2016年开发,由于其灵活性,易用性和动态计算图构建而广受欢迎。

        PyTorch 提供了一个基于 Python 的科学计算包,它使用图形处理单元 (GPU) 的强大功能来加速张量运算的计算。它具有简单直观的API,允许开发人员快速构建和训练深度学习模型。PyTorch 还支持自动微分,使用户能够计算任意函数的梯度。

三、准备我们的数据集

        首先,让我们从Github下载我们的数据。这里有一个关于如何从Github下载CSV文件的小提醒。只需继续并单击以下链接:

github.com

        然后,右键单击“原始”,然后左键单击“将链接文件下载为...”。您将看到“垃圾邮件.csv”并下载它。下载后,将其保存到您的首选文件夹中以供以后使用。

        现在,让我们导入数据。我们看到一条错误消息,告诉我们部分数据未采用 UTF-8 编码。

import pandas as pd
df = pd.read_csv("spam.csv")ERROR: 
UnicodeDecodeError: 'utf-8' codec can't decode bytes in position 606-607: invalid continuation byte

我们可以通过了解数据包含的字符编码并在读取数据时调用该编码来修复此错误。

# Use chardet to know the character encoding 
import chardet
with open("spam.csv", 'rb') as rawdata:result = chardet.detect(rawdata.read(100000))
resultOutput: 
{'encoding': 'Windows-1252', 'confidence': 0.7270322499829184, 'language': ''}

似乎我们的数据是在“Windows-1252”中编码的。那让我们再读一遍。它奏效了!

df = pd.read_csv("spam.csv", encoding = 'Windows-1252')
df.head()

        如我们所见,我们实际上并不需要“v1”和“v2”以外的列。此外,如果我们将“v1”和“v2”重命名为“类别”和“消息”,则更容易理解。

df = df.loc[:, ['v1', 'v2']]
df = df.rename(columns={'v1': 'Category', 'v2': 'Message'})
df.head()

        现在,我们应该看看我们的数据集,看看每个类别中有多少条消息。

df['Category'].value_counts()Output: 
ham     4825
spam     747
Name: Category, dtype: int64

四、创建平衡数据集

        事实证明,正常邮件比垃圾邮件多。构建机器学习模型时,如果数据集不平衡,其中一个类中的数据数量明显多于另一个类,则可能会对模型的性能产生各种影响。一些潜在的后果。例如:

-1 有偏差模型:如果数据集不平衡,模型可能会偏向多数类,而对少数类表现不佳。这是因为模型更有可能预测多数类,这将导致少数类的准确性较差。

-2 泛化不良:不平衡的数据集可能导致模型泛化不良。这是因为该模型将在不代表数据真实世界分布的数据集上进行训练,因此它可能无法很好地概括看不见的数据。

-3 评估不准确:如果使用准确性作为指标评估模型,则可能会产生误导性结果。例如,始终预测不平衡数据集中多数类的模型可能具有很高的准确性,但对少数类没有用。

-4 过拟合:由于数据点数量较多,模型可能会过度拟合多数类,从而导致测试数据的性能不佳。

为了解决这些问题,可以使用各种技术来平衡数据集,例如对少数类进行过采样,对多数类进行欠采样,或同时使用两者的组合。在这篇文章中,我将使用欠采样方法。

df_spam = df[df['Category']=='spam']
df_ham = df[df['Category']=='ham']
df_ham_downsampled = df_ham.sample(df_spam.shape[0])
df_balanced = pd.concat([df_ham_downsampled, df_spam])
df_balanced['Category'].value_counts()Output: 
ham     747
spam    747
Name: Category, dtype: int64

五、标记数据

        当数据表示为数字而不是分类为用于训练和测试的模型时,机器学习算法在准确性和其他性能指标方面表现更好。我们需要用数值对分类值进行标签编码。在这里,我们创建了一个新列“标签”,如果邮件是垃圾邮件,我们将其标记为 1,否则为 0。

df_balanced['Label']=df_balanced['Category'].apply(lambda x: 1 if x=='spam' else 0)
df_balanced = df_balanced.reset_index(drop=True)display(df_balanced)

由作者创建

六、训练、验证和测试数据集:谁是谁

        要记住的一件事是,当我们使用 train_test_split 库来训练模型时,我们实际上是将数据集拆分为 TRAINING 数据集和 VALIDATION 数据集,而不是 TRAINING 数据集和 TESTING 数据集。下面提醒一下这些数据集的含义。

  1. 训练集:用于构建我们的模型。我们将使用训练集来找到具有反向传播规则的“最佳”权重和偏差。在此阶段,我们通常会创建多个算法,以便在交叉验证阶段比较它们的性能。
  2. 交叉验证集:此数据集用于比较基于训练集创建的预测算法的性能。我们选择性能最佳的算法。
  3. 测试集:这是“未来”数据集。现在我们已经选择了我们喜欢的预测算法,但我们还不知道它将如何在完全看不见的真实世界数据上执行。因此,我们将我们选择的预测算法应用于我们的测试集,以查看它将如何执行,以便我们可以了解我们的算法在野外的性能。

        因此,在测试集中,我们没有数据的标签,而是使用我们的模型来预测标签。我们只能将手头的数据集拆分为训练集和验证集,因为我们还没有“未来”数据。

七、拆分为训练数据集和验证数据集

        现在我们了解了这三种类型的数据的真正含义,我们可以使用scikit-learn的train_test_split来拆分数据。

from sklearn.model_selection import train_test_split
X_train, X_val, y_train, y_val = train_test_split(df_balanced['Message'],df_balanced['Label'], stratify=df_balanced['Label'], test_size=.2)X_train.head()Output: 
708                      ;-) ok. I feel like john lennon.
1386    Cashbin.co.uk (Get lots of cash this weekend!)...
1492    REMINDER FROM O2: To get 2.50 pounds free call...
119     Back in brum! Thanks for putting us up and kee...
89                       Sorry, I can't help you on this.
Name: Message, dtype: object

八、总结

        我们已经学会了如何下载和拆分数据。在下一篇文章中,我们将首先对其进行标记,并使用DistilBERT训练分类器。达门·

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/39627.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

3.1 Qt样式选择器

本期内容 3.1 样式选择器 3.1.1 Universal Selector (通用选择器) 3.1.2 Type Selector (类型选择器) 3.1.3 Property Selector (属性选择器) 3.1.4 Class Selector (类选择器) 3.1.5 ID Selector (ID选择器) 3.1.6 Descendant Selector (后裔选择器) 3.1.7 Chil…

前端跨域的原因以及解决方案(vue),一文让你真正理解跨域

跨域这个问题,可以说是前端的必需了解的,但是多少人是知其然不知所以然呢? 下面我们来梳理一下vue解决跨域的思路。 什么情况会跨域? ​ 跨域的本质就是浏览器基于同源策略的一种安全手段。所谓同源就是必须有以下三个相同点:协议相同、域名…

WinCC V7.5 中的C脚本对话框不可见,将编辑窗口移动到可见区域的具体方法

WinCC V7.5 中的C脚本对话框不可见,将编辑窗口移动到可见区域的具体方法 由于 Windows 系统更新或使用不同的显示器,在配置C动作时,有可能会出现C脚本编辑窗口被移动到不可见区域的现象。 由于该窗口无法被关闭,故无法进行进一步…

KafkaStream:Springboot中集成

1、在kafka-demo中创建配置类 配置kafka参数 package com.heima.kafkademo.config;import lombok.Data; import org.apache.kafka.common.serialization.Serdes; import org.apache.kafka.streams.StreamsConfig; import org.springframework.boot.context.properties.Configu…

8月11日上课内容 nginx的多实例和动静分离

多实例部署 在一台服务器上有多个tomcat的服务。 配置多实例之前,看单个实例是否访问正常。 1.安装好 jdk 2.安装 tomcat cd /opt tar zxvf apache-tomcat-9.0.16.tar.gz mkdir /usr/local/tomcat mv apache-tomcat-9.0.16 /usr/local/tomcat/tomcat1 cp -a /u…

Linux系统管理:虚拟机ESXi安装

目录 一、理论 1.VMware Workstation 2.VMware vSphere Client 3.ESXi 二、实验 1.ESXi 7安装 一、理论 1.VMware Workstation 它是一款专业的虚拟机软件,可以在一台物理机上运行多个操作系统,支持Windows、Linux等操作系统,可以模拟…

使用selenium如何实现自动登录

回顾使用requests如何实现自动登录一文中,提到好多网站在我们登录过后,在之后的某段时间内访问该网页时,不会给出请登录的提示,时间到期后就会提示请登录!这样在使用爬虫访问网页时还要登录,打乱我们的节奏…

item_get_sales-获取商品销量详情

一、接口参数说明: item_get_sales-获取商品销量详情,点击更多API调试,请移步注册API账号点击获取测试key和secret 公共参数 请求地址: https://api-gw.onebound.cn/taobao/item_get_sales 名称类型必须描述keyString是调用key&#xff08…

ACM模式刷Leetcode题目

139题单词拆分 链接: link #include<iostream> #include<sstream> #include<string> #include<vector> #include<algorithm> #include<unordered_set> using namespace std;int main() {// 实现输入第一行为s字符串。// 第二行为wordDic…

【代码随想录day22】爬楼梯

题目 假设你正在爬楼梯。需要 n 阶你才能到达楼顶。 每次你可以爬 1 或 2 个台阶。你有多少种不同的方法可以爬到楼顶呢&#xff1f; 示例 1&#xff1a; 输入&#xff1a;n 2 输出&#xff1a;2 解释&#xff1a;有两种方法可以爬到楼顶。 1. 1 阶 1 阶 2. 2 阶 示例 2…

Spring的三种异常处理方式

1.SpringMVC 异常的处理流程 异常分为编译时异常和运行时异常&#xff0c;编译时异常我们 try-cache 进行捕获&#xff0c;捕获后自行处理&#xff0c;而运行时异常是不 可预期的&#xff0c;就需要规范编码来避免&#xff0c;在SpringMVC 中&#xff0c;不管是编译异常还是运行…

java:JDBC

文章目录 什么是JDBCJDBC使用步骤详解各个对象DriverManagerConnectionStatementResultSetPreparedStatement JDBC控制事务操作步骤示例 什么是JDBC 我们知道&#xff0c;数据库有很多种&#xff0c;比如 mysql&#xff0c;Oracle&#xff0c;DB2等等&#xff0c;如果每一种数…

C# WPF 中 外部图标引入iconfont,无法正常显示问题 【小白记录】

wpf iconfont 外部图标引入&#xff0c;无法正常显示问题。 1. 检查资源路径和引入格式是否正确2. 检查资源是否包含在程序集中 1. 检查资源路径和引入格式是否正确 正确的格式&#xff0c;注意字体文件 “xxxx.ttf” 应写为 “#xxxx” <TextBlock Text"&#xe7ae;…

不重启Docker能添加自签SSL证书镜像仓库吗?

应用背景 在企业应用Docker规划初期配置非安全镜像仓库时&#xff0c;有时会遗漏一些仓库没配置&#xff0c;但此时应用程序已经在Docker平台上部署起来了&#xff0c;体量越大就越不会让人去直接重启Docker。 那么&#xff0c;不重启Docker能添加自签SSL证书镜像仓库吗&…

经典人体模型SMPL介绍(一)

SMPL是马普所提出的经典人体模型&#xff0c;目前已成为姿态估计、人体重建等领域必不可少的基础先验。SMPL基于蒙皮和BlendShape实现&#xff0c;从数千个三维人体扫描结果得来&#xff0c;后通过PCA统计学习得来。 论文&#xff1a;SMPL: A Skinned Multi-Person Linear Mode…

Python读取svn版本

本文将详细介绍如何使用Python读取svn版本。 一、安装svn库 首先&#xff0c;我们需要使用Python来连接svn服务器&#xff0c;并获取版本号。这里我们使用pysvn库来完成这个工作。 pip install pysvn需要注意的是&#xff0c;如果你需要安装特定版本的pysvn&#xff0c;你可…

2023连锁收银系统该如何选?值得推荐的5款连锁收银系统

现在不管是连锁店还是零售店&#xff0c;只要是开店做生意赚钱的&#xff0c;都少不了要和钱打交道&#xff0c;尤其是对连锁店来说&#xff0c;收银工作更是重中之重。 连锁店涉及的门店较多&#xff0c;必须要有一套足够优秀的连锁收银系统&#xff0c;才能做好每个门店的收银…

【ARM 嵌入式 编译系列 5 -- GCC 内建函数 __builtin 详细介绍】

文章目录 什么是GCC内建函数?GCC 常见内建函数GCC内建函数使用示例上篇文章:ARM 嵌入式 编译系列 4.2 – GCC 链接规范 extern “C“ 介绍 下篇文章:ARM 嵌入式 编译系列 6 – GCC objcopy, objdump, readelf, nm 介绍 什么是GCC内建函数? GCC提供了一些专门的功能,用于…

使用 `tailwindcss-patch@2` 来提取你的类名吧

使用 tailwindcss-patch2 来提取你的类名吧 使用 tailwindcss-patch2 来提取你的类名吧 安装使用方式 命令行 Cli 开始提取吧 Nodejs API 的方式来使用 配置 初始化 What’s next? tailwindcss-patch 是一个 tailwindcss 生态的扩展项目。也是 tailwindcss-mangle 项目重要…

redis的Key的过期策略是如何实现的?

Key的过期策略 一个redis中可能同时存在很多很多key&#xff0c;这些key可能有很大一部分都有过期时间&#xff0c;此时&#xff0c;redis服务器咋知道哪些key已经过期要被删除&#xff0c;哪些key还没有过期&#xff1f; 如果直接遍历所有的key&#xff0c;显然是行不通的&am…