pytorch里常用操作(持续更新)

对不起我脑子不太记事儿每次变换都得想想想所以干脆汇总一下算了,当然也有一些不是torch包里面的但是没有关系hhh 官方文档里有一堆不太常用的,这里整理的都是自己比较常用的

张量操作

torch.tensor:从Python列表或NumPy数组创建张量

torch.zeros/ones:创建全零/一张量

torch.zeros(10,4)就是创建[10,4]的全零张量

torch.rand:创建随机张量

torch.cat:沿指定维度拼接张量

torch.stack:在新的维度上堆叠张量

torch.stack(tensors, dim=0, out=None)

  • tensors:要堆叠的输入张量的列表或元组。
  • dim:指定要堆叠的新维度的索引。默认是0。
  • out:可选参数,用于指定结果张量的输出

 e.g

假设我们有两个张量 tensor1tensor2,它们的形状都是 (3, 4),并且我们想将它们堆叠在一个新的维度上,创建一个新的形状为 (2, 3, 4) 的张量。

# 创建两个张量
tensor1 = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
tensor2 = torch.tensor([[-1, -2, -3, -4], [-5, -6, -7, -8], [-9, -10, -11, -12]])# 使用torch.stack将它们堆叠在一个新的维度上
stacked_tensor = torch.stack((tensor1, tensor2), dim=0)

torch.reshape:改变张量的形状。

torch.transpose:交换张量的维度。

torch.transpose(input, dim0, dim1)

  • input:要进行维度交换的输入张量。
  • dim0:要交换的第一个维度的索引。
  • dim1:要交换的第二个维度的索引。

假设我们有一个形状为 (3, 4) 的张量,现在想要交换它的维度,创建一个新的张量,使其形状为 (4, 3)

# 创建一个形状为 (3, 4) 的张量
input_tensor = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])# 使用 torch.transpose 进行维度交换
transposed_tensor = torch.transpose(input_tensor, 0, 1)

torch.arange: 用于创建一个包含指定范围内的数值的一维张量

# 开始,结尾,间隔;和索引比较像

arr = torch.arange(10,100,10) #[10, 20, 30, 40, 50, 60, 70, 80, 90]

torch.meshgrid: 网格图

根据提供的x,y轴的范围得到一张网格的里面的点的xy坐标

torch.flatten(tensor,dim) : 把tensor压缩,dim表示压缩的维度

torch.unsqueeze:在张量中插入新的维度

new_tensor = torch.unsqueeze(input, dim)

  • input 是要插入新维度的输入张量。
  • dim 是要插入新维度的位置,通常是一个非负整数。

y = torch.tensor([[1, 2], [3, 4],[5, 6]]) # torch.Size[3,2]
y_new_1 = torch.unsqueeze(y, 0) #torch.Size[1, 3, 2]
y_new_2 = torch.unsqueeze(y, 1) # torch.Size[3, 1, 2]
y_new_3 = torch.unsqueeze(y, 2) # torch.Size[3, 2, 1]

一些非torch包的操作

[:, :, :None]最后一维扩一维

A;2d,B:1d B[i*cols+j] = A[i,j] 把二阶张量变成一阶

permute(): 矩阵转置


数学操作

torch.add/sub:张量相加/减

torch.mul/div:张量相乘/除

torch.sum:计算张量的和

torch.mean:计算张量的平均值

torch.max/min:找到张量中的最大/小值

torch.abs:计算张量的绝对值

torch.exp:计算输入张量中元素的指数(exponential)

x = torch.tensor([1.0, 2.0, 3.0])

exp_x = torch.exp(x)     # [e^1.0, e^2.0, e^3.0]


索引和切片

tensor[idx]:根据索引获取张量中的元素。
tensor[start:end]:切片操作。
tensor[:, 1]:选取指定列。
tensor[condition]:使用布尔条件进行索引。


自动求导

torch.autograd.Variable:创建自动求导的变量。
backward():计算梯度。
grad:访问梯度值。
no_grad():上下文管理器,用于禁用梯度计算。


神经网络模块

