(三)Pytorch快速搭建卷积神经网络模型实现手写数字识别(代码+详细注解)

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档

文章目录

  • 前言
    • Q1:卷积网络和传统网络的区别
    • Q2:卷积神经网络的架构
    • Q3:卷积神经网络中的参数共享,也是比传统网络的优势所在
    • 4、 具体的实现代码+网络搭建


前言

深度学习pytorch系列第三篇啦,之前更了FC,NN,这篇是卷积神经网络(cNN)模型实现手写数字识别,依然是重在理解哈,具体的理解内容我都以注释的形式放在了代码中,我就直接放代码了,因为我把一些知识点和理解的东西用注释的形式写了


首先是关于卷积神经网络的一些点

Q1:卷积网络和传统网络的区别

传统网络只适合结构化数据,不适合图像数据,由于图像数据的数据量大(表现为像素点多),传统网络需要使用的参数量太大

Q2:卷积神经网络的架构

卷积神经网络包括:输入层,卷积层,池化层,全连接层
重点介绍卷积层!!
卷积就是针对每个区域去计算特征。可以这样做的原因是:图片是有像素点构成的,针对每个像素点进行处理,需要的参数量过于庞大,并且相邻的像素点之间是存在联系的
特征图的个数与卷积核的个数一致。每个卷积核通过对输入特征图进行卷积操作,生成一个输出特征图。因此,卷积核的个数决定了输出的特征图的个数。
使用不同的卷积核学习同一个位置,可以得到不同的特征图,从而使特征多样化
卷积核的大小一般使用3*3
卷积核的大小规格一般是固定的,卷积核的数量理论上是越多越好
卷积层涉及的参数有:滑动窗口步长,卷积核尺寸,边缘填充,卷积核个数
卷积结果计算公式:长:h2=(h1-Fh+2p)/s +1 宽:w2=(w1-Fw+2p)/s +1
其中:w1,h1表示输入的宽度,长度;w2和h2表示输出特征图的宽度、长度,F表示卷积核的长和宽,s表示滑动窗口的补偿,p表示边界填充
经过卷积操作后,特征图的长和宽也可以保持不变
池化层的作用就是筛选好的特征,pool是只筛选位置的,channel是全部使用的
池化也称为下采样,(一次只能下采样原来的一半,不能直接224-16)
卷积神经网络由多个block组成,重点就在于怎么设计这个block的组成
关于卷积神经网络的层数,带权重参数的就算是一层,6个conn+1个fc,就可以说是7层网络结构

Q3:卷积神经网络中的参数共享,也是比传统网络的优势所在

同一个卷积核在各个位置上的参数都是一致的
权重参数的个数与输入数据的大小无关

4、 具体的实现代码+网络搭建

