动手学深度学习(Pytorch版)代码实践 -卷积神经网络-29残差网络ResNet

29残差网络ResNet

在这里插入图片描述

import torch  
from torch import nn  
from torch.nn import functional as F 
import liliPytorch as lp  
import matplotlib.pyplot as plt# 定义一个继承自nn.Module的残差块类
class Residual(nn.Module):def __init__(self, input_channels, num_channels, use_1x1conv=False, strides=1):super().__init__()# 第一个卷积层,使用3x3的卷积核,填充为1,步幅为指定值self.conv1 = nn.Conv2d(input_channels, num_channels, kernel_size=3, padding=1, stride=strides)# 第二个卷积层,使用3x3的卷积核,填充为1self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1)# 可选的1x1卷积层,用于匹配输入输出通道数和步幅if use_1x1conv:self.conv3 = nn.Conv2d(input_channels, num_channels, kernel_size=1, stride=strides)else:self.conv3 = None# 批量归一化层self.bn1 = nn.BatchNorm2d(num_channels)self.bn2 = nn.BatchNorm2d(num_channels)# 为什么需要两个不同的批量归一化层?# 1.不同的位置,不同的输入特征# 2.独立的参数和统计数据def forward(self, X):# 先通过第一个卷积层、批量归一化层和ReLU激活函数Y = F.relu(self.bn1(self.conv1(X)))# 然后通过第二个卷积层和批量归一化层Y = self.bn2(self.conv2(Y))# 如果定义了conv3,则通过conv3调整Xif self.conv3:X = self.conv3(X)# 将输入X加到输出Y上实现残差连接Y += X# 通过ReLU激活函数return F.relu(Y)# 创建一个包含输入和输出形状一致的残差块实例,并测试其输出形状
# blk = Residual(3, 3)
# X = torch.rand(4, 3, 6, 6)
# Y = blk(X)
# print(Y.shape)  # 预期输出形状:torch.Size([4, 3, 6, 6])# 创建一个包含1x1卷积和步幅为2的残差块实例,并测试其输出形状
# blk = Residual(3, 6, use_1x1conv=True, strides=2)
# print(blk(X).shape)  # 预期输出形状:torch.Size([4, 6, 3, 3])# 定义一个包含初始卷积层、批量归一化层、ReLU激活函数和最大池化层的顺序容器
b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),nn.BatchNorm2d(64),nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)# 定义一个函数,用于创建由多个残差块组成的模块
def resnet_block(input_channels, num_channels, num_residuals, first_block=False):blk = []for i in range(num_residuals):# 如果是第一个残差块且不是第一个模块,则使用1x1卷积和步幅为2if i == 0 and not first_block:blk.append(Residual(input_channels, num_channels, use_1x1conv=True, strides=2))else:blk.append(Residual(num_channels, num_channels))return blk# 创建由残差块组成的各个模块
# *符号有多种用途,但在函数调用时,*符号主要用于将列表或元组解包。
# *resnet_block()的作用是将列表中的元素逐个传递给nn.Sequential
b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))
b3 = nn.Sequential(*resnet_block(64, 128, 2))
b4 = nn.Sequential(*resnet_block(128, 256, 2))
b5 = nn.Sequential(*resnet_block(256, 512, 2))# 创建整个ResNet模型
net = nn.Sequential(b1, b2, b3, b4, b5,nn.AdaptiveAvgPool2d((1, 1)),  # 自适应平均池化层nn.Flatten(),  # 展平层nn.Linear(512, 176)  # 全连接层,输出10类
)# 测试整个网络的输出形状
X = torch.rand(size=(1, 1, 96, 96))
for layer in net:X = layer(X)print(layer.__class__.__name__, 'output shape:\t', X.shape)
# Sequential output shape:         torch.Size([1, 64, 24, 24])
# Sequential output shape:         torch.Size([1, 64, 24, 24])
# Sequential output shape:         torch.Size([1, 128, 12, 12])
# Sequential output shape:         torch.Size([1, 256, 6, 6])
# Sequential output shape:         torch.Size([1, 512, 3, 3])
# AdaptiveAvgPool2d output shape:  torch.Size([1, 512, 1, 1])
# Flatten output shape:    torch.Size([1, 512])
# Linear output shape:     torch.Size([1, 10])# 设置训练参数
lr, num_epochs, batch_size = 0.05, 10, 256
# 加载训练和测试数据
train_iter, test_iter = lp.loda_data_fashion_mnist(batch_size, resize=96)
# 训练模型
lp.train_ch6(net, train_iter, test_iter, num_epochs, lr, lp.try_gpu())
# 显示训练结果
plt.show()# loss 0.009, train acc 0.998, test acc 0.920
# 2306.3 examples/sec on cuda:0

