简单搭建卷积神经网络实现手写数字10分类

搭建卷积神经网络实现手写数字10分类

1.思路流程

1.导入minest数据集

2.对数据进行预处理

3.构建卷积神经网络模型

4.训练模型,评估模型

5.用模型进行训练预测

一.导入minest数据集

 

MNIST--->raw--->test-->(0,1,2...) 10个文件夹

MNIST--->raw--->train-->(0,1,2...) 10个文件夹

共60000张图片.可自己去网上下载

二.对数据进行预处理

----读取图片,将图片先转为张量。

img=cv2.imread(path)

----将图片进行归一化,即将像素值标准化到0-1之间

img_tensor=util.transforms_train(img)

----裁剪,翻转等,实现数据增强。

数据增强:通过对原始图像进行旋转、翻转等操作,可以增加数据的多样性。这有助于模型学习到更具泛化性的特征,减少对特定方向或位置的依赖,从而提高模型的鲁棒性和准确性

transforms_train=transforms.Compose([# transforms.CenterCrop(10),# transforms.PILToTensor(),transforms.ToTensor(),#归一化,转tensortransforms.Resize((28,28)),transforms.RandomVerticalFlip()
])

ps:为什么要归一化

  1. 消除量纲影响:不同图像的像素值范围可能差异很大。归一化可以将像素值范围统一到一个特定的区间,例如 [0, 1] 或 [-1, 1],消除不同图像之间因像素值范围差异带来的影响,使模型更关注图像的特征和结构,而不是像素值的绝对大小。

  2. 提高训练稳定性:有助于优化算法的收敛性和稳定性。如果像素值范围较大且分布不均匀,可能导致梯度计算不稳定,从而影响模型的训练效率和效果。

  3. 缓解过拟合:一定程度上可以减少数据中的噪声和异常值对模型的影响,降低模型对某些特定像素值的过度依赖,从而提高模型的泛化能力,减少过拟合的风险。

三.构建卷积神经网络模型

常见卷积神经网络(CNN),主要由卷积,池化,全连接组成。卷积核在输入图像上滑动,通过卷积运算提取局部特征。卷积核在整个图像上重复使用,大大减少了模型的参数数量,降低了计算复杂度,同时也增强了模型对平移不变性的鲁棒性。池化层对特征进行压缩,提取主要特征,减少噪声和冗余信息。

x=torch.randn(2,3,28,28)

用x表示初始图形的信息。为了简单理解,简单表述。其中

2--->两张图片

3--->图片的通道数是3个,即 RGB

28,28--->图片的宽高是28px 28px

采用以上的神经网络conv为卷积操作,maxpool为池化。Linear为全连接。relu为激活函数。

进入全连接层时需要将展平。torch.Size([2, 16, 5, 5])--->torch.Size([2, 400])

x=torch.flatten(x,1)

因为全连接是只进行的线性的变化。所以要把每张图片的维数参数降为1。

使用print(summary(net, x))可查看网络的层次结构。其中-1就表示自己算,是多少张图片就是多少

输入的的是x=torch.randn(2,3,28,28),最终输出的是(2,10)

四.训练模型,评估模型

需要初始化之前的数据和网络,然后选择合适的优化器和损失函数,学习率和加载图片的批次去训练模型。使用loss_avg和accurary来评估模型的性能。对于pytorch来说优化器可以实现自动梯度清0,自动更新参数。我们需要主要的是就实现其中的维度的转化。loss越小越接近真实值。其中计算精度的方法使用one-hot编码。其中0表示[0,0,0,0,0,0,0,0,0,0],1表示[0,1,0,0,0,0,0,0,0,0],2表示[0,0,1,0,0,0,0,0,0,0].。。。其他依次类推。我们把用网络得出的参数,类似[0.1,0.2,0.1,0.5,0,0,0,0,0,0](数据我随便写的),然后用Python的argmax去处最大值的索引与one-hot真实值的索引相比,如果相等就是正确的结果。

----本次实验使用的是MSE损失函数

----lr(学习率)设为0.01

----使用的优化器Adam ,其实其他优化器你也可以随便试试。

