神经网络基础-神经网络补充概念-23-神经网络的梯度下降法

概念

神经网络的梯度下降法是训练神经网络的核心优化算法之一。它通过调整神经网络的权重和偏差,以最小化损失函数,从而使神经网络能够逐渐逼近目标函数的最优值。

步骤

1损失函数(Loss Function):
首先,我们定义一个损失函数,用来衡量神经网络预测值与真实标签之间的差距。常见的损失函数包括均方误差(Mean Squared Error)和交叉熵(Cross-Entropy)等。

2初始化参数:
在训练之前,需要随机初始化神经网络的权重和偏差。

4前向传播:
通过前向传播计算神经网络的输出,根据输入数据、权重和偏差计算每一层的激活值和预测值。

5计算损失:
使用损失函数计算预测值与真实标签之间的差距。

6反向传播:
反向传播是梯度下降法的关键步骤。它从输出层开始,计算每一层的误差梯度,然后根据链式法则将梯度传递回每一层。这样,可以得到关于权重和偏差的梯度信息,指导参数的更新。

7更新参数:
使用梯度信息,按照一定的学习率(learning rate)更新神经网络的权重和偏差。通常采用如下更新规则:新权重 = 旧权重 - 学习率 × 梯度。

8重复迭代:
重复执行前向传播、计算损失、反向传播和参数更新步骤,直到损失函数收敛或达到预定的迭代次数。

9评估模型:
在训练过程中,可以周期性地评估模型在验证集上的性能,以防止过拟合并选择合适的模型。

python实现

import numpy as np# 定义 sigmoid 激活函数及其导数
def sigmoid(x):return 1 / (1 + np.exp(-x))def sigmoid_derivative(x):return x * (1 - x)# 设置随机种子以保证可重复性
np.random.seed(42)# 生成模拟数据
X = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
y = np.array([[0], [1], [1], [0]])# 初始化权重和偏差
input_size = 2
output_size = 1
hidden_size = 4weights_input_hidden = np.random.uniform(-1, 1, (input_size, hidden_size))
bias_hidden = np.zeros((1, hidden_size))weights_hidden_output = np.random.uniform(-1, 1, (hidden_size, output_size))
bias_output = np.zeros((1, output_size))# 设置学习率和迭代次数
learning_rate = 0.1
epochs = 10000# 训练神经网络
for epoch in range(epochs):# 前向传播hidden_input = np.dot(X, weights_input_hidden) + bias_hiddenhidden_output = sigmoid(hidden_input)final_input = np.dot(hidden_output, weights_hidden_output) + bias_outputfinal_output = sigmoid(final_input)# 计算损失loss = np.mean(0.5 * (y - final_output) ** 2)# 反向传播d_output = (y - final_output) * sigmoid_derivative(final_output)d_hidden = d_output.dot(weights_hidden_output.T) * sigmoid_derivative(hidden_output)# 更新权重和偏差weights_hidden_output += hidden_output.T.dot(d_output) * learning_ratebias_output += np.sum(d_output, axis=0, keepdims=True) * learning_rateweights_input_hidden += X.T.dot(d_hidden) * learning_ratebias_hidden += np.sum(d_hidden, axis=0, keepdims=True) * learning_rateif epoch % 1000 == 0:print(f'Epoch {epoch}, Loss: {loss}')# 打印训练后的权重和偏差
print('Final weights_input_hidden:', weights_input_hidden)
print('Final bias_hidden:', bias_hidden)
print('Final weights_hidden_output:', weights_hidden_output)
print('Final bias_output:', bias_output)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/40754.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Springboot多路数据源

1、多路数据源配置 (1)SpringBootMyBatis-PlusOracle实现多数据源配置 https://blog.csdn.net/weixin_44812604/article/details/127386828 (2)SpringBootMybatis搭建Oracle多数据源配置简述 https://blog.csdn.net/HJW_233/arti…

网络安全 Day29-运维安全项目-iptables防火墙

iptables防火墙 1. 防火墙概述2. 防火墙2.1 防火墙种类及使用说明2.2 必须熟悉的名词2.3 iptables 执行过程※※※※※2.4 表与链※※※※※2.4.1 简介2.4.2 每个表说明2.4.2.1 filter表 :star::star::star::star::star:2.4.2.2 nat表 2.5 环境准备及命令2.6 案例01&#xff1a…

神经网络基础-神经网络补充概念-31-参数与超参数

概念 参数(Parameters): 参数是模型内部学习的变量,它们通过训练过程自动调整以最小化损失函数。在神经网络中,参数通常是连接权重(weights)和偏置(biases),…

ChatGLM2-6B安装部署(详尽版)

1、环境部署 安装Anaconda3 安装GIT 安装GUDA 11.8 安装NVIDIA 图形化驱动 522.25版本,如果电脑本身是更高版本则不用更新 1.1、检查CUDA 运行cmd或者Anaconda,运行以下命令 nvidia-smi CUDA Version是版本信息,Dricer Version是图形化…

LeetCode 160.相交链表

文章目录 💡题目分析💡解题思路🚩步骤一:找尾节点🚩步骤二:判断尾节点是否相等🚩步骤三:找交点🍄思路1🍄思路2 🔔接口源码 题目链接👉…

Ubuntu下mysql安装及远程连接支持配置

