NLP项目实战01之电影评论分类

介绍:

欢迎来到本篇文章!在这里,我们将探讨一个常见而重要的自然语言处理任务——文本分类。具体而言,我们将关注情感分析任务,即通过分析电影评论的情感来判断评论是正面的、负面的。

展示:
训练展示如下:

在这里插入图片描述
在这里插入图片描述

实际使用如下:

请添加图片描述

实现方式:

选择PyTorch作为深度学习框架,使用电影评论IMDB数据集,并结合torchtext对数据进行预处理。

环境:

Windows+Anaconda
重要库版本信息
torch==1.8.2+cu102
torchaudio==0.8.2
torchdata==0.7.1
torchtext==0.9.2
torchvision==0.9.2+cu102

实现思路:

1、数据集
本次使用的是IMDB数据集,IMDB是一个含有50000条关于电影评论的数据集
数据如下:
请添加图片描述
请添加图片描述

2、数据加载与预处理
使用torchtext加载IMDB数据集,并对数据集进行划分
具体划分如下:

TEXT = data.Field(tokenize='spacy', tokenizer_language='en_core_web_sm')
LABEL = data.LabelField(dtype=torch.float)
# Load the IMDB dataset
train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)

创建一个 Field 对象,用于处理文本数据。同时使用spacy分词器对文本进行分词,由于IMDB是英文的,所以使用en_core_web_sm语言模型。
创建一个 LabelField 对象,用于处理标签数据。设置dtype 参数为 torch.float,表示标签的数据类型为浮点型。

使用 datasets.IMDB.splits 方法加载 IMDB 数据集,并将文本字段 TEXT 和标签字段 LABEL 传递给该方法。返回的 train_data 和 test_data 包含了 IMDB 数据集的训练和测试部分。
下面是train_data的输出
请添加图片描述

3、构建词汇表与加载预训练词向量

TEXT.build_vocab(train_data,max_size=25000,vectors="glove.6B.100d",unk_init=torch.Tensor.normal_)
LABEL.build_vocab(train_data)

train_data:表示使用train_data中数据构建词汇表
max_size:限制词汇表的大小为 25000
vectors=“glove.6B.100d”:表示使用预训练的 GloVe 词向量,其中 “glove.6B.100d” 指的是包含 100 维向量的 6B 版 GloVe。
unk_init=torch.Tensor.normal_ :表示指定未知单词(UNK)的初始化方式,这里使用正态分布进行初始化。
LABEL.build_vocab(train_data):表示对标签进行类似的操作,构建标签的词汇表

train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits( (train_data, valid_data, test_data), batch_size=BATCH_SIZE, device=device)

使用data.BucketIterator.splits 来创建数据加载器,包括训练、验证和测试集的迭代器。这将确保你能够方便地以批量的形式获取数据进行训练和评估。

4、定义神经网络
这里的网络定义比较简单,主要采用在词嵌入层(embedding)后接一个全连接层的方式完成对文本数据的分类。
具体如下:

class NetWork(nn.Module):def __init__(self,vocab_size,embedding_dim,output_dim,pad_idx):super(NetWork,self).__init__()self.embedding = nn.Embedding(vocab_size,embedding_dim,padding_idx=pad_idx)self.fc = nn.Linear(embedding_dim,output_dim)self.dropout = nn.Dropout(0.5)self.relu = nn.ReLU()def forward(self,x):embedded = self.embedding(x)embedded = embedded.permute(1,0,2) pooled = F.avg_pool2d(embedded, (embedded.shape[1], 1)).squeeze(1)pooled = self.relu(pooled)pooled = self.dropout(pooled)output = self.fc(pooled)return output

5、模型初始化

vocab_size = len(TEXT.vocab)
embedding_dim  = 100
output = 1
pad_idx = TEXT.vocab.stoi[TEXT.pad_token]
model = NetWork(vocab_size,embedding_dim,output,pad_idx)
pretrained_embeddings = TEXT.vocab.vectors
model.embedding.weight.data.copy_(pretrained_embeddings)

定义模型的超参数,包括词汇表大小(vocab_size)、词向量维度(embedding_dim)、输出维度(output,在这个任务中是1,因为是二元分类,所以使用1),以及 PAD 标记的索引(pad_idx)

之后需要将预训练的词向量加载到嵌入层的权重中。TEXT.vocab.vectors 包含了词汇表中每个单词的预训练词向量,然后通过 copy_ 方法将这些词向量复制到模型的嵌入层权重中对网络进行初始化。这样做确保了模型的初始化状态良好。

6、训练模型

 total_loss = 0train_acc = 0 
model.train()
for batch in train_iterator:optimizer.zero_grad()preds = model(batch.text).squeeze(1)loss = criterion(preds,batch.label)total_loss += loss.item()batch_acc = (torch.round(torch.sigmoid(preds)) == batch.label).sum().item()train_acc += batch_accloss.backward()optimizer.step()average_loss = total_loss / len(train_iterator)train_acc /= len(train_iterator.dataset)

