机器学习深度学习——线性回归的从零开始实现

虽然现在的深度学习框架几乎可以自动化实现下面的工作,但从零开始实现可以更了解工作原理,方便我们自定义模型、自定义层或自定义损失函数。

import random
import torch
from d2l import torch as d2l

线性回归的从零开始实现

  • 生成数据集
  • 读取数据集
  • 初始化模型参数
  • 定义模型
  • 定义损失函数
  • 定义优化算法
  • 训练

生成数据集

根据带有噪声的线性模型构造一个人造数据集。任务是使用这个数据集来恢复模型的参数。我们使用低维数据,可以更容易地进行可视化。
在下面代码中,我们生成一个包含1000个样本的数据集,每个样本包含从标准正态分布中采样的2个特征。我们的数据集是一个1000×2的矩阵X。
使用线性模型参数 w = [ 2 , − 3.4 ] T 、 b = 4.2 和噪声项 δ 生成数据集及标签: y = X w + b + δ 使用线性模型参数w=[2,-3.4]^T、b=4.2和噪声项\delta生成数据集及标签:\\ y=Xw+b+\delta 使用线性模型参数w=[2,3.4]Tb=4.2和噪声项δ生成数据集及标签:y=Xw+b+δ
其中,δ可以视为模型预测和标签时的潜在观测误差。在这里我们认为标准假设成立,即δ服从均值为0的正态分布。为简化问题,将标准差设为0.01。下面的代码生成合成数据集:

def synthetic_data(w, b, num_examples):  #@save"""生成y=Xw+b+δ"""# 生成均值为0,标准差为1(标准正态分布)且大小1000*2的数据集X = torch.normal(0, 1, (num_examples, len(w)))# 生成y函数,生成1000*1的矩阵y = torch.matmul(X, w) + b# 再加上服从均值为0的正态分布的δy += torch.normal(0, 0.01, y.shape)return X, y.reshape((-1, 1))true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)

其中,features中每一行都包含一个二维数据样本,labels中每一行都包含一个一维标签值(一个标量)。

print('features:', features[0], '\nlabel:', labels[0])

结果:

features: tensor([-0.5829, -0.2094])
label: tensor([3.7491])

通过生成第二个特征features[:, 1]和labels的散点图,可以直观看出两者之间的线性关系:

d2l.set_figsize()
d2l.plt.scatter(features[:, 1].detach().numpy(), labels.detach().numpy(), 1)
d2l.plt.show()

在这里插入图片描述

读取数据集

训练模型时,要对数据集进行遍历,每次抽取一小批量样本,并使用它们来更新模型。因此,需要定义一个函数,该函数能打乱数据集中的样本并以小批量方式获取数据。
在下面代码中,定义一个data_iter函数,接收批量大小、特征矩阵和标签向量作为输入,生成大小为batch_size的小批量。每个小批量包含一组特征和标签。

def data_iter(batch_size, features, labels):num_examples = len(features)indices = list(range(num_examples))  # 0到999的顺序random.shuffle(indices)  # 这些样本是随机读取的,没有特定顺序for i in range(0, num_examples, batch_size):batch_indices = torch.tensor(indices[i: min(i+batch_size, num_examples)]  # 随机取样)yield features[batch_indices], labels[batch_indices]# yield返回一个可以用来迭代for循环的生成器,而不是直接return

通常我们会利用CPU并行运算的优势,处理合理大小的“小批量”。每个样本都可以并行进行模型计算,且每个样本损失函数的梯度也可以被并行计算。
可以直观感受一下小批量运算:读取第一个小批量数据样本并打印:

batch_size = 10
for X, y in data_iter(batch_size, features, labels):print(X, '\n', y)break

结果:

tensor([[-1.0186, 1.8338],
[ 0.6455, 1.1226],
[-0.5020, 0.2105],
[ 1.3583, 0.6979],
[ 0.3024, -0.8929],
[ 0.4045, -0.4207],
[ 0.5201, -0.3263],
[ 0.6037, -0.1332],
[ 1.6171, 0.2449],
[-0.6540, 1.0338]])
tensor([[-4.0795],
[ 1.6835],
[ 2.5014],
[ 4.5346],
[ 7.8678],
[ 6.4298],
[ 6.3537],
[ 5.8528],
[ 6.6194],
[-0.6216]])

当我们进行迭代时,我们会连续地获得不同的小批量,直到遍历完整个数据集。但上面实现的迭代执行效率很低,可能会出问题。在深度学习框架中实现的内置迭代器效率要高得多,它可以处理存储在文件中的数据和数据流提供的数据

初始化模型参数

通过从均值为0、标准差为0.01的正态分布中采样随机数来初始化权重,并将偏置初始化为0:

w = torch.normal(0, 0.01, size=(2, 1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)

在初始化参数后,我们的任务是更新这些参数,直到这些参数足够拟合我们的数据。每次更新都需要计算损失函数关于模型参数的梯度,有了这个梯度就可以向减小损失的方向来更新每个参数

定义模型

定义模型,就要将模型的输入和参数同模型的输出关联起来。
而要计算线性模型的输出,只需要计算输入特征X与模型权重w的矩阵-向量乘法后再加上偏置b。(Xw是一个向量,而b是标量)当我们用一个向量加上一个标量时,标量会加到向量的每个分量上(广播机制):

def linreg(X, w, b):  #@save"""线性回归模型"""return torch.matmul(X, w) + b

定义损失函数

要计算损失函数的梯度,自然要先定义损失函数,下面定义了平方损失函数:

def squared_loss(y_hat, y):  #@save"""均方损失"""return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2

定义优化算法

线性回归是有解析解的,但是其他模型基本没有,因此还是用随机梯度下降法来进行优化。
在每一步中,使用数据集中随机抽取的一个小批量,然后根据参数计算损失的梯度。接下来朝着减小损失的方向来更新参数。
下面就是随机梯度下降更新的函数,该函数接受模型参数集合、学习速率和批量大小作为输入。每一步更新的大小由学习率lr决定。因为我们计算的损失是一个批量样本的综合,因此用批量大小batch_size来规范步长,这样步长大小就不会取决于我们对批量大小的选择:

def sgd(params, lr, batch_size):  #@save"""小批量随机梯度下降"""with torch.no_grad():for param in params:param -= lr * param.grad / batch_sizeparam.grad.zero_()

其中,torch.no_grad()是上下文管理器,用来指定在其内部的代码块中不进行梯度计算。当不需要计算梯度时,使用该上下文管理器可以提高代码执行效率。

训练

在每次迭代中,我们读取一小批训练样本,并通过我们的模型来获得一组预测。计算完损失后,我们开始反向传播,存储每个参数的梯度。最后调用优化算法sgd来更新模型参数。
概括一下,就是执行下面的循环:
1、初始化参数
2、重复一下训练,直到完成:
计算梯度 g ← ∂ ( w , b ) 1 ∣ B ∣ ∑ i ∈ B l ( x ( i ) , y ( i ) , w , b ) 更新参数 ( w , b ) ← ( w , b ) − η g 计算梯度g←\partial_{(w,b)}\frac{1}{|B|}\sum_{i∈B}l(x^{(i)},y^{(i)},w,b)\\ 更新参数(w,b)←(w,b)-ηg 计算梯度g(w,b)B1iBl(x(i),y(i),w,b)更新参数(w,b)(w,b)ηg
在每个迭代周期中,我们使用deta_iter函数遍历整个数据集,并将训练数据集中所有样本都使用一次(假设样本数能够被批量大小整除)。这里的迭代周期个数num_epoches和学习率lr都是超参数,分别设为3和0.03。(超参数设置很麻烦,现在忽略细节)

batch_size = 10
lr = 0.03
num_epochs = 3
net = linreg
loss = squared_lossfor epoch in range(num_epochs):for X, y in data_iter(batch_size, features, labels):l = loss(net(X, w, b), y)  # X和y的小批量损失# 因为l形状是(batch_size,1),不是标量# l中所有元素加起来再计算关于[w,b]的梯度l.sum().backward()sgd([w, b], lr, batch_size)  # 使用参数的梯度更新参数with torch.no_grad():train_l = loss(net(features, w, b), labels)print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}')

结果:

epoch 1, loss 0.040672
epoch 2, loss 0.000146
epoch 3, loss 0.000047

事实上,真实参数和通过训练得到的参数很接近:

print(f'w的估计误差: {true_w - w.reshape(true_w.shape)}')
print(f'b的估计误差: {true_b - b}')

结果:

w的估计误差: tensor([ 0.0006, -0.0002], grad_fn=)
b的估计误差: tensor([0.0004], grad_fn=)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/11971.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

windows默认编码格式修改

1.命令提示符界面输入 chcp 936 对应 GBK 65001 对应 UTF-8 2.临时更改编码格式 chcp 936(或65001) 3.永久更改编码格式 依次开控制面板->时钟和区域->区域->管理->更改系统区域设置,然后按下图所示,勾选使用UTF-8语言支持。然后重启电脑。此…

防止连点..

1.连点js文件 let timer; letflag /*** 节流原理:在一定时间内,只能触发一次** param {Function} func 要执行的回调函数* param {Number} wait 延时的时间* param {Boolean} immediate 是否立即执行* return null*/ function throttle(func, wait 500…

【数字IC基础】竞争与冒险

竞争-冒险 1. 基本概念2. 冒险的分类3. 静态冒险产生的判断4. 毛刺的消除使用同步电路使用格雷码增加滤波电容增加冗余项,消除逻辑冒险引入选通脉冲 1. 基本概念 示例一: 如上图所示的这个电路,使用了两个逻辑门,一个非门和一个与…

Windows 找不到文件‘chrome‘。请确定文件名是否正确后,再试一次

爱像时间,永恒不变而又短暂;爱像流水,浩瀚壮阔却又普普通通。 Windows 找不到文件chrome。请确定文件名是否正确后,再试一次 如果 Windows 提示找不到文件 "chrome",可能是由于以下几种原因导致的&#xff1…

机器学习深度学习——模型选择、欠拟合和过拟合

👨‍🎓作者简介:一位即将上大四,正专攻机器学习的保研er 🌌上期文章:机器学习&&深度学习——多层感知机的简洁实现 📚订阅专栏:机器学习&&深度学习 希望文章对你们有…

【GitOps系列】使用 ArgoCD 快速打造GitOps工作流

文章目录 ArgoCD简介ArgoCD安装访问ArgoCDGitOps 工作流总览创建 ArgoCD 应用检查 ArgoCD 同步状态访问应用 连接 GitOps 工作流体验 GitOps 工作流生产建议1)修改默认密码2)配置 Ingress 和 TLS3)使用 Webhook 触发 ArgoCD4)将源…

