【深度学习】循环神经网络及文本生成模型构建

循环神经网络

词嵌入层

词嵌入层的作用就是将文本转换为向量。
​ 词嵌入层首先会根据输入的词的数量构建一个词向量矩阵,例如: 我们有 100 个词,每个词希望转换成 128 维度的向量,那么构建的矩阵形状即为: 100*128,输入的每个词都对应了一个该矩阵中的一个向量.

在这里插入图片描述

在 PyTorch 中,使用 nn.Embedding 词嵌入层来实现输入词的向量化。

nn.Embedding(num_embeddings=10, embedding_dim=4)
  • nn.Embedding 对象构建时,最主要有两个参数:
    • num_embeddings 表示词的数量
    • embedding_dim 表示用多少维的向量来表示每个词

​ 接下来,我们将会学习如何将词转换为词向量,其步骤如下:
​ 先将语料进行分词,构建词与索引的映射,我们可以把这个映射叫做词表,词表中每个词都对应了一个唯一的索引
​ 然后使用 nn.Embedding 构建词嵌入矩阵,词索引对应的向量即为该词对应的数值化后的向量表示。
​ 例如,我们的文本数据为: “北京冬奥的进度条已经过半,不少外国运动员在完成自己的比赛后踏上归途。”,

import torch
import torch.nn as nn
import jiebaif __name__ == '__main__':# 0.文本数据text = '北京冬奥的进度条已经过半,不少外国运动员在完成自己的比赛后踏上归途。'# 1. 文本分词words = jieba.lcut(text)print('文本分词:', words)# 2.分词去重并保留原来的顺序获取所有的词语unique_words = list(set(words))print("去重后词的个数:\n",len(unique_words))# 3. 构建词嵌入层:num_embeddings: 表示词的总数量;embedding_dim: 表示词嵌入的维度embed = nn.Embedding(num_embeddings=len(unique_words), embedding_dim=4)print("词嵌入的结果:\n",embed)# 4. 词语的词向量表示for i, word in enumerate(unique_words):# 获得词嵌入向量word_vec = embed(torch.tensor(i))print('%3s\t' % word, word_vec)

在这里插入图片描述

RNN网络原理

​ 文本数据是具有序列特性的
​ 例如: “我爱你”, 这串文本就是具有序列关系的,“爱” 需要在 “我” 之后,“你” 需要在 “爱” 之后, 如果颠倒了顺序,那么可能就会表达不同的意思。
​ 为了表示出数据的序列关系,需要使用循环神经网络(Recurrent Nearal Networks, RNN) 来对数据进行建模,RNN 是一个作用于处理带有序列特点的样本数据。
在这里插入图片描述

​ h 表示隐藏状态,
​ 每一次的输入都会包含两个值: 上一个时间步的隐藏状态、当前状态的输入值,输出当前时间步的隐藏状态和当前时间步的预测结果。上图有三个神经元处理’我爱你’这三个字,实际上是一个他们三个字重复输入到同一个神经元

在这里插入图片描述

​ 我们举个例子来理解上图的工作过程,假设我们要实现文本生成,也就是输入 “我爱” 这两个字,来预测出 “你”,其如下图所示:

在这里插入图片描述

将上图展开成不同时间步的形式,如下图所示:

在这里插入图片描述

​ 首先初始化出第一个隐藏状态h0,一般都是全0的一个向量,然后将 “我” 进行词嵌入,转换为向量的表示形式,送入到第一个时间步,然后输出隐藏状态 h1,然后将 h1 和 “爱” 输入到第二个时间步,得到隐藏状态 h2, 将 h2 送入到全连接网络,得到 “你” 的预测概率。

在这里插入图片描述

上述公式中:
Wih 表示输入数据的权重
bih 表示输入数据的偏置
Whh 表示输入隐藏状态的权重
bhh 表示输入隐藏状态的偏置
最后对输出的结果使用 tanh 激活函数进行计算,得到该神经元你的输出。

