深度学习 Pytorch 张量的索引、分片、合并以及维度调整

张量作为有序的序列,也是具备数值索引的功能,并且基本索引方法和python原生的列表、numpy中的数组基本一致。

不同的是,pytorch中还定义了一种采用函数来进行索引的方式。

作为pytorch中的基本数据类型,张量既具备了列表、数组的基本功能,同时还充当向量、矩阵等重要数据结构。因此pytorch中也设置了非常晚辈的张量合并与变换的操作。

import torch	# 导入torch
import numpy as np	# 导入numpy

6 张量的符号索引

6.1 一维张量索引

一维张量的索引过程和python原生对象类型的索引一致,基本格式遵循[start: end: step]

t1 = torch.arange(1, 11)	# 创建一维张量

从左到右,从零开始

t1[0]
# output : tensor(1)

**注:**张量索引出来的结果还是零维张量,而不是单独的数。

​ 要转化成单独的数,需要使用.item()方法


冒号分割,表示对某个区域进行索引,也就是所谓的切片

t1[1: 8]	# 索引其中2-9号元素,并且左闭右开
# output : tensor([2, 3, 4, 5, 6, 7, 8])

第二个冒号,表示索引的间隔

t1[1: 8: 2]		# 第三个参数表示每两个数取一个
# output : tensor([2, 4, 6, 8])

冒号前后没有值,表示索引这个区域

t1[1: : 2]		# 从第二个元素开始索引,一致到结尾,并且每隔两个取一个
# output : tensor([ 2,  4,  6,  8, 10])
t1[: 8: 2]		#从第一个元素开始索引到第九个元素(不包含),并且每隔两个数取一个
# output : tensor([1, 3, 5, 7])

在张量的索引中,step位必须大于0,也就是说不能逆序取数。


6.2 二维张量索引

二维张量的索引逻辑和一维张量基本相同,二维张量可以视为两个一维张量组合而成。

在实际的索引过程中,需要用逗号进行分割,表示分别对哪个一维张量进行索引、以及具体的一维张量的索引。

t2 = torch.arange(1, 10).reshape(3, 3)		# 创建二维张量
t2[0, 1]	# 表示索引第一行、第二列的元素
# output : tensor(2)
t2[0, : : 2]	# 表示索引第一行、每隔两个元素取一个
# output : tensor([1, 3])
t2[0, [0, 2]]	# 索引结果同上
t2[: : 2, : : 2]	# 表示每隔两行取一行、并且每一行中每隔两个元素取一个
# output : 
tensor([[1, 3],[7, 9]])
t2[[0, 2], 1]	# 索引第一行、第三行、第二列的元素
# output : tensor([2, 8])

6.3 三维张量索引

我们可以将三维张量视作矩阵组成的序列,则在索引过程中拥有三个维度,分别是索引矩阵,索引矩阵的行、索引矩阵的列。

t3 = torch.arange(1, 28).reshape(3, 3, 3)	# 创建三维张量
t3[1, 1, 1]		# 索引第二个矩阵中,第二行、第二个元素
# output : tensor(14)
t3[1, : : 2, : : 2]		#索引第二个矩阵,行和列都是每隔两个取一个
# output : 
tensor([[10, 12],[16, 18]])
# 每隔两个取一个矩阵,对于每个矩阵来说,行和列都是每隔两个取一个
t3[: : 2, : : 2, : : 2]		
# output : 
tensor([[[ 1,  3],[ 7,  9]],[[19, 21],[25, 27]]])

7 张量的函数索引

pytorch中,我们还可以使用index_select函数,通过指定index来对张量进行索引。

t1 = torch.arange(1, 11)
indices = torch.tensor([1, 2])
torch.index_select(t1, 0, indices)
# output : tensor([2, 3])

第二个参数dim代表索引的维度。

对于t1这个一维向量来说,由于只有一个维度,因此第二个参数化取值为0,代表在第一个维度上进行索引。


