pytorch-ResNet18简单复现

目录

  • 1. ResNet block
  • 2. ResNet18网络结构
  • 3. 完整代码
    • 3.1 网络代码
    • 3.2 训练代码

1. ResNet block

ResNet block有两个convolution和一个short cut层,如下图:
在这里插入图片描述
代码:

class ResBlk(nn.Module):def __init__(self, ch_in, ch_out, stride):super(ResBlk, self).__init__()self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)self.bn1 = nn.BatchNorm2d(ch_out)self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)self. bn2 = nn.BatchNorm2d(ch_out)self.extra = nn.Sequential()if ch_in != ch_out:self.extra = nn.Sequential(nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),nn.BatchNorm2d(ch_out))def forward(self, x):out = F.relu(self.bn1(self.conv1(x)))out = self.bn2(self.conv2(out))out = self.extra(x) + outout = F.relu(out)return out

2. ResNet18网络结构

在这里插入图片描述
在这里插入图片描述
从上图可以看出,resnet18有1个卷积层,4个残差层和1一个线性输出层,其中每个残差层有2个resnet块,每个块有2个卷积层。
对于cifar10数据来说,输入层[b, 64, 32,32],输出是10分类
代码:

class ResNet(nn.Module):def __init__(self, block, num_blocks, num_classes=1000):super(ResNet, self).__init__()self.in_planes = 64# 初始卷积层self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)self.bn1 = nn.BatchNorm2d(64)# 四个残差层self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)# 全连接层self.linear = nn.Linear(512 * block.expansion, num_classes)# 创建一个残差层def _make_layer(self, block, planes, num_blocks, stride):strides = [stride] + [1] * (num_blocks - 1)layers = []for stride in strides:layers.append(block(self.in_planes, planes, stride))self.in_planes = planes * block.expansionreturn nn.Sequential(*layers)def forward(self, x):out = F.relu(self.bn1(self.conv1(x)))out = F.max_pool2d(out, kernel_size=3, stride=2, padding=1)out = self.layer1(out)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)#out = F.avg_pool2d(out, 4)out = F.adaptive_avg_pool2d(out, [1, 1])out = out.view(out.size(0), -1)out = self.linear(out)return out

3. 完整代码

3.1 网络代码

import torch
from torch import nn
from torch.nn import functional as Fclass ResBlk(nn.Module):expansion = 1def __init__(self, ch_in, ch_out, stride):super(ResBlk, self).__init__()self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)self.bn1 = nn.BatchNorm2d(ch_out)self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)self.bn2 = nn.BatchNorm2d(ch_out)self.extra = nn.Sequential()if ch_in != ch_out:self.extra = nn.Sequential(nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),nn.BatchNorm2d(ch_out))def forward(self, x):out = F.relu(self.bn1(self.conv1(x)))out = self.bn2(self.conv2(out))out = self.extra(x) + outout = F.relu(out)return outclass ResNet(nn.Module):def __init__(self, block, num_blocks, num_classes=1000):super(ResNet, self).__init__()self.in_planes = 64# 初始卷积层self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)self.bn1 = nn.BatchNorm2d(64)# 四个残差层self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)# 全连接层self.linear = nn.Linear(512 * block.expansion, num_classes)# 创建一个残差层def _make_layer(self, block, planes, num_blocks, stride):strides = [stride] + [1] * (num_blocks - 1)layers = []for stride in strides:layers.append(block(self.in_planes, planes, stride))self.in_planes = planes * block.expansionreturn nn.Sequential(*layers)def forward(self, x):out = F.relu(self.bn1(self.conv1(x)))out = F.max_pool2d(out, kernel_size=3, stride=2, padding=1)out = self.layer1(out)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)#out = F.avg_pool2d(out, 4)out = F.adaptive_avg_pool2d(out, [1, 1])out = out.view(out.size(0), -1)out = self.linear(out)return outdef ResNet18():return ResNet(ResBlk, [2, 2, 2, 2], 10)if __name__ == '__main__':model = ResNet18()print(model)

3.2 训练代码