在这里插入图片描述

Pytorch RNN层的使用

  • 输入数据和输出结果
    将RNN实例化就可以将数据送入其中进行处理,处理的方式如下所示:
output, hn = RNN(x, h0)
  • 输入数据:输入主要包括词嵌入的x 、初始的隐藏层h0
    • x的表示形式为[seq_len, batch, input_size],即[句子的长度,batch的大小,词向量的维度]
    • h0的表示形式为[num_layers, batch, hidden_size],即[隐藏层的层数,batch的大,隐藏层h的维数]\
  • 输出结果:主要包括输出结果output,最后一层的hn
    • output的表示形式与输入x类似,为[seq_len, batch, hidden_size],即[句子的长度,batch的大小,输出向量的维度]
    • hn的表示形式与输入h0一样,为[num_layers, batch, hidden_size],即[隐藏层的层数,batch的大,隐藏层h的维度]
import torch
import torch.nn as nn#  RNN层送入批量数据
def test():# 词向量维度 128, 隐藏向量维度 256rnn = nn.RNN(input_size=128, hidden_size=256)# 第一个数字: 表示句子长度,也就是词语个数# 第二个数字: 批量个数,也就是句子的个数# 第三个数字: 词向量维度inputs = torch.randn(5, 32, 128)hn = torch.zeros(1, 32, 256)# 获取输出结果output, hn = rnn(inputs, hn)print("输出向量的维度:\n",output.shape)print("隐含层输出的维度:\n",hn.shape)if __name__ == '__main__':test()

RNN及其变体LSTM、GRU

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

文本生成模型构建

项目需求

​ 文本生成任务是一种常见的自然语言处理任务,输入一个开始词能够预测出后面的词序列。本案例将会使用循环神经网络来实现周杰伦歌词生成任务。

