对抗式生成模仿学习(GAIL)

目录

1 预先基础知识 

1.1 对抗生成网络(GAN)

1.1.1 基本概念

1.1.2 损失函数

1.1.2.1 固定G,求解令损失函数最大的D

1.1.2.2 固定D,求解令损失函数最小的G

1.2 对抗式生成模仿学习特点

2 对抗式生成模仿学习(GAIL)详细说明

3 参考文献

1 预先基础知识 

1.1 对抗生成网络(GAN)

1.1.1 基本概念

在GAN生成对抗网络中,包含两个模型,一个生成模型,一个判别模型。

  • 生成模型:负责生成看起来真实自然,和原始数据相似的实例。
  • 判别模型:负责判断给出的实例是真实的还是人为伪造的。

生成模型努力去欺骗判别模型,判别模型努力不被欺骗,这样两种模型交替优化训练,都得到了提升。

对于辨别器,如果得到的是生成图片辨别器应该输出0,如果是真实的图片应该输出 1,得到误差梯度反向传播来更新参数。对于生成器,首先由生成器生成一张图片,然后输入给判别器判别并的到相应的误差梯度,然后反向传播这些图片梯度成为组成生成器的权重。直观上来说就是:辨别器不得不告诉生成器如何调整从而使它生成的图片变得更加真实。

1.1.2 损失函数

GAN模型的目标函数:

其中,参考GAN的架构图,字母 V是原始GAN论文中指定用来表示该交叉熵的字母,x 表示任意真实数据,z 表示与真实数据相同结构的任意随机数据,G(z)表示在生成器中基于 z 生成的假数据,而D(x)表示判别器在真实数据 x上判断出的结果,D(G(z))表示判别器在假数据 G(z)上判断出的结果,其中 D(x) 与D(G(z))都是样本为“真”的概率,即标签为1的概率。

上式,主要意思是先固定生成器G,从判别器D的角度令损失最大化,紧接着固定D,从生成器G的角度令损失最小化,即可让判别器和生成器在共享损失的情况下实现对抗。其中第一个期望\mathbb{E}_{x \sim p_{\text{data}}(x)} \left[ \log D(x) \right]是所有x都是真实数据时(log(D(x)))的期望,第二个期望\mathbb{E}_{z \sim p(z)} \left[ \log (1 - D(G(z))) \right]是所有数据都是生成数据时log(1-D(G(z)))的期望。可以看出,在求解最优解的过程中存在两个过程:

  • 固定G,求解令损失函数最大的D
  • 固定D,求解令损失函数最小的G

判别网络是一个2分类,目标是分清真实数据和伪造数据,也就是希望D(x) 趋近于1,D(G(z))趋近于0,这也就体现了对抗的思想。G网络的loss是log(1-D(G(z))),D的loss是-(log(D(x)))+log(1-D(G(z)))。

1.1.2.1 固定G,求解令损失函数最大的D

判别器D的输入x有两部分:一部分是真实数据,设其分布为P_{\text{data}}(x);另一部分是生成器生成的数据,参考架构图,生成器接收的数据z服从分布P(z),A输入z经过生成器的计算生成的数据分布设为P_{G}(x)

这两部分这两部分都是判别器D的输入,不同的是,G的输出来自分布P_{G}(x),而真实数据来自分布P_{\text{data}}(x),经过一系列推导后的结果:

可以看出,固定G,将最优的D带入后,此时V(G,D*),实际上是在度量P_{\text{data}}(x)P_{G}(x)之间的JS散度,同KL散度一样,他们之间的分布差异越大,JS散度值也越大。换句话说:保持G不变,最大化V(G,D)就等价于计算JS散度。对于判别器来说,尽可能找出生成器生成的数据与真实数据分布之间的差异,这个差异就是JS散度。

1.1.2.2 固定D,求解令损失函数最小的G

对于生成器来说,让生成器生成的数据分布接近真实数据分布。现在第一步已经求出了最优解的D*,代入损失函数:

