wordpress转静态页面/开封搜索引擎优化

wordpress转静态页面,开封搜索引擎优化,微信 购物网站开发,网站规划与建设论文一、整体概述 此代码利用 REINFORCE 算法(一种基于策略梯度的强化学习算法)来解决 OpenAI Gym 中的 CartPole-v1 环境问题。CartPole-v1 环境的任务是控制一个小车,使连接在小车上的杆子保持平衡。代码通过构建一个神经网络作为策略网络&…

一、整体概述

此代码利用 REINFORCE 算法(一种基于策略梯度的强化学习算法)来解决 OpenAI Gym 中的 CartPole-v1 环境问题。CartPole-v1 环境的任务是控制一个小车,使连接在小车上的杆子保持平衡。代码通过构建一个神经网络作为策略网络,在与环境的交互中不断学习,以找到能获得最大累计奖励的策略。

二、依赖库

  1. gym:OpenAI 开发的强化学习环境库,用于创建和管理各种强化学习任务的环境,这里使用其 CartPole-v1 环境。
  2. torch:PyTorch 深度学习框架,用于构建神经网络模型、进行张量运算和自动求导。
  3. torch.nn:PyTorch 中用于定义神经网络层和模型结构的模块。
  4. torch.optim:PyTorch 中的优化器模块,用于更新神经网络的参数。
  5. numpy:用于进行数值计算和数组操作。
  6. torch.distributions.Categorical:PyTorch 中用于处理分类分布的模块,用于从策略网络输出的动作概率分布中采样动作。
  7. matplotlib.pyplot:用于绘制训练过程中的奖励曲线,可视化训练进度。

三、代码详细解释

3.1 REINFORCE 专用策略网络类 REINFORCEPolicy

收起

python

class REINFORCEPolicy(nn.Module):def __init__(self, input_dim, output_dim):super().__init__()self.net = nn.Sequential(nn.Linear(input_dim, 64),nn.ReLU(),nn.Linear(64, output_dim))def forward(self, x):return self.net(x)

  • 功能:定义一个简单的前馈神经网络作为策略网络,用于根据环境状态输出动作的概率分布。
  • 参数
    • input_dim:输入状态的维度,即环境状态的特征数量。
    • output_dim:输出的维度,即环境中可用动作的数量。
  • 结构
    • 包含两个全连接层(nn.Linear),中间使用 ReLU 激活函数(nn.ReLU)引入非线性。
    • 第一个全连接层将输入维度映射到 64 维,第二个全连接层将 64 维映射到输出维度。
  • 前向传播方法 forward:接收输入状态 x,并通过定义的网络层计算输出。

3.2 REINFORCE 训练函数 reinforce_train

收起

python

def reinforce_train(env, policy_net, optimizer, num_episodes=1000, gamma=0.99, lr_decay=0.995, baseline=True):...

  • 功能:使用 REINFORCE 算法训练策略网络。
  • 参数
    • env:OpenAI Gym 环境对象。
    • policy_net:策略网络模型。
    • optimizer:用于更新策略网络参数的优化器。
    • num_episodes:训练的总回合数,默认为 1000。
    • gamma:折扣因子,用于计算未来奖励的折扣,默认为 0.99。
    • lr_decay:学习率衰减因子,默认为 0.995。
    • baseline:布尔值,指示是否使用基线来降低方差,默认为 True
  • 训练流程
    1. 数据收集阶段
      • 每个回合开始时,重置环境状态。
      • 在回合中,不断与环境交互,直到回合结束。
      • 对于每个时间步,将状态转换为张量,通过策略网络得到动作的 logits,使用 Categorical 分布采样动作,并记录动作的对数概率。
      • 执行动作,获取下一个状态、奖励和回合是否结束的信息。
      • 将状态、动作、奖励和对数概率存储在 episode_data 字典中。
    2. 计算蒙特卡洛回报
      • 从最后一个时间步开始,反向计算每个时间步的累计折扣奖励 G
      • 将计算得到的回报存储在 returns 列表中,并转换为张量。
    3. 可选基线处理
      • 如果 baseline 为 True,对回报进行标准化处理,以降低方差。
    4. 计算策略梯度损失
      • 对于每个时间步的对数概率和回报,计算策略梯度损失。
      • 将所有时间步的损失相加得到总损失。
    5. 参数更新
      • 清零优化器的梯度。
      • 进行反向传播计算梯度。
      • 使用优化器更新策略网络的参数。
    6. 学习率衰减
      • 如果 lr_decay 不为 None,每 100 个回合衰减一次学习率。
    7. 记录训练进度
      • 记录每个回合的总奖励。
      • 每 50 个回合输出一次平均奖励和当前学习率。
      • 如果平均奖励达到环境的奖励阈值,输出解决信息并提前结束训练。