import torch
import re
import jieba
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import timedef build_dataset():"""获取数据集:return: unique_words, word_to_idx, count, corpus_idx"""# 数据集位置file_name = '../data/jaychou_lyrics.txt'# 分词结果存储位置unique_words = []all_words = []# 遍历数据集中的每一行文本for line in open(file_name, 'r', encoding='utf-8'):# 使用jieba分词,分割结果是一个列表words = jieba.lcut(line)# print(words)# 所有的分词结果存储到all_sentences,其中包含重复的词组all_words.append(words)# 遍历分词结果,去重后存储到unique_wordsfor word in words:if word not in unique_words:unique_words.append(word)# 语料中词的数量word_count = len(unique_words)# 词到索引映射word_to_index = {word: idx for idx, word in enumerate(unique_words)}# 词表索引表示corpus_idx = []# 遍历每一行的分词结果for words in all_words:temp = []# 获取每一行的词,并获取相应的索引for word in words:temp.append(word_to_index[word])# 在每行词之间添加空格隔开temp.append(word_to_index[' '])# 获取当前文档中每个词对应的索引corpus_idx.extend(temp)return unique_words, word_to_index, word_count, corpus_idxclass LyricsDataset(torch.utils.data.Dataset):def __init__(self, corpus_idx, num_chars):# 文档数据中词的索引self.corpus_idx = corpus_idx# 每个句子中词的个数self.num_chars = num_chars# 词的数量self.word_count = len(self.corpus_idx)# 句子数量self.number = self.word_count // self.num_charsdef __len__(self):# 返回句子数量return self.numberdef __getitem__(self, idx):# idx指词的索引,并将其修正索引值到文档的范围里面start = min(max(idx, 0), self.word_count - self.num_chars - 2)# 输入值x = self.corpus_idx[start: start + self.num_chars]# 网络预测结果(目标值)y = self.corpus_idx[start + 1: start + 1 + self.num_chars]# 返回结果return torch.tensor(x), torch.tensor(y)# 模型构建
class TextGenerator(nn.Module):def __init__(self, word_count):super(TextGenerator, self).__init__()# 初始化词嵌入层: 词向量的维度为128self.ebd = nn.Embedding(word_count, 128)# 循环网络层: 词向量维度 128, 隐藏向量维度 128, 网络层数1self.rnn = nn.RNN(128, 128, 1)# 输出层: 特征向量维度128与隐藏向量维度相同,词表中词的个数self.out = nn.Linear(128, word_count)def forward(self, inputs, hidden):# 输出维度: (batch, seq_len,词向量维度 128)embed = self.ebd(inputs)# 修改维度: (seq_len, batch,词向量维度 128)output, hidden = self.rnn(embed.transpose(0, 1), hidden)# 输入维度: (seq_len*batch,词向量维度 ) 输出维度: (seq_len*batch, 128)output = self.out(output.reshape((-1, output.shape[-1])))# 网络输出结果return output, hiddendef init_hidden(self, bs):# 隐藏层的初始化:[网络层数, batch, 隐藏层向量维度]return torch.zeros(1, bs, 128)# 模型训练
def train():# 构建词典index_to_word, word_to_index, word_count, corpus_idx = build_dataset()# 数据集dataset = LyricsDataset(corpus_idx, 32)# 初始化模型model = TextGenerator(word_count)# 损失函数criterion = nn.CrossEntropyLoss()# 优化方法optimizer = optim.Adam(model.parameters(), lr=1e-3)# 训练轮数epoch = 20for epoch_idx in range(epoch):# 数据加载器lyrics_dataloader = DataLoader(dataset, shuffle=True, batch_size=2)# 训练时间start = time.time()iter_num = 0  # 迭代次数# 训练损失total_loss = 0.0# 遍历数据集for x, y in lyrics_dataloader:# 隐藏状态的初始化hidden = model.init_hidden(x.size(0))# 模型计算output, hidden = model(x, hidden)# 计算损失# y:[batch,seq_len]->[seq_len,batch]->[seq_len*batch]y = torch.transpose(y, 0, 1).contiguous().view(-1)loss = criterion(output, y)optimizer.zero_grad()loss.backward()optimizer.step()iter_num += 1  # 迭代次数加1total_loss += loss.item()# 打印训练信息print('epoch %3s loss: %.5f time %.2f' % (epoch_idx + 1, total_loss / iter_num, time.time() - start))# 模型存储torch.save(model.state_dict(), '../model/lyrics_model_%d.pth' % epoch)def predict(start_word, sentence_length):# 构建词典index_to_word, word_to_index, word_count, _ = build_dataset()# 构建模型model = TextGenerator(word_count)# 加载参数model.load_state_dict(torch.load('../model/lyrics_model_10.pth'))# 隐藏状态hidden = model.init_hidden(bs=1)# 将起始词转换为索引word_idx = word_to_index[start_word]# 产生的词的索引存放位置generate_sentence = [word_idx]# 遍历到句子长度,获取每一个词for _ in range(sentence_length):# 模型预测output, hidden = model(torch.tensor([[word_idx]]), hidden)# 获取预测结果word_idx = torch.argmax(output)generate_sentence.append(word_idx)# 根据产生的索引获取对应的词,并进行打印for idx in generate_sentence:print(index_to_word[idx], end='')if __name__ == '__main__':# train()predict('回忆', 100)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/61328.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

51单片机基础01 单片机最小系统

目录 一、什么是51单片机 二、51单片机的引脚介绍 1、VCC GND 2、XTAL1 2 3、RST 4、EA 5、PSEN 6、ALE 7、RXD、TXD 8、INT0、INT1 9、T0、T1 10、MOSI、MISO、SCK 11、WR、RD 12、通用IO P0 13、通用IO P1 14、通用IO P2 三、51单片机的最小系统 1、供电与…

vue 录音流程

