神经网络学习笔记——如何设计、实现并训练一个标准的前馈神经网络

1.从零设计并训练一个神经网络icon-default.png?t=O83Ahttps://www.bilibili.com/video/BV134421U77t/?spm_id_from=333.337.search-card.all.click&vd_source=0b1f472915ac9cb9cdccb8658d6c2e69

一、如何设计、实现并训练一个标准的前馈神经网络,用于手写数字图像的分类,重点讲解了神经网络的设计和实现、数据的准备和处理、模型的训练和测试流程。

- 以数字图像作为输入,神经网络计算并识别图像中的数字。

- 输入层包含784个神经元,隐藏层用于特征提取,输出层包含10个神经元。

- 输出层的输出输入到soft max层,将十维的向量转换为十个概率值。

二、神经网络的设计思路和实现方法,以及手写数字识别的数据处理流程和代码实现,包括图像预处理、构建数据集等。

- 神经网络设计思路:每个概率值对应一个数字

- 手写数字识别训练数据:使用mini数据集

  • 数据处理流程:图像预处理、读取数据文件夹、构建数据集

三、使用PyTorch进行图像分类的步骤,包括读取数据、构建数据集、小批量数据读取、模型训练等,以及涉及到的对象和损失函数等。

  • 1、读取数据、构建数据集
  • 2、模型的训练
  • 使用train loader进行小批量数据读入,创建模型、优化器和损失函数进行训练
  • 训练模型的循环迭代,外层代表整个数据集的遍历次数,内层使用小批量数据读取进行梯度下降算法。
  • 3、模型的测试 
  • 注:测试的时候,需要​编辑model.eval() 
import torch
import torch.nn as nn# 定义模型结构
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc = nn.Linear(10, 2)self.dropout = nn.Dropout(0.5)self.batch_norm = nn.BatchNorm1d(2)def forward(self, x):x = self.fc(x)x = self.dropout(x)x = self.batch_norm(x)return x# 初始化模型
model = SimpleModel()# 加载训练好的模型权重
model.load_state_dict(torch.load('model.pth'))# 将模型设置为评估模式
model.eval()# 测试数据
test_input = torch.randn(5, 10)# 禁用梯度计算
with torch.no_grad():output = model(test_input)print(output)

Q1:为什么训练集要分批次训练,跟每条数据单独训练(batch_size=1)有什么不一样的吗?

  • 较大的 batch_size,梯度更新会更加平滑和稳定,模型能够更好地学到数据的总体分布特征。
  • 最优的batch size跟训练集的大小有关,大数据集适合大batch,小数据集适合小batch,极端情况下batch_size=1也不是不可以。

 Q2:为什么loss会不断变小?

  • 梯度下降只包含了局部的损失函数信息,所以只能保证存在趋近局部最优的可能。

Loss 在训练过程中不断变小是因为优化算法(如梯度下降)的作用,但这个现象背后有多个原因和理论支持。逐步解析:

1. 梯度下降原理

梯度下降算法的核心思想是利用目标函数(即损失函数)的梯度来迭代地更新模型的参数。梯度本身指示了损失函数增长最快的方向,因此,通过向梯度的反方向更新参数,可以逐步减小损失值。

2. 局部最优与全局最优

  • 局部最优:在多维空间中,损失函数可能存在多个局部最小值。梯度下降算法只能保证找到其中一个局部最小值,而不一定是全局最小值。
  • 全局最优:对于凸函数,任何局部最小值也是全局最小值。但对于非凸函数(大多数深度学习模型的损失函数),找到全局最小值更加复杂。

3. 损失函数的性质

  • 凸性:如果损失函数是凸的,那么任何局部最小值也是全局最小值,梯度下降法最终能够找到这个全局最小值。
  • 非凸性:对于非凸函数,虽然存在多个局部最小值,但梯度下降法依然可以找到某个局部最小值,使得损失函数值减小。

