昇思训练营打卡第二十五天(RNN实现情感分类)

RNN,即循环神经网络(Recurrent Neural Network),是一种深度学习模型,特别适用于处理序列数据。以下是对RNN的简要介绍:

RNN的特点:

  1. 记忆性:与传统的前馈神经网络不同,RNN具有内部状态(记忆),可以捕获到目前为止观察到的序列信息。
  2. 参数共享:在处理序列的不同时间步时,RNN使用相同的权重,这意味着模型的参数数量不会随着输入序列长度的增加而增加。
  3. 灵活性:RNN能够处理任意长度的输入序列。

RNN的结构:

  • 输入层:接收序列中的单个元素。
  • 隐藏层:包含循环单元,这些单元具有记忆功能,能够存储之前的信息。
  • 输出层:根据当前输入和隐藏层的状态输出结果。

RNN的类型:

  1. 简单RNN:基础模型,但容易受到梯度消失和梯度爆炸问题的影响。
  2. LSTM(长短期记忆网络):通过引入门控机制,解决了简单RNN的长期依赖问题。
  3. GRU(门控循环单元):LSTM的变体,结构更简单,但性能相似。

应用场景:

  • 自然语言处理:如语言模型、机器翻译、文本生成等。
  • 语音识别:将语音信号转换为文本。
  • 时间序列预测:如股票价格预测、天气预报等。

数据下载模块

import os
import shutil
import requests
import tempfile
from tqdm import tqdm
from typing import IO
from pathlib import Path# 指定保存路径为 `home_path/.mindspore_examples`
cache_dir = Path.home() / '.mindspore_examples'def http_get(url: str, temp_file: IO):"""使用requests库下载数据,并使用tqdm库进行流程可视化"""req = requests.get(url, stream=True)content_length = req.headers.get('Content-Length')total = int(content_length) if content_length is not None else Noneprogress = tqdm(unit='B', total=total)for chunk in req.iter_content(chunk_size=1024):if chunk:progress.update(len(chunk))temp_file.write(chunk)progress.close()def download(file_name: str, url: str):"""下载数据并存为指定名称"""if not os.path.exists(cache_dir):os.makedirs(cache_dir)cache_path = os.path.join(cache_dir, file_name)cache_exist = os.path.exists(cache_path)if not cache_exist:with tempfile.NamedTemporaryFile() as temp_file:http_get(url, temp_file)temp_file.flush()temp_file.seek(0)with open(cache_path, 'wb') as cache_file:shutil.copyfileobj(temp_file, cache_file)return cache_pathimdb_path = download('aclImdb_v1.tar.gz', 'https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/aclImdb_v1.tar.gz')
imdb_path

加载IMDB数据集

import re
import six
import string
import tarfileclass IMDBData():"""IMDB数据集加载器加载IMDB数据集并处理为一个Python迭代对象。"""label_map = {"pos": 1,"neg": 0}def __init__(self, path, mode="train"):self.mode = modeself.path = pathself.docs, self.labels = [], []self._load("pos")self._load("neg")def _load(self, label):pattern = re.compile(r"aclImdb/{}/{}/.*\.txt$".format(self.mode, label))# 将数据加载至内存with tarfile.open(self.path) as tarf:tf = tarf.next()while tf is not None:if bool(pattern.match(tf.name)):# 对文本进行分词、去除标点和特殊字符、小写处理self.docs.append(str(tarf.extractfile(tf).read().rstrip(six.b("\n\r")).translate(None, six.b(string.punctuation)).lower()).split())self.labels.append([self.label_map[label]])tf = tarf.next()def __getitem__(self, idx):return self.docs[idx], self.labels[idx]def __len__(self):return len(self.docs)
imdb_train = IMDBData(imdb_path, 'train')
len(imdb_train)
import mindspore.dataset as dsdef load_imdb(imdb_path):imdb_train = ds.GeneratorDataset(IMDBData(imdb_path, "train"), column_names=["text", "label"], shuffle=True, num_samples=10000)imdb_test = ds.GeneratorDataset(IMDBData(imdb_path, "test"), column_names=["text", "label"], shuffle=False)return imdb_train, imdb_test
imdb_train, imdb_test = load_imdb(imdb_path)
imdb_train

加载预训练词向量

预训练词向量是对输入单词的数值化表示,通过nn.Embedding层,采用查表的方式,输入单词对应词表中的index,获得对应的表达向量。 因此进行模型构造前,需要将Embedding层所需的词向量和词表进行构造。