vue 录音流程 RecordRTC npm install recordrtcimport RecordRTC from "recordrtc";<!-- 音频播放器&#xff0c;用于播放录音 --> <audio v-show"false" ref"audioPlayer" controls></audio>async startRecording() {// 检查…

QT使用libssh2库实现sftp文件传输

本篇文章通过用户名和密码来连接服务器端,通过密匙连接服务器端可以参考另外一篇文章: https://blog.csdn.net/u012372584/article/details/143826199?sharetype=blogdetail&sharerId=143826199&sharerefer=PC&sharesource=u012372584&spm=1011.2480.3001.…

【前端知识】前端打包工具webpack深度解读

webpackandesign搭建前端脚手架 webpack概述一、核心功能二、主要特点三、核心概念四、使用场景五、安装与配置六、常用命令 配置文件详解一、基本结构二、主要配置项及其作用三、示例配置 加载器一、加载器的定义与作用二、常见的加载器类型及作用三、加载器的配置与使用四、加…

用vscode编写verilog时,如何有信号定义提示、信号定义跳转(go to definition)、模块跳转(跨文件跳转)这些功能

&#xff08;一&#xff09;方法一&#xff1a;安装插件SystemVerilog - Language Support 安装一个vscode插件即可&#xff0c;插件叫SystemVerilog - Language Support。虽然说另一个插件“Verilog-HDL/SystemVerilog/Bluespec SystemVerilog”也有信号提示及定义跳转功能&am…

从零开始搭建Java开发环

目录 引言 一、JDK安装 二、IDE选择与配置 三、构建工具配置 四、测试环境搭建 五、其他建议 引言 随着Java技术的不断进步与应用范围的不断扩大&#xff0c;越来越多的开发者加入到了Java开发的行列。一个高效稳定的开发环境是提高开发效率的基础。本文将详细介绍如何从零…

uniapp vue3小程序报错Cannot read property ‘__route__‘ of undefined

在App.vue里有监听应用的生命周期 <script>// 只能在App.vue里监听应用的生命周期export default {onError: function(err) {console.log(AppOnError:, err); // 当 uni-app 报错时触发}} </script>在控制台打印里无意发现 Cannot read property ‘__route__‘ of …

Vue3插槽v-slot使用方式

在 Vue 3 中&#xff0c;v-slot 是用来定义和使用插槽的指令。插槽是 Vue 的一个功能&#xff0c;允许你在组件内部定义占位内容&#xff0c;便于在父组件中提供动态内容。以下是 v-slot 的详细使用方法&#xff1a; 1. 基础使用 <template><BaseComponent><te…

Android 网络请求(二)OKHttp网络通信

学习笔记 OkHttp 是一个非常强大且流行的 HTTP 客户端库&#xff0c;广泛用于 Android 开发中进行网络请求。与 HttpURLConnection 相比&#xff0c;OkHttp 提供了更简单、更高效的 API&#xff0c;特别是在处理复杂的 HTTP 请求时。 如何使用 OkHttp 进行网络请求 以下是使…

Vue 3 国际化 (i18n) 最佳实践指南

1. 安装依赖 npm install vue-i18n@9 2. 项目结构建议 src/ ├── i18n/ │ ├── index.ts # i18n 配置文件 │ ├── languages/ # 语言文件目录 │ │ ├── zh-CN.ts # 中文 │ │ ├── en-US.ts # 英文 │ │ └─…

Ubuntu20.04升级glibc升级及降级的心路历程

想使用pip安装Isaac Sim&#xff0c;无奈此方法只支持 GLIBC>2.34 。使用的是Ubuntu20.04&#xff0c;使用 ldd --version 查看GLIBC版本&#xff0c;如果版本低于 2.34 则需要升级GLIBC&#xff0c;基于此开始了长达一天的尝试。 请注意&#xff0c;升级GLIBC是一个危险操作…

Android开发实战班 - 网络编程 - WebSocket 实时通信

