Pytorch:张量的形状操作

文章目录

      • 一、维度改变
        • 1.flatten展开
          • a.函数的基本用法
          • b.示例
        • 2.unsqueeze增维
          • a.函数的基本用法
          • b.示例
        • 3.squeeze降维
          • a.函数的基本用法
          • b.示例
      • 二、张量变形
        • 1.view()
          • a.函数的基本用法
          • b.参数:
          • c.注意事项
          • d.示例
        • 2.reshape()
          • a.注意事项
          • b.示例
        • 3.reshape_as()
          • a.函数的基本用法
          • b.参数:
          • c.示例
          • d.注意
      • 三、维度重排
        • 1.permute
          • a.函数的基本用法
          • b.参数:
          • c.示例
          • d.注意
        • 2.transpose
          • a.函数的基本用法
          • b.参数:
          • c.示例
          • d.注意

维度改变和张量变形都不改变内存中存储的结构,因此改变后的张量的值顺序和没改变前是一样的。

一、维度改变

1.flatten展开
  • torch.flatten(tensor)
  • tensor.flatten()

torch.flatten() 是一个在 PyTorch 中常用于张量(tensor)处理的函数,它将输入张量展开成一个一维张量。该函数通常用于准备数据,将多维数据转换为一维,以便用于机器学习模型,特别是在模型的全连接层(fully connected layers)之前。
常用于展开成一维

a.函数的基本用法

只给定一个张量,将直接展开成一维。
torch.flatten(input, start_dim=0, end_dim=-1) 的参数解释如下:

  • input: 输入的张量。
  • start_dim: 开始展开的维度,默认为 0。这意味着从哪个维度开始将张量展开。
  • end_dim: 结束展开的维度,默认为 -1,即最后一个维度。这意味着展开将持续到哪个维度。
b.示例

考虑一个三维张量,例如形状为 (2, 3, 4) 的张量。如果使用 torch.flatten() 将其展开,可以有多种方式处理:

  1. 完全展开: 将整个张量展开成一维数组。

    import torch
    x = torch.randn(2, 3, 4)
    flat_x = torch.flatten(x)
    # 结果形状为 [24]
    
  2. 从特定维度开始展开: 指定从哪个维度开始展开。例如,从第一维(索引为 0 的维度)开始展开。

    flat_x = torch.flatten(x, start_dim=1)
    # 结果形状为 [2, 12],保留了第一个维度,其余维度被展开
    
2.unsqueeze增维
  • torch.unsqueeze(tensor)
  • tensor.unsqueeze()

torch.unsqueeze() 是 PyTorch 中用来增加张量的维度的函数。该函数可以在张量的指定位置插入一个维度,它非常有用于调整张量的形状,以满足特定操作或模型的需求,例如在单样本张量上应用需要批处理的模型。
常用于在第0个维度上增加大小为1的维度

a.函数的基本用法

torch.unsqueeze(input, dim) 的参数解释如下:

  • input: 输入的张量。
  • dim: 要插入新维度的索引位置。这个位置遵循 Python 的索引规则,支持负索引。
b.示例

假设有一个二维张量 x 形状为 (3, 4),表示一个包含3个样本,每个样本4个特征的数据集。如果需要在特定维度增加一个维度,可以使用 torch.unsqueeze() 如下:

import torch
x = torch.randn(3, 4)# 在第0维增加一个维度
x_unsqueezed = x.unsqueeze(0)
print(x_unsqueezed.shape)
# 输出: torch.Size([1, 3, 4])# 在第1维增加一个维度
x_unsqueezed = torch.unsqueeze(x, 1)
print(x_unsqueezed.shape)
# 输出: torch.Size([3, 1, 4])# 使用负索引,在最后一个维度后增加一个维度
x_unsqueezed = torch.unsqueeze(x, -1)
print(x_unsqueezed.shape)
# 输出: torch.Size([3, 4, 1])
3.squeeze降维
  • torch.squeeze(tensor)
  • tensor.squeeze()

torch.squeeze() 是 PyTorch 中的一个函数,用于减少张量的维度,特别是去除那些维度大小为1的维度。这个函数非常有用于去除由于某些操作(比如 unsqueeze)产生的单一维度,从而使张量的形状更加紧凑。

a.函数的基本用法

只给定一个张量,将直接去掉所有大小为1的维度。
torch.squeeze(input, dim=None) 的参数解释如下:

  • input: 输入的张量。
  • dim: 指定要压缩的维度。如果指定的维度大小为1,则该维度会被去除如果大小不为1,则该维度不会被压缩如果不指定 dim 参数,那么所有大小为1的维度都会被去除。
