【深度学习】—线性回归 线性回归的基本元素 线性模型 损失函数 解析解 随机梯度下降

【深度学习】— 线性回归 线性回归的基本元素 线性模型 损失函数 解析解 随机梯度下降

  • 线性回归
    • 线性回归的基本元素
  • 线性模型
  • 损失函数
  • 解析解
  • 随机梯度下降
    • 小批量随机梯度下降
        • 梯度下降算法的详细步骤
        • 解释公式

线性回归

回归(regression)是能为⼀个或多个⾃变量与因变量之间关系建模的⼀类⽅法。在⾃然科学和社会科学领域,回归经常⽤来表⽰输⼊和输出之间的关系。在机器学习领域中的⼤多数任务通常都与预测(prediction)有关。当我们想预测⼀个数值时,就会涉及到回归问题。常⻅的例⼦包括:预测价格(房屋、股票等)、预测住院时间(针对住院病⼈等)、预测需求(零售销量等)。但不是所有的预测都是回归问题。在后⾯的章节中,我们将介绍分类问题。分类问题的⽬标是预测数据属于⼀组类别中的哪⼀个。

线性回归的基本元素

线性回归(linear regression)可以追溯到19世纪初,它在回归的各种标准工具中最简单而且最流行。线性回归基于几个简单的假设:首先,假设自变量 x \mathbf{x} x 和因变量 y y y 之间的关系是线性的,即 y y y 可以表示为 x \mathbf{x} x 中元素的加权和,这里通常允许包含观测值的一些噪声;其次,我们假设任何噪声都比较正常,如噪声遵循正态分布。

为了解释线性回归,我们举一个实际的例子:我们希望根据房屋的面积(平方英尺)和房龄(年)来估算房屋价格(美元)。为了开发一个能预测房价的模型,我们需要收集一个真实的数据集。这个数据集包括了房屋的销售价格、面积和房龄。在机器学习的术语中,该数据集称为训练数据集(training data set)或训练集(training set)。每行数据(比如一次房屋交易相对应的数据)称为样本(sample),也可以称为数据点(data point)或数据样本(data instance)。我们把试图预测的目标(比如预测房屋价格)称为标签(label)或目标(target)。预测所依据的自变量(面积和房龄)称为特征(feature)或协变量(covariate)。

通常,我们使用 n n n 来表示数据集中的样本数。对索引为 i i i 的样本,其输入表示为 x ( i ) = [ x 1 ( i ) , x 2 ( i ) ] ⊤ \mathbf{x}^{(i)} = [x_1^{(i)}, x_2^{(i)}]^\top x(i)=[x1(i),x2(i)],其对应的标签是 y ( i ) y^{(i)} y(i)

线性模型

线性假设是指目标(房屋价格)可以表示为特征(面积和房龄)的加权和,如下面的式子:

price = w area ⋅ area + w age ⋅ age + b (3.1.1) \text{price} = w_{\text{area}} \cdot \text{area} + w_{\text{age}} \cdot \text{age} + b \tag{3.1.1} price=wareaarea+wageage+b(3.1.1)

在 (3.1.1) 中, w area w_{\text{area}} warea w age w_{\text{age}} wage 称为权重(weight),权重决定了每个特征对我们预测值的影响。 b b b 称为偏置(bias)、偏移量(offset)或截距(intercept)。偏置是指当所有特征都取值为 0 时,预测值应该为多少。即使现实中不会有任何房子的面积是 0 或房龄正好是 0 年,我们仍然需要偏置项。如果没有偏置项,我们模型的表达能力将受到限制。严格来说,(3.1.1) 是输入特征的一个仿射变换(affine transformation)。仿射变换的特点是通过加权和对特征进行线性变换(linear transformation),并通过偏置项来进行平移(translation)。

给定一个数据集,我们的目标是寻找模型的权重 w \mathbf{w} w 和偏置 b b b,使得根据模型做出的预测大体符合数据里的真实价格。输出的预测值由输入特征通过线性模型的仿射变换决定,仿射变换由所选权重和偏置确定。

而在机器学习领域,我们通常使用的是高维数据集,建模时采用线性代数表示法会比较方便。当我们的输入包含 d d d 个特征时,我们将预测结果 y ^ \hat{y} y^ (通常使用“尖角”符号表示 y y y 的估计值)表示为:

y ^ = w 1 x 1 + . . . + w d x d + b (3.1.2) \hat{y} = w_1 x_1 + ... + w_d x_d + b \tag{3.1.2} y^=w1x1+...+wdxd+b(3.1.2)

将所有特征放到向量 x ∈ R d \mathbf{x} \in \mathbb{R}^d xRd 中,并将所有权重放到向量 w ∈ R d \mathbf{w} \in \mathbb{R}^d wRd 中,我们可以用点积形式来简洁地表达模型:

y ^ = w ⊤ x + b (3.1.3) \hat{y} = \mathbf{w}^\top \mathbf{x} + b \tag{3.1.3} y^=wx+b(3.1.3)

在 (3.1.3) 中,向量 x \mathbf{x} x 对应于单个数据样本的特征。用符号表示的矩阵 X ∈ R n × d \mathbf{X} \in \mathbb{R}^{n \times d} XRn×d 可以很方便地引用我们整个数据集的 n n n 个样本。其中, X \mathbf{X} X 的每一行是一个样本,每一列是一种特征。

对于特征集合 X \mathbf{X} X,预测值 y ^ ∈ R n \hat{\mathbf{y}} \in \mathbb{R}^n y^Rn 可以通过矩阵-向量乘法表示为:

y ^ = X w + b (3.1.4) \hat{\mathbf{y}} = \mathbf{X} \mathbf{w} + b \tag{3.1.4} y^=Xw+b(3.1.4)

这个过程中的求和将使用广播机制(广播机制在 2.1.3 节中有详细介绍)。给定训练数据特征 X \mathbf{X} X 和对应的已知标签 y \mathbf{y} y,线性回归的目标是找到一组权重向量 w \mathbf{w} w 和偏置 b b b:当给定从 X \mathbf{X} X 的同分布中取样的新样本特征时,这组权重向量和偏置能够使得新样本预测标签的误差尽可能小。

虽然我们相信给定 x \mathbf{x} x 预测 y y y 的最佳模型会是线性的,但我们很难找到一个有 n n n 个样本的真实数据集,其中对于所有的 1 ≤ i ≤ n 1 \le i \le n 1in y ( i ) y^{(i)} y(i) 完全等于 w ⊤ x ( i ) + b \mathbf{w}^\top \mathbf{x}^{(i)} + b wx(i)+b。无论我们使用什么手段来观察特征 X \mathbf{X} X 和标签 y \mathbf{y} y,都可能会出现少量的观测误差。因此,即使确信特征与标签的潜在关系是线性的,我们也会加入一个噪声项来考虑观测误差带来的影响。

在开始寻找最好的模型参数(model parameters) w \mathbf{w} w b b b 之前,我们还需要两个东西:(1)一种模型质量的度量方式;(2)一种能够更新模型以提高模型预测质量的方法。

损失函数

在开始考虑如何用模型拟合数据之前,我们需要确定一个衡量拟合程度的标准。损失函数(loss function)用于量化目标的真实值与预测值之间的差距,通常是一个非负数,数值越小表示模型预测越准确,完美预测时损失为 0。

在回归问题中,最常用的损失函数是平方误差函数。对于样本 i i i,假设预测值为 y ^ ( i ) \hat{y}^{(i)} y^(i),实际标签为 y ( i ) y^{(i)} y(i),则平方误差可以定义为:

l ( i ) ( w , b ) = 1 2 ( y ^ ( i ) − y ( i ) ) 2 (3.1.5) l^{(i)}(w, b) = \frac{1}{2} (\hat{y}^{(i)} - y^{(i)})^2 \tag{3.1.5} l(i)(w,b)=21(y^(i)y(i))2(3.1.5)

这里的 1 2 \frac{1}{2} 21 只是为了在求导时使公式更简洁,不影响结果的本质。
为⼀维情况下的回归问题绘制图像
在这里插入图片描述

为了衡量模型在整个数据集上的表现,我们计算所有 n n n 个样本的损失均值,称为平均损失:

L ( w , b ) = 1 n ∑ i = 1 n l ( i ) ( w , b ) = 1 n ∑ i = 1 n 1 2 ( w ⊤ x ( i ) + b − y ( i ) ) 2 (3.1.6) L(\mathbf{w}, b) = \frac{1}{n} \sum_{i=1}^{n} l^{(i)}(\mathbf{w}, b) = \frac{1}{n} \sum_{i=1}^{n} \frac{1}{2} ( \mathbf{w}^\top \mathbf{x}^{(i)} + b - y^{(i)})^2 \tag{3.1.6} L(w,b)=n1i=1nl(i)(w,b)=n1i=1n21(wx(i)+by(i))2(3.1.6)

