神经网络学习笔记——如何设计、实现并训练一个标准的前馈神经网络

1.从零设计并训练一个神经网络icon-default.png?t=O83Ahttps://www.bilibili.com/video/BV134421U77t/?spm_id_from=333.337.search-card.all.click&vd_source=0b1f472915ac9cb9cdccb8658d6c2e69

一、如何设计、实现并训练一个标准的前馈神经网络,用于手写数字图像的分类,重点讲解了神经网络的设计和实现、数据的准备和处理、模型的训练和测试流程。

- 以数字图像作为输入,神经网络计算并识别图像中的数字。

- 输入层包含784个神经元,隐藏层用于特征提取,输出层包含10个神经元。

- 输出层的输出输入到soft max层,将十维的向量转换为十个概率值。

二、神经网络的设计思路和实现方法,以及手写数字识别的数据处理流程和代码实现,包括图像预处理、构建数据集等。

- 神经网络设计思路:每个概率值对应一个数字

- 手写数字识别训练数据:使用mini数据集

  • 数据处理流程:图像预处理、读取数据文件夹、构建数据集

三、使用PyTorch进行图像分类的步骤,包括读取数据、构建数据集、小批量数据读取、模型训练等,以及涉及到的对象和损失函数等。

  • 1、读取数据、构建数据集
  • 2、模型的训练
  • 使用train loader进行小批量数据读入,创建模型、优化器和损失函数进行训练
  • 训练模型的循环迭代,外层代表整个数据集的遍历次数,内层使用小批量数据读取进行梯度下降算法。
  • 3、模型的测试 
  • 注:测试的时候,需要​编辑model.eval() 
import torch
import torch.nn as nn# 定义模型结构
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc = nn.Linear(10, 2)self.dropout = nn.Dropout(0.5)self.batch_norm = nn.BatchNorm1d(2)def forward(self, x):x = self.fc(x)x = self.dropout(x)x = self.batch_norm(x)return x# 初始化模型
model = SimpleModel()# 加载训练好的模型权重
model.load_state_dict(torch.load('model.pth'))# 将模型设置为评估模式
model.eval()# 测试数据
test_input = torch.randn(5, 10)# 禁用梯度计算
with torch.no_grad():output = model(test_input)print(output)

Q1:为什么训练集要分批次训练,跟每条数据单独训练(batch_size=1)有什么不一样的吗?

  • 较大的 batch_size,梯度更新会更加平滑和稳定,模型能够更好地学到数据的总体分布特征。
  • 最优的batch size跟训练集的大小有关,大数据集适合大batch,小数据集适合小batch,极端情况下batch_size=1也不是不可以。

 Q2:为什么loss会不断变小?

  • 梯度下降只包含了局部的损失函数信息,所以只能保证存在趋近局部最优的可能。

Loss 在训练过程中不断变小是因为优化算法(如梯度下降)的作用,但这个现象背后有多个原因和理论支持。逐步解析:

1. 梯度下降原理

梯度下降算法的核心思想是利用目标函数(即损失函数)的梯度来迭代地更新模型的参数。梯度本身指示了损失函数增长最快的方向,因此,通过向梯度的反方向更新参数,可以逐步减小损失值。

2. 局部最优与全局最优

  • 局部最优:在多维空间中,损失函数可能存在多个局部最小值。梯度下降算法只能保证找到其中一个局部最小值,而不一定是全局最小值。
  • 全局最优:对于凸函数,任何局部最小值也是全局最小值。但对于非凸函数(大多数深度学习模型的损失函数),找到全局最小值更加复杂。

3. 损失函数的性质

  • 凸性:如果损失函数是凸的,那么任何局部最小值也是全局最小值,梯度下降法最终能够找到这个全局最小值。
  • 非凸性:对于非凸函数,虽然存在多个局部最小值,但梯度下降法依然可以找到某个局部最小值,使得损失函数值减小。

4. 学习率的作用

  • 学习率是梯度下降中一个关键的超参数,它决定了每一步参数更新的幅度。适当选择学习率可以保证算法的收敛性和稳定性。

