深度学习基本单元结构与输入输出维度解析

深度学习基本单元结构与输入输出维度解析

在深度学习领域,模型的设计和结构是理解其性能和应用的关键。本文将介绍深度学习中的基本单元结构,包括卷积神经网络(CNN)、反卷积(转置卷积)、循环神经网络(RNN)、门控循环单元(GRU)和长短期记忆网络(LSTM),并详细讨论每个单元的输入和输出维度。我们将以 MNIST 数据集为例,展示这些基本单元如何组合在一起构建复杂的模型。
之前的博客:
深入理解 RNN、LSTM 和 GRU:结构、参数与应用
理解 Conv2d 和 ConvTranspose2d 的输入输出特征形状计算

1. 模型结构概述

我们构建的模型包含以下主要部分:

  • 卷积神经网络(CNN)
  • 反卷积(转置卷积)
  • 循环神经网络(RNN)
  • 门控循环单元(GRU)
  • 长短期记忆网络(LSTM)
  • 全连接层

2. 模型代码

以下是实现综合模型的代码:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader# 定义模型
class CombinedModel(nn.Module):def __init__(self):super(CombinedModel, self).__init__()# CNN 部分self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)  # 输入: (1, 28, 28) -> 输出: (32, 28, 28)self.pool = nn.MaxPool2d(kernel_size=2, stride=2)  # 最大池化层self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)  # 输入: (32, 28, 28) -> 输出: (64, 28, 28)# 反卷积部分self.deconv = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)  # 输入: (64, 14, 14) -> 输出: (32, 28, 28)# RNN 部分self.rnn_input_size = 32 * 14 * 14  # 输入到 RNN 的特征数self.rnn = nn.RNN(input_size=self.rnn_input_size, hidden_size=128, num_layers=1,batch_first=True)  # 输入: (batch_size, seq_len, input_size)# GRU 部分self.gru = nn.GRU(input_size=128, hidden_size=64, num_layers=1,batch_first=True)  # 输入: (batch_size, seq_len, input_size)# LSTM 部分self.lstm = nn.LSTM(input_size=64, hidden_size=32, num_layers=1,batch_first=True)  # 输入: (batch_size, seq_len, input_size)# 全连接层self.fc = nn.Linear(32, 10)  # 输出: (batch_size, 10)def forward(self, x):# CNN 部分print(f'Input shape: {x.shape}')  # 输入形状: (batch_size, 1, 28, 28)x = self.pool(torch.relu(self.conv1(x)))  # 输出: (batch_size, 32, 28, 28)print(f'After conv1 and pool: {x.shape}')x = self.pool(torch.relu(self.conv2(x)))  # 输出: (batch_size, 64, 14, 14)print(f'After conv2 and pool: {x.shape}')# 反卷积部分x = self.deconv(x)  # 输出: (batch_size, 32, 14, 14)print(f'After deconv: {x.shape}')# 将数据展平并调整形状以输入到 RNNx = x.view(x.size(0), -1)  # 展平为 (batch_size, 32 * 14 * 14)print(f'After flattening: {x.shape}')x = x.unsqueeze(1)  # 添加序列长度维度,变为 (batch_size, 1, 32 * 14 * 14)print(f'After unsqueeze for RNN: {x.shape}')# RNN 部分x, _ = self.rnn(x)  # 输出: (batch_size, 1, 128)print(f'After RNN: {x.shape}')# GRU 部分x, _ = self.gru(x)  # 输出: (batch_size, 1, 64)print(f'After GRU: {x.shape}')# LSTM 部分x, _ = self.lstm(x)  # 输出: (batch_size, 1, 32)print(f'After LSTM: {x.shape}')# 取最后一个时间步的输出x = x[:, -1, :]  # 输出: (batch_size, 32)print(f'After selecting last time step: {x.shape}')# 全连接层x = self.fc(x)  # 输出: (batch_size, 10)print(f'Output shape: {x.shape}')return x# 3. 数据加载
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))  # MNIST 数据集的均值和标准差
])# 下载 MNIST 数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)# 4. 训练模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CombinedModel().to(device)
criterion = nn.CrossEntropyLoss()  # 损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001)  # 优化器# 训练过程
num_epochs = 5
for epoch in range(num_epochs):model.train()running_loss = 0.0correct = 0total = 0for images, labels in train_loader:images, labels = images.to(device), labels.to(device)  # 将数据移动到设备optimizer.zero_grad()  # 清空梯度outputs = model(images)  # 前向传播loss = criterion(outputs, labels)  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 更新参数running_loss += loss.item()_, predicted = torch.max(outputs.data, 1)  # 获取预测结果total += labels.size(0)correct += (predicted == labels).sum().item()avg_loss = running_loss / len(train_loader)accuracy = 100 * correct / totalprint(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%')# 5. 评估模型
model.eval()
correct = 0
total = 0
with torch.no_grad():for images, labels in test_loader:images, labels = images.to(device), labels.to(device)  # 将数据移动到设备outputs = model(images)_, predicted = torch.max(outputs.data, 1)  # 获取预测结果total += labels.size(0)correct += (predicted == labels).sum().item()print(f'Accuracy of the model on the test images: {100 * correct / total:.2f}%')

