PyTorch使用(4)-张量拼接操作

文章目录

  • 张量拼接操作
  • 1. torch.cat 函数的使用
      • 1.1. torch.cat 定义
      • 1.2. 语法
      • 1.3. 关键规则
    • 1.4. 示例代码
      • 1.4.1. 沿行拼接(dim=0)
      • 1.4.2. 沿列拼接(dim=1)
      • 1.4.3. 高维拼接(dim=2)
    • 1.5. 错误场景分析
      • 1.5.1. 维度数不一致
      • 1.5.2. 非拼接维度大小不匹配
      • 1.5.3. 设备或数据类型不一致
      • 1.6. 与 torch.stack 的区别
    • 1.7. 高级用法
      • 1.7.1. 批量拼接(Batch-wise Concatenation)
      • 1.7.2. 自动广播支持
    • 1.8. 总结
  • 2. torch.stack 函数的使用
    • 2.1. 函数定义
    • 2.2. 核心规则
    • 2.3. 使用示例
    • 2.4. 与 torch.cat 的对比
    • 2.4. 常见错误与调试
    • 2.5. 工程实践技巧
    • 2.7. 性能优化建议
    • 2.8. 总结

张量拼接操作

1. torch.cat 函数的使用

在 PyTorch 中,torch.cat 是用于沿指定维度拼接多个张量的核心函数

1.1. torch.cat 定义

功能: 将多个张量沿指定维度(dim)拼接,生成新张量。

输入要求:

所有输入张量的 维度数必须相同。

非拼接维度的大小必须一致。

张量必须位于 同一设备 且 数据类型相同。

1.2. 语法

torch.cat(tensors, dim=0, *, out=None) → Tensor

参数:

tensors (sequence of Tensors):需拼接的张量序列(列表或元组)。

dim (int, optional):拼接的维度索引,默认为 0。

out (Tensor, optional):可选输出张量。

1.3. 关键规则

规则示例
输入张量维度数必须相同不允许将 2D 张量与 3D 张量拼接
非拼接维度大小必须一致若 dim=1,所有张量的 dim=0、dim=2 等大小必须相同
拼接维度大小可以不同沿 dim=0 拼接形状为 (2, 3) 和 (3, 3) 的张量,结果为 (5, 3)
输出维度数与输入相同输入均为 3D 张量,输出仍为 3D 张量

1.4. 示例代码

1.4.1. 沿行拼接(dim=0)

import torchA = torch.tensor([[1, 2], [3, 4]])    # shape: (2, 2)
B = torch.tensor([[5, 6], [7, 8]])    # shape: (2, 2)
C = torch.cat([A, B], dim=0)          # shape: (4, 2)
print(C)
# 输出:
# tensor([[1, 2],
#         [3, 4],
#         [5, 6],
#         [7, 8]])

1.4.2. 沿列拼接(dim=1)

D = torch.tensor([[9], [10]])          # shape: (2, 1)
E = torch.cat([A, D], dim=1)          # shape: (2, 3)
print(E)
# 输出:
# tensor([[ 1,  2,  9],
#         [ 3,  4, 10]])

1.4.3. 高维拼接(dim=2)

F = torch.randn(2, 3, 4)              # shape: (2, 3, 4)
G = torch.randn(2, 3, 5)              # shape: (2, 3, 5)
H = torch.cat([F, G], dim=2)          # shape: (2, 3, 9)

1.5. 错误场景分析

1.5.1. 维度数不一致

A_2D = torch.randn(2, 3)
B_3D = torch.randn(2, 3, 4)
try:torch.cat([A_2D, B_3D], dim=0)  # 报错:维度数不同
except RuntimeError as e:print("错误:", e)

1.5.2. 非拼接维度大小不匹配

A = torch.randn(2, 3)
B = torch.randn(3, 3)              # dim=0 大小不同
try:torch.cat([A, B], dim=1)       # 报错:非拼接维度大小不一致
except RuntimeError as e:print("错误:", e)

