从0到1,AI我来了- (4)AI图片识别的理论知识-II

上篇文章,我们理解了我们程序的神经网络设计,这篇我们继续,把训练迭代过程分析一下,完成这两篇文章,下面问题,应该能回答了。

  1. 一张图片,如何被计算机读懂?
  2. pytorch 封装的网络,什么是卷积层,为什么要多层?层与层如何衔接?如何设计?
  3. 什么是池化?为什么要池化?
  4. 什么是全链接层?它有什么作用?
  5. 神经网络模型的前向传播?这个步骤的作用是什么?
  6. 什么梯度下降?梯度下降的价值?
  7. 什么是激活函数?为什么要用激活函数?

一、上篇完成网络设计后,神经网络如何训练,自优化的?

回到程序【如需解读,请参阅从0到1,AI我来了- (2)解读程序-从AI手写数字识别开始】

for epoch in range(num_epochs):  running_loss = 0.0  correct = 0  total = 0  for images, labels in train_loader:  optimizer.zero_grad()  outputs = model(images)  loss = criterion(outputs, labels)  loss.backward()  optimizer.step()  running_loss += loss.item()  _, predicted = torch.max(outputs.data, 1)  total += labels.size(0)  correct += (predicted == labels).sum().item()  avg_loss = running_loss / len(train_loader)  accuracy = 100 * correct / total  # 记录到 TensorBoard  writer.add_scalar('Loss/train', avg_loss, epoch)  writer.add_scalar('Accuracy/train', accuracy, epoch)  # 记录到列表  loss_values.append(avg_loss)  accuracy_values.append(accuracy)  print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%')  writer.close() 

这里做了10次训练,每次迭代训练会经过:

         forward(前向传播)--》计算损失 --》backward(反向传播) --》更新权重

forward

        前篇提到了,定义网络模型时,有个函数forword。

forward通常是指一个模型的前向传播过程,它的主要作用是计算输入数据通过网络时的输出,即这里的输入数据图片到输出数字概率的过程。

  outputs = model(images)  

程序中,这句代码,实际就是会调用SimpleCNN 的forward的方法,输入是图片,输出的是概率结果。

计算损失Loss


criterion = nn.CrossEntropyLoss() 
...
loss = criterion(outputs, labels)  

torch.nn模块提供了多种常用的损失函数,这些损失函数可以用于不同类型的机器学习任务。附录一,有一些主要的损失函数,这里是个多分类问题,所以nn.CrossEntropyLoss 可以满足我们要求。

那CrossEntropyLoss 是如何计算损失的?

CrossEntropyLoss 结合了 softmax 函数和负对数似然损失(Negative Log Likelihood Loss)

大白话就是模型会跑出一些结果,比如[1.0,2.0,7.0],这里表示7.0的可能性最大,CrossEntropyLoss 会先把这些分数转化为概率,即把这些值变为0~1之间的值,且和为1。[1.0,2.0,7.0] 转化为[0.1,0.2,0.7]。

backward 反向传播

 loss.backward()  

非常关键的一行代码,简单理解就是,为了追求更小的损失,自动优化模型参数,不断迭代。

原理尝试解读一下,网上找了个图,解释一下:

损失函数L(w),随机指定权重,偏置(bias,随机标量),比如图中的W0哪个点,为了让损失更小(找到波谷),需要对w求导,即寻找W0点的切线。

如果切线斜率为负(如下图,w越大,Loss越小),说明应该增加w的值,也就是会自动根据学习率,更新w的值。

如果切线斜率为正(w越大,Loss 越大),说明梯度应该减少w的值。

更新权重

    optimizer.step()  

在计算出梯度后,我们需要根据这些梯度更新模型的参数。这里使用的优化器(optimizer)将会在调用 optimizer.step() 时生成参数的更新。

optimizer.step()  会更新哪些参数?

在调用 optimizer.step() 时,所有在优化器中注册的参数(即 model.parameters() 返回的那些)都会被更新。

参数类型,包括哪些?

在深度学习模型中,这些可学习的参数通常包括全连接层(nn.Linear)的权重和偏置、卷积层(nn.Conv2d)的权重和偏置、以及其他层的参数。

