Pytorch框架下的CNN和RNN

1.CNN

建立了3层(3层=2层+1层全连接层)。分别是conv1、conv2和分类问题中的全连接层线性层out

class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Sequential(         # input shape (1, 28, 28)nn.Conv2d(in_channels=1,              # input heightout_channels=16,            # n_filterskernel_size=5,              # filter sizestride=1,                   # filter movement/steppadding=2,                  # if want same width and length of this image after Conv2d, padding=(kernel_size-1)/2 if stride=1),                              # output shape (16, 28, 28)nn.ReLU(),                      # activationnn.MaxPool2d(kernel_size=2),    # choose max value in 2x2 area, output shape (16, 14, 14))self.conv2 = nn.Sequential(         # input shape (16, 14, 14)nn.Conv2d(16, 32, 5, 1, 2),     # output shape (32, 14, 14)nn.ReLU(),                      # activationnn.MaxPool2d(2),                # output shape (32, 7, 7))self.out = nn.Linear(32 * 7 * 7, 10)   # fully connected layer, output 10 classesdef forward(self, x):x = self.conv1(x)x = self.conv2(x)x = x.view(x.size(0), -1)           # flatten the output of conv2 to (batch_size, 32 * 7 * 7)output = self.out(x)return output, x    # return x for visualization

2.RNN

2.1RNN分类问题代码

设计了rnn层【输入(INPUT_SIZE),隐藏层1层(hidden_size)】和分类问题的全连接线性层out

class RNN(nn.Module):def __init__(self):super(RNN, self).__init__()self.rnn = nn.LSTM(     # LSTM 效果要比 nn.RNN() 好多了input_size=28,      # 图片每行的数据像素点hidden_size=64,     # rnn hidden unitnum_layers=1,       # 有几层 RNN layersbatch_first=True,   # input & output 会是以 batch size 为第一维度的特征集 e.g. (batch, time_step, input_size))self.out = nn.Linear(64, 10)    # 输出层def forward(self, x):# x shape (batch, time_step, input_size)# r_out shape (batch, time_step, output_size)# h_n shape (n_layers, batch, hidden_size)   LSTM 有两个 hidden states, h_n 是分线, h_c 是主线# h_c shape (n_layers, batch, hidden_size)r_out, (h_n, h_c) = self.rnn(x, None)   # None 表示 hidden state 会用全0的 state# 选取最后一个时间点的 r_out 输出# 这里 r_out[:, -1, :] 的值也是 h_n 的值out = self.out(r_out[:, -1, :])return outrnn = RNN()
print(rnn)
"""
RNN ((rnn): LSTM(28, 64, batch_first=True)(out): Linear (64 -> 10)
)
"""

2.2RNN回归问题代码

具体参考:https://mofanpy.com/tutorials/machine-learning/torch/RNN-regression

class RNN(nn.Module):def __init__(self):super(RNN, self).__init__()self.rnn = nn.RNN(  # 这回一个普通的 RNN 就能胜任input_size=1,hidden_size=32,     # rnn hidden unitnum_layers=1,       # 有几层 RNN layersbatch_first=True,   # input & output 会是以 batch size 为第一维度的特征集 e.g. (batch, time_step, input_size))self.out = nn.Linear(32, 1)def forward(self, x, h_state):  # 因为 hidden state 是连续的, 所以我们要一直传递这一个 state# x (batch, time_step, input_size)# h_state (n_layers, batch, hidden_size)# r_out (batch, time_step, output_size)r_out, h_state = self.rnn(x, h_state)   # h_state 也要作为 RNN 的一个输入outs = []    # 保存所有时间点的预测值for time_step in range(r_out.size(1)):    # 对每一个时间点计算 outputouts.append(self.out(r_out[:, time_step, :]))return torch.stack(outs, dim=1), h_staternn = RNN()
print(rnn)
"""
RNN ((rnn): RNN(1, 32, batch_first=True)(out): Linear (32 -> 1)
)
"""

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/6051.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C++:输入输出运算符重载

在C中,输入输出运算符是用于从标准输入设备(通常是键盘)读取数据或将数据输出到标准输出设备(通常是屏幕)的运算符。常用的输入输出运算符包括: 输入运算符 (>>): 用于从输入流&#xff0…

逻辑漏洞:水平越权、垂直越权靶场练习

目录 1、身份认证失效漏洞实战 2、YXCMS检测数据比对弱(水平越权) 3、MINICMS权限操作无验证(垂直越权) 1、身份认证失效漏洞实战 上一篇学习了水平越权和垂直越权的相关基本知识,在本篇还是继续学习,这…

深度学习:基于Keras,使用长短期记忆人工神经网络模型(LSTM)对股票市场进行预测分析

前言 系列专栏:机器学习:高级应用与实践【项目实战100】【2024】✨︎ 在本专栏中不仅包含一些适合初学者的最新机器学习项目,每个项目都处理一组不同的问题,包括监督和无监督学习、分类、回归和聚类,而且涉及创建深度学…

Electron-Builder 打包 Vue 项目避坑指南

最近在开发一个基于 Vue 的 Electron 项目,在打包时遇到了诸多问题,为了解决这些问题也查阅了非常多的资料,排除了很多坑。现在将可能遇到的问题整理成避坑指南,供大家参考(此避坑指南后续还会继续更新)。 …

文章解读与仿真程序复现思路——电力自动化设备EI\CSCD\北大核心《计及高阶方程分段线性化的港口电-氢综合能源系统优化调度》

本专栏栏目提供文章与程序复现思路,具体已有的论文与论文源程序可翻阅本博主免费的专栏栏目《论文与完整程序》 论文与完整源程序_电网论文源程序的博客-CSDN博客https://blog.csdn.net/liang674027206/category_12531414.html 电网论文源程序-CSDN博客电网论文源…

Django运行不提示网址问题

问题描述:运行django项目不提示网址信息,也就是web没有起来,无法访问。 (my-venv-3.8) PS D:\Project\MyGitCode\public\it_blog\blog> python .\manage.py runserver INFO autoreload 636 Watching for file changes with StatReloader …

clang:在 Win10 上编译 MIDI 音乐程序

先从 Microsoft C Build Tools - Visual Studio 下载 1.73GB 安装 "Microsoft C Build Tools“ 访问 Swift.org - Download Swift 找到 Windows 10:x86_64 下载 swift-5.10-RELEASE-windows10.exe 大约490MB 建议安装在 D:\Swift\ ,安装后大约占…

SQL 基础 | UNION 用法介绍

在SQL中,UNION操作符用于合并两个或多个SELECT语句的结果集,形成一个新的结果集。 使用UNION时,合并的结果集列数必须相同,并且列的数据类型也需要兼容。 默认情况下,UNION会去除重复的行,只保留唯一的行。…

Flutter笔记:使用Flutter私有类涉及的授权协议问题

Flutter笔记 使用Flutter私有类涉及的授权协议问题 - 文章信息 - Author: 李俊才 (jcLee95) Visit me at CSDN: https://jclee95.blog.csdn.netMy WebSite:http://thispage.tech/Email: 291148484163.com. Shenzhen ChinaAddress of this article:https://blog.cs…

【跟马少平老师学AI】-【神经网络是怎么实现的】(七-1)词向量

一句话归纳: 1)神经网络不仅可以处理图像,还可以处理文本。 2)神经网络处理文本,先要解决文本的表示(图像的表示用像素RGB)。 3)独热编码词向量: 词表:{我&am…

ensp 配置s5700 ssh登陆

#核心配置 sys undo info-center enable sysname sw1 vlan 99 stelnet server enable telnet server enable int g 0/0/1 port lin acc port de vlan 99 q user-interface vty 0 4 protocol inbound ssh authentication-mode aaa q aaa local-user admin0 password cipher adm…

Java集合框架-容器源码分析

Java集合框架-容器&源码分析 文章目录 Java集合框架-容器&源码分析[TOC](文章目录)前言一、集合框架概述二、Collection接口及其子接口(List/Set)及实现类2.1 Collection接口中方法2.2 遍历:Iterator迭代器接口&foreach(5.0新特性)2.3 Connection子接口…

SQL 基础 | AS 的用法介绍

SQL(Structured Query Language)是一种用于管理和操作数据库的标准编程语言。 在SQL中,AS关键字有几种不同的用法,主要用于重命名表、列或者查询结果。 以下是AS的一些常见用法: 重命名列:在SELECT语句中&a…

C++基础—模版

C模板是C语言中实现泛型编程的核心机制,它允许程序员定义通用的代码框架,这些框架在编译时可以根据提供的具体类型参数生成相应的特定类型实例。 泛型编程的特点代码复用和安全性! 模板主要分为两大类:函数模板和类模板。 函数模板 基本语…

C++深度解析教程笔记7

C深度解析教程笔记7 第13课 - 进阶面向对象(上)类和对象小结 第14课 - 进阶面向对象(下)类之间的基本关系继承组合 类的表示法实验-类的继承 第15课 - 类与封装的概念实验-定义访问级别cmd 实验小结 第16课 - 类的真正形态实验-st…

Web,Sip,Rtsp,Rtmp,WebRtc,专业MCU融屏视频混流会议直播方案分析

随着万物互联,视频会议直播互动深入业务各方面,主流SFU并不适合管理,很多业务需要各种监控终端,互动SIP硬件设备,Web在线业务平台能相互融合,互联互通, 视频混流直播,录存直播推广&a…

vue3+vite项目中,图片显示为src=“[object Object]“

查了半天&#xff0c;网上都是教人改webpack配置&#xff08;很无语……&#xff09; 解决方法&#xff1a; 在原图片&#xff1b;路径后面加上?url // example <img src"/assets/imgs/stop.svg?url" alt"" />

c++ 筛选裁决文书 1985-2021的数据 分析算法的差异

c cpp 并行计算筛选过滤 裁决文书网1985-2021 的300g数据 数据 数据解压以后大概300g&#xff0c;最开始是使用python代码进行计算&#xff0c;但是python实在太慢了&#xff0c;加上多进程也不行&#xff0c; 于是 使用c 进行 计算 c这块最开始使用的是 i7-9700h 用的是单线…

【翻译】Elasticsearch-索引模块

索引块限制对指定的索引的可用的操作类型&#xff08;就是指对该索引能进行什么操作&#xff09;。这些块有不同的风格&#xff0c;可以阻止写、读或元数据操作。块可以通过动态索引设置来设置/移除&#xff0c;也可以通过专用API添加&#xff0c;这也可以确保写入块一旦成功返…

基于Spring Boot的心灵治愈交流平台设计与实现

基于Spring Boot的心灵治愈交流平台设计与实现 开发语言&#xff1a;Java框架&#xff1a;springbootJDK版本&#xff1a;JDK1.8数据库工具&#xff1a;Navicat11开发软件&#xff1a;eclipse/myeclipse/idea 系统部分展示 系统功能界面图&#xff0c;在系统首页可以查看首页…