动手学深度学习(二)线性神经网络

推荐课程:跟李沐学AI的个人空间-跟李沐学AI个人主页-哔哩哔哩视频

回归任务是指对连续变量进行预测的任务。

一、线性回归

线性回归模型是一种常用的统计学习方法,用于分析自变量与因变量之间的关系。它通过建立一个关于自变量和因变量的线性方程,来对未知数据进行预测。

1.1 线性模型

举个例子,房价预测模型

  • 假设1︰影响房价的关键因素是卧室个数,卫生间个数和居住面积,记为x1,x2,x3。
  • 假设2:成交价是关键因素的加权和,y = w_1x_1 + w_2x_2 + w_3x_3 + b

权重w和偏差b的实际值在后面决定。

  • 给定n维输入,x=[x_1,x_2, ....x_n]^T,向量x对应于单个数据样本的特征
  • 线性模型有一个n维权重和一个标量偏差,w =[w_1, w_2, ..., w_n]^Tb权重w决定了每个特征对预测值的影响。偏置b是指当所有的特征都取0时,预测值应为多少。
  • 输出是输入的加权和,\hat{y} = w_1x_1+w_2x_2+ ...+ w_nx_n + b。我们常用\hat{y}表示预测值

则,该房价预测模型为:\hat{y} = w^Tx+ b,这是一个线性预测模型。给定一个数据集(如x),我们的目标就是寻找模型的权重w和偏置b,使得根据模型做出的预测大体符合数据中真实价格y。也是就说最佳的权重w和偏置b有能力使得预测值\hat{y}逼近真实值y,找到最佳的权重w和偏置b这是我们的最终目的。

1.2 损失函数(衡量预估质量)

用于比较真实值和预估值的差异,即以特定规则计算真实值和预估值的差值,例如房屋售价和估价。

假设y是真实值,\hat{y}是预测值,平方差损失\ell(y,\hat{y})=(y-\hat{y})^2,我们以该函数作为损失函数。

设训练集有n个样本,则这n个样本的损失均值

             L(w, b)=\frac{1}{n}\sum_{n}^{i=1}\ell^i(y,\hat{y})=\frac{1}{n}\sum_{n}^{i=1}(y_i-\hat{y_i})^2=\frac{1}{n}\sum_{n}^{i=1}(y_i-w^Tx_i+b)^2

Q:那么损失函数,对我们找到最优的权重w和偏置b有什么帮助呢?

我们可以看到,最佳的预测值与真实值之间的损失值一定是尽可能小的,因此我们只要求得最小的损失值,那么得到这个损失值的权重w和偏置b一定是最优的。

Q:怎么求得最小的损失值呢?

如,平方差损失函数是一个凹函数,那么求解最小的损失值,我们只需要将该函数关于w的偏导数设为0,求导即可。求解得到的w就是最优的权重w。预测出的预估值\hat{y}也就最接近真实值。这类解称为解析解。

二、基础优化算法(梯度下降算法)

在绝大多数的情况下,损失函数是很复杂的(比如逻辑回归),根本无法得到参数估计值的表达式,也就无从获取没有显示解(解析解)

此需要一种对大多数函数都适用的方法,这就引出了“梯度下降算法”,这种方法几乎可以优化所有深度学习模型它通过不断地在损失函数递减的方向上更新参数来降低误差(原理)

2.1 梯度下降公式

首先,我们需要确定初始化模型的参数w_0,接下来重复迭代更新参数t=1、2、3、....、n,更新权重的公式为:

其中,\textup{w}_{t-1}为上一次更新权重的结果,\eta为学习率(这是一个超参数,决定了每次参数更新的步长),\frac{\partial \ell }{\partial \textup{w}_{t-1}}损失函数递增的方向(注意公式中为负)。

2.2 选择学习率

梯度下降的过程宛如一个人在走下山路,一步一步地接近谷底,学习率相当于这个人的步长

学习率的选取不易过大,也不宜过小。学习率选取过大会使得权重更新的过程一直在震荡,而不是真正的在下降。学习率选取过小,会使得权重更新的过程十分缓慢,影响效率。

2.3 小批量随机梯度下降

一个神经网络模型的训练可能需要几分钟至数个小时,我们可以采用小批量随机梯度下降的方式来加快这一过程。

在整个训练集上计算梯度太昂贵了,因此可以随机采用 b 个样本i_1,i_2,...,i_b来求取整个训练集的近似损失(原理)。求近似损失公式为:

 其中,b批量大小,另一个重要的超参数。