4. 学习率的作用

  • 学习率是梯度下降中一个关键的超参数,它决定了每一步参数更新的幅度。适当选择学习率可以保证算法的收敛性和稳定性。

5. 损失函数的优化目标

  • 训练过程中,优化的目标是最小化损失函数,这通常意味着模型的预测误差在减少
  • 随着训练的进行,模型逐渐学习到数据中的模式和结构,使得预测更加准确,从而损失值减小。

6. 泛化能力

  • 虽然训练过程中损失持续减小,但最终目标是提高模型在未知数据上的泛化能力
  • 为了防止过拟合,通常会采取正则化技术(如L1、L2正则化,Dropout等),以及早停(early stopping)策略。

7. 局部信息与全局搜索

  • 梯度下降利用的是局部信息(即当前位置的梯度),它提供了一种贪婪的搜索策略,每一步都朝着减少损失的方向前进。
  • 尽管只能保证趋近局部最优,但在实际应用中,通过合理的初始化、学习率调度和正则化策略,梯度下降往往能找到使损失足够小的参数配置。

结论

损失函数不断变小是因为梯度下降算法通过利用局部梯度信息来不断更新模型参数,使得模型逐渐学习到数据的内在规律,从而减少预测误差。虽然梯度下降只能保证找到局部最优解,但通过适当的策略和技巧,通常可以训练出性能良好的模型。


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/53818.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

九、外观模式

外观模式(Facade Pattern)是一种结构型设计模式,有叫门面模式,它为一个复杂子系统提供一个简单的接口,隐藏系统的复杂性。通过使用外观模式,客户端可以更方便地和复杂的系统进行交互,而无需直接…

数据驱动的生态系统架构:打造智能化管理与业务增长的未来战略

在当今的数字化经济中,数据已成为企业最具战略价值的资产。通过数据的分析与应用,企业不仅能够提高业务效率,还能通过构建数据驱动的生态系统架构,实现跨行业协作与技术创新,最终提升全球竞争力。2024年生态系统架构可…

SpringBoot Jar 包加密防止反编译实战

今天给大家分享一个 SpringBoot 程序 Jar 包加密的方式,通过代码加密可以实现无法反编译。 应用场景就是当需要把公司的产品部署到友方公司或者其他公司时,可以防止客户直接反编译出来源码,大大提升代码的安全性。 版本 springboot 2.6.8j…

Java后端编程语言进阶篇

第一章 函数式接口 函数式接口是Java 8中引入的一个新特性,只包含一个抽象方法的接口。 函数式接口可以使用Lambda表达式来实现,从而实现函数式编程的特性。 使用 FunctionalInterface标识接口是函数式接口,编译器才会检查接口是否符合函数…

Qt 实现自定义截图工具

目录 Qt 实现自定义截图工具实现效果图PrintScreen 类介绍PrintScreen 类的主要特性 逐步实现第一步:类定义第二步:初始化截图窗口第三步:处理鼠标事件第四步:计算截图区域第五步:捕获和保存图像 完整代码PrintScreen.…

《程序猿之设计模式实战 · 池化思想》

📢 大家好,我是 【战神刘玉栋】,有10多年的研发经验,致力于前后端技术栈的知识沉淀和传播。 💗 🌻 CSDN入驻不久,希望大家多多支持,后续会继续提升文章质量,绝不滥竽充数…

LSS如何做深度和语义预测

get_cam_feats() 先来看看代码: def get_cam_feats(self, x):"""Return B x N x D x H/downsample x W/downsample x C"""B, N

解锁编程潜力,从掌握GitHub开始

目录: 一、搜索开源项目 1、什么是Git 2、Github常用词含义 3、一个完整的项目界面 4、使用Github搜索项目 1)in关键词 2)star或fork数量去查找 3)awesome加强搜索 二、访问速度慢的解决 1、使用网易UU加速器 2、使用…

