python pytorch实现RNN,LSTM,GRU,文本情感分类

python pytorch实现RNN,LSTM,GRU,文本情感分类

数据集格式:
在这里插入图片描述
有需要的可以联系我

实现步骤就是:
1.先对句子进行分词并构建词表
2.生成word2id
3.构建模型
4.训练模型
5.测试模型

代码如下:


import pandas as pd
import torch
import matplotlib.pyplot as plt
import jieba
import numpy as np"""
作业:
一、完成优化
优化思路1 jieba
2 取常用的3000字
3 修改model:rnn、lstm、gru二、完成测试代码
"""# 了解数据
dd = pd.read_csv(r'E:\peixun\data\train.csv')
# print(dd.head())# print(dd['label'].value_counts())# 句子长度分析
# 确定输入句子长度为 500
text_len = [len(i) for i in dd['text']]
# plt.hist(text_len)
# plt.show()
# print(max(text_len), min(text_len))# 基本参数 config
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('my device:', DEVICE)MAX_LEN = 500
BATCH_SIZE = 16
EPOCH = 1
LR = 3e-4# 构建词表 word2id
vocab = []
for i in dd['text']:vocab.extend(jieba.lcut(i, cut_all=True))  # 使用 jieba 分词# vocab.extend(list(i))vocab_se = pd.Series(vocab)
print(vocab_se.head())
print(vocab_se.value_counts().head())vocab = vocab_se.value_counts().index.tolist()[:3000]  # 取频率最高的 3000 token
# print(vocab[:10])
# exit()WORD_PAD = "<PAD>"
WORD_UNK = "<UNK>"
WORD_PAD_ID = 0
WORD_UNK_ID = 1vocab = [WORD_PAD, WORD_UNK] + list(set(vocab))print(vocab[:10])
print(len(vocab))vocab_dict = {k: v for v, k in enumerate(vocab)}# 词表大小,vocab_dict: word2id; vocab: id2word
print(len(vocab_dict))import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
import pandas as pd# 定义数据集 Dataset
class Dataset(data.Dataset):def __init__(self, split='train'):# ChnSentiCorp 情感分类数据集path =  r'E:/peixun/data/' + str(split) + '.csv'self.data = pd.read_csv(path)def __len__(self):return len(self.data)def __getitem__(self, i):text = self.data.loc[i, 'text']label = self.data.loc[i, 'label']return text, label# 实例化 Dataset
dataset = Dataset('train')# 样本数量
print(len(dataset))
print(dataset[0])# 句子批处理函数
def collate_fn(batch):# [(text1, label1), (text2, label2), (3, 3)...]sents = [i[0][:MAX_LEN] for i in batch]labels = [i[1] for i in batch]inputs = []# masks = []for sent in sents:sent = [vocab_dict.get(i, WORD_UNK_ID) for i in list(sent)]pad_len = MAX_LEN - len(sent)# mask = len(sent) * [1] + pad_len * [0]# masks.append(mask)sent += pad_len * [WORD_PAD_ID]inputs.append(sent)# 只使用 lstm 不需要用 masks# masks = torch.tensor(masks)# print(inputs)inputs = torch.tensor(inputs)labels = torch.LongTensor(labels)return inputs.to(DEVICE), labels.to(DEVICE)# 测试 loader
loader = data.DataLoader(dataset,batch_size=BATCH_SIZE,collate_fn=collate_fn,shuffle=True,drop_last=False)inputs, labels = iter(loader).__next__()
print(inputs.shape, labels)# 定义模型
class Model(nn.Module):def __init__(self, vocab_size=5000):super().__init__()self.embed = nn.Embedding(vocab_size, 100, padding_idx=WORD_PAD_ID)# 多种 rnnself.rnn = nn.RNN(100, 100, 1, batch_first=True, bidirectional=True)self.gru = nn.GRU(100, 100, 1, batch_first=True, bidirectional=True)self.lstm = nn.LSTM(100, 100, 1, batch_first=True, bidirectional=True)self.l1 = nn.Linear(500 * 100 * 2, 100)self.l2 = nn.Linear(100, 2)def forward(self, inputs):out = self.embed(inputs)out, _ = self.lstm(out)out = out.reshape(BATCH_SIZE, -1)  # 16 * 100000out = F.relu(self.l1(out))  # 16 * 100out = F.softmax(self.l2(out))  # 16 * 2return out# 测试 Model
model = Model()
print(model)# 模型训练
dataset = Dataset()
loader = data.DataLoader(dataset,batch_size=BATCH_SIZE,collate_fn=collate_fn,shuffle=True)model = Model().to(DEVICE)# 交叉熵损失
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)model.train()
for e in range(EPOCH):for idx, (inputs, labels) in enumerate(loader):# 前向传播,计算预测值out = model(inputs)# 计算损失loss = loss_fn(out, labels)# 反向传播,计算梯度loss.backward()# 参数更新optimizer.step()# 梯度清零optimizer.zero_grad()if idx % 10 == 0:out = out.argmax(dim=-1)acc = (out == labels).sum().item() / len(labels)print('>>epoch:', e,'\tbatch:', idx,'\tloss:', loss.item(),'\tacc:', acc)# 模型测试
test_dataset = Dataset('test')
test_loader = data.DataLoader(test_dataset,batch_size=BATCH_SIZE,collate_fn=collate_fn,shuffle=False)loss_fn = nn.CrossEntropyLoss()out_total = []
labels_total = []model.eval()
for idx, (inputs, labels) in enumerate(test_loader):out = model(inputs)loss = loss_fn(out, labels)out_total.append(out)labels_total.append(labels)if idx % 50 == 0:print('>>batch:', idx, '\tloss:', loss.item())correct=0
sumz=0
for i in range(len(out_total)):out = out_total[i].argmax(dim=-1)correct = (out == labels_total[i]).sum().item() +correctsumz=sumz+len(labels_total[i])#acc = (out_total == labels_total).sum().item() / len(labels_total)print('>>acc:', correct/sumz)