import zipfile
import numpy as npdef load_glove(glove_path):glove_100d_path = os.path.join(cache_dir, 'glove.6B.100d.txt')if not os.path.exists(glove_100d_path):glove_zip = zipfile.ZipFile(glove_path)glove_zip.extractall(cache_dir)embeddings = []tokens = []with open(glove_100d_path, encoding='utf-8') as gf:for glove in gf:word, embedding = glove.split(maxsplit=1)tokens.append(word)embeddings.append(np.fromstring(embedding, dtype=np.float32, sep=' '))# 添加 <unk>, <pad> 两个特殊占位符对应的embeddingembeddings.append(np.random.rand(100))embeddings.append(np.zeros((100,), np.float32))vocab = ds.text.Vocab.from_list(tokens, special_tokens=["<unk>", "<pad>"], special_first=False)embeddings = np.array(embeddings).astype(np.float32)return vocab, embeddings
glove_path = download('glove.6B.zip', 'https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/glove.6B.zip')
vocab, embeddings = load_glove(glove_path)
len(vocab.vocab())
idx = vocab.tokens_to_ids('the')
embedding = embeddings[idx]
idx, embedding

数据集预处理

通过加载器加载的IMDB数据集进行了分词处理,但不满足构造训练数据的需要,因此要对其进行额外的预处理。其中包含的预处理如下:

  • 通过Vocab将所有的Token处理为index id。
  • 将文本序列统一长度,不足的使用<pad>补齐,超出的进行截断。
  • import mindspore as mslookup_op = ds.text.Lookup(vocab, unknown_token='<unk>')
    pad_op = ds.transforms.PadEnd([500], pad_value=vocab.tokens_to_ids('<pad>'))
    type_cast_op = ds.transforms.TypeCast(ms.float32)
    imdb_train = imdb_train.map(operations=[lookup_op, pad_op], input_columns=['text'])
    imdb_train = imdb_train.map(operations=[type_cast_op], input_columns=['label'])imdb_test = imdb_test.map(operations=[lookup_op, pad_op], input_columns=['text'])
    imdb_test = imdb_test.map(operations=[type_cast_op], input_columns=['label'])
    imdb_train, imdb_valid = imdb_train.split([0.7, 0.3])
    imdb_train = imdb_train.batch(64, drop_remainder=True)
    imdb_valid = imdb_valid.batch(64, drop_remainder=True)

    Embedding

    Embedding层又可称为EmbeddingLookup层,其作用是使用index id对权重矩阵对应id的向量进行查找,当输入为一个由index id组成的序列时,则查找并返回一个相同长度的矩阵

  • RNN(循环神经网络)

    循环神经网络(Recurrent Neural Network, RNN)是一类以序列(sequence)数据为输入,在序列的演进方向进行递归(recursion)且所有节点(循环单元)按链式连接的神经网络。

Dense

在经过LSTM编码获取句子特征后,将其送入一个全连接层,即nn.Dense,将特征维度变换为二分类所需的维度1,经过Dense层后的输出即为模型预测结果。

import math
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.common.initializer import Uniform, HeUniformclass RNN(nn.Cell):def __init__(self, embeddings, hidden_dim, output_dim, n_layers,bidirectional, pad_idx):super().__init__()vocab_size, embedding_dim = embeddings.shapeself.embedding = nn.Embedding(vocab_size, embedding_dim, embedding_table=ms.Tensor(embeddings), padding_idx=pad_idx)self.rnn = nn.LSTM(embedding_dim,hidden_dim,num_layers=n_layers,bidirectional=bidirectional,batch_first=True)weight_init = HeUniform(math.sqrt(5))bias_init = Uniform(1 / math.sqrt(hidden_dim * 2))self.fc = nn.Dense(hidden_dim * 2, output_dim, weight_init=weight_init, bias_init=bias_init)def construct(self, inputs):embedded = self.embedding(inputs)_, (hidden, _) = self.rnn(embedded)hidden = ops.concat((hidden[-2, :, :], hidden[-1, :, :]), axis=1)output = self.fc(hidden)return output
hidden_size = 256
output_size = 1
num_layers = 2
bidirectional = True
lr = 0.001
pad_idx = vocab.tokens_to_ids('<pad>')model = RNN(embeddings, hidden_size, output_size, num_layers, bidirectional, pad_idx)
loss_fn = nn.BCEWithLogitsLoss(reduction='mean')
optimizer = nn.Adam(model.trainable_params(), learning_rate=lr)
def forward_fn(data, label):logits = model(data)loss = loss_fn(logits, label)return lossgrad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters)def train_step(data, label):loss, grads = grad_fn(data, label)optimizer(grads)return lossdef train_one_epoch(model, train_dataset, epoch=0):model.set_train()total = train_dataset.get_dataset_size()loss_total = 0step_total = 0with tqdm(total=total) as t:t.set_description('Epoch %i' % epoch)for i in train_dataset.create_tuple_iterator():loss = train_step(*i)loss_total += loss.asnumpy()step_total += 1t.set_postfix(loss=loss_total/step_total)t.update(1)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/45596.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