1.安装 下载mysql-server(必须加sudo) sudo apt update sudo apt install mysql-server 查看mysql的状态 sudo service mysql status 通过如下命令开启mysql sudo service mysql start 2.配置 第一次安装mysql后,为root设置一个密码 …

Linux -- 进阶 Autofs应用 : 光驱自动挂载 操作详解

服务端自动挂载光驱 第一步 : 关闭安全软件,安装自动挂载软件 [rootserver ~]# setenforce 0 [rootserver ~]# systemctl stop firewalld [rootserver ~]# yum install autofs -y 第二步 : 修改 autofs 主配置文件, 计划挂载光…

C++之map的emplace与pair插入键值对用例(一百七十四)

简介: CSDN博客专家,专注Android/Linux系统,分享多mic语音方案、音视频、编解码等技术,与大家一起成长! 优质专栏:Audio工程师进阶系列【原创干货持续更新中……】🚀 人生格言: 人生…

213、仿真-基于51单片机智能电表电能表用电量电费报警Proteus仿真设计(程序+Proteus仿真+原理图+配套资料等)

毕设帮助、开题指导、技术解答(有偿)见文未 目录 一、硬件设计 二、设计功能 三、Proteus仿真图 四、原理图 五、程序源码 资料包括: 需要完整的资料可以点击下面的名片加下我,找我要资源压缩包的百度网盘下载地址及提取码。 方案选择 单片机的选…

uniapp tabbar 浏览器调试显示 真机不显示

解决方案,把tabBar里面的单位全改为px,rpx是不会显示的! 注意了,改完一定要重新运行,不然无效,坑爹 "tabBar": {"borderStyle": "black","selectedColor": &quo…

java-JVM内存区域JVM运行时内存

一. JVM 内存区域 JVM 内存区域主要分为线程私有区域【程序计数器、虚拟机栈、本地方法区】、线程共享区域【JAVA 堆、方法区】、直接内存。线程私有数据区域生命周期与线程相同, 依赖用户线程的启动/结束 而 创建/销毁(在 HotspotVM 内, 每个线程都与操作系统的本地线程直接映…

SwiftUI 动画进阶:实现行星绕圆周轨道运动

0. 概览 SwiftUI 动画对于优秀 App 可以说是布帛菽粟。利用美妙的动画我们不仅可以活跃界面元素,更可以单独打造出一整套生动有机的世界,激活无限可能。 如上图所示,我们用动画粗略实现了一个小太阳系:8大行星围绕太阳旋转,而卫星们围绕各个行星旋转。 在本篇博文中,您将…

vue3实现防抖、单页面引入、全局引入、全局挂载

文章目录 代码实现单页面引入全局引入使用 代码实现 const debounce (fn: any, delay: number) > {let timer: any undefined;return (item: any) > {if (timer) clearTimeout(timer);timer setTimeout(() > fn(item), delay);} };export default debounce;单页面…

Python + Selenium 处理浏览器Cookie

工作中遇到这么一个场景:自动化测试登录的时候需要输入动态验证码,由于某些原因,需要从一个已登录的机器上,复制cookie过来,到自动化这边绕过登录。 浏览器的F12里复制出来的cookie内容是文本格式的: uui…

【第二讲---初识SLAM】

SLAM简介 视觉SLAM,主要指的是利用相机完成建图和定位问题。如果传感器是激光,那么就称为激光SLAM。 定位(明白自身状态(即位置))建图(了解外在环境)。 视觉SLAM中使用的相机与常见…

VB+SQL银行设备管理系统设计与实现

摘要 随着银行卡的普及,很多地方安装了大量的存款机、取款机和POS机等银行自助设备。银行设备管理系统可以有效的记录银行设备的安装和使用情况,规范对自助设备的管理,从而为用户提供更加稳定和优质的服务。 本文介绍了银行设备管理系统的设计和开发过程,详细阐述了整个应…

Flink之Task解析

Flink之Task解析 对Flink的Task进行解析前,我们首先要清楚几个角色TaskManager、Slot、Task、Subtask、TaskChain分别是什么 角色注释TaskManager在Flink中TaskManager就是一个管理task的进程,每个节点只有一个TaskManagerSlotSlot就是TaskManager中的槽位,一个TaskManager中可…

数据结构单链表

单链表 1 链表的概念及结构 概念:链表是一种物理存储结构上非连续、非顺序的存储结构,数据元素的逻辑顺序是通过链表中的指针链 接次序实现的 。 在我们开始讲链表之前,我们是写了顺序表,顺序表就是类似一个数组的东西&#xff0…

上海虚拟展厅制作平台怎么选,蛙色3DVR 助力行业发展

引言: 在数字化时代,虚拟展厅成为了企业宣传的重要手段。而作为一家位于上海的实力平台,上海蛙色3DVR凭借其卓越的功能和创新的技术,成为了企业展示和宣传的首选。 一、虚拟展厅的优势 虚拟展厅的崛起是指随着科技的进步&#x…

36_windows环境debug Nginx 源码-使用 VSCode 和WSL

文章目录 配置 WSL编译 NginxVSCode 安装插件launch.json配置 WSL sudo apt-get -y install gcc cmake sudo apt-get -y install pcre sudo apt-get -y install libpcre3 libpcre3-dev sudo apt-get