营销网站建设流程/免费网站推广网站短视频

营销网站建设流程,免费网站推广网站短视频,杭州做网站工作室,幕墙设计培训乡网站建设一、分层强化学习原理 1. 分层学习核心思想 分层强化学习(Hierarchical Reinforcement Learning, HRL)通过时间抽象和任务分解解决复杂长程任务。核心思想是: 对比维度传统强化学习分层强化学习策略结构单一策略直接输出动作高层策略选择选…

一、分层强化学习原理

1. 分层学习核心思想

分层强化学习(Hierarchical Reinforcement Learning, HRL)通过时间抽象任务分解解决复杂长程任务。核心思想是:

对比维度传统强化学习分层强化学习
策略结构单一策略直接输出动作高层策略选择选项(Option)
时间尺度单一步长决策高层策略决策跨度长,底层策略执行
适用场景简单短程任务复杂长程任务(如迷宫导航、机器人操控)
2. Option-Critic 算法框架

Option-Critic 是 HRL 的代表性算法,其核心组件包括:


二、Option-Critic 实现步骤(基于 Gymnasium)

我们将以 Meta-World 机械臂多阶段任务 为例,实现 Option-Critic 算法:

  1. 定义选项集合:包含 reach(接近目标)、grasp(抓取)、move(移动) 三个选项

  2. 构建策略网络:高层策略 + 选项内部策略 + 终止条件网络

  3. 分层交互训练:高层选择选项,底层执行多步动作

  4. 联合梯度更新:优化高层和底层策略


