NLP(8)--利用RNN实现多分类任务

前言

仅记录学习过程,有问题欢迎讨论

循环神经网络RNN(recurrent neural network):
  • 主要思想:将整个序列划分成多个时间步,将每一个时间步的信息依次输入模型,同时将模型输出的结果传给下一个时间步
  • 自带了tanh的激活函数

代码

发现RNN效率高很多

import json
import randomimport numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.utils.data as Data"""
构建一个 用RNN实现的 判断某个字符的位置 的任务5 分类任务 判断 a出现的位置 返回index +1 or -1
"""class TorchModel(nn.Module):def __init__(self, sentence_length, hidden_size, vocab, input_dim, output_size):super(TorchModel, self).__init__()#self.emb = nn.Embedding(len(vocab) + 1, input_dim)self.rnn = nn.RNN(input_dim, hidden_size, batch_first=True)self.pool = nn.MaxPool1d(sentence_length)self.leaner = nn.Linear(hidden_size, output_size)self.loss = nn.functional.cross_entropydef forward(self, x, y=None):# x = 15 * 4x = self.emb(x)  # output = 15 * 4 * 10x, h = self.rnn(x)  # output = 15 * 4 * 20 h = 1*15*20x = self.pool(x.transpose(1, 2)).squeeze()  # output = 15 * 20 * (1,被去除)y_pred = self.leaner(x)  # output = 15 * 5if y is not None:return self.loss(y_pred, y)else:return y_pred# 创建字符集 只有6个 希望a出现的概率大点def build_vocab():chars = "abcdef"vocab = {}for index, char in enumerate(chars):vocab[char] = index + 1# vocab['unk'] = len(vocab) + 1return vocab# 构建样本集
def build_dataset(vocab, data_size, sentence_length):dataset_x = []dataset_y = []for i in range(data_size):x, y = build_simple(vocab, sentence_length)dataset_x.append(x)dataset_y.append(y)return torch.LongTensor(dataset_x), torch.LongTensor(dataset_y)# 构建样本
def build_simple(vocab, sentence_length):# 随机生成 长度为4的字符串x = [random.choice(list(vocab.keys())) for _ in range(sentence_length)]if x.count('a') != 0:y = x.index('a')else:y = 4# 转化为 数字x = [vocab[char] for char in list(x)]return x, ydef main():batch_size = 15simple_size = 500vocab = build_vocab()# 每个样本的长度为4sentence_length = 4# 样本的向量维度为10input_dim = 10# rnn的隐藏层 随便设置为20hidden_size = 20# 5 分类任务output_size = 5# 学习率lr = 0.02# 轮次epoch_size = 25model = TorchModel(sentence_length, hidden_size, vocab, input_dim, output_size)# 优化函数optim = torch.optim.Adam(model.parameters(), lr=lr)# 样本x, y = build_dataset(vocab, simple_size, sentence_length)dataset = Data.TensorDataset(x, y)dataiter = Data.DataLoader(dataset, batch_size)for epoch in range(epoch_size):epoch_loss = []model.train()for x, y_true in dataiter:loss = model(x, y_true)loss.backward()optim.step()optim.zero_grad()epoch_loss.append(loss.item())print("第%d轮 loss = %f" % (epoch + 1, np.mean(epoch_loss)))# evaluateacc = evaluate(model, vocab, sentence_length)  # 测试本轮模型结果return# 评估效果
def evaluate(model, vocab, sentence_length):model.eval()x, y = build_dataset(vocab, 200, sentence_length)correct, wrong = 0, 0with torch.no_grad():y_pred = model(x)for y_p, y_t in zip(y_pred, y):  # 与真实标签进行对比if int(torch.argmax(y_p)) == int(y_t):correct += 1  # 正样本判断正确else:wrong += 1print("正确预测个数:%d / %d, 正确率:%f" % (correct, correct + wrong, correct / (correct + wrong)))return correct / (correct + wrong)if __name__ == '__main__':main()

