【Pytorch】学习记录分享6——PyTorch经典网络 ResNet与手写体识别

【Pytorch】学习记录分享5——PyTorch经典网络 ResNet

      • 1. ResNet (残差网络)基础知识
      • 2. 感受野
      • 3. 手写体数字识别
        • 3. 0 数据集(训练与测试集)
        • 3. 1 数据加载
        • 3. 2 函数实现:
        • 3. 3 训练及其测试:

1. ResNet (残差网络)基础知识

图1 56层error比20层error高,提出ResNet (残差网络)的方案
在这里插入图片描述

网络效果:

在这里插入图片描述
网络结构:
在这里插入图片描述
在这里插入图片描述

2. 感受野

在这里插入图片描述
在这里插入图片描述

3. 手写体数字识别

3. 0 数据集(训练与测试集)

mnist 用于手写体训练与测试,这里包含完整的链接

3. 1 数据加载
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets,transforms 
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
### 首先读取数据
# - 分别构建训练集和测试集(验证集)
# - DataLoader来迭代取数据# 定义超参数 
input_size = 28  #图像的总尺寸28*28
num_classes = 10  #标签的种类数
num_epochs = 3  #训练的总循环周期
batch_size = 64  #一个撮(批次)的大小,64张图片# 训练集
train_dataset = datasets.MNIST(root='./data',  train=True,   transform=transforms.ToTensor(),  download=True) # 测试集
test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())# 构建batch数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

在这里插入图片描述

3. 2 函数实现:
# 卷积网络模块构建
# 一般卷积层,relu层,池化层可以写成一个套餐
# 注意卷积最后结果还是一个特征图,需要把图转换成向量才能做分类或者回归任务class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Sequential(         # 输入大小 (1, 28, 28)nn.Conv2d(in_channels=1,              # 灰度图out_channels=16,            # 要得到几多少个特征图kernel_size=5,              # 卷积核大小stride=1,                   # 步长padding=2,                  # 如果希望卷积后大小跟原来一样,需要设置padding=(kernel_size-1)/2 if stride=1),                              # 输出的特征图为 (16, 28, 28)nn.ReLU(),                      # relu层nn.MaxPool2d(kernel_size=2),    # 进行池化操作(2x2 区域), 输出结果为: (16, 14, 14))self.conv2 = nn.Sequential(         # 下一个套餐的输入 (16, 14, 14)nn.Conv2d(16, 32, 5, 1, 2),     # 输出 (32, 14, 14)nn.ReLU(),                      # relu层nn.MaxPool2d(2),                # 输出 (32, 7, 7))self.out = nn.Linear(32 * 7 * 7, 10)   # 全连接层得到的结果def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = x.view(x.size(0), -1)           # flatten操作,结果为:(batch_size, 32 * 7 * 7)  output = self.out(x)return output# 准确率作为评估标准
def accuracy(predictions, labels):pred = torch.max(predictions.data, 1)[1] rights = pred.eq(labels.data.view_as(pred)).sum() return rights, len(labels) 
3. 3 训练及其测试:
# 训练网络模型
# 实例化
net = CNN() 
#损失函数
criterion = nn.CrossEntropyLoss() 
#优化器
optimizer = optim.Adam(net.parameters(), lr=0.001) #定义优化器,普通的随机梯度下降算法#开始训练循环
for epoch in range(num_epochs):#当前epoch的结果保存下来train_rights = []for batch_idx, (data, target) in enumerate(train_loader):  #针对容器中的每一个批进行循环net.train()  # 将模型设置为训练模式output = net(data)  # 使用模型进行前向传播loss = criterion(output, target)  # 计算损失optimizer.zero_grad()  # 梯度清零loss.backward()  # 反向传播计算梯度optimizer.step()  # 更新参数right = accuracy(output, target)  # 计算当前批次的准确率train_rights.append(right)  # 将准确率保存起来if batch_idx % 500 == 0:  # 每500个批次进行一次验证net.eval()  # 将模型设置为评估模式val_rights = []  # 存储验证集的准确率for (data, target) in test_loader:  # 在测试集上进行验证output = net(data)  # 使用模型进行前向传播right = accuracy(output, target)  # 计算验证集上的准确率val_rights.append(right)  # 将准确率保存起来#准确率计算train_r = (sum([tup[0] for tup in train_rights]), sum([tup[1] for tup in train_rights]))  # 计算训练集准确率的分子和分母val_r = (sum([tup[0] for tup in val_rights]), sum([tup[1] for tup in val_rights]))  # 计算验证集准确率的分子和分母print('当前epoch: {} [{}/{} ({:.0f}%)]\t损失: {:.6f}\t训练集准确率: {:.2f}%\t测试集正确率: {:.2f}%'.format(epoch, batch_idx * batch_size, len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.data, 100. * train_r[0].numpy() / train_r[1],100. * val_r[0].numpy() / val_r[1]))  # 打印当前进度和准确率信息

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/237455.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Bash 脚本学习

