Pytorch Note

cat函数:

cat函数不会增加维度,默认按照dim=0连接张量

stack函数:

stack函数会增加一个维度

nn.Linear的默认输入:

torch中默认输入一定要为tensor,并且默认是tensor.float32,此外device如果没有model.to(device)放到gpu上面默认会在cpu上运行,如果把模型放到了device上面,那么输入的向量也要放到gpu上面

torch的eval模式和train模式:

使用model.eval模式,模型会进入评估模式,在这个时候,会丢弃以下行为:

  1. Dropout:在评估模式下,Dropout 层不会丢弃任何神经元,所有的神经元都会参与计算。

  2. Batch Normalization:在评估模式下,Batch Normalization 层会使用训练过程中累积的均值和方差来进行归一化,而不是使用当前批次的数据。

使用model.train模式,模型会进入训练模式,这时候模型会启用Dropout和Batch Normalization

torch.gather函数:
torch.gather(input, dim, index) → Tensor

 假设input的shape为(a*b*c),index的shape需要为(a*b,x),这时候指定dim=2,就会把dim=2这一维度的向量按照x的下标收集起来1

import torch# 创建一个形状为 (3, 4) 的输入张量
input = torch.tensor([[1, 2, 3, 4],[5, 6, 7, 8],[9, 10, 11, 12]])# 创建一个形状为 (3, 2) 的索引张量
index = torch.tensor([[0, 1],[1, 2],[2, 3]])# 沿着第 1 维(列)收集元素
output = torch.gather(input, dim=1, index=index)print(output)"""
tensor([[ 1,  2],[ 6,  7],[11, 12]])
"""
torch.distributions.Categorical函数:

torch.distributions.Categorical(probs=None, logits=None)

probs代表概率,要求加起来为1,logits代表对数概率,不一定要加起来为1,torch会自动计算让他们加起来为1,虽然用np.random.choice也能实现这个效果,但是numpy是不能进行梯度计算的

action_dist = torch.distributions.Categorical(probs)
action = action_dist.sample()
item() 和 detach().cpu().numpy()

在深度学习训练后,需要计算每个epoch得到的模型的训练效果的时候,一般会用到detach() item() cpu() numpy()等函数。

  • item():返回的是tensor中的值,且只能返回单个值(标量),不能返回向量,使用返回loss等,得到的值因为是标量所以肯定是在cpu上,因为cuda上只能放tensor
  • detach(): 阻断反向传播,返回值任然是tensor
  • cpu():将tensor放到cpu上,返回值任然是tensor
  • numpy():将tensor转换为numpy,注意cuda上面的变量类型只能是tensor,不能是其他

在pytorch中反向传播只能对计算出的loss进行,loss肯定是一个具体的值,使用detach是为了把拿出的计算图和主图分离,计算出的loss不再对主干进行修改:

critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach()))
critic_loss.backward()

如上的critic_loss.backward()只会修改critic的参数,并不会修改td_target的参数

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/56418.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

RabbitMQ系列学习笔记(十)--通配符模式

文章目录 一、通配符模式原理二、通配符模式实战1、消费者代码2、生产者代码3、查看运行结果 本文参考: 尚硅谷RabbitMQ教程丨快速掌握MQ消息中间件rabbitmq RabbitMQ 详解 Centos7环境安装Erlang、RabbitMQ详细过程(配图) 一、通配符模式原理 通配符模式&#xff…

2024 睿抗机器人开发者大赛(RAICOM)-【网络安全】CTF 部分WP

文章目录 一、前言二、MICS你是黑客么循环的压缩包Goodtime 三、WEBpy 四、Crypto变异凯撒RSAcrypto3 一、前言 WP不完整,仅供参考! 除WEB,RE,PWN外,其余附件均已打包完毕 也是一个对MISC比较友好的一个比赛~ 123网…

写了一个SpringBoot的后端管理系统(仅后端)pine-manage-system

文章目录 前言正文🚀 技术栈🛠️ 功能模块📁 项目结构🌈 接口文档🚀 项目启动 附录项目功能代码示例1、数据库拦截器-打印sql执行时间2、数据记录变更拦截器3、用户角色数据权限拦截器4、实体转换器接口5、触发器模版6…

自动驾驶合集2

我自己的原文哦~ https://blog.51cto.com/whaosoft/12304421 #NeRF与自动驾驶 神经辐射场(Neural Radiance Fields)自2020年被提出以来,相关论文数量呈指数增长,不但成为了三维重建的重要分支方向,也逐渐作为自动驾驶…

C++学习笔记----9、发现继承的技巧(五)---- 多重继承(1)

我们前面提到过,多重继承常被认为是面向对象编程中复杂且没有必要的部分。这就仁者见仁,智者见智了,留给大家去评判。本节解释c中的多重继承。 1、多个类继承 从语法角度来说,定义一个有多个父类的类是很简单的。需要做的就是当声…

vue前端接包(axios)+ 前端导出excel(xlsx-js-style)

// 先在请求处加上: responseType: arraybuffer, // 指定响应类型为ArrayBufferconst data new Uint8Array(response.data); // 将ArrayBuffer转换为Uint8Arrayconst val { columns: [], data: [] }let offset 0; // 用于跟踪当前解析到的位置while (offset …