在最小化JS散度,JS散度越小,分部之间的差异越小,正好印证了第二个原则。

1.2 对抗式生成模仿学习特点

逆强化学习(Inverse Reinforcement Learning, IRL)作为一种典型的模仿学习方法,顾名思义,逆强化学习的学习过程与正常的强化学习利用奖励函数学习策略相反,不利用现有的奖励函数,而是试图学出一个奖励函数,并以之指导基于奖励函数的强化学习过程。IRL可以归结为解决从观察到的最优行为中提取奖励函数( Reward Function)的问题,这些最优行为也可以表示为专家策略 。基于IRL的方法交替地在两个过程中交替:一个阶段是使用示范数据来推断一个隐藏的奖励(Reward)或代价( Cost)函数,另一个阶段是使用强化学习基于推断的奖励函数来学习一个模仿策略。IRL的基本准则是:IRL选择奖励函数来优化策略,并且使得任何不同于\Pi _{E}的动作决策尽可能产生更大损失。

对抗式生成模仿学习(Generative Adversarial Imitation Learning,GAIL)是逆强化学习的一种重要实现方法之一。逆强化学习旨在从专家示范的行为中推断环境的奖励函数或者价值函数,而GAIL是逆强化学习的一种实现方式,它利用了生成对抗网络(GAN)的概念来进行模仿学习。

GAIL的关键点在于:

1生成对抗网络: GAIL使用生成对抗网络的框架,其中包括生成器和判别器。

2生成器与判别器: 生成器尝试生成与专家示范行为相似的状态-动作对,而判别器则尝试区分专家行为和生成器生成的行为。

3对抗优化: GAIL使用对抗训练的思想,通过生成器和判别器的对抗优化来使得生成器的输出逼近专家的行为。

GAIL的工作方式使得它在逆强化学习中发挥着重要作用,因为它提供了一种有效的方式来从专家示范中学习环境的奖励结构,以指导智能体的学习行为。通过对抗式生成模仿学习,智能体可以学习并模仿专家的行为,而无需显式地使用环境的奖励信号。

因此,GAIL作为逆强化学习的一种方法,为从专家示范中学习环境的奖励函数或者价值函数提供了一种有效的框架和方法。

2 对抗式生成模仿学习(GAIL)详细说明

 

生成式对抗模仿学习的整体优化流程如图所示。通过 GAIL 方法,策略生成器通过生成类似专家示教样本的探索样本,泛化示教样本的概率分布, 逼近专家示范行为数据,进而实现模仿专家技能的目的。该过程直接优化采样样本的概率分布,计算代价较小且算法通用性更强,实际模仿效果也更好。 

伪代码:

# 初始化策略 π、判别器 D、专家示范数据 D_expert、策略缓冲区 D_policy函数 GAIL_Training():初始化策略 π 的参数初始化判别器 D 的参数循环 直到收敛 或 达到最大迭代次数:# 使用当前策略 π 生成轨迹并存储在策略缓冲区 D_policy 中生成 trajectories 使用 π 并存储在 D_policy 中# 判别器训练循环 discriminator_updates 次数:# 从策略缓冲区 D_policy 中采样数据采样 (s_policy, a_policy) 从 D_policy 中# 从专家示范数据 D_expert 中采样数据采样 (s_expert, a_expert) 从 D_expert 中# 更新判别器 D计算 L_D = -[log(D(s_expert, a_expert)) + log(1 - D(s_policy, a_policy))]使用梯度下降法更新判别器参数以最小化 L_D# 策略更新采样 (s, a, ...) 从 D_policy 中计算伪奖励 r = -log(1 - D(s, a))# 使用伪奖励 r 更新策略 π计算 L_π 使用 PPO 或 其他强化学习方法使用梯度下降法更新策略 π 的参数以最大化 L_π

能够表征GAIL流程的主程序如下: 

