GRU--详解

GRU(Gated Recurrent Unit)(门控循环单元)是RNN(循环神经网络)的一种变体。GRU的设计简化了另一种RNN变体——LSTM(长短期记忆网络),与LSTM不同的是,GRU将输入门和遗忘门合并为一个单一的“重置门”和“更新门”,从而减少了模型的复杂性,同时仍能有效地捕捉长期依赖关系。

GRU的基本结构

GRU的结构主要由以下两个门组成:

  1. 重置门(Reset Gate):控制前一时刻的状态信息应该被遗忘的程度,决定当前时刻有多少过去的信息需要被遗忘。

  2. 更新门(Update Gate):决定前一时刻的状态信息对当前时刻的影响程度,控制当前时刻的隐藏状态应该保留多少前一时刻的记忆。

GRU的经典代码

在深度学习框架如PyTorch或TensorFlow中,GRU的实现非常简单。以下是用PyTorch实现一个简单GRU网络的代码:

import torch
import torch.nn as nn
​
class GRUNet(nn.Module):def __init__(self, input_size, hidden_size, num_layers, output_size):super(GRUNet, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):# 初始化隐藏状态h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)# 通过GRU层out, _ = self.gru(x, h0)# 取最后一个时间步的输出out = out[:, -1, :]# 全连接层out = self.fc(out)return out
​
# 使用示例
input_size = 10
hidden_size = 20
num_layers = 2
output_size = 1
model = GRUNet(input_size, hidden_size, num_layers, output_size)
​
# 生成随机输入数据
input_data = torch.randn(32, 5, input_size)  # (batch_size, sequence_length, input_size)
output = model(input_data)
print(output.shape)  # (batch_size, output_size)

处理文本生成任务的GRU示例

文本生成任务中,GRU通常作为生成器的一部分,输入是前一个时间步生成的字符或单词,输出是下一个时间步的预测字符或单词。下面是一个使用PyTorch的GRU实现文本生成的简单示例。

数据准备

使用字符级RNN来生成文本,首先需要将文本数据转化为字符的索引。

import torch
import torch.nn as nn
import torch.optim as optim
​
# 准备数据
text = "hello world"  # 简单的训练文本示例
chars = list(set(text))
char_to_idx = {ch: i for i, ch in enumerate(chars)}
idx_to_char = {i: ch for i, ch in enumerate(chars)}
input_size = len(chars)
​
# 将文本转化为索引
data = [char_to_idx[ch] for ch in text]
input_data = torch.tensor(data[:-1])  # 输入文本(去掉最后一个字符)
target_data = torch.tensor(data[1:])  # 目标文本(去掉第一个字符)
模型定义
class TextGenerationGRU(nn.Module):def __init__(self, input_size, hidden_size, output_size, num_layers=1):super(TextGenerationGRU, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x, hidden):out, hidden = self.gru(x, hidden)out = self.fc(out)return out, hiddendef init_hidden(self, batch_size):return torch.zeros(self.num_layers, batch_size, self.hidden_size)
​
# 超参数
hidden_size = 128
output_size = input_size  # 输出大小和输入大小相同,都是字符集大小
num_layers = 1
​
model = TextGenerationGRU(input_size, hidden_size, output_size, num_layers)
​
# 损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
训练循环
num_epochs = 1000
seq_length = len(input_data)
input_data_one_hot = nn.functional.one_hot(input_data, num_classes=input_size).float().unsqueeze(0)
​
for epoch in range(num_epochs):# 初始化隐藏状态hidden = model.init_hidden(1)# 前向传播outputs, hidden = model(input_data_one_hot, hidden)loss = criterion(outputs.squeeze(0), target_data)# 反向传播及优化optimizer.zero_grad()loss.backward()optimizer.step()if (epoch + 1) % 100 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
文本生成

一旦训练完成,可以使用训练好的GRU模型来生成新文本。以下是生成新文本的代码:

