Pytorch索引、切片、连接

文章目录

    • 1.torch.cat()
    • 2.torch.column_stack()
    • 3.torch.gather()
    • 4.torch.hstack()
    • 5.torch.vstack()
    • 6.torch.index_select()
    • 7.torch.masked_select()
    • 8.torch.reshape
    • 9.torch.stack()
    • 10.torch.where()
    • 11.torch.tile()
    • 12.torch.take()
    • 13.torch.scatter()


在这里插入图片描述


1.torch.cat()

  torch.cat() 是 PyTorch 库中的一个函数,用于沿指定维度连接张量。它接受一系列张量作为输入,并沿指定的维度进行连接。

torch.cat(tensors, dim=0, out=None)
"""
tensors:要连接的张量序列(例如,列表、元组)。
dim(可选):要沿其进行连接的维度。它指定了轴或维度编号。默认情况下,它设置为0,表示沿第一个维度进行连接。
out(可选):存储结果的输出张量。如果指定了 out,结果将存储在此张量中。如果未提供 out,则会创建一个新的张量来存储结果。
"""
import torch# 创建两个张量
tensor1 = torch.tensor([[1, 2], [3, 4]])
tensor2 = torch.tensor([[5, 6], [7, 8]])# 沿着维度0连接两个张量
result = torch.cat((tensor1, tensor2), dim=0)print(result)

2.torch.column_stack()

 torch.column_stack() 是 PyTorch 中的一个函数,用于按列堆叠张量来创建一个新的张量。它将输入张量沿着列的方向进行堆叠,并返回一个新的张量。

torch.column_stack(tensors)
"""
tensors:要堆叠的张量序列。它可以是一个包含多个张量的元组、列表或任意可迭代对象。
"""
import torchtensor1 = torch.tensor([1, 2, 3])
tensor2 = torch.tensor([4, 5, 6])result = torch.column_stack((tensor1, tensor2))print(result)

3.torch.gather()

torch.gather() 是 PyTorch 中的一个函数,用于根据给定的索引从输入张量中收集元素。它允许你按照指定的索引从输入张量中选择元素,并将它们组合成一个新的张量。

torch.gather(input, dim, index, out=None, sparse_grad=False)
"""
input:输入张量,从中收集元素。
dim:指定索引的维度。
index:包含要收集元素的索引的张量。
out(可选):输出张量,用于存储结果。
sparse_grad(可选):指定是否启用稀疏梯度。默认为 False
"""

在这里插入图片描述

import torch# 输入张量
input = torch.tensor([[1, 2], [3, 4]])# 索引张量
index = torch.tensor([[0, 0], [1, 0]])# 根据索引从输入张量中收集元素
result = torch.gather(input, 1, index)print(result)
import torch# 输入张量
input = torch.tensor([[1, 2], [3, 4]])# 索引张量
index = torch.tensor([[0, 0], [1, 0]])# 根据索引从输入张量中收集元素
result = torch.gather(input, 0, index)print(result)

4.torch.hstack()

  torch.hstack() 是 PyTorch 中的一个函数,用于沿着水平方向(列维度)堆叠张量来创建一个新的张量。它将输入张量沿着水平方向进行堆叠,并返回一个新的张量。

torch.hstack(tensors) -> Tensor
"""
tensors:要堆叠的张量序列。可以是一个包含多个张量的元组、列表或任意可迭代对象。
"""
import torchtensor1 = torch.tensor([[1, 2], [3, 4]])
tensor2 = torch.tensor([[5, 6], [7, 8]])result = torch.hstack((tensor1, tensor2))print(result)
# tensor([[1, 2, 5, 6],
#        [3, 4, 7, 8]])

5.torch.vstack()

torch.vstack()是PyTorch中用于沿垂直方向(行维度)堆叠张量的函数。它将输入张量沿垂直方向进行堆叠,并返回一个新的张量。

torch.vstack(tensors) -> Tensor
import torchtensor1 = torch.tensor([[1, 2], [3, 4]])
tensor2 = torch.tensor([[5, 6], [7, 8]])result = torch.vstack((tensor1, tensor2))print(result)
tensor([[1, 2],[3, 4],[5, 6],[7, 8]])

6.torch.index_select()

torch.index_select() 是 PyTorch 中的一个函数,用于按索引从输入张量中选择元素并返回一个新的张量。