三、代码实现

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.distributions import Categorical, Normal
import gymnasium as gym
from metaworld.envs import ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE
import time
​
# ================== 配置参数优化 ==================
class OptionCriticConfig:num_options = 3                  # 选项数量(reach, grasp, move)option_length = 20               # 选项最大执行步长hidden_dim = 128                 # 网络隐藏层维度lr_high = 1e-4                   # 高层策略学习率lr_option = 3e-4                 # 选项策略学习率gamma = 0.99                     # 折扣因子entropy_weight = 0.01            # 熵正则化权重max_episodes = 5000              # 最大训练回合数device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
​
# ================== 高层策略网络 ==================
class HighLevelPolicy(nn.Module):def __init__(self, state_dim, num_options):super().__init__()self.net = nn.Sequential(nn.Linear(state_dim, OptionCriticConfig.hidden_dim),nn.ReLU(),nn.Linear(OptionCriticConfig.hidden_dim, num_options))def forward(self, state):return self.net(state)
​
# ================== 选项内部策略网络 ==================
class OptionPolicy(nn.Module):def __init__(self, state_dim, action_dim):super().__init__()self.net = nn.Sequential(nn.Linear(state_dim, OptionCriticConfig.hidden_dim),nn.ReLU(),nn.Linear(OptionCriticConfig.hidden_dim, action_dim))def forward(self, state):return self.net(state)
​
# ================== 终止条件网络 ==================
class TerminationNetwork(nn.Module):def __init__(self, state_dim):super().__init__()self.net = nn.Sequential(nn.Linear(state_dim, OptionCriticConfig.hidden_dim),nn.ReLU(),nn.Linear(OptionCriticConfig.hidden_dim, 1),nn.Sigmoid()  # 输出终止概率)def forward(self, state):return self.net(state)
​
# ================== 训练系统 ==================
class OptionCriticTrainer:def __init__(self):# 初始化环境self.env = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE['pick-place-v2-goal-observable']()# 处理观测空间if isinstance(self.env.observation_space, gym.spaces.Dict):self.state_dim = sum([self.env.observation_space.spaces[key].shape[0] for key in ['observation', 'desired_goal']])self.process_state = self._process_dict_stateelse:self.state_dim = self.env.observation_space.shape[0]self.process_state = lambda x: xself.action_dim = self.env.action_space.shape[0]# 初始化网络self.high_policy = HighLevelPolicy(self.state_dim, OptionCriticConfig.num_options).to(OptionCriticConfig.device)self.option_policies = nn.ModuleList([OptionPolicy(self.state_dim, self.action_dim).to(OptionCriticConfig.device)for _ in range(OptionCriticConfig.num_options)])self.termination_networks = nn.ModuleList([TerminationNetwork(self.state_dim).to(OptionCriticConfig.device)for _ in range(OptionCriticConfig.num_options)])# 优化器self.optimizer_high = optim.Adam(self.high_policy.parameters(), lr=OptionCriticConfig.lr_high)self.optimizer_option = optim.Adam(list(self.option_policies.parameters()) + list(self.termination_networks.parameters()),lr=OptionCriticConfig.lr_option)def _process_dict_state(self, state_dict):return np.concatenate([state_dict['observation'], state_dict['desired_goal']])def select_option(self, state):state = torch.FloatTensor(state).to(OptionCriticConfig.device)logits = self.high_policy(state)dist = Categorical(logits=logits)option = dist.sample()return option.item(), dist.log_prob(option)def select_action(self, state, option):state = torch.FloatTensor(state).to(OptionCriticConfig.device)action_mean = self.option_policies[option](state)dist = Normal(action_mean, torch.ones_like(action_mean))  # 假设动作空间连续action = dist.sample()log_prob = dist.log_prob(action).sum(dim=-1)  # 沿最后一个维度求和得到标量return action.cpu().numpy(), log_prob  # 返回标量log概率def should_terminate(self, state, current_option):state = torch.FloatTensor(state).to(OptionCriticConfig.device)terminate_prob = self.termination_networks[current_option](state)return torch.bernoulli(terminate_prob).item() == 1def train(self):for episode in range(OptionCriticConfig.max_episodes):state_dict, _ = self.env.reset()state = self.process_state(state_dict)episode_reward = 0current_option, log_prob_high = self.select_option(state)option_step = 0while True:# 执行选项内部策略action, log_prob_option = self.select_action(state, current_option)next_state_dict, reward, terminated, truncated, _ = self.env.step(action)done = terminated or truncatednext_state = self.process_state(next_state_dict)episode_reward += reward# 判断是否终止选项terminate = self.should_terminate(next_state, current_option) or (option_step >= OptionCriticConfig.option_length)# 计算梯度if terminate or done:# 计算选项价值(添加detach防止梯度传递)with torch.no_grad():next_value = self.high_policy(torch.FloatTensor(next_state).to(OptionCriticConfig.device)).max().item()termination_output = self.termination_networks[current_option](torch.FloatTensor(state).to(OptionCriticConfig.device))# 计算delta时分离终止网络的梯度delta = reward + OptionCriticConfig.gamma * next_value - termination_output.detach()
​# 高层策略梯度计算loss_high = -log_prob_high * deltaself.optimizer_high.zero_grad()loss_high.backward(retain_graph=True)  # 保留计算图self.optimizer_high.step()
​# 选项策略梯度计算loss_option = -log_prob_option * deltaentropy = -log_prob_option * torch.exp(log_prob_option.detach())loss_option_total = loss_option + OptionCriticConfig.entropy_weight * entropyself.optimizer_option.zero_grad()loss_option_total.backward()  # 此时仍可访问保留的计算图self.optimizer_option.step()# 重置选项if not done:current_option, log_prob_high = self.select_option(next_state)option_step = 0else:breakelse:option_step += 1state = next_stateif (episode + 1) % 100 == 0:print(f"Episode {episode+1} | Reward: {episode_reward:.1f}")
​
if __name__ == "__main__":start = time.time()start_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(start))print(f"开始时间: {start_str}")print("初始化环境...")trainer = OptionCriticTrainer()trainer.train()end = time.time()end_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(end))print(f"训练完成时间: {end_str}")print(f"训练完成,耗时: {end - start:.2f}秒")

四、关键代码解析

  1. 高层策略选择选项

    select_option:基于当前状态选择选项,返回选项 ID 和选择概率的对数值。
  2. 选项内部策略执行

    select_action:根据当前选项生成动作,支持连续动作空间(使用高斯分布)。
  3. 终止条件判断

    should_terminate:根据终止网络输出概率判断是否终止当前选项。
  4. 梯度更新逻辑

    高层策略:基于选项的价值差(TD Error)更新。
    选项策略:结合 TD Error 和熵正则化更新。

五、训练输出示例

