2024-11-27 学习人工智能的Day32 神经网络与反向传播

一、神经网络

神经网络神经网络(Neural Networks)是一种模拟人脑神经元网络结构的计算模型,用于处理复杂的模式识别、分类和预测等任务。

人工神经元是神经网络的基础构建单元,模仿了神武神经元的工作原理,核心功能是接收输入信号,经过加权求和和非线性激活函数处理后输出结果。

  1. 人工神经元
    • 输入:神经元接收多个输入,每个输入有一个对应的权重。
    • 加权求和:计算输入的加权和,添加偏置项。
    • 激活函数:对加权和结果应用激活函数,引入非线性能力。
  2. 神经网络结构
    • 输入层:接收数据,不进行计算。
    • 隐藏层:提取特征,通过多层堆叠实现复杂变换。
    • 输出层:生成预测结果或分类。
  3. 全连接网络
    • 每层神经元与上一层的所有神经元相连。
    • 常见问题:参数量大,容易过拟合;对高维数据局部特征捕捉能力差。

先说一下参数的初始化,对于线性回归模型来说,在进行正向传播时,不能缺失的值有构成前向传播方程的偏置和权重这两个值,而对于输入层来说,就是传入y = w*x+b中的x,而为了方便传播,在输入层中就会将传入的x引入这个方程来进行定义的线性变换,但是输出层只负责最简单的线性变化。

在这个线性变化过程中就有一个问题,这个方程中的w和b哪里来?

由于之前学过线性回归的时候我们知道,这个前向传播的方程是为了算出预测值来与实际数据进行损失计算然后通过反向传播来更新梯度值进行简单的梯度下降来逼近我们希望求得的权重和偏置,所以输出层的权重和偏置是由我们自定义的,而torch中就提供了很多种初始化的方法。接下来说几个例子。

全零初始化

因为会破坏对称性,一般治用于初始化偏置

import torch
import torch.nn as nndef test():linear = nn.Linear(in_feature=6,out_feature=4)nn.init.zeros_(linear.weight)print(linear.weight)if __name__ == "main":test()

然后还有全1初始化就没必要再写一遍了。

任意常数初始化

import torch
import torch.nn as nndef test002():# 2. 固定值参数初始化linear = nn.Linear(in_features=6, out_features=4)# 初始化权重参数nn.init.constant_(linear.weight, 0.63)# 打印权重参数print(linear.weight)passif __name__ == "__main__":test002()

防止了参数为0的情况出现,但是运用在权重上依旧避免不了破坏对称性的问题。

随机初始化

最基本的初始化方式,避免破坏对称性

import torch
import torch.nn as nndef test001():# 1. 均匀分布随机初始化linear = nn.Linear(in_features=6, out_features=4)# 初始化权重参数nn.init.uniform_(linear.weight)# 打印权重参数print(linear.weight)if __name__ == "__main__":test001()

Xavier初始化

也叫做Glorot初始化。

方法:根据输入和输出神经元的数量来选择权重的初始值。权重从以下分布中采样:
W ∼ U ( − 6 n i n + n o u t , 6 n i n + n o u t ) W\sim\mathrm{U}\left(-\frac{\sqrt{6}}{\sqrt{n_\mathrm{in}+n_\mathrm{out}}},\frac{\sqrt{6}}{\sqrt{n_\mathrm{in}+n_\mathrm{out}}}\right) WU(nin+nout 6 ,nin+nout 6 )
或者
W ∼ N ( 0 , 2 n i n + n o u t ) W\sim\mathrm{N}\left(0,\frac{2}{n_\mathrm{in}+n_\mathrm{out}}\right) WN(0,nin+nout2)

N ( 0 , std 2 ) N ( 0 , std 2 ) \mathcal{N}(0, \text{std}^2)\mathcal{N}(0, \text{std}^2) N(0,std2)N(0,std2)

其中 n in n_{\text{in}} nin 是当前层的输入神经元数量, n out n_{\text{out}} nout是输出神经元数量。

优点:平衡了输入和输出的方差,适合Sigmoid和 Tanh 激活函数。

应用场景:常用于浅层网络或使用Sigmoid、Tanh 激活函数的网络。

