深度学习基本单元结构与输入输出维度解析

深度学习基本单元结构与输入输出维度解析

在深度学习领域,模型的设计和结构是理解其性能和应用的关键。本文将介绍深度学习中的基本单元结构,包括卷积神经网络(CNN)、反卷积(转置卷积)、循环神经网络(RNN)、门控循环单元(GRU)和长短期记忆网络(LSTM),并详细讨论每个单元的输入和输出维度。我们将以 MNIST 数据集为例,展示这些基本单元如何组合在一起构建复杂的模型。
之前的博客:
深入理解 RNN、LSTM 和 GRU:结构、参数与应用
理解 Conv2d 和 ConvTranspose2d 的输入输出特征形状计算

1. 模型结构概述

我们构建的模型包含以下主要部分:

  • 卷积神经网络(CNN)
  • 反卷积(转置卷积)
  • 循环神经网络(RNN)
  • 门控循环单元(GRU)
  • 长短期记忆网络(LSTM)
  • 全连接层

2. 模型代码

以下是实现综合模型的代码:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader# 定义模型
class CombinedModel(nn.Module):def __init__(self):super(CombinedModel, self).__init__()# CNN 部分self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)  # 输入: (1, 28, 28) -> 输出: (32, 28, 28)self.pool = nn.MaxPool2d(kernel_size=2, stride=2)  # 最大池化层self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)  # 输入: (32, 28, 28) -> 输出: (64, 28, 28)# 反卷积部分self.deconv = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)  # 输入: (64, 14, 14) -> 输出: (32, 28, 28)# RNN 部分self.rnn_input_size = 32 * 14 * 14  # 输入到 RNN 的特征数self.rnn = nn.RNN(input_size=self.rnn_input_size, hidden_size=128, num_layers=1,batch_first=True)  # 输入: (batch_size, seq_len, input_size)# GRU 部分self.gru = nn.GRU(input_size=128, hidden_size=64, num_layers=1,batch_first=True)  # 输入: (batch_size, seq_len, input_size)# LSTM 部分self.lstm = nn.LSTM(input_size=64, hidden_size=32, num_layers=1,batch_first=True)  # 输入: (batch_size, seq_len, input_size)# 全连接层self.fc = nn.Linear(32, 10)  # 输出: (batch_size, 10)def forward(self, x):# CNN 部分print(f'Input shape: {x.shape}')  # 输入形状: (batch_size, 1, 28, 28)x = self.pool(torch.relu(self.conv1(x)))  # 输出: (batch_size, 32, 28, 28)print(f'After conv1 and pool: {x.shape}')x = self.pool(torch.relu(self.conv2(x)))  # 输出: (batch_size, 64, 14, 14)print(f'After conv2 and pool: {x.shape}')# 反卷积部分x = self.deconv(x)  # 输出: (batch_size, 32, 14, 14)print(f'After deconv: {x.shape}')# 将数据展平并调整形状以输入到 RNNx = x.view(x.size(0), -1)  # 展平为 (batch_size, 32 * 14 * 14)print(f'After flattening: {x.shape}')x = x.unsqueeze(1)  # 添加序列长度维度,变为 (batch_size, 1, 32 * 14 * 14)print(f'After unsqueeze for RNN: {x.shape}')# RNN 部分x, _ = self.rnn(x)  # 输出: (batch_size, 1, 128)print(f'After RNN: {x.shape}')# GRU 部分x, _ = self.gru(x)  # 输出: (batch_size, 1, 64)print(f'After GRU: {x.shape}')# LSTM 部分x, _ = self.lstm(x)  # 输出: (batch_size, 1, 32)print(f'After LSTM: {x.shape}')# 取最后一个时间步的输出x = x[:, -1, :]  # 输出: (batch_size, 32)print(f'After selecting last time step: {x.shape}')# 全连接层x = self.fc(x)  # 输出: (batch_size, 10)print(f'Output shape: {x.shape}')return x# 3. 数据加载
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))  # MNIST 数据集的均值和标准差
])# 下载 MNIST 数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)# 4. 训练模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CombinedModel().to(device)
criterion = nn.CrossEntropyLoss()  # 损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001)  # 优化器# 训练过程
num_epochs = 5
for epoch in range(num_epochs):model.train()running_loss = 0.0correct = 0total = 0for images, labels in train_loader:images, labels = images.to(device), labels.to(device)  # 将数据移动到设备optimizer.zero_grad()  # 清空梯度outputs = model(images)  # 前向传播loss = criterion(outputs, labels)  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 更新参数running_loss += loss.item()_, predicted = torch.max(outputs.data, 1)  # 获取预测结果total += labels.size(0)correct += (predicted == labels).sum().item()avg_loss = running_loss / len(train_loader)accuracy = 100 * correct / totalprint(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%')# 5. 评估模型
model.eval()
correct = 0
total = 0
with torch.no_grad():for images, labels in test_loader:images, labels = images.to(device), labels.to(device)  # 将数据移动到设备outputs = model(images)_, predicted = torch.max(outputs.data, 1)  # 获取预测结果total += labels.size(0)correct += (predicted == labels).sum().item()print(f'Accuracy of the model on the test images: {100 * correct / total:.2f}%')