import torch
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision import transforms
from torch import nn, optim
import syssys.path.append('.')
#from Lenet5 import Lenet5
from resnet import ResNet18def main():batchz = 128cifar_train = datasets.CIFAR10('cifa', True, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])]), download=True)cifar_train = DataLoader(cifar_train, batch_size=batchz, shuffle=True)cifar_test = datasets.CIFAR10('cifa', False, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])]), download=True)cifar_test = DataLoader(cifar_test, batch_size=batchz, shuffle=True)device = torch.device('cuda')#model = Lenet5().to(device)model = ResNet18().to(device)crition = nn.CrossEntropyLoss().to(device)optimizer = optim.Adam(model.parameters(), lr=1e-3)for epoch in range(1000):model.train()for batch, (x, label) in enumerate(cifar_train):x, label = x.to(device), label.to(device)logits = model(x)loss = crition(logits, label)optimizer.zero_grad()loss.backward()optimizer.step()# testmodel.eval()with torch.no_grad():total_correct = 0total_num = 0for x, label in cifar_test:x, label = x.to(device), label.to(device)logits = model(x)pred = logits.argmax(dim=1)correct = torch.eq(pred, label).float().sum().item()total_correct += correcttotal_num += x.size(0)acc = total_correct / total_numprint(epoch, 'test acc:', acc)if __name__ == '__main__':main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/864530.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java学习 (六) 面向对象--this、继承、方法重写、super

