《动手学深度学习》 第二天 (线性回归)

3.2 线性回归的从零开始实现

只利用NDArray和autograd来实现一个线性回归的训练。

首先,导入本节中实验所需的包或模块,其中的matplotlib包可用于作图,且设置成嵌入显示。

%matplotlib inline
from IPython import display
from matplotlib import pyplot as plt
from mxnet import autograd, nd
import random
3.2.1 生成数据集

我们构造一个简单的人工训练数据集,它可以使我们能够直观比较学到的参数和真实的模型参数的区别。
设训练数据集样本数为1000,输入个数(特征数)为2。
给定随机生成的批量样本特征𝑋∈ℝ1000×2,
我们使用线性回归模型真实权重𝑤=[2,−3.4]⊤和偏差𝑏=4.2,以及一个随机噪声项𝜖来生成标签 𝑦=𝑋𝑤+𝑏+𝜖,
其中噪声项𝜖服从均值为0、标准差为0.01的正态分布。噪声代表了数据集中无意义的干扰。下面,让我们生成数据集。

num_inputs = 2
num_examples = 1000
true_w = [2, -3.4]
true_b = 4.2
features = nd.random.normal(scale=1, shape=(num_examples, num_inputs))
labels = true_w[0] * features[:, 0] + true_w[1] * features[:, 1] + true_b
labels += nd.random.normal(scale=0.01, shape=labels.shape)

注意,features的每一行是一个长度为2的向量,而labels的每一行是一个长度为1的向量(标量)。

features[0], labels[0]

输出:

([2.2122064 0.7740038]<NDArray 2 @cpu(0)>, [6.000587]<NDArray 1 @cpu(0)>)

通过生成第二个特征features[:, 1]和标签 labels 的散点图,可以更直观地观察两者间的线性关系。

def use_svg_display():# 用矢量图显示display.set_matplotlib_formats('svg')def set_figsize(figsize=(3.5, 2.5)):use_svg_display()# 设置图的尺寸plt.rcParams['figure.figsize'] = figsize
​
set_figsize()
plt.scatter(features[:, 1].asnumpy(), labels.asnumpy(), 1);  # 加分号只显示图

我们将上面的plt作图函数以及use_svg_display函数set_figsize函数定义在d2lzh包里。以后在作图时,我们将直接调用d2lzh.plt。由于plt在d2lzh包中是一个全局变量,我们在作图前只需要调用d2lzh.set_figsize()即可打印矢量图并设置图的尺寸。

3.2.2 读取数据集

在训练模型的时候,我们需要遍历数据集并不断读取小批量数据样本。这里我们定义一个函数:它每次返回batch_size(批量大小)个随机样本的特征和标签。


def data_iter(batch_size, features, labels):num_examples = len(features)indices = list(range(num_examples))random.shuffle(indices)  # 样本的读取顺序是随机的for i in range(0, num_examples, batch_size):j = nd.array(indices[i: min(i + batch_size, num_examples)])yield features.take(j), labels.take(j)  # take函数根据索引返回对应元素

让我们读取第一个小批量数据样本并打印。每个批量的特征形状为(10, 2),分别对应批量大小和输入个数;标签形状为批量大小。

batch_size = 10for X, y in data_iter(batch_size, features, labels):print(X, y)break

输出:

[[-1.0929538  -0.1200345 ][-1.2860336  -1.6586353 ][ 0.00389364  1.1413413 ][-0.51129895  0.46543437][ 0.8011116  -0.5865901 ][ 0.52092004  0.18693134][ 0.5604595   0.96975976][-0.6614866   0.09907386][-0.4813231   0.5334126 ][-0.21595766  2.066646  ]]<NDArray 10x2 @cpu(0)> [ 2.409014    7.265286    0.31805784  1.6139998   7.7808976   4.6176642.0270698   2.5347762   1.4169512  -3.246182  ]
<NDArray 10 @cpu(0)>
3.2.3 初始化模型参数

我们将权重初始化成均值为0、标准差为0.01的正态随机数,偏差则初始化成0。

w = nd.random.normal(scale=0.01, shape=(num_inputs, 1))
b = nd.zeros(shape=(1,))

之后的模型训练中,需要对这些参数求梯度来迭代参数的值,因此我们需要创建它们的梯度。

w.attach_grad()
b.attach_grad()
3.2.4 定义模型

下面是线性回归的矢量计算表达式的实现。我们使用dot函数做矩阵乘法。

def linreg(X, w, b):  # 本函数已保存在d2lzh包中方便以后使用return nd.dot(X, w) + b
3.2.5 定义损失函数

我们使用上一节描述的平方损失来定义线性回归的损失函数。在实现中,我们需要把真实值y变形成预测值y_hat的形状。以下函数返回的结果也将和y_hat的形状相同。

def squared_loss(y_hat, y):  return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2
3.2.6 定义优化算法

以下的sgd函数实现了小批量随机梯度下降算法。它通过不断迭代模型参数来优化损失函数。这里自动求梯度模块计算得来的梯度是一个批量样本的梯度和。我们将它除以批量大小来得到平均值。