torch.index_select(input, dim, index, out=None) -> Tensor
"""
input:输入张量,从中选择元素。
dim:指定索引的维度。即要在 input 张量的哪个维度上进行索引。
index:指定要选择的索引的张量。它的形状可以与 input 张量的形状不同,但必须满足广播规则。
out(可选):输出张量,用于存储结果。如果提供了 out,则结果将存储在此张量中。
"""
import torch# 输入张量
input = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])# 索引张量
index = torch.tensor([0, 2])# 根据索引从输入张量中选择元素
result = torch.index_select(input, 0, index)print(result)
tensor([[1, 2, 3],[7, 8, 9]])

7.torch.masked_select()

torch.masked_select() 是 PyTorch 中的一个函数,用于根据给定的掩码从输入张量中选择元素并返回一个新的张量。

torch.masked_select(input, mask, out=None) -> Tensor
"""
input:输入张量,从中选择元素。
mask:掩码张量,用于指定要选择的元素。mask 张量的形状必须与 input 张量的形状相同,或者满足广播规则。
out(可选):输出张量,用于存储结果。如果提供了 out,则结果将存储在此张量中。
"""
import torch# 输入张量
input = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])# 掩码张量
mask = torch.tensor([[True, False, True], [False, True, False], [True, False, True]])# 根据掩码从输入张量中选择元素
result = torch.masked_select(input, mask)print(result)
tensor([1, 3, 5, 7, 9])

8.torch.reshape

torch.reshape() 是 PyTorch 中的一个函数,用于改变张量的形状而不改变元素的数量。它返回一个具有新形状的新张量,其中的元素与原始张量相同。

torch.reshape(input, shape) -> Tensor
"""
input:输入张量,要改变形状的张量。
shape:指定的新形状。可以是一个整数元组或传递一个张量,其中包含新的形状。
torch.reshape() 函数将输入张量重新排列为指定的新形状。新的形状应该满足以下条件:1. 新形状的元素数量与原始张量的元素数量相同。
2. 新形状中各维度的乘积与原始张量的元素数量相同。
"""
import torch# 输入张量
input = torch.tensor([[1, 2, 3], [4, 5, 6]])# 改变形状为 (3, 2)
result1 = torch.reshape(input, (3, 2))# 改变形状为 (1, 6)
result2 = torch.reshape(input, (1, 6))# 改变形状为 (6,)
result3 = torch.reshape(input, (6,))print(result1)
print(result2)
print(result3)

9.torch.stack()

torch.stack() 是 PyTorch 中的一个函数,用于沿着新的维度对给定的张量序列进行堆叠操作。

torch.stack(tensors, dim=0, *, out=None) -> Tensor
"""
tensors:张量的序列,要进行堆叠操作的张量。
dim(可选):指定新的维度的位置。默认值为 0。
out(可选):输出张量。如果提供了输出张量,则将结果存储在该张量中。
"""
import torch# 张量序列
tensor1 = torch.tensor([1, 2, 3])
tensor2 = torch.tensor([4, 5, 6])
tensor3 = torch.tensor([7, 8, 9])# 在维度 0 上进行堆叠操作
result = torch.stack([tensor1, tensor2, tensor3], dim=0)print(result)
tensor([[1, 2, 3],[4, 5, 6],[7, 8, 9]])

10.torch.where()

torch.where() 是 PyTorch 中的一个函数,用于根据给定的条件从两个张量中选择元素。

torch.where(condition, x, y) -> Tensor
"""
condition:条件张量,一个布尔张量,用于指定元素选择的条件。
x:张量,与 condition 形状相同的张量,当对应位置的 condition 元素为 True 时,选择 x 中的对应元素。
y:张量,与 condition 形状相同的张量,当对应位置的 condition 元素为 False 时,选择 y 中的对应元素。
"""
import torch# 条件张量
condition = torch.tensor([[True, False], [False, True]])# 选择的张量 x
x = torch.tensor([[1, 2], [3, 4]])# 选择的张量 y
y = torch.tensor([[5, 6], [7, 8]])# 根据条件选择元素
result = torch.where(condition, x, y)print(result)
#tensor([[1, 6],
#       [7, 4]])
import torch# 输入张量
input = torch.tensor([1.5, 0.8, -1.2, 2.7, -3.5])# 阈值
threshold = 0# 根据阈值选择元素
result = torch.where(input > threshold, torch.tensor(1), torch.tensor(0))print(result)#tensor([1, 1, 0, 1, 0])

11.torch.tile()

torch.tile() 是 PyTorch 中的一个函数,用于在指定维度上重复张量的元素。

