PyTorch中的Flatten

在 PyTorch 中,Flatten 操作是将多维张量转换为一维向量的重要操作,常用于卷积神经网络(CNN)的全连接层之前。以下是 PyTorch 中实现 Flatten 的各种方法及其应用场景。

一、基本 Flatten 方法

1. 使用 torch.flatten() 函数

import torch# 创建一个4D张量 (batch_size, channels, height, width)
x = torch.randn(32, 3, 28, 28)  # 32张28x28的RGB图像# 展平整个张量
flattened = torch.flatten(x)  # 输出形状: [75264] (32*3*28*28)# 从指定维度开始展平
flattened = torch.flatten(x, start_dim=1)  # 输出形状: [32, 2352] (保持batch维度)

2. 使用 nn.Flatten 层

import torch.nn as nnflatten = nn.Flatten()  # 默认从第1维开始展平(保持batch维度)
x = torch.randn(32, 3, 28, 28)
output = flatten(x)  # 输出形状: [32, 2352]

 可以指定开始和结束维度:

flatten = nn.Flatten(start_dim=1, end_dim=2)
x = torch.randn(32, 3, 28, 28)
output = flatten(x)  # 输出形状: [32, 84, 28] (合并了第1和2维)

二、不同场景下的 Flatten 应用

1. CNN 中的典型用法

class CNN(nn.Module):def __init__(self):super().__init__()self.conv_layers = nn.Sequential(nn.Conv2d(1, 16, 3),nn.ReLU(),nn.MaxPool2d(2),nn.Conv2d(16, 32, 3),nn.ReLU(),nn.MaxPool2d(2))self.flatten = nn.Flatten()self.fc = nn.Linear(32 * 5 * 5, 10)  # 计算展平后的尺寸def forward(self, x):x = self.conv_layers(x)x = self.flatten(x)  # 形状从 [B, 32, 5, 5] 变为 [B, 800]x = self.fc(x)return x

 2. 手动计算展平后的尺寸

# 计算卷积层输出尺寸的辅助函数
def conv_output_size(input_size, kernel_size, stride=1, padding=0):return (input_size - kernel_size + 2 * padding) // stride + 1# 计算经过多层卷积和池化后的尺寸
h, w = 28, 28  # 输入尺寸
h = conv_output_size(h, 3)  # conv1: 26
w = conv_output_size(w, 3)  # conv1: 26
h = conv_output_size(h, 2, 2)  # pool1: 13
w = conv_output_size(w, 2, 2)  # pool1: 13
h = conv_output_size(h, 3)  # conv2: 11
w = conv_output_size(w, 3)  # conv2: 11
h = conv_output_size(h, 2, 2)  # pool2: 5
w = conv_output_size(w, 2, 2)  # pool2: 5
print(f"展平后的特征数: {32 * h * w}")  # 32 * 5 * 5 = 800

三、高级用法

1. 部分展平

# 只展平图像空间维度,保留通道维度
x = torch.randn(32, 3, 28, 28)
flattened = x.flatten(start_dim=2)  # 形状: [32, 3, 784]

 2. 自定义 Flatten 层

class ChannelLastFlatten(nn.Module):"""将通道维度移到最后的展平层"""def forward(self, x):# 输入形状: [B, C, H, W]x = x.permute(0, 2, 3, 1)  # [B, H, W, C]return x.reshape(x.size(0), -1)  # [B, H*W*C]

3. 展平特定维度

# 展平批量维度和通道维度
x = torch.randn(32, 3, 28, 28)
flattened = x.flatten(end_dim=1)  # 形状: [96, 28, 28] (32*3=96)

四、注意事项

  1. 维度计算:确保展平后的尺寸与全连接层的输入尺寸匹配

  2. 批量维度:通常保留第0维(batch维度)不被展平

  3. 内存连续性view()需要连续内存,必要时先调用contiguous()

  4. 替代方法x.view(x.size(0), -1)flatten(start_dim=1)的常见替代写法

五、性能比较

方法优点缺点
torch.flatten()官方推荐,可读性好
nn.Flatten()可作为网络层使用需要实例化对象
x.view()最简洁需要手动计算尺寸
x.reshape()自动处理内存连续性性能略低于view

六、示例代码

