AI学习指南数学工具篇-梯度下降在机器学习中的应用

AI学习指南数学工具篇-梯度下降在机器学习中的应用

线性回归模型中的梯度下降

线性回归是一种用于建立预测模型的基本统计方法。在线性回归中,我们试图通过输入特征的线性组合来预测输出变量的值。梯度下降是一种优化算法,在线性回归模型中,我们可以使用梯度下降来找到使得模型误差最小化的最优参数。

应用梯度下降来训练线性回归模型

假设我们有一个包含n个样本的数据集,每个样本有m个特征。线性回归模型的预测值可以表示为:

y ^ = w 0 + w 1 x 1 + w 2 x 2 + . . . + w m x m \hat{y} = w_0 + w_1x_1 + w_2x_2 + ... + w_mx_m y^=w0+w1x1+w2x2+...+wmxm

其中, w 0 w_0 w0是偏置项, w 1 , w 2 , . . . , w m w_1, w_2, ..., w_m w1,w2,...,wm是权重。

我们的目标是找到一组最优的权重和偏置项,使得预测值 y ^ \hat{y} y^与真实值 y y y的误差最小化。我们可以定义误差函数(损失函数)为均方误差(MSE):

J ( w ) = 1 n ∑ i = 1 n ( y ( i ) − y ^ ( i ) ) 2 J(w) = \frac{1}{n}\sum_{i=1}^{n}(y^{(i)} - \hat{y}^{(i)})^2 J(w)=n1i=1n(y(i)y^(i))2

其中, w = [ w 0 , w 1 , w 2 , . . . , w m ] w=[w_0, w_1, w_2, ..., w_m] w=[w0,w1,w2,...,wm]是模型的参数。

梯度下降的目标是通过不断迭代更新参数 w w w,使得误差函数 J ( w ) J(w) J(w)最小化。具体地,梯度下降的迭代过程如下:

  1. 初始化参数 w w w的数值
  2. 计算误差函数 J ( w ) J(w) J(w)关于每个参数的偏导数(梯度)
  3. 更新参数 w w w w = w − α ⋅ ∇ J ( w ) w = w - \alpha \cdot\nabla J(w) w=wαJ(w)

其中, α \alpha α是学习率, ∇ J ( w ) \nabla J(w) J(w)是误差函数 J ( w ) J(w) J(w)关于参数 w w w的梯度。

我们可以通过如下示例来说明梯度下降在线性回归模型中的应用:

import numpy as np# 生成随机数据集
np.random.seed(0)
X = 2 * np.random.rand(100, 1)
y = 4 + 3 * X + np.random.randn(100, 1)# 初始化参数
w = np.random.randn(2, 1)# 使用梯度下降进行训练
n_iterations = 1000
learning_rate = 0.1
m = 100for iteration in range(n_iterations):gradients = 2/m * X.T.dot(X.dot(w) - y)w = w - learning_rate * gradients

通过上述示例,我们可以看到如何使用梯度下降来训练线性回归模型,不断迭代更新参数 w w w,直到误差函数 J ( w ) J(w) J(w)收敛。

逻辑回归模型中的梯度下降

逻辑回归是一种用于建立分类模型的方法。在逻辑回归中,我们试图通过输入特征的线性组合来预测离散的输出类别。梯度下降同样可以应用在逻辑回归模型中,帮助找到最优参数。

梯度下降在逻辑回归中的应用

在逻辑回归中,我们使用逻辑函数(sigmoid函数)来执行分类:

σ ( z ) = 1 1 + e − z \sigma(z) = \frac{1}{1+e^{-z}} σ(z)=1+ez1

其中, z z z是输入特征的线性组合:

z = w 0 + w 1 x 1 + w 2 x 2 + . . . + w m x m z = w_0 + w_1x_1 + w_2x_2 + ... + w_mx_m z=w0+w1x1+w2x2+...+wmxm

我们的目标是通过梯度下降来优化参数 w w w,使得逻辑回归模型的预测值与真实类别的误差最小化。

梯度下降的迭代更新过程与线性回归类似,不同之处在于在逻辑回归中,我们使用交叉熵损失函数:

J ( w ) = − 1 m ∑ i = 1 m [ y ( i ) log ⁡ ( y ^ ( i ) ) + ( 1 − y ( i ) ) log ⁡ ( 1 − y ^ ( i ) ) ] J(w) = -\frac{1}{m}\sum_{i=1}^{m}[y^{(i)}\log(\hat{y}^{(i)}) + (1-y^{(i)})\log(1-\hat{y}^{(i)})] J(w)=m1i=1m[y(i)log(y^(i))+(1y(i))log(1y^(i))]

其中, y y y是真实的类别, y ^ \hat{y} y^是模型的预测类别。

我们可以通过如下示例来说明梯度下降在逻辑回归模型中的应用:

import numpy as np
from sklearn.datasets import make_classification# 生成随机分类数据集
X, y = make_classification(n_samples=100, n_features=2, n_informative=2, n_redundant=0, n_clusters_per_class=1)# 初始化参数
w = np.random.randn(3, 1)# 添加偏置项
X_b = np.c_[np.ones((100, 1)), X]# 使用梯度下降进行训练
n_iterations = 1000
learning_rate = 0.1
m = 100for iteration in range(n_iterations):logits = X_b.dot(w)y_proba = 1 / (1 + np.exp(-logits))gradients = 1/m * X_b.T.dot(y_proba - y.reshape(-1, 1))w = w - learning_rate * gradients

通过上述示例,我们可以看到如何使用梯度下降来训练逻辑回归模型,不断迭代更新参数 w w w,直到误差函数 J ( w ) J(w) J(w)收敛。

深度学习中的梯度下降

在深度学习中,梯度下降同样是一种常用的优化算法,用于训练神经网络模型。

使用梯度下降来训练神经网络

在神经网络中,我们通常使用反向传播算法来计算每一层参数的梯度,然后使用梯度下降来更新参数。

具体地,梯度下降的迭代更新过程如下:

  1. 初始化神经网络的参数
  2. 前向传播计算预测值
  3. 反向传播计算每一层参数的梯度
  4. 更新参数: w = w − α ⋅ ∇ J ( w ) w = w - \alpha \cdot\nabla J(w) w=wαJ(w)

其中, α \alpha α是学习率, ∇ J ( w ) \nabla J(w) J(w)是神经网络损失函数 J ( w ) J(w) J(w)关于参数 w w w的梯度。

在深度学习中,梯度下降的变种有很多,如动量梯度下降、Nesterov加速梯度下降、Adagrad、RMSprop和Adam等。这些变种算法在梯度下降的基础上做了一些改进,提高了训练的速度和效果。

综上所述,梯度下降在机器学习中有着广泛的应用,无论是在线性回归、逻辑回归还是深度学习中,梯度下降都是一种重要的优化算法。通过不断迭代更新参数,梯度下降可以帮助我们找到最优的模型参数,使得模型的预测值与真实值之间的误差最小化,从而提高模型的预测性能。

希望通过本篇博客,读者对梯度下降在机器学习中的应用有了更深入的理解。如果对于梯度下降还有疑问,欢迎留言讨论。

参考文献:

  • [1] Hands-On Machine Learning with Scikit-Learn and TensorFlow. Aurélien Géron. O"Reilly Media. 2017.
  • [2] Deep Learning. Ian Goodfellow, Yoshua Bengio, and Aaron Courville. MIT Press. 2016.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/15964.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

微信小程序源码-基于Java后端的网上商城系统毕业设计(附源码+演示录像+LW)

大家好!我是程序员一帆,感谢您阅读本文,欢迎一键三连哦。 💞当前专栏:微信小程序毕业设计 精彩专栏推荐👇🏻👇🏻👇🏻 🎀 Python毕业设…

题解 P1150

题解 P1150 因为k个烟蒂1根烟1个烟蒂 所以k-1个烟蒂1根烟 注意减掉最后一根烟的烟蒂 (因这题并没有借烟蒂换烟再还回这一说) 此解法为小学4~6年级水平 #include <bits/stdc.h>using namespace std;int main(){int n,k;cin>>n>>k;cout<<n(n-1)/(k-…

代码随想录——找树左下角的值(Leetcode513)

题目链接 层序遍历 思路&#xff1a;使用层序遍历&#xff0c;记录每一行 i 0 的元素&#xff0c;就可以找到树左下角的值 /*** Definition for a binary tree node.* public class TreeNode {* int val;* TreeNode left;* TreeNode right;* TreeNode() {}*…

北核论文完美复现:自适应t分布与动态边界策略改进的算术优化算法

声明&#xff1a;文章是从本人公众号中复制而来&#xff0c;因此&#xff0c;想最新最快了解各类智能优化算法及其改进的朋友&#xff0c;可关注我的公众号&#xff1a;强盛机器学习&#xff0c;不定期会有很多免费代码分享~ 目录 原始算术优化算法 改进点1&#xff1a;引入…

【Linux】Ubuntu系统挂载NAS文件夹

测试系统&#xff1a;Ubuntu24.02 1. 安装必要的软件包 sudo apt update sudo apt install cifs-utils 2. 创建挂载点 sudo mkdir -p /mnt/nas 3. 获取当前用户的 UID 和 GID id -u id -g 4. 挂载&#xff1a;设置用户名/密码/nas地址 sudo mount -t cifs -o username,…

【网络】socket套接字结合IO多路复用

引言 在多线程编程中&#xff0c;I/O 多路复用&#xff08;如 select、poll 或 epoll&#xff09;可以与多线程结合使用&#xff0c;以提高系统的并发处理能力和效率。结合多线程和 I/O 多路复用&#xff0c;可以实现高性能的网络服务器和客户端。以下是一些常见的多线程和 I/…

vue+css解决图片变形问题(flex-shrink: 0)

解决前 给图片添加 flex-shrink: 0;即可解决图片变形问题

基于springboot+vue的致远汽车租赁系统