DoIP学习笔记系列:(二)VN5620 DoIP测试配置实践笔记

文章目录 1. 添加.cdd2. CAPL中调用接口发送DoIP请求3. “Ethernet Packet Builder”的妙用4. CANoe也可以做交互界面在进行测试前,先检查车载以太网硬件连线是否正确,需要注意连接两端的Master、Slave,100M、1000M等基本情况,在配置VN5620的时候就可以灵活处理了。成功安装…

数学建模-MATLAB三维作图

导出图片用无压缩tif会更清晰 帮助文档:doc 函数名 matlab代码导出为PDF 新建实时脚本或右键文件转换为实时脚本实时编辑器-全部运行-内嵌显示保存为PDF

【TypeScript】接口类型 Interfaces 的使用理解

导语: 什么是 类型接口? 在面向对象语言中,接口(Interfaces)是一个很重要的概念,它是对行为的抽象,而具体如何行动需要由类(classes)去实现(implement&#x…

JVM-类加载

1.了解冯诺依曼计算机结构 1.1计算机处理数据过程 (1)提取阶段:由输入设备把原始数据或信息输入给计算机存储器存起来 (2)解码阶段:根据CPU的指令集架构(ISA)定义将数值解译为指令 (3)执行阶段:再由控制器把需要处理或计算的数据调入运算器 (4)最终阶段:由输出设备把最后运…

区间预测 | MATLAB实现基于QRF随机森林分位数回归时间序列区间预测模型

区间预测 | MATLAB实现基于QRF随机森林分位数回归时间序列区间预测模型 目录 区间预测 | MATLAB实现基于QRF随机森林分位数回归时间序列区间预测模型效果一览基本介绍程序设计参考资料 效果一览 基本介绍 1.Matlab实现基于QRF随机森林分位数回归时间序列区间预测模型&#xff1…

Dooring-Saas低代码技术详解

hello, 大家好, 我是徐小夕, 今天和大家分享一下基于 H5-Dooring零代码 开发的全新零代码搭建平台 Dooring-Saas 的技术架构和设计实现思路. 背景介绍 3年前我上线了第一版自研零代码引擎 H5-Dooring, 至今已迭代了 300 多个版本, 主要目的是快速且批量化的生产业务/营销过程中…

empty module导致的lvs问题

write_verilog时-exclude empty_modules即可 这里也分享一下ICC2 write lvs netlist的命令 write_verilog -exclude {scalar_wire_declarations leaf_module_declarations empty_modules well_tap_cells filler_cells supply_statements} -hierarchy all -force_no_referenc…

手风琴案例(jQuery)

案例效果 代码实现 jQuery代码(两种方法) 方法一:hover版 $(function () {$(".king li").hover(function() {$(this).addClass("current").siblings().removeClass("current");}, function() {$(".king…

单机部署NGINX

​ 1、找到合适的nginx资源包,可以去官网下载 这里用的是 nginx-1.24.0.tar.gz 2、上传下载下来的nginx软件包,并解压 tar zxvf nginx-1.24.0.tar.gz cd nginx-1.24.0/ 3、安装nginx 编译 ./configure --prefix/usr/local/nginx --with-http_ssl…

哈希表的简单模拟实现

文章目录 底层结构哈希冲突闭散列定义哈希节点定义哈希表**哈希表什么情况下进行扩容?如何扩容?**Insert()函数Find()函数二次探测HashFunc()仿函数Erase()函数全部的代码 开散列定义哈希节点定义哈希表Insert()函数Find()函数Erase()函数总代码 初识哈希…

UG\NX二次开发 获取2D制图视图中可见的对象,并获取类型

文章作者:里海 来源网站:https://blog.csdn.net/WangPaiFeiXingYuan 简介: 使用UF_VIEW_ask_visible_objects获取2D制图视图中可见的对象,并获取类型。 下面是将一个六面体以不同的视图投影,获取视图对象和类型的效果。 效果: 1个部件事例,1个体,4条边 1个部件事…

springboot创建并配置环境(一) - 创建环境

文章目录 一、介绍二、启动环境Environment的分析三、进入源码四、创建环境1. 如何确定应用类型2. 测试 一、介绍 在springboot的启动流程中,启动环境Environment是可以说是除了应用上下文ApplicationContext之外最重要的一个组件了,而且启动环境为应用…

MySQL主从复制

1.理解MySQL主从复制原理。 1.1MySQL支持的复制类型 MySQL支持以下几种常见的复制类型: 基于语句的复制(Statement-based Replication,SBR):基于语句的复制是MySQL最早支持的复制方式,它通过复制和执行S…

CSS3 实现边框圆角渐变色渐变文字效果

.boder-txt {width: 80px;height: 30px; line-height: 30px;padding: 5px;text-align: center;border-radius: 10px;border: 6rpx solid transparent;background-clip: padding-box, border-box;background-origin: padding-box, border-box;/*第一个linear-gradient表示内填充…