import torch
import torch.nn as nn# 定义一个包含Flatten的完整模型
class ImageClassifier(nn.Module):def __init__(self):super().__init__()self.features = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2))self.flatten = nn.Flatten()self.classifier = nn.Sequential(nn.Linear(256 * 4 * 4, 1024),  # 假设输入图像是32x32nn.ReLU(inplace=True),nn.Dropout(0.5),nn.Linear(1024, 10))def forward(self, x):x = self.features(x)x = self.flatten(x)x = self.classifier(x)return x# 使用示例
model = ImageClassifier()
input_tensor = torch.randn(16, 3, 32, 32)  # batch=16, 3通道, 32x32图像
output = model(input_tensor)
print(output.shape)  # 输出形状: [16, 10]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/77125.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Spring Boot + MyBatis + Maven论坛内容管理系统源码

项目描述 xxxForum是一个基于Spring Boot MyBatis Maven开发的一个论坛内容管理系统,主要实现了的功能有: 前台页面展示数据、广告展示内容模块:发帖、评论、帖子分类、分页、回帖统计、访问统计、表单验证用户模块:权限、资料…

探索AI编程规范化的利器:Awesome Cursor Rules

在AI辅助编程逐渐成为开发者标配的今天,如何让AI生成的代码既符合项目规范又保持高质量,成为开发者面临的新挑战。GitHub仓库**awesome-cursorrules**正是为解决这一问题而生的开源项目,它通过系统化的规则模板库,重新定义了AI编程的规范边界。本文将深入解析这一工具的核心…

AnimateCC基础教学:json数据结构的测试

一.核心代码: const user1String {"name": "张三", "age": 30, "gender": "男"}; const user1Obj JSON.parse(user1String); console.log("测试1:", user1Obj.name, user1Obj.age, user1Obj.gender);/*const u…

阿里云域名证书自动更新acme.sh

因为阿里云的免费证书只有三个月的有效期,每次更换都比较繁琐,所以找到了 acme.sh,还有一种 certbot 我没有去了解,就直接使用了 acme.sh 来更新证书,acme.sh 的主要特点就是: 支持多种 DNS 服务商自动化续…

PDF 中提取数学公式

✅ 方法一:使用 doc2x extract_formula_imgs Pix2Text 一键运行脚本(自动提取识别) 👉 适合你如果用 Python 的话,只需要运行一段脚本即可: ✅ 🔁 一步搞定脚本(仅需安装一次&…

SQL并行产生进程数量问题

有一些数据库性能问题可能是因为同时启动的并行进程过多造成的,特别常见于RAC节点重启,很多时候是因为瞬间启动了几百个并行进程,导致OS各项指标“彪高”,后台进程失去响应。最近遇到的一个,是因为SQL语句中写了/* par…

【Vue-组件】学习笔记

目录 <<回到导览组件1.项目1.1.Vue Cli1.2.项目目录1.3.运行流程1.4.组件的组成1.5.注意事项 2.组件2.1.组件注册2.2.scoped样式冲突2.3.data是一个函数2.4.props详解2.5.data和prop的区别 3.组件通信3.1.父子通信3.1.1.父传子&#xff08;props&#xff09;3.1.2.子传父…

【Kafka基础】单机安装与配置指南,从零搭建环境

学习Kafka&#xff0c;掌握Kafka的单机部署是理解其分布式特性的第一步。本文将手把手带你完成Kafka单机环境的安装、配置及基础验证&#xff0c;涵盖常见问题排查技巧。 1 环境准备 1.1 系统要求 操作系统&#xff1a;CentOS 7.9依赖组件&#xff1a;JDK 8&#xff08;Kafka …

OpenCV 图形API(21)逐像素操作

操作系统&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 编程语言&#xff1a;C11 算法描述 在OpenCV的G-API模块中&#xff0c;逐像素操作指的是对图像中的每个像素单独进行处理的操作。这些操作可以通过G-API的计算图&#xff08;Graph …

CubeMX配置STM32VET6实现网口通信(无操作系统版-附源码)

下面是使用CubeMX配置STM32F407VET6,实现以太网通讯(PHY芯片为LAN8720)的具体步骤总结: 一、硬件连接方式: 硬件原理图: 使用外部晶振为PHY芯片提供时钟。 STM32F407VET6 与 LAN8720 采用 RMII 模式连接。STM32F407VET6引脚功能(RMII)LAN8720引脚PA1ETH_REF_CLKREF_CL…

Android Compose 中获取和使用 Context 的完整指南

在 Android Jetpack Compose 中&#xff0c;虽然大多数 UI 组件不再需要直接使用 Context&#xff0c;但有时你仍然需要访问它来执行一些 Android 平台特定的操作。以下是几种在 Compose 中获取和使用 Context 的方法&#xff1a; 1. 使用 LocalContext 这是 Compose 中最常用…

在VMware下Hadoop分布式集群环境的配置--基于Yarn模式的一个Master节点、两个Slaver(Worker)节点的配置

你遇到的大部分ubuntu中配置hadoop的问题这里都有解决方法&#xff01;&#xff01;&#xff01;&#xff08;近10000字&#xff09; 概要 在Docker虚拟容器环境下&#xff0c;进行Hadoop-3.2.2分布式集群环境的配置与安装&#xff0c;完成基于Yarn模式的一个Master节点、两个…

PID灯控算法

根据代码分析&#xff0c;以下是针对PID算法和光敏传感器系统的优化建议&#xff0c;分为算法优化、代码结构优化和系统级优化三部分&#xff1a; 一、PID算法优化 1. 增量式PID 输出平滑 // 修改PID计算函数 uint16_t PID_calculation_fun(void) {if(PID_Str_Val.Tdata >…

文件映射mmap与管道文件

在用户态申请内存&#xff0c;内存内容和磁盘内容建立一一映射 读写内存等价于读写磁盘 支持随机访问 简单来说&#xff0c;把磁盘里的数据与内存的用户态建立一一映射关系&#xff0c;让读写内存等价于读写磁盘&#xff0c;支持随机访问。 管道文件&#xff1a;进程间通信机…

在 Java 中调用 ChatGPT API 并实现流式接收(Server-Sent Events, SSE)

文章目录 简介OkHttp 流式获取 GPT 响应通过 SSE 流式推送前端后端代码消息实体接口接口实现数据推送给前端 前端代码创建 sseClient.jsvue3代码 优化后端代码 简介 用过 ChatGPT 的伙伴应该想过自己通过调用ChatGPT官网提供的接口来实现一个自己的问答机器人&#xff0c;但是…

硬盘分区格式之GPT(GUID Partition Table)笔记250407

硬盘分区格式之GPT&#xff08;GUID Partition Table&#xff09;笔记250407 GPT&#xff08;GUID Partition Table&#xff09;硬盘分区格式详解 GPT&#xff08;GUID Partition Table&#xff09;是替代传统 MBR 的现代分区方案&#xff0c;专为 UEFI&#xff08;统一可扩展固…

Vite环境下解决跨域问题

在 Vite 开发环境中&#xff0c;可以通过配置代理来解决跨域问题。以下是具体步骤&#xff1a; 在项目根目录下找到 vite.config.js 文件&#xff1a;如果没有&#xff0c;则需要创建一个。配置代理&#xff1a;在 vite.config.js 文件中&#xff0c;使用 server.proxy 选项来…

交换机与ARP

交换机与 ARP&#xff08;Address Resolution Protocol&#xff0c;地址解析协议&#xff09; 的关系主要体现在 局域网&#xff08;LAN&#xff09;内设备通信的地址解析与数据帧转发 过程中。以下是二者的核心关联&#xff1a; 1. 基本角色 交换机&#xff1a;工作在 数据链…

【Spring】小白速通AOP-日志记录Demo

这篇文章我将通过一个最常用的AOP场景-方法调用日志记录&#xff0c;带你彻底理解AOP的使用。例子使用Spring BootSpring AOP实现。 如果对你有帮助可以点个赞和关注。谢谢大家的支持&#xff01;&#xff01; 一、Demo实操步骤&#xff1a; 1.首先添加Maven依赖 <!-- Sp…

git功能点管理

需求&#xff1a; 功能模块1 已经完成&#xff0c;已经提交并推送到远程&#xff0c;准备交给测试。功能模块2 已经完成&#xff0c;但不提交给测试&#xff0c;继续开发。功能模块3 正在开发中。 管理流程&#xff1a; 创建并开发功能模块1&#xff1a; git checkout main…