def sgd(params, lr, batch_size):  for param in params:param[:] = param - lr * param.grad / batch_size
3.2.7 训练模型

在训练中,我们将多次迭代模型参数。在每次迭代中,我们根据当前读取的小批量数据样本(特征X和标签y),通过调用反向函数backward计算小批量随机梯度,并调用优化算法sgd迭代模型参数。由于我们之前设批量大小batch_size为10,每个小批量的损失l的形状为(10, 1)。回忆一下“自动求梯度”一节。由于变量l并不是一个标量,运行l.backward()将对l中元素求和得到新的变量,再求该变量有关模型参数的梯度。

在一个迭代周期(epoch)中,我们将完整遍历一遍data_iter函数,并对训练数据集中所有样本都使用一次(假设样本数能够被批量大小整除)。这里的迭代周期个数num_epochs和学习率lr都是超参数,分别设3和0.03。在实践中,大多超参数都需要通过反复试错来不断调节。虽然迭代周期数设得越大模型可能越有效,但是训练时间可能过长。

lr = 0.03
num_epochs = 3
net = linreg
loss = squared_loss
​
for epoch in range(num_epochs):  # 训练模型一共需要num_epochs个迭代周期# 在每一个迭代周期中,会使用训练数据集中所有样本一次(假设样本数能够被批量大小整除)。X# 和y分别是小批量样本的特征和标签for X, y in data_iter(batch_size, features, labels):with autograd.record():l = loss(net(X, w, b), y)  # l是有关小批量X和y的损失l.backward()  # 小批量的损失对模型参数求梯度sgd([w, b], lr, batch_size)  # 使用小批量随机梯度下降迭代模型参数train_l = loss(net(features, w, b), labels)print('epoch %d, loss %f' % (epoch + 1, train_l.mean().asnumpy()))

输出:

epoch 1, loss 0.040552
epoch 2, loss 0.000158
epoch 3, loss 0.000051

训练完成后,我们可以比较学到的参数和用来生成训练集的真实参数。它们应该很接近。

true_w, w

输出

([2, -3.4], [[ 1.9997406][-3.4000957]]<NDArray 2x1 @cpu(0)>)
true_b, b

输出:

(4.2, [4.199303]<NDArray 1 @cpu(0)>)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/484002.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Gartner发布2021年新兴技术成熟度曲线

来源&#xff1a;Gartner中国8月24日&#xff0c;Gartner 公司最新发布了“2021年新兴技术成熟度曲线”&#xff08;Hype Cycle for Emerging Technologies&#xff0c;2021&#xff09;。其中&#xff0c;建立信任&#xff0c;加速增长以及塑造变革将是三大主要趋势&#xff0…

Vue语法学习第一课——插值

学习关于Vue的插值语法 ① 文本值 &#xff1a; "Mustache"语法&#xff0c;即双大括号 1 <span>Message:{{msg}}</span> 注&#xff1a;双大括号中的msg值改变&#xff0c;插入的内容也会随之改变&#xff0c;可通过v-once指令限制&#xff0c;但会影响…

计算方法之方程求根、线性方程组求解、插值方法、数值积分简介

提示:本文章主要通过介绍方程求根、线性方程组求解、插值方法、数值积分等相关方法的理论知识,并运用相关方法来解决一个实际的问题,文章中简单介绍了二分法、不动点迭代,牛顿法、Scant Method等方程求根方法,Gauss-Seidel迭代,Jacobi迭代,SOR迭代,Gauss消元法等方程组…

hadoop 重新格式化 NameNode

【问题描述】 在安装配置hadoop的过程中&#xff0c;很可能发生错误导致datanode或者namenode 启动失败&#xff0c;这时我们可以选择重新格式化 namenode。 一、删除data数据和log日志 二 、使用命令 bin/dfs namenode -format 重新格式化 【注意事项】 为什么不能一直格式…

人工智能“上位”会让程序员消失吗?

大脑以及二进制代码&#xff08;图&#xff1a;Canva&#xff09;来源&#xff1a;Forbes作者&#xff1a;Nisha Talagala编译整理&#xff1a;科技行者写代码已经成了许多工作的一项关键技能。一些国家和学校甚至认为&#xff0c;编程语言是一种可以接受的外语。而在各种熙熙攘…

分类的IP地址

现有物理地址再有IP地址IP地址的表示方法为点分十进制法IP地址的设计思想&#xff1a;网络部分 主机部分 分类的IP地址 特征&#xff1a;根据不同特征的IP地址&#xff0c;事先约定好网络号所占的位数和主机号所占的位数。 A类地址 全球一共有27-2 个A类网络&#xff0c;每…

人工智能之深度优先,广度优先,贪婪最佳优先搜索,A*搜索以及爬山法与遗传算法

项目场景: 1. 分别用宽度优先、深度优先、贪婪算法和A*算法求解“罗马利亚度假问题”。 2. 分别用爬山法和GA算法求解n皇后问题。 文章目录 项目场景:一、度假场景1.1 问题描述2.1 问题分析:1.3 解决方案:1.4 运行结果二、N皇后问题2.1 问题描述2.2 数据存储结构2.3 算法思…

