【深度学习】卷积网络代码实战ResNet

        ResNet (Residual Network) 是由微软研究院的何凯明等人在2015年提出的一种深度卷积神经网络结构。ResNet的设计目标是解决深层网络训练中的梯度消失和梯度爆炸问题,进一步提高网络的表现。下面是一个ResNet模型实现,使用PyTorch框架来展示如何实现基本的ResNet结构。这个例子包括了一个基本的残差块(Residual Block)以及ResNet-18的实现,代码结构分为model.py(模型文件)和train.py(训练文件)。

model.py 

      首先,我们导入所需要的包 

import torch
from torch import nn
from torch.nn import functional as F

        然后,定义Resnet Block(ResBlk)类。

class ResBlk(nn.Module):def __init__(self):super(ResBlk, self).__init__()self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1)self.bn1 = nn.BatchNorm2d(ch_out)self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)self.bn2 = nn.BatchNorm2d(ch_out)self.extra = nn.Sequential()if ch_out != ch_inself.extra = nn.Sequential(nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1)nn.BatchNorm2d(ch_out))def forward(self, x):out = F.relu(self.bn1(self.conv1(x)))out = F.relu(self.bn2(self.conv2(x)))out = self.extra(x) + outreturn out

        最后,根据ResNet18的结构对ResNet Block进行堆叠。

class Resnet18(nn.Module):def __init__(self):super(Resnet18, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)nn.BatchNorm2d(64))self.blk1 = ResBlk(64, 128)self.blk2 = ResBlk(128, 256)self.blk3 = ResBlk(256, 512)self.blk4 = ResBlk(512, 1024)self.outlayer = nn.Linear(512, 10)def forward(self, x):x = F.relu(self.conv1(x))x = self.blk1(x)x = self.blk2(x)x = self.blk3(x)x = self.blk4(x)# print('after conv1:', x.shape)x = F.adaptive_avg_pool2d(x, [1,1])x = x.view(x.size(0), -1)x = self.outlayer(x)return x

        其中,在网络结构搭建过程中,需要用到中间阶段的图片参数,用下述测试过程求得。

def main():tmp = torch.randn(2, 3, 32, 32)out = blk(tmp)print('block', out.shape)x = torch.randn(2, 3, 32, 32)model = ResNet18()out = model(x)print('resnet:', out.shape)

train.py

        首先,导入所需要的包

import torch
from torchvision import datasets
from torchvision import transforms
from torch import nn, optimizer

        然后,定义main()函数

def main():batchsz = 32cifar_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor()]), download=True)cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor()]), download=True)cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)x, label = iter(cifar_train).next()print('x:', x.shape, 'label:', label.shape)device = torch.device('cuda')model = ResNet18().to(device)criteon = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=1e-3)print(model)for epoch in range(100):for batchidx, (x, label) in enumerate(cifar_train):x, label = x.to(device), label.to(device)logits = model(x)loss = criteon(logitsm label)optimizer.zero_grad()loss.backward()optimizer.step()print(loss.item())with torch.no_grad():total_correct = 0total_num = 0for x, label in cifar_test:x, label = x.to(device), label.to(device)logits = model(x)pred = logits.argmax(dim=1)total_correct += torch.eq(pred, label).floot().sum().item()total_num += x.size(0)acc = total_correct / total_numprint(epoch, acc)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/891087.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

雷电「模拟器」v9 最新清爽去广

前言 雷电模拟器9是基于安卓9内核开发的全新版本模拟器 安装环境 [名称]:雷电「模拟器」 [大小]:579MB [版本]:9.1.34 [语言]:简体中文 [安装环境]:Windows 通过网盘分享的文件:雷电模拟器 链接:…

大模型 API 接入初探

文章目录 大模型 API 接入初探一、使用大模型 API 的前置步骤(一)注册账户与获取凭证(二)理解 API 文档 二、三个常用 API(一)列出模型(二)FIM 补全(三)对话补…

实时在线翻译谷歌插件

Real - time Translation插件的安装 1、下载插件并解压 2、打开谷歌浏览器,在地址栏输入 “chrome://extensions/” 进入扩展程序页面. 3、开启页面右上角的 “开发者模式”. 4、点击 “加载已解压的扩展程序” 按钮,选择之前解压的文件夹,点…

[数据集][图像分类]常见鱼类分类数据集2w张8类别

数据集类型:图像分类用,不可用于目标检测无标注文件 数据集格式:仅仅包含jpg图片,每个类别文件夹下面存放着对应图片 图片数量(jpg文件个数):7554(剩余1w多为测试集) 分类类别数:…

uniapp开发小程序内嵌h5页面,video视频两边有细小黑色边框

1.问题如图 2.原因分析 是否为设置上述属性呢? 设置了,但是仍然有黑边。经过选中页面元素分析后,判断video元素本身就有这种特点,就是视频资源无法完全铺满元素容器。 3.解决方案

【SpringMVC】SpringMVC 快速入门