import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import Adamfrom .ppo import PPO
from gail_airl_ppo.network import GAILDiscrimclass GAIL(PPO):def __init__(self, buffer_exp, state_shape, action_shape, device, seed,gamma=0.995, rollout_length=50000, mix_buffer=1,batch_size=64, lr_actor=3e-4, lr_critic=3e-4, lr_disc=3e-4,units_actor=(64, 64), units_critic=(64, 64),units_disc=(100, 100), epoch_ppo=50, epoch_disc=10,clip_eps=0.2, lambd=0.97, coef_ent=0.0, max_grad_norm=10.0):super().__init__(state_shape, action_shape, device, seed, gamma, rollout_length,mix_buffer, lr_actor, lr_critic, units_actor, units_critic,epoch_ppo, clip_eps, lambd, coef_ent, max_grad_norm)# Expert's buffer.self.buffer_exp = buffer_exp# Discriminator.self.disc = GAILDiscrim(state_shape=state_shape,action_shape=action_shape,hidden_units=units_disc,hidden_activation=nn.Tanh()).to(device)self.learning_steps_disc = 0self.optim_disc = Adam(self.disc.parameters(), lr=lr_disc)self.batch_size = batch_sizeself.epoch_disc = epoch_discdef update(self, writer):self.learning_steps += 1for _ in range(self.epoch_disc):self.learning_steps_disc += 1# Samples from current policy's trajectories.states, actions = self.buffer.sample(self.batch_size)[:2]# Samples from expert's demonstrations.states_exp, actions_exp = \self.buffer_exp.sample(self.batch_size)[:2]# Update discriminator.self.update_disc(states, actions, states_exp, actions_exp, writer)# We don't use reward signals here,states, actions, _, dones, log_pis, next_states = self.buffer.get()# Calculate rewards.rewards = self.disc.calculate_reward(states, actions)# Update PPO using estimated rewards.self.update_ppo(states, actions, rewards, dones, log_pis, next_states, writer)def update_disc(self, states, actions, states_exp, actions_exp, writer):# Output of discriminator is (-inf, inf), not [0, 1].logits_pi = self.disc(states, actions)logits_exp = self.disc(states_exp, actions_exp)# Discriminator is to maximize E_{\pi} [log(1 - D)] + E_{exp} [log(D)].loss_pi = -F.logsigmoid(-logits_pi).mean()loss_exp = -F.logsigmoid(logits_exp).mean()loss_disc = loss_pi + loss_expself.optim_disc.zero_grad()loss_disc.backward()self.optim_disc.step()if self.learning_steps_disc % self.epoch_disc == 0:writer.add_scalar('loss/disc', loss_disc.item(), self.learning_steps)# Discriminator's accuracies.with torch.no_grad():acc_pi = (logits_pi < 0).float().mean().item()acc_exp = (logits_exp > 0).float().mean().item()writer.add_scalar('stats/acc_pi', acc_pi, self.learning_steps)writer.add_scalar('stats/acc_exp', acc_exp, self.learning_steps)

3 参考文献

https://zhuanlan.zhihu.com/p/628915533

【强化学习】GAIL_gail算法-CSDN博客

代码:https://github.com/toshikwa/gail-airl-ppo.pytorch.git

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/852890.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【CS.PL】Lua 编程之道: 简介与环境设置 - 进度8%

1 初级阶段 —— 简介与环境设置 文章目录 1 初级阶段 —— 简介与环境设置1.1 什么是 Lua&#xff1f;特点?1.2 Lua 的应用领域1.3 安装 Lua 解释器1.3.1 安装1.3.2 Lua解释器的结构 1.4 Lua执行方式1.4.0 程序段1.4.1 使用 Lua REPL&#xff08;Read-Eval-Print Loop&#x…

【Gradio】Building With Blocks 控制布局

默认情况下&#xff0c;块中的组件垂直排列。让我们看看如何重新排列组件。在底层&#xff0c;这种布局结构使用了网页开发的 flexbox 模型。 行 在 with gr.Row 子句中的元素将全部水平显示。例如&#xff0c;要并排显示两个按钮&#xff1a; with gr.Blocks() as demo:with g…

