PyTorch 实现动态输入

使用 PyTorch 实现动态输入:支持训练和推理输入维度不一致的 CNN 和 LSTM/GRU 模型

在深度学习中,处理不同大小的输入数据是一个常见的挑战。许多实际应用需要模型能够灵活地处理可变长度的输入。本文将介绍如何使用 PyTorch 实现支持动态输入的 CNN 和 LSTM/GRU 模型,并打印每一层的输入和输出。

  • 卷积神经网络(CNN):CNN 通常用于处理图像数据。它通过卷积层提取局部特征,并能够处理不同大小的输入图像。通过使用全局池化层,CNN 可以将不同大小的特征图转换为固定大小的输出。

  • 长短期记忆网络(LSTM)和门控循环单元(GRU):LSTM 和 GRU 是处理序列数据的 RNN 变体。它们能够捕捉时间序列中的长期依赖关系,并支持可变长度的输入序列。

模型搭建

1. CNN 模型

我们将构建一个简单的 CNN 模型,支持动态输入大小,并打印每一层的输入和输出。

import torch
import torch.nn as nn
import torch.nn.functional as Fclass DynamicCNN(nn.Module):def __init__(self):super(DynamicCNN, self).__init__()self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3)self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3)self.pool = nn.AdaptiveAvgPool2d((1, 1))  # 自适应池化层self.fc = nn.Linear(32, 10)  # 输出10个类别def forward(self, x):print(f'Input to CNN: {x.shape}')x = F.relu(self.conv1(x))print(f'Output after conv1: {x.shape}')x = F.relu(self.conv2(x))print(f'Output after conv2: {x.shape}')x = self.pool(x)print(f'Output after pooling: {x.shape}')x = x.view(x.size(0), -1)  # 展平x = self.fc(x)print(f'Output after fc: {x.shape}')return x# 创建模型
cnn_model = DynamicCNN()# 测试动态输入
input_tensor_cnn = torch.randn(1, 3, 64, 64)  # 输入形状为 (batch_size, channels, height, width)
output_cnn = cnn_model(input_tensor_cnn)
Input to CNN: torch.Size([1, 3, 55, 64])
Output after conv1: torch.Size([1, 16, 53, 62])
Output after conv2: torch.Size([1, 32, 51, 60])
Output after pooling: torch.Size([1, 32, 1, 1])
Output after fc: torch.Size([1, 10])
Input to CNN: torch.Size([1, 3, 64, 64])
Output after conv1: torch.Size([1, 16, 62, 62])
Output after conv2: torch.Size([1, 32, 60, 60])
Output after pooling: torch.Size([1, 32, 1, 1])
Output after fc: torch.Size([1, 10])

2. LSTM/GRU 模型

接下来,我们将构建一个支持动态输入的 LSTM 模型,并打印每一层的输入和输出。

import torch
import torch.nn as nnclass DynamicLSTM(nn.Module):def __init__(self):super(DynamicLSTM, self).__init__()self.lstm = nn.LSTM(input_size=10, hidden_size=20, batch_first=True)self.fc = nn.Linear(20, 1)  # 输出一个值def forward(self, x):print(f'Input to LSTM: {x.shape}')x, _ = self.lstm(x)print(f'Output after LSTM: {x.shape}')x = self.fc(x[:, -1, :])  # 取最后一个时间步的输出print(f'Output after fc: {x.shape}')return x# 创建模型
lstm_model = DynamicLSTM()# 测试动态输入
input_tensor_lstm = torch.randn(5, 15, 10)  # 输入形状为 (batch_size, seq_length, input_size)
output_lstm = lstm_model(input_tensor_lstm)
Input to LSTM: torch.Size([5, 15, 10])
Output after LSTM: torch.Size([5, 15, 20])
Output after fc: torch.Size([5, 1])
Input to LSTM: torch.Size([5, 20, 10])
Output after LSTM: torch.Size([5, 20, 20])
Output after fc: torch.Size([5, 1])

代码说明

  1. DynamicCNN:该模型包含两个卷积层和一个全连接层。使用自适应平均池化层将特征图的大小调整为 (1, 1),从而支持不同大小的输入图像。每一层的输入和输出形状在前向传播中被打印出来。

  2. DynamicLSTM:该模型包含一个 LSTM 层和一个全连接层。LSTM 层能够处理可变长度的输入序列,输出的形状在前向传播中被打印出来。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/62903.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

了解Linux —— 理解其中的权限

前言 在了解Linux权限之前,先来探讨我们使用的shell 命令它到底是什么? Linux 是一个操作系统,我们称其为内核(kernel) ,正常情况下,我们一般用户操作并不是去直接使用内核,而是通过kernel 的外壳程序&…

WebHID API演示Demo教程:设备列表,设备连接,数据读写

1. 简介 WebHID API允许网页应用直接与HID(人机接口设备)进行通信。本教程将演示如何创建一个基础的WebHID应用,实现以下功能: 显示和获取HID设备列表连接/断开HID设备读取设备数据向设备发送数据 2. 兼容性和前提条件 2.1 浏览…

中安证件OCR识别技术助力鸿蒙生态:智能化证件识别新体验

在数字化和智能化的浪潮中,伴随国产化战略的深入推进,国产操作系统和软件生态的建设逐渐走向成熟。鸿蒙操作系统(HarmonyOS Next)作为华为推出的重要操作系统,凭借其开放、灵活和高效的特点,正在加速在多个…

PDF与PDF/A的区别及如何使用Python实现它们之间的相互转换

