PyTorch从零开始实现ResNet

文章目录

    • 代码实现
    • 参考

代码实现

本文实现 ResNet原论文 Deep Residual Learning for Image Recognition 中的50层,101层和152层残差连接。
在这里插入图片描述
代码中使用基础残差块这个概念,这里的基础残差块指的是上图中红色矩形圈出的内容:从上到下分别使用3, 4, 6, 3个基础残差块,每个基础残差块由三个卷积层组成,核大小分别为1x1, 3x3, 1x1 。

残差连接的结构

在这里插入图片描述

复现代码如下:

import torch
import torch.nn as nn# 基础残差块,后面ResNet要多次重复使用该块
class block(nn.Module):def __init__(self, in_channels, out_channels, identity_downsample=None, stride=1):super(block, self).__init__()self.expansion = 4  self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)self.bn1 = nn.BatchNorm2d(out_channels)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1)self.bn2 = nn.BatchNorm2d(out_channels)self.conv3 = nn.Conv2d(out_channels, out_channels*self.expansion, kernel_size=1, stride=1, padding=0)self.bn3 = nn.BatchNorm2d(out_channels*self.expansion)self.relu = nn.ReLU()self.identity_downsample = identity_downsampledef forward(self, x):identity = xx = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.conv2(x)x = self.bn2(x)x = self.relu(x)x = self.conv3(x)x = self.bn3(x)# x 和 identity形状一致,才能相加if self.identity_downsample is not None:identity = self.identity_downsample(identity)x += identityx = self.relu(x)return xclass ResNet(nn.Module):def __init__(self, block, layers, image_channels, num_classes):super(ResNet, self).__init__()# 初始化的层self.in_channels = 64self.conv1 = nn.Conv2d(image_channels, 64, kernel_size=7, stride=2, padding=3)self.bn1 = nn.BatchNorm2d(64)self.relu = nn.ReLU()self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)# ResNet layersself.layer1 = self._make_layer(block, layers[0], out_channels=64, stride=1)self.layer2 = self._make_layer(block, layers[1], out_channels=128, stride=2)self.layer3 = self._make_layer(block, layers[2], out_channels=256, stride=2)self.layer4 = self._make_layer(block, layers[3], out_channels=512, stride=2)self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512*4, num_classes)def forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avgpool(x)x = x.reshape(x.shape[0], -1)x = self.fc(x)return x# 核心函数:调用block基础残差块,构造ResNet的每一层def _make_layer(self, block, num_residual_blocks, out_channels, stride):identity_downsample = Nonelayers = []# 修改形状,使得残差连接可以相加:x + identityif stride != 1 or self.in_channels != out_channels * 4:identity_downsample = nn.Sequential(nn.Conv2d(self.in_channels, out_channels*4, kernel_size=1,stride=stride),                                               nn.BatchNorm2d(out_channels*4))layers.append(block(self.in_channels, out_channels, identity_downsample, stride))self.in_channels = out_channels * 4for i in range(num_residual_blocks - 1):layers.append(block(self.in_channels, out_channels)) # 256 -> 64, 64*4(256) againreturn nn.Sequential(*layers)# 构造ResNet50层:默认图像通道3,分类类别为1000
def resnet50(img_channels=3, num_classes=1000):return ResNet(block, [3, 4, 6, 3], img_channels, num_classes)# 构造ResNet101层  
def resnet101(img_channels=3, num_classes=1000):return ResNet(block, [3, 4, 23, 3], img_channels, num_classes)# 构造ResNet152层  
def resnet152(img_channels=3, num_classes=1000):return ResNet(block, [3, 8, 36, 3], img_channels, num_classes)# 测试输出y的形状是否满足1000类
def test():net = resnet152()x = torch.randn(2, 3, 224, 224)y = net(x)print(y.shape) # [2, 1000]test()

参考

[1] Deep Residual Learning for Image Recognition
[2] https://www.youtube.com/watch?v=DkNIBBBvcPs&list=PLhhyoLH6IjfxeoooqP9rhU3HJIAVAJ3Vz&index=19

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/43540.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

感觉和身边其他人有差距怎么办?

虽然清楚知识需要靠时间沉淀,但在看到自己做不出来的题别人会做,自己写不出的代码别人会写时还是会感到焦虑怎么办? 你是否也因为自身跟周围人的差距而产生过迷茫,这份迷茫如今是被你克服了还是仍旧让你感到困扰? 下…

LabVIEW开发最小化5G系统测试平台

LabVIEW开发最小化5G系统测试平台 由于具有大量存储能力和数据的应用程序的智能手机的激增,当前一代产品被迫提高其吞吐效率。正交频分复用由于其卓越的品质,如单抽头均衡和具有成本效益的实施,现在被广泛用作物理层技术。这些好处是以严格的…

ElasticSearch索引库、文档、RestClient操作

文章目录 一、索引库1、mapping属性2、索引库的crud 二、文档的crud三、RestClient 一、索引库 es中的索引是指相同类型的文档集合,即mysql中表的概念 映射:索引中文档字段的约束,比如名称、类型 1、mapping属性 mapping映射是对索引库中文…

公网远程连接Redis数据库详解

文章目录 1. Linux(centos8)安装redis数据库2. 配置redis数据库3. 内网穿透3.1 安装cpolar内网穿透3.2 创建隧道映射本地端口 4. 配置固定TCP端口地址4.1 保留一个固定tcp地址4.2 配置固定TCP地址4.3 使用固定的tcp地址连接 前言 洁洁的个人主页 我就问你有没有发挥&#xff0…

通讯录实现【C语言】

目录 前言 一、整体逻辑分析 二、实现步骤 1、创建菜单和多次操作问题 2、创建通讯录 3、初始化通讯录 4、添加联系人 5、显示联系人 6、删除指定联系人 ​7、查找指定联系人 8、修改联系人信息 9、排序联系人信息 三、全部源码 前言 我们上期已经详细的介绍了自定…