在现代应用开发中&#xff0c;实时通信是许多应用的核心功能之一&#xff0c;例如聊天应用、实时通知、在线游戏等。WebSocket 是一种在单个 TCP 连接上进行全双工通信的协议&#xff0c;能够实现服务器与客户端之间的实时双向数据交换。相比于传统的 HTTP 请求&#xff0c;Web…

如何从android的webview 取得页面上的数据

要从Android的WebView中获取页面上的数据&#xff0c;通常有几种常见的方法&#xff1a; JavaScript Interface&#xff1a;通过JavaScript和Android Interface进行通信。这种方法允许你在JavaScript中调用Android的方法&#xff0c;反之亦然。 Evaluate JavaScript&#xff…

力扣--LCR 140.训练计划||

题目 给定一个头节点为 head 的链表用于记录一系列核心肌群训练项目编号&#xff0c;请查找并返回倒数第 cnt 个训练项目编号。 示例 1&#xff1a; 输入&#xff1a;head [2,4,7,8], cnt 1 输出&#xff1a;8 提示&#xff1a; 1 < head.length < 100 0 < hea…

奶龙IP联名异军突起:如何携手品牌营销共创双赢?

在快节奏的互联网消费时代&#xff0c;年轻消费群体对产品和品牌的要求越来越挑剔。因此在品牌年轻化的当下&#xff0c;一方面需要品牌自身形象也要不断追求时代感&#xff0c;另一方面品牌也需要不断引领消费者需求&#xff0c;提升竞争力和产品力。 奶龙作为近年来异军突起…

Java LinkedList 详解

LinkedList 是 Java 集合框架中常用的数据结构之一&#xff0c;位于 java.util 包中。它实现了 List、Deque 和 Queue 接口&#xff0c;是一个双向链表结构&#xff0c;适合频繁的插入和删除操作。 1. LinkedList 的特点 数据结构&#xff1a;基于双向链表实现&#xff0c;每个…

ROM修改进阶教程------安卓14去除修改系统应用后导致的卡logo验证步骤 适用安卓13 14 安卓15可借鉴参考

上期的博文解析了安卓14 安卓15去除系统应用签名验证的步骤解析。我们要明白。修改系统应用后有那些验证。其中签名验证 去卡logo验证 与可降级安装应用验证等等的区别。有些要相互结合使用。今天的博文将对修改系统应用后卡logo验证做个步骤解析。 通过博文了解💝💝�…

【Spring boot】微服务项目的搭建整合swagger的fastdfs和demo的编写

文章目录 1. 微服务项目搭建2. 整合 Swagger 信息3. 部署 fastdfsFastDFS安装环境安装开始图片测试FastDFS和nginx整合在Storage上安装nginxnginx安装不成功排查:4. springboot 整合 fastdfs 的demodemo编写1. 微服务项目搭建 版本总结: spring boot: 2.6.13springfox-boot…

Docker 篇-Docker 详细安装、了解和使用 Docker 核心功能(数据卷、自定义镜像 Dockerfile、网络)

&#x1f525;博客主页&#xff1a; 【小扳_-CSDN博客】 ❤感谢大家点赞&#x1f44d;收藏⭐评论✍ 文章目录 1.0 Docker 概述 1.1 Docker 主要组成部分 1.2 Docker 安装 2.0 Docker 常见命令 2.1 常见的命令介绍 2.2 常见的命令演示 3.0 数据卷 3.1 数据卷常见的命令 3.2 常见…

部门管理系统功能完善(删除部门、添加部门、根据 ID 查询部门 和 修改部门)

一、目标 继续实现 删除部门、添加部门、根据 ID 查询部门 和 修改部门 的详细功能实现&#xff0c;分为 Controller 层、Service 层 和 Mapper 层。 二、代码分析 总体代码&#xff1a; Controller 层&#xff1a; package com.zhang.Controller; Slf4j RequestMapping(&qu…