JAVA云HIS医院管理系统源码 云HIS运维平台源码 SaaS模式支撑电子病历4级,HIS与电子病历系统均拥有自主知识产权

JAVA云HIS医院管理系统源码 云HIS运维平台源码 SaaS模式支撑电子病历4级&#xff0c;HIS与电子病历系统均拥有自主知识产权 系统简介&#xff1a; SaaS模式Java版云HIS系统&#xff0c;在公立二甲医院应用三年&#xff0c;经过多年持续优化和打磨&#xff0c;系统运行稳定、功…

SSH密钥认证:实现远程服务器免密登录的两种方法|Linux scp命令详解:高效实现文件与目录的远程传输

简介&#xff1a; 服务器之间经常需要有一些跨服务器的操作&#xff0c;此时就需要我们在一台服务器上登录到另外一台服务器&#xff0c;若是人为操作时我们都可以每次输入密码进行远程登录&#xff0c;但要是程序需要跨服务器时&#xff0c;每次输入密码就不现实了&#xff0c…

最新Sublime Text软件安装包分享(汉化版本)

Sublime Text 是一款广受欢迎的跨平台文本编辑器&#xff0c;专为代码、标记和散文编辑而设计。它以其简洁的用户界面、强大的功能和高性能而著称&#xff0c;深受开发者和写作者的喜爱。 一、下载地址 链接&#xff1a;https://pan.baidu.com/s/1kErSkvc7WnML7fljQZlcOg?pwdk…

解决:安装MySQL 5.7 的时候报错:unknown variable ‘mysqlx_port=0.0‘

目录 1. 背景2. 解决步骤 1. 背景 吐槽1&#xff0c;没被收购之前可以随便下载&#xff0c;现在下载要注册登录吐槽2&#xff0c;5.7安装到初始化数据库的时候就会报错&#xff0c;而8.x的可以一镜到底&#xff0c;一开始以为是国区的特色问题&#xff0c;google了一圈&#x…

[Algorithm][贪心][最长递增子序列][递增的三元子序列][最长连续递增序列][买卖股票的最佳时机][买卖股票的最佳时机Ⅱ]详细讲解

目录 1.最长递增子序列1.题目链接2.算法原理详解3.代码实现 2.递增的三元子序列1.题目链接2.算法原理详解3.题目链接 3.最长连续递增序列1.题目链接2.算法原理详解3.代码实现 4.买卖股票的最佳时机1.题目链接2.算法原理详解3.代码实现 5.买卖股票的最佳时机 II1.题目链接2.算法…

厂里资讯之总体架构介绍以及环境搭建

本项目是本人根据黑马程序员的微服务项目黑马头条进行包装改造&#xff0c;作为实习简历上面的项目&#xff0c;为了进一步熟悉深挖这个项目&#xff0c;写了这一系列的博客来加深自己对项目的理解。 概述 项目背景 本项目主要着手于使用户获取学校最新最热的资讯&#xff0c…

使用 ML.NET CLI 自动进行模型训练

ML.NET CLI 可为 .NET 开发人员自动生成模型。 若要单独使用 ML.NET API(不使用 ML.NET AutoML CLI),需要选择训练程序(针对特定任务的机器学习算法的实现),以及要应用到数据的数据转换集(特征工程)。 每个数据集的最佳管道各不相同,从所有选择中选择最佳算法增加了复…

seata原理源码分析系列(一)架构, 组件

简介 SEATA开源的分布式事务解决方案&#xff0c;用于解决分布式系统中的数据一致性问题&#xff0c;由阿里巴巴开源。 分布式系统&#xff0c;数据存储在不同的资源管理器(数据库)&#xff0c;需要保证分布式事务的原子性&#xff0c;业界比较常用xa&#xff0c;数据库标准实现…