1.5.3. 设备或数据类型不一致

if torch.cuda.is_available():A_cpu = torch.randn(2, 3)B_gpu = torch.randn(2, 3).cuda()try:torch.cat([A_cpu, B_gpu], dim=0)  # 报错:设备不一致except RuntimeError as e:print("错误:", e)

1.6. 与 torch.stack 的区别

函数输入维度输出维度核心用途
torch.cat所有张量维度相同维度数与输入相同沿现有维度扩展张量
torch.stack所有张量形状严格相同新增一个维度创建新维度合并张量

示例对比:

A = torch.tensor([1, 2])          # shape: (2)
B = torch.tensor([3, 4])          # shape: (2)# cat 沿 dim=0
C_cat = torch.cat([A, B])         # shape: (4)# stack 沿 dim=0
C_stack = torch.stack([A, B])     # shape: (2, 2)

1.7. 高级用法

1.7.1. 批量拼接(Batch-wise Concatenation)

# 批量数据拼接(batch_size=2)
batch_A = torch.randn(2, 3, 4)    # shape: (2, 3, 4)
batch_B = torch.randn(2, 3, 5)    # shape: (2, 3, 5)
batch_C = torch.cat([batch_A, batch_B], dim=2)  # shape: (2, 3, 9)

1.7.2. 自动广播支持

torch.cat 不支持广播,必须显式匹配形状:

A = torch.randn(3, 1)            # shape: (3, 1)
B = torch.randn(1, 3)            # shape: (1, 3)
try:torch.cat([A, B], dim=1)     # 报错:非拼接维度大小不一致
except RuntimeError as e:print("错误:", e)

1.8. 总结

适用场景:合并同维度的特征、批量数据拼接等。

核心规则

1、输入张量维度数相同。2、非拼接维度大小严格一致。3、设备与数据类型一致。

优先使用 torch.cat:当需要在现有维度扩展时;需新增维度时选择 torch.stack。

2. torch.stack 函数的使用

2.1. 函数定义

torch.stack(tensors, dim=0, *, out=None) → Tensor

功能:将多个张量沿新维度堆叠(非拼接),要求所有输入张量形状严格相同。

  • 输入:
    • tensors (sequence of Tensors):形状相同的张量序列(列表/元组)。
    • dim (int):新维度的插入位置(支持负数索引)。
  • 输出:
    • 比输入张量多一维的新张量。

2.2. 核心规则

规则示例
输入张量形状必须完全相同(3, 4) 只能与 (3, 4) 堆叠,不能与 (3, 5) 堆叠
输出维度 = 输入维度 + 1输入(3, 4) → 输出 (n, 3, 4)(n为堆叠数量)
新维度大小 = 张量数量堆叠3个张量 → 新维度大小为3
设备/数据类型必须一致所有张量需在同一设备(CPU/GPU)且 dtype 相同

2.3. 使用示例

(1) 基础用法

import torch
# 定义两个相同形状的张量
A = torch.tensor([1, 2, 3])      # shape: (3,)
B = torch.tensor([4, 5, 6])      # shape: (3,)# 沿新维度0堆叠
C = torch.stack([A, B])          # shape: (2, 3)
print(C)
# tensor([[1, 2, 3],
#         [4, 5, 6]])# 沿新维度1堆叠
D = torch.stack([A, B], dim=1)   # shape: (3, 2)
print(D)
# tensor([[1, 4],
#         [2, 5],
#         [3, 6]])

(2) 高维张量堆叠

# 形状为 (2, 3) 的张量
X = torch.randn(2, 3)
Y = torch.randn(2, 3)# 沿dim=0堆叠(新增最外层维度)
Z0 = torch.stack([X, Y])         # shape: (2, 2, 3)# 沿dim=1堆叠(插入到第二维)
Z1 = torch.stack([X, Y], dim=1)  # shape: (2, 2, 3)# 沿dim=-1堆叠(插入到最后一维)
Z2 = torch.stack([X, Y], dim=-1) # shape: (2, 3, 2)

