PyTorch从零开始实现ResNet

文章目录

    • 代码实现
    • 参考

代码实现

本文实现 ResNet原论文 Deep Residual Learning for Image Recognition 中的50层,101层和152层残差连接。
在这里插入图片描述
代码中使用基础残差块这个概念,这里的基础残差块指的是上图中红色矩形圈出的内容:从上到下分别使用3, 4, 6, 3个基础残差块,每个基础残差块由三个卷积层组成,核大小分别为1x1, 3x3, 1x1 。

残差连接的结构

在这里插入图片描述

复现代码如下:

import torch
import torch.nn as nn# 基础残差块,后面ResNet要多次重复使用该块
class block(nn.Module):def __init__(self, in_channels, out_channels, identity_downsample=None, stride=1):super(block, self).__init__()self.expansion = 4  self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)self.bn1 = nn.BatchNorm2d(out_channels)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1)self.bn2 = nn.BatchNorm2d(out_channels)self.conv3 = nn.Conv2d(out_channels, out_channels*self.expansion, kernel_size=1, stride=1, padding=0)self.bn3 = nn.BatchNorm2d(out_channels*self.expansion)self.relu = nn.ReLU()self.identity_downsample = identity_downsampledef forward(self, x):identity = xx = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.conv2(x)x = self.bn2(x)x = self.relu(x)x = self.conv3(x)x = self.bn3(x)# x 和 identity形状一致,才能相加if self.identity_downsample is not None:identity = self.identity_downsample(identity)x += identityx = self.relu(x)return xclass ResNet(nn.Module):def __init__(self, block, layers, image_channels, num_classes):super(ResNet, self).__init__()# 初始化的层self.in_channels = 64self.conv1 = nn.Conv2d(image_channels, 64, kernel_size=7, stride=2, padding=3)self.bn1 = nn.BatchNorm2d(64)self.relu = nn.ReLU()self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)# ResNet layersself.layer1 = self._make_layer(block, layers[0], out_channels=64, stride=1)self.layer2 = self._make_layer(block, layers[1], out_channels=128, stride=2)self.layer3 = self._make_layer(block, layers[2], out_channels=256, stride=2)self.layer4 = self._make_layer(block, layers[3], out_channels=512, stride=2)self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512*4, num_classes)def forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avgpool(x)x = x.reshape(x.shape[0], -1)x = self.fc(x)return x# 核心函数:调用block基础残差块,构造ResNet的每一层def _make_layer(self, block, num_residual_blocks, out_channels, stride):identity_downsample = Nonelayers = []# 修改形状,使得残差连接可以相加:x + identityif stride != 1 or self.in_channels != out_channels * 4:identity_downsample = nn.Sequential(nn.Conv2d(self.in_channels, out_channels*4, kernel_size=1,stride=stride),                                               nn.BatchNorm2d(out_channels*4))layers.append(block(self.in_channels, out_channels, identity_downsample, stride))self.in_channels = out_channels * 4for i in range(num_residual_blocks - 1):layers.append(block(self.in_channels, out_channels)) # 256 -> 64, 64*4(256) againreturn nn.Sequential(*layers)# 构造ResNet50层:默认图像通道3,分类类别为1000
def resnet50(img_channels=3, num_classes=1000):return ResNet(block, [3, 4, 6, 3], img_channels, num_classes)# 构造ResNet101层  
def resnet101(img_channels=3, num_classes=1000):return ResNet(block, [3, 4, 23, 3], img_channels, num_classes)# 构造ResNet152层  
def resnet152(img_channels=3, num_classes=1000):return ResNet(block, [3, 8, 36, 3], img_channels, num_classes)# 测试输出y的形状是否满足1000类
def test():net = resnet152()x = torch.randn(2, 3, 224, 224)y = net(x)print(y.shape) # [2, 1000]test()

参考

[1] Deep Residual Learning for Image Recognition
[2] https://www.youtube.com/watch?v=DkNIBBBvcPs&list=PLhhyoLH6IjfxeoooqP9rhU3HJIAVAJ3Vz&index=19

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/43540.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

感觉和身边其他人有差距怎么办?

虽然清楚知识需要靠时间沉淀,但在看到自己做不出来的题别人会做,自己写不出的代码别人会写时还是会感到焦虑怎么办? 你是否也因为自身跟周围人的差距而产生过迷茫,这份迷茫如今是被你克服了还是仍旧让你感到困扰? 下…

LabVIEW开发最小化5G系统测试平台

LabVIEW开发最小化5G系统测试平台 由于具有大量存储能力和数据的应用程序的智能手机的激增,当前一代产品被迫提高其吞吐效率。正交频分复用由于其卓越的品质,如单抽头均衡和具有成本效益的实施,现在被广泛用作物理层技术。这些好处是以严格的…

ElasticSearch索引库、文档、RestClient操作

文章目录 一、索引库1、mapping属性2、索引库的crud 二、文档的crud三、RestClient 一、索引库 es中的索引是指相同类型的文档集合,即mysql中表的概念 映射:索引中文档字段的约束,比如名称、类型 1、mapping属性 mapping映射是对索引库中文…

Elasticsearch在部署时,对Linux的设置有哪些优化方法?

部署Elasticsearch时,可以通过优化Linux系统的设置来提升性能和稳定性。以下是一些常见的优化方法: 1.文件描述符限制 Elasticsearch需要大量的文件描述符来处理数据和连接,所以确保调整系统的文件描述符限制。可以通过修改 /etc/security/…

Docker-compose搭建Git私服

1. 新建个专用的目录,然后在里面新建个docker-compose.yml文件: (gitlab-ce是社区版,当然还有ee,是商业版) version: 3.6 …

