强化学习原理python篇06——DQN

强化学习原理python篇05——DQN

  • DQN 算法
    • 定义DQN网络
    • 初始化环境
    • 开始训练
    • 可视化结果

本章全篇参考赵世钰老师的教材 Mathmatical-Foundation-of-Reinforcement-Learning Deep Q-learning 章节,请各位结合阅读,本合集只专注于数学概念的代码实现。

DQN 算法

1)使用随机权重 ( w ← 1.0 ) (w←1.0) w1.0初始化目标网络 Q ( s , a , w ) Q(s, a, w) Q(s,a,w)和网络 Q ^ ( s , a , w ) \hat Q(s, a, w) Q^(s,a,w) Q Q Q Q ^ \hat Q Q^相同,清空回放缓冲区。

2)以概率ε选择一个随机动作a,否则 a = a r g m a x Q ( s , a , w ) a=argmaxQ(s,a,w) a=argmaxQ(s,a,w)

3)在模拟器中执行动作a,观察奖励r和下一个状态s’。

4)将转移过程(s, a, r, s’)存储在回放缓冲区中。

5)从回放缓冲区中采样一个随机的小批量转移过程。

6)对于回放缓冲区中的每个转移过程,如果片段在此步结束,则计算目标 y = r y=r y=r,否则计算 y = r + γ m a x Q ^ ( s , a , w ) y=r+\gamma max \hat Q(s, a, w) y=r+γmaxQ^(s,a,w)

7)计算损失: L = ( Q ( s , a , w ) – y ) 2 L=(Q(s, a, w)–y)^2 L=(Q(s,a,w)y)2

8)固定网络 Q ^ ( s , a , w ) \hat Q(s, a, w) Q^(s,a,w)不变,通过最小化模型参数的损失,使用SGD算法更新 Q ( s , a ) Q(s, a) Q(s,a)

9)每N步,将权重从目标网络 Q Q Q复制到 Q ^ ( s , a , w ) \hat Q(s, a, w) Q^(s,a,w)

10)从步骤2开始重复,直到收敛为止。

定义DQN网络

