pytorch小记(十七):PyTorch 中的 `expand` 与 `repeat`:详解广播机制与复制行为(附详细示例)

pytorch小记(十七):PyTorch 中的 `expand` 与 `repeat`:详解广播机制与复制行为(附详细示例)

    • 🚀 PyTorch 中的 `expand` 与 `repeat`:详解广播机制与复制行为(附详细示例)
      • 🔍 一、基础定义
        • 1. `tensor.expand(*sizes)`
        • 2. `tensor.repeat(*sizes)`
      • 📌 二、维度行为详解
        • 使用 `expand`
        • 使用 `repeat`
      • ⚠️ 三、重点报错案例解释
        • 📌 示例 1:`expand(1, 4)` 报错
        • ✅ 示例 2:`expand(2, 4)` 正确
      • 🔁 四、repeat 的多种使用场景举例
      • 🔍 五、输入维度对 `expand` 和 `repeat` 的影响总结
      • 🎯 六、常见错误总结
      • ✅ 七、维度补齐技巧
      • 🎓 八、结语:如何选择?
    • 问题
      • 1. PyTorch 自动**广播一维 tensor**
      • 2. 和二维 `[1, 2, 3]` 效果一样?
      • 🔎 为什么以前会报错?
    • 📌 总结规律(适用于新版本 PyTorch)


🚀 PyTorch 中的 expandrepeat:详解广播机制与复制行为(附详细示例)

在使用 PyTorch 构建神经网络时,经常会遇到不同维度张量需要对齐的问题,expand()repeat() 就是两种非常常用的方式来处理张量的形状变化。本博客将详细解释两者的区别、作用、使用规则以及典型的报错原因,配合实际例子,帮助你深入理解广播机制。


🔍 一、基础定义

1. tensor.expand(*sizes)
  • 功能:沿指定维度进行“虚拟复制”,不占用额外内存
  • 要求:只能扩展 原始维度中为1的维度,否则会报错。
2. tensor.repeat(*sizes)
  • 功能真正复制数据,生成新的内存区域。
  • 不限制是否为1的维度,任意维度都能复制。

📌 二、维度行为详解

以一个张量为例:

a = torch.tensor([[1], [2]])  # shape: (2, 1)
使用 expand
print(a.expand(2, 3))

结果:

tensor([[1, 1, 1],[2, 2, 2]])
  • 第1维为 1,可以扩展成3列。
  • 数据并没有真实复制,只是通过 广播机制 显示为多列。
使用 repeat
print(a.repeat(1, 3))

结果:

tensor([[1, 1, 1],[2, 2, 2]])
  • 每一行的元素真实地复制了3份,占用了新内存。

⚠️ 三、重点报错案例解释

📌 示例 1:expand(1, 4) 报错
c = torch.tensor([[7], [8]])  # shape: (2, 1)
print(c.expand(1, 4))

错误原因

RuntimeError: The expanded size of the tensor (1) must match the existing size (2) at non-singleton dimension 0.

解释:

  • 原 tensor 的第0维是2,而你想扩展为1。
  • 非1的维度不能进行expand扩展,会触发报错。

✅ 示例 2:expand(2, 4) 正确
c = torch.tensor([[7], [8]])  # shape: (2, 1)
print(c.expand(2, 4))

输出:

tensor([[7, 7, 7, 7],[8, 8, 8, 8]])
  • 第0维是2,不变 ✅
  • 第1维是1,被扩展为4 ✅

🔁 四、repeat 的多种使用场景举例

a = torch.tensor([[1, 2, 3]])  # shape: (1, 3)
print(a.repeat(2, 3))

输出:

tensor([[1, 2, 3, 1, 2, 3],[1, 2, 3, 1, 2, 3]])

解释:

  • (2, 3) 的含义是:行重复2次,列重复3次。
  • 数据真实复制!

🔍 五、输入维度对 expandrepeat 的影响总结

操作输入维度形状输入参数说明
expand必须是显式维度尺寸必须与原tensor维度数一致,且非1的维度不能变
repeat任意形状每个维度对应复制几次
自动广播可扩展1维为任意数目expand底层用到
内存行为不复制数据expand 是 zero-copy
内存行为真正复制repeat 用得多就要小心内存

🎯 六、常见错误总结

错误场景示例错误原因
expand 维度不对tensor(2, 1).expand(1, 4)非1维度不能扩展
expand 维数不匹配tensor(2, 1).expand(4)参数数目与维度数不一致
repeat 维度数对不上tensor(2, 1).repeat(3)参数不够,需要补齐

✅ 七、维度补齐技巧

有时原始张量的维度太少,需要先 .unsqueeze() 添加维度:

x = torch.tensor([1, 2, 3])   # shape: (3,)
x = x.unsqueeze(0)            # shape: (1, 3)
x = x.expand(2, 3)

🎓 八、结语:如何选择?

  • 如果你只是想“假装复制”以减少内存开销 ➜ expand()
  • 如果你真的需要重复数据去喂模型 ➜ repeat()
  • 如果你想安全无脑复制 ➜ repeat() 更通用但代价大
  • 如果你要配合 broadcasting ➜ expand() 是你的最优选择