开发语言&#xff1a;Java框架&#xff1a;springbootJDK版本&#xff1a;JDK1.8服务器&#xff1a;tomcat7数据库&#xff1a;mysql 5.7&#xff08;一定要5.7版本&#xff09;数据库工具&#xff1a;Navicat11开发软件&#xff1a;eclipse/myeclipse/ideaMaven包&#xff1a;…

东方通TongWeb结合Spring-Boot使用

一、概述 信创需要; 原状:原来的服务使用springboot框架,自带的web容器是tomcat,打成jar包启动; 需求:使用东方通tongweb来替换tomcat容器; 二、替换步骤 2.1 准备 获取到TongWeb7.0.E.6_P7嵌入版 这个文件,文件内容有相关对应的依赖包,可以根据需要来安装到本地…

上5个B端系统的设计规范,让你的开发比着葫芦画瓢。

B端系统设计规范在企业级系统开发中起着重要的作用&#xff0c;具体包括以下几个方面&#xff1a; 统一风格和布局&#xff1a;设计规范能够统一系统的风格和布局&#xff0c;使不同功能模块的界面看起来一致&#xff0c;提升用户的使用体验和学习成本。通过统一的设计规范&am…

Day42 最后一块石头的重量Ⅱ + 目标和 + 一和零

1049 最后一块石头的重量Ⅱ 题目链接&#xff1a;1049.最后一块石头的重量Ⅱ 有一堆石头&#xff0c;用整数数组 stones 表示。其中 stones[i] 表示第 i 块石头的重量。 每一回合&#xff0c;从中选出任意两块石头&#xff0c;然后将它们一起粉碎。假设石头的重量分别为 x 和…

力扣 5. 最长回文子串 python AC

动态规划 class Solution:def longestPalindrome(self, s):size len(s)maxl 1start 0dp [[False] * size for _ in range(size)]for i in range(size):dp[i][i] Truefor L in range(2, size 1):for i in range(size):j L i - 1if j > size:breakif s[i] s[j]:if L…

AI大模型领域新闻跟踪

杨值麟 月之暗面杨植麟&#xff1a;大模型开发是“承包森林”月之暗面集结最强创投&#xff0c;“清华师姐”是最强“助攻”月之暗面杨植麟&#xff1a;互联网研发是“种树”&#xff0c;大模型研发是“承包森林”月之暗面杨植麟复盘大模型创业这一年&#xff1a;向延绵而未知…

Web课外练习9

<!DOCTYPE html> <html> <head><meta charset"utf-8"><title>邮购商品业务</title><!-- 引入vue.js --><script src"./js/vue.global.js" type"text/javascript"></script><link rel&…

B2117 整理药名

整理药名 题目描述 医生在书写药品名的时候经常不注意大小写&#xff0c;格式比较混乱。现要求你写一个程序将医生书写混乱的药品名整理成统一规范的格式&#xff0c;即药品名的第一个字符如果是字母要大写&#xff0c;其他字母小写。 如将 ASPIRIN 、 aspirin 整理成 Aspir…

python获取cookie的方式

通过js获取cookie&#xff0c;避免反复登录操作。 经验证在JD上没有用&#xff0c;cookie应该无痕或者加密了&#xff0c;只能用单浏览器不关的模式来实现&#xff0c;但是代码留着&#xff0c;其他网站可能有用。 def cookie_set():driver webdriver.Chrome(optionschrome_…

一千题,No.0024(数素数)

令 Pi​ 表示第 i 个素数。现任给两个正整数 M≤N≤104&#xff0c;请输出 PM​ 到 PN​ 的所有素数。 输入格式&#xff1a; 输入在一行中给出 M 和 N&#xff0c;其间以空格分隔。 输出格式&#xff1a; 输出从 PM​ 到 PN​ 的所有素数&#xff0c;每 10 个数字占 1 行&…

Kubernetes 文档 / 概念 / 工作负载 / 工作负载管理 / Job

Kubernetes 文档 / 概念 / 工作负载 / 工作负载管理 / Job 此文档从 Kubernetes 官网摘录 中文地址 英文地址 Job 会创建一个或者多个 Pod&#xff0c;并将继续重试 Pod 的执行&#xff0c;直到指定数量的 Pod 成功终止。 随着 Pod 成功结束&#xff0c;Job 跟踪记录成功完成的…

原哥花了1个多月的时间终于开发了一款基于android studio的原生商城app

大概讲一下这个app实现的功能和前后端技术架构。 功能简介 广告展示商品展示跳转淘宝联盟优惠卷购买发布朋友圈宝妈知识资讯商品搜索朋友圈展示/点赞/评论登陆注册版本升级我的个人资料商品和资讯收藏我的朋友圈意见反馈 安卓端技术选型 Arouter组件化daggerrxjavaretrofit…

基于开源二兄弟MediaPipe+Rerun实现人体姿势跟踪可视化

概述 本文中&#xff0c;我们将探索一个利用开源框架MediaPipe的功能以二维和三维方式跟踪人体姿势的使用情形。使这一探索更有趣味的是由开源可视化工具Rerun提供的可视化展示&#xff0c;该工具能够提供人类动作姿势的整体视图。 您将一步步跟随作者使用MediaPipe在2D和3D环…