Q:如何选择批量大小?

选择批量大小不能太小,也不能太大。批量大小选择过小,则每次计算量太小,不适合并行来最大利用计算资源。批量大小选择过大,内存消耗增加浪费计算,例如如果所有样本都是相同的。

三、线性回归的从零开始实现(代码实现)

3.1 生成数据集

首先,我们根据带有噪声的线性模型构造一个人造数据集,我们的目的是通过这个数据集来还原线性模型中正确的参数。

我们使用线性模型参数 \textup{w}=[2,-3.4]^Tb = 4.2​ 和噪声项 \varepsilon 生成数据集及其标签。

# 生成数据集
def synthetic_data(w, b, num_examples):"""生成 y=Xw + b + 噪声"""X = torch.normal(0, 1, (num_examples, len(w))) # 正态分布(均值为0,标准差为1)y = torch.matmul(X, w) + b # 矩阵相乘y += torch.normal(0, 0.01, y.shape) # 加入噪声项# 得到的y为行向量的形式,为了使其变为一列的形式需要进行reshapereturn X, y.reshape((-1, 1))

3.2 传输数据集

def data_iter(batch_size, features, labels):num_examples = len(features)indices = list(range(num_examples))# 这些样本是随机读出的,没有特定的顺序random.shuffle(indices)for i in range(0, num_examples, batch_size):batch_indices = torch.tensor(indices[i:min(i+batch_size,num_examples)])yield features[batch_indices],labels[batch_indices]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/22595.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

分布式协议与算法——拜占庭将军问题

拜占庭将军问题 背景:以战国时期为背景 战国时期,齐、楚、燕、韩、赵、魏、秦七雄并立,后来秦国的势力不断强大起来,成了东方六国的共同威胁。于是,这六个国家决定联合,全力抗秦,免得被秦国各个…

JVM面试突击1

JVM面试突击 JDK,JRE以及JVM的关系 我们的编译器到底干了什么事? 仅仅是将我们的 .java 文件转换成了 .class 文件,实际上就是文件格式的转换,对等信息转换。 类加载机制是什么? 所谓类加载机制就是 虚拟机把Class文…

C语言阶段性测试题

大家好,我是深鱼~ 【前言】:本部分是C语言初阶学完阶段性测试题,最后一道编程题有一定的难度,需要多去揣摩,代码敲多了,自然就感觉不难了,加油,铁汁们!!&…

YOLO中Anchor生成介绍

Anchor生成机制 YOLOv1YOLOv2YOLOv4模型输出decode1.维度变换2.读取位置信息3.坐标变换4.构建网格5. 计算实际偏移量6.得到输出 YOLOv1 利用全连接层直接对边界框进行预测 YOLOv2 YOLOv2通过缩减网络,使用416x416的输入,模型下采样的总步长为32&#…

flutter开发实战-实现自定义按钮类似UIButton效果