可以对model 优化一下

 def __init__(self, sentence_length, hidden_size, vocab, input_dim, output_size):super(TorchModel, self).__init__()# Embedding 层 变为稀疏self.emb = nn.Embedding(len(vocab) + 1, input_dim)self.rnn = nn.RNN(input_dim, input_dim, batch_first=True)self.pool = nn.AvgPool1d(sentence_length)self.leaner = nn.Linear(input_dim, sentence_length + 1)self.loss = nn.functional.cross_entropydef forward(self, x, y=None):# x = 15 * 4x = self.emb(x)  # output = 15 * 4 * 10x, h = self.rnn(x)  # output = 15 * 4 * 20 h = 1*15*20# x = self.pool(x.transpose# (1, 2)).squeeze()  # output = 15 * 20 * (1,被去除)# rnn 最后一维度包含之前所有信息h = h.squeeze()y_pred = self.leaner(h)  # output = 15 * 5if y is not None:return self.loss(y_pred, y)else:return y_pred

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/3574.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JavaEE——spring MVC请求处理

目录 主要目的: 1. Spring web 项目搭建 2. 添加依赖 3. 配置插件 4. 配置设置类 5. 编写controller层类 6. 编写测试的http请求 主要目的: 创建一个spring web项目; 创建控制类; 掌握如何配置MVC; 编写htt…

【机器学习-18】特征筛选:提升模型性能的关键步骤

一、引言 在机器学习领域,特征筛选是一个至关重要的预处理步骤。随着数据集的日益庞大和复杂,特征的数量往往也随之激增。然而,并非所有的特征都对模型的性能提升有所贡献,有些特征甚至可能是冗余的、噪声较大的或者与目标变量无关…

在Visual Studio中查看C项目使用的C语言版本

在Visual Studio中查看C项目使用的C语言版本,可以通过以下步骤进行: 打开Visual Studio。 打开你的C项目。 右键点击项目名称,选择“属性”。 在弹出的属性页中,找到“配置属性” -> “C/C” -> “语言”。 在右侧的“…

(十三)PostgreSQL的扩展(extensions)

PostgreSQL的扩展(extensions) 基础信息 OS版本:Red Hat Enterprise Linux Server release 7.9 (Maipo) DB版本:16.2 pg软件目录:/home/pg16/soft pg数据目录:/home/pg16/data 端口:5777在Post…

Django框架之python后端框架介绍

一、网络框架及MVC、MTV模型 1、网络框架 网络框架(Web framework)是一种软件框架,用于帮助开发人员构建Web应用程序和Web服务。它提供了一系列预先编写好的代码和工具,以简化开发过程并提高开发效率。网络框架通常包括以下功能…

常用组件(启停活动页面、活动之间传递信息、收发应用广播、操作后台服务)

启停活动页面 Activity的启动和结束 页面跳转可以使用startActivity接口,具体格式为startActivity(new Intent(this, 目标页面.class));。 关闭一个页面可以直接调用finish();方法即可退出页面。 Activity的生命周期 页面在安卓有个新的名字叫活动,因…

充电机是什么?其技术原理和行业应用

充电机是一种能够为电池充电的设备,通常由一个变压器和整流器组成。变压器将电网中的交流电转换为直流电,而整流器则将直流电转换为稳定的直流电,这种直流电可以被用来给电池充电。 充电机可以分为很多种不同类型,包括输入输出式、输入输出隔离式和车载充电机等。不同类型的充…

Vue2与Vue3实例的深入比较:响应式系统、模板编译和性能分析

I. 响应式系统的差异 A. Vue2的响应式系统 数据劫持(Object.defineProperty) Vue2的核心响应式机制依赖于JavaScript的Object.defineProperty方法。这个方法允许开发者为对象的属性提供getter和setter,从而实现对属性访问和修改的监控。当…

GoLand 2021.1.3 下载与安装

当前环境:Windows 8.1 x64 1 浏览器打开网站 https://www.jetbrains.com/go/download/other.html 找到 2021.1.3 版本。 2 解压 goland-2021.1.3.win.zip 到 goland-2021.1.3.win。 3 打开 bin 目录下的 goland64.exe,选择 Evaluate for free -- Evalu…