运行结果如下:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/186292.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

命名管道:简单案例实现

&#x1f4df;作者主页&#xff1a;慢热的陕西人 &#x1f334;专栏链接&#xff1a;Linux &#x1f4e3;欢迎各位大佬&#x1f44d;点赞&#x1f525;关注&#x1f693;收藏&#xff0c;&#x1f349;留言 本博客主要内容讲解了什么是命名管道&#xff0c;匿名管道和命名管道的…

【GraphQL 】将GraphQL API添加到Postgres数据库的六种简单方法,比较Hasura、Prisma和其他

PostgreSQL是世界上最流行的开源SQL数据库之一&#xff0c;GraphQL是一种日益流行的API规范。 将经过验证和众所周知的PostgreSQL与GraphQL带来的API创建新方式集成在一起不是很好吗&#xff1f; 在本文中&#xff0c;我们讨论了六个不同的项目&#xff0c;它们试图将SQL与Gr…

深入了解Rabbit加密技术:原理、实现与应用

一、引言 在信息时代&#xff0c;数据安全愈发受到重视&#xff0c;加密技术作为保障信息安全的核心手段&#xff0c;得到了广泛的研究与应用。Rabbit加密技术作为一种新型加密方法&#xff0c;具有较高的安全性和便捷性。本文将对Rabbit加密技术进行深入探讨&#xff0c;分析…

【动态规划】LeetCode-面试题08.01三步问题

&#x1f388;算法那些事专栏说明&#xff1a;这是一个记录刷题日常的专栏&#xff0c;每个文章标题前都会写明这道题使用的算法。专栏每日计划至少更新1道题目&#xff0c;在这立下Flag&#x1f6a9; &#x1f3e0;个人主页&#xff1a;Jammingpro &#x1f4d5;专栏链接&…

六、初识FreeRTOS之FreeRTOS的任务挂起和恢复函数介绍

本节需要掌握以下内容&#xff1a; 1&#xff0c;任务的挂起与恢复的API函数&#xff08;熟悉&#xff09; 2&#xff0c;任务挂起与恢复实验&#xff08;掌握&#xff09; 3&#xff0c;课堂总结&#xff08;掌握&#xff09; 一、任务的挂起与恢复的API函数&#xff08;熟…

exceljs读取el-upload上传的excle数据并转为json输出

当使用 Element UI 的 el-upload 组件上传 Excel 文件时&#xff0c;您可以使用 exceljs 库将上传的 Excel 数据转换为 JSON 格式。以下是一个示例代码&#xff0c;演示了如何在 Vue 项目中实现这一功能&#xff1a; <template><el-uploadclass"upload-demo&quo…

C++ day41 动态规划 整数拆分 不同的二叉搜索树

题目1&#xff1a;343 整数拆分 题目链接&#xff1a;整数拆分 对题目的理解 将正整数n&#xff0c;拆分成k个正整数的和&#xff08;k>2&#xff09;使得这些整数的乘积最大化&#xff0c;返回最大乘积 动规五部曲 1&#xff09;dp数组的含义以及其下标i的含义 dp[i]…

Verilog 入门(四)(门电平模型化)

文章目录 内置基本门多输入门简单示例 内置基本门 Verilog HDL 中提供下列内置基本门&#xff1a; 多输入门 and&#xff0c;nand&#xff0c;or&#xff0c;nor&#xff0c;xor&#xff0c;xnor 多输出门 buf&#xff0c;not 三态门上拉、下拉电阻MOS 开关双向开关 门级逻辑…

OSG编程指南<十七>:OSG光照与材质