DASCTF 2024金秋十月赛RE题wp

目录 RE1:ezRERE2:ezelfRE3:ezAndroid 3题RE,差一点就AK了,可能好久没打比赛了,技能有所下降,还是需要经常摸一摸工具。 RE1:ezRE 执行的时候dump出来,然后静态分析 发…

深入理解Spring Boot的事务注解及其实现原理

深入理解Spring Boot的事务注解及其实现原理 在现代企业级应用开发中,事务管理是一个至关重要的概念。Spring Boot 提供了强大的事务管理功能,使得开发者可以轻松地管理数据库事务。本文将详细介绍Spring Boot中的事务注解及其实现原理。 1. 什么是事务…

Java项目-基于springboot框架的游戏分享系统项目实战(附源码+文档)

作者:计算机学长阿伟 开发技术:SpringBoot、SSM、Vue、MySQL、ElementUI等,“文末源码”。 开发运行环境 开发语言:Java数据库:MySQL技术:SpringBoot、Vue、Mybaits Plus、ELementUI工具:IDEA/…

[ACTF2020] 新生赛]Exec1

目录 0x01命令执行 [ACTF2020 新生赛]Exec1 1、解法1 2、解法2 3、总结 3.1php命令注入函数 3.2java命令注入函数 3.3常见管道符 0x02SQL注入 [极客大挑战 2019]EasySQL1 0x01命令执行 [ACTF2020 新生赛]Exec1 1、解法1 ping本地,有回显,TTL…

红队-安全见闻篇(上)

声明 学习视频来自B站UP主 泷羽sec的个人空间-泷羽sec个人主页-哔哩哔哩视频,如涉及侵权马上删除文章 笔记的只是方便各位师傅学习知识,以下网站只涉及学习内容,其他的都与本人无关,切莫逾越法律红线,否则后果自负 一.编程与开发 1.后端语言学习 C语⾔:⼀种通⽤的…

Pytest-Bdd-Playwright 系列教程(1):从零开始教你写自动化测试框架「喂饭教程」

Pytest-Bdd-Playwright 系列教程(1):从零开始教你写自动化测试框架「喂饭教程」 前言一、项目结构二、安装依赖三、BDD特性文件四、页面对象五、步骤定义六、测试脚本七、Pytest配置八、运行测试 前言 最近收到一些小伙伴在后台的留言&#x…

生成式AI时代的内容安全与系统构建:合合信息文档图像篡改检测创新方案

目录 一、生成式AI时代的内容安全与图像识别1.图像内容安全的重要性2.伪造文档与证件检测的应用场景3.人脸伪造检测技术 二、系统构建加速与文档解析1.TextIn文档解析平台2.TextIn文档解析输出的示例 三、合合信息的行业影响力总结 一、生成式AI时代的内容安全与图像识别 随着…

python-----函数详解(一)

一、概念及作用: 概念:由若干条语句组成语句块,其中包括函数名称、参数列表,它是组织代码的最小单元,完成一定的功能 作用:把一个代码封装成一个函数,一般按功能组织一段代码 目的就是为了重…

autMan奥特曼机器人-安装或更新golang依赖

autMan2.3.4及以上需要更新中间件或安装golang依赖,参照下列步骤: 一、直装版本 ssh下进入autMan文件夹下plugin/scripts下面输入以下指令: go get -u github.com/hdbjlizhe/middleware二、docker版本 从后台进入web终端,依次输入…

速盾:高防cdn的好处体现在什么地方?

随着互联网的迅猛发展,网络安全问题也日益受到关注。为了保护网站免受恶意攻击和DDoS攻击的影响,许多网站选择使用高防CDN(Content Delivery Network)服务。高防CDN的好处体现在以下几个方面。 首先,高防CDN可以有效防…

Ubuntu 上安装 Redmine 5.1 指南

文章目录 官网安装文档:命令步骤相关介绍GemRubyRailsBundler 安装 Redmine更新系统包列表和软件包:安装必要的依赖:安装 Ruby:安装 bundler下载 Redmine 源代码:安装 MySQL配置 Redmine 的数据库配置文件:…

Hudi 核心知识点详解

‌数据写入‌: ‌近实时写入‌:Hudi支持近实时写入,可以减少碎片化工具的使用,并通过CDC(Change Data Capture)增量导入RDBMS数据。此外,Hudi还限制小文件的大小和数量,优化存储效率…

Java开发者的成长轨迹:从入门到权威的二十年征程

在Java开发的漫长征途中,流传着一句耳熟能详的话:“三年入门,五年入行,十年精英,十五年专家,二十年权威”。这句话不仅是对Java开发者职业生涯的高度概括,更是对技术成长路径的一种深刻洞察。它…

Node.js:深入探秘 CommonJS 模块化的奥秘

在Node.js出现之前,服务端JavaScript基本上处于一片荒芜的境况,而当时也没有出现ES6的模块化规范。因此,Node.js采用了当时比较先进的一种模块化规范来实现服务端JavaScript的模块化机制,它就是CommonJS,有时也简称为C…