基于pytorch 的RNN实现文本分类

首先,需要导入必要的库,包括torch、torchtext、numpy等:

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torchtext.datasets import AG_NEWS
from torchtext.data.utils import get_tokenizer
from collections import Counter

然后,我们需要加载数据集并进行数据预处理。在这里,我们使用AG News数据集,其中包含120,000个新闻文本,分为四个不同的类别:World、Sports、Business和Sci/Tech。我们首先定义一个函数来预处理数据:

# 加载数据集
train_dataset, test_dataset = AG_NEWS()# 定义tokenizer,用于将文本转换为单词列表
tokenizer = get_tokenizer('basic_english')# 定义函数preprocess,用于将文本转换为数值向量
def preprocess(dataset):# 定义空列表,用于存放文本data = []# 遍历数据集中的每个样本for (label, text) in dataset:# 将文本转换为单词列表tokens = tokenizer(text)# 将单词列表转换为数值向量vector = [vocab.stoi[token] for token in tokens]# 将标签和数值向量打包成元组,并添加到data列表中data.append((label, torch.tensor(vector)))return data# 统计数据集中所有单词的出现频率,并将出现频率最高的50000个单词作为词汇表
counter = Counter()
for (label, text) in train_dataset:tokens = tokenizer(text)counter.update(tokens)
vocab = torchtext.vocab.Vocab(counter, max_size=50000)# 使用preprocess函数将数据集转换为数值向量形式
train_data = preprocess(train_dataset)
test_data = preprocess(test_dataset)

接下来,我们定义一个RNN模型,用于对文本进行分类。这里我们使用LSTM作为我们的RNN模型,并将其应用于文本分类任务。LSTM是一种特殊的RNN模型,它能够在处理长序列时更好地保留先前的信息。下面是代码:

class LSTMModel(nn.Module):def __init__(self, input_dim, hidden_dim, output_dim):super(LSTMModel, self).__init__()self.embedding = nn.Embedding(input_dim, hidden_dim)self.lstm = nn.LSTM(hidden_dim, hidden_dim)self.fc = nn.Linear(hidden_dim, output_dim)def forward(self, x):# 将输入x的每个元素(即每个数值向量)通过embedding层转换为向量embedded = self.embedding(x)# 将embedding后的向量输入到LSTM中output, (hidden, cell) = self.lstm(embedded)# 取LSTM的最后一个输出作为模型的输出prediction = self.fc(hidden[-1])return prediction

在上面的代码中,我们首先定义了一个名为LSTMModel的类,它继承自nn.Module类。在__init__中,我们定义了三个层:embedding层、LSTM层和全连接层(也称为线性层)。embedding层用于将输入的数值向量转换为向量表示,LSTM层用于在处理序列数据时保留先前的信息,全连接层用于将LSTM输出转换为预测标签。

在forward函数中,我们首先通过embedding层将输入x转换为向量表示,然后将其输入到LSTM中。由于LSTM是一种可以处理序列数据的RNN模型,因此它能够保留先前的信息,并生成一个输出向量。在这里,我们选择使用LSTM的最后一个输出作为模型的输出向量。最后,我们将输出向量输入到全连接层中,以生成最终的预测标签。

接下来,我们需要训练我们的模型。我们首先定义一个函数,用于计算模型在测试集上的准确率:

def evaluate(model, data):correct = 0total = 0with torch.no_grad():for (label, text) in data:output = model(text.unsqueeze(0)) # 将输入张量增加一维,以便输入模型predicted = torch.argmax(output.squeeze()) # 取最大值作为预测结果if predicted == label:correct += 1total += 1return correct / total

在上面的代码中,我们定义了一个名为evaluate的函数,该函数接受一个模型和数据作为输入,并返回模型在数据上的准确率。在函数中,我们首先将输入张量的维度增加一维,以便输入到模型中。然后,我们使用torch.argmax函数找到输出向量中的最大值,并将其作为预测结果。最后,我们计算模型在测试集上的准确率。

现在我们可以开始训练我们的模型了。我们首先定义一些超参数:

input_dim = len(vocab)
hidden_dim = 128
output_dim = 4
batch_size = 64
learning_rate = 0.001
num_epochs = 5

这里,我们定义了词汇表的大小、隐藏层的维度、输出维度、批次大小、学习率和训练轮数等超参数。

接下来,我们实例化我们的模型,并定义损失函数和优化器:

model = LSTMModel(input_dim, hidden_dim, output_dim)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

在上面的代码中,我们实例化了我们的模型LSTMModel,并定义了损失函数CrossEntropyLoss和优化器Adam。

