NLP项目实战01--电影评论分类

介绍:

欢迎来到本篇文章!在这里,我们将探讨一个常见而重要的自然语言处理任务——文本分类。具体而言,我们将关注情感分析任务,即通过分析电影评论的情感来判断评论是正面的、负面的。

展示:
训练展示如下:

在这里插入图片描述
在这里插入图片描述

实际使用如下:

请添加图片描述

实现方式:

选择PyTorch作为深度学习框架,使用电影评论IMDB数据集,并结合torchtext对数据进行预处理。

环境:

Windows+Anaconda
重要库版本信息
torch==1.8.2+cu102
torchaudio==0.8.2
torchdata==0.7.1
torchtext==0.9.2
torchvision==0.9.2+cu102

实现思路:

1、数据集
本次使用的是IMDB数据集,IMDB是一个含有50000条关于电影评论的数据集
数据如下:
请添加图片描述
请添加图片描述

2、数据加载与预处理
使用torchtext加载IMDB数据集,并对数据集进行划分
具体划分如下:

TEXT = data.Field(tokenize='spacy', tokenizer_language='en_core_web_sm')
LABEL = data.LabelField(dtype=torch.float)
# Load the IMDB dataset
train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)

创建一个 Field 对象,用于处理文本数据。同时使用spacy分词器对文本进行分词,由于IMDB是英文的,所以使用en_core_web_sm语言模型。
创建一个 LabelField 对象,用于处理标签数据。设置dtype 参数为 torch.float,表示标签的数据类型为浮点型。

使用 datasets.IMDB.splits 方法加载 IMDB 数据集,并将文本字段 TEXT 和标签字段 LABEL 传递给该方法。返回的 train_data 和 test_data 包含了 IMDB 数据集的训练和测试部分。
下面是train_data的输出
请添加图片描述

3、构建词汇表与加载预训练词向量

TEXT.build_vocab(train_data,max_size=25000,vectors="glove.6B.100d",unk_init=torch.Tensor.normal_)
LABEL.build_vocab(train_data)

train_data:表示使用train_data中数据构建词汇表
max_size:限制词汇表的大小为 25000
vectors=“glove.6B.100d”:表示使用预训练的 GloVe 词向量,其中 “glove.6B.100d” 指的是包含 100 维向量的 6B 版 GloVe。
unk_init=torch.Tensor.normal_ :表示指定未知单词(UNK)的初始化方式,这里使用正态分布进行初始化。
LABEL.build_vocab(train_data):表示对标签进行类似的操作,构建标签的词汇表

train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits( (train_data, valid_data, test_data), batch_size=BATCH_SIZE, device=device)

使用data.BucketIterator.splits 来创建数据加载器,包括训练、验证和测试集的迭代器。这将确保你能够方便地以批量的形式获取数据进行训练和评估。

4、定义神经网络
这里的网络定义比较简单,主要采用在词嵌入层(embedding)后接一个全连接层的方式完成对文本数据的分类。
具体如下:

class NetWork(nn.Module):def __init__(self,vocab_size,embedding_dim,output_dim,pad_idx):super(NetWork,self).__init__()self.embedding = nn.Embedding(vocab_size,embedding_dim,padding_idx=pad_idx)self.fc = nn.Linear(embedding_dim,output_dim)self.dropout = nn.Dropout(0.5)self.relu = nn.ReLU()def forward(self,x):embedded = self.embedding(x)embedded = embedded.permute(1,0,2) pooled = F.avg_pool2d(embedded, (embedded.shape[1], 1)).squeeze(1)pooled = self.relu(pooled)pooled = self.dropout(pooled)output = self.fc(pooled)return output

5、模型初始化

vocab_size = len(TEXT.vocab)
embedding_dim  = 100
output = 1
pad_idx = TEXT.vocab.stoi[TEXT.pad_token]
model = NetWork(vocab_size,embedding_dim,output,pad_idx)
pretrained_embeddings = TEXT.vocab.vectors
model.embedding.weight.data.copy_(pretrained_embeddings)

定义模型的超参数,包括词汇表大小(vocab_size)、词向量维度(embedding_dim)、输出维度(output,在这个任务中是1,因为是二元分类,所以使用1),以及 PAD 标记的索引(pad_idx)

之后需要将预训练的词向量加载到嵌入层的权重中。TEXT.vocab.vectors 包含了词汇表中每个单词的预训练词向量,然后通过 copy_ 方法将这些词向量复制到模型的嵌入层权重中对网络进行初始化。这样做确保了模型的初始化状态良好。

6、训练模型

 total_loss = 0train_acc = 0 
model.train()
for batch in train_iterator:optimizer.zero_grad()preds = model(batch.text).squeeze(1)loss = criterion(preds,batch.label)total_loss += loss.item()batch_acc = (torch.round(torch.sigmoid(preds)) == batch.label).sum().item()train_acc += batch_accloss.backward()optimizer.step()average_loss = total_loss / len(train_iterator)train_acc /= len(train_iterator.dataset)

