【Pytorch】学习记录分享6——PyTorch经典网络 ResNet与手写体识别

【Pytorch】学习记录分享5——PyTorch经典网络 ResNet

      • 1. ResNet (残差网络)基础知识
      • 2. 感受野
      • 3. 手写体数字识别
        • 3. 0 数据集(训练与测试集)
        • 3. 1 数据加载
        • 3. 2 函数实现:
        • 3. 3 训练及其测试:

1. ResNet (残差网络)基础知识

图1 56层error比20层error高,提出ResNet (残差网络)的方案
在这里插入图片描述

网络效果:

在这里插入图片描述
网络结构:
在这里插入图片描述
在这里插入图片描述

2. 感受野

在这里插入图片描述
在这里插入图片描述

3. 手写体数字识别

3. 0 数据集(训练与测试集)

mnist 用于手写体训练与测试,这里包含完整的链接

3. 1 数据加载
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets,transforms 
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
### 首先读取数据
# - 分别构建训练集和测试集(验证集)
# - DataLoader来迭代取数据# 定义超参数 
input_size = 28  #图像的总尺寸28*28
num_classes = 10  #标签的种类数
num_epochs = 3  #训练的总循环周期
batch_size = 64  #一个撮(批次)的大小,64张图片# 训练集
train_dataset = datasets.MNIST(root='./data',  train=True,   transform=transforms.ToTensor(),  download=True) # 测试集
test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())# 构建batch数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

在这里插入图片描述

3. 2 函数实现:
# 卷积网络模块构建
# 一般卷积层,relu层,池化层可以写成一个套餐
# 注意卷积最后结果还是一个特征图,需要把图转换成向量才能做分类或者回归任务class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Sequential(         # 输入大小 (1, 28, 28)nn.Conv2d(in_channels=1,              # 灰度图out_channels=16,            # 要得到几多少个特征图kernel_size=5,              # 卷积核大小stride=1,                   # 步长padding=2,                  # 如果希望卷积后大小跟原来一样,需要设置padding=(kernel_size-1)/2 if stride=1),                              # 输出的特征图为 (16, 28, 28)nn.ReLU(),                      # relu层nn.MaxPool2d(kernel_size=2),    # 进行池化操作(2x2 区域), 输出结果为: (16, 14, 14))self.conv2 = nn.Sequential(         # 下一个套餐的输入 (16, 14, 14)nn.Conv2d(16, 32, 5, 1, 2),     # 输出 (32, 14, 14)nn.ReLU(),                      # relu层nn.MaxPool2d(2),                # 输出 (32, 7, 7))self.out = nn.Linear(32 * 7 * 7, 10)   # 全连接层得到的结果def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = x.view(x.size(0), -1)           # flatten操作,结果为:(batch_size, 32 * 7 * 7)  output = self.out(x)return output# 准确率作为评估标准
def accuracy(predictions, labels):pred = torch.max(predictions.data, 1)[1] rights = pred.eq(labels.data.view_as(pred)).sum() return rights, len(labels) 
3. 3 训练及其测试:
# 训练网络模型
# 实例化
net = CNN() 
#损失函数
criterion = nn.CrossEntropyLoss() 
#优化器
optimizer = optim.Adam(net.parameters(), lr=0.001) #定义优化器,普通的随机梯度下降算法#开始训练循环
for epoch in range(num_epochs):#当前epoch的结果保存下来train_rights = []for batch_idx, (data, target) in enumerate(train_loader):  #针对容器中的每一个批进行循环net.train()  # 将模型设置为训练模式output = net(data)  # 使用模型进行前向传播loss = criterion(output, target)  # 计算损失optimizer.zero_grad()  # 梯度清零loss.backward()  # 反向传播计算梯度optimizer.step()  # 更新参数right = accuracy(output, target)  # 计算当前批次的准确率train_rights.append(right)  # 将准确率保存起来if batch_idx % 500 == 0:  # 每500个批次进行一次验证net.eval()  # 将模型设置为评估模式val_rights = []  # 存储验证集的准确率for (data, target) in test_loader:  # 在测试集上进行验证output = net(data)  # 使用模型进行前向传播right = accuracy(output, target)  # 计算验证集上的准确率val_rights.append(right)  # 将准确率保存起来#准确率计算train_r = (sum([tup[0] for tup in train_rights]), sum([tup[1] for tup in train_rights]))  # 计算训练集准确率的分子和分母val_r = (sum([tup[0] for tup in val_rights]), sum([tup[1] for tup in val_rights]))  # 计算验证集准确率的分子和分母print('当前epoch: {} [{}/{} ({:.0f}%)]\t损失: {:.6f}\t训练集准确率: {:.2f}%\t测试集正确率: {:.2f}%'.format(epoch, batch_idx * batch_size, len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.data, 100. * train_r[0].numpy() / train_r[1],100. * val_r[0].numpy() / val_r[1]))  # 打印当前进度和准确率信息

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/237455.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