5. 损失函数的优化目标

  • 训练过程中,优化的目标是最小化损失函数,这通常意味着模型的预测误差在减少
  • 随着训练的进行,模型逐渐学习到数据中的模式和结构,使得预测更加准确,从而损失值减小。

6. 泛化能力

  • 虽然训练过程中损失持续减小,但最终目标是提高模型在未知数据上的泛化能力
  • 为了防止过拟合,通常会采取正则化技术(如L1、L2正则化,Dropout等),以及早停(early stopping)策略。

7. 局部信息与全局搜索

  • 梯度下降利用的是局部信息(即当前位置的梯度),它提供了一种贪婪的搜索策略,每一步都朝着减少损失的方向前进。
  • 尽管只能保证趋近局部最优,但在实际应用中,通过合理的初始化、学习率调度和正则化策略,梯度下降往往能找到使损失足够小的参数配置。

结论

损失函数不断变小是因为梯度下降算法通过利用局部梯度信息来不断更新模型参数,使得模型逐渐学习到数据的内在规律,从而减少预测误差。虽然梯度下降只能保证找到局部最优解,但通过适当的策略和技巧,通常可以训练出性能良好的模型。


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/53818.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

九、外观模式

外观模式(Facade Pattern)是一种结构型设计模式,有叫门面模式,它为一个复杂子系统提供一个简单的接口,隐藏系统的复杂性。通过使用外观模式,客户端可以更方便地和复杂的系统进行交互,而无需直接…

【Android】【Bug】Activity全屏(保留底部按钮)被打断变成非全屏了

问题 在Activity里面设置全屏显示(保留底部按钮的全屏),刚开始的时候显示的也是全屏,但是在此页面进行一些操作之后,全屏变成非全屏了。 全屏设置方法 在 Activity 中的 onCreate 方法里,添加以下代码: Override p…

微信文件处理与命名机制分析(基于微信 8.0.50 版本)

微信文件处理与命名机制分析(基于微信 8.0.50 版本) 摘要 微信作为一款广泛使用的即时通讯工具,涉及图片、视频、文档等多种文件类型的传输与管理。本文基于微信 8.0.50 版本,探讨其对于图片、GIF、视频等文件的命名处理策略&am…

数据驱动的生态系统架构:打造智能化管理与业务增长的未来战略

在当今的数字化经济中,数据已成为企业最具战略价值的资产。通过数据的分析与应用,企业不仅能够提高业务效率,还能通过构建数据驱动的生态系统架构,实现跨行业协作与技术创新,最终提升全球竞争力。2024年生态系统架构可…

SpringBoot Jar 包加密防止反编译实战

今天给大家分享一个 SpringBoot 程序 Jar 包加密的方式,通过代码加密可以实现无法反编译。 应用场景就是当需要把公司的产品部署到友方公司或者其他公司时,可以防止客户直接反编译出来源码,大大提升代码的安全性。 版本 springboot 2.6.8j…

RuoYi 开源框架,集成了后端管理,后端java版 App 移动解决方案

文章目录 前言一、后端:二、后台管理三、App 移动总结 前言 后端: 后台管理: 使用的前端技术Vue、Element后端SpringBoot & Security完全分离的权限管理系统。 App 移动解决方案:采用uniapp框架 提示:以下是本篇文…

Java后端编程语言进阶篇

第一章 函数式接口 函数式接口是Java 8中引入的一个新特性,只包含一个抽象方法的接口。 函数式接口可以使用Lambda表达式来实现,从而实现函数式编程的特性。 使用 FunctionalInterface标识接口是函数式接口,编译器才会检查接口是否符合函数…

Qt 实现自定义截图工具

目录 Qt 实现自定义截图工具实现效果图PrintScreen 类介绍PrintScreen 类的主要特性 逐步实现第一步:类定义第二步:初始化截图窗口第三步:处理鼠标事件第四步:计算截图区域第五步:捕获和保存图像 完整代码PrintScreen.…

《程序猿之设计模式实战 · 池化思想》

📢 大家好,我是 【战神刘玉栋】,有10多年的研发经验,致力于前后端技术栈的知识沉淀和传播。 💗 🌻 CSDN入驻不久,希望大家多多支持,后续会继续提升文章质量,绝不滥竽充数…