b.示例

考虑一个张量 x,其形状包括一些大小为1的维度。以下是如何使用 torch.squeeze() 来去除这些维度的示例:

import torch
x = torch.randn(1, 3, 1, 5)# 去除所有大小为1的维度
squeezed_x = x.squeeze()
print(squeezed_x.shape)
# 输出: torch.Size([3, 5])# 只压缩第0维(大小为1)
squeezed_x = x.squeeze(0)
print(squeezed_x.shape)
# 输出: torch.Size([3, 1, 5])# 只压缩第2维(大小为1)
squeezed_x = torch.squeeze(x, 2)
print(squeezed_x.shape)
# 输出: torch.Size([1, 3, 5])# 尝试压缩一个不是大小为1的维度(没有变化)
squeezed_x = torch.squeeze(x, 1)
print(squeezed_x.shape)
# 输出: torch.Size([1, 3, 1, 5])

二、张量变形

1.view()

在 PyTorch 中,.view() 方法是一个非常重要且常用的功能,用于改变张量的形状而不改变其数据内容。此方法提供了一种高效的方式来重新排列张量的维度,使其适应不同的需求,例如输入到一个模型或对数据进行不同的操作。
view是共享内存的!

a.函数的基本用法

.view() 方法的基本用法是 tensor.view(*shape),其中 *shape 是希望张量拥有的新形状,由一组维度大小组成。

b.参数:
  • shape: 新的形状,是一个由整数构成的元组,其中的每个整数指定相应维度的大小。你也可以在某个位置使用 -1,让 PyTorch 自动计算该维度的大小。(注意某个位置是任意的某个位置,但是只能有一个)
c.注意事项
  1. 连续性.view() 要求张量在内存中是连续的(即一维数组中的元素顺序与多维视图中的顺序相同)。如果张量不是连续的,你可能需要首先调用 .contiguous() 方法来使其连续。

  2. 自动计算维度:使用 -1 作为形状参数的一部分,PyTorch 将自动计算该维度的正确大小,以便保持元素总数不变。

  3. 大小不变.view()要求张量变换形状之后的大小和变换之前的大小是一样的。即维度大小之积相等。比如tensor.Size([2,4])tensor.Size([8])是一样的。

d.示例
import torch
x = torch.randn(4, 4)  # 创建一个 4x4 的张量# 改变形状为 2x8
y = x.view(2, 8)
print(y.shape)
# 输出: torch.Size([2, 8])# 改变形状为 16(一维)
z = x.view(-1)#z = x.view(16)
print(z.shape)
# 输出: torch.Size([16])# 使用 -1 自动计算维度
w = x.view(-1, 8)
print(w.shape)
# 输出: torch.Size([2, 8])
import torch
x = torch.randn(2, 1)  # 创建一个 2×1 的张量# 改变形状为 2x8
y = x.view(2)
print(y)
# 输出: torch.Size([2, 8])
x[0][0]=2 #共享内存,y也会变
print(x)
print(y)
tensor([-0.5001,  0.5409])
tensor([[2.0000],[0.5409]])
tensor([2.0000, 0.5409])
2.reshape()

在 PyTorch 中,.reshape() 方法用于改变张量的形状而不改变其数据内容。
这一方法与 .view() 类似,都允许您重新排列张量的维度,但它们在处理非连续张量时的行为不同。
只有当非连续张量时,才会导致和.view不一样,如果是连续的,同样也是共享内存的。

a.注意事项
  1. 数据连续性:与 .view() 相比,.reshape() 可以处理非连续张量,如果必要,它会自动处理数据的内存复制。因此,如果原始张量不连续,而你尝试用 .view() 改变其形状可能会导致错误,但 .reshape() 会自动解决这个问题。

  2. 自动计算维度:使用 -1 作为形状参数的一部分时,PyTorch 会自动计算该维度的大小,以确保总元素数量与原张量相同。