torch.nn.Module:创建神经网络模块

torch.nn.Linear:定义全连接层

这个层通常用于神经网络中,用来实现从输入到输出的线性变换,其中包括权重矩阵和偏置项。

# 创建一个 Linear 层
input_size = 10
output_size = 5
linear_layer = nn.Linear(input_size, output_size)# 随机生成一个输入张量
input_data = torch.randn(1, input_size)  # 这里创建一个形状为 (1, input_size) 的随机输入张量# 使用线性层进行前向传播
output = linear_layer(input_data)# 查看权重和偏置
weights = linear_layer.weight
bias = linear_layer.biasprint("输入张量:", input_data)
print("输出张量:", output)
print("权重矩阵:", weights)
print("偏置项:", bias)

这个示例中,首先创建了一个 nn.Linear 层,指定输入特征的数量和输出特征的数量。然后,随机生成一个输入张量 input_data,并通过将其传递给 linear_layer 来进行前向传播。线性层会应用权重矩阵和偏置项,生成输出张量。

nn.Linear 层在神经网络中通常用于连接不同层之间的神经元,执行线性变换的作用,帮助网络学习数据的特征表示。

torch.nn.Conv2d:定义卷积层

torch.nn.ReLU:ReLU激活函数

torch.nn.CrossEntropyLoss:交叉熵损失函数

torch.nn.optim:包含各种优化器,如SGD、Adam等

torch.nn.LayerNorm:用于层归一化

层归一化是一种用于神经网络的正则化技术,有助于加速训练和提高模型的鲁棒性。输入输出的形状并不会改变

import torch
import torch.nn as nn# 创建一个 LayerNorm 层
input_size = 10
layer_norm = nn.LayerNorm(input_size)# 随机生成一个输入张量
input_data = torch.randn(1, input_size)  # 创建一个形状为 (1, input_size) 的随机输入张量# 使用 LayerNorm 层进行前向传播
output = layer_norm(input_data)print("输入张量:", input_data)
print("LayerNorm 后的输出张量:", output) #shape[1,input_size]

torch.nn.Parameter:将张量标记为模型参数(可训练的参数)

将张量封装为 torch.nn.Parameter 对象后,它会被自动注册为模型的可训练参数,并在反向传播(backpropagation)期间更新它的值。这对于构建神经网络模型非常有用,因为神经网络的权重和偏置通常需要在训练期间进行优化。

# 创建一个普通的张量
tensor_data = torch.tensor([1.0, 2.0, 3.0])

# 将张量包装为一个模型参数·1                                                                                               
parameter = nn.Parameter(tensor_data)

# 打印参数
print(parameter) #Parameter containing: tensor([1., 2., 3.], requires_grad=True)

一起使用的是register_buffer

用于将张量注册为模型的缓冲区(buffer)

 注册为缓冲区的张量不会被视为模型的可训练参数,也不会在反向传播期间更新。它们用于保存模型的固定状态信息,例如统计信息(均值、方差等)、预训练的权重或任何其他不需要进行梯度更新的张量。

register_buffer 的主要作用是将这些张量添加到模型的状态字典中,以便在保存和加载模型时一并保存和加载。这对于确保模型的一致性和可重现性非常有用。

class CustomModel(nn.Module):def __init__(self):super(CustomModel, self).__init__()# 创建一个常量张量作为缓冲区self.register_buffer('constant_tensor', torch.tensor([1.0, 2.0, 3.0]))# 创建模型实例
model = CustomModel()# 打印模型缓冲区
for name, buffer in model.named_buffers():print(name, buffer)

torch.nn.MultiheadAttention:实现多头注意力机制

多头注意力机制允许模型同时关注输入中的不同部分,以提高模型性能

定义多头注意力模块
- embed_dim: 输入的维度
- num_heads: 头的数量,用于并行处理不同部分的注意力
- dropout: 可选的丢弃率

attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=0.1)# 输入数据 (query, key, value),通常是三个相同形状的张量
query = torch.randn(seq_length, batch_size, embed_dim)
key = torch.randn(seq_length, batch_size, embed_dim)
value = torch.randn(seq_length, batch_size, embed_dim)# 调用多头注意力模块
output, attention_weights = attention(query, key, value)# output 是注意力机制的输出,attention_weights 是注意力权重

e.g 

import torch
import torch.nn as nn# 定义多头注意力模块
embed_dim = 128
num_heads = 4
attention = nn.MultiheadAttention(embed_dim, num_heads)# 输入数据 (query, key, value)
seq_length = 10
batch_size = 32
query = torch.randn(seq_length, batch_size, embed_dim)
key = torch.randn(seq_length, batch_size, embed_dim)
value = torch.randn(seq_length, batch_size, embed_dim)# 调用多头注意力模块
output, attention_weights = attention(query, key, value)print("Output shape:", output.shape)
print("Attention weights shape:", attention_weights.shape)

数据加载和处理

torch.utils.data.Dataset:创建自定义数据集。


torch.utils.data.DataLoader:数据加载器。


transforms模块:用于数据预处理和转换。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/108771.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

idea使用debug无法启动,使用run可以启动

1、将调试断点清除 使用快捷键ctrl shift F8,将勾选的选项去除即可 2、Error running SampleApplication: Command line is too long. Shorten command line for SampleApplication or also for Spring Boot default configuration,报这种错误&#x…

vr火灾逃生安全科普软件开展消防突击教育安全有效

VR火灾逃生自救虚拟体验是一种利用虚拟现实技术来模拟火灾逃生自救场景的教育工具。以下是这个体验的几个优点:VR消防安全体验馆的出现,为城市的安全教育开辟了新的途径。这种创新的体验方式,能够让市民在模拟的火灾场景中学习并掌握消防安全…

前端面试基础面试题——10

1. 说说你对 promise 的了解 2.解构赋值及其原理 3.箭头函数需要注意的地方 4.箭头函数和普通函数有什么区别 5.ES6 都有什么 Iterator 遍历器 6.jQuery 一个对象可以同时绑定多个事件,这是如何实现的? 7.jQuery 库中的 $() 是什么? 8…

tcp/ip协议2实现的插图,数据结构2 (9 - 章)

