pytorch-ResNet18简单复现

目录

  • 1. ResNet block
  • 2. ResNet18网络结构
  • 3. 完整代码
    • 3.1 网络代码
    • 3.2 训练代码

1. ResNet block

ResNet block有两个convolution和一个short cut层,如下图:
在这里插入图片描述
代码:

class ResBlk(nn.Module):def __init__(self, ch_in, ch_out, stride):super(ResBlk, self).__init__()self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)self.bn1 = nn.BatchNorm2d(ch_out)self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)self. bn2 = nn.BatchNorm2d(ch_out)self.extra = nn.Sequential()if ch_in != ch_out:self.extra = nn.Sequential(nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),nn.BatchNorm2d(ch_out))def forward(self, x):out = F.relu(self.bn1(self.conv1(x)))out = self.bn2(self.conv2(out))out = self.extra(x) + outout = F.relu(out)return out

2. ResNet18网络结构

在这里插入图片描述
在这里插入图片描述
从上图可以看出,resnet18有1个卷积层,4个残差层和1一个线性输出层,其中每个残差层有2个resnet块,每个块有2个卷积层。
对于cifar10数据来说,输入层[b, 64, 32,32],输出是10分类
代码:

class ResNet(nn.Module):def __init__(self, block, num_blocks, num_classes=1000):super(ResNet, self).__init__()self.in_planes = 64# 初始卷积层self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)self.bn1 = nn.BatchNorm2d(64)# 四个残差层self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)# 全连接层self.linear = nn.Linear(512 * block.expansion, num_classes)# 创建一个残差层def _make_layer(self, block, planes, num_blocks, stride):strides = [stride] + [1] * (num_blocks - 1)layers = []for stride in strides:layers.append(block(self.in_planes, planes, stride))self.in_planes = planes * block.expansionreturn nn.Sequential(*layers)def forward(self, x):out = F.relu(self.bn1(self.conv1(x)))out = F.max_pool2d(out, kernel_size=3, stride=2, padding=1)out = self.layer1(out)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)#out = F.avg_pool2d(out, 4)out = F.adaptive_avg_pool2d(out, [1, 1])out = out.view(out.size(0), -1)out = self.linear(out)return out

3. 完整代码

3.1 网络代码

import torch
from torch import nn
from torch.nn import functional as Fclass ResBlk(nn.Module):expansion = 1def __init__(self, ch_in, ch_out, stride):super(ResBlk, self).__init__()self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)self.bn1 = nn.BatchNorm2d(ch_out)self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)self.bn2 = nn.BatchNorm2d(ch_out)self.extra = nn.Sequential()if ch_in != ch_out:self.extra = nn.Sequential(nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),nn.BatchNorm2d(ch_out))def forward(self, x):out = F.relu(self.bn1(self.conv1(x)))out = self.bn2(self.conv2(out))out = self.extra(x) + outout = F.relu(out)return outclass ResNet(nn.Module):def __init__(self, block, num_blocks, num_classes=1000):super(ResNet, self).__init__()self.in_planes = 64# 初始卷积层self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)self.bn1 = nn.BatchNorm2d(64)# 四个残差层self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)# 全连接层self.linear = nn.Linear(512 * block.expansion, num_classes)# 创建一个残差层def _make_layer(self, block, planes, num_blocks, stride):strides = [stride] + [1] * (num_blocks - 1)layers = []for stride in strides:layers.append(block(self.in_planes, planes, stride))self.in_planes = planes * block.expansionreturn nn.Sequential(*layers)def forward(self, x):out = F.relu(self.bn1(self.conv1(x)))out = F.max_pool2d(out, kernel_size=3, stride=2, padding=1)out = self.layer1(out)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)#out = F.avg_pool2d(out, 4)out = F.adaptive_avg_pool2d(out, [1, 1])out = out.view(out.size(0), -1)out = self.linear(out)return outdef ResNet18():return ResNet(ResBlk, [2, 2, 2, 2], 10)if __name__ == '__main__':model = ResNet18()print(model)

3.2 训练代码