import collections
import copy
import random
from collections import defaultdict
import math
import gym
import gym.spaces
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from gym.envs.toy_text import frozen_lake
from torch.utils.tensorboard import SummaryWriterclass Net(nn.Module):def __init__(self, obs_size, hidden_size, q_table_size):super(Net, self).__init__()self.net = nn.Sequential(# 输入为状态,样本为(1*n)nn.Linear(obs_size, hidden_size),nn.ReLU(),# nn.Linear(hidden_size, hidden_size),# nn.ReLU(),nn.Linear(hidden_size, q_table_size),)def forward(self, state):return self.net(state)class DQN:def __init__(self, env, tgt_net, net):self.env = envself.tgt_net = tgt_netself.net = netdef generate_train_data(self, batch_size, epsilon):state, _ = env.reset()train_data = []while len(train_data)<batch_size*2:q_table_tgt = self.tgt_net(torch.Tensor(state)).detach()if np.random.uniform(0, 1, 1) > epsilon:action = self.env.action_space.sample()else:action = int(torch.argmax(q_table_tgt))new_state, reward,terminated, truncted, info = env.step(action)train_data.append([state, action, reward, new_state, terminated])state = new_stateif terminated:state, _ = env.reset()continuerandom.shuffle(train_data)                return train_data[:batch_size]def calculate_y_hat_and_y(self, batch):# 6)对于回放缓冲区中的每个转移过程,如果片段在此步结束,则计算目标$y=r$,否则计算$y=r+\gamma max \hat Q(s, a, w)$ 。y = []state_space = []action_space = []for state, action, reward, new_state, terminated in batch:# y值if terminated:y.append(reward)else:# 下一步的 qtable 的最大值q_table_net = self.net(torch.Tensor(np.array([new_state]))).detach()y.append(reward + gamma * float(torch.max(q_table_net)))# y hat的值state_space.append(state)action_space.append(action)idx = [list(range(len(action_space))), action_space]y_hat = self.tgt_net(torch.Tensor(np.array(state_space)))[idx]return y_hat, torch.tensor(y)def update_net_parameters(self, update=True):self.net.load_state_dict(self.tgt_net.state_dict())

初始化环境

   # 初始化环境
env = gym.make("CartPole-v1")
# env = DiscreteOneHotWrapper(env)hidden_num = 64
# 定义网络
net = Net(env.observation_space.shape[0],hidden_num, env.action_space.n)
tgt_net = Net(env.observation_space.shape[0],hidden_num, env.action_space.n)
dqn = DQN(env=env, net=net, tgt_net=tgt_net)# 初始化参数
# dqn.init_net_and_target_net_weight()# 定义优化器
opt = optim.Adam(tgt_net.parameters(), lr=0.001)# 定义损失函数
loss = nn.MSELoss()# 记录训练过程
# writer = SummaryWriter(log_dir="logs/DQN", comment="DQN")

开始训练

gamma = 0.8
for i in range(10000):batch = dqn.generate_train_data(256, 0.8)y_hat, y = dqn.calculate_y_hat_and_y(batch)opt.zero_grad()l = loss(y_hat, y)l.backward()opt.step()print("MSE: {}".format(l.item()))if i % 5 == 0:dqn.update_net_parameters(update=True)

输出:

MSE: 0.027348674833774567
MSE: 0.1803671419620514
MSE: 0.06523636728525162
MSE: 0.08363766968250275
MSE: 0.062360599637031555
MSE: 0.004909628536552191
MSE: 0.05730309337377548
MSE: 0.03543371334671974
MSE: 0.08458714932203293

可视化结果

env = gym.make("CartPole-v1", render_mode = "human")
env = gym.wrappers.RecordVideo(env, video_folder="video")state, info = env.reset()
total_rewards = 0while True:q_table_state = dqn.tgt_net(torch.Tensor(state)).detach()# if np.random.uniform(0, 1, 1) > 0.9:#     action = env.action_space.sample()# else:action = int(torch.argmax(q_table_state))state, reward, terminated, truncted, info = env.step(action)if terminated:break

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/653255.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[Python图像处理] 使用OpenCV创建深度图

使用OpenCV创建深度图 双目视觉创建深度图相关链接双目视觉 在传统的立体视觉中,两个摄像机彼此水平移动,用于获得场景上的两个不同视图(作为立体图像),就像人类的双目视觉系统: 通过比较这两个图像,可以以视差的形式获得相对深度信息,该视差编码对应图像点的水平坐标的…

基于Python 网络爬虫和可视化的房源信息的设计与实现

摘 要 一般来说&#xff0c;在房地产行业&#xff0c;房源信息采集&#xff0c;对企业来说至关重要&#xff0c;通过人工采集数据的方式进行数据收集&#xff0c;既耗时又费力&#xff0c;影响工作效率&#xff0c;还导致信息时效性变差&#xff0c;可靠性偏低&#xff0c;不利…

QWT开源库使用

源代码地址&#xff1a;Qwt Users Guide: Qwt - Qt Widgets for Technical Applications Qwt库包含GUI组件和实用程序类&#xff0c;它们主要用于具有技术背景的程序。除了2D图的框架外&#xff0c;它还提供刻度&#xff0c;滑块&#xff0c;刻度盘&#xff0c;指南针&#xf…

matlab appdesigner系列-仪器仪表4-旋钮(离散)

旋钮&#xff08;离散&#xff09;&#xff0c;或叫分档旋钮&#xff0c;跟旋钮的连续性相区别&#xff0c;呈分档性。 示例&#xff1a;模拟空调档位切换 操作步骤&#xff1a; 1&#xff09;将旋钮&#xff08;离散&#xff09;、信号灯、标签拖拽到画布上&#xff0c;并设…

CSS 星空按钮

<template><button class="btn" type="button"><strong>星空按钮</strong><div id="container-stars"><div id="stars"></div></div><div id="glow"><div class=…

Kafka-服务端-GroupMetadataManager

GroupMetadataManager是GroupCoordinator中负责管理Consumer Group元数据以及其对应offset信息的组件。 GroupMetadataManager底层使用Offsets Topic,以消息的形式存储Consumer Group的GroupMetadata信息以及其消费的每个分区的offset,如图所示。 consumer_offsets的某Partiti…

ffmpeg4.0.4 ffmpeg.c 讲解

ffmpeg.c 是 FFmpeg 中的一个核心文件&#xff0c;负责实现 FFmpeg 命令行工具的主要功能。这个文件包含了 FFmpeg 命令行工具的入口函数 main()&#xff0c;以及与命令行参数解析、多媒体处理、编解码、封装格式处理等相关的功能实现。 int main(int argc, char **argv) {int…

每日一题——LeetCode1365.有多少小于当前数字的数字

方法一 暴力循环 对于数组里的没一个元素都遍历一遍看有多少元素小于当前元素 var smallerNumbersThanCurrent function(nums) {let n nums.length;let ret [];for (let i 0; i < n; i) {let count 0;for (let j 0; j < n; j) {if (nums[j] < nums[i]) {count…

菜谱的未来:SpringBoot, Vue与MySQL的智能推荐系统设计

✍✍计算机编程指导师 ⭐⭐个人介绍&#xff1a;自己非常喜欢研究技术问题&#xff01;专业做Java、Python、微信小程序、安卓、大数据、爬虫、Golang、大屏等实战项目。 ⛽⛽实战项目&#xff1a;有源码或者技术上的问题欢迎在评论区一起讨论交流&#xff01; ⚡⚡ Java实战 |…

算法学习记录:动态规划

前言&#xff1a; 算法学习记录不是算法介绍&#xff0c;本文记录的是从零开始的学习过程&#xff08;见到的例题&#xff0c;代码的理解……&#xff09;&#xff0c;所有内容按学习顺序更新&#xff0c;而且不保证正确&#xff0c;如有错误&#xff0c;请帮助指出。 学习工具…

51单片机——电动车报警器

51单片机——电动车报警器 1.震动控制灯 硬件:震动传感器,51单片机 #include "reg52.h"sbit led1 P3^7;//根据原理图&#xff08;电路图&#xff09;&#xff0c;设备变量led1指向P3组IO口的第7口 sbit vibrate P3^3;//Do接到了P3.3口void Delay2000ms() //11.…

PINN物理信息网络 | 全局自适应物理信息神经网络SA-PINN

概述 本文提出的自适应加权方法在于权重适用于不同损失组件中的个别训练点,而不是整个损失组件。之前的方法可以被看作是这个方法的一个特例,当所有针对特定损失组件的自适应权重同时更新时。在之前的方法中,独立开发的极小极大加权方案[16]与SA-PINNs最为相近,因为它也通过…

Mac terminal/vi/vim 编译器 命令总结

一个程序员的自述&#xff1a; 纯纯的脚本编程&#xff0c;去工具化&#xff0c; 一个终端解决战斗&#xff0c; 乃我辈之云云尔。 你别管&#xff01;&#xff01; Mac terminal cd 切换路径ls 当前目录内容pwd 当前文件路径cp 复制 cp file.text /destinationmv 移动(或重命…

vue.js中如何使用动态组件。

使用场景&#xff1a; 在不同的情况下展示相应的组件。 在日常开发中&#xff0c;当我们考虑到要简化代码的情况下&#xff0c;我们要进行模块化&#xff0c;写很多组件&#xff0c;如何动态展示组件呢&#xff1f; 使用 <component is"" ></component>…

JavaScript浅拷贝和深拷贝

浅拷贝和深拷贝的区别 浅拷贝let a 10;let ba;a20console.log(b)//10 1&#xff0c;由于a和b基本类型并且都是在栈中的&#xff0c;它们分别进行保存&#xff0c;所以这里输出的b还是102&#xff0c;通过内存可以看出&#xff0c;它们的两个值是独立的&#xff0c;更改其中一…

CMake 完整入门教程(五)

CMake 使用实例 13.1 例子一 一个经典的 C 程序&#xff0c;如何用 cmake 来进行构建程序呢&#xff1f; //main.c #include <stdio.h> int main() { printf("Hello World!/n"); return 0; } 编写一个 CMakeList.txt 文件 ( 可看做 cmake 的…

深度学习之多分类问题

多分类问题&#xff1a; 我们在解决的时候会使用到一种叫做SoftMax的分类器。 前面我们在做糖尿病问题的时候&#xff0c;我们做出一个二分类网络&#xff0c;我们得到的是y1&#xff08;即一年后发病&#xff09;它的概率是多少&#xff0c;即P&#xff08;y1&#xff09;。这…

【Linux C | 网络编程】详细介绍 “三次握手(建立连接)、四次挥手(终止连接)、TCP状态”

&#x1f601;博客主页&#x1f601;&#xff1a;&#x1f680;https://blog.csdn.net/wkd_007&#x1f680; &#x1f911;博客内容&#x1f911;&#xff1a;&#x1f36d;嵌入式开发、Linux、C语言、C、数据结构、音视频&#x1f36d; &#x1f923;本文内容&#x1f923;&a…

Linux命令-apt-key命令(管理Debian Linux系统中的软件包密钥)

补充说明 apt-key命令 用于管理Debian Linux系统中的软件包密钥。每个发布的deb包&#xff0c;都是通过密钥认证 的&#xff0c;apt-key用来管理密钥。 语法 apt-key(参数)参数 操作指令&#xff1a;APT密钥操作指令。 实例 apt-key list # 列出已保存在系统中key。 apt-…

微服务架构的实现:选择最佳方案,构建未来的应用生态

目录 一、概述 1.1 什么是微服务架构&#xff1f; 1.2 微服务架构的优势和劣势 二、微服务架构的设计原则与最佳实践 2.1 单一职责原则 2.2 服务自治与自治团队 2.3. 松耦合与高内聚 2.4 服务边界的划分 2.5 服务间通信方式选择 2.6 高可用与容错设计 2.7 服务监控…