文章目录 1、脚本编程基础2. 变量2.1 参数变量的引用2.2 环境变量 3 条件判断语句3.1 if 语句3.1.1 语法3.1.2 案例 3.2 case 语句3.2.1 语法3.2.2 案例 3.3 判断参数说明 4 循环语句4.1 for 循环4.1.1 语法4.1.2 案例 4.2 while循环4.2.1 语法4.2.2 案例4. 3 循环总结 5. 函数…

Prompt-to-Prompt:基于 cross-attention 控制的图像编辑技术

Hertz A, Mokady R, Tenenbaum J, et al. Prompt-to-prompt image editing with cross attention control[J]. arXiv preprint arXiv:2208.01626, 2022. Prompt-to-Prompt 是 Google 提出的一种全新的图像编辑方法,不同于任何传统方法需要用户指定编辑区域&#xff…

微信小程序开发系列-01创建一个最小的小程序项目

本文讲述了通过微信开发者工具,创建一个新的小程序项目,完全从零开始,不依赖开发者工具的模板。目的是为了更好的理解小程序工程项目的构成。 文章目录 创建一个空项目app.json全局配置pagessitemapLocation app.js 创建一个空项目 打开微信…

新型智慧视频监控系统:基于TSINGSEE青犀边缘计算AI视频识别技术的应用

边缘计算AI智能识别技术在视频监控领域的应用有很多。这项技术结合了边缘计算和人工智能技术,通过在摄像头或网关设备上运行AI算法,可以在现场实时处理和分析视频数据,从而实现智能识别和分析。目前来说,边缘计算AI视频智能技术可…

aws-waf-cdn 基于规则组的永黑解决方案

1. 新建waf 规则组 2. 为规则组添加规则 根据需求创建不同的规则 3. waf中附加规则组 (此时规则组所有规则都会附加到waf中,但是不会永黑) 此刻,可以选择测试下规则是否生效,测试前确认保护资源绑定无误 4. 创建堆…

02 - Kbuild子系统(整理中)

1. Kbuild简介 Kernel build,用来编译 Linux 内核基于 GNU make 设计,对 Makefile 进行扩充 菜单式配置:Kconfig预定义目标和变量:xx_defconfig、menuconfig、obj-y跨平台工具、递归式 Makefile Linux 模块化设计、高度可以裁剪 …

java开发面试:常见业务场景之单点登录SSO(JWT)、权限认证、上传数据的安全性的控制、项目中遇到的问题、日志采集(ELK)、快速定位系统的瓶颈

单点登录(SSO) 单点登录,Single Sign On(简称SSO),只需要登录一次,就可以访问所有信任的应用系统。 如果是单个tomcat服务,session可以共享,如果是多个tomcat,那么服务s…

tcp 的限制 (TCP_WRAPPERS)

#江南的江 #每日鸡汤:青春是打开了就合不上的书,人生是踏上了就回不了头的路,爱情是扔出了就收不回的赌注。 #初心和目标:拿到高级网络工程师 TCP_WRAPPERs Tcp_wrappers 对于七层模型中是位于第四层的安全工具,他…