3. 每个基本单元的输入输出维度

3.1 CNN 部分

  1. 输入(batch_size, 1, 28, 28)

    • 这是 MNIST 数据集的输入形状,其中 1 表示单通道(灰度图像),28x28 是图像的高度和宽度。
  2. 卷积层 1

    • self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
    • 输入形状(batch_size, 1, 28, 28)
    • 输出形状(batch_size, 32, 28, 28)
    • 32 个特征图,空间维度保持不变。
  3. 最大池化层 1

    • 输入形状(batch_size, 32, 28, 28)
    • 输出形状(batch_size, 32, 14, 14)
    • 高度和宽度减半。
  4. 卷积层 2

    • self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
    • 输入形状(batch_size, 32, 14, 14)
    • 输出形状(batch_size, 64, 14, 14)
    • 64 个特征图,空间维度保持不变。
  5. 最大池化层 2

    • 输入形状(batch_size, 64, 14, 14)
    • 输出形状(batch_size, 64, 7, 7)
    • 高度和宽度再次减半。

3.2 反卷积部分

  1. 反卷积层
    • self.deconv = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
    • 输入形状(batch_size, 64, 7, 7)
    • 输出形状(batch_size, 32, 14, 14)
    • 高度和宽度翻倍。

3.3 RNN 部分

  1. 展平

    • x = x.view(x.size(0), -1)
    • 输入形状(batch_size, 32, 14, 14)
    • 输出形状(batch_size, 6272) # 这里的 6272 是 32 * 14 * 14
    • 将特征图展平为一个向量。
  2. 添加序列长度维度

    • x = x.unsqueeze(1)
    • 输入形状(batch_size, 6272)
    • 输出形状(batch_size, 1, 6272)
    • 添加序列长度维度,表示只有一个时间步。
  3. RNN

    • 输入形状(batch_size, 1, 6272)
    • 输出形状(batch_size, 1, 128)
    • RNN 输出的隐藏状态,隐藏层大小为 128。

3.4 GRU 和 LSTM 部分

  1. GRU

    • 输入形状(batch_size, 1, 128)
    • 输出形状(batch_size, 1, 64)
    • GRU 输出的隐藏状态,隐藏层大小为 64。
  2. LSTM

    • 输入形状(batch_size, 1, 64)
    • 输出形状(batch_size, 1, 32)
    • LSTM 输出的隐藏状态,隐藏层大小为 32。

3.5 全连接层

  1. 全连接层
    • self.fc = nn.Linear(32, 10)
    • 输入形状(batch_size, 32)
    • 输出形状(batch_size, 10)
    • 最终输出的类别数(10 类,表示 MNIST 的数字 0-9)。