老板新招的牛人,竟然用1天搭建了一套完整的仓库管理系统!

仓储管理系统是什么&#xff1f; 仓储管理系统&#xff08;WMS&#xff09;是一个全面的软件解决方案&#xff0c;旨在帮助企业优化仓库管理流程、管理和控制日常仓库运营。通过数学模型和信息手段&#xff0c;对仓库管理的各个环节进行优化和调控&#xff0c;涵盖了从货物入库…

使用网关和Spring Security进行认证和授权

个人名片 &#x1f393;作者简介&#xff1a;java领域优质创作者 &#x1f310;个人主页&#xff1a;码农阿豪 &#x1f4de;工作室&#xff1a;新空间代码工作室&#xff08;提供各种软件服务&#xff09; &#x1f48c;个人邮箱&#xff1a;[2435024119qq.com] &#x1f4f1…

jquery发送jsonp请求

使用 jQuery 发送 JSONP 请求相对来说比较简单&#xff0c;以下是示例代码&#xff1a; $.ajax({url: "http://example.com/data",dataType: "jsonp",jsonp: "callback",jsonpCallback: "myCallback" }).done(function(response) {//…

Linux命令更新-sort 和 uniq 命令

简介 sort 和 uniq 都是 Linux 系统中常用的文本处理命令。 sort 命令用于对文件内容进行排序。 uniq 命令用于去除文件中重复出现的行。 1. sort 命令 命令格式 sort [选项] [文件]选项&#xff1a; -n: 按照数字进行排序 -r: 反向排序 -c: 统计每个元素出现的次数 -…

怎么录制视频?电脑录制,试试这3种方法

在数字化快速发展的时代&#xff0c;视频已经成为我们传递信息、分享生活、表达情感的重要载体。每一个人都希望自己能够掌握视频录制技巧&#xff0c;轻松驾驭影像的力量&#xff0c;创造出属于自己的视觉盛宴。 那么&#xff0c;怎么录制视频呢&#xff1f;首先选择一款好用…

vue脚手架配置代理请求

在 Vue 脚手架中&#xff0c;可以通过配置vue.config.js文件来设置代理请求&#xff0c;以解决跨域问题或实现其他代理需求。以下是两种常见的配置方式&#xff1a; 方法一&#xff1a; 在vue.config.js中添加如下配置&#xff1a; module.exports {devServer: {proxy: http…

《信息与电脑(理论版)》是什么级别的期刊?是正规期刊吗?能评职称吗?

问题解答 问&#xff1a;《信息与电脑(理论版)》是不是核心期刊&#xff1f; 答&#xff1a;不是&#xff0c;是知网收录的正规学术期刊。 问&#xff1a;《信息与电脑(理论版)》级别&#xff1f; 答&#xff1a;省级。主管单位&#xff1a;北京电子控股有限责任公司 主办…

AI安全入门-人工智能数据与模型安全

参考 人工智能数据与模型安全 from 复旦大学视觉与学习实验室 文章目录 0. 计算机安全学术知名公众号1. 概述数据安全模型安全 3. 人工智能安全基础3.1 基本概念攻击者攻击方法受害者受害数据受害模型防御者防御方法威胁模型目标数据替代数据替代模型 3.2 威胁模型3.2.1 白盒威…

实践致知第16享:设置Word中某一页横着的效果及操作

一、背景需求 小姑电话说&#xff1a;现在有个word文档,里面有个表格太长&#xff08;如下图所示&#xff09;&#xff0c;希望这一个设置成横的&#xff0c;其余页还是保持竖的&#xff01; 二、解决方案 1、将鼠标放置在该页的最前面闪烁&#xff0c;然后选择“页面”》“↘…