volatile关键字

1.什么是volatile? 1.1.volatile是一种同步机制,比synchronized或Lock更轻量级,因为使用volatile并不会发生线程“上下 文切换”等开销很大的行为,volatile关键字只是把被修饰的变量修改后刷新到“主内存”中; 1.2.如果一个变量被volatile修饰,那么JVM就知道这个变量可能会被…

Bash 脚本学习

文章目录 1、脚本编程基础2. 变量2.1 参数变量的引用2.2 环境变量 3 条件判断语句3.1 if 语句3.1.1 语法3.1.2 案例 3.2 case 语句3.2.1 语法3.2.2 案例 3.3 判断参数说明 4 循环语句4.1 for 循环4.1.1 语法4.1.2 案例 4.2 while循环4.2.1 语法4.2.2 案例4. 3 循环总结 5. 函数…

Prompt-to-Prompt:基于 cross-attention 控制的图像编辑技术

Hertz A, Mokady R, Tenenbaum J, et al. Prompt-to-prompt image editing with cross attention control[J]. arXiv preprint arXiv:2208.01626, 2022. Prompt-to-Prompt 是 Google 提出的一种全新的图像编辑方法,不同于任何传统方法需要用户指定编辑区域&#xff…

微信小程序开发系列-01创建一个最小的小程序项目

本文讲述了通过微信开发者工具,创建一个新的小程序项目,完全从零开始,不依赖开发者工具的模板。目的是为了更好的理解小程序工程项目的构成。 文章目录 创建一个空项目app.json全局配置pagessitemapLocation app.js 创建一个空项目 打开微信…

新型智慧视频监控系统:基于TSINGSEE青犀边缘计算AI视频识别技术的应用

边缘计算AI智能识别技术在视频监控领域的应用有很多。这项技术结合了边缘计算和人工智能技术,通过在摄像头或网关设备上运行AI算法,可以在现场实时处理和分析视频数据,从而实现智能识别和分析。目前来说,边缘计算AI视频智能技术可…

aws-waf-cdn 基于规则组的永黑解决方案

1. 新建waf 规则组 2. 为规则组添加规则 根据需求创建不同的规则 3. waf中附加规则组 (此时规则组所有规则都会附加到waf中,但是不会永黑) 此刻,可以选择测试下规则是否生效,测试前确认保护资源绑定无误 4. 创建堆…

大数据开发职业规划

大数据开发职业规划 我的学历是双非本,在学校学习的是大数据专业,目前是在企业做大数据全栈的工作,爬虫,数仓,风控项目,etl开发都做 .................................................................…

FFmpeg实现rtp推流

以下是一个简单的示例代码&#xff0c;演示了如何使用 UDP 或 TCP 进行音视频传输&#xff1a; 代码未经验证&#xff0c;供参考 #include <stdio.h> #include <stdlib.h> #include <string.h> #include <unistd.h> #include <sys/types.h> #in…

Mac设置ll永久生效,设置.bash_profile生效

Mac设置ll永久生效&#xff0c;设置.bash_profile生效 前言&#xff1a;Mac上自带的终端不好用&#xff0c;一般我推荐ITerm终端&#xff0c;官网下载即可 如果想只生效一次&#xff1a; 直接在终端执行alias llls -l即可 如果想永久生效&#xff1a; vim ~/.bash_profile&…

02 - Kbuild子系统(整理中)