现在,我们可以开始训练我们的模型了。对于每个epoch,我们将训练集分成若干个小批次,并对每个小批次进行训练。在每个小批次训练结束后,我们将测试集输入到我们的模型中,并计算模型的准确率。最后,我们输出每个epoch的损失和准确率:

for epoch in range(num_epochs):np.random.shuffle(train_data)train_loss = 0train_correct = 0train_total = 0for i in range(0, len(train_data), batch_size):batch = train_data[i:i+batch_size]labels, texts = zip(*batch)labels = torch.tensor(labels)texts = nn.utils.rnn.pad_sequence(texts, batch_first=True)optimizer.zero_grad()output = model(texts)loss = criterion(output, labels)loss.backward()optimizer.step()train_loss += loss.item() * len(batch)train_correct += torch.sum(torch.argmax(output, dim=1) == labels).item()train_total += len(batch)train_accuracy = train_correct / train_totaltest_accuracy = evaluate(model, test_data)print('Epoch [%d/%d], Loss: %.4f, Train Acc: %.4f, Test Acc: %.4f'% (epoch+1, num_epochs, train_loss / len(train_data),train_accuracy, test_accuracy))

在上面的代码中,我们使用np.random.shuffle函数对训练数据进行随机化处理,并按照batch_size的大小将其分成若干个小批次。在每个小批次训练结束后,我们将记录损失值、训练集准确率和测试集准确率。最后,我们输出每个epoch的损失和准确率。

到此,我们就完成了基于PyTorch的RNN实现文本分类的代码和解释。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/219594.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

视频推拉流平台EasyDSS点播文件播放请求添加token验证的实现方法

EasyDSS视频直播点播平台可提供一站式的视频推拉流、转码、点播、直播、播放H.265编码视频等服务,搭配RTMP高清摄像头使用,可将设备的实时流推送到平台上,实现无人机视频推流直播等应用。今天我们来介绍下EasyDSS系统点播文件播放请求添加tok…

Linux---创建、删除文件及目录命令

1. 创建、删除文件及目录命令的使用 命令说明touch 文件名创建指定文件mkdir 目录名创建目录(文件夹)rm 文件名或者目录名删除指定文件或者目录rmdir 目录名删除空目录 touch命令效果图: mkdir命令效果图: rm命令效果图: rm删除目录效果图 说明: rm命令想要删除目录需要加上…

MATLAB算法实战应用案例精讲-【数模应用】漫谈机器学习(八)

目录 几个相关概念 01特征统计 02概率分布 03 降维 04 过采样和欠采样

【工具】VUE 前端列表拖拽功能代码

【工具】VUE 前端列表拖拽功能代码 使用组件 yarn add sortablejs --save Sortable.js中文网 (sortablejs.com) 以下代码只是举个例子&#xff0c; 大家可以举一反三去实现各自的业务功能 <template><div><el-button type"primary" click"切换…

HTML---表单

文章目录 前言一、pandas是什么&#xff1f;二、使用步骤 1.引入库2.读入数据总结 一.表单概念 HTML表单是网页上用于收集用户输入信息的一种元素。它由一系列输入字段&#xff08;input&#xff09;、选择字段&#xff08;select&#xff09;、文本区域&#xff08;textarea&a…

缓存雪崩问题与应对策略

目录 1. 缓存雪崩的原因 1.1 缓存同时失效 1.2 缓存层无法应对高并发 1.3 缓存和后端系统之间存在紧密关联 2. 缓存雪崩的影响 2.1 系统性能下降 2.2 数据库压力激增 2.3 用户请求失败率增加 3. 应对策略 3.1 多级缓存 3.2 限流与降级 3.3 异步缓存更新 3.4 并发控…

​Linux Ubuntu环境下安装配置Docker 和Docker、compose、mysql、中文版portainer

​Linux Ubuntu环境下安装配置Docker 和Docker、compose、mysql、中文版portainer 这篇文章探讨了在Linux Ubuntu环境下安装和配置Docker及其相关工具的过程。首先介绍了Docker的基本概念&#xff0c;然后详细讲解了在Ubuntu系统上的安装步骤。随后&#xff0c;文章涵盖了Dock…

智能优化算法应用:基于旗鱼算法3D无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用&#xff1a;基于旗鱼算法3D无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用&#xff1a;基于旗鱼算法3D无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.旗鱼算法4.实验参数设定5.算法结果6.参考文献7.MA…

yaml 文件格式

yaml文件&#xff1a;是一种标记语言&#xff0c;以竖列形式展示序列化的时间格式&#xff0c;可读性高 类似于json格式。语法简单。 yaml通过缩进来表示数据结构&#xff0c;连续的项目用-减号来表示。 yaml文件使用的注意事项&#xff1a; 1&#xff0c;大小写敏感 2&am…