3. 每个基本单元的输入输出维度

3.1 CNN 部分

  1. 输入(batch_size, 1, 28, 28)

    • 这是 MNIST 数据集的输入形状,其中 1 表示单通道(灰度图像),28x28 是图像的高度和宽度。
  2. 卷积层 1

    • self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
    • 输入形状(batch_size, 1, 28, 28)
    • 输出形状(batch_size, 32, 28, 28)
    • 32 个特征图,空间维度保持不变。
  3. 最大池化层 1

    • 输入形状(batch_size, 32, 28, 28)
    • 输出形状(batch_size, 32, 14, 14)
    • 高度和宽度减半。
  4. 卷积层 2

    • self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
    • 输入形状(batch_size, 32, 14, 14)
    • 输出形状(batch_size, 64, 14, 14)
    • 64 个特征图,空间维度保持不变。
  5. 最大池化层 2

    • 输入形状(batch_size, 64, 14, 14)
    • 输出形状(batch_size, 64, 7, 7)
    • 高度和宽度再次减半。

3.2 反卷积部分

  1. 反卷积层
    • self.deconv = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
    • 输入形状(batch_size, 64, 7, 7)
    • 输出形状(batch_size, 32, 14, 14)
    • 高度和宽度翻倍。

3.3 RNN 部分

  1. 展平

    • x = x.view(x.size(0), -1)
    • 输入形状(batch_size, 32, 14, 14)
    • 输出形状(batch_size, 6272) # 这里的 6272 是 32 * 14 * 14
    • 将特征图展平为一个向量。
  2. 添加序列长度维度

    • x = x.unsqueeze(1)
    • 输入形状(batch_size, 6272)
    • 输出形状(batch_size, 1, 6272)
    • 添加序列长度维度,表示只有一个时间步。
  3. RNN

    • 输入形状(batch_size, 1, 6272)
    • 输出形状(batch_size, 1, 128)
    • RNN 输出的隐藏状态,隐藏层大小为 128。

3.4 GRU 和 LSTM 部分

  1. GRU

    • 输入形状(batch_size, 1, 128)
    • 输出形状(batch_size, 1, 64)
    • GRU 输出的隐藏状态,隐藏层大小为 64。
  2. LSTM

    • 输入形状(batch_size, 1, 64)
    • 输出形状(batch_size, 1, 32)
    • LSTM 输出的隐藏状态,隐藏层大小为 32。

3.5 全连接层

  1. 全连接层
    • self.fc = nn.Linear(32, 10)
    • 输入形状(batch_size, 32)
    • 输出形状(batch_size, 10)
    • 最终输出的类别数(10 类,表示 MNIST 的数字 0-9)。

4. 可视化模型结构

from torchinfo import summary
model = CombinedModel()
summary(model, input_size=(64,1, 28, 28))

或者

import torch
import torch.nn as nn
from torchviz import make_dot
model = CombinedModel()
dummy_input = torch.randn(1, 1, 28, 28)  # (batch_size, channels, height, width)
output = model(dummy_input)
dot = make_dot(output, params=dict(model.named_parameters()))
dot.render("model_structure", format="png")  # 生成 model_structure.png

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/62407.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Sofia-SIP 使用教程