综上,那7个问题,应该能基本解答出来了,不往底层算法钻的话,能理解就行了。

下篇,我们我们来实践一个本地智能知识库,把常见的Agent、RAG,向量数据库分析分析。

附录:

一、分类任务
  1. 交叉熵损失 (Cross Entropy Loss):

    • torch.nn.CrossEntropyLoss:用于多类分类问题
    • torch.nn.BCEWithLogitsLoss:用于二分类问题的 logits(未经过 sigmoid)和目标相结合的二元交叉熵损失。
  2. 负对数似然损失 (Negative Log Likelihood Loss):

    • torch.nn.NLLLoss:配合在 LogSoftmax 后使用,适用于多类分类。
  3. KL 散度损失 (KL Divergence Loss):

    • torch.nn.KLDivLoss:用于衡量两个概率分布之间的差异。

回归任务

  1. 均方误差损失 (Mean Squared Error Loss):

    • torch.nn.MSELoss:适用于回归问题,计算预测值和目标值之间的均方差。
  2. 平均绝对误差损失 (Mean Absolute Error Loss):

    • torch.nn.L1Loss:计算预测值和目标值之间的平均绝对差。
  3. Huber 损失:

    • torch.nn.HuberLoss:结合了均方误差和平均绝对误差的特点,对于目标值较大时更为稳健。

其他损失函数

  1. 对比损失 (Contrastive Loss):

    • 自定义损失,一般用于 Siamese 网络。
  2. Triplet Loss:

    • 自定义损失,适用于需保持距离的学习(如人脸识别)。
  3. Focal Loss:

    • 自定义损失,用于处理类别不平衡问题。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/876792.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MATLAB禁忌蚁群算法求解充电电动车辆路径规划EVRP代码实例

MATLAB禁忌蚁群算法求解充电电动车辆路径规划EVRP代码实例 MATLAB禁忌蚁群算法求解充电电动车辆路径规划EVRP代码实例

DP 整数拆分不同的二叉搜索树 DAY21

整数拆分? 给定一个正整数 n ,将其拆分为 k 个 正整数 的和( k > 2 ),并使这些整数的乘积最大化。 返回 你可以获得的最大乘积。 示例 1: 输入: n 2 输出: 1 解释: 2 1 1, 1 1 1。示例 2: 输入: n 10 输…

全国区块链职业技能大赛样题第9套前端源码

后端源码地址:https://blog.csdn.net/Qhx20040819/article/details/140746050 前端源码地址:https://blog.csdn.net/Qhx20040819/article/details/140746216 智能合约+数据库表设计:https://blog.csdn.net/Qhx20040819/article/details/140746646 登录 ​ 用户管理

Unity Playables:下一代动画与音频序列

Unity的Playables API是一种灵活的系统,用于创建和控制动画、音频以及其他形式的连续媒体序列。它为开发者提供了一种全新的方法来处理游戏中的时间序列,包括动画、音频、特效等。本文将探讨Playables的基本概念、如何使用Playables API实现动画&#xf…

又一成就,Pencils Protocol单链 TVL 突破 3 亿美元

Pencils Protocol 是 Scroll 生态的原生项目,该项目以一站式收益聚合器和拍卖平台作为主要定位,在功能上,其集 Launchpad、资产统一聚合和分发、杠杆收益等功能于一体,旨在最大化用户的资产利用率。近日,Pencils Proto…

利用python自动化运维i脚本实现远程连接服务器并实现相应命令

目录 前言: 一.调用的python库介绍 二.在主机上安装好相应的库 2.1激活虚拟环境 三.代码实现以及解析 四.效果的实现 五.致谢 前言: 在当今快速发展的技术环境中,自动化运维已成为 IT 基础设施管理的关键组成部分。它不仅可以显著提…

大学生算法高等数学学习平台设计方案 (第一版)

目录 目标用户群体的精准定位 初阶探索者 进阶学习者 资深研究者 功能需求的深度拓展 个性化学习路径定制 概念图谱构建 公式推导展示 交互式问题解决系统 新功能和创新点的引入 虚拟教室环境 数学建模工具集成 算法可视化平台 学术论文资源库 技术实现的前瞻性…

PHP魔术常量