(3) 批量数据构建

# 模拟批量图像数据(单张图像shape: (3, 32, 32))
image1 = torch.randn(3, 32, 32)
image2 = torch.randn(3, 32, 32)
image3 = torch.randn(3, 32, 32)# 构建batch维度(batch_size=3)
batch = torch.stack([image1, image2, image3])  # shape: (3, 3, 32, 32)

2.4. 与 torch.cat 的对比

特性 torch.stack torch.cat
输入要求 所有张量形状严格相同 仅需非拼接维度相同
输出维度 比输入多1维 与输入维度相同
内存开销 更高(新增维度) 更低(复用现有维度)
典型场景 构建batch、新增序列维度 合并特征、扩展现有维度
示例对比:

A = torch.tensor([1, 2])
B = torch.tensor([3, 4])# stack -> 新增维度
stacked = torch.stack([A, B])    # shape: (2, 2)# cat -> 沿现有维度扩展
concatenated = torch.cat([A, B]) # shape: (4)

2.4. 常见错误与调试

(1) 形状不匹配

A = torch.randn(2, 3)
B = torch.randn(2, 4)  # 第二维不同
try:torch.stack([A, B])
except RuntimeError as e:print("Error:", e)  # Sizes of tensors must match

(2) 设备不一致

A_cpu = torch.randn(3, 4)
B_gpu = torch.randn(3, 4).cuda()
try:torch.stack([A_cpu, B_gpu])
except RuntimeError as e:print("Error:", e)  # Expected all tensors to be on the same device

(3) 空张量处理

empty_tensors = [torch.tensor([]) for _ in range(3)]
try:torch.stack(empty_tensors)  # 可能引发未定义行为
except RuntimeError as e:print("Error:", e)

2.5. 工程实践技巧

(1) 批量数据预处理

# 从数据加载器中逐批读取数据并堆叠
batch_images = []
for image in dataloader:batch_images.append(image)if len(batch_images) == batch_size:batch = torch.stack(batch_images)  # shape: (batch_size, C, H, W)process_batch(batch)batch_images = []

(2) 序列建模中的时间步堆叠

# RNN输入序列构建(T个时间步,每个步长特征dim=D)
time_steps = [torch.randn(1, D) for _ in range(T)]
input_seq = torch.stack(time_steps, dim=1)  # shape: (1, T, D)

(3) 多任务输出合并

# 多任务学习中的输出堆叠
task1_out = torch.randn(batch_size, 10)
task2_out = torch.randn(batch_size, 5)
multi_out = torch.stack([task1_out, task2_out], dim=1)  # shape: (batch_size, 2, ...)

2.7. 性能优化建议

避免循环中频繁堆叠:优先在内存中收集所有张量后一次性堆叠。

# 低效做法
result = None
for x in data_stream:if result is None:result = x.unsqueeze(0)else:result = torch.stack([result, x.unsqueeze(0)])# 高效做法
tensor_list = [x for x in data_stream]
result = torch.stack(tensor_list)

显存不足时考虑分块处理:

chunk_size = 1000
for i in range(0, len(big_list), chunk_size):chunk = torch.stack(big_list[i:i+chunk_size])process(chunk)

2.8. 总结

核心用途:构建batch、新增维度、多任务输出整合。

关键检查点:

  • 输入张量形状完全一致。
  • 设备与数据类型统一。
  • 合理选择 dim 参数控制维度扩展位置。

优先选择场景:当需要显式创建新维度时使用;若仅需扩展现有维度,用 torch.cat 更高效。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/77060.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

linux命令之yes(Linux Command Yes)

linux命令之yes 简介与功能 yes 命令在 Linux 系统中用于重复输出一行字符串,直到被杀死(kill)。该命令最常见的用途是自动化控制脚本中的交互式命令,以便无需用户介入即可进行连续的确认操作。 用法示例 基本用法非常简单&am…