optimizer.zero_grad():表示将模型参数的梯度清零,以准备接收新的梯度。
preds = model(batch.text).squeeze(1):表示一次前向传播的过程,由于model输出的是torch.tensor(batch_size,1)所以使用squeeze(1)给其中的1维度数据去除,以匹配标签张量的形状
criterion(preds,batch.label):定义的损失函数 criterion 计算预测值 preds 与真实标签 batch.label 之间的损失

(torch.round(torch.sigmoid(preds)) == batch.label).sum().item():
通过比较模型的预测值与真实标签,计算当前批次的准确率,并将其累加到 train_acc 中
后面的就是进行反向传播更新参数,还有就是计算loss和train_acc的值了
7、模型评估:

model.eval()valid_loss = 0valid_acc = 0best_valid_acc = 0with torch.no_grad():for batch in valid_iterator:preds = model(batch.text).squeeze(1)loss = criterion(preds,batch.label)valid_loss += loss.item()batch_acc = ((torch.round(torch.sigmoid(preds)) == batch.label).sum().item())valid_acc += batch_acc

和训练模型的类似,这里就不解释了

8、保存模型
这里一共使用了两种保存模型的方式:

torch.save(model, "model.pth")
torch.save(model.state_dict(),"model.pth")

第一种方式叫做模型的全量保存
第二种方式叫做模型的参数保存

全量保存是保存了整个模型,包括模型的结构、参数、优化器状态等信息
参数量保存是保存了模型的参数(state_dict),不包括模型的结构
9、测试模型
测试模型的基本思路:
加载训练保存的模型、对待推理的文本进行预处理、将文本数据加载给模型进行推理

加载模型:

saved_model_path = "model.pth"
saved_model = torch.load(saved_model_path)

输入文本:
input_text = “Great service! The staff was very friendly and helpful.”

文本进行处理:

tokenizer = get_tokenizer("spacy", language="en_core_web_sm")
tokenized_text = tokenizer(input_text)
indexed_text = [TEXT.vocab.stoi[token] for token in tokenized_text]
tensor_text = torch.LongTensor(indexed_text).unsqueeze(1).to(device)

模型推理:

saved_model.eval()
with torch.no_grad():output = saved_model(tensor_text).squeeze(1)prediction = torch.round(torch.sigmoid(output)).item()probability = torch.sigmoid(output).item()

由于笔者能力有限,所以在描述的过程中难免会有不准确的地方,还请多多包含!

更多NLP和CV文章以及完整代码请到"陶陶name"获取。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/209046.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【基于LicheePi-4A的 人脸识别系统软件设计】

参考:https://www.xrvm.cn/community/post/detail?spm=a2cl5.27438731.0.0.31d40dck0dckmg&id=4253195599836418048 1.前言 原先计划做基于深度学习的炸药抓取和智能填装方法研究,但是后来发现板卡不支持pyrealsense2等多个依赖包。因此改变策略,做一款基于LicheePi…

Android Studio的笔记--三元表达式、布尔运算符、与() 或(||) 非(!)

[TOC](三元表达式、布尔运算符、与(&&) 或(||) 非(!)) 表达式 int x 1; int y 2;x < y 结果 true x > y 结果 false x < y 结果 false x > y 结果 true x y 结果 false x ! y 结果 true 布尔运算符 boolean boolean a true; boolean b false; 与…

【Python】列表乘积的计算时间

概述 使用以下三种模式测量了计算列表乘积所需的时间。 使用 for 语句传递list使用math模块使用numpy 下面是实际运行的代码。 import timestart time.time() A [1] * 100000000 ans 1 for a in A:ans * a print("list loop:", time.time() - start)import m…

前端面试提问(4)

1、手撕防抖与节流、树与对象的转换、递归调用&#xff0c;链表头插法 1.1、防抖 防抖函数用于延迟执行某个函数&#xff0c;直到过了一定的间隔时间&#xff08;例如等待用户停止输入&#xff09;后再执行。 即后一次点击事件发生时间距离一次点击事件至少间隔一定时间。 …

笙默考试管理系统-MyExamTest----codemirror(49)

笙默考试管理系统-MyExamTest----codemirror&#xff08;49&#xff09; 目录 笙默考试管理系统-MyExamTest----codemirror&#xff08;49&#xff09; 一、 笙默考试管理系统-MyExamTest----codemirror 二、 笙默考试管理系统-MyExamTest----codemirror 三、 笙默考试…

有哪些已经上线的vue商城项目?

前言 下面是一些商城的项目&#xff0c;需要练手的同学可以挑选一些来练&#xff0c;废话少说&#xff0c;让我们直接开始正题~~ 1、newbee-mall-vue3-app 是一个基于 Vue 3 和 TypeScript 的电商前端项目&#xff0c;它是 newbee-mall 项目的升级版。该项目包含了商品列表、…

内网环境下 - 安装linux命令、搭建docker以及安装镜像

一 内网环境安装docker 先在外网环境下载好docker二进制文件docker二进制文件下载&#xff0c;要下载对应硬件平台的文件&#xff0c;否则不兼容 如下载linux平台下的文件&#xff0c;直接访问这里即可linux版本docker二进制文件 这里下载docker-24.0.5.tgz 将下载好的文件…

