卷积神经网络 - 参数学习

本文我们通过两个简化的例子,展示如何从前向传播、损失计算,到反向传播推导梯度,再到参数更新,完整地描述卷积层的参数学习过程。

一、例子一

我们构造一个非常简单的卷积神经网络,其结构仅包含一个卷积层和一个输出(不使用激活函数,为了便于数学推导),损失函数采用均方误差(MSE)。

1. 设定问题

输入数据
假设输入为一幅小的灰度图像 X:

例如,令

卷积核
使用一个 2×2 的卷积核 W:

并设有偏置 b。

卷积操作
采用“valid”卷积(不填充),在这种情况下,由于输入和核大小都为2×2,卷积操作仅得到一个输出标量 O:

O=(w1⋅x11+w2⋅x12+w3⋅x21+w4⋅x22)+b.

我们为了简化,不使用激活函数(即线性激活),这样前向计算就很直观。

目标输出
设定目标值为 y(比如标签值),假设 y=10。

损失函数
我们使用均方误差(MSE):

2. 前向传播计算

代入示例数据:

  • 初始假设卷积核权重和偏置(假设初始值为):

计算输出 O:

损失:

3. 梯度推导(反向传播)

我们需要计算损失 L 关于每个参数的梯度,即

步骤1:计算损失对输出 O 的梯度

代入数据:O−y=4.5−10=−5.5.

步骤2:计算输出 O 关于各参数的梯度

步骤3:链式法则计算损失对各参数的梯度

根据链式法则:

代入数值:

4. 参数更新(梯度下降)

设定学习率 η,例如 η=0.01,则更新规则为:

更新后的参数:

更新后,新的卷积核参数为:

5. 训练过程总结

整个训练过程如下:

  1. 前向传播:对输入图像进行卷积计算,得到输出 O。
  2. 计算损失:利用损失函数(MSE)计算模型输出与目标值之间的误差 L。
  3. 反向传播:根据链式法则计算损失对各参数(卷积核权重和偏置)的梯度。
  4. 参数更新:使用梯度下降(或其他优化算法)更新参数,向降低损失的方向调整。
  5. 迭代训练:重复上述步骤,遍历整个训练数据集,直到损失收敛或达到设定的迭代次数。

这个例子虽然非常简单(只有一个卷积层,一个输出单元),但它清楚展示了如何从前向传播计算输出、如何利用损失函数计算误差、如何通过反向传播推导每个参数的梯度,并最终利用梯度下降更新参数。实际的卷积神经网络通常包含多个卷积层、池化层和全连接层,但每个部分的参数学习原理都是类似的——都是通过不断地计算梯度并更新参数,使得模型能够更好地拟合数据,从而实现准确的图像分类、检测等任务。

二、例子二

以下是一个具体的卷积神经网络(CNN)训练过程示例,详细展示前向传播、损失计算、反向传播、参数更新的每个步骤。我们以简单的数字识别任务为例,使用极简网络结构便于手动验证。

示例设定

1. 输入数据
  • 输入图像:2×2的灰度图像,表示数字“1”

  • 真实标签:y_true=[1,0](二分类,索引0为“非1”,索引1为“是1”)

2. 网络结构
  • 卷积层:1个2×2卷积核,无填充,步长=1

    • 初始权重:

    • 激活函数:ReLU(ReLU(x)=max⁡(0,x))

  • 全连接层:输入1个神经元 → 输出2个神经元

    • 初始权重:

    • 激活函数:Softmax(输出概率)

3. 超参数
  • 学习率:η=0.1

  • 损失函数:交叉熵损失(Cross-Entropy Loss)

(一)前向传播(Forward Pass)

1. 卷积层计算
  • 卷积操作
    输入 X 与卷积核 W_conv 逐元素相乘后求和,加偏置:

    Conv输出=(0.2×0.5+0.8×(−0.3)+0.6×0.2+0.1×0.4)+0.1=0.1−0.24+0.12+0.04+0.1=0.12
  • ReLU激活

    Aconv=ReLU(0.12)=0.12
2. 全连接层计算
  • 线性变换

  • Softmax激活

    预测结果:类别1(“是1”)的概率为0.628,类别0为0.372。

(二)损失计算(Loss Calculation)

  • 交叉熵损失

(三)反向传播(Backward Pass)

1. 全连接层梯度
  • Softmax梯度

  • 权重梯度

  • 偏置梯度

2. 卷积层梯度
  • ReLU梯度

  • 传递到卷积层的梯度

  • 卷积核权重梯度

  • 卷积层偏置梯度