t2 = torch.arange(12).reshape(4,3)
t2
# output :
tensor([[ 0,  1,  2],[ 3,  4,  5],[ 6,  7,  8],[ 9, 10, 11]])
indices = torch.tensor([1, 2])# dim参数取值为0,代表在shape的第一个维度上索引
torch.index_select(t2, 0, indices)	
# output : 
tensor([[3, 4, 5],[6, 7, 8]])# dim参数取值为0,代表在shape的第二个维度上索引
torch.index_select(t2, 1, indices)	
# output : 
tensor([[ 1,  2],[ 4,  5],[ 7,  8],[10, 11]])

8 tensor.view()方法

该方法会返回一个类似视图的结果,且该结果会和原张量对象共享一块数据存储空间

通过.view()方法,还可以改变对象结构,生成一个不同结构、但共享一个存储空间的张量。

t = torch.arange(6).reshape(2, 3)
t
# output :
tensor([[0, 1, 2],[3, 4, 5]])
# 构建一个数据相同,但形状不同的“视图”
te = t.view(3, 2)	
te
# output :
tensor([[0, 1],[2, 3],[4, 5]])

当然,共享一个存储空间,也就代表二者是浅拷贝的关系,修改其中一个,另一个也会同步更改。

t[0] = 1
te
# output :
tensor([[1, 1],[1, 3],[4, 5]])

当然,维度也可以修改

tr = t.view(1, 2, 3)
tr
# output :
tensor([[[1, 1, 1],[3, 4, 5]]])

视图的作用就是节省空间,在接下来介绍的很多切分张量的方法中,返回结果都是“视图”,而不是新生成一个对象。


9 张量的分片函数

9.1 分块:chunk函数

chunk函数能够按照某维度,对张量进行均匀切分,返回结果是原张量的视图

t2 = torch.arange(12).reshape(4, 3)
t2
# output :
tensor([[ 0,  1,  2],[ 3,  4,  5],[ 6,  7,  8],[ 9, 10, 11]])
# 在第零个维度上,按行进行四等分
tc = torch.chunk(t2, 4, dim = 0)
tc
# output :
(tensor([[0, 1, 2]]),tensor([[3, 4, 5]]),tensor([[6, 7, 8]]),tensor([[ 9, 10, 11]]))

注:chunk返回结果是一个视图,不是新生成了一个对象

tc[0][0][0] = 1		# 修改tc中的值
t2
# output :
tensor([[ 1,  1,  2],[ 3,  4,  5],[ 6,  7,  8],[ 9, 10, 11]])

当原张量不能均分时,chunk不会报错,但会返回其他均分结果。

torch.chunk(t2, 3, dim = 0)	# 返回次一级均分结果
# output :
(tensor([[1, 1, 2],[3, 4, 5]]),tensor([[ 6,  7,  8],[ 9, 10, 11]]))
torch.chunk(t2, 5, dim = 0)	# 返回次一级均分结果
# output :
(tensor([[1, 1, 2]]),tensor([[3, 4, 5]]),tensor([[6, 7, 8]]),tensor([[ 9, 10, 11]]))

9.2 拆分 :split函数

split既能进行均分,也能自定义切分

t2 = torch.arange(12).reshape(4, 3)
t2
# output :
tensor([[ 0,  1,  2],[ 3,  4,  5],[ 6,  7,  8],[ 9, 10, 11]])

第二个参数只输入一个数值时表示均分,第三个参数表示按第几个维度进行切分

torch.split(t2, 2, 0)
# output :
(tensor([[1, 1, 2],[3, 4, 5]]),tensor([[ 6,  7,  8],[ 9, 10, 11]]))

第二个参数输入一个序列时,表示按照序列数值进行切分

torch.split(t2, [1, 3], 0)
# output :
(tensor([[1, 1, 2]]),tensor([[ 3,  4,  5],[ 6,  7,  8],[ 9, 10, 11]]))

当第二个参数输入一个序列时,序列的各数值的和必须等于对于维度下形状分量的取值。