# 读取数据
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets,transforms
# transforms  进行预处理,比如进行tensor转换
import matplotlib.pyplot as plt
import numpy as np
#全连接:batch*28*28,全连接各个像素点之间无关
# cnn:batch*1*28*28  ,多了一个参数channel,卷积会综合考虑一个窗口之间的关系,因此各个像素点并不是独立的,卷积网络更适合处理图像数据
# 定义超参数
input_size = 28  #图像的总尺寸28*28
num_classes = 10  #标签的种类数
num_epochs = 3  #训练的总循环周期
batch_size = 64  #一个撮(批次)的大小,64张图片
# 训练集
train_dataset = datasets.MNIST(root='./data',train=True,transform=transforms.ToTensor(),download=True)
# 测试集
test_dataset = datasets.MNIST(root='./data',train=False,transform=transforms.ToTensor())# 构建batch数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=True)
# 卷积网络模块构建
# 一般卷积层,relu层,池化层可以写成一个套餐
# 注意卷积最后结果还是一个特征图,需要把图转换成向量才能做分类或者回归任务
# 定义一个网络
class CNN(nn.Module):def __init__(self):#         构造函数# 卷积网络一般是组合进行的:conv pool relu可以当一个组合super(CNN, self).__init__()self.conv1 = nn.Sequential(  # 输入大小 (1, 28, 28)nn.Conv2d(  # 2d卷积做任务in_channels=1,  # 灰度图out_channels=16,  # 要得到几多少个特征图,就是卷积核的个数,相当于有16个卷积核kernel_size=5,  # 卷积核大小 5*5的stride=1,  # 步长padding=2,  # 如果希望卷积后大小跟原来一样,需要设置padding=(kernel_size-1)/2 if stride=1,一般是这么希望的#                                             如果不能整除pytorch采用向下取整),  # 输出的特征图为 (16, 28, 28)nn.ReLU(),  # relu层nn.MaxPool2d(kernel_size=2),  # 进行池化操作(2x2 区域), 输出结果为: (16, 14, 14),一般是pooling后是之前的一半)self.conv2 = nn.Sequential(  # 下一个套餐的输入 (16, 14, 14)nn.Conv2d(16, 32, 5, 1, 2),  # 输出 (32, 14, 14)nn.ReLU(),  # relu层nn.Conv2d(32, 32, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(2),  # 输出 (32, 7, 7))self.conv3 = nn.Sequential(  # 下一个套餐的输入 (32, 7, 7)nn.Conv2d(32, 64, 5, 1, 2),  # 输出 (64, 7, 7)nn.ReLU(),  # 输出 (64, 7, 7))# 只有pool的时候才会筛选特征self.out = nn.Linear(64 * 7 * 7, 10)  # 全连接层得到的结果,最后的任务是10分类任务,进行一个wx+b的操作去做分类def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = x.view(x.size(0), -1)  # flatten操作,结果为:(batch_size, 32 * 7 * 7),和reshape操作一样# reshape操作:总的大小是不变的,提供一个维度后,后边的维度自动计算# 比如当前的x:64*7*7,x.size:64,也就是要从三维转成两维,总的大小不变,就变为64*49这样,-1可以简单的看成一个占位符号# 变换维度,开始是64*7*7,转成batchsize*特征个数,比如64*49output = self.out(x)return output
# 定义准确率
def accuracy(predictions, labels):pred = torch.max(predictions.data, 1)[1] # 最大值是多少,最大值的索引,只要索引就可以rights = pred.eq(labels.data.view_as(pred)).sum()return rights, len(labels)
# 训练网络模型
# 实例化
net = CNN()
# 损失函数
criterion = nn.CrossEntropyLoss()
# 优化器,学习率是0.001
optimizer = optim.Adam(net.parameters(), lr=0.001)  # 定义优化器,普通的随机梯度下降算法
# 开始训练循环
for epoch in range(num_epochs):# 当前epoch的结果保存下来train_rights = []for batch_idx, (data, target) in enumerate(train_loader):  # 针对容器中的每一个批进行循环net.train()output = net(data)loss = criterion(output, target)optimizer.zero_grad()loss.backward()optimizer.step()right = accuracy(output, target)train_rights.append(right)# 每一个batch都进行训练,每一百个batch进行一次评估if batch_idx % 100 == 0:net.eval()val_rights = []for (data, target) in test_loader:output = net(data)right = accuracy(output, target)val_rights.append(right)# 准确率计算train_r = (sum([tup[0] for tup in train_rights]), sum([tup[1] for tup in train_rights]))val_r = (sum([tup[0] for tup in val_rights]), sum([tup[1] for tup in val_rights]))print('当前epoch: {} [{}/{} ({:.0f}%)]\t损失: {:.6f}\t训练集准确率: {:.2f}%\t测试集正确率: {:.2f}%'.format(epoch, batch_idx * batch_size, len(train_loader.dataset),100. * batch_idx / len(train_loader),loss.data,100. * train_r[0].numpy() / train_r[1],100. * val_r[0].numpy() / val_r[1]))

实现结果
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/179094.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux Nmap命令解析(Nmap指令)(功能:主机发现、ping扫描、arp扫描、端口扫描、服务版本检测、操作系统识别等)

文章目录 Linux Nmap 命令解析简介Nmap 的核心功能主机发现端口扫描服务版本检测OS 指纹识别(操作系统指纹识别)脚本扫描 安装 NmapNmap 命令结构Nmap 命令文档英文中文 主机发现Ping 扫描ARP 扫描关于nmap -PR(ARP Ping Scan)和n…

CentOS7.9虚拟机EDA环境,支持模拟集成电路、数字集成电路、数模混合设计全流程,包含工艺库

目录 前言一、配置准备工作1.1 网盘文件说明1.2 EDA工具介绍 二、虚拟机运行2.1 虚拟机工具启动2.2 软件配置使用2.3 Module工具切换环境变量和软件版本 获取方法附录:部分EDA工具运行效果图 前言 搭建了CentOS7.9虚拟机环境,工具包括但不限于&#xff…

json处理由fastjson换jackjson

fastjson没有jackjson稳定,所以换成jackjson来处理对象转json和json转对象问题。 首先下载jackjson包,三个都要引用 然后修改实现类 package JRT.Core.Util;import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.ja…

安防视频监控汇聚EasyNVR视频集中存储平台级联上级时下级未回复原因是什么?该如何解决?

安防监控系统EasyNVR视频云存储平台可实现设备接入、实时直播、录像、检索与回放、视频云存储、视频分发等视频能力服务,可覆盖全终端平台(pc、手机、平板等终端),在智慧工厂、智慧工地、智慧社区、智慧校园等场景中有大量落地应用…

unity3d地图、地面跟着NPC跑

清除烘焙后,再 将地图、地面的设置为非静态。只设置NPC的寻路路面为静态,再烘焙

VS2010配置opencv2.4.10

1.下载opencv2.4.10,百度网盘链接如下: 链接:https://pan.baidu.com/s/1UdoQJbRUEB_G2urT703xYQ 提取码:7lbd 2.运行opencv-2.4.10.exe,将文件提取到一个自定义目录里: 3.添加系统环境变量 在“系统变量…