在训练模型时,我们的目标是寻找一组参数 ( w ∗ , b ∗ ) (\mathbf{w}^*, b^*) (w,b),使得所有训练样本上的总损失达到最小,即:

w ∗ , b ∗ = arg ⁡ min ⁡ w , b L ( w , b ) (3.1.7) \mathbf{w}^*, b^* = \arg \min_{\mathbf{w}, b} L(\mathbf{w}, b) \tag{3.1.7} w,b=argw,bminL(w,b)(3.1.7)

解析解

线性回归是一个非常简单的优化问题,与大部分复杂模型不同,它的解可以用一个明确的公式表达出来,这类解被称为解析解(analytical solution)。

首先,我们将偏置 b b b 合并到参数 w \mathbf{w} w 中,方法是将所有参数包含在一个矩阵中并附加一列常数 1。这样,预测问题就变成了最小化 ∥ y − X w ∥ 2 \|\mathbf{y} - \mathbf{X}\mathbf{w}\|^2 yXw2。由于损失函数在整个平面上只有一个临界点,这个临界点即对应于整个区域的最小损失。

通过对 w \mathbf{w} w 求导并设导数为 0,可以得到线性回归的解析解:

w ∗ = ( X ⊤ X ) − 1 X ⊤ y (3.1.8) \mathbf{w}^* = (\mathbf{X}^\top \mathbf{X})^{-1} \mathbf{X}^\top \mathbf{y} \tag{3.1.8} w=(XX)1Xy(3.1.8)

像线性回归这样的简单问题存在解析解,但并不是所有的问题都能找到解析解。尽管解析解在数学上易于分析,但由于它对问题的要求很严格,因此无法广泛应用于深度学习。

随机梯度下降

即使在无法得到解析解的情况下,我们仍然可以有效地训练模型。实际上,在许多任务中,那些难以优化的模型往往表现得更好,因此,学会如何训练这些难以优化的模型非常重要。

梯度下降(gradient descent)它几乎可以优化所有深度学习模型。梯度下降通过不断沿着损失函数递减的方向更新参数来减少误差。

小批量随机梯度下降

梯度下降最直接的方法是计算整个数据集上损失函数关于模型参数的导数(即梯度),但这种方式在实践中可能非常慢,因为每次更新参数之前都需要遍历整个数据集。为此,我们通常采用一种称为小批量随机梯度下降(minibatch stochastic gradient descent)的变体。

在每次迭代中,我们首先随机抽取一个小批量 B B B,它由固定数量的训练样本组成。然后,计算小批量样本的平均损失关于模型参数的导数。最后,我们将梯度乘以一个预先设定的正数 η \eta η,并从当前参数值中减去这个结果。用数学公式表示该更新过程:

( w , b ) ← ( w , b ) − η 1 ∣ B ∣ ∑ i ∈ B ∂ ∂ ( w , b ) l ( i ) ( w , b ) (3.1.9) (\mathbf{w}, b) \leftarrow (\mathbf{w}, b) - \eta \frac{1}{|B|} \sum_{i \in B} \frac{\partial}{\partial (\mathbf{w}, b)} l^{(i)}(\mathbf{w}, b) \tag{3.1.9} (w,b)(w,b)ηB1iB(w,b)l(i)(w,b)(3.1.9)

梯度下降算法的详细步骤

具体来说,梯度下降算法的步骤如下:

  1. 初始化参数:随机初始化模型参数 w w w b b b
  2. 从数据集中随机抽取小批量样本:从训练数据中随机抽取小批量样本 B B B
  3. 计算梯度并更新参数:计算损失函数在小批量上的梯度,并沿着负梯度方向更新参数。
  4. 重复步骤 2 和 3,直到达到预定的迭代次数或满足其他停止条件。

针对平方损失和仿射变换,我们可以明确写出参数更新的公式:

w ← w − η 1 ∣ B ∣ ∑ i ∈ B x ( i ) ( w ⊤ x ( i ) + b − y ( i ) ) \mathbf{w} \leftarrow \mathbf{w} - \eta \frac{1}{|B|} \sum_{i \in B} \mathbf{x}^{(i)} \left( \mathbf{w}^\top \mathbf{x}^{(i)} + b - y^{(i)} \right) wwηB1iBx(i)(wx(i)+by(i))

