现代卷积网络实战系列2:PyTorch构建训练函数、LeNet网络

🌈🌈🌈现代卷积网络实战系列 总目录

本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传

1、MNIST数据集处理、加载、网络初始化、测试函数
2、训练函数、PyTorch构建LeNet网络
3、PyTorch从零构建AlexNet训练MNIST数据集
4、PyTorch从零构建VGGNet训练MNIST数据集
5、PyTorch从零构建GoogLeNet训练MNIST数据集
6、PyTorch从零构建ResNet训练MNIST数据集

4、训练函数

4.1 调用训练函数

train(epochs, net, train_loader, device, optimizer, test_loader, true_value)

因为每一个epoch训练结束后,我们需要测试一下这个网络的性能,所有会在训练函数中频繁调用测试函数,所有测试函数中所有需要的参数,训练函数都需要
这七个参数,是训练一个神经网络所需要的最少参数

4.2 训练函数

训练函数中,所有训练集进行多次迭代,而每次迭代又会将数据分成多个批次进行迭代

def train(epochs, net, train_loader, device, optimizer, test_loader, true_value):for epoch in range(1, epochs + 1):net.train()all_train_loss = []for batch_idx, (data, target) in enumerate(train_loader):data = data.to(device)target = target.to(device)optimizer.zero_grad()output = net(data)loss = F.cross_entropy(output, target)loss.backward()optimizer.step()cur_train_loss = loss.item()all_train_loss.append(cur_train_loss)train_loss = np.round(np.mean(all_train_loss) * 1000, 2)print('\nepoch step:', epoch)print('training loss: ', train_loss)test(net, test_loader, device, true_value, epoch)print("\nTraining finished")
  1. 定义训练函数
  2. 安装epochs迭代数据
  3. 进入pytorch的训练模式
  4. all_train_loss 存放训练集5万张图片的损失值
  5. 按照batch取数据
  6. 数据进入GPU
  7. 标签进入GPU
  8. 梯度清零
  9. 当前batch进入网络后得到输出
  10. 根据输出得到当前损失
  11. 反向传播
  12. 梯度下降
  13. 获取损失的损失值(PyTorch框架中的数据)
  14. 把当前batch的损失加入all_train_loss数组中,结束batch的迭代
  15. 将5张图片的损失计算出来并且进行求平均,这里乘以1000是因为我觉得计算出的损失太小了,所以乘以1000,方便看损失的变化,保留两位有效数字
  16. 打印当前epoch
  17. 打印损失
  18. 调用测试函数,测试当前训练的网络的性能,结束epoch的迭代
  19. 打印训练完成

5、LeNet

向传播来优化学习策略,而是采用的无监督学习的方案,这其实限制了Neocognitron模型。反向传播算法于1974年哈佛大学的 Paul Werbos 提出,并由LeCun于1989将反向传播算法引入了卷积神经网络并且用于手写数字识别任务上,这个就是LeNet-1,通过几年的迭代,LeNet在1998的手写体数字识别任务上取得了很大的成功,这个版本的LeNet就是著名的LeNet-5。为什么LeNet-5这么被广泛使用呢?因为LeNet-5在美国被大规模用于自动对银行支票上的手写数字进行分类。在LeNet之前,字符识别主要是通过手工特征工程来完成特征提取,然后利用机器学习模型来学习手工特征进行分类。因此,特征工程就是一个很大的问题,究竟什么样的特征是需要的特征呢?LeNet-5可以自己学习图像的特征,这就意味着,网络模型自己学习特征成为可能,手工提取特征将成为过去式。卷积还可以被看作是“滑动平均”的推广。

5.1 网络结构

LeNet可以说是首次提出卷积神经网络的模型
主要包含下面的网络层:

  1. 5*5的二维卷积
  2. sigmoid激活函数(这里使用了relu)
  3. 5*5的二维卷积
  4. sigmoid激活函数
  5. 数据一维化
  6. 全连接层
  7. 全连接层
  8. softmax分类器

将网络结构打印出来:

LeNet(
-------(conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))
-------(conv2): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))
-------(conv2_drop): Dropout2d(p=0.5, inplace=False)
-------(fc1): Linear(in_features=320, out_features=50, bias=True)
-------(fc2): Linear(in_features=50, out_features=10, bias=True)
)

5.2 PyTorch构建LeNet