1、OSG光照 OSG 全面支持 OpenGL 的光照特性&#xff0c;包括材质属性&#xff08;material property&#xff09;、光照属性&#xff08;light property&#xff09;和光照模型&#xff08;lighting model&#xff09;。与 OpenGL 相似&#xff0c;OSG 中的光源也是不可见的&a…

工博会新闻稿汇总

23届工博会媒体报道汇总 点击文章标题即可进入详情页 9月23日&#xff0c;第23届工博会圆满落幕&#xff01;本届工博会规模之大、能级之高、新展品之多创下历史之最。高校展区在规模、能级和展品上均也创下新高。工博会系列报道深入探讨了高校科技发展的重要性和多方面影响。…

【合集】MQ消息队列——Message Queue消息队列的合集文章 RabbitMQ入门到使用

前言 RabbitMQ作为一款常用的消息中间件&#xff0c;在微服务项目中得到大量应用&#xff0c;其本身是微服务中的重点和难点。本篇博客是Message Queue相关的学习博客文章的合集篇&#xff0c;目前主要是RabbitMQ入门到使用文章&#xff0c;后续会扩展其他MQ。 目录 前言一、R…

自定义链 SNAT / DNAT 实验举例

参考原理图 实验前的环境搭建 1. 准备三台虚拟机&#xff0c;定义为内网&#xff0c;外网以及网卡服务器 2. 给网卡服务器添加网卡 3. 将三台虚拟机的防火墙和安全终端全部关掉 systemctl stop firewalld && setenforce 0 4. 给内网虚拟机和外网虚拟机 yum安装 httpd…

阿里云国际短信业务网络超时排障指南

选取一台或多台线上的应用服务器或选取相同网络环境下的机器&#xff0c;执行以下操作。 获取公网出口IP。 curl ifconfig.me 测试连通性。 &#xff08;推荐&#xff09;执行MTR命令&#xff08;可能需要sudo权限&#xff09;&#xff0c;检测连通性&#xff0c;执行30秒。 m…

【华为OD题库-052】数字序列比大小-java

题目 A&#xff0c;B两个人玩一个数字比大小的游戏&#xff0c;在游戏前&#xff0c;两个人会拿到相同长度的两个数字序列&#xff0c;两个数字序列是不完全相同的&#xff0c;且其中的数字是随机的。 A&#xff0c;B各自从数字序列中挑选出一个数字进行大小比较&#xff0c;赢…

Scrapy框架中间件(一篇文章齐全)

1、Scrapy框架初识&#xff08;点击前往查阅&#xff09; 2、Scrapy框架持久化存储&#xff08;点击前往查阅&#xff09; 3、Scrapy框架内置管道&#xff08;点击前往查阅&#xff09; 4、Scrapy框架中间件 Scrapy 是一个开源的、基于Python的爬虫框架&#xff0c;它提供了…

HashMap的实现原理

1.HashMap实现原理 HashMap的数据结构&#xff1a; *底层使用hash表数据结构&#xff0c;即数组链表红黑树 当我们往HashMap中put元素时&#xff0c;利用key的hashCode重新hash计算出当前对象的元素在数组中的下标 存储时&#xff0c;如果出现hash值相同的key&#xff0c;此时…

自动化测试 —— 如何优雅实现方法的依赖!

在 seldom 3.4.0 版本实现了该功能。 在复杂的测试场景中&#xff0c;常常会存在用例依赖&#xff0c;以一个接口自动化平台为例&#xff0c;依赖关系&#xff1a; 创建用例 --> 创建模块 --> 创建项目 --> 登录。 用例依赖的问题 •用例的依赖对于的执行顺序有严格…

SpringBoot——Spring Security 框架

优质博文&#xff1a;IT-BLOG-CN 一、Spring Security 简介 Spring Security是一个能够为基于Spring的企业应用系统提供声明式的安全访问控制解决方案的安全框架。它提供了一组可以在Spring应用上下文中配置的 Bean&#xff0c;充分利用了Spring IoC&#xff0c;DI&#xff0…

什么是 Proxy?

目录 Proxy 的作用 1. 流量过滤 2. 记录日志 3. 加快访问速度 4. 隐藏 IP 地址 Proxy 的分类 1. 按协议分类 - HTTP 代理&#xff1a;只支持 HTTP 协议的代理服务器&#xff0c;它可以缓存 HTTP 请求和响应并过滤 HTTP 流量。 - FTP 代理&#xff1a;只支持 FTP 协议的…

异常数据检测 | Python实现孤立森林(IsolationForest)异常检测

孤立森林(IsolationForest)异常检测 IsolationForest[6]算法它是一种集成算法(类似于随机森林)主要用于挖掘异常(Anomaly)数据,或者说离群点挖掘,总之是在一大堆数据中,找出与其它数据的规律不太符合的数据。该算法不采样任何基于聚类或距离的方法,因此他和那些基于距离的的…