optimizer.zero_grad():表示将模型参数的梯度清零,以准备接收新的梯度。
preds = model(batch.text).squeeze(1):表示一次前向传播的过程,由于model输出的是torch.tensor(batch_size,1)所以使用squeeze(1)给其中的1维度数据去除,以匹配标签张量的形状
criterion(preds,batch.label):定义的损失函数 criterion 计算预测值 preds 与真实标签 batch.label 之间的损失

(torch.round(torch.sigmoid(preds)) == batch.label).sum().item():
通过比较模型的预测值与真实标签,计算当前批次的准确率,并将其累加到 train_acc 中
后面的就是进行反向传播更新参数,还有就是计算loss和train_acc的值了
7、模型评估:

model.eval()valid_loss = 0valid_acc = 0best_valid_acc = 0with torch.no_grad():for batch in valid_iterator:preds = model(batch.text).squeeze(1)loss = criterion(preds,batch.label)valid_loss += loss.item()batch_acc = ((torch.round(torch.sigmoid(preds)) == batch.label).sum().item())valid_acc += batch_acc

和训练模型的类似,这里就不解释了

8、保存模型
这里一共使用了两种保存模型的方式:

torch.save(model, "model.pth")
torch.save(model.state_dict(),"model.pth")

第一种方式叫做模型的全量保存
第二种方式叫做模型的参数保存

全量保存是保存了整个模型,包括模型的结构、参数、优化器状态等信息
参数量保存是保存了模型的参数(state_dict),不包括模型的结构
9、测试模型
测试模型的基本思路:
加载训练保存的模型、对待推理的文本进行预处理、将文本数据加载给模型进行推理

加载模型:

saved_model_path = "model.pth"
saved_model = torch.load(saved_model_path)

输入文本:
input_text = “Great service! The staff was very friendly and helpful.”

文本进行处理:

tokenizer = get_tokenizer("spacy", language="en_core_web_sm")
tokenized_text = tokenizer(input_text)
indexed_text = [TEXT.vocab.stoi[token] for token in tokenized_text]
tensor_text = torch.LongTensor(indexed_text).unsqueeze(1).to(device)

模型推理:

saved_model.eval()
with torch.no_grad():output = saved_model(tensor_text).squeeze(1)prediction = torch.round(torch.sigmoid(output)).item()probability = torch.sigmoid(output).item()

由于笔者能力有限,所以在描述的过程中难免会有不准确的地方,还请多多包含!

更多NLP和CV文章以及完整代码请到"陶陶name"获取。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/209595.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

图像叠加中文字体

目录 1) 前言2) freetype下载3) Demo3.1) 下载3.2) 编译3.3) 运行3.4) 结果3.5) 更详细的使用见目录中说明 4) 积少成多 1) 前言 最近在做图片、视频叠加文字,要求支持中文,基本原理是将图片或视频解码后叠加文字,之后做图片或视频编码即可。…

ASP.NET Core概述-微软已经收购了mono,为什么还搞.NET Core呢

一、.NET Core概述 1、相关历程 .NET在设计之初也是考虑像Java一样跨平台,.NET Framework是在Windows下运行的,大部分类是可以兼容移植到Linux下,但是没有人做这个工作。 2001年米格尔为Gnome寻找桌面开发技术,在研究了微软的.…

数据库版本管理框架-Flyway(从入门到精通)

一、flyway简介 Flyway是一个简单开源数据库版本控制器(约定大于配置),主要提供migrate、clean、info、validate、baseline、repair等命令。它支持SQL(PL/SQL、T-SQL)方式和Java方式,支持命令行客户端等&am…

TCP对数据的拆分

应用程序的数据一般都比较大,因此TCP会按照网络包的大小对数据进行拆分。 当发送缓冲区中的数据超过MSS的长度,数据会被以MSS长度为单位进行拆分,拆分出来的数据块被放进单独的网路包中。 根据发送缓冲区中的数据拆分情况,当判断…

JWT介绍及演示

JWT 介绍 cookie(放在浏览器) cookie 是一个非常具体的东西,指的就是浏览器里面能永久存储的一种数据,仅仅是浏览器实现的一种数据存储功能。 cookie由服务器生成,发送给浏览器,浏览器把cookie以kv形式保存到某个目录下的文本…

JavaScript 金额元转化为万

