深度学习每周学习总结P1(pytorch手写数字识别)

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制

目录

    • 0. 总结
    • 1. 数据导入部分
    • 2. 模型构建部分
    • 3. 训练前的准备
    • 4. 定义训练函数
    • 5. 定义测试函数
    • 6. 训练过程

0. 总结

总结:

  • 数据导入部分:数据导入使用了torchvision自带的数据集,获取到数据后需要使用torch.utils.data中的DataLoader()加载数据

  • 模型构建部分:有两个部分一个初始化部分(init())列出了网络结构的所有层,比如卷积层池化层等。第二个部分是前向传播部分,定义了数据在各层的处理过程。

  • 训练前的准备:在这之前需要定义损失函数,学习率,以及根据学习率定义优化器(例如SGD随机梯度下降),用来在训练中更新参数,最小化损失函数。

  • 定义训练函数:函数的传入的参数有四个,分别是设置好的DataLoader(),定义好的模型,损失函数,优化器。函数内部初始化损失准确率为0,接着开始循环,使用DataLoader()获取一个批次的数据,对这个批次的数据带入模型得到预测值,然后使用损失函数计算得到损失值。接下来就是进行反向传播以及使用优化器优化参数,梯度清零放在反向传播之前或者是使用优化器优化之后都是可以的。将 optimizer.zero_grad() 放在了每个批次处理的开始,这是最标准和常见的做法。这样可以确保每次迭代处理一个新批次时,梯度是从零开始累加的。准确率是通过累计预测正确的数量得到的,处理每个批次的数据后都要不断累加正确的个数,最终的准确率是由预测正确的数量除以所有样本得数量得到的。损失值也是类似每次循环都累计损失值,最终的损失值是总的损失值除以训练批次得到的

  • 定义测试函数:函数传入的参数相比训练函数少了优化器,只需传入设置好的DataLoader(),定义好的模型,损失函数。此外除了处理批次数据时无需再设置梯度清零、返向传播以及优化器优化参数,其余部分均和训练函数保持一致。

  • 训练过程:定义训练次数,有几次就使用整个数据集进行几次训练,初始化四个空list分别存储每次训练及测试的准确率及损失。使用model.train()开启训练模式,调用训练函数得到准确率及损失。使用model.eval()将模型设置为评估模式,调用测试函数得到准确率及损失。接着就是将得到的训练及测试的准确率及损失存储到相应list中并合并打印出来,得到每一次整体训练后的准确率及损失。

  • 模型的保存,调取及使用,暂时没有看到这部分,但是训练好的模型肯定是会用到这步的,需要自己添加进去。在PyTorch中,通常使用 torch.save(model.state_dict(), ‘model.pth’) 保存模型的参数,使用 model.load_state_dict(torch.load(‘model.pth’)) 加载参数。

1. 数据导入部分

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvisiondevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")device
device(type='cpu')
print(torch.__version__) # 查看pytorch版本
1.9.0+cpu

# 导入数据
train_ds = torchvision.datasets.MNIST('data', train = True,transform = torchvision.transforms.ToTensor(),download = True
)
test_ds = torchvision.datasets.MNIST('data',train = False,transform = torchvision.transforms.ToTensor(),download = True
)
batch_size = 32# shuffle = True 意味着每次迭代数据集时,数据都会被随机打乱。
# 因此,当您从train_dl中获取一个批次的数据时,每次都可能得到不同的图像
train_dl = torch.utils.data.DataLoader(train_ds,batch_size=batch_size,shuffle = True
)
test_dl = torch.utils.data.DataLoader(test_ds,batch_size = batch_size
)
# 取一个批次查看数据格式
# 数据的shape为:[batch_size,channel,height,weight]
# 其中batch_size为自己设定,channel,height,weight分别对应图片的通道数,高度和宽度
imgs,labels = next(iter(train_dl)) # 由于数据加载器被设置为随机打乱数据(shuffle=True),因此每次调用next函数时,都会从数据集中随机选择一个批次的数据。
imgs.shape
torch.Size([32, 1, 28, 28])
import numpy as np#指定图片大小,图像大小为20宽,5高的绘图(单位为英寸inch)
plt.figure(figsize=(20,5))
for i,img in enumerate(imgs[:20]):# 维度缩减npimg = np.squeeze(img.numpy())plt.subplot(2,10,i+1) # 将整个figure分成2行10列,绘制第i+1个子图plt.imshow(npimg,cmap=plt.cm.binary)plt.axis('off') # 这行代码关闭了当前子图的坐标轴,使得图像没有任何坐标轴标签或刻度。