b.示例
import torch
x = torch.randn(2, 3, 4)  # 创建一个 2x3x4 的张量# 改变形状为 6x4
y = x.reshape(6, 4)
print(y.shape)
# 输出: torch.Size([6, 4])# 改变形状为 1x24
z = x.reshape(1, 24)
print(z.shape)
# 输出: torch.Size([1, 24])# 使用 -1 自动计算维度
w = x.reshape(-1, 2)
print(w.shape)
# 输出: torch.Size([12, 2])
import torch
x = torch.randn(2, 2)  # 创建一个 2x1 的张量
x=x.transpose(0,1)
# 改变形状为 2x8
y = x.reshape(4)#转置后的x不是连续的,使用reshape产生复制,此时不能用.view()
print(y)
# 输出: torch.Size([2, 8])
x[0][0]=100
print(x)
print(y)
tensor([-0.5386, -0.3646, -0.1661, -0.2516])
tensor([[100.0000,  -0.1661],[ -0.3646,  -0.2516]])
tensor([-0.5386, -0.3646, -0.1661, -0.2516])
3.reshape_as()

在 PyTorch 中,.reshape_as() 是一个方便的方法,用于将一个张量重新塑形为与另一个张量相同的形状。这个方法实质上是 .reshape() 方法的一个简化版本,它以另一个张量的形状为目标形状。
换句话说,.reshape_as()相当于是省略了自指定参数的.reshape(),而可以直接用目标张量形状作为形状。

a.函数的基本用法

.reshape_as() 的基本用法非常直接:tensor1.reshape_as(tensor2)。这会将 tensor1 的形状修改为与 tensor2 相同的形状。

b.参数:
  • tensor2: 这是模型张量,tensor1 将改变形状以匹配 tensor2 的形状。
c.示例
import torch
x = torch.randn(2, 3, 4)  # 原始张量,形状为 2x3x4
y = torch.randn(6, 4)     # 目标张量,形状为 6x4# 将 x 的形状改变为与 y 相同
z = x.reshape_as(y)
print(z.shape)
# 输出: torch.Size([6, 4])
d.注意

虽然 .reshape_as() 很方便,但使用它时应确保两个张量具有相同的元素总数,因为改变形状的操作不会改变数据的总量。如果两个张量的总元素数量不匹配,尝试使用 .reshape_as() 将抛出错误。此外,如果原始张量在内存中是非连续的,.reshape_as() 会像 .reshape() 一样处理,可能需要在内部进行数据复制以确保连续性。

三、维度重排

permute方法可以按照指定顺序重新排列维度,而transpose方法可以交换张量的两个维度。用于需要进行维度重排或转置操作。如矩阵转置。

1.permute

在 PyTorch 中,.permute() 方法用于重新排列张量的维度,这是处理多维数据时一个非常有用的功能,尤其在需要对维度进行特定的重排序操作时。

a.函数的基本用法

.permute() 方法的调用格式为 tensor.permute(*dims),其中 *dims 是一个整数序列,代表新的维度排列顺序。

b.参数:
  • dims: 这个参数定义了张量的每个维度应该如何重新排列。序列中的每个整数都代表原始张量中一个维度的索引,这些索引的排列顺序确定了输出张量的形状。
c.示例
import torch
x = torch.randn(2, 3, 5)  # 创建一个形状为 [2, 3, 5] 的张量# 改变维度的排列顺序为 [2, 0, 1]
y = x.permute(2, 0, 1)
print(y.shape)
# 输出: torch.Size([5, 2, 3])# 将维度的排列顺序改为 [1, 2, 0]
z = x.permute(1, 2, 0)
print(z.shape)
# 输出: torch.Size([3, 5, 2])
d.注意
import torch
x = torch.tensor([[1,2,3,4],[2,4,2,4],[5,6,7,8]]) 
x = x.permute(1,0)
'''
tensor([[1, 2, 5],[2, 4, 6],[3, 2, 7],[4, 4, 8]])
'''

在 PyTorch 中,当使用 .permute() 方法重排张量维度时,张量的数据实际上在内存中的位置并没有改变。更准确地说.permute() 改变的是张量访问这些数据的方式,通过调整形状(shape)步长(stride) 的元信息,而不是数据本身。

  • 步长(Stride)
    • 步长是一个定义在每一维上的整数数组,表示为了在数据中从当前维度的一个元素移动到下一个元素,需要跨过的内存位置数。对于一个连续的张量,步长决定了元素在内存中的布局。

形状(Shape)和步长的调整当调用 .permute(1,0) 时,你实际上是告诉 PyTorch 以一个新的顺序来解释原始数据的内存布局。例如:

x = torch.tensor([[1, 2, 3, 4],[2, 4, 2, 4],[5, 6, 7, 8]])

原始的 x 的形状为 (3, 4),即有 3 行和 4 列。在 PyTorch 中,这意味着其步长为 (4, 1),其中 4 表示要从一行的开始移动到下一行的开始,在内存中需要跨过 4 个元素位置;1 表示在同一行中从一个元素移动到下一个元素,只需要跨过 1 个元素位置。