“xi” 和 “dbscan” 在OPTICS聚类中是什么意思

在 OPTICS(Ordering Points To Identify the Clustering Structure) 聚类算法中,xi 和 dbscan 是两种不同的聚类提取方法,它们用于从OPTICS算法生成的排序数据中提取最终的聚类结构。具体解释如下: dbscan 方法: 该方法…

LSS如何做深度和语义预测

get_cam_feats() 先来看看代码: def get_cam_feats(self, x):"""Return B x N x D x H/downsample x W/downsample x C"""B, N

PHP函数如何传递数组参数

php 函数可以使用数组参数传递大量数据。语法:参数类型前加上方括号 ([])。例如:myfunction(array $arr)。实战案例:计算数组元素平均值。注意:数组参数默认为引用传递,类型提示可提高代码可读性,数组解构可…

解锁编程潜力,从掌握GitHub开始

目录: 一、搜索开源项目 1、什么是Git 2、Github常用词含义 3、一个完整的项目界面 4、使用Github搜索项目 1)in关键词 2)star或fork数量去查找 3)awesome加强搜索 二、访问速度慢的解决 1、使用网易UU加速器 2、使用…

OpenSSL工具验证RSA证书

openssl x509 是一个用于处理 X.509 证书的命令行工具。常用的 openssl x509 命令&#xff1a; -in <file>&#xff1a;指定输入文件。-out <file>&#xff1a;指定输出文件。-noout&#xff1a;不输出证书信息。-text&#xff1a;以文本格式输出证书信息。-pubke…

基于SSM的大学新生报到系统+LW参考示例

系列文章目录 1.基于SSM的洗衣房管理系统原生微信小程序LW参考示例 2.基于SpringBoot的宠物摄影网站管理系统LW参考示例 3.基于SpringBootVue的企业人事管理系统LW参考示例 4.基于SSM的高校实验室管理系统LW参考示例 5.基于SpringBoot的二手数码回收系统原生微信小程序LW参考示…

关于RabbitMQ消息丢失的解决方案

RabbitMQ如何保证消息的可靠性传输 一、消息丢失的原因 1. 生产者端 网络问题&#xff1a; 原因&#xff1a;生产者与RabbitMQ服务器之间的网络连接不稳定或中断&#xff0c;导致消息在传输过程中丢失。解决方案&#xff1a;确保网络连接稳定&#xff0c;监控网络状态&#x…

springboot后端开发-常见注解及其用途

文章目录 1. 组件注解2. 依赖注入注解3. 配置类注解4. 测试注解5. 控制器注解6. 安全和认证注解7. 切面相关注解8. API文档相关注解(需引入swagger)9. 其他注解 在Spring Boot框架中&#xff0c;有许多常用的注解用来简化开发过程中的依赖注入、组件扫描、配置、安全控制等方面…

VSCode创建C++项目和编译多文件

前言 在刚安装好VSCode后&#xff0c;我简单尝试了仅main.cpp单文件编译代码&#xff0c;没有问题&#xff0c;但是当我尝试多文件编译时&#xff0c;就出现了无法识别cpp文件。 内容 创建项目 首先点击左上角“文件”&#xff1b;在菜单中选择“打开文件夹”&#xff1b;在…

软件测试工程师面试整理-数据库与SQL

在软件测试过程中,数据库和SQL的知识是非常重要的,尤其是在涉及数据密集型应用或需要验证数据准确性的场景中。测试人员需要掌握SQL语句,以便查询、插入、更新和删除数据,并验证数据库操作的正确性。 1. 数据库基础知识 ● 关系型数据库:大多数应用使用关系型数据库(如My…

Qt什么时候触发paintEvent事件

‌paintEvent事件可以在以下几种情况下被触发‌&#xff1a; ‌窗口初始化和显示‌&#xff1a;当窗口首次被创建、显示&#xff0c;或者窗口被覆盖、最小化后再恢复时&#xff0c;paintEvent会被触发以绘制窗口的内容。‌部件大小或位置变化‌&#xff1a;如果窗口或部件的大…