开始时间: 2025-03-24 08:29:46
初始化环境...
Episode 100 | Reward: 2.7
Episode 200 | Reward: 4.9
Episode 300 | Reward: 2.2
Episode 400 | Reward: 2.8
Episode 500 | Reward: 3.0
Episode 600 | Reward: 3.3
Episode 700 | Reward: 3.2
Episode 800 | Reward: 4.7
Episode 900 | Reward: 5.3
Episode 1000 | Reward: 7.5
Episode 1100 | Reward: 6.3
Episode 1200 | Reward: 3.7
Episode 1300 | Reward: 7.8
Episode 1400 | Reward: 3.8
Episode 1500 | Reward: 2.4
Episode 1600 | Reward: 2.3
Episode 1700 | Reward: 2.5
Episode 1800 | Reward: 2.7
Episode 1900 | Reward: 2.7
Episode 2000 | Reward: 3.9
Episode 2100 | Reward: 4.5
Episode 2200 | Reward: 4.1
Episode 2300 | Reward: 4.7
Episode 2400 | Reward: 4.0
Episode 2500 | Reward: 4.3
Episode 2600 | Reward: 3.8
Episode 2700 | Reward: 3.3
Episode 2800 | Reward: 4.6
Episode 2900 | Reward: 5.2
Episode 3000 | Reward: 7.7
Episode 3100 | Reward: 7.8
Episode 3200 | Reward: 3.3
Episode 3300 | Reward: 5.3
Episode 3400 | Reward: 4.5
Episode 3500 | Reward: 3.9
Episode 3600 | Reward: 4.1
Episode 3700 | Reward: 4.0
Episode 3800 | Reward: 5.2
Episode 3900 | Reward: 8.2
Episode 4000 | Reward: 2.2
Episode 4100 | Reward: 2.2
Episode 4200 | Reward: 2.2
Episode 4300 | Reward: 2.2
Episode 4400 | Reward: 6.9
Episode 4500 | Reward: 5.6
Episode 4600 | Reward: 2.0
Episode 4700 | Reward: 1.6
Episode 4800 | Reward: 1.7
Episode 4900 | Reward: 1.9
Episode 5000 | Reward: 3.1
训练完成时间: 2025-03-24 12:41:48
训练完成,耗时: 15122.31秒

在下一篇文章中,我们将探索 逆向强化学习(Inverse RL),并实现 GAIL 算法!


注意事项

  1. 安装依赖:

    pip install metaworld gymnasium torch
  2. Meta-World 需要 MuJoCo 许可证:

    export MUJOCO_PY_MUJOCO_PATH=/path/to/mujoco
  3. 训练时间较长(推荐 GPU 加速):

    CUDA_VISIBLE_DEVICES=0 python option_critic.py

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/898983.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Spring Boot整合Spring Data JPA

Spring Data作为Spring全家桶中重要的一员,在Spring项目全球使用市场份额排名中多次居前位,而在Spring Data子项目的使用份额排名中,Spring Data JPA也一直名列前茅。Spring Boot为Spring Data JPA提供了启动器,使Spring Data JPA…

Oracle归档配置及检查

配置归档位置到 USE_DB_RECOVERY_FILE_DEST,并设置存储大小 startup mount; !mkdir /db/archivelog ALTER SYSTEM SET db_recovery_file_dest_size100G SCOPEBOTH; ALTER SYSTEM SET db_recovery_file_dest/db/archivelog SCOPEBOTH; ALTER SYSTEM SET log_archive…

Four.meme是什么,一篇文章读懂

一、什么是Four.meme? Four.meme 是一个运行在 BNB 链的去中心化平台旨在为 meme 代币供公平启动服务。它允许用户以极低的成本创建和推出 meme 代币,无需预售或团队分配,它消除了传统的预售、种子轮和团队分配,确保所有参与者有…

Java 集合 List、Set、Map 区别与应用

一、核心特性对比 二、底层实现与典型差异 ‌List‌ ‌ArrayList‌:动态数组结构,随机访问快(O(1)),中间插入/删除效率低(O(n))‌‌LinkedList‌:双向链表结构,头尾操作…

欢迎来到未来:探索 Dify 开源大语言模型应用开发平台

欢迎来到未来:探索 Dify 开源大语言模型应用开发平台 如果你对 AI 世界有所耳闻,那么你一定听说过大语言模型(LLM)。这些智能巨兽能够生成文本、回答问题、甚至编写代码!但是,如何将它们变成真正的实用工具…

python多线程和多进程的区别有哪些

python多线程和多进程的区别有七种: 1、多线程可以共享全局变量,多进程不能。 2、多线程中,所有子线程的进程号相同;多进程中,不同的子进程进程号不同。 3、线程共享内存空间;进程的内存是独立的。 4、同一…

【MySQL报错】:Column count doesn’t match value count at row 1

MySQL报错:Column count doesn’t match value count at row 1 意思是存储的数据与数据库表的字段类型定义不相匹配. 由于类似 insert 语句中,前后列数不等造成的 主要有3个易错点: 要传入表中的字段数和values后面的值的个数不相等。 由于类…

PostgreSQL 连接数超限问题

目录标题 **PostgreSQL 连接数超限问题解决方案****一、错误原因分析****二、查看连接数与配置****三、排查连接泄漏(应用侧问题)****四、服务侧配置调整****1. 调整最大连接数****2. 释放无效连接(谨慎操作)****3. 使用连接池工具…