操作系统之多级队列调度算法,银行家算法,动态分区式存储区管理

题目描述: 1.对于多级队列调度算法,主要介绍轮转法,短进程优先算法;银行家算法主要介绍进程的资源分配策略; 2.对于动态分区式存储区管理,主要介绍首先适应法,最佳适应法,最坏适应法等调度算法。 文章目录 题目描述:程序功能及设计思路1. 多级队列调度算法函数设计2. …

卫星对于物联网来说是一个非常好的选择

ALAMY来源&#xff1a;IEEE电气电子工程师对许多人来说&#xff0c;“物联网”一词可能会让人想起智能城市的努力&#xff0c;比如配备交通摄像头和空气质量传感器的路灯&#xff0c;或者在自己家里连接设备。一个很自然的问题是&#xff0c;为什么你从没想使用卫星连接任何这些…

利用子网掩码划分子网

分类IP地址的弊端 一个物理网络不能过大&#xff0c;否则网络性能很差&#xff0c;某个B类或A类IP网络无法全部用于单个物理网络分类IP地址分配不合理&#xff0c;利用率低分类IP地址设计的弊端 —— 不灵活&#xff0c;IP地址利用率不高 划分子网的思路 网络管理员将本应属于…

springcloud流程图

自己画的&#xff1a; 别人画的 别人画的2 转载于:https://www.cnblogs.com/dzhou/p/10504215.html

编译原理之LR语法分析器,自动机

本博客主要介绍LR语法分析器的代码实现,包含完整的测试数据与源代码。 文章目录 1. 主要内容:2.实验过程2.1 实验数据2.2 源代码1. 主要内容: LR语法分析器理论:https://blog.csdn.net/qq_40294512/article/details/92621241 2.实验过程 2.1 实验数据 G.txt数据文件 E-&…

java 搭建 web服务器 socket实现

【写在前面】 云计算的第n个java作业&#xff0c;开始一直不懂为什么老师一直让我们写java web的小demo&#xff0c;不应该是hadoop啥的直接上框架嘛。后来慢慢了解到&#xff0c;其实java web 的一些内容确实是云计算的基础。这个demo是用java socket 来搭建一个web服务器&…

【趋势】未来十年计算机体系结构的历史和趋势

来源&#xff1a;机器之心先分享我对这篇文章的总结&#xff0c;或者我得到的启发&#xff1a;1、DSA&#xff08;Domain-Specific Architectures&#xff0c;特定领域的体系结构&#xff09;将成为未来十年甚至更长时间&#xff0c;计算机体系结构的趋势。登纳德缩放定律结束、…

BZOJ3064 CPU监控

题目链接&#xff1a;戳我 比较神仙的一个题&#xff08;至少对于我这个小蒟蒻来说。。。&#xff09;下面尽可能详细地解释一下吧。。。学习来源&#xff1a;这位神仙的题解 其实就是对于操作的转换。我们设(x,y)为操作的参数&#xff0c;设当前数为a&#xff0c;操作为max(ax…

java socket 实现增删改查 + 在线答题小案例

实现效果 &#xff08;1&#xff09; 在client端可以实现对数据库的操作&#xff08;Select&#xff0c;Insert&#xff0c;Update&#xff0c;Delete&#xff09; &#xff08;2&#xff09;数据库中创建一个考试表和学生表&#xff0c;考试表中问题是四项选择题&#xff08;…

90后斯坦福博士论文登Science封面!AI算法准确预测RNA三维结构

来源&#xff1a;Science编辑&#xff1a;yaxin、su「我们对大部分RNA的结构几乎一无所知。」半个世纪以来&#xff0c;确定生物分子的三维结构一直困惑着科学家&#xff0c;也是生物学的重大挑战之一。难就难在&#xff0c;RNA折叠成复杂三维结构的形状很难通过实验或计算来确…

Event Recommendation Engine Challenge分步解析第五步

一、请知晓 本文是基于&#xff1a; Event Recommendation Engine Challenge分步解析第一步 Event Recommendation Engine Challenge分步解析第二步 Event Recommendation Engine Challenge分步解析第三步 Event Recommendation Engine Challenge分步解析第四步 需要读者先阅读…

计算机网络之RIP协议与OSPF协议模拟、UDP与TCP编程,Wireshark抓包分析

通过Python模拟RIP协议,OSPF协议,并模拟UDP和TCP编程,并通过Wireshark抓包工具,对所发送的报文进行捕获分析。 文章目录 一、RIP协议的模拟与编程二、OSPF协议的模拟与编程三、UDP编程四、TCP套接字编程五、Wireshark 数据分析六、总结一、RIP协议的模拟与编程 1.1 题目 …

虚拟机 NAT模式与桥接模式的区别

同个人网站 https://www.serendipper-x.cn/&#xff0c;欢迎访问 &#xff01; NAT模式&#xff1a;相当于宿主机再构建一个局域网&#xff0c;虚拟机无法和本局域网中的其他真实主机进行通讯。只需要宿主机器能访问互联网&#xff0c;那么虚拟机就能上网&#xff0c;不需要再…