例如,上述代码中是按照第一个维度进行切分,第一个维度有四行,因此序列的求和必须等于4,也就是1 + 3 = 4


序列中每个分量的取值表示切块大小

torch.split(t2,[1, 1, 1, 1], 0)
# output :
(tensor([[1, 1, 2]]),tensor([[3, 4, 5]]),tensor([[6, 7, 8]]),tensor([[ 9, 10, 11]]))
torch.split(t2,[1, 2], 1)
# output :
(tensor([[1],[3],[6],[9]]),tensor([[ 1,  2],[ 4,  5],[ 7,  8],[10, 11]]))

当然,split函数返回结果也是view

ts = torch.split(t2,[1, 2], 1)
ts[0][0] = 1
t2
# output :
tensor([[ 1,  1,  2],[ 3,  4,  5],[ 6,  7,  8],[ 9, 10, 11]])

10 张量的合并操作

张量的合并操作类似列表的追加元素,可以拼接、也可以堆叠。

拼接函数:cat

a = torch.zeros(2, 3)
b = torch.ones(2, 3)
c = torch.zeros(3, 3)
# dim默认取值为0,按行进行拼接
torch.cat([a, b])	
# output :
tensor([[0., 0., 0.],[0., 0., 0.],[1., 1., 1.],[1., 1., 1.]])
# 按列进行拼接
torch.cat([a, b], 1)	
# output :
tensor([[0., 0., 0., 1., 1., 1.],[0., 0., 0., 1., 1., 1.]])
# 形状不匹配时将报错
torch.cat([a, c], 1)
# output :
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 2 but got size 3 for tensor number 1 in the list.

拼接的本质是实现元素的堆积,也就是构成a、b两个二维张量的各一维张量的堆积,最终还是构成二维向量


堆叠函数:stack

a = torch.zeros(2, 3)
b = torch.ones(2, 3)
c = torch.zeros(3, 3)
# 堆叠之后,生成一个三维张量
torch.stack([a,b])
# output :
tensor([[[0., 0., 0.],[0., 0., 0.]],[[1., 1., 1.],[1., 1., 1.]]])

注意对比和**cat**函数的区别,拼接之后维度不变,堆叠之后维度升高

对于两个二维张量,拼接是把一个个元素单独提取出来之后放到二维张量中,而堆叠则是直接将两个二维张量封装到一个三维张量中。

因此,堆叠的要求更高,参与堆叠的张量必须形状完全相同

# 维度不匹配将报错
torch.stack([a, c])
# output :
RuntimeError: stack expects each tensor to be equal size, but got [2, 3] at entry 0 and [3, 3] at entry 1

11 张量维度变换

在实际操作张量进行计算时,往往需要另外进行降维和升维的操作。

squeeze函数:删除不必要的维度

t = torch.zeros(1, 1, 3, 1)
# output :
tensor([[[[0.],[0.],[0.]]]])
t.shape
# output :
torch.Size([1, 1, 3, 1])
torch.squeeze(t)
# output :
tensor([0., 0., 0.])
torch.squeeze(t).shape
# output :
torch.Size([3])

简单理解,squeeze就相对于提出了shape返回结果中的1.

t1 = torch.zeros(1, 1, 3, 2, 1, 2)
torch.squeeze(t1)
torch.squeeze(t1).shape
# output :
torch.Size([3, 2, 2])

unsqueeze函数:手动升维

t = torch.zeros(1, 2, 1, 2)
t.shape
# output :
torch.Size([1, 2, 1, 2])
# 在第1个维度索引上升高1个维度
torch.unsqueeze(t, dim = 0)
# output :
tensor([[[[[0., 0.]],[[0., 0.]]]]])
torch.unsqueeze(t, dim = 0).shape
# output :
torch.Size([1, 1, 2, 1, 2])
# 在第3个维度索引上升高1个维度
torch.unsqueeze(t, dim = 2).shape
# output :
torch.Size([1, 2, 1, 1, 2])