3.3 主程序部分

收起

python

if __name__ == "__main__":...

  • 功能:创建环境,初始化策略网络和优化器,进行训练,保存模型,并可视化训练进度。
  • 步骤
    1. 创建环境:使用 gym.make 创建 CartPole-v1 环境,并获取状态维度和动作维度。
    2. 初始化网络和优化器
      • 创建 REINFORCEPolicy 策略网络实例。
      • 使用 Adam 优化器,设置较高的初始学习率(lr = 1e-2)。
    3. 训练模型:调用 reinforce_train 函数进行训练,设置训练回合数为 800。
    4. 保存模型:使用 torch.save 保存训练好的策略网络的参数。
    5. 可视化训练进度
      • 使用 matplotlib.pyplot 绘制每个回合的总奖励曲线。
      • 设置 x 轴标签为 “回合数”,y 轴标签为 “总奖励”,标题为 “REINFORCE 训练进度”。
      • 显示绘制的图形。

四、注意事项

  • 代码使用了新版 Gym API,确保你的 Gym 库版本支持 env.reset() 和 env.step() 的返回值格式。
  • 可以根据实际情况调整超参数,如 num_episodesgammalr_decay 和初始学习率,以获得更好的训练效果。
  • 训练可能需要一定的时间,尤其是在计算资源有限的情况下,可以适当减少 num_episodes 来加快训练速度。

完整代码

import gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.distributions import Categorical
import matplotlib.pyplot as plt# REINFORCE专用策略网络
class REINFORCEPolicy(nn.Module):def __init__(self, input_dim, output_dim):super().__init__()self.net = nn.Sequential(nn.Linear(input_dim, 64),nn.ReLU(),nn.Linear(64, output_dim))def forward(self, x):return self.net(x)def reinforce_train(env, policy_net, optimizer, num_episodes=1000, gamma=0.99, lr_decay=0.995, baseline=True):rewards_history = []lr = optimizer.param_groups[0]['lr']for ep in range(num_episodes):# 数据收集阶段state, _ = env.reset()episode_data = {'states': [], 'actions': [], 'rewards': [], 'log_probs': []}done = Falsewhile not done:state_tensor = torch.FloatTensor(state)logits = policy_net(state_tensor)policy = Categorical(logits=logits)action = policy.sample()log_prob = policy.log_prob(action)next_state, reward, terminated, truncated, _ = env.step(action.item())done = terminated or truncated# 存储轨迹数据episode_data['states'].append(state_tensor)episode_data['actions'].append(action)episode_data['rewards'].append(reward)episode_data['log_probs'].append(log_prob)state = next_state# 计算蒙特卡洛回报returns = []G = 0for r in reversed(episode_data['rewards']):G = r + gamma * Greturns.insert(0, G)returns = torch.tensor(returns)# 可选基线(降低方差)if baseline:returns = (returns - returns.mean()) / (returns.std() + 1e-8)# 计算策略梯度损失policy_loss = []for log_prob, G in zip(episode_data['log_probs'], returns):policy_loss.append(-log_prob * G)total_loss = torch.stack(policy_loss).sum()  # 使用 torch.stack() 代替 torch.cat()# 参数更新optimizer.zero_grad()total_loss.backward()optimizer.step()# 学习率衰减if lr_decay:new_lr = lr * (0.99 ** (ep//100))optimizer.param_groups[0]['lr'] = new_lr# 记录训练进度total_reward = sum(episode_data['rewards'])rewards_history.append(total_reward)# 进度输出if (ep+1) % 50 == 0:avg_reward = np.mean(rewards_history[-50:])print(f"Episode {ep+1} | Avg Reward: {avg_reward:.1f} | LR: {optimizer.param_groups[0]['lr']:.2e}")if avg_reward >= env.spec.reward_threshold:print(f"Solved at episode {ep+1}!")breakreturn rewards_historyif __name__ == "__main__":env = gym.make('CartPole-v1')state_dim = env.observation_space.shape[0]action_dim = env.action_space.n# 初始化REINFORCE专用网络policy_net = REINFORCEPolicy(state_dim, action_dim)optimizer = optim.Adam(policy_net.parameters(), lr=1e-2)  # 更高初始学习率# 训练rewards = reinforce_train(env, policy_net, optimizer, num_episodes=800)# 保存与测试(同原代码)torch.save(policy_net.state_dict(), 'reinforce_cartpole.pth')plt.plot(rewards)plt.xlabel('Episode')plt.ylabel('Total Reward')plt.title('REINFORCE Training Progress')plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/71204.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【iOS】小蓝书学习(七)