b ← b − η 1 ∣ B ∣ ∑ i ∈ B ( w ⊤ x ( i ) + b − y ( i ) ) (3.1.10) b \leftarrow b - \eta \frac{1}{|B|} \sum_{i \in B} \left( \mathbf{w}^\top \mathbf{x}^{(i)} + b - y^{(i)} \right) \tag{3.1.10} bbηB1iB(wx(i)+by(i))(3.1.10)

解释公式
  • 在上面的公式中, w \mathbf{w} w 是权重向量, x ( i ) \mathbf{x}^{(i)} x(i) 是样本 i i i 的特征向量。
  • η \eta η 是学习率,控制每次参数更新的步长。如果 η \eta η 太大,参数可能会来回振荡;如果 η \eta η 太小,收敛速度会非常慢。
  • ∣ B ∣ |B| B 表示小批量的大小,这决定了每次计算梯度时使用的样本数量。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/54146.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

计算机前沿技术-人工智能算法-大语言模型-最新研究进展-2024-10-01

计算机前沿技术-人工智能算法-大语言模型-最新研究进展-2024-10-01 目录 文章目录 计算机前沿技术-人工智能算法-大语言模型-最新研究进展-2024-10-01目录1. Beyond Text-to-Text: An Overview of Multimodal and Generative Artificial Intelligence for Education Using Topi…

QT-MySQL QSqlDatabase: QMYSQL driver not loaded

文章目录 问题解决操作:自己尝试编译,各种错误层出不穷: 解决问题检查总结: 问题 使用Qt连接mysql数据库,遇到了一个问题,就是QT5.14.1版本在连接MySQL数据库时候,提示驱动加载失败&#xff0c…

麒麟操作系统部分目录介绍

图形系统目录 文字系统目录 (1)/bin:存放普通用户可以使用的命令文件。 (2)/boot:包含内核和其它系统程序启动时使用的文件。 (3)/dev:设备文件所在目录。在操作系统中…

数据结构 ——— 单链表oj题:返回链表的中间节点

目录 题目要求 手搓简易单链表 代码实现 题目要求 给你单链表的头节点 head ,请你找出并返回链表的中间节点 如果有两个中间节点,则返回第二个中间节点 要求算法的时间复杂度为:O(N) 手搓简易单链表 代码演示: // 单链表中…

Java Web应用升级故障案例解析

在一次Java Web应用程序的优化升级过程中,从Tomcat 7.0.109版本升级至8.5.93版本后,尽管在预发布环境中验证无误,但在灰度环境中却发现了一个令人困惑的问题:新日志记录神秘“失踪”。本文深入探讨了这一问题的排查与解决过程&…

【湖南步联科技身份证】 身份证读取与酒店收银系统源码整合———未来之窗行业应用跨平台架构

一、html5 <!DOCTYPE html> <html><head><meta http-equiv"Content-Type" content"text/html; charsetutf-8" /><script type"text/javascript" src"http://51.onelink.ynwlzc.net/o2o/tpl/Merchant/static/js…

onload_tcpdump命令抓包报错Onload stack [7,] already has tcpdump process

最近碰到Onload 不支持同时运行多个 tcpdump 进程的报错&#xff0c;实际上使用了ps查询当时系统中并没有tcpdump相关进程存在。需要重启服务器本机使用onload加速的相关进程后才能使用onload_tcpdump正常抓包&#xff0c;很奇怪&#xff0c;之前确实没遇到这样的问题&#xff…

Golang | Leetcode Golang题解之第450题删除二叉搜索树的节点

题目&#xff1a; 题解&#xff1a; func deleteNode(root *TreeNode, key int) *TreeNode {var cur, curParent *TreeNode root, nilfor cur ! nil && cur.Val ! key {curParent curif cur.Val > key {cur cur.Left} else {cur cur.Right}}if cur nil {retur…

金镐开源组织成立,增加最新KIT技术,望能为开源添一把火

国内做开源的很多&#xff0c;知名的若依、芋道源码、Pig、Guns等&#xff0c;可谓是百花齐放&#xff0c;虽然比不上Apache&#xff0c;但也大大提高了国内的生产力。经过多年的发展&#xff0c;一些开源项目逐渐也都开始商业化。基于这样的背景&#xff0c;我拉拢了三个技术人…