Adam 算法的主要优点包括:

  1. 自适应学习率:能够为每个参数自适应地调整学习率。

  2. 偏差校正:在初始阶段对梯度估计进行校正,加速初期的学习速率。

  3. 适应性强:在很多不同的模型和数据集上都表现出良好的性能。

  4. 实现简单,计算高效,对内存需求少。

使用tensorboard进行可视化

五.用模型进行训练预测

需要读取之前训练好的模型,然后用这个模型来实现预测一个自己手写的图片

    # 加载整个模型loaded_model = torch.load('whole_model.pth')
​# 保存模型参数torch.save(loaded_model.state_dict(),'model_params.pth')

代码附上:

dataset.py

import glob
import os.path
​
import cv2
import torch
import util
​
class DataAndLabel:def __init__(self,path='D:\\0MNIST\\raw',is_train=True):super().__init__()# 拼接路径#data里面是path,labelclas='train' if is_train==True else 'test'path=os.path.join(path,clas)paths=glob.glob(os.path.join(path,'*','*'))# print(paths)# print(path)self.data=[]for path in paths:label=int(path.split('\\')[-2])self.data.append((path,label))def __getitem__(self, idx):#返回一个tensor,one-hotpath,label =self.data[idx]img=cv2.imread(path)# cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)img_tensor=util.transforms_train(img)one_hot=torch.zeros(10)one_hot[label]=1return img_tensor,one_hotdef __len__(self):return len(self.data)
# if __name__ == '__main__':
#     data=DataAndLabel()
#     print(data[0])
#     print()

lenet5.py

import torch
import torch.nn as nn
from torchkeras import summary
class Net(nn.Module):def __init__(self):super().__init__()self.conv1=nn.Conv2d(3,6,5,1)self.maxpool1=nn.MaxPool2d(2)self.conv2=nn.Conv2d(6,16,3,1)self.maxpool2=nn.MaxPool2d(2)self.layer1=nn.Linear(16*5*5,10)self.layer2=nn.Linear(10,10)self.relu=nn.Softmax()def forward(self,x):x=self.conv1(x)x=self.relu(x)x=self.maxpool1(x)x=self.conv2(x)x=self.relu(x)x=self.maxpool2(x)# print(x.shape)x=torch.flatten(x,1)# print(x.shape)
​x=self.layer1(x)x=self.layer2(x)return x
if __name__ == '__main__':x=torch.randn(2,3,28,28)net=Net()out=net(x)# print(out.shape)# print(summary(net, x))

train_and_test

import torch
import tqdm
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from lenet5 import Net
import torch.nn as nn
from dataset import DataAndLabel
class TrainAndTest(Dataset):def __init__(self):super().__init__()# self.writer=SummaryWriter("logs")net=Net()self.net=netself.loss=nn.MSELoss()self.opt = torch.optim.Adam(net.parameters(), lr=0.1)self.train_data=DataAndLabel(is_train=True)self.test_data=DataAndLabel(is_train=False)self.train_loader=DataLoader(self.train_data,batch_size=100,shuffle=False)self.test_loader=DataLoader(self.test_data,batch_size=100,shuffle=False)# 拿到数据,网络def train(self,epoch):loss_sum = 0accurary_sum = 0for img_tensor, label in tqdm.tqdm(self.train_loader, desc='train...', total=len(self.train_loader)):out = self.net(img_tensor)loss = self.loss(out, label)self.opt.zero_grad()loss.backward()self.opt.step()loss_sum += loss.item()accurary_sum += torch.mean(torch.eq(torch.argmax(label, dim=1), torch.argmax(out, dim=1)).to(torch.float32)).item()loss_avg = loss_sum / len(self.train_loader)accurary_avg = accurary_sum / len(self.train_loader)print(f'train---->loss_avg={round(loss_avg, 3)},accurary_avg={round(accurary_avg, 3)}')# self.writer.add_scalars('loss',{'loss_avg':loss_avg},epoch)def train1(self):sum_loss = 0sum_acc = 0for img_tensors, targets in tqdm.tqdm(self.train_loader, desc="train...", total=len(self.train_loader)):out = self.net(img_tensors)loss = self.loss(out, targets)self.opt.zero_grad()loss.backward()self.opt.step()sum_loss += loss.item()pred_cls = torch.argmax(out, dim=1)target_cls = torch.argmax(targets, dim=1)accuracy =torch.mean(torch.eq(pred_cls, target_cls).to(torch.float32))sum_acc += accuracy.item()avg_loss = sum_loss / len(self.train_loader)avg_acc = sum_acc / len(self.train_loader)print(f'train:loss{round(avg_loss, 3)} acc:{round(avg_acc, 3)}')
​
​def run(self):for epoch in range(10):self.train1()# self.test(epoch)
if __name__ == '__main__':tt=TrainAndTest()tt.run()