注意理解维度和shape返回结果一一对应的关系,shape返回的序列有多少元素,张量就有多少维度。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/68449.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

强推未发表!3D图!Transformer-LSTM+NSGAII工艺参数优化、工程设计优化!

目录 效果一览基本介绍程序设计参考资料 效果一览 基本介绍 1.Transformer-LSTMNSGAII多目标优化算法,工艺参数优化、工程设计优化!(Matlab完整源码和数据) Transformer-LSTM模型的架构:输入层:多个变量作…

SpringCloud系列教程:微服务的未来(十一)服务注册、服务发现、OpenFeign快速入门

本篇博客将通过实例演示如何在 Spring Cloud 中使用 Nacos 实现服务注册与发现,并使用 OpenFeign 进行服务间调用。你将学到如何搭建一个完整的微服务通信框架,帮助你快速开发可扩展、高效的分布式系统。 目录 前言 服务注册和发现 服务注册 ​编辑 …

跨境电商使用云手机用来做什么呢?

随着跨境电商的发展,越来越多的卖家开始尝试使用云手机来协助他们的业务,这是因为云手机具有许多优势。那么,具体来说,跨境电商使用云手机可以做哪些事情呢? (一)实现多账号登录和管理 跨境电商…

一体机cell服务器更换内存步骤

一体机cell服务器更换内存步骤: #1、确认grdidisk状态 cellcli -e list griddisk attribute name,asmmodestatus,asmdeactivationoutcome #2、offline griddisk cellcli -e alter griddisk all inactive #3、确认全部offline后进行关机操作 shutdown -h now #4、开…

“AI开放式目标检测系统:开启智能识别新时代

嘿,朋友们!今天咱们来聊聊一个超酷炫的技术——AI开放式目标检测系统。这可不是什么高大上、遥不可及的玩意儿,它已经悄悄地走进了我们的生活,改变着我们对世界的认知和互动方式呢。 先来说说,什么是AI开放式目标检测系…

【鱼皮大佬API开放平台项目】Spring Cloud Gateway HTTPS 配置问题解决方案总结

问题背景 项目架构为前后端分离的微服务架构: 前端部署在 8000 端口API 网关部署在 9000 端口后端服务包括: api-backend (9001端口)api-interface (9002端口) 初始状态: 前端已配置 HTTPS(端口 8000)后端服务未配…

【游戏设计原理】68 - 玩家错误

一、错误类型 玩家错误类型 行为错误(performance errors)和运动控制错误(motor control errors)是玩家在游戏中常犯的错误。 运动控制错误 错误发生在玩家协调或掌握输入设备时,可能包括不小心按错键或未能及时把握战…

2.使用Spring BootSpring AI快速构建AI应用程序

Spring AI 是基于 Spring Boot3.x 框架构建,Spring Boot官方提供了非常便捷的工具Spring Initializr帮助开发者快速的搭建Spring Boot应用程序,IDEA也集成了此工具。本文使用的开发工具IDEASpring Boot 3.4Spring AI 1.0.0-SNAPSHOTMaven。 1.创建Spring Boot项目 …

Ubuntu离线docker compose安装DataEase 2.10.4版本笔记

1、先准备一个可以正常上网的相同版本的Ubuntu系统,可以使用虚拟机。Ubuntu系统需要安装好docker compose或docker-compose 2、下载dataease-online-installer-v2.10.4-ce.tar在线安装包,解压并执行install.sh进行安装和启动 3、导出docker镜像 sudo d…

【报错解决】Sql server 2022连接数据库时显示证书链是由不受信任的颁发机构颁发的