4. 可视化模型结构

from torchinfo import summary
model = CombinedModel()
summary(model, input_size=(64,1, 28, 28))

或者

import torch
import torch.nn as nn
from torchviz import make_dot
model = CombinedModel()
dummy_input = torch.randn(1, 1, 28, 28)  # (batch_size, channels, height, width)
output = model(dummy_input)
dot = make_dot(output, params=dict(model.named_parameters()))
dot.render("model_structure", format="png")  # 生成 model_structure.png

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/62407.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

乐鑫发布 esp-iot-solution v2.0 版本

今天,乐鑫很高兴地宣布,esp-iot-solution v2.0 版本已经发布,release/v2.0 分支下的正式版本组件将为用户提供为期两年的 Bugfix 维护(直到 2027.01.25 ESP-IDF v5.3 EOL)。该版本将物联网开发中常用的功能进行了分类整…

面经-综合面/hr面

面经-综合面/hr面 概述1.大学期间遇到的困难,怎么解决的2. 大学期间印象最深/最难忘的是什么3. 大学里面担任了什么职务没?做了什么工作?4. 大学最大的遗憾是什么?5. 对自己的未来规划6. 对自己的评价7. 自己的优缺点8. 对公司的认…

pyspark实现基于协同过滤的电影推荐系统

最近在学一门大数据的课,课程要求很开放,任意做一个大数据相关的项目即可,不知道为什么我就想到推荐算法,一直到着手要做之前还没有新的更好的来代替,那就这个吧。 推荐算法 推荐算法的发展由来已久,但和…

十、Spring Boot集成Spring Security之HTTP请求授权

文章目录 往期回顾:Spring Boot集成Spring Security专栏及各章节快捷入口前言一、HTTP请求授权工作原理二、HTTP请求授权配置1、添加用户权限2、配置ExceptionTranslationFilter自定义异常处理器3、HTTP请求授权配置 三、测试接口1、测试类2、测试 四、总结 往期回顾…

Unity3d C# 实现一个基于UGUI的自适应尺寸图片查看器(含源码)

前言 Unity3d实现的数字沙盘系统中,总有一些图片或者图片列表需要点击后弹窗显示大图,这个弹窗在不同尺寸分辨率的图片查看处理起来比较麻烦,所以,需要图片能够根据容器的大小自适应地进行缩放,兼容不太尺寸下的横竖图…

DVWA 在 Windows 环境下的部署指南