Sofia-SIP 是一个开源的 SIP 协议栈,广泛用于 VoIP 和即时通讯应用。以下是一些基本的使用教程,帮助你快速上手 Sofia-SIP。 1. 安装 Sofia-SIP 首先,你需要安装 Sofia-SIP 库。你可以从其官方 GitHub 仓库克隆源代码并编译安装&#xff1a…

【八股文】小米

文章目录 一、vector 和 list 的区别?二、include 双引号和尖括号的区别?三、set 的底层数据结构?四、set 和 multiset 的区别?五、map 和 unordered_map 的区别?六、虚函数和纯虚函数的区别?七、extern C …

【leetcode100】找到字符串中所有字母异位词

1、题目描述 给定两个字符串 s 和 p,找到 s 中所有 p 的 异位词 异位词 的子串,返回这些子串的起始索引。不考虑答案输出的顺序。 示例 1: 输入: s "cbaebabacd", p "abc" 输出: [0,6] 解释: 起始索引等于 0 的子串是 "…

乐鑫发布 esp-iot-solution v2.0 版本

今天,乐鑫很高兴地宣布,esp-iot-solution v2.0 版本已经发布,release/v2.0 分支下的正式版本组件将为用户提供为期两年的 Bugfix 维护(直到 2027.01.25 ESP-IDF v5.3 EOL)。该版本将物联网开发中常用的功能进行了分类整…

c/c++ 用easyx图形库写一个射击游戏

#include <graphics.h> #include <conio.h> #include <stdlib.h> #include <time.h>// 定义游戏窗口的大小 #define WINDOW_WIDTH 800 #define WINDOW_HEIGHT 600// 定义玩家和目标的尺寸 #define PLAYER_SIZE 50 #define TARGET_SIZE 20// 玩家的结构…

面经-综合面/hr面

面经-综合面/hr面 概述1.大学期间遇到的困难&#xff0c;怎么解决的2. 大学期间印象最深/最难忘的是什么3. 大学里面担任了什么职务没&#xff1f;做了什么工作&#xff1f;4. 大学最大的遗憾是什么&#xff1f;5. 对自己的未来规划6. 对自己的评价7. 自己的优缺点8. 对公司的认…

pyspark实现基于协同过滤的电影推荐系统

最近在学一门大数据的课&#xff0c;课程要求很开放&#xff0c;任意做一个大数据相关的项目即可&#xff0c;不知道为什么我就想到推荐算法&#xff0c;一直到着手要做之前还没有新的更好的来代替&#xff0c;那就这个吧。 推荐算法 推荐算法的发展由来已久&#xff0c;但和…

虚拟现实(VR)与增强现实(AR)有什么区别?

虚拟现实&#xff08;Virtual Reality&#xff0c;VR&#xff09;与增强现实&#xff08;Augmented Reality&#xff0c;AR&#xff09;在多个方面存在显著差异。以下是对这两者的详细比较&#xff1a; 一、概念定义 虚拟现实&#xff08;VR&#xff09;&#xff1a; 是一种…

【图像去噪】论文精读:Deep Image Prior(DIP)

请先看【专栏介绍文章】:【图像去噪(Image Denoising)】关于【图像去噪】专栏的相关说明,包含适配人群、专栏简介、专栏亮点、阅读方法、定价理由、品质承诺、关于更新、去噪概述、文章目录、资料汇总、问题汇总(更新中) 文章目录 前言Abstract1. Introduction2. Method3…

十、Spring Boot集成Spring Security之HTTP请求授权

文章目录 往期回顾&#xff1a;Spring Boot集成Spring Security专栏及各章节快捷入口前言一、HTTP请求授权工作原理二、HTTP请求授权配置1、添加用户权限2、配置ExceptionTranslationFilter自定义异常处理器3、HTTP请求授权配置 三、测试接口1、测试类2、测试 四、总结 往期回顾…

Unity3d C# 实现一个基于UGUI的自适应尺寸图片查看器(含源码)