util.py

from torchvision import transforms
​
transforms_train=transforms.Compose([# transforms.CenterCrop(10),# transforms.PILToTensor(),transforms.ToTensor(),#归一化,转tensortransforms.Resize((28,28)),transforms.RandomVerticalFlip()
])
transforms_test=transforms.Compose([transforms.ToTensor(),  # 归一化,转tensortransforms.Resize((28, 28)),
])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/46244.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

VRRP虚拟路由冗余技术

VRRP虚拟路由冗余技术:是一种路由容错协议,用于在网络中提供路由器的冗余备份。它通过将多个路由器虚拟成一个虚拟路由器并且多个路由器之间共享一个虚拟IP地址来实现冗余和高可用性。当承担转发业务的主路由器出现故障时,其他备份路由器可以…

安全防御:防火墙概述

目录 一、信息安全 1.1 恶意程序一般会具备一下多个或全部特点 1.2 信息安全五要素: 二、了解防火墙 2.1 防火墙的核心任务 2.2 防火墙的分类 2.3 防火墙的发展历程 2.3.1 包过滤防火墙 2.3.2 应用代理防火墙 2.3.3 状态检测防火墙 补充防御设备 三、防…

骑士人才系统74cms专业版实现本地VUE打包和在线升级方法以及常见问题

骑士人才系统我就不多说了目前来说我接触的人才系统里面除了phpyun就是骑士人才了,两个历史都很悠久,总起来说功能方面各分伯仲,前几期我作过Phpyun的配置教程这次我们针对骑士人才系统说说怎么使用VUE源码本地一键打包后台和在线升级方式&am…

每日Attention学习10——Scale-Aware Modulation

模块出处 [ICCV 23] [link] [code] Scale-Aware Modulation Meet Transformer 模块名称 Scale-Aware Modulation (SAM) 模块作用 改进的自注意力 模块结构 模块代码 import torch import torch.nn as nn import torch.nn.functional as Fclass SAM(nn.Module):def __init__…

redisTemplate报错为nil,通过redis-cli查看前缀有乱码