torch.tile(input, reps) -> Tensor
"""
input:输入张量,要重复的张量。
reps:重复的次数,可以是一个整数或一个元组。
"""
import torch# 输入张量
input = torch.tensor([1, 2, 3])# 在维度 0 上重复 2 次
result = torch.tile(input, 2)print(result)#tensor([1, 2, 3, 1, 2, 3])
import torch# 输入张量
input = torch.tensor([[1, 2], [3, 4]])# 在维度 0 和维度 1 上重复
result = torch.tile(input, (2, 3))print(result)
tensor([[1, 2, 1, 2, 1, 2],[3, 4, 3, 4, 3, 4],[1, 2, 1, 2, 1, 2],[3, 4, 3, 4, 3, 4]])

12.torch.take()

torch.take() 是 PyTorch 中的一个函数,用于在给定索引处提取张量的元素。

torch.take(input, indices) -> Tensor
"""
input:输入张量,要从中提取元素的张量。
indices:索引张量,包含要提取的元素的索引。它可以是一个一维整数张量或一个具有相同形状的张量。
"""
import torch# 输入张量
input = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])# 索引张量
indices = torch.tensor([1, 4, 7])# 提取元素
result = torch.take(input, indices)print(result)# tensor([2, 5, 8])
import torch# 输入张量
input = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])# 索引张量
indices = torch.tensor([[0, 2], [1, 2]])# 提取部分元素
result = torch.take(input, indices)print(result)
tensor([[1, 3],[2, 3]])

13.torch.scatter()

torch.scatter() 是 PyTorch 中的一个函数,用于根据索引在张量中进行散射操作。散射操作是指根据给定的索引,将源张量的值散布(写入)到目标张量的指定位置。

在这里插入图片描述

torch.scatter(input, dim, index, src)
"""
input:输入张量,表示目标张量,散射操作将在此张量上进行。
dim:整数值,表示散射操作沿着的维度。
index:索引张量,指定散射操作的目标位置。
src:源张量,包含要散射到目标张量中的值。
"""
import torch# 创建目标张量
target = torch.zeros(3, 4)# 创建索引张量和源张量
index = torch.tensor([[0, 1, 2, 0], [2, 1, 0, 2]])
source = torch.tensor([1, 2, 3, 4])# 执行散射操作
torch.scatter(target, dim=1, index=index, src=source)print(target)
# 输出:
# tensor([[1., 4., 3., 1.],
#         [0., 3., 2., 0.],
#         [3., 2., 1., 3.]])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/20459.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于排队理论的客户结账等待时间MATLAB模拟仿真

目录 1.程序功能描述 2.测试软件版本以及运行结果展示 3.核心程序 4.本算法原理 4.1 排队系统的组成 4.2 基本概念 4.3 常见的排队模型 5.完整程序 1.程序功能描述 基于排队理论的客户结账等待时间MATLAB模拟仿真,分析平均队长,平均等待时长&…

智慧交通视频AI监控识别解决方案

背景分析 随着社会的进步和科技的不断发展,互联网技术和AI视觉分析技术日益成熟,为传统交通监控领域带来了新的发展机遇。AI视觉分析技术的引入,不仅提升了交通监控的智能化和自动化水平,还显著减轻了交管部门的工作负担&#xf…

雷卯解析AECQ101与AECQ200

AEC(汽车电子委员会)推出了AECQ101和AECQ200这两项行业标准,作为汽车电子元件的“品质通行证”。上海雷卯已率先申请AECQ101证书。 鉴于有些客户不清楚AECQ101和AECQ200的区别,哪些供应商应该提供什么类别证书。本文将带您解析这…

本地知识库开源框架Fastgpt、MaxKB产品体验

本地知识库开源框架Fastgpt、MaxKB产品体验 背景fastgpt简介知识库共享部署 MaxKB总结 背景 上一篇体验了Quivr、QAnything两个开源知识库模型框架,这次介绍两款小众但是体验比较好的产品。 fastgpt 简介 FastGPT 是一个基于 LLM 大语言模型的知识库问答系统&am…

第四范式Q1业务进展:驰而不息 用科技锻造不朽价值

5月28日,第四范式发布今年前三个月的核心业务进展,公司坚持科技创新,业务稳步拓展,用人工智能为千行万业贡献价值。 今年前三个月,公司总收入人民币8.3亿元,同比增长28.5%,毛利润人民币3.4亿元&…

python猜数字游戏

猜数字游戏 计算机随机产生一个1~100的随机数,人输入自己猜的数字, 计算机给出对应的提示“大一点”,”小一点“或”恭喜你猜对了“,直到猜中为止。 如果猜的次数超过7次,计算机温馨提示“智商余额明显不足” import …

SLAM精度评估—evo