运行结果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/32948.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

技术性屏蔽百度爬虫已经一周了!

很久前明月就发现百度爬虫只抓取、只收录就是不给流量了,加上百度搜索体验越来越差,反正明月已经很久没有用过百度搜索,目前使用的浏览器几乎默认搜索都已经修改成其他搜索引擎了,真要搜索什么,一般都是必应谷歌结合着…

uni-CMS:全端开源内容管理系统的技术探索

摘要 本文介绍了uni-CMS,一个基于uniCloud开发的开源内容管理系统(CMS)。该系统旨在帮助开发者快速搭建并管理内容丰富的网站、小程序和移动应用。通过其全端渲染、内容安全检测、广告解锁付费内容以及AI生成文章等特性,uni-CMS不…

cesium 添加 Echarts 饼图

cesium 添加 Echarts 饼图 1、实现思路 1、首先创建echarts饼图,拿到创建好的canvas 2、用echarts里面生成的canvas添加到cesium billboard中 2、示例代码 <!DOCTYPE html> <html lang="en"><head><

微信小程序毕业设计-在线厨艺平台系统项目开发实战(附源码+论文)

大家好&#xff01;我是程序猿老A&#xff0c;感谢您阅读本文&#xff0c;欢迎一键三连哦。 &#x1f49e;当前专栏&#xff1a;微信小程序毕业设计 精彩专栏推荐&#x1f447;&#x1f3fb;&#x1f447;&#x1f3fb;&#x1f447;&#x1f3fb; &#x1f380; Python毕业设计…

微软搁置水下数据中心项目——项目纳蒂克相比陆地服务器故障更少

“我的团队努力了&#xff0c;并且成功了&#xff0c;”COI负责人诺埃尔沃尔什说。 微软已悄然终止了始于2013年的水下数据中心&#xff08;UDC&#xff09;项目“纳蒂克”。该公司向DatacenterDynamics确认了这一消息&#xff0c;微软云运营与创新部门负责人诺埃尔沃尔什表示…

STM32 Customer BootLoader 刷新项目 (二) 方案介绍

STM32 Customer BootLoader 刷新项目 (二) 方案介绍 文章目录 STM32 Customer BootLoader 刷新项目 (二) 方案介绍1. 需求分析2. STM32 Memery介绍3. BootLoader方案介绍4. 支持指令 1. 需求分析 首先在开始编程之前&#xff0c;我们先详细设计一下BootLoder的方案。 本项目做…

openjudge_2.5基本算法之搜索_917:Knight Moves

题目 917:Knight Moves 总时间限制: 1000ms 内存限制: 65536kB 描述 Background Mr Somurolov, fabulous chess-gamer indeed, asserts that no one else but him can move knights from one position to another so fast. Can you beat him? The Problem Your task is to wr…

python3GUI--ktv点歌软件By:PyQt5(附下载地址)

文章目录 一&#xff0e;前言二&#xff0e;展示1.启动2.搜索2.服务1.首页2.天气预报3.酒水饮料4.酒水饮料2 3.服务4.灯光5.调音6.排行榜7.分类点歌9.歌手点歌10.歌手个人页 三&#xff0e;心得体会1.关于代码2.关于设计3.关于打包 四&#xff0e;总结 文件大小&#xff1a;33.…

绘制口罩maskTheFace数据源是300w_lp

官网下载mask the face 代码&#xff0c;增加代码draw_face.py import argparse import cv2 import scipy.io from tqdm import tqdm from utils.aux_functions_2 import *# 设置命令行输入参数 parser argparse.ArgumentParser(description"MaskTheFace - Python code…

JavaScript之类(1)

class基础语法结构&#xff1a; 代码&#xff1a; class MyClass {constructor() { ... }method1() { ... }method2() { ... }method3() { ... }... } 解释&#xff1a; 属性解释class是我们定义的类型(类)MyClass是我们定义的类的名称 constructor()我们可以在其中初始化对象m…