SSMS 20在连接Sql server 2022数据库时有如下报错: A connection was successfully established with the server, but then an error occurred during the login process. (provider: SSL Provider, error: 0 - 证书链是由不受信任的颁发机构颁发的。 原因是尝试使…

LSA更新、撤销

LSA的新旧判断&#xff1a; 1.seq&#xff0c;值越大越优先 2.chksum&#xff0c;值越大越优先 3.age&#xff0c;本地的LSA age和收到的LSA age作比较 如果差值<900s&#xff0c;认为age一致&#xff0c;保留本地的&#xff1a;我本地有一条LSA是100 你给的是400 差值小于…

【FlutterDart】MVVM(Model-View-ViewModel)架构模式例子-dio版本(31 /100)

动图更精彩 dio & http 在Flutter中&#xff0c;dio和http是两个常用的HTTP请求库&#xff0c;它们各有优缺点。以下是对这两个库的详细对比&#xff1a; 功能特性 http&#xff1a; 功能&#xff1a;提供了基本的HTTP请求和响应功能&#xff0c;如GET、POST、PUT、DELE…

递归40题!再见递归

简介&#xff1a;40个问题&#xff0c;有难有易&#xff0c;均使用递归完成&#xff0c;需要C/C的指针、字符串、数组、链表等基础知识作为基础。 1、数字出现的次数 由键盘录入一个正整数&#xff0c;求该整数中每个数字出现的次数。 输入&#xff1a;19931003 输出&#xf…

STM32 FreeRTOS 的任务挂起与恢复以及查看任务状态

目录 任务的挂起与恢复的API函数 任务挂起函数 任务恢复函数 任务恢复函数&#xff08;中断中恢复&#xff09; 函数说明 注意事项 查看任务状态 任务的挂起与恢复的API函数 vTaskSuspend()&#xff1a;挂起任务, 类似暂停&#xff0c;可恢复 vTaskResume()&#xff1a…

openharmony标准系统方案之瑞芯微RK3568移植案例

标准系统方案之瑞芯微RK3568移植案例 ​本文章是基于瑞芯微RK3568芯片的DAYU200开发板&#xff0c;进行标准系统相关功能的移植&#xff0c;主要包括产品配置添加&#xff0c;内核启动、升级&#xff0c;音频ADM化&#xff0c;Camera&#xff0c;TP&#xff0c;LCD&#xff0c…

sunrays-framework 微调

文章目录 1.common-log4j2-starter 动态获取并打印日志存储的根目录的绝对路径以及应用的访问地址1.目录2.log4j2.xml 配置LOG_HOME3.LogHomePrinter.java 配置监听器4.spring.factories 注册监听器5.测试1.common-log4j2-starter-demo 配置2.启动测试 2.common-minio-starter …

ElasticSearch上

安装ElasticSearch Lucene&#xff1a;Java语言的搜索引擎类库&#xff0c;易扩展&#xff1b;高性能&#xff08;基于倒排索引&#xff09;Elasticsearch基于Lucene&#xff0c;支持分布式&#xff0c;可水平扩展&#xff1b;提供Restful接口&#xff0c;可被任何语言调用Ela…

element-ui textarea备注 textarea 多行输入框

发现用这个组件&#xff0c;为了给用户更好的体验&#xff0c;要加下属性 1. 通过设置 autosize 属性可以使得文本域的高度能够根据文本内容自动进行调整&#xff0c;并且 autosize 还可以设定为一个对象&#xff0c;指定最小行数和最大行数。:autosize"{ minRows: 3, ma…

.netframwork模拟启动webapi服务并编写对应api接口

在.NET Framework环境中模拟启动Web服务&#xff0c;可以使用几种不同的方法。一个常见的选择是利用HttpListener类来创建一个简单的HTTP服务器&#xff0c;或者使用Owin/Katana库来自托管ASP.NET Web API或MVC应用。下面简要介绍Owin/Katana示例代码。这种方法更加灵活&#x…

路由环路的产生原因与解决方法(1)

路由环路 路由环路就是数据包不断在这个网络传输&#xff0c;始终到达不了目的地&#xff0c;导致掉线或者网络瘫痪。 TTL &#xff08;生存时间&#xff09;&#xff1a;数据包每经过一个路由器的转发&#xff0c;其数值减1&#xff0c;当一个数据包的TTL值为0是&#xff0c;路…