从PyTorch官方的一篇教程说开去(2 - 源码)

先上图,上篇文章的运行结果,可以看到,算法在迭代了200来次左右达到人生巅峰,倒立摆金枪不倒,可以扛住连续200次操作。不幸的是,然后就出现了大幅度的回撤,每况愈下,在600次时候居然和100次的时候一个水平。

事实上,训练充满了随机性,也不乏非常漂亮的曲线,可以用来tree new bee,这也是AI领域很好水论文的体现吧。

下面两个图,分别来自,windows11本地运行 vs Colab云端运行。

呃,这个就是为啥G家主推的深度学习目前应用场景窄,被openAI狠揍的核心原因了 -

1)只能处理离散模型,数据量要求极高;

2)模型通用性差,不同的场景需要定制算法;

虽然G家多次宣称“霸权”,但事实上这个技术栈确实不适合解决通用问题。

开箱即食,以下为代码(单一python文件,windows11 + python 3.11.9 + GTX1080显卡) - 

import gymnasium as gym
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import countimport torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as Fenv = gym.make("CartPole-v1", render_mode="human")# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:from IPython import displayplt.ion()# if GPU is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
##device = torch.device("cuda")Transition = namedtuple('Transition',('state', 'action', 'next_state', 'reward'))class ReplayMemory(object):def __init__(self, capacity):self.memory = deque([], maxlen=capacity)def push(self, *args):"""Save a transition"""self.memory.append(Transition(*args))def sample(self, batch_size):return random.sample(self.memory, batch_size)def __len__(self):return len(self.memory)class DQN(nn.Module):def __init__(self, n_observations, n_actions):super(DQN, self).__init__()self.layer1 = nn.Linear(n_observations, 128)self.layer2 = nn.Linear(128, 128)self.layer3 = nn.Linear(128, n_actions)# Called with either one element to determine next action, or a batch# during optimization. Returns tensor([[left0exp,right0exp]...]).def forward(self, x):x = F.relu(self.layer1(x))x = F.relu(self.layer2(x))return self.layer3(x)# BATCH_SIZE is the number of transitions sampled from the replay buffer
# GAMMA is the discount factor as mentioned in the previous section
# EPS_START is the starting value of epsilon
# EPS_END is the final value of epsilon
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
# TAU is the update rate of the target network
# LR is the learning rate of the ``AdamW`` optimizer
BATCH_SIZE = 128
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 1000
TAU = 0.005
LR = 1e-4# Get number of actions from gym action space
n_actions = env.action_space.n
# Get the number of state observations
state, info = env.reset()
n_observations = len(state)policy_net = DQN(n_observations, n_actions).to(device)
target_net = DQN(n_observations, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
memory = ReplayMemory(10000)steps_done = 0def select_action(state):global steps_donesample = random.random()eps_threshold = EPS_END + (EPS_START - EPS_END) * \math.exp(-1. * steps_done / EPS_DECAY)steps_done += 1if sample > eps_threshold:with torch.no_grad():# t.max(1) will return the largest column value of each row.# second column on max result is index of where max element was# found, so we pick action with the larger expected reward.return policy_net(state).max(1)[1].view(1, 1)else:return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long)episode_durations = []def plot_durations(show_result=False):plt.figure(1)durations_t = torch.tensor(episode_durations, dtype=torch.float)if show_result:plt.title('Result')else:plt.clf()plt.title('Training...')plt.xlabel('Episode')plt.ylabel('Duration')plt.plot(durations_t.numpy())# Take 100 episode averages and plot them tooif len(durations_t) >= 100:means = durations_t.unfold(0, 100, 1).mean(1).view(-1)means = torch.cat((torch.zeros(99), means))plt.plot(means.numpy())plt.pause(0.001)  # pause a bit so that plots are updatedif is_ipython:if not show_result:display.display(plt.gcf())display.clear_output(wait=True)else:display.display(plt.gcf())def optimize_model():if len(memory) < BATCH_SIZE:returntransitions = memory.sample(BATCH_SIZE)# Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for# detailed explanation). This converts batch-array of Transitions# to Transition of batch-arrays.batch = Transition(*zip(*transitions))# Compute a mask of non-final states and concatenate the batch elements# (a final state would've been the one after which simulation ended)non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,batch.next_state)), device=device, dtype=torch.bool)non_final_next_states = torch.cat([s for s in batch.next_stateif s is not None])state_batch = torch.cat(batch.state)action_batch = torch.cat(batch.action)reward_batch = torch.cat(batch.reward)# Compute Q(s_t, a) - the model computes Q(s_t), then we select the# columns of actions taken. These are the actions which would've been taken# for each batch state according to policy_netstate_action_values = policy_net(state_batch).gather(1, action_batch)# Compute V(s_{t+1}) for all next states.# Expected values of actions for non_final_next_states are computed based# on the "older" target_net; selecting their best reward with max(1)[0].# This is merged based on the mask, such that we'll have either the expected# state value or 0 in case the state was final.next_state_values = torch.zeros(BATCH_SIZE, device=device)with torch.no_grad():next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0]# Compute the expected Q valuesexpected_state_action_values = (next_state_values * GAMMA) + reward_batch# Compute Huber losscriterion = nn.SmoothL1Loss()loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))# Optimize the modeloptimizer.zero_grad()loss.backward()# In-place gradient clippingtorch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)optimizer.step()num_episodes = 600for i_episode in range(num_episodes):# Initialize the environment and get it's statestate, info = env.reset()state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)for t in count():action = select_action(state)observation, reward, terminated, truncated, _ = env.step(action.item())reward = torch.tensor([reward], device=device)done = terminated or truncatedif terminated:next_state = Noneelse:next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)# Store the transition in memorymemory.push(state, action, next_state, reward)# Move to the next statestate = next_state# Perform one step of the optimization (on the policy network)optimize_model()# Soft update of the target network's weights# θ′ ← τ θ + (1 −τ )θ′target_net_state_dict = target_net.state_dict()policy_net_state_dict = policy_net.state_dict()for key in policy_net_state_dict:target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)target_net.load_state_dict(target_net_state_dict)if done:episode_durations.append(t + 1)plot_durations()breakprint('Complete')
plot_durations(show_result=True)
plt.ioff()
plt.show()

迭代次数是600,如果是使用CPU而不是GPU的话,建议设置在50以内,否则你懂的... ...

前置条件是安装老黄家的Cuda,以及准备好python环境(cuda暂不支持python 3.12),安装好需要的库,需要的可以看我此前的博文。

您的进步和反馈是我最大的动力,小伙伴来个三连呗!共勉。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/873184.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

高性能内存对象缓存

1&#xff1a;数据存储方式与数据过期方式 数据存储方式多种多样&#xff0c;以下为常见的几种&#xff1a; 1.关系型数据库&#xff1a;如 MySQL、Oracle 等&#xff0c;通过表格形式组织数据&#xff0c;具有严格的数据结构和关系约束&#xff0c;适用于结构化数据的存储和管…

设计模式第一天|了解设计模式、设计模式七大原则

文章目录 了解设计模式概念优点核心原则 设计模式七大原则单一职责原则里氏替换原则依赖倒置原则接口隔离原则迪米特法则开闭原则合成复用原则 了解设计模式 概念 软件设计模式(Software Design Patten),又称设计模式,是一套被反复使用,多数人只晓的,经过分类编目的,代码设计…

JVM知识点总结(全网最详细)!!!!

JVM知识总结 运行时数据区域程序计数器Java虚拟机栈局部变量表 StackOverflowError异常和OutOfMemoryError异常本地方法栈Java堆方法区运行时常量池 对象的创建对象的内存分配对象的内存布局对象头实例数据对齐填充 对象的访问定位使用句柄直接指针使用句柄和直接指针的优缺点 …

android11 屏蔽usb通过otg转接口外接鼠标设备

硬件平台&#xff1a;QCS6125 软件平台&#xff1a;Android11 需求&#xff1a;Android设备通过接usb转接线连接鼠标功能屏蔽。 考虑到屏蔽的层面可以从两个层面去做&#xff0c;一个是驱动层面不识别&#xff0c;一个就是Android系统层面不识别加载&#xff0c;本篇只讲后者。…

重置Kafka

重置kafka 1、关闭kafka kill -9 进程号 2、删除元数据 1&#xff09;zk zkCli.sh 2&#xff09;删除预kafka有关的所有信息 ls / rmr /config rmr /brokers 3、删除kafka的数据 所有节点都要删除 rm -rf /usr/local/soft/kafka_2.11-2.0.0/data 4、 重启 kafka-server-sta…

PHP房产中介租房卖房平台微信小程序系统源码

​&#x1f3e0;【租房卖房新选择】揭秘房产中介小程序&#xff0c;一键搞定置业大事&#xff01;&#x1f3e1; &#x1f50d;【开篇&#xff1a;告别繁琐&#xff0c;拥抱便捷】&#x1f50d; 还在为找房子跑断腿&#xff1f;为卖房发愁吗&#xff1f;今天给大家安利一个超…

IPython与Pandas:数据分析的动态组

IPython与Pandas&#xff1a;数据分析的动态组合 前言 欢迎来到"iPython与Pandas&#xff1a;数据分析的动态组合"教程&#xff01;无论你是数据分析新手还是希望提升技能的专业人士&#xff0c;这里都是你开始的地方。让我们开始这段数据分析之旅吧&#xff01; …

【.NET全栈】ASP.NET开发Web应用——AJAX开发技术

文章目录 前言一、ASP.NET AJAX基础1、AJAX技术简介2、ASP.NET AJAX技术架构 二、ASP.NET AJAX服务器端扩展1、声明ScriptManager控件2、使用ScriptManager分发自定义脚本3、在ScriptManager中注册Web服务4、处理ScriptManager中的异常5、编程控制ScriptManager控件6、使用Upda…

如何高效定制视频扩散模型?卡内基梅隆提出VADER:通过奖励梯度进行视频扩散对齐

论文链接&#xff1a;https://arxiv.org/pdf/2407.08737 git链接&#xff1a;https://vader-vid.github.io/ 亮点直击&#xff1a; 引入奖励模型梯度对齐方法&#xff1a;VADER通过利用奖励模型的梯度&#xff0c;对多种视频扩散模型进行调整和对齐&#xff0c;包括文本到视频和…

如何评估 5G 毫米波相控阵天线模块

5G 新无线电 (5G NR) 是空中接口或无线接入网络 (RAN) 技术的行业标准和全球规范。它涵盖 6 GHz 及以下频率&#xff08;称为 FR1&#xff09;和 24 GHz 至 50 GHz 或更高频段&#xff08;称为 FR2 或 mmWave&#xff09;的运行。该技术可用于固定或移动接入、回程和日益流行的…

Flutter 插件之 package_info_plus

当使用Flutter开发应用时,通常需要获取应用程序的基本信息,例如包名、版本号和构建号。Flutter提供了一个名为 package_info_plus 的插件,它能方便地帮助我们获取这些信息。 1. 添加依赖 首先,需要在项目的 pubspec.yaml 文件中添加 package_info_plus 的依赖。打开 pubs…

C语言结构体字节对齐技术详解

C语言结构体字节对齐技术详解&#xff08;第一部分&#xff09; 在C语言中&#xff0c;结构体字节对齐是一个重要的概念&#xff0c;它涉及到内存中数据的布局和访问效率。字节对齐可以帮助提高程序的性能&#xff0c;减少内存碎片&#xff0c;并确保数据的一致性和正确性。本…

一些简单的基本知识(与C基本一致)

一、注释 1.单行注释&#xff1a;//&#xff08;快捷键&#xff1a;ctrlshift&#xff1f;&#xff0c;可以选择多行&#xff09; 2.多行注释&#xff1a;/* 文本 */ 二、变量 变量的作用是给一段内存空间起名&#xff0c;方便操作内存中的数据。 通过赋予某数据的…

逆向案例二十五——webpack所需模块函数很多,某翼云登录参数逆向。

解决步骤&#xff1a; 网址&#xff1a;aHR0cHM6Ly9tLmN0eXVuLmNuL3dhcC9tYWluL2F1dGgvbG9naW4 不说废话&#xff0c;密码有加密&#xff0c;直接搜索找到疑似加密位置打上断点。 再控制台打印&#xff0c;分析加密函数 有三个处理过程&#xff0c;b[g]得到的是用户名,b[f] 对…

【ASP.NET网站传值问题】“object”不包含“GetEnumerator”的公共定义,因此 foreach 语句不能作用于“object”类型的变量等

问题一&#xff1a;不允许遍历 原因&#xff1a;实体未强制转化 后端: ViewData["CateGroupList"] grouplist; 前端加上&#xff1a;var catelist ViewData["CateGroupList"] as List<Catelogue>; 这样就可以遍历catelist了 问题二&#xff1a…

数据结构初阶·排序算法(内排序)

目录 前言&#xff1a; 1 冒泡排序 2 选择排序 3 插入排序 4 希尔排序 5 快速排序 5.1 Hoare版本 5.2 挖坑法 5.3 前后指针法 5.4 非递归快排 6 归并排序 6.1递归版本归并 6.2 非递归版本归并 7 计数排序 8 排序总结 前言&#xff1a; 目前常见的排序算法有9种…

探索Eureka的高级用法:在服务中实现分布式锁

在分布式系统中&#xff0c;实现分布式锁是一种常见需求&#xff0c;用于确保多个服务实例不会同时访问共享资源或执行相同的任务。虽然Eureka本身是一个服务发现工具&#xff0c;并不直接提供分布式锁功能&#xff0c;但我们可以通过结合其他技术&#xff08;如Redis、Zookeep…

Torch-Pruning 库入门级使用介绍

项目地址&#xff1a;https://github.com/VainF/Torch-Pruning Torch-Pruning 是一个专用于torch的模型剪枝库&#xff0c;其基于DepGraph 技术分析出模型layer中的依赖关系。DepGraph 与现有的修剪方法&#xff08;如 Magnitude Pruning 或 Taylor Pruning&#xff09;相结合…

TCP重传机制详解

1.什么是TCP重传机制 在 TCP 中&#xff0c;当发送端的数据到达接收主机时&#xff0c;接收端主机会返回⼀个确认应答消息&#xff0c;表示已收到消息。 但是如果传输的过程中&#xff0c;数据包丢失了&#xff0c;就会使⽤重传机制来解决。TCP的重传机制是为了保证数据传输的…

React安装(学习版)

1. 安装Node.js和npm 首先&#xff0c;确保你的电脑上已经安装了Node.js和npm&#xff08;Node Package Manager&#xff09;。你可以从 Node.js官网 下载安装包并按照提示进行安装。安装完成后&#xff0c;可以在命令行终端中验证Node.js和npm是否正确安装&#xff1a; node …