evo是一款用于SLAM轨迹精度的评估工具。核心功能是(1)能够绘制(传感器运动)轨迹,(2)评估估计轨迹与真值(ground truth)的误差。evo支持多种数据集的轨迹格式(TUM、KITT、…

用户购物性别模型标签(USG)之决策树模型

一、USG模型引入: 首先了解一下,如何通过大数据来确定用户的真实性别, 经常谈论的用户精细化运营,到底是什么? 简单来讲,就是将网站的每个用户标签化,制作一个属于用户自己的网络身份证。然后,运营人员 通…

D3D 顶点格式学习

之前D3D画三角形的代码中有这一句, device.VertexFormat CustomVertex.TransformedColored.Format; 这是设置顶点格式; 画出的三角形如下, 顶点格式是描述一个三维模型的顶点信息的格式;可以包含以下内容, 位置…

Xcode设置cocoapods库的最低兼容版本

目录 前言 1.使用cocoapods遇到的问题 2.解决办法 1.用法解释 1. config.build_settings: 2.IPHONEOS_DEPLOYMENT_TARGET 2.使用实例 3.注意事项 1.一致性 2.pod版本 前言 这篇文章主要是介绍如何设置cocoapods三方库如何设置最低兼容的版本。 1.使用cocoapods遇到的…

qt学习笔记

qt的对象树 在 Qt中创建对象的时候会提供一个 Parent 对象指针,Q0bject是以对象树的形式组织起来的。 当你创建一个 Q0biect 对象时,会看到 Q0biect 的构造函数接收一个Q0b.ject指针作为参数,这个参数就是 parent,也就是父对象指…

三次样条插值的实现(Matlab)

一、问题描述 三次样条插值的实现。 二、实验目的 掌握三次样条插值方法的原理,能够编写代码获得自然、抛物线端点以及非纽结三次样条。 三、实验内容及要求 找出并画出三次样条S,满足S(0) 1, S(1) 3, S(2) 3, S(3) 4, S(4) 2,其中…

Spring Boot 开发 -- 过滤器与拦截器详解

引言 在Web开发中,经常需要对请求进行预处理或在响应后进行后处理,Spring Boot提供了过滤器和拦截器两种机制来实现这一需求。虽然它们都可以用来处理HTTP请求和响应,但在使用场景、执行顺序和配置方式上存在明显的差异。本文将详细讲解Spri…

LeetCode 2928.给小朋友们分糖果 I:Java提交的运行时间超过了61%的用户

【LetMeFly】2928.给小朋友们分糖果 I:Java提交的运行时间超过了61%的用户 力扣题目链接:https://leetcode.cn/problems/distribute-candies-among-children-i/ 给你两个正整数 n 和 limit 。 请你将 n 颗糖果分给 3 位小朋友,确保没有任何…

易语言贪吃蛇游戏(附带源码)

易语言贪吃蛇游戏 效果图源码说明源码领取下期更新预报 效果图 源码说明 本源码用易语言来编写,供大家研究,保留版权,谢谢! 源码领取 易语言贪吃蛇游戏源码领取地址:https://www.123pan.com/s/ji8kjv-TKPU3.html提取…

Oracle中rman的增量备份使用分享

继上次使用RMAN的全量备份和异机还原以后,开始研究一下增量备份和还原的方法。相比于全量RMAN的备份还原,增量的备份还原就相对简单。本实践教程直接上操作,还是回归到一个问题,就是关于两个数据库创建时候,必须保持or…

泄漏libc基地址

拿libc基地址 方法一:格式化字符串 格式化字符串,首先确定输入的 AAAA 在栈上的位置(x)。使用 elf.got[fun] 获得got地址。利用格式化字符串,构造payload泄漏got地址处的值,recv接受到的字符串中&#xf…

linux部署运维1——centos7.9离线安装部署web或java项目所需的依赖环境,包括mysql8.0,nginx1.20,redis5.0等工具

在实际项目部署运维过程中,如果是云服务器,基本安装项目所需的依赖环境都是通过yum联网拉取网络资源实现自动化安装的;但是对于一些特殊场合,在没有外部网络的情况下,就无法使用yum命令联网操作,只能通过编…

网络报文协议头学习

vxlan:就是通过Vxlan_header头在原始报文前面套了一层UDPIP(4/6)Eth_hdr 需求背景:VXLAN:简述VXLAN的概念,网络模型及报文格式_vxlan报文格式-CSDN博客 如果服务器作为VTEP,那从服务器发送到接…

jmeter之MD5加密请求秒杀接口教程

前言: 有时候在项目中,需要使用MD5加密的方法才可以登录,或者在某一个接口中遇到 登录获取token后才可以进行关联,下面介绍下遇到的常见使用 一、第一种方法:使用jmeter自带的函数助手digest 选择工具,选…