(四)参数更新(Parameter Update)

1. 全连接层参数
  • 权重更新

  • 偏置更新

2. 卷积层参数
  • 权重更新

  • 偏置更新

(五)更新后效果验证

1. 前向传播再计算
  • 新卷积输出

    Conv输出=(0.2×0.4933+0.8×(−0.3268)+0.6×0.1799+0.1×0.3967)+0.0665≈0.0987

    ReLU激活后:Aconv=0.0987

  • 新全连接输出

    Softmax概率:

2. 新损失值

New Loss=−log⁡(0.581)≈0.542(比原始损失0.465反而增大)

(六)结果分析

  1. 损失未下降的原因

    • 单步更新局限性:梯度下降可能需要多步迭代才能收敛。

    • 学习率过大:学习率 η=0.1 可能跳过最优解,可尝试更小值(如0.05)。

    • 网络容量限制:极简模型可能无法有效拟合复杂模式。

  2. 参数学习方向验证

(七)关键总结

  • 前向传播:数据从输入到输出的逐层变换。

  • 反向传播:通过链式法则计算梯度,明确参数调整方向。

  • 参数更新:沿负梯度方向微调参数,逐步逼近最优解。

  • 迭代优化:需多次迭代(Epoch)才能显著降低损失。

通过这个极简示例,可直观理解CNN参数学习的动态过程。实际训练中需结合批量数据、更复杂网络结构和优化策略(如Adam、学习率衰减)提升效果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/73501.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

.NET三层架构详解

.NET三层架构详解 文章目录 .NET三层架构详解引言什么是三层架构表示层(Presentation Layer)业务逻辑层(Business Logic Layer,BLL)数据访问层(Data Access Layer,DAL) .NET三层架构…

Redis实战常用二、缓存的使用

一、什么是缓存 在实际开发中,系统需要"避震器",防止过高的数据访问猛冲系统,导致其操作线程无法及时处理信息而瘫痪. 这在实际开发中对企业讲,对产品口碑,用户评价都是致命的。所以企业非常重视缓存技术; 缓存(Cache):就是数据交换的缓冲区&…

STM32八股【2】-----ARM架构

1、架构包含哪几部分内容 寄存器处理模式流水线MMU指令集中断FPU总线架构 2、以STM32为例进行介绍 2.1 寄存器 寄存器名称作用R0-R3通用寄存器用于数据传递、计算及函数参数传递;R0 也用于存储函数返回值。R4-R12通用寄存器用于存储局部变量,减少频繁…

effective Java 学习笔记(第二弹)

effective Java 学习笔记(第一弹) 整理自《effective Java 中文第3版》 本篇笔记整理第3,4章的内容。 重写equals方法需要注意的地方 自反性:对于任何非空引用 x,x.equals(x) 必须返回 true。对称性:对于…

mac命令行快捷键

光标移动 Ctrl A: 将光标移动到行首。Ctrl E: 将光标移动到行尾。Option 左箭头: 向左移动一个单词。Option 右箭头: 向右移动一个单词。 删除和修改 Ctrl K: 删除从光标到行尾的所有内容。Ctrl U: 删除从光标到行首的所有内容。Ctrl W: 删除光标前的一个单词。Ctrl …

CentOS 7部署主域名服务器 DNS

1. 安装 BIND 服务和工具 yum install -y bind bind-utils 2. 配置 BIND 服务 vim /etc/named.conf 修改以下配置项: listen-on port 53 { any; }; # 监听所有接口allow-query { any; }; # 允许所有设备查询 3 . 添加你的域名区域配置 …

优化 SQL 语句方向和提升性能技巧

优化 SQL 语句是提升 MySQL 性能的关键步骤之一。通过优化 SQL 语句,可以减少查询时间、降低服务器负载、提高系统吞吐量。以下是优化 SQL 语句的方法、策略和技巧: 一、优化 SQL 语句的方法 1. 使用 EXPLAIN 分析查询 作用:查看 SQL 语句的执行计划,了解查询是如何执行的…

C++ 多线程简要讲解

std::thread是 C11 标准库中用于多线程编程的核心类,提供线程的创建、管理和同步功能。下面我们一一讲解。 一.构造函数 官网的构造函数如下: 1.默认构造函数和线程创建 thread() noexcept; 作用:创建一个 std::thread 对象,但…

Vscode HTML5新增元素及属性