当你调用 x.permute(1, 0) 时,你是在指示 PyTorch 将原来的列视为行,将原来的行视为列。这就改变了形状为 (4, 3)。这时,步长变为 (1, 4)。这意味着:

  • 要从列的一个元素到下一个元素(现在变成了“行”移动),你只需要移动一个数据位置(原来的行移动)。
  • 要从一行移动到下一行(现在是原来的列跨行移动),你需要跨过 4 个数据位置。
2.transpose

在 PyTorch 中,.transpose() 方法用于交换张量中的两个维度,这是处理多维数组时一个常用的功能,尤其是在需要对特定的维度进行转置操作时。

a.函数的基本用法

.transpose() 方法的调用格式为 tensor.transpose(dim0, dim1),其中 dim0dim1 是要交换的维度的索引。

b.参数:
  • dim0: 第一个要交换的维度的索引。
  • dim1: 第二个要交换的维度的索引。
c.示例
import torch
x = torch.randn(2, 3, 5)  # 创建一个形状为 [2, 3, 5] 的张量# 交换维度 0 和 1
y = x.transpose(0, 1)
print(y.shape)
# 输出: torch.Size([3, 2, 5])# 交换维度 1 和 2
z = x.transpose(1, 2)
print(z.shape)
# 输出: torch.Size([2, 5, 3])
d.注意

.permute() 类似,.transpose() 也是返回原始数据的一个新视图,并不复制数据。因此,输出张量与输入张量共享同一块内存空间,只是它们的形状和步长(stride)不同。同样,.transpose() 会导致张量在内存中可能变为非连续,因此在某些情况下,可能需要调用 .contiguous() 来使张量在内存中连续。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/824570.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深入理解 pytest Fixture 方法及其应用

当涉及到编写自动化测试时,测试框架和工具的选择对于测试用例的设计和执行非常重要。在Python 中,pytest是一种广泛使用的测试框架,它提供了丰富的功能和灵活的扩展性。其中一个很有用的功 能是fixture方法,它允许我们初始化测试环…

css中backface-visibility使用

backface-visibility 是一个 CSS 属性,用于控制元素的背面是否可见。它主要用于在进行3D转换时控制元素的背面可见性。当一个元素被旋转或进行其他3D变换时,通常浏览器默认会进行背面剪裁(backface culling),使得元素的…

DAY29| 491.递增子序列 ,46.全排列 ,47.全排列II

文章目录 491.递增子序列46.全排列47.全排列II 491.递增子序列 文字讲解:递增子序列 视频讲解:递增子序列 **状态:这题看了文字讲解才AC,掌握了如何在回溯里通过Set集合来对同层节点去重 思路: 代码: cla…

HTML5漫画风格个人介绍源码

源码介绍 HTML5漫画风格个人介绍源码,源码由HTMLCSSJS组成,记事本打开源码文件可以进行内容文字之类的修改,双击html文件可以本地运行效果,也可以上传到服务器里面,重定向这个界面 效果截图 源码下载 HTML5漫画风格…

设计模式———单例模式

单例也就是只能有一个实例,即只创建一个实例对象,不能有多个。 可能会疑惑,那我写代码的时候注意点,只new一次不就得了。理论上是可以的,但在实际中很难实现,因为你无法预料到后面是否会脑抽一下~~因此我们…

「Python大数据」数据采集-某东产品数据评论获取

前言 本文主要介绍通过python实现数据采集、脚本开发、办公自动化。数据内容范围:星级评分是1-3分、获取数据页面是前50页。 友情提示 法律分析:下列三种情况,爬虫有可能违法,严重的甚至构成犯罪: 爬虫程序规避网站经营者设置的反爬虫措施或者破解服务器防抓取措施,非法…

arm 作业 24/4/17

1、主机向从机发送多个字节的数据 主机发送起始信号 主机发送8bit从机地址1bit写标志(0) 从机回应应答信号 主机发送8bit从机的寄存器地址 从机回应应答信号 主机发送8bit数据 从机回应应答 主机发送8bit数据 从机回应应答 ………… 主机发起…

【Pytorch】Conv1d