(20) 20 九章1 IP选项处理 ip_dooptions (21) 21 九章2 IP选项处理 ip_rtaddr,save_rte,ip_srcroute与结构体 (22)九章3 IP选项处理 ip_pcbopts, ip_insertoptions , iptime 与结构 (23&#xf…

安装 mysql

gpt: 要在 Debian 11 上安装 MySQL 数据库服务器,您可以使用以下步骤: 1. **更新软件包列表**:在安装任何软件之前,始终建议首先更新软件包列表,以确保获取最新的软件包信息。在终端中运行以下命令: bash…

课时4作业1

Description 输入一个整型数,判断是否是对称数,如果是,输出yes,否则输出no,不用考虑这个整型数过大,int类型存不下,不用考虑负值; 例如 12321是对称数,输出yes&#xf…

过滤器(Filter)和拦截器(Interceptor)有什么不同?

过滤器(Filter)和拦截器(Interceptor)是用于处理请求和响应的中间件组件,但它们在实现方式和应用场景上有一些不同。 实现方式: 过滤器是Servlet规范中定义的一种组件,通常以Java类的形式实现。过滤器通过在…

编译添加了ALPHA开发板的NXP官方uboot

一. 简介 之前文章学习了 如何在NXP(恩智浦)官方 uboot 中添加正点原子的 ALPHA 开发板。 如何在NXP(恩智浦)官方 uboot 中添加正点原子的 ALPHA 开发板,文章如下: 向NXP官方uboot添加Nand版开发板-CSDN博…

【webrtc 】FEC 1: 音频RED rfc2198及视频ULPFEC的RED封装

1 参考和引用 M79 代码。 ULPFEC报文构建流程 与大神的分析: WebRTC-FEC协议总结 一致 CrystalShaw 大神的文章 ULPFEC在WebRTC中的实现 WebRTC研究:FEC之RED封装 本文是大神们文章和代码的学习笔记。red封包(rfc2189)1.1 RED(Redundant Coding) 封装 Ulpfec 非均等保护前向纠…

HarmonyOS云开发基础认证---练习题二

【判断题】 2/2 Serverless是云计算下一代的默认计算范式。 正确(True) 【判断题】 2/2 接入认证服务后,用户每次收到验证码短信都需要开发者买单。 错误(False) 【判断题】 2/2 认证服务手机号码登录需要填写国家码。 正确(True) 【判断题】 2/2 在Cloud Functi…

大数据Flink(九十八):SQL函数的归类和引用方式

文章目录 SQL函数的归类和引用方式 一、SQL 函数的归类

Vue_Bug Failed to fetch extension, trying 4 more times

Bug描述: 启动electron时出现Failed to fetch extension, trying 4 more times的问题 解决方法: 去src/background.js文件中进行代码注释工作 app.on(ready, async() > {// if (isDevelopment && !process.env.IS_TEST) {// // Install V…

小程序长期订阅

准备工作 ::: tip 管理后台配置 小程序类目:住建(硬性要求) 功能-》订阅消息-》我的模版 申请模版:1、预约进度通知 2、申请结果通知 3、业务办理进度提醒 ::: 用户订阅一次后,可长期下发多条消息。目前长期性订阅…

【SA8295P 源码分析 (一)】41 - SA8295所有镜像位置、拷贝脚本、生成QFIL包 及 Fastboot 下载命令介绍

【SA8295P 源码分析】41 - SA8295所有镜像位置、拷贝脚本、生成QFIL包 及 Fastboot 下载命令介绍 一、SA8295 各镜像位置二、SA8295 QNX 侧镜像拷贝脚本三、SA8295 Android 侧镜像拷贝脚本四、使用QFIL 下载整包五、Fastboot 下载命令整理系列文章汇总见:《【SA8295P 源码分析…

STM32如何使用PWM?

一:PWM介绍 PWM 是 Pulse Width Modulation 的缩写,中文意思就是脉冲宽度调制,简 称脉宽调制。它是利用微处理器的数字输出来对模拟电路进行控制的一种非常有 效的技术,其控制简单、灵活和动态响应好等优点而成为电力电子技术最广…

Vue之Vue的介绍安装开发实例生命周期钩子

博主心得: keyup必须与change一起使用v-on.click可以直接写成clickclick“setVal”里的setVal换成数字之后有惊喜VS Code是真的狗,一些报错根本不会直接显示总结:VS code太狗了 1.vue介绍 1.1 什么是vue vue是一个构建用户界面UI的渐进式jav…

【配置环境】SQLite数据库安装和编译以及VS下C++访问SQLite数据库

一,环境 Windows 11 家庭中文版,64 位操作系统, 基于 x64 的处理器SQLite - 3.43.2Microsoft Visual Studio Community 2022 (64 位) - Current 版本 17.5.3 二,SQLite简介 简要介绍 SQLite(Structured Query Language for Lite&a…

Babel 在Powershell 上无法查看版本

ES6 模块语法不能应用在ES5环境中 (ES6模块化语法不能在node.js中执行),此时需要Babel进行转码 通过npm install -g babel-cli 安装好后,想通过 babel --version产看版本。但是无法查看 首先,我们要以管理员方式运行PowerShell,&…

密码学二: md5 网站服务器与用户通信过程 ca原理 签名原理 Flame 病毒原理

md5被破解? MD5(Message Digest Algorithm 5)是一个较早的哈希函数,但由于其弱点和漏洞,它已经被认为不再适合用于安全性要求较高的应用。MD5的一些安全性问题包括: 碰撞攻击: MD5已经被证明容易受到碰撞攻…

9-k8s-亲和力与反亲和力

文章目录 一、概念二、实操节点亲和力1三、实操pod亲和力2 一、概念 节点亲和力概念(反亲和力相反) ps:官方文档http://kubernetes.p2hp.com/docs/concepts/scheduling-eviction/assign-pod-node.html 节点亲和力(Node Affinity&a…