1. Kbuild简介 Kernel build&#xff0c;用来编译 Linux 内核基于 GNU make 设计&#xff0c;对 Makefile 进行扩充 菜单式配置&#xff1a;Kconfig预定义目标和变量&#xff1a;xx_defconfig、menuconfig、obj-y跨平台工具、递归式 Makefile Linux 模块化设计、高度可以裁剪 …

java开发面试:常见业务场景之单点登录SSO(JWT)、权限认证、上传数据的安全性的控制、项目中遇到的问题、日志采集(ELK)、快速定位系统的瓶颈

单点登录&#xff08;SSO&#xff09; 单点登录&#xff0c;Single Sign On&#xff08;简称SSO&#xff09;,只需要登录一次&#xff0c;就可以访问所有信任的应用系统。 如果是单个tomcat服务&#xff0c;session可以共享&#xff0c;如果是多个tomcat&#xff0c;那么服务s…

tcp 的限制 (TCP_WRAPPERS)

#江南的江 #每日鸡汤&#xff1a;青春是打开了就合不上的书&#xff0c;人生是踏上了就回不了头的路&#xff0c;爱情是扔出了就收不回的赌注。 #初心和目标&#xff1a;拿到高级网络工程师 TCP_WRAPPERs Tcp_wrappers 对于七层模型中是位于第四层的安全工具&#xff0c;他…

微信小程序 动态设置状态栏样式

onLoad(options) {//修改状态栏标题wx.setNavigationBarTitle({title: 页面标题, //页面标题success: () > {}, //接口调用成功的回调函数fail: () > {}, //接口调用失败的回调函数complete: () > {} //接口调用结束的回调函数&#xff08;调用成功、失败…

[CVPR 2023:3D Gaussian Splatting:实时的神经场渲染]

文章目录 前言小结 原文地址&#xff1a;https://blog.csdn.net/qq_45752541/article/details/132854115 前言 mesh 和点是最常见的3D场景表示&#xff0c;因为它们是显式的&#xff0c;非常适合于快速的基于GPU/CUDA的栅格化。相比之下&#xff0c;最近的神经辐射场&#xf…

从0开始学Git指令

从0开始学Git指令 因为网上的git文章优劣难评&#xff0c;大部分没有实操展示&#xff0c;所以打算自己从头整理一份完整的git实战教程&#xff0c;希望对大家能够起到帮助&#xff01; 初始化一个Git仓库&#xff0c;使用git init命令。 添加文件到Git仓库&#xff0c;分两步…

Ceph基本环境配置

基本环境准备 准备三台服务器&#xff08;服务器至少需要添加两块磁盘&#xff09;以及一台客户端&#xff0c;最好配置时间同步同时再次配置hosts解析。 node1192.168.134.160node2192.168.134.161node3192.168.134.162node4192.168.134.163客户端 配置免密登录 [rootnode…

智能优化算法应用:基于金鹰算法3D无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用&#xff1a;基于金鹰算法3D无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用&#xff1a;基于金鹰算法3D无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.金鹰算法4.实验参数设定5.算法结果6.参考文献7.MA…

多模态大模型机器人VIMA讲解与资料整理 #持续更新

简单录制了一个视频&#xff0c;欢迎查看&#xff0c;希望对您有帮助&#xff1a; 具身智能开源工作试玩&#xff1a;VIMA_哔哩哔哩_bilibili 参考资料&#xff1a; 多模态prompt通往通用机器人操控——VIMA论文解读 - 知乎 (zhihu.com) 40. VIMA&#xff1a;基于多模态输入…

数据库学习日常案例20231221-oracle libray cache lock分析

1 问题概述&#xff1a; 阻塞的源头为两个ddl操作导致大量的libray cache lock 其中1133为gis sde的create table as语句。 其中697为alter index语句。

第二十一章网络通讯

网络程序设计基础 局域网与互联网 为了实现两台计算机的通信&#xff0c;必须用一个网络线路连接两台计算机。如下图所示 网络协议 1.IP协议 IP是Internet Protocol的简称&#xff0c;是一种网络协议。Internet 网络采用的协议是TCP/IP协议&#xff0c;其全称是Transmission…