在这里插入图片描述

2. 模型构建部分

# 模型构建
import torch.nn.functional as Fnum_classes = 10 # 图片的类别数class Model(nn.Module):def __init__(self):super().__init__() # super(Model,self).__init__() 的简化写法# 特征提取网络self.conv1 = nn.Conv2d(1,32,kernel_size = 3) # 第一层卷积,卷积核大小为3*3self.pool1 = nn.MaxPool2d(2) # 设置池化层,池化核大小为2*2self.conv2 = nn.Conv2d(32,64,kernel_size = 3) # 第二层卷积,卷积核大小为3*3self.pool2 = nn.MaxPool2d(2)# 分类网络self.fc1 = nn.Linear(1600,64)self.fc2 = nn.Linear(64,num_classes)# 前向传播def forward(self,x):x = self.pool1(F.relu(self.conv1(x)))x = self.pool2(F.relu(self.conv2(x)))x = torch.flatten(x,start_dim=1) # x.view(x.size(0), -1) 展平张量x = F.relu(self.fc1(x))x = self.fc2(x)return x
# !pip install torchinfo -i https://pypi.mirrors.ustc.edu.cn/simple/
# 查看模型结构
model = Model()
model
Model((conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))(pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))(pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(fc1): Linear(in_features=1600, out_features=64, bias=True)(fc2): Linear(in_features=64, out_features=10, bias=True)
)
# 打印模型
from torchinfo import summarymodel = Model().to(device) # 在指定的设备(device,可能是CPU或CUDA/GPU)上实例化原始Model。summary(model) # 使用torchinfo库中的summary函数来打印模型的摘要。
=================================================================
Layer (type:depth-idx)                   Param #
=================================================================
Model                                    --
├─Conv2d: 1-1                            320
├─MaxPool2d: 1-2                         --
├─Conv2d: 1-3                            18,496
├─MaxPool2d: 1-4                         --
├─Linear: 1-5                            102,464
├─Linear: 1-6                            650
=================================================================
Total params: 121,930
Trainable params: 121,930
Non-trainable params: 0
=================================================================

3. 训练前的准备

# 训练模型# 设置超参数
loss_fn = nn.CrossEntropyLoss() # 创建损失函数
learn_rate = 1e-2 # 学习率
opt = torch.optim.SGD(model.parameters(),lr=learn_rate)

4. 定义训练函数

# 训练循环
def train(dataloader,model,loss_fn,optimizer):size = len(dataloader.dataset) # 训练集的大小,一共60000张图片num_batches = len(dataloader) # 批次数目,1875(60000/32)train_loss,train_acc = 0,0 # 初始化训练损失和正确率for X,y in dataloader: # 获取图片及其标签X,y = X.to(device),y.to(device)# 计算预测误差pred = model(X) # 网络输出loss = loss_fn(pred,y) # 计算网络输出值和真实值之间的差距,targets为真实值,计算二者差值即为损失# 反向传播optimizer.zero_grad() # grad属性归零loss.backward() # 反向传播optimizer.step() # 每一步自动更新# 记录acc与losstrain_acc += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc /= sizetrain_loss /= num_batchesreturn train_acc,train_loss

5. 定义测试函数

# 编写测试函数
def test(dataloader,model,loss_fn):size = len(dataloader.dataset) # 测试集的大小,一共10000张图片num_batches = len(dataloader) # 批次数目,313(10000/32=312.5)test_loss,test_acc = 0,0# 当不进行训练时,停止梯度更新,节省计算内存消耗with torch.no_grad():for imgs,target in dataloader:imgs,target = imgs.to(device),target.to(device)# 计算losstarget_pred = model(imgs)loss = loss_fn(target_pred,target)test_loss += loss.item()test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()test_acc /= sizetest_loss /= num_batchesreturn test_acc,test_loss

6. 训练过程