import torch
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision import transforms
from torch import nn, optim
import syssys.path.append('.')
#from Lenet5 import Lenet5
from resnet import ResNet18def main():batchz = 128cifar_train = datasets.CIFAR10('cifa', True, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])]), download=True)cifar_train = DataLoader(cifar_train, batch_size=batchz, shuffle=True)cifar_test = datasets.CIFAR10('cifa', False, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])]), download=True)cifar_test = DataLoader(cifar_test, batch_size=batchz, shuffle=True)device = torch.device('cuda')#model = Lenet5().to(device)model = ResNet18().to(device)crition = nn.CrossEntropyLoss().to(device)optimizer = optim.Adam(model.parameters(), lr=1e-3)for epoch in range(1000):model.train()for batch, (x, label) in enumerate(cifar_train):x, label = x.to(device), label.to(device)logits = model(x)loss = crition(logits, label)optimizer.zero_grad()loss.backward()optimizer.step()# testmodel.eval()with torch.no_grad():total_correct = 0total_num = 0for x, label in cifar_test:x, label = x.to(device), label.to(device)logits = model(x)pred = logits.argmax(dim=1)correct = torch.eq(pred, label).float().sum().item()total_correct += correcttotal_num += x.size(0)acc = total_correct / total_numprint(epoch, 'test acc:', acc)if __name__ == '__main__':main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/864530.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java学习 (六) 面向对象--this、继承、方法重写、super