《算法笔记》10.3小节——图算法专题->图的遍历 问题 B: 连通图

题目描述 给定一个无向图和其中的所有边&#xff0c;判断这个图是否所有顶点都是连通的。 输入 每组数据的第一行是两个整数 n 和 m&#xff08;0<n<1000&#xff09;。n 表示图的顶点数目&#xff0c;m 表示图中边的数目。如果 n 为 0 表示输入结束。随后有 m 行数据…

使用Prometheus监控systemd服务并可视化

实训背景 你是一家企业的运维工程师&#xff0c;需将服务器的systemd服务监控集成到Prometheus&#xff0c;并通过Grafana展示实时数据。需求如下&#xff1a; 数据采集&#xff1a;监控所有systemd服务的状态&#xff08;运行/停止&#xff09;、资源占用&#xff08;CPU、内…

OpenCV--图像边缘检测

在计算机视觉和图像处理领域&#xff0c;边缘检测是极为关键的技术。边缘作为图像中像素值发生急剧变化的区域&#xff0c;承载了图像的重要结构信息&#xff0c;在物体识别、图像分割、目标跟踪等众多应用场景中发挥着核心作用。OpenCV 作为强大的计算机视觉库&#xff0c;提供…

Rollup详解

Rollup 是一个 JavaScript 模块打包工具&#xff0c;专注于 ES 模块的打包&#xff0c;常用于打包 JavaScript 库。下面从它的工作原理、特点、使用场景、配置和与其他打包工具对比等方面进行详细讲解。 一、 工作原理 Rollup 的核心工作是分析代码中的 import 和 export 语句…

Chapter 7: Compiling C++ Sources with CMake_《Modern CMake for C++》_Notes