es自定义分词器支持数字字母分词,中文分词器jieba支持添加禁用词和扩展词典

自定义分析器,分词器 PUT http://xxx.xxx.xxx.xxx:9200/test_index/ {"settings": {"analysis": {"analyzer": {"char_test_analyzer": {"tokenizer": "char_test_tokenizer","filter": [&…

公网远程连接Redis数据库详解

文章目录 1. Linux(centos8)安装redis数据库2. 配置redis数据库3. 内网穿透3.1 安装cpolar内网穿透3.2 创建隧道映射本地端口 4. 配置固定TCP端口地址4.1 保留一个固定tcp地址4.2 配置固定TCP地址4.3 使用固定的tcp地址连接 前言 洁洁的个人主页 我就问你有没有发挥&#xff0…

ssh免密登陆报错ERROR: @ WARNING: REMOTE HOST IDENTIFICATION HAS CHANGED!

问题描述: 在日常的运维中需要做ssh的免密登陆有提示如下的报错内容: [rootpaas-harbor01 cce-v5.2.3]# ssh-copy-id 192.45.66.14 /usr/bin/ssh-copy-id: INFO: Source of key(s) to be installed: "/root/.ssh/id_rsa.pub" /usr/bin/ssh-c…

通讯录实现【C语言】

目录 前言 一、整体逻辑分析 二、实现步骤 1、创建菜单和多次操作问题 2、创建通讯录 3、初始化通讯录 4、添加联系人 5、显示联系人 6、删除指定联系人 ​7、查找指定联系人 8、修改联系人信息 9、排序联系人信息 三、全部源码 前言 我们上期已经详细的介绍了自定…

Java SpringBoot Vue ERP系统

系统介绍 该ERP系统基于SpringBoot框架和SaaS模式,支持多租户,专注进销存财务生产功能。主要模块有零售管理、采购管理、销售管理、仓库管理、财务管理、报表查询、系统管理等。支持预付款、收入支出、仓库调拨、组装拆卸、订单等特色功能。拥有商品库存…

ubuntu设置共享文件夹成功后却不显示找不到(已解决)

1.首先输下面命令查看是否真的设置成功共享文件夹 vmware-hgfsclient如果确实已经设置过共享文件夹将输出window下共享文件夹名字 2.确认自己已设置共享文件夹后输入下面的命令 //如果之前没有命令包则先执行sudo apt-get install open-vm-tools sudo vmhgfs-fuse .host:/ /mn…

十六、Spring Cloud Sleuth 分布式请求链路追踪

目录 一、概述1、为什么出出现这个技术?需要解决哪些问题2、是什么?3、解决 二、搭建链路监控步骤1、下载运行zipkin2、服务提供者3、服务调用者4、测试 一、概述 1、为什么出出现这个技术?需要解决哪些问题 2、是什么? 官网&am…

spss---如何使用信度分析以及案例分析

信度分析 问卷调查法是教育研究中广泛采用的一种调查方法,根据调查目的设计的调查问卷是问卷调查法获取信息的工具,其质量高低对调查结果的真实性、适用性等具有决定性的作用。 为了保证问卷具有较高的可靠性和有效性,在形成正式问卷之 前&…

CLion:最好用的c/c++编写工具(最详细安装教程)

目录 一.前言介绍 1.下载安装 1.1右上角点击下载 1.2选择自己操作系统,然后点击下载 1.3选择next 1.4 更改路径 1.5D盘最好 1.6 按照我的选择配置环境 1.7install安装 1.8 安装完成 2、mingw64安装 2.1下载资源压缩包 2.2mingw64放入到合适的位置,…

Redis五大基本数据类型及其使用场景

文章目录 **一 什么是NoSQL?****二 redis是什么?****三 redis五大基本类型**1 String(字符串)**应用场景** 2 List(列表)**应用场景** 3 Set(集合)4 sorted set(有序集合…

高级艺术二维码制作教程

最近不少关于二维码制作的,而且都是付费。大概就是一个好看的二维码,扫描后跳转网址。本篇文章使用Python来实现,这么简单花啥钱呢?学会,拿去卖便宜点吧。 文章目录 高级二维码制作环境安装普通二维码艺术二维码动态 …

【LVS】2、部署LVS-DR群集

LVS-DR数据包的流向分析 1.客户端发送请求到负载均衡器,请求的数据报文到达内核空间; 2.负载均衡服务器和正式服务器在同一个网络中,数据通过二层数据链路层来传输; 3.内核空间判断数据包的目标IP是本机VIP,此时IP虚…

批量将Excel中的第二列内容从拼音转换为汉字

要批量将Excel中的第二列内容从拼音转换为汉字,您可以使用Python的openpyxl库来实现。下面是一个示例代码,演示如何读取Excel文件并将第二列内容进行拼音转汉字: from openpyxl import load_workbook from xpinyin import Pinyin # 打开Exce…

Android kotlin系列讲解(入门篇)使用Intent在Activity之间穿梭

<<返回总目录 上一篇:Android kotlin系列讲解(入门篇)Activity的理解与基本用法 文章目录 1、使用显式Intent2、使用隐式Intent3、更多隐式Intent的用法4、向下一个Activity传递数据5、返回数据给上一个Activity1、使用显式Intent 你应该已经对创建Activity的流程比较…

SASS 学习笔记

SASS 学习笔记 总共会写两个练手项目&#xff0c;成品在 https://goldenaarcher.com/scss-study 可以看到&#xff0c;代码在 https://github.com/GoldenaArcher/scss-study。 什么是 SASS SASS 是 CSS 预处理&#xff0c;它提供了变量&#xff08;虽然现在 CSS 也提供了&am…