Pytorch从零开始实战01

Pytorch从零开始实战——MNIST手写数字识别

文章目录

  • Pytorch从零开始实战——MNIST手写数字识别
    • 环境准备
    • 数据集
    • 模型选择
    • 模型训练
    • 可视化展示

环境准备

本系列基于Jupyter notebook,使用Python3.7.12,Pytorch1.7.0+cu110,torchvision0.8.0,需读者自行配置好环境且有一些深度学习理论基础。

导入需要用到的包

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision
import torch.nn.functional as F
import random
from time import time
import random
import numpy as np
import pandas as pd
import datetime
import gc
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'  # 用于避免jupyter环境突然关闭
torch.backends.cudnn.benchmark=True  # 用于加速GPU运算的代码

创建设备对象

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type=‘cuda’)

设置随机数种子

torch.manual_seed(428)
torch.cuda.manual_seed(428)
torch.cuda.manual_seed_all(428)
random.seed(428)
np.random.seed(428)

数据集

本次实战使用MNIST数据集,这是一个包含了手写数字的灰度图像的数据集,每个图像都是28x28像素大小,并且标记了相应的数字,也是很多计算机视觉初学者第一个使用的数据集。

导入训练集与测试集,使用torchvision.datasets可以在线下载很多常见数据集,只需要将后面参数设置download=True即可直接下载,train=True为训练集,train=False为测试集

# 导入训练集和测试集
train_data = torchvision.datasets.MNIST('data', train=True, transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.MNIST('data', train=False, transform=torchvision.transforms.ToTensor(),download=True)

定义一个函数,随机查看5张图片

# 随机展示5个图片 data = torchvision.datasets....  需要接受tensor格式的对象
def plotsample(data):fig, axs = plt.subplots(1, 5, figsize=(10, 10)) #建立子图for i in range(5):num = random.randint(0, len(data) - 1) #首先选取随机数,随机选取五次#抽取数据中对应的图像对象,make_grid函数可将任意格式的图像的通道数升为3,而不改变图像原始的数据#而展示图像用的imshow函数最常见的输入格式也是3通道npimg = torchvision.utils.make_grid(data[num][0]).numpy()nplabel = data[num][1] #提取标签 #将图像由(3, weight, height)转化为(weight, height, 3),并放入imshow函数中读取axs[i].imshow(np.transpose(npimg, (1, 2, 0))) axs[i].set_title(nplabel) #给每个子图加上标签axs[i].axis("off") #消除每个子图的坐标轴plotsample(train_data)

在这里插入图片描述

使用DataLoder将它按照batch_size批量划分,并将训练集顺序打乱。

batch_size = 32
train_dl = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_dl = torch.utils.data.DataLoader(test_data, batch_size=batch_size)

模型选择

由于数据集较为简单,所以本次实验使用简单的卷积神经网络。

第一次卷积和池化:
self.conv1 是第一个卷积层,将输入特征图的通道数从1增加到32,同时使用3x3的卷积核进行卷积。由于没有填充(padding)操作,卷积后的特征图大小减小为原来的大小减2(28x28 -> 26x26)。
self.pool1 是第一个最大池化层,将特征图的大小减半,从26x26变为13x13。
第二次卷积和池化:
self.conv2 是第二个卷积层,将输入特征图的通道数从32增加到64,同样使用3x3的卷积核进行卷积。由于没有填充操作,卷积后的特征图大小再次减小为原来的大小减2(13x13 -> 11x11)。
self.pool2 是第二个最大池化层,将特征图的大小再次减半,从11x11变为5x5。
全连接层:
在进入全连接层之前,需要将最后一个池化层的输出拉平成一个一维向量。这是通过 torch.flatten(x, start_dim=1) 完成的,它将5x5x64的三维张量转换为长度为5x5x64 = 1600的一维向量。
然后,self.fc1 是第一个全连接层,将1600个输入特征映射到64个输出特征。
最后进行10分类输出结果。

num_classes = 10 # 10分类
class Model(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3)self.pool1 = nn.MaxPool2d(2)self.conv2 = nn.Conv2d(32, 64, kernel_size=3)self.pool2 = nn.MaxPool2d(2)self.fc1 = nn.Linear(1600, 64)self.fc2 = nn.Linear(64, num_classes)def forward(self, x):x = self.pool1(F.relu(self.conv1(x)))x = self.pool2(F.relu(self.conv2(x)))x = torch.flatten(x, start_dim=1) # 拉平x = F.relu(self.fc1(x))x = self.fc2(x)return x

将模型转移到GPU中,并使用summary查看模型

from torchinfo import summary
# 将模型转移到GPU中
model = Model().to(device)
summary(model)

在这里插入图片描述

模型训练

定义损失函数、学习率、优化算法

loss_fn = nn.CrossEntropyLoss()
learn_rate = 0.01
opt = torch.optim.SGD(model.parameters(), lr=learn_rate)

定义训练函数,返回一个epoch的模型的准确率和损失

def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)num_batches = len(dataloader)train_loss, train_acc = 0, 0for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)loss = loss_fn(pred, y)optimizer.zero_grad()loss.backward()optimizer.step()train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc /= sizetrain_loss /= num_batchesreturn train_acc, train_loss