小蓝书学习(七) 前言第47条:熟悉系统框架第48条:多用枚举块,少用for循环第50条:构建缓存使选用NSCache而非NSDictionary第51条:精简initialize与load的实现代码第52条:别忘了NSTimer…

SyntaxError: positional argument follows keyword argument

命令行里面日常练手爬虫不注意遇到的问题,报错说参数位置不正确 修改代码后,运行如下图: 结果: 希望各位也能顺利解决问题,祝你好运!

drawDB:一款免费数据库设计工具

drawDB 是一款基于 Web 的免费数据库设计工具,通过拖拽、复制、粘贴等方式进行数据库建模设计,同时可以生成相应的 SQL 脚本。 功能特性 drawDB 目前可以支持 MySQL、MariaDB、PostgreSQL、SQL Server 以及 SQLite 数据库,核心功能包括&…

FPGA开发,使用Deepseek V3还是R1(9):FPGA的全流程(详细版)

以下都是Deepseek生成的答案 FPGA开发,使用Deepseek V3还是R1(1):应用场景 FPGA开发,使用Deepseek V3还是R1(2):V3和R1的区别 FPGA开发,使用Deepseek V3还是R1&#x…

Hive-05之查询 分组、排序、case when、 什么情况下Hive可以避免进行MapReduce

一、目标 掌握hive中select查询语句中的基本语法掌握hive中select查询语句的分组掌握hive中select查询语句中的join掌握hive中select查询语句中的排序 二、要点 1. 基本查询 注意 SQL 语言大小写不敏感SQL 可以写在一行或者多行关键字不能被缩写也不能分行各子句一般要分行…

人工智能之数学基础:矩阵的范数

本文重点 在前面课程中,我们学习了向量的范数,在矩阵中也有范数,本文来学习一下。矩阵的范数对于分析线性映射函数的特性有重要的作用。 矩阵范数的本质 矩阵范数是一种映射,它将一个矩阵映射到一个非负实数。 矩阵的范数 前面我们学习了向量的范数,只有当满足几个条…

I2C驱动(十一) -- gpio模拟的i2c总线驱动i2c-gpio.c分析

相关文章 I2C驱动(一) – I2C协议 I2C驱动(二) – SMBus协议 I2C驱动(三) – 驱动中的几个重要结构 I2C驱动(四) – I2C-Tools介绍 I2C驱动(五) – 通用驱动i2c-dev.c分析 I2C驱动(六) – I2C驱动程序模型 I2C驱动(七) – 编写I2C设备驱动之i2c_driver I2C驱动(八) – 编写I2C…

(KTransformers) RTX4090单卡运行 DeepSeek-R1 671B

安装环境为:ubuntu 22.04 x86_64 下载模型 编辑文件vim url.list 写入如下内容 https://modelscope.cn/models/unsloth/DeepSeek-R1-GGUF/resolve/master/DeepSeek-R1-Q4_K_M/DeepSeek-R1-Q4_K_M-00001-of-00009.gguf https://modelscope.cn/models/unsloth/Dee…

海康威视摄像头ISUP(原EHOME协议) 摄像头实时预览springboot 版本java实现,并可以在浏览器vue前端播放(附带源码)