持续集成交付CICD:GitLab Webhook触发Jenkins流水线

目录 一、实验 1.Jenkins远程下载GiaLab仓库代码 2.curl远程触发Jenkins流水线 3.GitLab Webhook触发Jenkins流水线 二、问题 1.GitLab配置Webhook时报错 一、实验 1.Jenkins远程下载GiaLab仓库代码 (1) Jenkins添加选项参数 (2)添加字符参数 (3)查看构建参数情况 (4)添…

C++ 背包理论基础01 + 滚动数组

背包问题的重中之重是01背包 01背包 有n件物品和一个最多能背重量为w 的背包。第i件物品的重量是weight[i],得到的价值是value[i] 。每件物品只能用一次,求解将哪些物品装入背包里物品价值总和最大。 每一件物品其实只有两个状态,取或者不…

桥接设计模式

package com.jmj.pattern.bridge;/*** 视频文件(实现化角色)*/ public interface VideoFile {void decode(String fileName); }package com.jmj.pattern.bridge;public class RmvFile implements VideoFile{Overridepublic void decode(String fileName) {System.out.println(&…

语文老师怎么和家长沟通

作为一位语文老师,深知教育不单单是传授知识,更是引导学生发展潜能,培养品格。而在这个过程中,与家长建立良好的沟通关系是至关重要的。 建立信任关系 与家长沟通的第一步是建立信任关系。作为老师,需要展现出专业、热…

堆排序(详解)

在上篇文章中,我们说利用堆的插入和删除也可以排序数据,但排序的只是堆里面的数组;同时每次排序数据都要单独写一个堆的实现,很不方便,这次就来着重讲讲如何使用堆排序。 1.建堆 给了你数据,要利用堆对数据…

开发定制化抖音票务小程序的技术解析

通过定制化抖音票务小程序,可以为用户提供更加个性化的活动体验,同时也为企业和品牌提供了更多的营销机会。 一、小程序开发框架的选择 在开发定制化抖音票务小程序之前,选择合适的小程序开发框架至关重要。目前,主流的小程序框…

Unity之ARFoundation如何实现BodyTracking人体跟踪

前言 ARBodyTracking,就是指通过手机AR扫描并精确的捕获人物的肢体部位的技术。如下图所示 这项技术目前是有苹果的ARKit提供,苹果的body tracking 功能需要使用配备 TrueDepth 摄像头的设备,配备 A12 仿生芯片、运行 iOS 13 或更高版本的设备,比如 iPhone X 及更新机型。…

【开源】基于JAVA的城市桥梁道路管理系统

项目编号: S 025 ,文末获取源码。 \color{red}{项目编号:S025,文末获取源码。} 项目编号:S025,文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块三、系统展示四、核心代码4.1 查询城市桥…

SpringBoot RestTemplate 的使用

一、简介 RestTemplate 在JDK HttpURLConnection、Apache HttpComponents、OkHttp等基础上&#xff0c;封装了更高级别的API&#xff0c;默认依赖JDK HttpURLConnection&#xff0c;连接方式默认长连接。 二、使用 2.1、引入依赖 <dependency><groupId>org.spri…

1-2-3图片的排列

目录 1.展示效果 2.基础方法源码展示 ①div部分展示 ②css部分展示 3.接口方法源码展示 scoped使用 1.展示效果 2.基础方法源码展示 ①div部分展示 <view class"container"> <view class"cover"> <im…

独家精品!git action发布electron成功的关键

首先来说git action真心是个坑爹货&#xff0c;使用起来太费劲了&#xff0c;各种报错一大堆。 再加上electron这个更坑爹的东西&#xff0c;二者合璧要把你累死一层皮。 昨天经过反复测试&#xff0c;通过无数次的失败&#xff0c;查找&#xff0c;试验&#xff0c;再失败&a…

在Linux中对Docker中的服务设置自启动

先在Linux中安装docker&#xff0c;然后对docker中的服务设置自启动。 安装docker 第一步&#xff0c;卸载旧版本docker。 若系统中已安装旧版本docker&#xff0c;则需要卸载旧版本docker以及与旧版本docker相关的依赖项。 命令&#xff1a;yum -y remove docker docker-c…

封装一个基于ThreeJS渲染基础模型的类,非常简单,可拖动可缩放

工作需求要求threeJS渲染一个模型以供可视化大屏展示&#xff0c;抛出模型精度不谈&#xff0c;只说业务实现 1.Three.JS的引入 ThreeJS官网地址:Three.js – JavaScript 3D Library 查看文档 中文切换及安装创建步骤 如果是自己研究学习用的&#xff0c;在官网安装完后&…

linux的基本指令

目录 ls指令&#xff1a; pwd指令&#xff1a; cd指令&#xff1a; touch指令&#xff1a; mkdir指令&#xff1a; rmdir指令: rm指令&#xff1a; man指令&#xff1a; mv指令&#xff1a; cat指令&#xff1a; more指令&#xff1a; less指令&#xff1a; head指…