通常,Web 应用的工作流程如下: 用户通过浏览器访问前端页面; 前端页面通过异步请求向后端服务器发送数据; 后端采用“表现层-业务层-数据层”三层架构进行开发: 表现层接收页面请求将请求参数传递给业务层业务层访问…

OpenGL变换矩阵和输入控制

在前面的文章当中我们已经成功播放了动画,让我们的角色动了起来,这一切变得比较有意思了起来。不过我们发现,角色虽然说是动了起来,不过只是在不停地原地踏步而已,而且我们也没有办法通过键盘来控制这个角色来进行移动…

【Spring MVC 核心机制】核心组件和工作流程解析

在 Web 应用开发中,处理用户请求的逻辑常常会涉及到路径匹配、请求分发、视图渲染等多个环节。Spring MVC 作为一款强大的 Web 框架,将这些复杂的操作高度抽象化,通过组件协作简化了开发者的工作。 无论是处理表单请求、生成动态页面&#x…

Verilog 过程赋值

关键词:阻塞赋值,非阻塞赋值,并行 过程性赋值是在 initial 或 always 语句块里的赋值,赋值对象是寄存器、整数、实数等类型。 这些变量在被赋值后,其值将保持不变,直到重新被赋予新值。 连续性赋值总是处…

[并发与并行] python如何构建并发管道处理多阶段任务?

文章目录 1. 背景&目标2. show me the code3.知识点4. 总结 1. 背景&目标 背景:一个任务可分为多个阶段(各个阶段非CPU密集型任务,而是属于IO密集型任务),希望每个阶段能够交给各自的线程去执行。 目标:构建支持多并发的…

07 基于OpenAMP的核间通信方案

引言 ZYNQ7020有两个CPU核心,这两个核心可以采用SMP或AMP方式进行调度,当采用AMP方式进行调度时核0和核1可以运行不同的操作系统,如核0运行Linux系统,提供有些复杂的用户交互工作,核1运行实时操作系统,对设…

7.若依参数设置、通知公告、日志管理

参数设置 对系统中的参数进行动态维护。 关闭验证码校验功能 打开页面注册功能 需要修改前端页面代码 通知公告 促进组织内部信息传递 若依只提供了一个半成品,只实现了管理员可以添加通知公告。 日志管理 追踪用户行为和系统运行状况。 登录日志 和操作日志…

基于Docker+模拟器的Appium自动化测试(二)

模拟器的设置 打开“夜神模拟器”的系统设置,切换到“手机与网络”页,选中网络设置下的“开启网络连接”和“开启网络桥接模式”复选框,而后选择“静态IP”单选框,在IP地址中输入“192.168.0.105”,网关等内容不再赘述…

大数据技术-Hadoop(三)Mapreduce的介绍与使用

目录 一、概念和定义 二、WordCount案例 1、WordCountMapper 2、WordCountReducer 3、WordCountDriver 三、序列化 1、为什么序列化 2、为什么不用Java的序列化 3、Hadoop序列化特点: 4、自定义bean对象实现序列化接口(Writable) 4…

【数据仓库】SparkSQL数仓实践

文章目录 集成hive metastoreSQL测试spark-sql 语法SQL执行流程两种数仓架构的选择hive on spark数仓配置经验 spark-sql没有元数据管理功能,只有sql 到RDD的解释翻译功能,所以需要和hive的metastore服务集成在一起使用。 集成hive metastore 在spark安…

基本算法——回归

本节将通过分析能源效率数据集(Tsanas和Xifara,2012)学习基本的回归算法。我们将基 于建筑的结构特点(比如表面、墙体与屋顶面积、高度、紧凑度)研究它们的加热与冷却负载要 求。研究者使用一个模拟器设计了12种不…

V-Express - 一款针对人像视频生成的开源软件

V-Express是腾讯AI Lab开发的一款针对人像视频生成的开源软件。它旨在通过条件性丢弃(Conditional Dropout)技术,实现渐进式训练,以改善使用单一图像生成人像视频时的控制信号平衡问题。 在生成过程中,不同的控制信号&…

Java与SQL Server数据库连接的实践与要点

本文还有配套的精品资源,点击获取 简介:Java和SQL Server数据库交互是企业级应用开发中的重要环节。本文详细探讨了使用Java通过JDBC连接到SQL Server数据库的过程,包括加载驱动、建立连接、执行SQL语句、处理异常、资源管理、事务处理和连…

学习记录—正则表达式-基本语法

正则表达式简介-《菜鸟教程》 正则表达式是一种用于匹配和操作文本的强大工具,它是由一系列字符和特殊字符组成的模式,用于描述要匹配的文本模式。 正则表达式可以在文本中查找、替换、提取和验证特定的模式。 本期内容将介绍普通字符,特殊…

企业安装加密软件有什么好处?

加密软件为企业的安全提供了很多便利,从以下几点我们看看比较重要的几个优点: 1、数据保护:企业通常拥有大量的商业机密、客户数据、技术文档等敏感信息。加密软件可以对这些信息进行加密处理,防止未经授权的人员访问。即使数据被…