定义测试函数,与训练函数类似,只是停止梯度更新,节省计算内存消耗

def test (dataloader, model, loss_fn):size = len(dataloader.dataset) num_batches = len(dataloader)         test_loss, test_acc = 0, 0with torch.no_grad():for X, target in dataloader:X, target = X.to(device), target.to(device)pred = model(X)loss = loss_fn(pred, target)test_acc += (pred.argmax(1) == target).type(torch.float).sum().item()test_loss += loss.item()test_acc /= sizetest_loss /= num_batchesreturn test_acc, test_loss

开始训练,一共进行了5轮epoch,最后在训练集准确率可达97.7%,测试集准确率可达98.1%

epochs = 5
train_loss = []
train_acc = []
test_loss = []
test_acc = []for epoch in range(epochs):model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)model.eval() # 确保模型不会进行训练操作epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)print("epoch:%d, train_acc:%.1f%%, train_loss:%.3f, test_acc:%.1f%%, test_loss:%.3f"% (epoch + 1, epoch_train_acc * 100, epoch_train_loss, epoch_test_acc * 100, epoch_test_loss))
print("Done")

可视化展示

使用matplotlib进行训练、测试的可视化

plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 100        #分辨率epochs_range = range(epochs)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/69366.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【二等奖方案】大规模金融图数据中异常风险行为模式挖掘赛题「冀科数字」解题思路

第十届CCF大数据与计算智能大赛(2022 CCF BDCI)已圆满结束,大赛官方竞赛平台DataFountain(简称DF平台)正在陆续释出各赛题获奖队伍的方案思路,欢迎广大数据科学家交流讨论。 本方案为【大规模金融图数据中…

Android10 SystemUI系列(一)概述

一、前言 由于笔者之前负责过SystemUI,之前没有抽空把很多东西整理出来,趁着最近不太忙,就慢慢动手梳理一下,顺便把自己遇到的问题也整理一下,当然自己之前主要看的是android11 之后的源码。这次主要是Android10 的源码,当然原理大差不差,也算是自己沉淀一下了 二、Sy…

Github 下载指定文件夹(git sparse-checkout)

比如要下载这里的 data_utils 步骤 1、新建空文件夹,并进入新建的空文件夹。 2、git init 初始化 3、git remote add origin 添加远程仓库 4、git config core.sparsecheckout true 允许稀疏检出 5、git sparse-checkout set 设置需要拉取的文件夹(可…

面试问题记录一 --- C++(Qt方向)

以下是我于2023年6~7月间换工作时遇到的面试题目,有需要的小伙伴可以参考下。约100个题目。 1 C和C++的区别 1) 文件区别:C源文件后缀 .c;C++源文件后缀 .cpp 2) 返回值: C默认返回int型;C++ 若无返回值,必须指定为void 3) 参数列表:C默认接收多个…

zookeeper-3.6.4集群搭建

1、上传zookeeper安装包并解压 上传路径:/opt/software/ 解压路径:/opt/module/ 2、创建数据目录及日志目录 #数据目录:/data/zookeeper/data/ #3台机器创建存储目录: sudo mkdir -p /data/zookeeper/data#日志目录&#xff1a…

Docker Desktop 设置镜像环境变量

点击run 展开Optional settings container name :容器名称 Ports:根据你需要的端口进行输入,不输入则默认 后面这个 比如我这个 5432 Volumes:卷,也就是做持久化 需要docker 数据保存的地方 Environment variables…

Mysql中having语句与where语句的用法与区别