class LeNet(nn.Module):def __init__(self, num_classes):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(1, 10, kernel_size=5)self.conv2 = nn.Conv2d(10, 20, kernel_size=5)self.conv2_drop = nn.Dropout2d()self.fc1 = nn.Linear(320, 50)self.fc2 = nn.Linear(50, num_classes)def forward(self, x):x = F.relu(F.max_pool2d(self.conv1(x), 2))x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))x = x.view(-1, 320)x = F.relu(self.fc1(x))x = F.dropout(x, training=self.training)x = self.fc2(x)return F.log_softmax(x, dim=1)

这个时候已经是一个完整的项目了,看看10个epoch训练过程的打印:

D:\conda\envs\pytorch\python.exe A:\0_MNIST\train.py

Reading data…
train_data: (60000, 28, 28) train_label (60000,)
test_data: (10000, 28, 28) test_label (10000,)

Initialize neural network
test loss: 2301.68
test accuracy: 11.3 %

epoch step: 1
training loss: 634.74
test loss: 158.03
test accuracy: 95.29 %

epoch step: 2
training loss: 324.04
test loss: 107.62
test accuracy: 96.55 %

epoch step: 3
training loss: 271.25
test loss: 88.43
test accuracy: 97.04 %

epoch step: 4
training loss: 236.69
test loss: 70.94
test accuracy: 97.61 %

epoch step: 5
training loss: 211.05
test loss: 69.69
test accuracy: 97.72 %

epoch step: 6
training loss: 199.28
test loss: 62.04
test accuracy: 97.98 %

epoch step: 7
training loss: 187.11
test loss: 59.65
test accuracy: 97.98 %

epoch step: 8
training loss: 178.79
test loss: 53.89
test accuracy: 98.2 %

epoch step: 9
training loss: 168.75
test loss: 51.83
test accuracy: 98.43 %

epoch step: 10
training loss: 160.83
test loss: 50.35
test accuracy: 98.4 %

Training finished
进程已结束,退出代码为 0

可以看出基本上只要一个epoch就可以得到很好的训练效果了,后续的epoch中的提升比较小

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/87338.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ndoe.js、npm相关笔记

1、npm 全局安装 npm config get prefix 获取 npm 全局安装路径如果全局插件不能正常使用,看环境变量是否已经配置。没有配置则把全局安装路径配置到环境变量的path中

npm 命令

目录 初始化 搜索 安装 删除 更新 换源 查看 其他 补充 1.初始化 npm init #初始化一个package.json文件 npm init -y | npm init --yes 2.搜索 npm s jquery | npm search jquery 3.安装 npm install npm -g #更新到最新版本 npm i uniq | npm ins…

MS933NA适用于 1MP/60fps 摄像头、37.5MHz100MHz、10 位/12 位的串化器

MS933NA 是 10 位 /12 位串化器,支持 37.5MHz  100MHz 时钟, MS933NA 广泛应用于车载摄像、医疗设备、管道探测等领域。 主要特点 ◼ 支持输入 37.5MHz 到 100MHz 的图像时钟 ◼ 单个差分对互连 ◼ 可编程数据有效负载 10 位 /12 …

django 实现:闭包表—树状结构

闭包表—树状结构数据的数据库表设计 闭包表模型 闭包表(Closure Table)是一种通过空间换时间的模型,它是用一个专门的关系表(其实这也是我们推荐的归一化方式)来记录树上节点之间的层级关系以及距离。 场景 我们 …

什么是关系模型? 关系模型的基本概念

关系模型由IBM公司研究员Edgar Frank Codd于1970年发表的论文中提出,经过多年的发展,已经成为目前最常用、最重要的模型之一。 在关系模型中有一些基本的概念,具体如下。 (1)关系(Relation)。关系一词与数学领域有关,它是集合基…

Xcode14.3.1打包报错Command PhaseScriptExecution failed with a nonzero exit code

真机运行编译正常,一打包就报错 rsync error: some files could not be transferred (code 23) at /AppleInternal/Library/BuildRoots/d9889869-120b-11ee-b796-7a03568b17ac/Library/Caches/com.apple.xbs/Sources/rsync/rsync/main.c(996) [sender2.6.9] Command PhaseScrip…

优化类问题概述

数学建模系列文章: 以下是个人在准备数模国赛时候的一些模型算法和代码整理,有空会不断更新内容: 评价模型(一)层次分析法(AHP),熵权法,TOPSIS分析 及其对应 PYTHON 实现代码和例题…

QRunnable与外界互传对象