前言 Unity3d实现的数字沙盘系统中&#xff0c;总有一些图片或者图片列表需要点击后弹窗显示大图&#xff0c;这个弹窗在不同尺寸分辨率的图片查看处理起来比较麻烦&#xff0c;所以&#xff0c;需要图片能够根据容器的大小自适应地进行缩放&#xff0c;兼容不太尺寸下的横竖图…

用 llama.cpp 体验 Meta 的 Llama AI 模型

继续体验 Meta 开源的 Llama 模型&#xff0c;前篇 试用 Llama-3.1-8B-Instruct AI 模型 直接用 Python 的 Tranformers 和 PyTorch 库加载 Llama 模型进行推理。模型训练出来的精度是 float32, 加载时采用的精度是 torch.bfloat16。 注&#xff1a;数据类型 torch.float32, t…

Axios与FastAPI结合:构建并请求用户增删改查接口

在现代Web开发中&#xff0c;FastAPI以其高性能和简洁的代码结构成为了构建RESTful API的热门选择。而Axios则因其基于Promise的HTTP客户端特性&#xff0c;成为了前端与后端交互的理想工具。本文将介绍FastAPI和Axios的结合使用&#xff0c;通过一个用户增删改查&#xff08;C…

深入理解B-树与B+树:数据结构中的高效索引利器

一、引言 在数据库系统中&#xff0c;索引是提高查询效率的关键技术。而B-树和B树作为常用的索引数据结构&#xff0c;以其高效的查询、插入和删除操作备受青睐。下面我们将分别探讨B-树和B树的结构及其优缺点。 二、B-树 B-树简介 B-树&#xff08;Balanced Tree&#xff…

DVWA 在 Windows 环境下的部署指南

目录预览 一、靶场介绍二、前置准备1. 环境准备2.靶场下载 三、安装步骤1.配置Phpstudy2.配置数据库3.配置DVWA4.登入DVWA靶场 四、参考链接 一、靶场介绍 DVWA 一共包含了十个攻击模块&#xff0c;分别是&#xff1a; Brute Force&#xff08;暴力&#xff08;破解&#xff…

Spring Bean 初始化如何保证线程安全

创作内容丰富的干货文章很费心力,感谢点过此文章的读者,点一个关注鼓励一下作者,激励他分享更多的精彩好文,谢谢大家! Spring Bean 中的参数通常有几种初始化方法: 通过构造函数注入: @Service public void MyService {private MyData myData;public MyService(MyData…

虚拟机ubuntu-20.04.6-live-server搭建OpenStack:Victoria(二:OpenStack环境准备-compute node)

文章目录 Host networkinga. 配置网络接口b. 验证连通性 Network Time Protocol (NTP)a. 安装并配置组件b. 验证操作 OpenStack packagesa. 下载Victoria云存储仓库b. 安装示例c. 安装客户端 沉浸版指令及内容&#xff1a; Host networking a. 配置网络接口 切换至超级用户模…

微软企业邮箱:安全可靠的企业级邮件服务!

微软企业邮箱的设置步骤&#xff1f;如何注册使用烽火域名邮箱&#xff1f; 微软企业邮箱作为一款专为企业设计的邮件服务&#xff0c;不仅提供了高效便捷的通信工具&#xff0c;更在安全性、可靠性和功能性方面树立了行业标杆。烽火将深入探讨微软企业邮箱的多重优势。 微软…

使用UE5.5的Animator Kit变形器

UE5.5版本更新了AnimatorKit内置插件&#xff0c;其中包含了一些内置变形器&#xff0c;可以辅助我们的动画制作。 操作步骤 首先打开UE5.5&#xff0c;新建第三人称模板场景以便测试&#xff0c;并开启AnimatorKit组件。 新建Sequence&#xff0c;放入测试角色 点击角色右…

JS异步进化与Promise

JavaScript 是单线程的&#xff0c;但它并不是无法处理异步操作。相反&#xff0c;JavaScript 的单线程特性和其事件循环机制使得它在处理异步任务方面非常高效 回调函数(Callback Functions) 一开始JS使用回调的形式来处理异步的结果,但是异步的弊端很大 例如:无法更好的处理…