C语言 | Leetcode C语言题解之第151题反转字符串中的单词

题目&#xff1a; 题解&#xff1a; void myResverse(char* s,int start,int end){while(start<end){char temp s[start];s[start] s[end];s[end] temp;start;end--;} } char* reverseWords(char* s) {int start 0;int end strlen(s)-1;myResverse(s,start,end);if(s[…

面试题:Redis是什么?有什么作用?怎么测试?

有些测试朋友来问我&#xff0c;redis要怎么测试&#xff1f;首先我们需要知道&#xff0c;redis是什么&#xff1f;它能做什么&#xff1f; redis是一个key-value类型的高速存储数据库。 redis常被用做&#xff1a;缓存、队列、发布订阅等。 所以&#xff0c;“redis要怎么测试…

Linux系统使用Docker安装Dashy导航页结合内网穿透一键发布公网

文章目录 简介1. 安装Dashy2. 安装cpolar3.配置公网访问地址4. 固定域名访问 简介 Dashy 是一个开源的自托管的导航页配置服务&#xff0c;具有易于使用的可视化编辑器、状态检查、小工具和主题等功能。你可以将自己常用的一些网站聚合起来放在一起&#xff0c;形成自己的导航…

机器视觉:工业镜头的主要参数

工业镜头是图像采集系统的重要光学设备。它的作用是将目标物体的像成在相机的感光面上。 一、工业镜头原理 镜头是对光线进行调制和变换&#xff0c;使目标能够成像到相机的感光芯片上。将不同折射率的硝材加工成高精度的曲面&#xff0c;再把这些曲面进行组合后设计成能够满…

秋招突击——6/14——复习{(树形DP)树的最长路径}——新作{非递归求二叉树的深度、重复区间合并}

文章目录 引言复习树形DP——树的最长路径 新作使用dfs非递归计算二叉树的深度多个区间合并删除问题实现思路实现代码参考思路 总结 引言 这两天可能有点波动&#xff0c;但是算法题还是尽量保证复习和新作一块弄&#xff0c;数量上可能有所差别。 复习 树形DP——树的最长路…

React state(及组件) 的保留与重置

当在树中相同的位置渲染相同的组件时&#xff0c;React 会一直保留着组件的 state return (<div><Counter />{showB && <Counter />} </div> ) // 当 showB 为 false, 第二个计数器停止渲染&#xff0c;它的 state 完全消失了。这是因为 React…

vite.config.js如何使用env的环境变量

了解下环境变量在vite中 官方文档走起 https://cn.vitejs.dev/guide/env-and-mode.html#env-variables-and-modes 你见到的.env,.env.production等就是放置环境变量的 官方文档说到.env.[mode] # 只在指定模式下加载,比如.env.development只在开发环境加载 至于为什么是deve…

windows下open webui+ollama+sd webui

原文&#xff1a;https://wangguo.site/Blog/2024/Q2/2024-06-14/ 说明&#xff1a;安装使用环境是在Windows下 1、给ollama一个好看的交互界面&#xff08;open webui&#xff09; 1.1、ollama安装 安装&#xff1a;在ollama官网下载windows版本进行安装 模型列表&#xff1…

【SQLAlChemy】表之间的关系,外键如何使用?

表之间的关系 数据库表之间的关系分为三种&#xff1a; 一对一关系&#xff08;One-to-One&#xff09;&#xff1a;在这种关系中&#xff0c;表A的每一行都与表B的一行关联&#xff0c;反之亦然。例如&#xff0c;每个人都有一个唯一的社保号&#xff0c;每个社保号也只属于…

南师大GIS专业2024排名NO.1!!!

南师大GIS 666 学科专业实力666&#xff0c;研究方向多多多&#xff01; 有学术方向有开发应用方向&#xff0c; 有GIS&#xff08;建模、数字地形、基础理论和三维GIS等&#xff09;、 有Cartography &#xff08;叙事地图、动态地图、地图风格迁移等&#…