def generate_text(model, start_char, char_to_idx, idx_to_char, hidden_size, num_generate):input_char = torch.tensor([char_to_idx[start_char]])input_char_one_hot = nn.functional.one_hot(input_char, num_classes=len(char_to_idx)).float().unsqueeze(0)hidden = model.init_hidden(1)generated_text = start_charfor _ in range(num_generate):output, hidden = model(input_char_one_hot, hidden)predicted_idx = torch.argmax(output, dim=2).item()predicted_char = idx_to_char[predicted_idx]generated_text += predicted_charinput_char = torch.tensor([predicted_idx])input_char_one_hot = nn.functional.one_hot(input_char, num_classes=len(char_to_idx)).float().unsqueeze(0)return generated_text
​
# 使用训练好的模型生成文本
generated_text = generate_text(model, 'h', char_to_idx, idx_to_char, hidden_size, num_generate=20)
print("Generated Text:", generated_text)
总结

GRU 是一种强大的循环神经网络架构,在处理序列数据(如文本生成、语言模型等)时非常有效。其结构相比 LSTM 简化了门控机制,但仍能有效捕捉长时间依赖。通过PyTorch等框架,可以快速构建并训练GRU模型,并应用于诸如文本生成等任务。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/56486.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【OpenGauss源码学习 —— (VecSortAgg)】

VecSortAgg 概述SortAggRunner::SortAggRunner 函数SortAggRunner::init_phase 函数SortAggRunner::init_indexForApFun 函数SortAggRunner::set_key 函数BaseAggRunner::initialize_sortstate 函数SortAggRunner::BindingFp 函数SortAggRunner::buildSortAgg 函数SortAggRunne…

python从0快速上手(一)python环境搭建 windows macos linux

Python环境搭建超详细指南 Python是一种广泛使用的高级编程语言,它以其简洁的语法和强大的功能而受到开发者的喜爱。对于初学者来说,搭建一个合适的Python开发环境是开始Python之旅的第一步。本文将为你提供一个超级详细的Python环境搭建指南&#xff0…

基于SpringBoot+Vue+Uniapp家具购物小程序的设计与实现

详细视频演示 请联系我获取更详细的演示视频 项目运行截图 技术框架 后端采用SpringBoot框架 Spring Boot 是一个用于快速开发基于 Spring 框架的应用程序的开源框架。它采用约定大于配置的理念,提供了一套默认的配置,让开发者可以更专注于业务逻辑而…

3个方法快速恢复微信已过期或被清理图片

微信作为现在国内用户数量最多的社交软件,已经成为了许多人日常生活和工作中必不可少的一部分。但微信中的图片有时会因为多种原因而消失,如过期、被清理或者误删。遇到这种情况,那么已过期或被清理的图片还能恢复吗?下面小编就来…

学习之上下文管理器

one_file open(demo.txt, w) one_file.write("xxxxx") # raise ValueError # 如果抛出异常将会报错 one_file.close()with open(demo.txt, w) as f: # open--返回的是IO--IO中实现了__enter__方法和__exit__方法f.write("aaaa")class MyContextManger:d…

论文速读:通过目标感知双分支蒸馏进行跨域目标检测(CVPR2022)

原文标题:Cross Domain Object Detection by Target-Perceived Dual Branch Distillation 中文标题:通过目标感知双分支蒸馏进行跨域目标检测 论文地址: https://arxiv.org/abs/2205.01291 代码地址: GitHub - Feobi1999/TDD 这篇…

做个工作中的退让者,生活中的前进者

先来分享一下什么是退让者原则,退让者原则,也被称为“幸福者退让原则”,是一种在面对冲突和挑衅时采取的策略,其核心理念是在拥有幸福生活的背景下,选择退让而非直接对抗,以保护个人及家庭幸福为优先。 为…

在IDEA中配置Selenium和WebDriver

前言: 在当今自动化测试和网络爬虫的领域,Selenium是一个被广泛使用的工具。它不仅能够模拟用户与浏览器的交互,还能进行网页测试和数据抓取。而为了使用Selenium与谷歌/Edge浏览器进行自动化测试,配置合适的WebDriver至关重要。本…

【前端】Bootstrap:栅格系统 (Grid System)