# 训练
epochs = 5
train_loss = []
train_acc = []
test_loss = []
test_acc = []for epoch in range(epochs):model.train()epoch_train_acc,epoch_train_loss = train(train_dl,model,loss_fn,opt)model.eval()epoch_test_acc,epoch_test_loss = test(test_dl,model,loss_fn)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)template = ('Epoch:{:2d},Train_acc:{:.1f}%,Train_loss:{:.3f},Test_acc:{:.1f}%,Test_loss:{:.3f}')print(template.format(epoch+1,epoch_train_acc*100,epoch_train_loss,epoch_test_acc*100,epoch_test_loss))print('Done')
C:\Users\chengyuanting\.conda\envs\pytorch_cpu\lib\site-packages\torch\nn\functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at  ..\c10/core/TensorImpl.h:1156.)return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)Epoch: 1,Train_acc:77.5%,Train_loss:0.753,Test_acc:0.2%,Test_loss:0.000
Epoch: 2,Train_acc:94.3%,Train_loss:0.190,Test_acc:0.1%,Test_loss:0.000
Epoch: 3,Train_acc:96.2%,Train_loss:0.125,Test_acc:0.2%,Test_loss:0.000
Epoch: 4,Train_acc:96.9%,Train_loss:0.098,Test_acc:0.2%,Test_loss:0.000
Epoch: 5,Train_acc:97.5%,Train_loss:0.081,Test_acc:0.2%,Test_loss:0.000
Done
# 结果可视化
import matplotlib.pyplot as plt
# 隐藏警告
import warnings
warnings.filterwarnings("ignore") # 忽略警告信息
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi'] = 100epochs_range = range(epochs)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/748122.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Hive中的explode函数、posexplode函数与later view函数

1.概述 在离线数仓处理通过HQL业务数据时,经常会遇到行转列或者列转行之类的操作,就像concat_ws之类的函数被广泛使用,今天这个也是经常要使用的拓展方法。 2.explode函数 2.1 函数语法 -- explode(a) - separates the elements of array …

014 Linux_同步

​🌈个人主页:Fan_558 🔥 系列专栏:Linux 🌹关注我💪🏻带你学更多操作系统知识 文章目录 前言一、死锁(1)死锁概念 二、同步(1)同步概念&#xff…

计算机考研|双非一战135上岸,408经验分享+复盘

计算机专业的同学真的别想的太天真! 相比于其他专业,计算机专业的同学其实还是很有优势的 但是现在随着计算机专业的同学越来越多,找工作的困难程度以及学历自然而然被卷起来了 以前的算法岗基本要求在本科以上,现在基本都是非92研…

Spring Boot 多媒体(音频/视频)文件处理FFmpegFrameGrabber 方法(例子:获取视频总时长)

1.pom.xml 坐标 <dependency><groupId>org.bytedeco</groupId><artifactId>javacv-platform</artifactId><version>1.5.6</version></dependency> 2.FFmpegFrameGrabber类提供了多种方法来处理多媒体文件&#xff0c;以下是一…

主成分分析用于数据降维

主成分分析&#xff08;Principal Component Analysis&#xff0c;PCA&#xff09;是一种常用的数据降维算法。它通过线性变换将原始数据转化为一组互相不相关的主成分&#xff0c;这些主成分能够最大程度地保留原始数据的信息。 数据降维是为了减少数据集中的特征数量&#x…

家电工厂5G智能制造数字孪生可视化平台,推进家电工业数字化转型

家电5G智能制造工厂数字孪生可视化平台&#xff0c;推进家电工业数字化转型。随着科技的飞速发展&#xff0c;家电行业正迎来一场前所未有的数字化转型。在这场制造业数字化转型中&#xff0c;家电5G智能制造工厂数字孪生可视化平台扮演着至关重要的角色。本文将从数字孪生技术…

【数据库】数据库介绍

文章目录 一、数据库介绍二、SQL分类 一、数据库介绍 什么是数据库 存储数据用文件就可以了&#xff0c;为什么还要弄个数据库? 文件保存数据有以下几个缺点&#xff1a; 文件的安全性问题 文件不利于数据查询和管理 文件不利于存储海量数据 文件在程序中控制不方便 数据库存…

提交数据加快百度搜索引擎收录

百度站长工具做了更新&#xff0c;百度收录的地址分享如下&#xff0c;新站点提交后&#xff0c;可以加快百度收录。 普通收录_加快网站内容抓取&#xff0c;快速提交数据工具_站长工具_网站支持_百度搜索资源平台普通收录工具可实时向百度推送数据&#xff0c;创建并提交site…

git 安装、创建仓库、常用命令、克隆下载、上传项目、删除分支 -- 一篇文章总结

一、git安装 1、git安装地址&#xff1a;https://git-scm.com/downloads 2、选择操作系统 3、安装自己系统对应的操作位数 4、等待下载完&#xff0c;一路next安装就可以了 5、安装完成后&#xff0c;在任意文件夹点击右键&#xff0c;看到下图说明安装成功 二、创建仓库 1…