1.首先说了一下为什么要用ISUP协议来取流 ISUP主要就是用来解决摄像头没有公网ip的情况,如果摄像头或者所在局域网的路由器有公网ip的话,其实采用rtsp直接取流是最方便也是性能最好的,但是项目的摄像头没有公网IP所以被迫使用ISUP,ISUP是海康…

SpringBoot原理-03.自动配置-方案

一.自动配置原理 探究自动配置原理,就是探究spring是如何在运行时将要依赖JAR包提供的配置类和bean对象注入到IOC容器当中。我们当前准备一个maven项目itheima-utils,这里面定义了bean对象以及配置类,用来模拟第三方提供的依赖,首…

高频 SQL 50 题(基础版)_2356. 每位教师所教授的科目种类的数量

高频 SQL 50 题(基础版)_2356. 每位教师所教授的科目种类的数量 select teacher_id ,count(distinct(subject_id)) as cnt from Teacher group by teacher_id

神经网络之词嵌入模型(基于torch api调用)

一、Word Embedding(词嵌入)简介 Word Embedding(词嵌入): 词嵌入技术是自然语言处理(NLP)领域的一项重大创新,它极大地推动了计算机理解和处理人类语言的能力。 通过将单词、句子甚…

SpringBoot @Value 注解使用

Value 注解用于将配置文件中的属性值注入到Spring管理的Bean中。 1. 基本用法 Value 可以直接注入配置文件中的属性值。 配置文件 (application.properties 或 application.yml) 配置文件定义需要注入的数据。 consumer:username: lisiage: 23hobby: sing,read,sleepsubje…

Redis面试常见问题——使用场景问题

目录 Redis面试常见问题 如果发生了缓存穿透、击穿、雪崩,该如何解决? 缓存穿透 什么是布隆过滤器? 缓存击穿 缓存雪崩 双写一致性(redis做为缓存,mysql的数据如何与redis进行同步呢?) …

在Ubuntu 22.04 LTS 上安装 MySQL两种方式:在线方式和离线方式

Ubuntu安装MySQL 介绍: Ubuntu 是一款基于Linux操作系统的免费开源发行版,广受欢迎。它以稳定性、安全性和用户友好性而闻名,适用于桌面和服务器环境。Ubuntu提供了大量的软件包和应用程序,拥有庞大的社区支持和活跃的开发者社区…

用Java编写sql

1.概念 通过Java代码操作mysql数据库 数据库编程,是需要数据库服务器,提供一些API,供程序员调用的 2.安装 2.1下载 在程序中操作mysql需要先安装mysql的驱动包 并且要把驱动包引入到项目中 在中央仓库可以下载到驱动包(mvnrepository.…

Redis数据结构-List列表

1.List列表 列表类型适用于存储多个有序的字符串(这里的有序指的是强调数据排列顺序的重要,不是升序降序的意思),列表中的每个字符串称为元素(element),一个列表最多可以存储2^32-1个元素。在R…

Linux实操——在服务器上直接从百度网盘下载(/上传)文件

Linux Linux实操——在服务器上直接从百度网盘下载(/上传)文件 文章目录 Linux前言一、下载并安装bypy工具二、认证并授权网盘账号三、将所需文件转移至目的文件夹下四、下载文件五、上传文件六、更换绑定的百度云盘账户 前言 最近收到一批很大的数据&…

题解 | 牛客周赛82 Java ABCDEF

目录 题目地址 做题情况 A 题 B 题 C 题 D 题 E 题 F 题 牛客竞赛主页 题目地址 牛客竞赛_ACM/NOI/CSP/CCPC/ICPC算法编程高难度练习赛_牛客竞赛OJ 做题情况 A 题 判断字符串第一个字符和第三个字符是否相等 import java.io.*; import java.math.*; import java.u…

基金 word-->pdf图片模糊的解决方法

1. 首先需要Adobe或福昕等pdf阅读器。 2. word中 [文件]--[打印],其中打印机选择pdf阅读器,例如此处我选择福昕阅读器。 3. 选择 [打印机属性]--[编辑]--[图像],将所有的采样、压缩均设置为 关闭。点击[另存为],保存为 基金报告…