Python面经

文章目录 Python基本概念1. Python是**解释型**语言还是**编译型**语言2. Python是**面向对象**语言还是面向过程语言3. Python基本数据类型4.append和 extend区别5.del、pop和remove区别6. sort和sorted区别介绍一下Python 中的字符串编码is 和 的区别*arg 和**kwarg作用浅拷…

Electron 进程间通信

文章目录 渲染进程到主进程&#xff08;单向&#xff09;渲染进程到主进程&#xff08;双向&#xff09;主进程到渲染进程 &#xff08;单向&#xff0c;可模拟双向&#xff09; 渲染进程到主进程&#xff08;单向&#xff09; send &#xff08;render 发送&#xff09;on &a…

【Stable Diffusion】(基础篇三)—— 图生图基础

图生图基础 本系列笔记主要参考B站nenly同学的视频教程&#xff0c;传送门&#xff1a;B站第一套系统的AI绘画课&#xff01;零基础学会Stable Diffusion&#xff0c;这绝对是你看过的最容易上手的AI绘画教程 | SD WebUI 保姆级攻略_哔哩哔哩_bilibili 本文主要讲解如何使用S…

客户端与服务端之间的通信连接

目录 那什么是Socket? 什么是ServerSocket? 代码展示&#xff1a; 代码解析&#xff1a; 补充&#xff1a; 输入流&#xff08;InputStream&#xff09;&#xff1a; 输出流&#xff08;OutputStream&#xff09;&#xff1a; BufferedReader 是如何提高读取效率的&a…

K8s集群初始化遇到的问题

kubectl describe pod coredns-545d6fc579-s9g5s -n kube-system 找到原因1&#xff1a;CoreDNS Pod 处于 Pending 状态的原因是集群中的节点都带有 node.kubernetes.io/not-ready 污点 journalctl -u kubelet -f 14:57:59.178592 3553 remote_image.go:114] "PullIma…

《简历宝典》12 - 简历中“项目经历”,内功学习 - 下篇

这一小节呢&#xff0c;我们继续说简历中 “项目经历” 的一些内功心法。因为项目经历比较核心&#xff0c;所以说完了&#xff0c;内功呢&#xff0c;我们会着重说一下 实战部分。 目录 1 所用技术的考虑 2 自我成长的突出 3 综合使用STAR法则 4 小节 1 所用技术的考虑 …

如何评估AI模型:评估指标的分类、方法及案例解析

如何评估AI模型&#xff1a;评估指标的分类、方法及案例解析 引言第一部分&#xff1a;评估指标的分类第二部分&#xff1a;评估指标的数学基础第三部分&#xff1a;评估指标的选择与应用第四部分&#xff1a;评估指标的局限性第五部分&#xff1a;案例研究第六部分&#xff1a…

pear-admin-fast项目修改为集成PostgreSQL启动

全局搜索代码中的sysdate()&#xff0c;修改为now() 【前者是mysql特有的&#xff0c;后者是postgre特有的】修改application-dev.yml中的数据库url使用DBeaver把mysql中的数据库表导出csv&#xff0c;再从postgre中导入csv脚本转换后出现了bpchar(xx)类型&#xff0c;那么一定…

用友U8存货分类按层级取数SQL语句

SELECT cInvCCode 分类编码, cInvCName 分类名称, iInvCGrade 分类层级, ss.bInvCEnd 是否是末级, aa.* FROM InventoryClass ss LEFT JOIN ( SELECT * FROM ( SELECT cInvCCode AS 一级分类编码, …

python数据可视化(6)——绘制散点图

课程学习来源&#xff1a;b站up&#xff1a;【蚂蚁学python】 【课程链接&#xff1a;【【数据可视化】Python数据图表可视化入门到实战】】 【课程资料链接&#xff1a;【链接】】 Python绘制散点图查看BMI与保险费的关系 散点图: 用两组数据构成多个坐标点&#xff0c;考察…

如何降低老年人患帕金森病的风险?

降低老年人患帕金森病风险的方法 避免接触有害物质&#xff1a;长期接触某些化学物质、农药或其他有害物质可能会增加患帕金森病的风险。应减少这些物质的暴露&#xff0c;例如在工作或生活中采取防护措施。 健康饮食&#xff1a;均衡饮食&#xff0c;多吃富含抗氧化剂的食物&a…