目录 概述 PDF/A 是什么?与 PDF 有何不同? 用于实现 PDF 与 PDF/A 相互转换的 Python 库 Python 实现 PDF 转 PDF/A 将 PDF 转换为 PDF/A-1a 将 PDF 转换为 PDF/A-1b 将 PDF 转换为 PDF/A-2a 将 PDF 转换为 PDF/A-2b 将 PDF 转换为 PDF/A-3a 将…

k8s Quality of Service

文章目录 QoS 分类规则QoS 类别影响创建 QoS 分类的案例1. Guaranteed QoS 示例示例 YAML 文件: 2. Burstable QoS 示例示例 YAML 文件: 3. BestEffort QoS 示例示例 YAML 文件: 4. 混合 QoS 示例(多个容器)示例 YAML …

面向对象(二)——类和对象(上)

1 类的定义 做了关于对象的很多介绍,终于进入代码编写阶段。 本节中重点介绍类和对象的基本定义,属性和方法的基本使用方式。 【示例】类的定义方式 // 每一个源文件必须有且只有一个public class,并且类名和文件名保持一致! …

【HarmonyOS】鸿蒙应用地理位置获取,地理名称获取

【HarmonyOS】鸿蒙应用地理位置获取,地理名称获取 一、前言 首先要理解地理专有名词,当我们从系统获取地理位置,一般会拿到地理坐标,是一串数字,并不是地理位置名称。例如 116.2305,33.568。 这些数字坐…

关于数据库数据国际化方案

方案一:每个表设计一个翻译表 数据库国际化的应用场景用到的比较少,主要用于对数据库的具体数据进行翻译,在需要有大量数据翻译的场景下使用,举个例子来说,力扣题目的中英文切换。参考方案可见: https://b…

「Mac畅玩鸿蒙与硬件37」UI互动应用篇14 - 随机颜色变化器

本篇将带你实现一个随机颜色变化器应用。用户点击“随机颜色”按钮后,界面背景会随机变化为淡色系颜色,同时显示当前的颜色代码,页面还会展示一只猫咪图片作为装饰,提升趣味性。 关键词 UI互动应用随机颜色生成状态管理用户交互…

跟着官方文档快速入门RAGAS

官网: Ragas Ragas(Retrieval-Augmented Generation, RAG)是一个基于简单手写提示的评估框架,通过这些提示全自动地衡量答案的准确性、 相关性和上下文相关性。这种评估方法不需要访问人工注释的数据集或参考答案,使得评估过程更…

掌握 Spring Boot 中的缓存:技术和最佳实践

缓存是一种用于将经常访问的数据临时存储在更快的存储层(通常在内存中)中的技术,以便可以更快地满足未来对该数据的请求,从而提高应用程序的性能和效率。在 Spring Boot 中,缓存是一种简单而强大的方法,可以…

我谈冈萨雷斯对频域滤波的误解——快速卷积与频域滤波之间的关系

在Rafael Gonzalez和Richard Woods所著的《数字图像处理》中,Gonzalez对频域滤波是有误解的,在频域设计滤波器不是非得图像和滤波器的尺寸相同,不是非得在频域通过乘积实现。相反,FIR滤波器设计都是构造空域脉冲响应。一般的原则是…

Hello World C#

using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading.Tasks; using System; 引入了System命名空间,基本输入输出。一般只用这个,后面的不用 using System.Collections.Generic; 包含了定…

AI高中数学教学视频生成技术:利用通义千问、MathGPT、视频多模态大模型,语音大模型,将4个模型融合 ,生成高中数学教学视频,并给出实施方案。

大家好,我是微学AI,今天给大家介绍一下AI高中数学教学视频生成技术:利用通义千问、MathGPT、视频多模态大模型,语音大模型,将4个模型融合 ,生成高中数学教学视频,并给出实施方案。本文利用专家模…

Linux下,用ufw实现端口关闭、流量控制(二)

本文是 网安小白的端口关闭实践 的续篇。 海量报文,一手掌握,你值得拥有,让我们开始吧~ ufw 与 iptables的关系 理论介绍: ufw(Uncomplicated Firewall)是一个基于iptables的前端工具&#xf…

@staticmethod、@classmethod

staticmethod 静态方法:staticmethod将一个普通函数嵌入到类中,使其成为类的静态方法。静态方法不需要一个类实例即可被调用,同时它也不需要访问类实例的状态。参数:静态方法可以接受任何参数,但通常不使用self或cls作…

MySQL之数据完整性

数据的完整性约束可以分为三类: 实体完整性、域完整性和引用完整性。 说来说去(说主键,外键,以及⼀些约束) 1、实体完整性 (实体就是行) 什么是关系型数据库? 一个表代表一类事务&#xff0…

echarts的双X轴,父级居中的相关配置

前言:折腾了一个星期,在最后一天中午,都快要放弃了,后来坚持下来,才有下面结果。 这个效果就相当是复合表头,第一行是子级,第二行是父级。 子级是奇数个时,父级label居中很简单&…

配置宝塔php curl 支持http/2 发送苹果apns消息推送

由于宝塔面板默认的php编译的curl未加入http2的支持,如果服务需要使用apns推送等需要http2.0的访问就会失败,所以重新编译php让其支持http2.0 编译方法: 一、安装nghttp2 git clone https://github.com/tatsuhiro-t/nghttp2.git cd nghttp…

记录一次网关异常

记一次网关异常 网关时不时就会出现下面的异常。关键是不知道什么时候就会报错,并且有时候就算什么都不操作,也会导致这个异常。 ERROR org.springframework.scheduling.support.TaskUtils$LoggingErrorHandler - Unexpected error occurred in schedul…