Apache Web 服务器监控工具

将Apache Web 服务器监控纳入 IT 基础架构管理策略有助于先发制人地识别性能瓶颈&#xff0c;这种主动监控方法提供必要的数据&#xff0c;以确保 Web 服务器能够胜任任务&#xff0c;并在需要时进行优化。保证客户获得流畅、无忧的用户体验可以大大有助于巩固他们对组织的信任…

SSL证书过期怎么更新?

一、概述 SSL证书是用于加密网站和客户端之间通信的一种数字证书&#xff0c;可以确保数据传输的安全性和保密性。然而&#xff0c;SSL证书是有有效期的&#xff0c;一旦过期就需要及时更新。本文将介绍如何更新SSL证书&#xff0c;以确保网站的安全性和正常运行。 二、SSL证…

【字符串】ABC324E

退役啦&#xff0c;接下来的博客全是图一乐啦 E - Joint Two Strings 题意 思路 统计两个指针的方案数一定是枚举一个&#xff0c;统计另一个 然后因为拼起来之后要包含 t 这个字符串&#xff0c;隐隐约约会感觉到和前缀后缀子序列有关 考虑预处理每个 s[i] 的最长公共前…

gRPC-Gateway:高效转换 RESTful 接口 | 开源日报 No.105

grpc-ecosystem/grpc-gateway Stars: 16.4k License: BSD-3-Clause gRPC-Gateway 是一个遵循 gRPC HTTP 规范的 gRPC 到 JSON 代理生成器。它是 Google 协议缓冲编译器 protoc 的插件&#xff0c;可以读取 protobuf 服务定义并生成反向代理服务器&#xff0c;将 RESTful HTTP…

pycharm中如何去除波浪线的设置

pycharm中&#xff0c;碰到恼人的红绿波浪线&#xff0c;打开’file-settings’&#xff0c;然后&#xff0c;参照如图设置&#xff0c;去除’effects’选项&#xff1a;

【Linux服务器Java环境搭建】09 在CentOS系统中安装和配置clickhouse数据库

一、安装环境 CentOS7 二、官网安装参考文档 官网安装参考文档 不同系统请参考如下建议 从RPM软件包安装&#xff1a; 建议在CentOS、RedHat和所有其他基于rpm的Linux发行版上使用官方预编译的rpm软件包从DEB软件包安装&#xff1a; 建议在Debian或Ubuntu上使用官方预编译…

C语言 联合体验证 主机字节序 +枚举

联合体应用&#xff1a;验证当前主机的大小端&#xff08;字节序&#xff09; //验证当前主机的大小端 #include <stdio.h>union MyData {unsigned int data;struct{unsigned char byte0;unsigned char byte1;unsigned char byte2;unsigned char byte3;}byte; };int main…

详细说说vuex

Vuex 是什么 Vuex有几个属性及作用注意事项vuex 使用举例Vuex3和Vuex4有哪些区别 创建 Store 的方式在组件中使用 Store辅助函数的用法响应式的改进Vuex4 支持多例模式 Vuex 是什么 Vuex是一个专门为Vue.js应用设计的状态管理构架&#xff0c;它统一管理和维护各个Vue组件的可…

lua脚本的基本语法,以及Redis中简单使用

Lua 脚本的基本语法如下&#xff1a; 变量与赋值&#xff1a; variable value变量名可以是字母、数字和下划线的组合&#xff0c;以字母或下划线开头。Lua 是动态类型语言&#xff0c;无需事先声明变量类型。 控制结构&#xff1a; a) 条件语句&#xff1a; if condition the…

【深度学习】Pytorch 系列教程(一):PyTorch数据结构:1、Tensor(张量)及其维度(Dimensions)、数据类型(Data Types)

文章目录 一、前言二、实验环境三、PyTorch数据结构0、分类1、Tensor&#xff08;张量&#xff09;1. 维度&#xff08;Dimensions&#xff09;0维&#xff08;标量&#xff09;1维&#xff08;向量&#xff09;2维&#xff08;矩阵&#xff09;3维张量 2. 数据类型&#xff08…

【AI应用】在VSCode中集成AI编程 ------CodeGeeX智能编程助手

本专栏主要记录人工智能的应用方面的内容,包括chatGPT、AI绘图等等; 在当今AI的热潮下,不学习AI,就要被AI淘汰;所以欢迎小伙伴加入本专栏和我一起探索AI的应用,通过AI来帮助自己提升生产力; 订阅后可私聊我获取 《从零注册并登录使用ChatGPT》《从零开始使用chatGPT的AP…