conv1d 先看看官方文档 再来个简单的例子 import torch import numpy as np import torch.nn as nndata np.arange(1, 13).reshape([1, 4, 3]) data torch.tensor(data, dtypetorch.float) print("[data]:\n", data) conv nn.Conv1d(in_channels4, out_channels1…

二叉树前序遍历​​​​​​​到底部为何会返回到顶部?函数调用栈

前序遍历是一种二叉树的遍历方式,其遍历顺序是先访问根节点,然后递归地遍历左子树,最后递归地遍历右子树。具体来说,前序遍历的顺序是根节点->左子树->右子树。 前序遍历到底部为何会返回到顶部是因为在进行递归遍历时&…

启明智显应用分享|基于ESP32-S3方案的SC01PLUS彩屏与chatgpt融合应用DEMO

今天将带大家真实体验科技与智慧的完美融合——SC01PLUS与ChatGPT的深度融合DEMO效果呈现。 彩屏的清晰显示与ChatGPT的精准回答,将为我们带来前所未有的便捷与高效。 SC01PLUS是启明智显基于ESP32-S3打造的一款3.5寸480*320分辨率的彩屏产品,您可以看…

32、模拟队列

模拟队列 题目描述 实现一个队列,队列初始为空,支持四种操作: (1) “push x” – 向队尾插入一个数x; (2) “pop” – 从队头弹出一个数; (3) “empty” – 判断队列是否为空; (4) “query” – 查询…

【Git】git命令大全(持续更新)

本文架构 0.描述git简介术语 1.常用命令2. 信息管理新建git库命令更改存在库设置获取当前库信息 3.工作空间相关将工作空间文件添加到缓存区(增)从工作空间中移除文件(删)撤销提交 4.远程仓库相关同步远程仓库分支 (持…

高版本Android studio 使用Markdown无法预览(已解决)

目录 概述 解决方法 概述 本人升级Android studio 当前版本为Android Studio Jellyfish | 2023.3.1 RC 2导致Markdown无法预览。 我尝试了很多网上的方法都无法Markdown解决预览问题,包括升级插件、安装各种和Markdown相关的插件及使用“Choose Boot Java Runtim…

yolov5 自训练pt模型转onnx,再转rknn,并部署 注意事项

yolov5 部署到rk3588 教程来自 yolov5训练pt模型并转换为rknn模型,部署在RK3588开发板上——从训练到部署全过程_yolov5 rknn-CSDN博客 1.通过android studio 部署代码在rk3588板子上运行代码 项目来源 rknn-toolkit2/rknpu2/examples/rknn_yolov5_android_apk…

使用AWK进行文本处理

awk 的基本概念 awk 是一种强大的文本处理语言,广泛用于模式匹配和数据提取。这种编程语言设计用于对文本文件进行操作,尤其适用于格式化的文本,如 CSV 或空格分隔的表格数据。下面详细介绍 awk 的一些基本概念: 1. 记录和字段 …

一文了解OCI标准、runC、docker、contianerd、CRI的关系

docker和contanerd都是流行的容器运行时(container runtime);想讲清楚他们两之间的关系,让我们先从runC和OCI规范说起。 一、OCI标准和runC 1、OCI(open container initiative) OCI是容器标准化组织为了…

利用动态规划优化10年投资回报:策略、证明与算法分析

利用动态规划优化10年投资回报:策略、证明与算法分析 a. 存在最优投资策略的证明b. 最优子结构性质的证明c. 最优投资策略规划算法设计d. 新限制条款下最优子结构性质的证明 在面对投资策略规划问题时,我们的目标是在10年后获得最大的回报。Amalgamated投…

Java上传文件到服务器

1、使用jsch <!--sftp文件上传--><dependency><groupId>com.jcraft</groupId><artifactId>jsch</artifactId><version>0.1.55</version></dependency> 2、配置类 package com.base.jsch;import lombok.Data; import o…

数据结构与算法-哈希表

哈希表 哈希表&#xff08;hash table&#xff09;&#xff0c;又称散列表&#xff0c;它通过建立键 key 与值 value 之间的映射&#xff0c;实现高效的元素查询。具体而言&#xff0c;我们向哈希表中输入一个键 key &#xff0c;则可以在 时间内获取对应的值 value 。 1.基础…

牛客 NC205 跳跃游戏(三)【中等 贪心 Java,Go,PHP】

题目 题目链接&#xff1a; https://www.nowcoder.com/practice/14abdfaf0ec4419cbc722decc709938b 思路 参考答案Java import java.util.*;public class Solution {/*** 代码中的类名、方法名、参数名已经指定&#xff0c;请勿修改&#xff0c;直接返回方法规定的值即可*** …