Java SpringBoot Vue ERP系统

系统介绍 该ERP系统基于SpringBoot框架和SaaS模式,支持多租户,专注进销存财务生产功能。主要模块有零售管理、采购管理、销售管理、仓库管理、财务管理、报表查询、系统管理等。支持预付款、收入支出、仓库调拨、组装拆卸、订单等特色功能。拥有商品库存…

ubuntu设置共享文件夹成功后却不显示找不到(已解决)

1.首先输下面命令查看是否真的设置成功共享文件夹 vmware-hgfsclient如果确实已经设置过共享文件夹将输出window下共享文件夹名字 2.确认自己已设置共享文件夹后输入下面的命令 //如果之前没有命令包则先执行sudo apt-get install open-vm-tools sudo vmhgfs-fuse .host:/ /mn…

十六、Spring Cloud Sleuth 分布式请求链路追踪

目录 一、概述1、为什么出出现这个技术?需要解决哪些问题2、是什么?3、解决 二、搭建链路监控步骤1、下载运行zipkin2、服务提供者3、服务调用者4、测试 一、概述 1、为什么出出现这个技术?需要解决哪些问题 2、是什么? 官网&am…

spss---如何使用信度分析以及案例分析

信度分析 问卷调查法是教育研究中广泛采用的一种调查方法,根据调查目的设计的调查问卷是问卷调查法获取信息的工具,其质量高低对调查结果的真实性、适用性等具有决定性的作用。 为了保证问卷具有较高的可靠性和有效性,在形成正式问卷之 前&…

CLion:最好用的c/c++编写工具(最详细安装教程)

目录 一.前言介绍 1.下载安装 1.1右上角点击下载 1.2选择自己操作系统,然后点击下载 1.3选择next 1.4 更改路径 1.5D盘最好 1.6 按照我的选择配置环境 1.7install安装 1.8 安装完成 2、mingw64安装 2.1下载资源压缩包 2.2mingw64放入到合适的位置,…

Redis五大基本数据类型及其使用场景

文章目录 **一 什么是NoSQL?****二 redis是什么?****三 redis五大基本类型**1 String(字符串)**应用场景** 2 List(列表)**应用场景** 3 Set(集合)4 sorted set(有序集合…

高级艺术二维码制作教程

最近不少关于二维码制作的,而且都是付费。大概就是一个好看的二维码,扫描后跳转网址。本篇文章使用Python来实现,这么简单花啥钱呢?学会,拿去卖便宜点吧。 文章目录 高级二维码制作环境安装普通二维码艺术二维码动态 …

【LVS】2、部署LVS-DR群集

LVS-DR数据包的流向分析 1.客户端发送请求到负载均衡器,请求的数据报文到达内核空间; 2.负载均衡服务器和正式服务器在同一个网络中,数据通过二层数据链路层来传输; 3.内核空间判断数据包的目标IP是本机VIP,此时IP虚…

SASS 学习笔记

SASS 学习笔记 总共会写两个练手项目,成品在 https://goldenaarcher.com/scss-study 可以看到,代码在 https://github.com/GoldenaArcher/scss-study。 什么是 SASS SASS 是 CSS 预处理,它提供了变量(虽然现在 CSS 也提供了&am…

C++ 面向对象三大特性——继承

✅<1>主页&#xff1a;我的代码爱吃辣 &#x1f4c3;<2>知识讲解&#xff1a;C 继承 ☂️<3>开发环境&#xff1a;Visual Studio 2022 &#x1f4ac;<4>前言&#xff1a;面向对象三大特性的&#xff0c;封装&#xff0c;继承&#xff0c;多态&#xff…

【数仓建设系列之一】什么是数据仓库?

一、什么是数据仓库&#xff1f; 数据仓库(Data Warehouse&#xff0c;简称DW)简单来讲&#xff0c;它是一个存储和管理大量结构化和非结构化数据的存储集合&#xff0c;它以主题为向导&#xff0c;通过整合来自不同数据源下的数据(比如各业务数据&#xff0c;日志文件数据等)…

MySQL- sql语句基础

文章目录 1.select后对表进行修改&#xff08;delete&#xff09;2.函数GROUP_CONCAT()3.使用正则表达式3.DATE_FORMAT()4.count() 加条件 1.select后对表进行修改&#xff08;delete&#xff09; 报错&#xff1a;You can’t specify target table ‘Person’ for update in …

proteus结合keil-arm编译器构建STM32单片机项目进行仿真

proteus是可以直接创建设计图和源码的&#xff0c;但是源码编译它需要借助keil-arm编译器&#xff0c;也就是我们安装keil-mdk之后自带的编译器。 下面给出一个完整的示例&#xff0c;主要是做一个LED灯闪烁的效果。 新建工程指定路径&#xff0c;Schematic,PCB layout都选择默…

【Docker】 使用Docker-Compose 搭建基于 WordPress 的博客网站

引 本文将使用流行的博客搭建工具 WordPress 搭建一个私人博客站点。部署过程中使用到了 Docker 、MySQL 。站点搭建完成后经行了发布文章的体验。 WordPress WordPress 是一个广泛使用的开源内容管理系统&#xff08;CMS&#xff09;&#xff0c;用于构建和管理网站、博客和…

大数据平台运维实训室建设方案

一、概况 本实训室的主要目的是培养大数据平台运维项目的实践能力,以数据计算、分析、挖掘和可视化的案例训练为辅助。同时,实训室也承担相关考评员与讲师培训考试、学生认证培训考试、社会人员认证培训考试、大数据技能大赛训练、大数据专业课程改革等多项任务。 实训室旨在培…