目录预览 一、靶场介绍二、前置准备1. 环境准备2.靶场下载 三、安装步骤1.配置Phpstudy2.配置数据库3.配置DVWA4.登入DVWA靶场 四、参考链接 一、靶场介绍 DVWA 一共包含了十个攻击模块,分别是: Brute Force(暴力(破解&#xff…

微软企业邮箱:安全可靠的企业级邮件服务!

微软企业邮箱的设置步骤?如何注册使用烽火域名邮箱? 微软企业邮箱作为一款专为企业设计的邮件服务,不仅提供了高效便捷的通信工具,更在安全性、可靠性和功能性方面树立了行业标杆。烽火将深入探讨微软企业邮箱的多重优势。 微软…

使用UE5.5的Animator Kit变形器

UE5.5版本更新了AnimatorKit内置插件,其中包含了一些内置变形器,可以辅助我们的动画制作。 操作步骤 首先打开UE5.5,新建第三人称模板场景以便测试,并开启AnimatorKit组件。 新建Sequence,放入测试角色 点击角色右…

应用案例丨坤驰科技双通道触发采集实时FFT数据处理系统

双通道触发采集实时FFT数据处理系统 应用案例 双通道采集,每路通道需要2GSPS的采样率,每2毫秒采集一次,每次采集数据量为65536*2 Sample。采集的信号频率满足奈奎斯特采样定律。采集数据后,每路通道的数据均做运算以及FFT实时处理…

OGRE 3D----3. OGRE绘制自定义模型

在使用OGRE进行开发时,绘制自定义模型是一个常见的需求。本文将介绍如何使用OGRE的ManualObject类来创建和绘制自定义模型。通过ManualObject,开发者可以直接定义顶点、法线、纹理坐标等,从而灵活地构建各种复杂的几何体。 Ogre::ManualObject 是 Ogre3D 引擎中的一个类,用…

如何用Excel做数据可视化自动化报表?

作为一个经常需要做数据报表的人,我最常用的工具是Excel,对于我来说用Excel处理繁琐冗杂的数据并不难,但是我发现身边很多人用Excel做的数据报表非常的耗时,而且最后的成品也是难以直视,逻辑和配色等都非常的“灾难”。…

基于FPGA的SD NAND读写测试(图文并茂+源代码+详细注释)

本实验所使用的源代码已同步至个人主页的资源处,可供读者自行学习...... 什么是SD NAND? 1.SD NAND 卡介绍 SD NAND 卡是一种基于 NAND 闪存技术的存储设备,其外观和接口类似于标准的 SD 卡。它将 NAND 闪存芯片和必要的控制电路集成在一个小…

机器学习6-梯度下降法

梯度下降法 目的 梯度下降法(Gradient Descent)是一个算法,但不是像多元线性回归那样是一个具体做回归任务的算法,而是一个非常通用的优化算法来帮助一些机器学习算法求解出最优解的,所谓的通用就是很多机器学习算法都是用它,甚…

(0基础保姆教程)-JavaEE开课啦!--11课程(初识Spring MVC + Vue2.0 + Mybatis)-实验9

一、什么是Spring MVC? Spring MVC 是一个基于 Java 的 Web 框架,遵循 MVC 设计模式,用于构建企业级应用程序。它通过控制器(Controller)处理用户请求,模型(Model)处理业务逻辑,视图(View)展示数据,实现了请…

微前端-MicroApp

微前端即是由一个主应用来集成多个微应用(可以不区分技术栈进行集成) 下面是使用微前端框架之一 MicroApp 对 react微应用 的详细流程 第一步 创建主应用my-mj-app 利用脚手架 npx create-react-app my-mj-app 快速创建 安装 npm install --save rea…

知识库助手的构建之路:ChatGLM3-6B和LangChain的深度应用

ChatGLM3-6B和LangChain构建知识库助手 安装依赖库 使用pip命令安装以下库: pip install modelscope langchain0.1.7 chromadb0.5.0 sentence-transformers2.7.0 unstructured0.13.7 markdown3.0.0 docx2txt0.8 pypdf4.2.0依赖库简介: ModelScope&a…

shell(2)永久环境变量和字符串显位

shell(2)永久环境变量和字符串显位 声明! 学习视频来自B站up主 ​泷羽sec​​ 有兴趣的师傅可以关注一下,如涉及侵权马上删除文章 笔记只是方便各位师傅的学习和探讨,文章所提到的网站以及内容,只做学习…

Java实现IP代理池

文章目录 Java实现IP代理池一、引言二、构建IP代理池1、代理IP的获取2、代理IP的验证1. 导入必要的库2. 设置代理IP和端口3. 发起HTTP请求4. 检查请求结果5. 完整的验证方法 注意事项 三、使用IP代理池四、总结 Java实现IP代理池 一、引言 在网络爬虫或者需要频繁请求网络资源…

微服务保护和分布式事务

文章目录 一、微服务保护1.1 微服务保护方案:1.1.1 请求限流:1.1.2 线程隔离:1.1.3 服务熔断: 1.2 Sentinel:1.2.1 介绍和安装:1.2.2 微服务整合: 1.3 请求限流:1.4 线程隔离&#x…

后端 Java发送邮件 JavaMail 模版 20241128测试可用

配置授权码 依赖 <dependency><groupId>javax.mail</groupId><artifactId>javax.mail-api</artifactId><version>1.5.5</version> </dependency> <dependency><groupId>com.sun.mail</groupId><artifa…