微信小程序 动态设置状态栏样式

onLoad(options) {//修改状态栏标题wx.setNavigationBarTitle({title: 页面标题, //页面标题success: () > {}, //接口调用成功的回调函数fail: () > {}, //接口调用失败的回调函数complete: () > {} //接口调用结束的回调函数(调用成功、失败…

[CVPR 2023:3D Gaussian Splatting:实时的神经场渲染]

文章目录 前言小结 原文地址:https://blog.csdn.net/qq_45752541/article/details/132854115 前言 mesh 和点是最常见的3D场景表示,因为它们是显式的,非常适合于快速的基于GPU/CUDA的栅格化。相比之下,最近的神经辐射场&#xf…

从0开始学Git指令

从0开始学Git指令 因为网上的git文章优劣难评,大部分没有实操展示,所以打算自己从头整理一份完整的git实战教程,希望对大家能够起到帮助! 初始化一个Git仓库,使用git init命令。 添加文件到Git仓库,分两步…

智能优化算法应用:基于金鹰算法3D无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用:基于金鹰算法3D无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用:基于金鹰算法3D无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.金鹰算法4.实验参数设定5.算法结果6.参考文献7.MA…

数据库学习日常案例20231221-oracle libray cache lock分析

1 问题概述: 阻塞的源头为两个ddl操作导致大量的libray cache lock 其中1133为gis sde的create table as语句。 其中697为alter index语句。

案例125:基于微信小程序的个人健康数据管理系统的设计与实现

文末获取源码 开发语言:Java 框架:SSM JDK版本:JDK1.8 数据库:mysql 5.7 开发软件:eclipse/myeclipse/idea Maven包:Maven3.5.4 小程序框架:uniapp 小程序开发软件:HBuilder X 小程序…

龙蜥开源操作系统能解决CentOS 停服造成的空缺吗?

龙蜥开源操作系统能解决CentOS 停服造成的空缺吗? 本文图片来源于龙蜥,仅做介绍时引用用途,版权归属龙蜥和相关设计人员。 一、《国产服务器操作系统发展报告(2023)》称操作系统已步入 2.0 时代,服务器操作…

MacOS+Homebrew+iTerm2+oh my zsh+powerlevel10k美化教程

MacOS终端 你是否已厌倦了MacOS终端的大黑屏? 你是否对这种美观的终端抱有兴趣? 那么,接下来我将会教你用最简单的方式来搭建一套自己的终端。 Homebrew的安装 官网地址:Homebrew — The Missing Package Manager for macOS (o…

Hbase的安装配置

注:本文默认已经完成hadoop的下载以及环境配置 1.上传zookeeper和hbase压缩包到指令路径并且解压 (理论上讲,hbase其实内置了zookeeper,我们也可以不另外下载,另外下载的目的在于减少组件间依赖性) cd /home mkir hbase cd /hom…

泽攸科技SEM台式扫描电子显微镜

泽攸科技是一家国产的科学仪器公司,专注于研发、生产和销售原位电镜解决方案、扫描电镜整机、台阶仪、探针台等仪器。目前台式扫描电镜分为三个系列:ZEM15、ZEM18、ZEM20。 ZEM15台式扫描电镜: ZEM18台式扫描电镜: ZEM20台式扫描…

WeakMap 和 WeakSet:解决内存泄漏避免循环引用(下)

🤍 前端开发工程师(主业)、技术博主(副业)、已过CET6 🍨 阿珊和她的猫_CSDN个人主页 🕠 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 🍚 蓝桥云课签约作者、已在蓝桥云…

Debian在升级过程中报错

当我们在升级的过程中出现如下报错信息 报错信息如下所示: The following signatures couldnt be verified because the public key is not available: NO_PUBKEY ED444FF07D8D0BF6 W: GPG error: http://mirrors.jevincanders.net/kali kali-rolling InRelease: …