一、this 关键字 1、this 演示 vi Person.java public class Person {String name;int age;//显示声明构造器public Person(String s , int i){name s;age i;}public void setAge(int age){age age;}}vi PersonTest.java public class PersonTest {public static void m…

第二十一章 函数(Python)

文章目录 前言一、定义函数二、函数参数三、参数类型四、函数返回值五、函数类型1、无参数,无返回值2、无参数,有返回值3、有参数,无返回值4、有参数,有返回值 六、函数的嵌套七、全局变量和局部变量1、局部变量2、全局变量 前言 …

了解 .NET 中的会话管理

在 Web 开发领域,跨多个请求维护状态是一个关键方面。HTTP 的无状态特性要求开发人员实现持久保存用户数据的机制。这就是会话发挥作用的地方。在本文中,我们将探讨什么是会话、它们在 .NET 中的工作方式,并提供实际示例来说明它们的用法。 …

set_source_files_properties QT_QML_SINGLETON_TYPE

目录 前言 QT_QML_SINGLETON_TYPE 属性 基本用法 示例 1. 创建一个基本的 CMake 项目 2. 编辑 CMakeLists.txt 3. 创建 main.cpp 4. 创建 MySingleton.qml 5. 创建 qml.qrc 6. 创建 main.qml 构建和运行项目 结论 前言 在使用 Qt 和 CMake 构建项目时,…

weapp.socket.io.js

!function(t,e){if(“object"typeof exports&&"object"typeof module)module.exportse();else if("function"typeof define&&define.amd)define([],e);else{var re();for(var n in r)(“object"typeof exports?exports:t)[n]r[…

探索未知:sklearn处理未知类别数据的策略

探索未知:sklearn处理未知类别数据的策略 在机器学习项目中,我们经常遇到带有未知类别的数据,这些数据可能因为各种因素而缺失或无法归类。有效地处理这些未知类别对于构建鲁棒的模型至关重要。本文将深入探讨sklearn如何处理带有未知类别的…

解决Java中的NoRouteToHostException异常的方法

解决Java中的NoRouteToHostException异常的方法 大家好,我是免费搭建查券返利机器人省钱赚佣金就用微赚淘客系统3.0的小编,也是冬天不穿秋裤,天冷也要风度的程序猿! 在Java开发中,网络编程是非常重要的一部分&#x…

某腾X滑块验证码

⚠️前言⚠️ 本文仅用于学术交流。 学习探讨逆向知识,欢迎私信共享学习心得。 如有侵权,联系博主删除。 请勿商用,否则后果自负。 网址 aHR0cHM6Ly9jbG91ZC50ZW5jZW50LmNvbS9wcm9kdWN0L2NhcHRjaGE= 1. 先整体分析一下 1_1. 验证码信息下发接口 cap_union_prehandle ua:…

JS基础与Chrome介绍

导言 在Web开发中后端负责程序架构和数据管理,前端负责页面展示和用户交互;在这种前后端分离的开发方式中,以接口为标准来进行联调整合,为了保证接口在调用时数据的安全性,也为了防止请求参数被篡改,大多数…

白骑士的Python教学基础篇 1.5 数据结构

系列目录​​​​​​​ 上一篇:白骑士的Python教学基础篇 1.4 函数与模块 数据结构是编程语言中用于存储和组织数据的基本构件。在Python中,常见的数据结构包括列表(List)、元组(Tuple)、字典&#xff08…

深入理解 “androidx.databinding.DataBindingUtil“ 细节和使用

介绍 数据绑定(Data Binding)是 Android 中的一个强大功能,它允许你使用声明性格式而不是编程方式将布局中的 UI 组件绑定到应用中的数据源。androidx.databinding.DataBindingUtil 类是一个工具类,它提供了用于处理数据绑定的方…

容器技术-docker5

一、docker-compose 常用命令和指令 1. 概要 默认的模板文件是 docker-compose.yml,其中定义的每个服务可以通过 image 指令指定镜像或 build 指令(需要 Dockerfile)来自动构建。 注意如果使用 build 指令,在 Dockerfile 中设置…

【面试干货】Static关键字的用法详解

【面试干货】Static关键字的用法详解 1、Static修饰内部类2、Static修饰方法3、Static修饰变量4、Static修饰代码块5、总结 💖The Begin💖点点关注,收藏不迷路💖 在Java编程语言中,static是一个关键字,它可…

MT19937 64bit 机器上的实现及原理解析

1, mt19937 实现源码 mt19937-64bit_ex.cpp /*References:T. Nishimura, Tables of 64-bit Mersenne TwistersACM Transactions on Modeling and Computer Simulation 10. (2000) 348--357.M. Matsumoto and T. Nishimura,Mersenne Twister: a 623-dimensionally e…

electron vite react 创建一个项目

要使用 Electron、Vite 和 React 创建一个项目,你可以按照以下步骤操作: 1. 安装 Node.js 和 npm 首先,确保你的计算机上安装了 Node.js 和 npm(Node Package Manager)。你可以从 Node.js 官网 下载并安装。 2. 初始化一个新的项目 在你的工作目录下,创建一个新的文件…

水果商城外卖微信小程序模板

手机微信水果外卖,水果电商,水果商城网页小程序模板。包含:主页、列表页、详情页、购物车、个人中心。 水果商城外卖小程序模板

[C++][设计模式][迭代器模式]详细讲解

目录 1.动机2.模式定义3.要点总结4.代码感受 1.动机 在软件构建过程中,集合对象内部结构常常变化各异。但对于这些集合对象,我们希望不暴露其内部结构的同时,可以让外部客户代码透明地访问其中包含的元素; 同时这种”透明遍历“也…

实现Java中的线程安全集合类

实现Java中的线程安全集合类 大家好,我是免费搭建查券返利机器人省钱赚佣金就用微赚淘客系统3.0的小编,也是冬天不穿秋裤,天冷也要风度的程序猿! 一、介绍 在多线程编程中,保证数据的线程安全性是至关重要的。Java提…

可燃气体报警器检测机构:严格遵守的安全标准

随着工业、商业和家庭领域对安全要求的不断提高,可燃气体报警器作为预防火灾和爆炸事故的重要设备,其性能稳定性和可靠性越来越受到关注。 可燃气体报警器检测机构应运而生,为确保这些设备的有效运行发挥着不可替代的作用。 接下来&#xf…

超强风冷制动电阻器-大功率对流冷却电阻器

风冷制动电阻 EAK 的风冷制动电阻器的制造功率范围为 5 kW 至 1200 kW。这些电阻器用于从螺旋桨、起重机、绞盘、顶部驱动器等倾倒多余的电力。 风冷电阻器是独立的单元,不需要进一步安装。该装置由内置风扇冷却。它也可以在没有风扇的情况下制作(非强…