微软Edge浏览器全解析

微软Edge浏览器全解析(一) 解决浏览器的主页被篡改后无法通过浏览器的自带设置来恢复的问题 相信各位都有发现新买的联想电脑浏览器的主页设置不太满意,但从浏览器自带的设置上又无法解决此问题,网上找了许多方法都无济于事,特别对有着强迫症的小伙伴们更是一种煎熬。 通…

CVE-2023-50563(sql延时注入)

简介 SEMCMS是一套支持多种语言的外贸网站内容管理系统&#xff08;CMS&#xff09;。SEMCMS v4.8版本存在SQLI&#xff0c;该漏洞源于SEMCMS_Function.php 中的 AID 参数包含 SQL 注入 过程 打开靶场 目录扫描&#xff0c;发现安装install目录&#xff0c;进入&#xff0c;…

免费一年SSL证书申请——建议收藏

免费一年SSL证书申请——建议收藏 获取免费一年期SSL证书其实挺简单的 准备你的网站&#xff1a; 确保你的网站已经有了域名&#xff0c;而且这个域名已经指向你的服务器。还要检查你的服务器支持HTTPS&#xff0c;也就是443端口要打开&#xff0c;这是HTTPS默认用的。 验证域…

开源技术:在线教育系统源码及教育培训APP开发指南

本篇文章&#xff0c;小编将探讨如何利用开源技术开发在线教育系统及教育培训APP&#xff0c;旨在为有志于此的开发者提供全面的指导和实践建议。 一、在线教育系统的基本构架 1.1架构设计 包括前端、后端和数据库三个主要部分。 1.2前端技术 在前端开发中&#xff0c;HTML…

[实践篇]13.29 再来聊下Pass Through设备透传

写在前面 为什么要再聊天Pass Through? 因为在QNX + Linux Android的技术方案下,我们会遇到LA发生reboot或异常panic后,无法正常开机。而再次异常的原因确实最头疼的Memory Corruption。观察下来是由于一些DMA外设如使用UART的一些设备在重启或panic后,没有正常走Shutdow…

使用Inno Setup 5.5制作软件安装包-精品(二)

上一篇 使用Inno Setup 6制作软件安装包&#xff08;一&#xff09;-CSDN博客 文章简单的说了一下使用Inno Setup 6制作软件安装包&#xff0c;具体有很多的细节&#xff0c;都可以参考上篇的案例。本节说一下&#xff0c;Inno Setup 5 增强版制作软件精品安装包&#xff0c;…

如何搭建饥荒服务器

《饥荒》是由Klei Entertainment开发的一款动作冒险类求生游戏&#xff0c;于2013年4月23日在PC上发行&#xff0c;2015年7月9日在iOS发布口袋版。游戏讲述的是关于一名科学家被恶魔传送到了一个神秘的世界&#xff0c;玩家将在这个异世界生存并逃出这个异世界的故事。《饥荒》…

力扣SQL50 销售分析III having + 条件计数

Problem: 1084. 销售分析III &#x1f468;‍&#x1f3eb; 参考题解 Code select s.product_id,p.product_name from sales s left join product p on s.product_id p.product_id group by product_id having count(if(sale_date between 2019-01-01 and 2019-03-31,1,nu…

【SpringBoot Actuator】⭐️Actuator 依赖实现服务健康检查,线程信息收集

目录 &#x1f378;前言 &#x1f37b;一、Actuator 了解 &#x1f37a;二、使用 2.1 依赖引入 2.2 测试场景搭建 &#x1f379;三、测试 3.1 项目启动测试 3.2 服务健康检查 3.3 线程转储 3.4 内存使用&#xff0c;垃圾回收信息获取 &#x1f49e;️四、章末 &#x1…

MySQL的自增 ID 用完了,怎么办?

MySQL 自增 ID 一般用的数据类型是 INT 或 BIGINT&#xff0c;正常情况下这两种类型可以满足大多数应用的需求。 当然也有不正常的情况&#xff0c;当达到其最大值时&#xff0c;尝试插入新的记录会导致错误&#xff0c;错误信息类似于&#xff1a; ERROR 167 (22003): Out o…