从PyTorch官方的一篇教程说开去(2 - 源码)

先上图,上篇文章的运行结果,可以看到,算法在迭代了200来次左右达到人生巅峰,倒立摆金枪不倒,可以扛住连续200次操作。不幸的是,然后就出现了大幅度的回撤,每况愈下,在600次时候居然和100次的时候一个水平。

事实上,训练充满了随机性,也不乏非常漂亮的曲线,可以用来tree new bee,这也是AI领域很好水论文的体现吧。

下面两个图,分别来自,windows11本地运行 vs Colab云端运行。

呃,这个就是为啥G家主推的深度学习目前应用场景窄,被openAI狠揍的核心原因了 -

1)只能处理离散模型,数据量要求极高;

2)模型通用性差,不同的场景需要定制算法;

虽然G家多次宣称“霸权”,但事实上这个技术栈确实不适合解决通用问题。

开箱即食,以下为代码(单一python文件,windows11 + python 3.11.9 + GTX1080显卡) - 

import gymnasium as gym
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import countimport torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as Fenv = gym.make("CartPole-v1", render_mode="human")# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:from IPython import displayplt.ion()# if GPU is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
##device = torch.device("cuda")Transition = namedtuple('Transition',('state', 'action', 'next_state', 'reward'))class ReplayMemory(object):def __init__(self, capacity):self.memory = deque([], maxlen=capacity)def push(self, *args):"""Save a transition"""self.memory.append(Transition(*args))def sample(self, batch_size):return random.sample(self.memory, batch_size)def __len__(self):return len(self.memory)class DQN(nn.Module):def __init__(self, n_observations, n_actions):super(DQN, self).__init__()self.layer1 = nn.Linear(n_observations, 128)self.layer2 = nn.Linear(128, 128)self.layer3 = nn.Linear(128, n_actions)# Called with either one element to determine next action, or a batch# during optimization. Returns tensor([[left0exp,right0exp]...]).def forward(self, x):x = F.relu(self.layer1(x))x = F.relu(self.layer2(x))return self.layer3(x)# BATCH_SIZE is the number of transitions sampled from the replay buffer
# GAMMA is the discount factor as mentioned in the previous section
# EPS_START is the starting value of epsilon
# EPS_END is the final value of epsilon
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
# TAU is the update rate of the target network
# LR is the learning rate of the ``AdamW`` optimizer
BATCH_SIZE = 128
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 1000
TAU = 0.005
LR = 1e-4# Get number of actions from gym action space
n_actions = env.action_space.n
# Get the number of state observations
state, info = env.reset()
n_observations = len(state)policy_net = DQN(n_observations, n_actions).to(device)
target_net = DQN(n_observations, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
memory = ReplayMemory(10000)steps_done = 0def select_action(state):global steps_donesample = random.random()eps_threshold = EPS_END + (EPS_START - EPS_END) * \math.exp(-1. * steps_done / EPS_DECAY)steps_done += 1if sample > eps_threshold:with torch.no_grad():# t.max(1) will return the largest column value of each row.# second column on max result is index of where max element was# found, so we pick action with the larger expected reward.return policy_net(state).max(1)[1].view(1, 1)else:return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long)episode_durations = []def plot_durations(show_result=False):plt.figure(1)durations_t = torch.tensor(episode_durations, dtype=torch.float)if show_result:plt.title('Result')else:plt.clf()plt.title('Training...')plt.xlabel('Episode')plt.ylabel('Duration')plt.plot(durations_t.numpy())# Take 100 episode averages and plot them tooif len(durations_t) >= 100:means = durations_t.unfold(0, 100, 1).mean(1).view(-1)means = torch.cat((torch.zeros(99), means))plt.plot(means.numpy())plt.pause(0.001)  # pause a bit so that plots are updatedif is_ipython:if not show_result:display.display(plt.gcf())display.clear_output(wait=True)else:display.display(plt.gcf())def optimize_model():if len(memory) < BATCH_SIZE:returntransitions = memory.sample(BATCH_SIZE)# Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for# detailed explanation). This converts batch-array of Transitions# to Transition of batch-arrays.batch = Transition(*zip(*transitions))# Compute a mask of non-final states and concatenate the batch elements# (a final state would've been the one after which simulation ended)non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,batch.next_state)), device=device, dtype=torch.bool)non_final_next_states = torch.cat([s for s in batch.next_stateif s is not None])state_batch = torch.cat(batch.state)action_batch = torch.cat(batch.action)reward_batch = torch.cat(batch.reward)# Compute Q(s_t, a) - the model computes Q(s_t), then we select the# columns of actions taken. These are the actions which would've been taken# for each batch state according to policy_netstate_action_values = policy_net(state_batch).gather(1, action_batch)# Compute V(s_{t+1}) for all next states.# Expected values of actions for non_final_next_states are computed based# on the "older" target_net; selecting their best reward with max(1)[0].# This is merged based on the mask, such that we'll have either the expected# state value or 0 in case the state was final.next_state_values = torch.zeros(BATCH_SIZE, device=device)with torch.no_grad():next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0]# Compute the expected Q valuesexpected_state_action_values = (next_state_values * GAMMA) + reward_batch# Compute Huber losscriterion = nn.SmoothL1Loss()loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))# Optimize the modeloptimizer.zero_grad()loss.backward()# In-place gradient clippingtorch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)optimizer.step()num_episodes = 600for i_episode in range(num_episodes):# Initialize the environment and get it's statestate, info = env.reset()state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)for t in count():action = select_action(state)observation, reward, terminated, truncated, _ = env.step(action.item())reward = torch.tensor([reward], device=device)done = terminated or truncatedif terminated:next_state = Noneelse:next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)# Store the transition in memorymemory.push(state, action, next_state, reward)# Move to the next statestate = next_state# Perform one step of the optimization (on the policy network)optimize_model()# Soft update of the target network's weights# θ′ ← τ θ + (1 −τ )θ′target_net_state_dict = target_net.state_dict()policy_net_state_dict = policy_net.state_dict()for key in policy_net_state_dict:target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)target_net.load_state_dict(target_net_state_dict)if done:episode_durations.append(t + 1)plot_durations()breakprint('Complete')
plot_durations(show_result=True)
plt.ioff()
plt.show()

迭代次数是600,如果是使用CPU而不是GPU的话,建议设置在50以内,否则你懂的... ...

前置条件是安装老黄家的Cuda,以及准备好python环境(cuda暂不支持python 3.12),安装好需要的库,需要的可以看我此前的博文。

您的进步和反馈是我最大的动力,小伙伴来个三连呗!共勉。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/873184.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JVM知识点总结(全网最详细)!!!!

JVM知识总结 运行时数据区域程序计数器Java虚拟机栈局部变量表 StackOverflowError异常和OutOfMemoryError异常本地方法栈Java堆方法区运行时常量池 对象的创建对象的内存分配对象的内存布局对象头实例数据对齐填充 对象的访问定位使用句柄直接指针使用句柄和直接指针的优缺点 …

PHP房产中介租房卖房平台微信小程序系统源码

​&#x1f3e0;【租房卖房新选择】揭秘房产中介小程序&#xff0c;一键搞定置业大事&#xff01;&#x1f3e1; &#x1f50d;【开篇&#xff1a;告别繁琐&#xff0c;拥抱便捷】&#x1f50d; 还在为找房子跑断腿&#xff1f;为卖房发愁吗&#xff1f;今天给大家安利一个超…

【.NET全栈】ASP.NET开发Web应用——AJAX开发技术

文章目录 前言一、ASP.NET AJAX基础1、AJAX技术简介2、ASP.NET AJAX技术架构 二、ASP.NET AJAX服务器端扩展1、声明ScriptManager控件2、使用ScriptManager分发自定义脚本3、在ScriptManager中注册Web服务4、处理ScriptManager中的异常5、编程控制ScriptManager控件6、使用Upda…

如何高效定制视频扩散模型?卡内基梅隆提出VADER:通过奖励梯度进行视频扩散对齐

论文链接&#xff1a;https://arxiv.org/pdf/2407.08737 git链接&#xff1a;https://vader-vid.github.io/ 亮点直击&#xff1a; 引入奖励模型梯度对齐方法&#xff1a;VADER通过利用奖励模型的梯度&#xff0c;对多种视频扩散模型进行调整和对齐&#xff0c;包括文本到视频和…

如何评估 5G 毫米波相控阵天线模块

5G 新无线电 (5G NR) 是空中接口或无线接入网络 (RAN) 技术的行业标准和全球规范。它涵盖 6 GHz 及以下频率&#xff08;称为 FR1&#xff09;和 24 GHz 至 50 GHz 或更高频段&#xff08;称为 FR2 或 mmWave&#xff09;的运行。该技术可用于固定或移动接入、回程和日益流行的…

一些简单的基本知识(与C基本一致)

一、注释 1.单行注释&#xff1a;//&#xff08;快捷键&#xff1a;ctrlshift&#xff1f;&#xff0c;可以选择多行&#xff09; 2.多行注释&#xff1a;/* 文本 */ 二、变量 变量的作用是给一段内存空间起名&#xff0c;方便操作内存中的数据。 通过赋予某数据的…

逆向案例二十五——webpack所需模块函数很多,某翼云登录参数逆向。

解决步骤&#xff1a; 网址&#xff1a;aHR0cHM6Ly9tLmN0eXVuLmNuL3dhcC9tYWluL2F1dGgvbG9naW4 不说废话&#xff0c;密码有加密&#xff0c;直接搜索找到疑似加密位置打上断点。 再控制台打印&#xff0c;分析加密函数 有三个处理过程&#xff0c;b[g]得到的是用户名,b[f] 对…

数据结构初阶·排序算法(内排序)

目录 前言&#xff1a; 1 冒泡排序 2 选择排序 3 插入排序 4 希尔排序 5 快速排序 5.1 Hoare版本 5.2 挖坑法 5.3 前后指针法 5.4 非递归快排 6 归并排序 6.1递归版本归并 6.2 非递归版本归并 7 计数排序 8 排序总结 前言&#xff1a; 目前常见的排序算法有9种…

Torch-Pruning 库入门级使用介绍

项目地址&#xff1a;https://github.com/VainF/Torch-Pruning Torch-Pruning 是一个专用于torch的模型剪枝库&#xff0c;其基于DepGraph 技术分析出模型layer中的依赖关系。DepGraph 与现有的修剪方法&#xff08;如 Magnitude Pruning 或 Taylor Pruning&#xff09;相结合…

TCP重传机制详解

1.什么是TCP重传机制 在 TCP 中&#xff0c;当发送端的数据到达接收主机时&#xff0c;接收端主机会返回⼀个确认应答消息&#xff0c;表示已收到消息。 但是如果传输的过程中&#xff0c;数据包丢失了&#xff0c;就会使⽤重传机制来解决。TCP的重传机制是为了保证数据传输的…

React安装(学习版)

1. 安装Node.js和npm 首先&#xff0c;确保你的电脑上已经安装了Node.js和npm&#xff08;Node Package Manager&#xff09;。你可以从 Node.js官网 下载安装包并按照提示进行安装。安装完成后&#xff0c;可以在命令行终端中验证Node.js和npm是否正确安装&#xff1a; node …

【Node.js】初识 Node.js

Node.js 概念 Node.js 是一个开源与跨平台的 JavaScript运行时环境 &#xff0c;在浏览器外运行 V8 JavaScript 引擎(Google Chrome的内核)&#xff0c;利用事件驱动、非阻塞和异步输入输出 等技术提高性能。 可以理解为 Node.js就是一个服务器端的、非阻塞式 l/O 的、事件驱…

01 MySQL

学习资料&#xff1a;B站视频-黑马程序员JavaWeb基础教程 文章目录 JavaWeb整体介绍 MySQL1、数据库相关概念2、MySQL3、SQL概述4、DDL:数据库操作5、DDL:表操作6、DML7、DQL8、约束9、数据库设计10、多表查询11、事务 JavaWeb整体介绍 JavaWeb Web&#xff1a;全球广域网&…

PostgreSQL的逻辑架构

一、PostgreSql的逻辑架构&#xff1a; 一个server可以有多个database&#xff1b;一个database有多个schema&#xff0c;默认的schema是public&#xff1b;schema下才是对象&#xff0c;其中对象包含&#xff1a;表、视图、触发器、索引等&#xff1b;与user之间的关系&#x…

Mysql笔记-20240718

零、 help、\h、? 调出帮助 mysql> \hFor information about MySQL products and services, visit:http://www.mysql.com/ For developer information, including the MySQL Reference Manual, visit:http://dev.mysql.com/ To buy MySQL Enterprise support, training, …

免费恢复软件有哪些?电脑免费使用的 5 大数据恢复软件

您是否在发现需要的文件时不小心删除了回收站中的文件&#xff1f;您一定对误操作感到后悔。文件永远消失了吗&#xff1f;还有机会找回它们吗&#xff1f;当然有&#xff01;您可以查看这篇文章&#xff0c;挑选 5 款功能强大的免费数据恢复软件&#xff0c;用于 Windows 和 M…

<数据集>混凝土缺陷检测数据集<目标检测>

数据集格式&#xff1a;VOCYOLO格式 图片数量&#xff1a;7353张 标注数量(xml文件个数)&#xff1a;7353 标注数量(txt文件个数)&#xff1a;7353 标注类别数&#xff1a;6 标注类别名称&#xff1a;[exposed reinforcement, rust stain, Crack, Spalling, Efflorescence…

又缩水Unity7月闪促限时4折活动模块化角色模板编辑器场景美术插件拖尾怪物3D模型UI载具AI对话TPS飞机RPG和FPS202407

Flash Deals are Coming Back! 限时抢购又回来了&#xff01; July 17, 2024 8:00:00 PT to July 24, 2024 7:59:00 PT 太平洋时间 2024 年 7 月 17 日 8&#xff1a;00&#xff1a;00 至 2024 年 7 月 24 日 7&#xff1a;59&#xff1a;00&#xff08;太平洋时间&#xff09;…

云计算实训室的核心功能有哪些?

在当今数字化转型浪潮中&#xff0c;云计算技术作为推动行业变革的关键力量&#xff0c;其重要性不言而喻。唯众&#xff0c;作为教育实训解决方案的领先者&#xff0c;深刻洞察到市场对云计算技能人才的迫切需求&#xff0c;精心打造了云计算实训室。这一实训平台不仅集成了先…

软件著作权申请教程(超详细)(2024新版)软著申请

目录 一、注册账号与实名登记 二、材料准备 三、申请步骤 1.办理身份 2.软件申请信息 3.软件开发信息 4.软件功能与特点 5.填报完成 一、注册账号与实名登记 首先我们需要在官网里面注册一个账号&#xff0c;并且完成实名认证&#xff0c;一般是注册【个人】的身份。中…