import torch
import torch.nn as nndef test007():# Xavier初始化:正态分布linear = nn.Linear(in_features=6, out_features=4)nn.init.xavier_normal_(linear.weight)print(linear.weight)# Xavier初始化:均匀分布linear = nn.Linear(in_features=6, out_features=4)nn.init.xavier_uniform_(linear.weight)print(linear.weight)if __name__ == "__main__":test007()

二、激活函数

就我自己的理解来是,进入每个神经元时,就是一次线性变换再输出给下一个神经元的时候则是通过非线性变化的激活函数来提高数据的分离度,而如果不设置激活函数也是能传递的,只是他始终进行的线性的变化,而线性数据的一个缺点就是无法表示复杂的非线性的关系,如异或问题。而引入激活函数的优点还有

  1. 通过升维来提取更改层次的抽象特征
  2. 控制输出值的范围,稳定梯度传播
  3. 动态调整神经元的激活状态,增强表达能力
  4. 实现逼近任意复杂函数的能力

下面是常用的激活函数,

  1. 常见激活函数
    • Sigmoid:将输出映射到 (0,1)(0, 1)(0,1),适合二分类输出层;易出现梯度消失。
    • Tanh:将输出映射到 (−1,1)(-1, 1)(−1,1),适合隐藏层,梯度较大,缓解梯度消失问题。
    • ReLU:输出正值本身,负值为零,计算简单,适合深层网络;存在神经元死亡问题。
    • Leaky ReLU:改进 ReLU,为负值引入小斜率,避免神经元死亡。
    • Softmax:将输出映射为概率分布,适合多分类任务的输出层。
  2. 选择建议
    • 隐藏层优先选择 ReLU 或其变体。
    • 输出层根据任务选择:二分类用 Sigmoid,多分类用 Softmax,回归任务用线性激活。

由于激活函数的公式过多,就不说了。

三、损失函数

这里的损失函数与之前学习反向传播时的概念无差异,只是根据不同的任务推出了不同损失函数去更好的适配这些任务的回归。

  1. 线性回归损失

    • MAE:平均绝对误差,对异常值鲁棒。
    • MSE:均方误差,对大误差敏感。
    • Smooth L1:结合 MAE 和 MSE 的优点,平滑过渡。
  2. 分类损失

    • 交叉熵损失:

      • 多分类问题与 Softmax 搭配。
      • 衡量预测分布与真实分布的差异。
    • 二分类交叉熵:

      • 二分类问题与 Sigmoid 搭配。
      • 用于衡量模型对两个类别的预测效果。
  3. 总结

    • 回归任务:MSE、Smooth L1。
    • 二分类任务:Sigmoid + 二分类交叉熵。
    • 多分类任务:Softmax + 交叉熵。

反向传播的api实现

由于上面所有的都是为了反向传播而服务,所以单论以上的所有都是理论上的,只有在神经网络中实现反向传播才会真正的运用到上面所有的知识。

import torch
import torch.nn as nn
import torch.optim as optimclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.linear1 = nn.Linear(2, 2)self.linear2 = nn.Linear(2, 2)# 网络参数初始化self.linear1.weight.data = torch.tensor([[0.15, 0.20], [0.25, 0.30]])self.linear2.weight.data = torch.tensor([[0.40, 0.45], [0.50, 0.55]])self.linear1.bias.data = torch.tensor([0.35, 0.35])self.linear2.bias.data = torch.tensor([0.60, 0.60])def forward(self, x):x = self.linear1(x)x = torch.sigmoid(x)x = self.linear2(x)x = torch.sigmoid(x)return xif __name__ == "__main__":inputs = torch.tensor([[0.05, 0.10]])target = torch.tensor([[0.01, 0.99]])# 获得网络输出值net = Net()output = net(inputs)# 计算误差loss = torch.sum((output - target) ** 2) / 2# 优化方法optimizer = optim.SGD(net.parameters(), lr=0.5)# 梯度清零optimizer.zero_grad()# 反向传播loss.backward()# 打印(w1-w8)观察w5、w7、w1 的梯度值是否与手动计算一致print(net.linear1.weight.grad.data)print(net.linear2.weight.grad.data)#更新梯度optimizer.step()# 打印更新后的网络参数print(net.state_dict())