基于SSM的大学新生报到系统+LW参考示例

系列文章目录 1.基于SSM的洗衣房管理系统原生微信小程序LW参考示例 2.基于SpringBoot的宠物摄影网站管理系统LW参考示例 3.基于SpringBootVue的企业人事管理系统LW参考示例 4.基于SSM的高校实验室管理系统LW参考示例 5.基于SpringBoot的二手数码回收系统原生微信小程序LW参考示…

VSCode创建C++项目和编译多文件

前言 在刚安装好VSCode后,我简单尝试了仅main.cpp单文件编译代码,没有问题,但是当我尝试多文件编译时,就出现了无法识别cpp文件。 内容 创建项目 首先点击左上角“文件”;在菜单中选择“打开文件夹”;在…

【Elasticsearch系列二】安装 Kibana

💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

ClickHouse 24.8 LTS 版本发布说明

本文字数:13885;估计阅读时间:35 分钟 作者:ClickHouse Team 本文在公众号【ClickHouseInc】首发 时间飞逝,又到了新版本发布的时刻! 发布概要 本次ClickHouse 24.8 版本包含了19个新功能🎁、18…

基于51单片机的16X16点阵显示屏proteus仿真

地址: https://pan.baidu.com/s/1JQ225NSKweqf1Zlad_f1Mw 提取码:1234 仿真图: 芯片/模块的特点: AT89C52/AT89C51简介: AT89C52/AT89C51是一款经典的8位单片机,是意法半导体(STMicroelectro…

管家婆云辉煌手机端怎么连接蓝牙打印机?

管家婆云辉煌手机端可以连接蓝牙打印机,这样手机可以发送打印任务到蓝牙打印机,完成打印任务。具体的设置步骤如下: 一、首先完成手机和蓝牙打印机配对,打开蓝牙打印机后。手机开启蓝牙和定位服务 点击手机设置,进入手…

分类预测|基于差分优化DE-支持向量机数据分类预测完整Matlab程序 DE-SVM

分类预测|基于差分优化DE-支持向量机数据分类预测完整Matlab程序 DE-SVM 文章目录 一、基本原理DE-SVM 分类预测原理和流程总结 二、实验结果三、核心代码四、代码获取五、总结 一、基本原理 DE-SVM 分类预测原理和流程 1. 差分进化优化算法(DE) 原理…

【深度学习】【图像分类】【OnnxRuntime】【Python】VggNet模型部署

【深度学习】【图像分类】【OnnxRuntime】【Python】VggNet模型部署 提示:博主取舍了很多大佬的博文并亲测有效,分享笔记邀大家共同学习讨论 文章目录 【深度学习】【图像分类】【OnnxRuntime】【Python】VggNet模型部署前言Windows平台搭建依赖环境模型转换--pytorch转onnxONN…

走进低代码表单开发(一):可视化表单数据源设计

在前文,我们已对勤研低代码平台的报表功能做了详细介绍。接下来,让我们深入探究低代码开发中最为常用的表单设计功能。一个完整的应用是由众多表单组合而成的,所以高效的表单设计在开发过程中起着至关重要的作用。让我们一同了解勤研低代码开…

[网络]http/https的简单认识

文章目录 一. 什么是http二. http协议工作过程三. http协议格式1. 抓包工具fiddler2. http请求报文3. http响应报文 一. 什么是http HTTP (全称为 “超⽂本传输协议”) 是⼀种应⽤⾮常⼴泛的 应⽤层协议 HTTP 诞⽣与1991年. ⽬前已经发展为最主流使⽤的⼀种应⽤层协议 HTTP 往…

FPGA实现串口升级及MultiBoot(四)MultiBoot简介

缩略词索引: K7:Kintex 7V7:Vertex 7A7:Artix 7 我们在正常升级的过程(只使用一个位流文件),假如:(1)因为干扰通信模块收到了一个错误位;(2)或者烧写进FLASH时…