计算机存储单位 + 程序编译过程

C语言的编译过程 计算机存储单位 头文件包含的两种方式 使用 C/C 程序常用的IDE 常用的C语言编译器&#xff1a; 在选择编译器时&#xff0c;需考虑平台兼容性、性能优化、调试工具和开发人员的个人偏好等因素。 详细教程可转 爱编程的大丙

Java编程中通用的正则表达式(一)

正则表达式&#xff08;Regular Expression&#xff0c;简称RegEx&#xff09;&#xff0c;又称常规表示法、正则表示、正规表示式、规则表达式、常式、表达式等&#xff0c;是计算机科学中的一个概念。正则表达式是用于描述某种特定模式的字符序列&#xff0c;特别是用来匹配、…

持续集成和持续交付

引言 CI/CD 是一种通过在应用开发阶段引入自动化来频繁向客户交付应用的方法。CI/CD 的核心概念是持续集成、持续交付和持续部署。作为一种面向开发和运维团队的解决方案&#xff0c;CI/CD 主要针对在集成新代码时所引发的问题&#xff08;亦称&#xff1a;“集成地狱”&#…

力扣刷题笔记——反转链表

力扣&#xff08;LeetCode&#xff09;官网 - 全球极客挚爱的技术成长平台 经典问题反转链表 这里给出四种解法 1.双指针 这种方法是用一个next指针记录当前节点的下一个节点&#xff0c;一个pre指针记录当前节点的前一个节点。 只需要遍历一遍链表就可以完成链表的反转 c…

idea__SpringBoot微服务05——JSR303校验(新注解)(新的依赖),配置文件优先级,多环境切换

JSR303校验&#xff0c;配置文件优先级&#xff0c;多环境切换 一、JSR303数据校验二、配置文件优先级三、多环境切换一、properties多环境切换二、yaml多环境切换————————创作不易&#xff0c;如觉不错&#xff0c;随手点赞&#xff0c;关注&#xff0c;收藏(*&#x…

电脑待机怎么设置?让你的电脑更加节能

在日常使用电脑的过程中&#xff0c;合理设置待机模式是一项省电且环保的好习惯。然而&#xff0c;许多用户对于如何设置电脑待机感到困扰。那么电脑待机怎么设置呢&#xff1f;本文将深入探讨三种常用的电脑待机设置方法&#xff0c;通过详细的步骤&#xff0c;帮助用户更好地…

【C语言期末】题目+笔记

文章目录 题目1.下面哪个不是C语言的基本数据类型&#xff1f;&#xff08; B &#xff09;2.C语言的标识符应以字母或&#xff08; A &#xff09;开头。3.如果需要在C程序里调用标准函数库中的printf函数&#xff0c;则应该在程序的开头包含哪个头文件&#xff1f;&#xff0…

【数据结构】顺序表的定义和运算

目录 1.初始化 2.插入 3.删除 4.查找 5.修改 6.长度 7.遍历 8.完整代码 &#x1f308;嗨&#xff01;我是Filotimo__&#x1f308;。很高兴与大家相识&#xff0c;希望我的博客能对你有所帮助。 &#x1f4a1;本文由Filotimo__✍️原创&#xff0c;首发于CSDN&#x1f4da;。 &…

web前端开发html/css练习

目标图&#xff1a; 素材&#xff1a; 代码&#xff1a; <!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"> <html xmlns"http://www.w3.org/1999/xhtml"…

使用RSA工具进行对信息加解密

我们在开发中需要对用户敏感数据进行加解密&#xff0c;比如密码 这边科普一下RSA算法 RSA是非对称加密算法&#xff0c;与对称加密算法不同;在对称加密中&#xff0c;相同的密钥用于加密和解密数据,因此密钥的安全性至关重要;而在RSA非对称加密中&#xff0c;有两个密钥&…

【USRP】5G / 6G OAI 系统 5g / 6G OAI system

面向5G/6G科研应用 USRP专门用于5G/6G产品的原型开发与验证。该系统可以在实验室搭建一个真实的5G 网络&#xff0c;基于开源的代码&#xff0c;专为科研用户设计。 软件无线电架构&#xff0c;构建真实5G移动通信系统 X410 采用了目前流行的异构式系统&#xff0c;融合了FP…

SQLite基本使用

目录 1. 概述2. 引入SQLite3. 连接数据库创建游标4. 创建数据库文件5. 新增单条数据6. 批量新增数据7. 查询单条数据8.查询全部数据9. 查询指定条数的数据10. 修改数据11. 删除数据12. 事务回滚13. 关闭数据库关闭游标1. 概述 SQLite是一个进程内的库,实现了自给自足的、无服务…

【嵌入式开发 Linux 常用命令系列 4.2 -- .repo 各个目录介绍】

文章目录 概述.repo 目录结构manifests/default.xmlManifest 文件的作用default.xml 文件内容示例linkfile 介绍 .repo/projects 子目录配置和管理configHEADhooksinfo/excludeobjectsrr-cache 工作区中的对应目录 概述 repo 是一个由 Google 开发的版本控制工具&#xff0c;它…