PHP 中的魔术常量(Magic Constants)是一组特殊的预定义常量,它们在脚本的任何时候都可用,并且它们的值会根据它们使用的上下文动态变化。这些常量在开发过程中非常有用,尤其是在需要根据当前环境或脚本位置动态改变行为…

Lua编程

文章目录 概述lua数据类型元表注意 闭包表现 实现 lua/c 接口编程skynet中调用层次虚拟栈C闭包注册表userdatalightuserdata 小结 概述 这次是skynet,需要一些lua/c相关的。写一篇博客,记录下。希望有所收获。 lua数据类型 boolean , number , string…

大模型算法面试题(十五)

本系列收纳各种大模型面试题及答案。 1、大模型LLM进行SFT如何对样本进行优化 大模型LLM(Language Model,语言模型)进行SFT(Structured Fine-Tuning,结构化微调)时,对样本的优化是提升模型性能…

Linux源码阅读笔记16-文件系统关联及字符设备操作

文件系统关联 设备文件都是由标准函数处理,类似普通文件。设备文件也是通过虚拟文件系统来管理的,和普通文件都是通过完全相同的接口访问的。 inode中设备文件的成员数据 虚拟文件系统每个文件都关联到一个inode,用于管理文件的属性。源码如…

【Go - context 速览,场景与用法】

作用 context字面意思上下文,用于关联管理上下文,具体有如下几个作用 取消信号传递:可以用来传递取消信号,让一个正在执行的函数知道它应该提前终止。超时控制:可以设定一个超时时间,自动取消超过执行时间…

Swift学习入门,新手小白看过来

😄作者简介: 小曾同学.com,一个致力于测试开发的博主⛽️,主要职责:测试开发、CI/CD 如果文章知识点有错误的地方,还请大家指正,让我们一起学习,一起进步。 😊 座右铭:不…

文本分类动转静预测错误分析和挖掘稀疏数据和建立新数据集.ipynb

import os import paddle from paddlenlp.transformers import AutoModelForSequenceClassification params_pathcheckpoint/text_classes/ output_pathoutput/text_class model AutoModelForSequenceClassification.from_pretrained(params_path) model.eval() # 转换为具…

(十三)Spring教程——依赖注入之工厂方法注入

1.工厂方法注入 工厂方法是在应用中被经常使用的设计模式,它也是控制反转和单例设计思想的主要实现方法。由于Spring IoC容器以框架的方式提供工厂方法的功能,并以透明的方式开放给开发者,所以很少需要手工编写基于工厂方法的类。正是因为工厂…

如何从网站获取表格数据

1.手动复制粘贴 最简单的方法是直接在网页上手动选择表格内容,然后复制粘贴到Excel或其他表格处理软件中。这种方法适用于表格较小且不经常更新的情况。 2.使用浏览器插件 有许多浏览器插件可以帮助从网页中提取表格数据,例如: -TableCapt…

SSRF过滤攻击

SSRF绕过: 靶场地址:重庆橙子科技SSRF靶场 这个是毫无过滤的直接读取,但是一般网站会设置有对SSRF的过滤,比如将IP地址过滤。 下面是常用的绕过方式: 1.环回地址绕过 http://127.0.0.1/flag.php http://017700…

相机怎么选(不推荐,只分析)

title: 相机怎么选 tags: [相机, 单反相机] categories: [其他, 相机] 最近准备购买,相机怎么选,我去搜索了许多文章,整理了一篇小白挑选技术篇,供大家参考。 分类 胶片相机 需要装入胶卷才能使用的相机,拍照后可直…

永磁同步电机无速度算法--非线性磁链观测器

非线性磁链观测器顾名思义观测器的状态变量为磁链值,观测的磁链值收敛于电机实际磁链值,观测器收敛。非线性是由于观测器存在sin和cos项,所以是非线性观测器 一、原理介绍 表贴式永磁同步电机αβ轴电压方程: 将公式变换 定义状态变量X: 定…

easy-ui nowrap

​​easy-ui​​ 是一个基于 jQuery 的前端框架,用于构建现代化的 Web 应用程序。它提供了丰富的组件和功能,简化了 Web 应用的开发。 ​​nowrap​​ 是 ​​easy-ui​​ 中的一个属性,用于控制表格列(或其他容器)中…