public void set(String key, String value, long timeout) {redisTemplate.opsForValue().set(key, value, timeout, TimeUnit.SECONDS);} 改完之后 public void set(String key, String value, long timeout) {redisTemplate.setKeySerializer(new StringRedisSerializer()…

opencascade AIS_InteractiveContext源码学习8 trihedron display attributes

AIS_InteractiveContext 前言 交互上下文(Interactive Context)允许您在一个或多个视图器中管理交互对象的图形行为和选择。类方法使这一操作非常透明。需要记住的是,对于已经被交互上下文识别的交互对象,必须使用上下文方法进行…

最优化(10):牛顿类、拟牛顿类算法

4.4 牛顿类算法——介绍了经典牛顿法及其收敛性,并介绍了修正牛顿法和非精确牛顿法; 4.5 拟牛顿类算法——引入割线方程,介绍拟牛顿算法以及拟牛顿矩阵更新方式,然后给出了拟牛顿法的全局收敛性,最后介绍了有限内存BFG…

Java中创建线程的方式

文章目录 创建线程ThreadRunnableCallable线程池创建方式自定义线程池线程池工作原理阻塞队列线程池参数合理配置线程池参数 创建线程 在Java中创建一个线程,有且仅有一种方式,创建一个Thread类实例,并调用它的start方法。 Thread 最经典也…

在Linux上设置MySQL允许远程连接的完整指南

个人名片 🎓作者简介:java领域优质创作者 🌐个人主页:码农阿豪 📞工作室:新空间代码工作室(提供各种软件服务) 💌个人邮箱:[2435024119qq.com] &#x1f4f1…

【Linux】多线程_6

文章目录 九、多线程7. 生产者消费者模型生产者消费者模型的简单代码结果演示 未完待续 九、多线程 7. 生产者消费者模型 生产者消费者模型的简单代码 Makefile: cp:Main.ccg -o $ $^ -stdc11 -lpthread .PHONY:clean clean:rm -f cpThread.hpp: #i…

【Linux】Linux必备的基础指令

目录 Linux必备的基础指令一 、 什么是Linux二、 Linux常用命令2.1 ls2.2 pwd2.3 cd2.4 touch2.5 cat2.6 mkdir2.7 rm 三、 Linux重要指令3.1 cp3.2 mv3.3 tail3.4 vim3.5 grep3.6 ps3.7 netstat Linux必备的基础指令 一 、 什么是Linux 系统编程&⽹络编程 Linux⾃⾝提供…

快速掌握块级盒子垂直水平居中的几种方式

大家好!今天我们来聊聊Web开发中如何实现块级盒子水平居中。在前端开发中,经常需要将一个块级盒子在父元素中进行垂直和水平居中对齐,本文将介绍几种常见且高效的实现方法。 一、子元素有固定宽高 第一种情况 子元素有固定宽高(…

编译x-Wrt 全过程

参考自;​​​​​​c编译教程 | All about X-Wrt 需要详细了解的小伙伴还请参看原文 ^-^ 概念: x-wrt(基于openwrt深度定制的发行版本) 编译系统: ubuntu22.04 注意: 特别注意的是,整个编译过程,都是用 …

汽车的驱动力,是驱动汽车行驶的力吗?

一、地面对驱动轮的反作用力? 汽车发动机产生的转矩,经传动系传至驱动轮上。此时作用于驱动轮上的转矩Tt产生一个对地面的圆周力F0,地面对驱动轮的反作用力Ft(方向与F0相反)即是驱动汽车的外力,此外力称为汽车的驱动力。 即汽车…

知识图谱研究综述笔记

推荐导读:知识图谱Knowledge Graph Embeddings 论文标题:A Survey on Knowledge Graphs:Representation, Acquisition and Applications发表期刊:IEEE TRANSACTIONS ON NEURAL NETWORKS AND LEARNING SYSTEMS, 2021本文作者:Shaoxiong Ji, Shirui Pan, M…

Swiper轮播图实现

如上图,列表左右滚动轮播,用户鼠标移动到轮播区域,动画停止,鼠标移开轮播继续。 此例子实现技术框架是用的ReactCSS。 主要用的是css的transform和transition来实现左右切换动画效果。 React代码: import React, { us…

二叉树六道基本习题,你都会了吗?

Hello大家好呀,本博客目的在于记录暑假学习打卡,后续会整理成一个专栏,主要打算在暑假学习完数据结构,因此会发一些相关的数据结构实现的博客和一些刷的题,个人学习使用,也希望大家多多支持,有不…

手把手教你写UART(verilog)

最近工作用uart用的比较多,为了让自己更好的掌握这个协议,写了这篇文章,解读了uart程序的编写过程(程序参考了米联客的教程)。 最基础的概念 UART是用来让两个设备之间传输数据的协议,毕竟我不能直接给你一…

鸿蒙HarmonyOS应用开发为何选择ArkTS不是Java?

前言 随着智能设备的快速发展,操作系统的需求也变得越来越多样化。为了满足不同设备的需求,华为推出了鸿蒙HarmonyOS。 与传统的操作系统不同,HarmonyOS采用了一种新的开发语言——ArkTS。 但是,刚推出鸿蒙系统的时候&#xff0…

JavaScript进阶(四)---js解构

目录 一.定义: 二.类型: 1.数组解构: 1.1变量和值不匹配的情况 1.2多维数组 2.对象解构 3.对象数组解构 4.函数参数解构 5.扩展运算符 一.定义: JavaScript 中的解构(Destructuring)是一种语法糖&…