function dealNum(price){if (price 0) {return 0元}const BASE 10000const decimal 0const SIZES ["", "万", "亿", "万亿"];let i undefined;let str "";if (price) {if ((price > 0 && price < BASE…

通过命令行输入参数控制激励

1)在命令行的仿真参数&#xff08;SIM_OPT&#xff09;加上&#xff1a;“var_a100 var_b99” 2)在环境中调用&#xff1a; $test$plusargs("var_a")&#xff1b;如果命令行存在这个字符&#xff0c;返回1&#xff0c;否则返回0&#xff1b; $value$plusargs(&qu…

蓝牙物联网对接技术难点有哪些?

#物联网# 蓝牙物联网对接技术难点主要包括以下几个方面&#xff1a; 1、设备兼容性&#xff1a;蓝牙技术有多种版本和规格&#xff0c;如蓝牙4.0、蓝牙5.0等&#xff0c;不同版本之间的兼容性可能存在问题。同时&#xff0c;不同厂商生产的蓝牙设备也可能存在兼容性问题。 2、…

0-1背包问题

二维版: import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader;public class Main {static int N 1010;static int[][] dp new int[N][N]; //dp[i][j] 只选前i件物品,体积 < j的最优解static int[] w new int[N]; //存储价…

字符串函数`strlen`、`strcpy`、`strcmp`、`strstr`、`strcat`的使用以及模拟实现

文章目录 &#x1f680;前言&#x1f680;库函数strlen✈️strlen的模拟实现 &#x1f680;库函数strcpy✈️strcpy的模拟实现 &#x1f680;strcmp✈️strcmp的模拟实现 &#x1f680;strstr✈️strstr的模拟实现 &#x1f680;strcat✈️strcat的模拟实现 &#x1f680;前言 …

ReactJS和VueJS的简介以及它们之间的区别

本文主要介绍ReactJS和VueJS的简介以及它们之间的区别。 目录 ReactJS简介ReactJS的优缺点ReactJS的应用场景VueJS简介VueJS的优缺点VueJS的应用场景ReactJS和VueJS的区别 ReactJS简介 ReactJS是一个由Facebook开发的基于JavaScript的前端框架。它是一个用于构建用户界面的库&…

【C语言】——函数递归,用递归简化并实现复杂问题

文章目录 前言一、什么是递归二、递归的限制条件三、递归举例1.求n的阶乘2. 举例2&#xff1a;顺序打印一个整数的每一位 四、递归的优劣总结 前言 不多废话了&#xff0c;直接开始。 一、什么是递归 递归是学习C语言函数绕不开的⼀个话题&#xff0c;那什么是递归呢&#xf…

电商平台商品销量API接口,30天销量API接口接口超详细接入方案说明

电商平台商品销量API接口的作用主要是帮助开发者获取电商平台上的商品销量信息。通过这个接口&#xff0c;开发者可以在自己的应用或网站中实时获取商品的销量数据&#xff0c;以便进行销售分析、库存管理、市场预测等操作。 具体来说&#xff0c;电商平台商品销量API接口的使…

RocketMq集成SpringBoot(待完善)

环境 jdk1.8, springboot2.7.3 Maven依赖 <parent><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-parent</artifactId><version>2.7.3</version><relativePath/> <!-- lookup parent from…

Unity中后处理简介

文章目录 前言一、后处理的原理二、我们看一下Unity文档中&#xff0c;内置的后处理前后的效果后处理前&#xff1a;后处理后&#xff1a; 前言 我们在这篇文章中&#xff0c;了解一下Unity中的后处理效果 后期处理概述 一、后处理的原理 在后处理的过程中&#xff0c;我们主…

Java当中常用的算法

文章目录 算法二叉树左右变换数据二分法实现 冒泡排序算法插入排序算法快速排序算法希尔排序算法归并排序算法桶排序算法基数排序算法分治算法汉诺塔问题动态规划算法引子代码实现背包问题 KMP算法什么是KMP算法暴力匹配KMP算法实现 今天我们来看看常用的算法&#xff0c;开干。…

《微信小程序开发从入门到实战》学习四十五

4.4 云函数 云函数是开发者提前定义好的、保存在云端并且将在云端运行的JS函数。 开发者先定义好云函数&#xff0c;再使用微信开发工具将云函数上传到云空间&#xff0c;在云开发控制台中可看到已经上传的云函数。 云函数运行在云端Node.js环境中。 小程序端通过wx.cloud.…

IP地址定位技术为网络安全建设提供全新方案

随着互联网的普及和数字化进程的加速&#xff0c;网络安全问题日益引人关注。网络攻击、数据泄露、欺诈行为等安全威胁层出不穷&#xff0c;对个人隐私、企业机密和社会稳定构成严重威胁。在这样的背景下&#xff0c;IP地址定位技术应运而生&#xff0c;为网络安全建设提供了一…

短视频ai剪辑分发矩阵系统源码3年技术团队开发搭建打磨

如果您需要搭建这样的系统&#xff0c;建议您寻求专业的技术支持&#xff0c;以确保系统的稳定性和安全性。 在搭建短视频AI剪辑分发矩阵系统时&#xff0c;您需要考虑以下几个方面&#xff1a; 1. 技术实现&#xff1a;您需要选择适合您的需求和预算的技术栈&#xff0c;例如使…

肖sir__ 项目讲解__项目数据

项目时间&#xff1a; 情况一&#xff1a;项目时间开始到上线的时间&#xff0c;这个时间一般比较长&#xff08;一年&#xff0c;二年&#xff0c;三年&#xff09; 情况二&#xff1a;项目的版本的时间或则是周期&#xff08;1个月&#xff0c;2个月&#xff0c;3个月&…