问题

a = torch.tensor([[1, 2, 3]])  # shape: (1, 3)
print(a.shape)
print(a.repeat(6, 4))a = torch.tensor([1, 2, 3])  # shape: (1, 3)
print(a.shape)
print(a.repeat(6, 4))

为什么维度不同但是输出是一样的?

1. PyTorch 自动广播一维 tensor

在新版 PyTorch 中(大约 1.8 起),当你对 一维张量 调用 .repeat(m, n),PyTorch 会自动地把它当作 shape 为 (1, 3),然后再执行 repeat。这相当于隐式地:

a = torch.tensor([1, 2, 3])    # shape: (3,)
a = a.unsqueeze(0)             # shape: (1, 3)
print(a.repeat(6, 4))          # 🔁 repeat(6, 4) 等价于 (6 rows, 12 columns)

2. 和二维 [1, 2, 3] 效果一样?

是的。你对比的两个 tensor:

a1 = torch.tensor([[1, 2, 3]])  # shape: (1, 3)
a2 = torch.tensor([1, 2, 3])    # shape: (3,)
print(a1.repeat(6, 4))
print(a2.repeat(6, 4))  # 现在两者结果完全一致!

输出都是 shape: (6, 12),值为重复的 [1, 2, 3]

tensor([[1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3],...[1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3]])

🔎 为什么以前会报错?

在早期版本的 PyTorch 中(<1.8),repeat(6, 4) 要求参数个数和维度完全一致。所以对 a = torch.tensor([1,2,3])(一维)来说,你只能:

a.repeat(6)  # 正确,对一维张量
a.repeat(6, 4)  # 错误(旧版本)

📌 总结规律(适用于新版本 PyTorch)

原始 tensorrepeat 维度自动行为结果
[1,2,3] (1维)repeat(6,4)自动 unsqueeze → (1,3)
[[1,2,3]](2维)repeat(6,4)直接 repeat
[1,2,3](1维)repeat(6)沿第0维重复
[[1,2,3]](2维)repeat(6)报错,维度不匹配

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/75152.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Databricks: Why did your cluster disappear?

You may found that you created a cluster many days ago, and you didnt delete it, but it is disapear. Why did this happen? Who deleted the cluster? Actually, 30 days after a compute is terminated, it is permanently deleted automaticlly. If your workspac…

C语言【输出字符串中的大写字母】

题目 输出字符串中的大写字母 思路&#xff08;注意事项&#xff09; 纯代码 #include<stdio.h> #include<string.h>int main(){char str[20], ans[20];fgets(str, sizeof(str), stdin);str[strcspn(str, "\n")] \0;for (int i 0, j 0; i < strl…

基于队列构建优先级抢占机制的LED灯框架设计与实现

文章目录 前言一、LED 显示框架概述1. 框架结构图2. 基本机制 二、核心结构与接口设计1. 状态命令结构2. 状态项结构3. LED框架配置结构4. LED运行控制器 三、LED框架逻辑流程1. 初始化逻辑2. 优先级抢占判断与处理逻辑3. 执行队列命令并处理tick4. 队列为空时的默认状态回滚 四…

PyQt6实例_A股财报数据维护工具_解说并数据与完整代码分享

目录 1 20250403之前的财报数据 2 整个项目代码 3 工具使用方法 3.1 通过akshare下载 3.2 增量更新 3.3 查看当前数据情况 3.4 从数据库中下载数据 视频 1 20250403之前的财报数据 通过网盘分享的文件&#xff1a;财报三表数据20250403之前.7z 链接: https://pan.ba…

React 之 Redux 第三十一节 useDispatch() 和 useSelector()使用以及详细案例

使用 Redux 实现购物车案例 由于 redux 5.0 已经将 createStore 废弃&#xff0c;我们需要先将 reduxjs/toolkit 安装一下&#xff1b; yarn add reduxjs/toolkit// 或者 npm install reduxjs/toolkit使用 vite 创建 React 项目时候 配置路径别名 &#xff1a; // 第一种写法…

Spring Boot 中集成 Knife4j:解决文件上传不显示文件域的问题

Spring Boot 中集成 Knife4j&#xff1a;解决文件上传不显示文件域的问题 在使用 Knife4j 为 Spring Boot 项目生成 API 文档时&#xff0c;开发者可能会遇到文件上传功能不显示文件域的问题。本文将详细介绍如何解决这一问题&#xff0c;并提供完整的解决方案。 Knife4j官网…

OpenCV 图形API(17)计算输入矩阵 src 中每个元素的平方根函数sqrt()

操作系统&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 编程语言&#xff1a;C11 描述 计算数组元素的平方根。 cv::gapi::sqrt 函数计算每个输入数组元素的平方根。对于多通道数组&#xff0c;每个通道会独立处理。其精度大约与内置的 …

大学论文书写规范与格式说明