分析&回答 我们在写sql语句的时候,经常会使用where语句,很少会用到having,其实在mysql中having子句也是设定条件的语句与where有相似之处但也有区别。having子句在查询过程中慢于聚合语句(sum,min,max,avg,count)。而where子句在查询过程中则快于聚合语句(sum,min,max,avg…

解决C++ 遇笔试题输入[[1,2,3,...,],[5,6,...,],...,[3,1,2,...,]]问题

目录 0 引言1 思路2 测试结果3 完整代码4 总结 0 引言 现在面临找工作问题,做了几场笔试,遇到了一个比较棘手的题目就是题目输入形式如下: [ [3,1,1], [3,5,3], [3,2,1] ] 当时遇到这个问题还是比较慌的,主要是之前没有遇到这样的…

【STM32】锁存器

问题背景 在学习FSMC控制外部NOR存储器时,看到在NOR复用接口模式下,AD信号[15:0]是复用的。也就是说,若不使用锁存器:当NADV为低时,ADx(x0…15)上出现地址信号Ax,当NADV变高时,ADx上出现数据信号Dx。若使用…

9.3.3网络原理(网络层IP)

一.报文: 1.4位版本号:IPv4和IPv6(其它可能是实验室版本). 2.4位首部长度:和TCP一样,可变长,带选项,单位是4字节. 3.8位服务类型 4.16位总长度:IP报头 IP载荷 传输层是不知道载荷长度的,需要网络层来计算. IP报文 - IP报头 IP载荷 TCP报文 TCP载荷 IP载荷(TCP报文) …

Golang编写客户端SDK,并开源发布包到GitHub,供其他项目import使用

目录 编写客户端SDK,并开源发布包到GitHub1. 创建 GitHub 仓库2. 构建项目,编写代码Go 代码示例:项目目录结构展示: 3. 提交代码到 GitHub仓库4. 发布版本5. 现在其他人可以引用使用你的模块包了 编写客户端SDK,并开源…

Vue项目案例-头条新闻

目录 1.项目介绍 1.1项目功能 1.2数据接口 1.3设计思路 2.创建项目并安装依赖 2.1创建步骤 2.2工程目录结构 2.3配置文件代码 3.App主组件开发 3.1设计思路 3.2对应代码 4.共通组件开发 4.1设计思路 4.2对应代码 5.头条新闻组件开发 5.1设计思路 5.2对应代码 …

Xcode打包ipa文件,查看app包内文件

1、Xcode发布ipa文件前,在info中打开如下两个选项,即可在手机上查看app包名文件夹下的文件及数据。

postman9.12.汉化版(附有下载链接)

想用英文版本的可以直接点击下载最新版本 这里直接付上9.12.2版本的下载链接,如果大家要下载别的版本,可以直接修改链接里面的版本号即可 ,下面是汉化包下载 链接:https://pan.baidu.com/s/1izK3HfqlfXJdq6KIYeJ2zw?pwdpetk 提…

【数据结构】2015统考真题 6

题目描述 【2015统考真题】求下面的带权图的最小(代价)生成树时,可能是Kruskal算法第2次选中但不是Prim算法(从v4开始)第2次选中的边是(C) A. (V1, V3) B. (V1, V4) C. (V2, V3) D. (V3, V4) …

划分Vlan时需要注意的问题

网络部分2019年才开始学习的,在学习过程中配置了整个公司的网络,心里才有了一点把握,算是掌握了最基本的。 不会的就上网学,反正网络上什么知识都有,只要有需求就对照着学,很长时间没有学习网络了&#xff…

567. 字符串的排列

我写了首诗&#xff0c;把滑动窗口算法变成了默写题 | labuladong 的算法小抄 (gitee.io) windows放窗口里需要统计的元素 class Solution { public:bool checkInclusion(string s1, string s2) {int left 0;int right 0;int flag 0;map<char, int> need;for (int …

【计算机组成 课程笔记】5.1 处理器的设计步骤

课程链接&#xff1a; 计算机组成_北京大学_中国大学MOOC(慕课) 5 - 1 - 501-处理器的设计步骤&#xff08;14-49--&#xff09;_哔哩哔哩_bilibili 处理器&#xff0c;或者说是CPU&#xff0c;是现代计算机中最为复杂的一个部件。不过先不要劝退&#xff0c;要设计一个简单但是…

如何检测勒索软件攻击

什么是勒索软件 勒索软件又称勒索病毒&#xff0c;是一种特殊的恶意软件&#xff0c;又被归类为“阻断访问式攻击”&#xff08;denial-of-access attack&#xff09;&#xff0c;与其他病毒最大的不同在于攻击方法以及中毒方式。 攻击方法&#xff1a;攻击它采用技术手段限制…

若依 MyBatis改为MyBatis-Plus

主要内容&#xff1a;升级成mybatis-plus&#xff0c;代码生成也是mybatis-plus版本 跟着我一步一步来&#xff0c;就可完成升级&#xff01; 检查&#xff1a;启动程序&#xff0c;先保证若依能启动 第一步&#xff1a;添加依赖 这里需要在两个地方添加&#xff0c;一个是最…