一、‌HTML5 语义化标签 HTML5 语义化标签&#xff08;Semantic Elements&#xff09;是一组 ‌具有明确含义的 HTML 元素‌&#xff0c;通过标签名称直接描述其内容或结构的功能&#xff0c;而非仅作为样式容器&#xff08;如 <div> 或 <span>&#xff09;。它们旨…

【PostgreSQL教程】PostgreSQL 特别篇之 语言接口Python

博主介绍:✌全网粉丝22W+,CSDN博客专家、Java领域优质创作者,掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域✌ 技术范围:SpringBoot、SpringCloud、Vue、SSM、HTML、Nodejs、Python、MySQL、PostgreSQL、大数据、物联网、机器学习等设计与开发。 感兴趣的可…

Three学习入门(四)

9-Three.js 贴图与材质学习指南 环境准备 <!DOCTYPE html> <html> <head><title>Three.js Texture Demo</title><style> body { margin: 0; } </style> </head> <body><script src"https://cdnjs.cloudflare.…

前端NVM安装

https://v0.dev/chat/settings 本地启动环境 1安装 nvm 2安装node nvm install v18.19.0 nvm install v20.9.0 nvm use 18 node -v 3安装 pnpm npm install -g pnpm 或者 npm i -g pnpm 4启动 代码 目录下 执行 pnpm i pnpm run dev 4.1到代码目录下 4.2直接cmd…

蓝桥杯算法精讲:二分查找实战与变种解析

适合人群&#xff1a;蓝桥杯备考生 | 算法竞赛入门者 | 二分查找进阶学习者 目录 一、二分查找核心要点 1. 算法思想 2. 适用条件 3. 算法模板 二、蓝桥杯真题实战 例题&#xff1a;分巧克力&#xff08;蓝桥杯2017省赛&#xff09; 三、二分查找变种与技巧 1. 查找左边…

cmd命令查看电脑的CPU、内存、存储量

目录 获取计算机硬件的相关信息的命令分别的功能结果展示结果说明获取计算机硬件的相关信息的命令 wmic cpu get name wmic memorychip get capacity wmic diskdrive get model,size,mediaType分别的功能 获取计算机中央处理器(CPU)的名称 获取计算机内存(RAM)芯片的容量…

SCI论文阅读指令(特征工程)

下面是一个SCI论文阅读特征工程V3.0&#xff0c;把指令输入大模型中&#xff0c;并上传PDF论文&#xff0c;就可以帮你快速阅读论文。 优先推荐kimi&#xff0c;当然DeepSeek、QwQ-32B等大语言模型也可以。测试了一下总结的还不错&#xff0c;很详细。 请仔细并深入地阅读所提…

如何监控 SQL Server

监控 SQL Server 对于维护数据库性能、确保数据可用性和最大限度地减少停机时间至关重要。随着企业越来越依赖数据驱动的决策&#xff0c;高效的SQL Server监控策略能显著提升组织生产力和用户满意度。 为什么要监控 SQL Server SQL Server 是许多关键应用程序的支柱&#xf…

python脚本处理excel文件

1.对比perl和python 分别尝试用perl和python处理excel文件&#xff0c;发现perl的比较复杂&#xff0c;比如说read excel就有很多方式 Spreadsheet::Read use Spreadsheet::ParseExcel 不同的method&#xff0c;对应的取sheet的cell方式也不一样。更复杂的是处理含有中文内…

3、pytest实现参数化

在 pytest 中&#xff0c;参数化&#xff08;parametrization&#xff09;是一种强大的功能&#xff0c;可以让你用不同的输入数据重复执行同一个测试函数。这种功能非常有用&#xff0c;可以帮助你显著减少重复代码并提高测试覆盖率。 参数化的主要作用是&#xff1a; 测试多…

Cursor:超强AI变成神器

是一个强大的 AI 编程助手&#xff0c;可以帮助开发者快速地编写、编辑和讨论代码&#xff0c;支持 Python、Java、C# 等多种编程语言&#xff0c;并且可以与 GitHub、Slack 等平台集成。 Cursor 是什么&#xff1f; 想象一下&#xff0c;你有一个能把你的创意变成现实的造梦 …

画秒杀系统流程图

秒杀系统流程图 秒杀系统关键点 高并发处理: 使用网关&#xff08;如 Nginx&#xff09;进行流量限流&#xff0c;避免过载。分布式锁或 Redis 原子操作控制并发。 活动状态检查: Redis 存储活动状态&#xff08;如 seckill:activity:1:status&#xff09;&#xff0c;快速…