flutter开发实战-实现自定义按钮类似UIButton效果 最近开发过程中需要实现一下UIButton效果的flutter按钮,这里使用的是监听手势点击事件。 一、GestureDetector GestureDetector属性定义 GestureDetector({super.key,this.child,this.onTapDown,this.onTapUp,t…

附件展示 点击下载

效果图 实现代码 <el-table-column prop"attachment" label"合同附件" width"250" show-overflow-tooltip><template slot-scope"scope"><div v-if"scope.row.cceedcAppendixInfoList &&scope.row.ccee…

路由的hash和history模式的区别

目录 ✅ 路由模式概述 一. 路由的hash和history模式的区别 1. hash模式 2. history模式 3. 两种模式对比 二. 如何获取页面的hash变化 ✅ 路由模式概述 单页应用是在移动互联时代诞生的&#xff0c;它的目标是不刷新整体页面&#xff0c;通过地址栏中的变化来决定内容区…

SQL 表别名 和 列别名

列表名 列表名之后 order by 可以用别名 也可以用原名&#xff0c; where 中不能用别名的 SQL语句执行顺序&#xff1a; from–>where–>group by -->having — >select --> order 第一步&#xff1a;from语句&#xff0c;选择要操作的表。 第二步&#xff1…

SpringBoot图片上传并对大小进行压缩(缩放比例)

前言 最近有个新需求&#xff0c;项目中对客户上传jpg图片的时候&#xff0c;每次都是校验大小必须≤30KB&#xff0c;但是客户实际使用的时候&#xff0c;总是会自己去进行压缩&#xff0c;压缩到30KB以内之后再上传&#xff0c;使用时间长了之后&#xff0c;客户总会觉得很麻…

react学习笔记——1. hello react

包含的包一共有4个&#xff0c;分别的作用如下&#xff1a; babel.min.js&#xff1a;可以进行ES6到ES5的语法转换&#xff1b;可以用于import&#xff1b;可以用于将jsx转换为js。注意&#xff0c;在开发的时候&#xff0c;这个转换&#xff08;jsx转换js&#xff09;不在线上…

Tcp的粘包和半包问题及解决方案

目录 粘包&#xff1a; 半包&#xff1a; 应用进程如何解读字节流&#xff1f;如何解决粘包和半包问题&#xff1f; ①&#xff1a;固定长度 ②&#xff1a;分隔符 ③&#xff1a;固定长度字段存储内容的长度信息 粘包&#xff1a; 一次接收到多个消息&#xff0c;粘包 应…

HBase概述

HBase 一 HBase简介与环境部署 1.1 HBase简介&在Hadoop生态中的地位 1.1.1 什么是HBase HBase是一个分布式的、面向列的开源数据库HBase是Google BigTable的开源实现HBase不同于一般的关系数据库, 适合非结构化数据存储 1.1.2 BigTable BigTable是Google设计的分布式…

mysql的update_time

CREATE TABLE users (id INT AUTO_INCREMENT PRIMARY KEY,name VARCHAR(50) NOT NULL,age INT,update_time TIMESTAMP NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 更新时间 );具体解释如下&#xff1a; DEFAULT CURRENT_TIMESTAMP: 这部分表示当插入…

【CI/CD】图解六种分支管理模型

图解六种分支管理模型 任何一家公司乃至于一个小组织&#xff0c;只要有写代码的地方&#xff0c;就有代码版本管理的主场&#xff0c;初入职场&#xff0c;总会遇到第一个拦路虎 git 管理流程&#xff0c;但是每一个企业似乎都有自己的 git 管理流程&#xff0c;倘若我们能掌握…

如何在不使用脚本和插件的情况下手动删除 3Ds Max 中的病毒?

如何加快3D项目的渲染速度&#xff1f; 3D项目渲染慢、渲染卡顿、渲染崩溃&#xff0c;本地硬件配置不够&#xff0c;想要加速渲染&#xff0c;在不增加额外的硬件成本投入的情况下&#xff0c;最好的解决方式是使用渲云云渲染&#xff0c;在云端批量渲染&#xff0c;批量出结…

ABAP 自定义搜索功能 demo1

ABAP 自定义搜索功能 demo1 效果&#xff1a; 双击选中行则为选中对应发票 实现 1定义 定义屏幕筛选参数 SELECTION-SCREEN BEGIN OF SCREEN 9020. SELECT-OPTIONS:s1_belnr FOR rbkp-belnr, s1_gjahr FOR rbkp-gjahr, s1_lifnr FOR rbkp-lifnr, s1_erfna FOR rbkp-erfnam, …

go入门实践二-tcp服务端

文章目录 前言接口与方法并发-协程项目管理bufio包使用其他代码 前言 上一篇&#xff0c;我们通过go语言的hello-world入门&#xff0c;搭建了go的编程环境&#xff0c;并对go语法有了简单的了解。本文实现一个go的tcp服务端。借用这个示例&#xff0c;展示接口、协程、bufio的…

php运算符的短路特性

php运算符的短路特性 1、逻辑运算符&#xff1a;逻辑与&#xff08;&&)和逻辑或&#xff08;||&#xff09;&#xff0c;存在着短路特性 PHP中有以下两个运算符具有短路的特性&#xff0c;他们是逻辑运算符的逻辑与&#xff08;&&)和逻辑或&#xff08;||&am…

线程概念linux

何为线程&#xff1a; 线程是程序中负责执行的单位&#xff0c;它可以被看作是进程的一部分&#xff0c;是进程的子任务。线程与进程的区别在于&#xff0c;进程是一个资源单位&#xff0c;而线程是进程的一部分&#xff0c;它只有栈这个独立的资源&#xff0c;其他资源如代码…

Java SpringBoot集成Activiti7工作流

Activiti7 Java SpringBoot集成Activiti7工作流介绍项目集成引入依赖YML配置文件配置类 启动项目生成表结构Activiti的数据库支持 Activiti数据表介绍项目Demo地址&#xff1a; Java SpringBoot集成Activiti7工作流 本文项目Demo地址附在文章后方 官网主页&#xff1a;http://a…