这个程序中没有用到上面说到的这些初始化,而是自定义传参初始化。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/887603.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

代码随想录打卡DAY21

算法记录第21天 [二叉树] 1.LeetCode 538. 把二叉搜索树转换为累加树 题目描述: 给出二叉 搜索 树的根节点,该树的节点值各不相同,请你将其转换为累加树(Greater Sum Tree),使每个节点 node 的新值等于原…

[在线实验]-ActiveMQ Docker镜像的下载与部署

镜像下载 下载ActiveMQ的Docker镜像文件。通常,这些文件会以.tar格式提供,例如activemq.tar。 docker的activemq镜像资源-CSDN文库 加载镜像 下载完成后,您可以使用以下命令将镜像文件加载到Docker中: docker load --input a…

VTK中对于相机camera的设置

1. 相机的核心属性 在 VTK 中,vtkCamera 的核心属性有默认值。如果你不设置这些属性,相机会使用默认值来渲染场景。 Position(默认值:(0, 0, 1)): 默认情况下,相机位于 Z 轴正方向的 (0, 0, 1)…

学习日志017--python的几种排序算法

冒泡排序 def bubble_sort(alist):i 0while i<len(alist):j0while j<len(alist)-1:if alist[j]>alist[j1]:alist[j],alist[j1] alist[j1],alist[j]j1i1l [2,4,6,8,0,1,3,5,7,9] bubble_sort(l) print(l) 选择排序 def select_sort(alist):i 0while i<len(al…

超高流量多级缓存架构设计!

文章内容已经收录在《面试进阶之路》&#xff0c;从原理出发&#xff0c;直击面试难点&#xff0c;实现更高维度的降维打击&#xff01; 文章目录 电商-多级缓存架构设计多级缓存架构介绍多级缓存请求流程负载均衡算法的选择轮询负载均衡一致性哈希负载均衡算法选择 应用层 Ngi…

红黑树的概念以及基本模拟

目录 一、概念和规则&#xff1a; 1、思考为什么最长路径不超过最短路径的二倍&#xff1f; 2、红黑树的效率&#xff1f; 二、红黑树的代码实现 1、红黑树的节点结构 2、红黑树的插入 1、大致过程&#xff1a; 2、维护的三种情况&#xff1a; 1、情况一&#xff1a;变…

IP反向追踪技术,了解一下?

DOSS&#xff08;拒绝服务&#xff09;攻击是现在比较常见的网络攻击手段。想象一下&#xff0c;有某个恶意分子想要搞垮某个网站&#xff0c;他就会使用DOSS攻击。这种攻击常常使用的方式是IP欺骗。他会伪装成正常的IP地址&#xff0c;让网络服务器以为有很多平常的请求&#…

【C++习题】15.滑动窗口_串联所有单词的子串

文章目录 题目链接&#xff1a;题目描述&#xff1a;解法C 算法代码&#xff1a;图解 题目链接&#xff1a; 30. 串联所有单词的子串 题目描述&#xff1a; 解法 滑动窗口哈希表 这题和第14题不同的是&#xff1a; 哈希表不同&#xff1a;hash<string,int>left与right指…

论文笔记(五十七)Diffusion Model Predictive Control

Diffusion Model Predictive Control 文章概括摘要1. Introduction2. Related work3. 方法3.1 模型预测控制3.2. 模型学习3.3. 规划&#xff08;Planning&#xff09;3.4. 适应 4. 实验&#xff08;Experiments&#xff09;4.1. 对于固定奖励&#xff0c;D-MPC 可与其他离线 RL…

oracle 创建只可以查询权限用户+sqldeveloper如何看到对应表

声明 申明部分是从其他csdn用户哪里复制的&#xff0c;只是自己操作后发现无法达到我最后的预期&#xff0c;所以关闭忘记是看的那篇了&#xff0c;如果有侵权请见谅&#xff0c;联系我删除谢谢。 好了&#xff0c;故事的开始是我最近删投产表了。没错职业黑点&#xff0c;清…

比特币libsecp256k1中safegcd算法形式化验证完成

1. 引言 比特币和其他链&#xff08;如 Liquid&#xff09;的安全性取决于 ECDSA 和 Schnorr 签名等数字签名算法的使用。Bitcoin Core 和 Liquid 都使用名为 libsecp256k1 的 C 库来提供这些数字签名算法&#xff0c;该库以其所运行的椭圆曲线命名。这些算法利用一种称为modu…

15分钟做完一个小程序,腾讯这个工具有点东西

我记得很久之前&#xff0c;我们都在讲什么低代码/无代码平台&#xff0c;这个概念很久了&#xff0c;但是&#xff0c;一直没有很好的落地&#xff0c;整体的效果也不算好。 自从去年 ChatGPT 这类大模型大火以来&#xff0c;各大科技公司也都推出了很多 AI 代码助手&#xff…

Kafka知识体系

一、认识Kafka 1. kafka适用场景 消息系统&#xff1a;kafka不仅具备传统的系统解耦、流量削峰、缓冲、异步通信、可扩展性、可恢复性等功能&#xff0c;还有其他消息系统难以实现的消息顺序消费及消息回溯功能。 存储系统&#xff1a;kafka把消息持久化到磁盘上&#xff0c…

JVM调优篇之JVM基础入门AND字节码文件解读

目录 Java程序编译class文件内容常量池附录-访问标识表附录-常量池类型列表 Java程序编译 Java文件通过编译成class文件后&#xff0c;通过JVM虚拟机解释字节码文件转为操作系统执行的二进制码运行。 规范 Java虚拟机有自己的一套规范&#xff0c;遵循这套规范&#xff0c;任…

【Petri网导论学习笔记】Petri网导论入门学习(十一) —— 3.3 变迁发生序列与Petri网语言

目录 3.3 变迁发生序列与Petri网语言定义 3.4定义 3.5定义 3.6定理 3.5例 3.9定义 3.7例 3.10定理 3.6定理 3.7 有界Petri网泵引理推论 3.5定义 3.9定理 3.8定义 3.10定义 3.11定义 3.12定理 3.93.3 变迁发生序列与Petri网语言 对于 Petri 网进行分析的另一种方法是考察网系统…

Flink--API 之Transformation-转换算子的使用解析

目录 一、常用转换算子详解 &#xff08;一&#xff09;map 算子 &#xff08;二&#xff09;flatMap 算子 &#xff08;三&#xff09;filter 算子 &#xff08;四&#xff09;keyBy 算子 元组类型 POJO &#xff08;五&#xff09;reduce 算子 二、合并与连接操作 …

Top 10 Tools to Level Up Your Prompt Engineering Skills

此文章文字是转载翻译&#xff0c;图片是自已用AI 重新生成的。文字内容来自 https://www.aifire.co/p/top-10-ai-prompt-engineering-tools 供记录学习使用。 Introduction to AI Prompt Engineering AI Prompt Engineering 简介 1&#xff0c;Prompt Engineering 提示工程…

Rust语言俄罗斯方块(漂亮的界面案例+详细的代码解说+完美运行)

tetris-demo A Tetris example written in Rust using Piston in under 500 lines of code 项目地址: https://gitcode.com/gh_mirrors/te/tetris-demo 项目介绍 "Tetris Example in Rust, v2" 是一个用Rust语言编写的俄罗斯方块游戏示例。这个项目不仅是一个简单…

Spring Boot 与 Spring Cloud Alibaba 版本兼容对照

版本选择要点 Spring Boot 3.x 与 Spring Cloud Alibaba 2022.0.x Spring Boot 3.x 基于 Jakarta EE&#xff0c;javax.* 更换为 jakarta.*。 需要使用 Spring Cloud 2022.0.x 和 Spring Cloud Alibaba 2022.0.x。 Alibaba 2022.0.x 对 Spring Boot 3.x 的支持在其发行说明中…

(免费送源码)计算机毕业设计原创定制:Java+ssm+JSP+Ajax SSM棕榈校园论坛的开发

摘要 随着计算机科学技术的高速发展,计算机成了人们日常生活的必需品&#xff0c;从而也带动了一系列与此相关产业&#xff0c;是人们的生活发生了翻天覆地的变化&#xff0c;而网络化的出现也在改变着人们传统的生活方式&#xff0c;包括工作&#xff0c;学习&#xff0c;社交…