1.概述 QRunnable与外界互通讯是有两种方法的 使用多继承。让我们的自定义线程类同时继承于QRunnable和QObject,这样就可以使用信号和槽,但是多线程使用比较麻烦,特别是继承于自定义的类时,容易出现接口混乱,所以在项…

数据通信——应用层(域名系统)

引言 TCP到此就告一段落,这也意味着传输层结束了,紧随其后的就是TCP/IP五层架构的应用层。操作系统、编程语言、用户的可视化界面等等都要通过应用层来体现。应用层和我们息息相关,我们使用电子设备娱乐或办公时,接触到的就是应用…

package.json属性

添加链接描述 一、必须属性 name 定义项目的名称,不能以".“和”_"开头,不能包含大写字母version 定义项目的版本号,格式为:大版本号.次版本号.修订号 二、描述信息 description 项目描述keywords 项目关键词author …

【刷题笔记9.24】LeetCode:二叉树最大深度

LeetCode:二叉树最大深度 1、题目描述: 给定一个二叉树 root ,返回其最大深度。 二叉树的 最大深度 是指从根节点到最远叶子节点的最长路径上的节点数。 二、思路与算法 如果我们知道了左子树和右子树的最大深度 lll 和 rrr,…

力扣:109. 有序链表转换二叉搜索树(Python3)

题目: 给定一个单链表的头节点 head ,其中的元素 按升序排序 ,将其转换为高度平衡的二叉搜索树。 本题中,一个高度平衡二叉树是指一个二叉树每个节点 的左右两个子树的高度差不超过 1。 来源:力扣(LeetCod…

uni-app使用HBuilder X编辑器本地打包apk步骤说明

1.下载安装Android Studio 下载地址官方地址:Android Studio 下载文件归档 | Android 开发者 | Android Developers 安装Android SDK和Google USB Driver即可,后者主要是为了后期使用USB设置的,如果不需要可以不点。 2.下载uni-app提供…

ICMP差错包

ICMP报文分类 Type Code 描述 查询/差错 0-Echo响应 0 Echo响应报文 查询 3-目的不可达 0 目标网络不可达报文 差错 1 目标主机不可达报文 差错 2 目标协议不可达报文 差错 3 目标端口不可达报文 差错 4 要求分段并设置DF flag标志报文 差错 5 源路由…

Mac磁盘空间满了怎么办?Mac如何清理磁盘空间

你是不是发现你的Mac电脑存储越来越满,甚至操作系统本身就占了100多G的空间?这不仅影响了电脑的性能,而且也让你无法存储更多的重要文件和软件。别担心,今天这篇文章将告诉你如何清除多余的文件,让你的Mac重获新生。 一…

测试工程师通常用哪个单元测试库来测试Java程序?

测试工程师在测试Java程序时通常使用各种不同的单元测试库,具体选择取决于项目的需求和团队的偏好。我们先来看一些常用的Java单元测试库,以及它们的一些特点: 1.JUnit: 描述: JUnit 是Java中最广泛使用的单元测试库之一,它支持J…

gateway之过滤器(Filter)详解

文章目录 什么是过滤器过滤器的种类局部过滤器代码示例全局过滤器代码示例 总结 什么是过滤器 在Spring Cloud中,过滤器(Filter)是一种关键的组件,用于在微服务架构中处理和转换传入请求以及传出响应。过滤器位于服务网关或代理中…

Android AMS——AMS初始化(五)

Android AMS 也是一个系统服务,这里我们主要看一下 ActivityManagerService 的启动流程。 一、AMS启动流程 ActivityManagerService 既然是系统服务,那么肯定是通过 SystemServer 启动的,所以我们首先看一下 SystemServer 服务中启动 ActivityManagerService 相关代码。 S…

Angular:通过路由切换页面后,ngOnInit()不会被触发的问题

描述: 我在在使用angular 9版本,出现这样一个问题:我通过路由进入页面时候,会执行ngOnInit,切换到其他页面再切回,此时这个页面的ngInit不会主动执行 原因: 在Angular中,当一个组…

【力扣-每日一题】213. 打家劫舍 II

class Solution { public:int getMax(int n,vector<int> &nums){int a0,bnums[n],c0;for(int in1;i<nums.size()n-1;i){ //sizen-1,为0时&#xff0c;第一个可以偷&#xff0c;最后一个不能偷size-1&#xff1b;n为1时&#xff0c;最后一个可偷&#xff0c;计算…