云原生(二)、Docker基础

Docker Docker 是一种开源的容器化平台&#xff0c;用于开发、部署和运行应用程序。它允许开发者将应用程序及其所有依赖项打包到一个可移植的容器中&#xff0c;这个容器可以在任何支持 Docker 的环境中运行&#xff0c;无论是开发人员的个人笔记本电脑、测试环境、生产服务器…

ROS Kinetic通信编程:话题、服务、动作编程

文章目录 一、话题编程二、服务编程三、动作编程 接上篇&#xff0c;继续学习ROS通信编程基础 一、话题编程 步骤&#xff1a; 创建发布者 初始化ROS节点向ROS Master注册节点信息&#xff0c;包括发布的话题名和话题中的消息类型按照一定频率循环发布消息 创建订阅者 初始化…

stm32-编码器测速

一、编码器简介 编码电机 旋转编码器 A,B相分别接通道一和二的引脚&#xff0c;VCC&#xff0c;GND接单片机VCC&#xff0c;GND 二、正交编码器工作原理 以前的代码是通过触发外部中断&#xff0c;然后在中断函数里手动进行计次。使用编码器接口的好处就是节约软件资源。对于频…

从0开始回顾MySQL---事务四大特性

事务概述 事务是一个最小的工作单元。在数据库当中&#xff0c;事务表示一件完整的事儿。一个业务的完成可能需要多条DML语句共同配合才能完成&#xff0c;例如转账业务&#xff0c;需要执行两条DML语句&#xff0c;先更新张三账户的余额&#xff0c;再更新李四账户的余额&…

实现elasticsearch和数据库的数据同步

1. 数据同步 elasticsearch中的酒店数据来自于mysql数据库&#xff0c;因此mysql数据发生改变时&#xff0c;elasticsearch也必须跟着改变&#xff0c;这个就是elasticsearch与mysql之间的数据同步。 1.1. 思路分析 常见的数据同步方案有三种&#xff1a; 同步调用 异步通知…

面试题手撕篇

参考博客 开始之前&#xff0c;理解递归 手写 浅拷贝 function shallow(target){if(target instanceof Array){return [...resObj]}else{return Object.assign({},target);} }手写深拷贝 const _sampleDeepClone target > {// 补全代码return JSON.parse(JSON.stringify…

EtherCAT开源主站 IGH 介绍及主站伺服控制过程

目录 前言 IGH EtherCAT主站介绍 主要特点和功能 使用场景 SOEM 主站介绍 SOEM 的特点和功能 SOEM 的使用场景 IGH 主站 和 SOEM对比 1. 功能和复杂性 2. 资源消耗和移植性 3. 使用场景 EtherCAT 通信原理 EtherCAT主站控制伺服过程 位置规划模式 原点复归模式…

八股文打卡day32——数据库(9)

面试题&#xff1a;MySQL日志文件有哪几种&#xff1f; 我的回答&#xff1a; 1.undo log&#xff0c;撤销日志&#xff0c;是InnoDB存储引擎层面生成的日志&#xff0c;主要用于数据库事务的回滚和MVCC&#xff08;多版本并发控制&#xff09;。可以帮助数据库回滚到事务开启…

Flinkcdc通过catalog同步mysql数据到hologres的ods中

Flinkcdc通过catalog同步mysql数据到hologres的ods中大致分为以下几步: 配置Flink CDC 的MySQL catalog:CREATE CATALOG mysqlsource WITH (type = mysql,hostname = xxxx,port = xxxx,username = xxxx<

AcWing 4964.子矩阵

首先就是运用了暴力的思路&#xff0c;能够过个70%的数据&#xff0c;剩下的直接时间超时了&#xff0c;没办法优化了。 讲一下暴力的思路&#xff1a; 其实就是模拟而已&#xff0c;也就是看作想要找的矩阵为一个小窗口&#xff0c;然后不断移动的事而已。 #include<ios…

服务器远程端口故障应该如何解决并且避免

在信息化快速发展的今天&#xff0c;服务器作为支撑各类业务运行的核心设备&#xff0c;其稳定性对于企业的运营至关重要。然而&#xff0c;即使是最高端的服务器也难免会出现问题&#xff0c;其中“服务器远程端口故障”就是一个相对常见但又十分棘手的问题。德迅云安全今天就…