pytorch实现门控循环单元 (GRU)

 人工智能例子汇总:AI常见的算法和例子-CSDN博客  

特性GRULSTM
计算效率更快,参数更少相对较慢,参数更多
结构复杂度只有两个门(更新门和重置门)三个门(输入门、遗忘门、输出门)
处理长时依赖一般适用于中等长度依赖更适合处理超长时序依赖
训练速度训练更快,梯度更稳定训练较慢,占用更多内存

例子:

import torch
import torch.nn as nn
import torch.optim as optim
import random
import matplotlib.pyplot as plt# 🏁 迷宫环境(5×5)
class MazeEnv:def __init__(self, size=5):self.size = sizeself.state = (0, 0)  # 起点self.goal = (size-1, size-1)  # 终点self.actions = [(0,1), (0,-1), (1,0), (-1,0)]  # 右、左、下、上def reset(self):self.state = (0, 0)  # 重置起点return self.statedef step(self, action):dx, dy = self.actions[action]x, y = self.statenx, ny = max(0, min(self.size-1, x+dx)), max(0, min(self.size-1, y+dy))reward = 1 if (nx, ny) == self.goal else -0.1done = (nx, ny) == self.goalself.state = (nx, ny)return (nx, ny), reward, done# 🤖 GRU 策略网络
class GRUPolicy(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(GRUPolicy, self).__init__()self.gru = nn.GRU(input_size, hidden_size, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x, hidden):out, hidden = self.gru(x, hidden)out = self.fc(out[:, -1, :])  # 只取最后时间步return out, hidden# 🎯 训练参数
env = MazeEnv(size=5)
policy = GRUPolicy(input_size=2, hidden_size=16, output_size=4)
optimizer = optim.Adam(policy.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()# 🎓 训练
num_episodes = 500
epsilon = 1.0  # 初始的ε值,控制探索的概率
epsilon_min = 0.01  # 最小ε值
epsilon_decay = 0.995  # ε衰减率
best_path = []  # 用于存储最佳路径for episode in range(num_episodes):state = env.reset()hidden = torch.zeros(1, 1, 16)  # GRU 初始状态states, actions, rewards = [], [], []logits_list = []  for _ in range(20):  # 最多 20 步state_tensor = torch.tensor([[state[0], state[1]]], dtype=torch.float32).unsqueeze(0)logits, hidden = policy(state_tensor, hidden)logits_list.append(logits)# ε-greedy 策略if random.random() < epsilon:action = random.choice(range(4))  # 随机选择动作else:action = torch.argmax(logits, dim=1).item()  # 选择最大值对应的动作next_state, reward, done = env.step(action)states.append(state)actions.append(action)rewards.append(reward)if done:print(f"Episode {episode} - Reached Goal!")# 找到最优路径best_path = states + [next_state]  # 当前 episode 的路径breakstate = next_state# 计算损失logits = torch.cat(logits_list, dim=0)  # (T, 4)action_tensor = torch.tensor(actions, dtype=torch.long)  # (T,)loss = loss_fn(logits, action_tensor)  optimizer.zero_grad()loss.backward()optimizer.step()# 衰减 εepsilon = max(epsilon_min, epsilon * epsilon_decay)if episode % 100 == 0:print(f"Episode {episode}, Loss: {loss.item():.4f}, Epsilon: {epsilon:.4f}")# 🧐 确保 best_path 已经记录
if len(best_path) == 0:print("No path found during training.")
else:print(f"Best path: {best_path}")# 🚀 测试路径(只绘制最佳路径)
fig, ax = plt.subplots(figsize=(6,6))# 初始化迷宫图
maze = [[0 for _ in range(5)] for _ in range(5)]  # 5×5 迷宫
ax.imshow(maze, cmap="coolwarm", origin="upper")# 画网格
ax.set_xticks(range(5))
ax.set_yticks(range(5))
ax.grid(True, color="black", linewidth=0.5)# 画出最佳路径(红色)
for (x, y) in best_path:ax.add_patch(plt.Rectangle((y, x), 1, 1, color="red", alpha=0.8))# 画起点和终点
ax.text(0, 0, "S", ha="center", va="center", fontsize=14, color="white", fontweight="bold")
ax.text(4, 4, "G", ha="center", va="center", fontsize=14, color="white", fontweight="bold")plt.title("GRU RL Agent - Best Path")
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/894641.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

PAT甲级1032、sharing

题目 To store English words, one method is to use linked lists and store a word letter by letter. To save some space, we may let the words share the same sublist if they share the same suffix. For example, loading and being are stored as showed in Figure …

最小生成树kruskal算法

文章目录 kruskal算法的思想模板 kruskal算法的思想 模板 #include <bits/stdc.h> #define lowbit(x) ((x)&(-x)) #define int long long #define endl \n #define PII pair<int,int> #define IOS ios::sync_with_stdio(0),cin.tie(0),cout.tie(0); using na…

为何在Kubernetes容器中以root身份运行存在风险?

作者&#xff1a;马辛瓦西奥内克&#xff08;Marcin Wasiucionek&#xff09; 引言 在Kubernetes安全领域&#xff0c;一个常见的建议是让容器以非root用户身份运行。但是&#xff0c;在容器中以root身份运行&#xff0c;实际会带来哪些安全隐患呢&#xff1f;在Docker镜像和…

ConcurrentHashMap线程安全:分段锁 到 synchronized + CAS

专栏系列文章地址&#xff1a;https://blog.csdn.net/qq_26437925/article/details/145290162 本文目标&#xff1a; 理解ConcurrentHashMap为什么线程安全&#xff1b;ConcurrentHashMap的具体细节还需要进一步研究 目录 ConcurrentHashMap介绍JDK7的分段锁实现JDK8的synchr…

[ESP32:Vscode+PlatformIO]新建工程 常用配置与设置

2025-1-29 一、新建工程 选择一个要创建工程文件夹的地方&#xff0c;在空白处鼠标右键选择通过Code打开 打开Vscode&#xff0c;点击platformIO图标&#xff0c;选择PIO Home下的open&#xff0c;最后点击new project 按照下图进行设置 第一个是工程文件夹的名称 第二个是…

述评:如果抗拒特朗普的“普征关税”

题 记 美国总统特朗普宣布对美国三大贸易夥伴——中国、墨西哥和加拿大&#xff0c;分别征收10%、25%的关税。 他威胁说&#xff0c;如果这三个国家不解决他对非法移民和毒品走私的担忧&#xff0c;他就要征收进口税。 去年&#xff0c;中国、墨西哥和加拿大这三个国家&#…

九. Redis 持久化-AOF(详细讲解说明,一个配置一个说明分析,步步讲解到位 2)

九. Redis 持久化-AOF(详细讲解说明&#xff0c;一个配置一个说明分析&#xff0c;步步讲解到位 2) 文章目录 九. Redis 持久化-AOF(详细讲解说明&#xff0c;一个配置一个说明分析&#xff0c;步步讲解到位 2)1. Redis 持久化 AOF 概述2. AOF 持久化流程3. AOF 的配置4. AOF 启…

基于Springboot框架的学术期刊遴选服务-项目演示

项目介绍 本课程演示的是一款 基于Javaweb的水果超市管理系统&#xff0c;主要针对计算机相关专业的正在做毕设的学生与需要项目实战练习的 Java 学习者。 1.包含&#xff1a;项目源码、项目文档、数据库脚本、软件工具等所有资料 2.带你从零开始部署运行本套系统 3.该项目附…

新版231普通阿里滑块 自动化和逆向实现 分析

声明: 本文章中所有内容仅供学习交流使用&#xff0c;不用于其他任何目的&#xff0c;抓包内容、敏感网址、数据接口等均已做脱敏处理&#xff0c;严禁用于商业用途和非法用途&#xff0c;否则由此产生的一切后果均与作者无关&#xff01; 逆向过程 补环境逆向 部分补环境 …

java-(Oracle)-Oracle,plsqldev,Sql语法,Oracle函数

卸载好注册表,然后安装11g 每次在执行orderby的时候相当于是做了全排序,思考全排序的效率 会比较耗费系统的资源,因此选择在业务不太繁忙的时候进行 --给表添加注释 comment on table emp is 雇员表 --给列添加注释; comment on column emp.empno is 雇员工号;select empno,en…

泰山派Linux环境下自动烧录脚本(EMMC 2+16G)

脚本名字&#xff1a; download.sh 输入./download -h获取帮助信息 &#xff0c;其中各个IMG/TXT烧录的地址和路径都在前几行修改即可 #!/bin/bash# # DownLoad.sh 多镜像烧录脚本 # 版本&#xff1a;1.1 # 作者&#xff1a;zhangqi # 功能&#xff1a;通过参数选择烧录指定镜…

正大杯攻略|分层抽样+不等概率三阶段抽样

首先&#xff0c;先进行分层抽样&#xff0c;确定主城区和郊区的比例 然后对主城区分别进行不等概率三阶段抽样 第一阶段&#xff0c;使用PPS抽样&#xff0c;确定行政区&#xff08;根据分层抽样比例合理确定主城区和郊区行政区数量&#xff09; 第二阶段&#xff0c;使用分…

开源智慧园区管理系统对比其他十种管理软件的优势与应用前景分析

内容概要 在当今数字化快速发展的时代&#xff0c;园区管理软件的选择显得尤为重要。而开源智慧园区管理系统凭借其独特的优势&#xff0c;逐渐成为用户的新宠。与传统管理软件相比&#xff0c;它不仅灵活性高&#xff0c;而且具有更强的可定制性&#xff0c;让各类园区&#…

计算机网络 应用层 笔记1(C/S模型,P2P模型,FTP协议)

应用层概述&#xff1a; 功能&#xff1a; 常见协议 应用层与其他层的关系 网络应用模型 C/S模型&#xff1a; 优点 缺点 P2P模型&#xff1a; 优点 缺点 DNS系统&#xff1a; 基本功能 系统架构 域名空间&#xff1a; DNS 服务器 根服务器&#xff1a; 顶级域…

人类心智逆向工程:AGI的认知科学基础

文章目录 引言:为何需要逆向工程人类心智?一、逆向工程的定义与目标1.1 什么是逆向工程?1.2 AGI逆向工程的核心目标二、认知科学的四大支柱与AGI2.1 神经科学:大脑的硬件解剖2.2 心理学:心智的行为建模2.3 语言学:符号与意义的桥梁2.4 哲学:意识与自我模型的争议三、逆向…

游戏引擎学习第86天

仓库: https://gitee.com/mrxiao_com/2d_game_2 回顾 继续之前的工作。 昨天已经让地形系统基本运行起来&#xff0c;但目前仍然需要进一步完善&#xff0c;使其能够生成更多的地块。目前的情况是&#xff0c;仅仅有一个地块位于中心区域&#xff0c;而真正需要的是让地块覆盖…

Python在线编辑器

from flask import Flask, render_template, request, jsonify import sys from io import StringIO import contextlib import subprocess import importlib import threading import time import ast import reapp Flask(__name__)RESTRICTED_PACKAGES {tkinter: 抱歉&…

力扣动态规划-20【算法学习day.114】

前言 ###我做这类文章一个重要的目的还是记录自己的学习过程&#xff0c;我的解析也不会做的非常详细&#xff0c;只会提供思路和一些关键点&#xff0c;力扣上的大佬们的题解质量是非常非常高滴&#xff01;&#xff01;&#xff01; 习题 1.网格中的最小路径代价 题目链接…

从通讯工具到 AI 助理,AI手机如何发展?

随着AI进军各行各业&#xff0c;全面AI化时代已经到来。手机&#xff0c;作为现代人类的“数字器官”之一&#xff0c;更是首当其冲地融入了这一变革浪潮之中。 2024年年初&#xff0c;OPPO联合IDC发布了《AI手机白皮书》&#xff0c;公布OPPO已迈向AI手机这一全新阶段。到如今…

游戏引擎 Unity - Unity 打开项目、Unity Editor 添加简体中文语言包模块、Unity 项目设置为简体中文

Unity Unity 首次发布于 2005 年&#xff0c;属于 Unity Technologies Unity 使用的开发技术有&#xff1a;C# Unity 的适用平台&#xff1a;PC、主机、移动设备、VR / AR、Web 等 Unity 的适用领域&#xff1a;开发中等画质中小型项目 Unity 适合初学者或需要快速上手的开…