一、this 关键字 1、this 演示 vi Person.java public class Person {String name;int age;//显示声明构造器public Person(String s , int i){name s;age i;}public void setAge(int age){age age;}}vi PersonTest.java public class PersonTest {public static void m…

某腾X滑块验证码

⚠️前言⚠️ 本文仅用于学术交流。 学习探讨逆向知识,欢迎私信共享学习心得。 如有侵权,联系博主删除。 请勿商用,否则后果自负。 网址 aHR0cHM6Ly9jbG91ZC50ZW5jZW50LmNvbS9wcm9kdWN0L2NhcHRjaGE= 1. 先整体分析一下 1_1. 验证码信息下发接口 cap_union_prehandle ua:…

JS基础与Chrome介绍

导言 在Web开发中后端负责程序架构和数据管理,前端负责页面展示和用户交互;在这种前后端分离的开发方式中,以接口为标准来进行联调整合,为了保证接口在调用时数据的安全性,也为了防止请求参数被篡改,大多数…

深入理解 “androidx.databinding.DataBindingUtil“ 细节和使用

介绍 数据绑定(Data Binding)是 Android 中的一个强大功能,它允许你使用声明性格式而不是编程方式将布局中的 UI 组件绑定到应用中的数据源。androidx.databinding.DataBindingUtil 类是一个工具类,它提供了用于处理数据绑定的方…

容器技术-docker5

一、docker-compose 常用命令和指令 1. 概要 默认的模板文件是 docker-compose.yml,其中定义的每个服务可以通过 image 指令指定镜像或 build 指令(需要 Dockerfile)来自动构建。 注意如果使用 build 指令,在 Dockerfile 中设置…

【面试干货】Static关键字的用法详解

【面试干货】Static关键字的用法详解 1、Static修饰内部类2、Static修饰方法3、Static修饰变量4、Static修饰代码块5、总结 💖The Begin💖点点关注,收藏不迷路💖 在Java编程语言中,static是一个关键字,它可…

MT19937 64bit 机器上的实现及原理解析

1, mt19937 实现源码 mt19937-64bit_ex.cpp /*References:T. Nishimura, Tables of 64-bit Mersenne TwistersACM Transactions on Modeling and Computer Simulation 10. (2000) 348--357.M. Matsumoto and T. Nishimura,Mersenne Twister: a 623-dimensionally e…

水果商城外卖微信小程序模板

手机微信水果外卖,水果电商,水果商城网页小程序模板。包含:主页、列表页、详情页、购物车、个人中心。 水果商城外卖小程序模板

[C++][设计模式][迭代器模式]详细讲解

目录 1.动机2.模式定义3.要点总结4.代码感受 1.动机 在软件构建过程中,集合对象内部结构常常变化各异。但对于这些集合对象,我们希望不暴露其内部结构的同时,可以让外部客户代码透明地访问其中包含的元素; 同时这种”透明遍历“也…

可燃气体报警器检测机构:严格遵守的安全标准

随着工业、商业和家庭领域对安全要求的不断提高,可燃气体报警器作为预防火灾和爆炸事故的重要设备,其性能稳定性和可靠性越来越受到关注。 可燃气体报警器检测机构应运而生,为确保这些设备的有效运行发挥着不可替代的作用。 接下来&#xf…

超强风冷制动电阻器-大功率对流冷却电阻器

风冷制动电阻 EAK 的风冷制动电阻器的制造功率范围为 5 kW 至 1200 kW。这些电阻器用于从螺旋桨、起重机、绞盘、顶部驱动器等倾倒多余的电力。 风冷电阻器是独立的单元,不需要进一步安装。该装置由内置风扇冷却。它也可以在没有风扇的情况下制作(非强…

HTMLCSS(入门)

HTML <html> <head><title>第一个页面</title></head><body>键盘敲烂&#xff0c;工资过万</body> </html> <!DOCTYPE>文档类型声明&#xff0c;告诉浏览器使用哪种HTML版本显示网页 <!DOCTYPE html>当前页面采取…

D - Intersecting Intervals(abc355)

题意&#xff1a;有n个区间&#xff0c;找出俩俩区间相交的个数 分析&#xff1a; 设初始俩俩相交&#xff0c;找出不相交的&#xff08;不同区间l>r)&#xff0c;减去即可 #include<bits/stdc.h> using namespace std; typedef long long ll; int main(){ ios:…

【前端CSS3】一篇搞懂各类常用选择器(黑马程序员)

文章目录 一、前言&#x1f680;&#x1f680;&#x1f680;二、正文&#xff1a;2.1 基础选择器2.1.1 标签选择器2.1.2 类选择器2.1.3 id选择器2.1.4 通配符选择题2.1.5 类选择器与id选择器区别☀️☀️☀️2.1.6 基础选择器总结&#x1f680; 2.2 复合类选择器2.2.1 后代选择…

Linux系统编程:信号

目录 1.信号概念 2.信号产生 2.1 终端 2.2 系统调用 2.3 硬件异常 2.4 软件条件 2.5 小结 3. 进程退出时的核心转储问题 4. 信号捕捉初识 5. 阻塞信号 5.1 相关概念 5.2 在内核中的表示 6. 信号捕捉 6.1 知识铺垫 6.2 信号捕捉流程 6.3 sigset_t 6.4 信号集操…

CTFHUB-SSRF-上传文件

通过file协议访问flag.php文件内容 ?urlfile:///var/www/html/flag.php 右键查看页面源代码&#xff0c;发现需要从内部上传一个文件这样才能正常获取到flag ?urlhttp://127.0.0.1/flag.php 发现无提交按钮&#xff0c;构造一个 <input type"submit" name&qu…

RabbitMQ 之 延迟队列

目录 ​编辑一、延迟队列概念 二、延迟队列使用场景 三、整合 SpringBoot 1、创建项目 2、添加依赖 3、修改配置文件 4、添加 Swagger 配置类 四、队列 TTL 1、代码架构图 2、配置文件代码类 3、生产者 4、消费者 5、结果展示 五、延时队列优化 1、代码架构图 …

Python番外篇之责任转移:有关于虚拟机编程语言的往事

编程之痛 如果&#xff0c;你像笔者一样&#xff0c;有过学习或者使用汇编语言与C、C等语言的经历&#xff0c;一定对下面所说的痛苦感同身受。 汇编语言 将以二进制表示的一条条CPU的机器指令&#xff0c;以人类可读的方式进行表示。虽然&#xff0c;人类可读了&#xff0c…

使用deep修改前端框架中的样式

目录 1.deep的作用 2.使用方式 3.特别说明 scoped 的实现原理&#xff1a; !important 1.deep的作用 /deep/、::v-deep 和 :deep 都是用于穿透组件作用域的选择器。它们的主要目的是允许开发者在父组件中直接选择并样式化子组件内部的元素&#xff0c;即使这些元素被封装在…

JVM原理(九):JVM虚拟机工具之可视化故障处理工具

1. JHSDB:基于服务性代理的调试工具 JHSDB是一款基于服务性代理实现的进程外调试工具。 服务性代理是HotSpot虚拟机中一组用于映射Java虚拟机运行信息的、主要基于Java语言实现的API集合。 2. JConsole:Java监视与管理控制台 JConsole是一款基于JMX的可视化监视、管理工具。…