【重学 MySQL】三十九、Having 的使用

【重学 MySQL】三十九、Having 的使用 基本语法示例示例 1&#xff1a;使用 HAVING 过滤分组示例 2&#xff1a;HAVING 与 WHERE 的结合使用 注意点WHERE 与 HAVING 的对比基本定义与用途主要区别示例对比总结 在 MySQL 中&#xff0c;HAVING 子句主要用于对 GROUP BY 语句产生…

使用powershell的脚本报错:因为在此系统中禁止执行脚本

1.添加powershell功能环境&#xff1a; 2.启动powershell的执行策略 因为在此系统中禁止执行脚本。 set-executionpolicy unrestricted

【计算机视觉】ch1-Introduction

相机模型与成像 1. 世界坐标系 (World Coordinate System) 世界坐标系是指物体在真实世界中的位置和方向的表示方式。在计算机视觉和图像处理领域&#xff0c;世界坐标系通常是一个全局坐标系统&#xff0c;描述了摄像机拍摄到的物体在实际三维空间中的位置。它是所有其他坐标…

刷题day11 栈与队列下【逆波兰表达式求值】【滑动窗口最大值】【前 K 个高频元素】

⚡刷题计划day11 栈与队列继续&#xff0c;可以点个免费的赞哦~ 往期可看专栏&#xff0c;关注不迷路&#xff0c; 您的支持是我的最大动力&#x1f339;~ 目录 ⚡刷题计划day11 栈与队列继续&#xff0c;可以点个免费的赞哦~ 往期可看专栏&#xff0c;关注不迷路&#xf…

无心剑七绝《华夏中兴》

七绝华夏中兴 长空万里尽春声 治世群英喜纵横 一代雄才华夏梦 中兴日月照前程 2024年10月1日 平水韵八庚平韵 无心剑的七绝《华夏中兴》通过对自然景观和国家景象的描绘&#xff0c;展现了一种恢弘的气势和对未来的美好愿景。 意境开阔&#xff1a;首句“长空万里尽春声”以广阔…

MATLAB数字水印系统

课题介绍 本课题为基于MATLAB的小波变换dwt和离散余弦dct的多方法对比数字水印系统。带GUI交互界面。有一个主界面GUI&#xff0c;可以调用dwt方法的子界面和dct方法的子界面。流程包括&#xff0c;读取宿主图像和水印图像&#xff0c;嵌入&#xff0c;多种方法的攻击&#xf…

sysbench 命令:跨平台的基准测试工具

一、命令简介 sysbench 是一个跨平台的基准测试工具&#xff0c;用于评估系统性能&#xff0c;包括 CPU、内存、文件 I/O、数据库等性能。 ‍ 比较同类测试工具 bench.sh 在上文 bench.sh&#xff1a;Linux 服务器基准测试中介绍了 bench.sh 一键测试脚本&#xff0c;它对…

GAMES101(17~18节,物理材质模型)

材质 BRDF 材质&#xff1a;决定了光线与物体不同的作用方式 BRDF定义了物体材质,包含漫反射和镜面部分 BSDF &#xff08;scattering散射&#xff09; BRDF&#xff08;reflect反射&#xff09; BTDF 光线打击到物体上会向四面八方散射 反射 光线打击到物体上反射出去…

IIS开启后https访问出错net::ERR_CERT_INVALID

安装ArcGIS server和portal等&#xff0c;按照说明上&#xff0c;先开启iis&#xff0c;在安装server、datastore、portal、webadapter等&#xff0c;遇到一些问题&#xff1a; 问题1 访问http正常&#xff0c;访问https出错&#xff1a; 解决方案 从这里找到解决方案&…

【Android 源码分析】Activity生命周期之onPause

忽然有一天&#xff0c;我想要做一件事&#xff1a;去代码中去验证那些曾经被“灌输”的理论。                                                                                  – 服装…

通信工程学习:什么是CSMA/CA载波监听多路访问/冲突避免

CSMA/CA&#xff1a;载波监听多路访问/冲突避免 CSMA/CA&#xff08;Carrier Sense Multiple Access/Collision Avoidance&#xff09;&#xff0c;即载波监听多路访问/冲突避免&#xff0c;是一种用于数据传输时避免各站点之间冲突的算法&#xff0c;尤其适用于无线局域网&…