2025最新-智慧小区物业管理系统

目录 1. 项目概述 2. 技术栈 3. 功能模块 3.1 管理员端 3.1.1 核心业务处理模块 3.1.2 基础信息模块 3.1.3 数据统计分析模块 3.2 业主端 5. 系统架构 5.1 前端架构 5.2 后端架构 5.3 数据交互流程 6. 部署说明 6.1 环境要求 6.2 部署步骤 7. 使用说明 7.1 管…

智能汽车图像及视频处理方案,支持视频智能包装能力

美摄科技的智能汽车图像及视频处理方案,通过深度学习算法与先进的色彩管理技术,能够自动调整图像中的亮度、对比度、饱和度等关键参数,确保在各种光线条件下,图像都能呈现出最接近人眼的自然色彩与细节层次。这不仅提升了驾驶者的…

React - LineChart组件编写(用于查看每日流水图表)

一、简单版本 LineChart.tsx // src/component/LineChart/LineChart.tsx import React, {useEffect,useRef,useImperativeHandle,forwardRef,useMemo,useCallback, } from react; import * as echarts from echarts/core; import type { ComposeOption } from echarts/core; …

Web前端考核 JavaScript知识点详解

一、JavaScript 基础语法 1.1 变量声明 关键字作用域提升重复声明暂时性死区var函数级✅✅❌let块级❌❌✅const块级❌❌✅ 1.1.1变量提升的例子 在 JavaScript 中,var 声明的变量会存在变量提升的现象,而 let 和 const 则不会。变量提升是指变量的声…

使用 Go 构建 MCP Server

一个互联网技术玩家,一个爱聊技术的家伙。在工作和学习中不断思考,把这些思考总结出来,并分享,和大家一起交流进步。 一、MCP 介绍 1. 基本介绍 MCP(Model Context Protocol,模型上下文协议)是…

CES Asia 2025赛逸展:科技浪潮中的创新与商贸盛会

在科技发展日新月异的当下,CES Asia 2025第七届亚洲消费电子技术贸易展(赛逸展)正积极筹备,将在北京举办,有望成为亚洲消费电子领域极具影响力的年度盛会。作为亚洲科技领域的重要展会,此次得到了数十家电子…

Windows桌面采集技术

在进入具体的方式讨论前,我们先看看 Windows 桌面图形界面的简化架构,如下图: 在 Windows Vista 之前,Windows 界面的复合画面经由 Graphics Device Interface(以下简称 GDI)技术直接渲染到桌面上。 在 Wi…

ElementPlus 快速入门

目录 前言 为什么要学习 ElementPlus? 正文 步骤 1 创建 一个工程化的vue 项目 ​2 安装 element-Plus :Form 表单 | Element Plus 1 点击 当前界面的指南 2 点击左边菜单栏上的安装,选择包管理器 3 运行该命令 demo(案例1 ) 步骤 …

[蓝桥杯 2023 省 A] 异或和之和

题目来自洛谷网站&#xff1a; 暴力思路&#xff1a; 先进性预处理&#xff0c;找到每个点位置的前缀异或和&#xff0c;在枚举区间。 暴力代码&#xff1a; #include<bits/stdc.h> #define int long long using namespace std; const int N 1e520;int n; int arr[N…

python学习笔记--实现简单的爬虫(二)

任务&#xff1a;爬取B站上最爱欢迎的编程课程 网址&#xff1a;编程-哔哩哔哩_bilibili 打开网页的代码模块&#xff0c;如下图&#xff1a; 标题均位于class_"bili-video-card__info--tit"的h3标签中&#xff0c;下面通过代码来实现&#xff0c;需要说明的是URL中…

windows清除电脑开机密码,可保留原本的系统和资料,不重装系统

前言 很久的一台电脑没有使用了&#xff0c;开机密码忘了&#xff0c;进不去系统 方法 1.将一个闲置u盘设置成pe盘&#xff08;注意&#xff0c;这个操作会清空原来u盘的数据&#xff0c;需要在配置前将重要数据转移走&#xff0c;数据无价&#xff0c;别因为配置这个丢了重…

5.4 位运算专题:LeetCode 137. 只出现一次的数字 II

1. 题目链接 LeetCode 137. 只出现一次的数字 II 2. 题目描述 给定一个整数数组 nums&#xff0c;其中每个元素均出现 三次&#xff0c;除了一个元素只出现 一次。请找出这个只出现一次的元素。 要求&#xff1a; 时间复杂度为 O(n)&#xff0c;空间复杂度为 O(1)。 示例&a…