Chapter 7: Compiling C Sources with CMake 1. Understanding the Compilation Process Key Points: Four-stage process: Preprocessing → Compilation → Assembly → LinkingCMake abstracts low-level commands but allows granular controlToolchain configuration (c…

5分钟上手GitHub Copilot:AI编程助手实战指南

引言 近年来&#xff0c;AI编程工具逐渐成为开发者提升效率的利器。GitHub Copilot作为由GitHub和OpenAI联合推出的智能代码补全工具&#xff0c;能够根据上下文自动生成代码片段。本文将手把手教你如何快速安装、配置Copilot&#xff0c;并通过实际案例展示其强大功能。 一、…

谢志辉和他的《韵之队诗集》:探寻生活与梦想交织的诗意世界

大家好&#xff0c;我是谢志辉&#xff0c;一个扎根在文字世界&#xff0c;默默耕耘的写作者。写作于我而言&#xff0c;早已不是简单的爱好&#xff0c;而是生命中不可或缺的一部分。无数个寂静的夜晚&#xff0c;当世界陷入沉睡&#xff0c;我独自坐在书桌前&#xff0c;伴着…

Logo语言的死锁

Logo语言的死锁现象研究 引言 在计算机科学中&#xff0c;死锁是一个重要的研究课题&#xff0c;尤其是在并发编程中。它指的是两个或多个进程因争夺资源而造成的一种永久等待状态。在编程语言的设计与实现中&#xff0c;如何避免死锁成为了优化系统性能和提高程序可靠性的关…

深入理解矩阵乘积的导数:以线性回归损失函数为例

深入理解矩阵乘积的导数&#xff1a;以线性回归损失函数为例 在机器学习和数据分析领域&#xff0c;矩阵微积分扮演着至关重要的角色。特别是当我们涉及到优化问题&#xff0c;如最小化损失函数时&#xff0c;对矩阵表达式求导变得必不可少。本文将通过一个具体的例子——线性…

real_time_camera_audio_display_with_animation

视频录制 import cv2 import pyaudio import wave import threading import os import tkinter as tk from PIL import Image, ImageTk # 视频录制设置 VIDEO_WIDTH = 640 VIDEO_HEIGHT = 480 FPS = 20.0 VIDEO_FILENAME = _video.mp4 AUDIO_FILENAME = _audio.wav OUTPUT_…

【Pandas】pandas DataFrame astype

Pandas2.2 DataFrame Conversion 方法描述DataFrame.astype(dtype[, copy, errors])用于将 DataFrame 中的数据转换为指定的数据类型 pandas.DataFrame.astype pandas.DataFrame.astype 是一个方法&#xff0c;用于将 DataFrame 中的数据转换为指定的数据类型。这个方法非常…

Johnson

理论 全源最短路算法 Floyd 算法&#xff0c;时间复杂度为 O(n)跑 n 次 Bellman - Ford 算法&#xff0c;时间复杂度是 O(nm)跑 n 次 Heap - Dijkstra 算法&#xff0c;时间复杂度是 O(nmlogm) 第 3 种算法被 Johnson 做了改造&#xff0c;可以求解带负权边的全源最短路。 J…

Exce格式化批处理工具详解:高效处理,让数据更干净!

Exce格式化批处理工具详解&#xff1a;高效处理&#xff0c;让数据更干净&#xff01; 1. 概述 在数据分析、报表整理、数据库管理等工作中&#xff0c;数据清洗是不可或缺的一步。原始Excel数据常常存在格式不统一、空值、重复数据等问题&#xff0c;影响数据的准确性和可用…

(三十七)Dart 中使用 Pub 包管理系统与 HTTP 请求教程

Dart 中使用 Pub 包管理系统与 HTTP 请求教程 Pub 包管理系统简介 Pub 是 Dart 和 Flutter 的包管理系统&#xff0c;用于管理项目的依赖。通过 Pub&#xff0c;开发者可以轻松地添加、更新和管理第三方库。 使用 Pub 包管理系统 1. 找到需要的库 访问以下网址&#xff0c…

代码随想录算法训练营第三十五天 | 416.分割等和子集

416. 分割等和子集 题目链接&#xff1a;416. 分割等和子集 - 力扣&#xff08;LeetCode&#xff09; 文章讲解&#xff1a;代码随想录 视频讲解&#xff1a;动态规划之背包问题&#xff0c;这个包能装满吗&#xff1f;| LeetCode&#xff1a;416.分割等和子集_哔哩哔哩_bilibi…

HTTP 教程 : 从 0 到 1 全面指南 教程【全文三万字保姆级详细讲解】

目录 HTTP 的请求-响应 HTTP 方法 HTTP 状态码 HTTP 版本 安全性 HTTP/HTTPS 简介 HTTP HTTPS HTTP 工作原理 HTTPS 作用 HTTP 与 HTTPS 区别 HTTP 消息结构 客户端请求消息 服务器响应消息 实例 HTTP 请求方法 各个版本定义的请求方法 HTTP/1.0 HTTP/1.1 …

spring功能汇总

1.创建一个dao接口&#xff0c;实现类&#xff1b;service接口&#xff0c;实现类并且service里用new创建对象方式调用dao的方法 2.使用spring分别获取dao和service对象(IOC) 注意 2中的service里面获取dao的对象方式不用new的(DI) 运行测试&#xff1a; 使用1的方式创建servic…

Vue.js 实现下载模板和导入模板、数据比对功能核心实现。

在前端开发中&#xff0c;数据比对是一个常见需求&#xff0c;尤其在资产管理等场景中。本文将基于 Vue.js 和 Element UI&#xff0c;通过一个简化的代码示例&#xff0c;展示如何实现“新建比对”和“开始比对”功能的核心部分。 一、功能简介 我们将聚焦两个核心功能&…

volatile关键字用途说明

volatile 关键字在 C# 中用于指示编译器和运行时系统&#xff0c;某个字段可能会被多个线程同时访问&#xff0c;并且该字段的读写操作不应被优化&#xff08;例如缓存到寄存器或重排序&#xff09;&#xff0c;以确保所有线程都能看到最新的值。这使得 volatile 成为一种轻量级…