Bootstrap的栅格系统是该框架的核心部分之一,能够让开发者轻松创建响应式网页布局,适配各种屏幕尺寸和设备。栅格系统通过将页面划分为12列的布局结构,开发者可以根据内容的重要性和设计需求灵活控制元素的宽度和排列。 在这篇文章中&#x…

Java--练习--DVD管理系统

一、详细代码 package demo2.Test;import java.text.ParseException; import java.text.SimpleDateFormat; import java.util.Date; import java.util.Scanner;public class DVD2 {//DVD管理系统//声明三个数组 用来存储 名称 状态 借出日期 借出次数String[] names new St…

学习使用linux的bash命令可以删除ps aux进程中今天之前指定运行进程的脚本

学习使用linux的bash命令可以删除ps aux进程中今天之前指定运行进程的脚本 脚本注意事项: 脚本 #!/bin/bash# 获取今天的日期(格式 YYYY-MM-DD) TODAY$(date %Y-%m-%d)# 使用 ps aux 查找所有名为 qipa250 的进程 # 并提取出 PID 和启动时间…

游戏引擎哪家强?选哪一个更有钱途

游戏引擎乃是构筑及开发视频游戏的软件架构。其供应一整套工具与库,以处置常见的游戏开发事务,诸如渲染图形、模拟物理、管控音频等等。凭借对游戏引擎的运用,开发人员能够将精力倾注于构建其游戏的独特之处,而非再度发明此类基础…

【游戏模组】极品飞车12无间风云冬季mod,冬天版本的无间风云你体验过吗

各位好,今天小编给大家带来一款新的高清重置魔改MOD,本次高清重置的游戏叫《极品飞车12无间风云》。 《极品飞车12:无间风云》是由Black Box游戏制作室开发的竞速类游戏,于2008年11月18日在北美首发、2008年11月21日在欧洲先后推…

【深入学习Redis丨第八篇】详解Redis数据持久化机制

前言 Redis支持两种数据持久化方式:RDB方式和AOF方式。前者会根据配置的规则定时将内存中的数据持久化到硬盘上,后者则是在每次执行写命令之后将命令记录下来。两种持久化方式可以单独使用,但是通常会将两者结合使用。 一、持久化 1.1、什么…

MySQL【知识改变命运】04

复习: 1:CURD 1.1Create (创建) 语法: insert [into] 表名 [column[,column]] valuse(value_list)[,vaule_list]... value_list:value,[value]...创建一个实例表: 1.1.1单⾏数据全列插⼊ values_l…

Python爬虫之正则表达式于xpath的使用教学及案例

正则表达式 常用的匹配模式 \d # 匹配任意一个数字 \D # 匹配任意一个非数字 \w # 匹配任意一个单词字符(数字、字母、下划线) \W # 匹配任意一个非单词字符 . # 匹配任意一个字符(除了换行符) [a-z] # 匹配任意一个小写字母 […

CSS之一

目录 简介 CSS 语法规范 CSS 代码风格 1.样式格式书写 2.样式大小写 CSS 基础选择器 选择器分类 标签选择器 类选择器 案例之画盒子 多类型使用 id选择器 通配符选择器 font-family设置字体 字体系列 字体大小 字体粗细 文字样式 字体复合属性 示例 CSS 文…

【力扣 | SQL题 | 每日3题】力扣1107,1112, 1077

今天三道mid题都可以用窗口函数轻松秒杀。 1. 力扣1107:每日新用户统计 1.1 题目: Traffic 表: ------------------------ | Column Name | Type | ------------------------ | user_id | int | | activity | enum …

mysql模糊查询优化

mysql模糊查询优化 一、合理使用索引 如下SQL举例: SELECT username,age FROM WHERE username LIKE ‘hysen%’ 如果username字段有索引,前缀匹配会走索引,如 ‘%hysen’或’%hysen%’ 则无法走索引。 二、使用反向索引 对于需要使用后缀…

解决关于HTML+JS + Servlet 实现前后端请求Session不一致的问题

1、前后端不分离情况 在处理session过程中,如果前后端项目在一个容器中,session是可以被获取的。例如如下项目结构: 结构 后端的代码是基本的设置值、获取值、销毁值的内容: 运行结果 由此可见,在前后统一的项目中&a…