大学论文书写规范与格式说明 (适用于人文社科、理工科通用框架) 一、论文整体结构 1. 基本组成部分 封面 包含论文标题、作者姓名、学院/专业、学号、指导教师、提交日期等(按学校模板填写)。 中英文摘要 中文摘要:300~500字,概述研究背景、方法、结论与创新点,末尾附…

C# 串口通信

1. 导入 using System.IO.Ports;2. 初始化定义 SerialPort sp new SerialPort(); // 设置串口 sp.PortName "COM3"; // 串口 sp.BaudRate 9600; // 波特率 sp.Parity Parity.None; // 校验位 sp.DataBits 8; // 数据位 sp.StopBits StopBits.One; // 停…

android14 keycode 上报 0 解决办法

驱动改完后发现上报了keycode=0 04-07 13:02:33.201 2323 2662 D WindowManager: interceptKeyTq keycode=0 interactive=false keyguardActive=true policyFlags=2000000 04-07 13:02:33.458 2323 2662 D WindowManager: interceptKeyTq keycode=0 interactive=false key…

C++day9

思维导图 牛客练习 练习&#xff1a; 将我们写的 myList 迭代器里面 operator[] 和 operator 配合异常再写一遍 #include <iostream> #include <cstring> #include <cstdlib> #include <unistd.h> #include <sstream> #include <vector>…

批量合并多张 jpg/png 图片为长图或者 PDF 文件,支持按文件夹合并图片

我们经常会碰到需要将多张图片拼成一张图片的场景&#xff0c;比如将多张图片拼成九宫格图片&#xff0c;或者将多张图片拼成一张长图。还有可能会碰到需要将多张图片合并成一个完整的 PDF 文件来方便我们进行打印或者传输等操作。那这些将图片合并成一张图片或者一个完整的文档…

程序化广告行业(73/89):买卖双方需求痛点及应对策略深度剖析

程序化广告行业&#xff08;73/89&#xff09;&#xff1a;买卖双方需求痛点及应对策略深度剖析 大家好&#xff01;一直以来&#xff0c;我都热衷于在技术领域探索学习&#xff0c;也深知知识的分享能让我们共同进步。写这篇博客的目的&#xff0c;就是希望能和大家一起深入了…

[随笔] nn.Embedding的前向传播与反向传播

nn.Embedding的前向传播与反向传播 nn.Embedding的前向计算过程 embedding module 的前向过程其实是一个索引&#xff08;查表&#xff09;的过程 表的形式是一个 matrix&#xff08;embedding.weight, learnable parameters&#xff09; matrix.shape: (v, h) v&#xff1a;…

构建实时、融合的湖仓一体数据分析平台:基于 Delta Lake 与 Apache Iceberg

1. 执行摘要 挑战&#xff1a; 传统数据仓库在处理现代数据需求时面临诸多限制&#xff0c;包括高昂的存储和计算成本、处理海量多样化数据的能力不足、以及数据从产生到可供分析的端到端延迟过高。同时&#xff0c;虽然数据湖提供了低成本、灵活的存储&#xff0c;但往往缺乏…

Maven error:Could not transfer artifact

问题描述 当项目从私有仓库下载依赖时&#xff0c;Maven 报错&#xff0c;无法从远程仓库下载指定的依赖包&#xff0c;错误信息如下&#xff1a; Could not transfer artifact com.ding.abcd:zabk-java:pom from/to releases (http://192.1122.101/repory/mavenleases/): 此…

Dify 生成提示词的 Prompt

Dify 生成提示词的 Prompt **第1次提示词****第2次提示词****第3次提示词**总结 Dify 生成提示词是&#xff0c;会和LLM进行3次交互&#xff0c;下面是和LLM进行交互是的Prompt。 以下是每次提示词的概要、目标总结以及原始Prompt&#xff1a; 第1次提示词 概要&#xff1a; …

sqli-labs靶场 less4

文章目录 sqli-labs靶场less 4 联合注入 sqli-labs靶场 每道题都从以下模板讲解&#xff0c;并且每个步骤都有图片&#xff0c;清晰明了&#xff0c;便于复盘。 sql注入的基本步骤 注入点注入类型 字符型&#xff1a;判断闭合方式 &#xff08;‘、"、’、“”&#xf…

【什么是动态链接?这里的动态是什么意思?链接了什么?】

动态链接&#xff08;Dynamic Linking&#xff09;详解 1. 什么是动态链接&#xff1f; 动态链接是 Java 虚拟机&#xff08;JVM&#xff09;在运行时将字节码中的符号引用&#xff08;Symbolic Reference&#xff09;转换为直接引用&#xff08;Direct Reference&#xff09;…

AWS S3深度剖析:云存储的瑞士军刀

1. 引言 在当今数据驱动的世界中,高效、可靠、安全的数据存储解决方案至关重要。Amazon Simple Storage Service (S3)作为AWS生态系统中的核心服务之一,为企业和开发者提供了一个强大而灵活的对象存储平台。本文将全面解析S3的核心特性,帮助读者深入理解如何充分利用这一&q…