论文解读-面向高效生成大语言模型服务:从算法到系统综述

一、简要介绍 在快速发展的人工智能(AI)领域中,生成式大型语言模型(llm)站在了最前沿,彻底改变了论文与数据交互的方式。然而,部署这些模型的计算强度和内存消耗在服务效率方面带来了重大挑战&a…

Linux CentOS 7 服务器集群硬件常用查看命令

(一)查看内核:uname -a [rootcdh1 ~]# uname -a Linux cdh1.macro.com 3.10.0-1062.el7.x86_64 #1 SMP Wed Aug 7 18:08:02 UTC 2019 x86_64 x86_64 x86_64 GNU/Linux(二)查看系统:cat /etc/redhat-releas…

react-创建组件的两种方式

一、函数式组件 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>hello_react</title> </he…

ABS8-ASEMI新能源专用整流桥ABS8

编辑&#xff1a;ll ABS8-ASEMI新能源专用整流桥ABS8 型号&#xff1a;KBL410 品牌&#xff1a;ASEMI 封装&#xff1a;ABS-4 最大重复峰值反向电压&#xff1a;800V 最大正向平均整流电流(Vdss)&#xff1a;1A 功率(Pd)&#xff1a;小功率 芯片个数&#xff1a;4 引脚…

夜神、雷电、android studio手机模拟器资源占用情况

夜神、雷电、android studio手机模拟器内存资源占用情况 由于开发电脑只有16G内存&#xff0c;出于开发需要和本身硬件资源的限制&#xff0c;对多个手机模拟器进行了机器资源占用&#xff08;主要是内存&#xff09;的简单比较。 比较的模拟器包括&#xff1a; 1. Android S…

vue2知识点————(vue插槽,透传 Attributes )

vue 插槽 插槽&#xff08;slot&#xff09;是一种强大的特性&#xff0c;允许在组件的模板中定义带有特定用途的“插槽”&#xff0c;然后在组件的使用者中填充内容。插槽能够使组件更加灵活&#xff0c;让组件的结构更容易复用和定 具名插槽&#xff08;Named Slots&#x…

PHP利用phpmailer实现邮件发送功能

要在PHP中实现发送邮件验证码的功能,你需要使用一些特定的库来帮助你处理邮件发送的任务。PHPMailer是一个常用的库,它可以帮助你轻松地发送电子邮件。 以下是一个简单的例子,展示了如何使用PHPMailer库来发送包含验证码的电子邮件: 首先,你需要安装PHPMailer库。你可以通…

微信小程序有的机型无法播放m3u8格式的直播流,使用H5在微信环境里播放

我这测试鸿蒙的还有苹果X及部分机型在微信小程序里无法播放&#xff0c;不知道什么原因&#xff1b; 直播流地址有的是hevc有的是h.264&#xff0c;音频都是aac&#xff1b; <head><meta charset"UTF-8"><title>前端播放m3u8格式视频</title&g…

MATLAB 向量

MATLAB 向量 向量是一维数字数组。MATLAB允许创建两种类型的向量 行向量 列向量 行向量 行向量通过将元素集括在方括号中并使用空格或逗号定界元素来创建。 示例 r [7 8 9 10 11] MATLAB将执行上述语句并返回以下结果- r 7 8 9 10 11 列向量 列向量 通过将元素集括在方…

c++ 新特性 std::bind 简单实验

1.概要 std::bind 是 C11 引入的一个功能&#xff0c;它用于绑定函数/可调用对象到特定的参数&#xff0c;并生成一个新的可调用对象。这个新的可调用对象可以稍后调用&#xff0c;就像调用原始函数/可调用对象一样&#xff0c;但会带有预先绑定的参数。 std::placeholders::…

操作系统安全:Linux安全审计,Linux日志详解

「作者简介」&#xff1a;2022年北京冬奥会网络安全中国代表队&#xff0c;CSDN Top100&#xff0c;就职奇安信多年&#xff0c;以实战工作为基础对安全知识体系进行总结与归纳&#xff0c;著作适用于快速入门的 《网络安全自学教程》&#xff0c;内容涵盖系统安全、信息收集等…