机器学习与深度学习-1-线性回归从零开始实现

机器学习与深度学习-1-线性回归从零开始实现

1 前言

​ 内容来源于沐神的《动手学习深度学习》课程,本篇博客对线性回归从零开始实现(即不调用封装好的库,如SGD优化器、MSE损失函数等)进行重述,并且修改了沐神的课堂示例代码以符合PEP8代码编写规范(如内参、外参等)。我先发布代码实现的文章,过后会把线性回归的数学推导发布。

2 问题背景–以房价预测为例

​ 通常,我们希望通过过去的房价历史数据去预测未来的房价走向。但实际上,房价与很多因素有关。因此为了以房价预测为例子引出线性回归问题,我们做以下假设:

  • 房价只与房屋的面积与房屋年限有关;
  • 房价与两个因素是线性关系

​ 基于以上假设,有如下式子:

p r i c e = ω a r e a ⋅ a r e a + ω a g e ⋅ a g e + b price = \omega_{area}\cdot area ~+~\omega_{age} \cdot age~+~b price=ωareaarea + ωageage + b
其中, p r i c e price price为回归模型预测的房价; ω a r e a \omega_{area} ωarea ω a g e \omega_{age} ωage为房屋面积与房屋年限的权重,也为模型训练的目标; b b b为偏置项。

3 编程实现

​ 首先,我们先调用一些基本的库:

import random  # 用于生成随机数
import torch  # 导入Pytorch库
from d2l import torch as d2l  # 可视化数据结果

其中,前两个库是基本库。random用于生成随机数,而torch是深度学习的基本框架–Pytorch的库。d2l这个库需要用pip或者conda自行下载。笔者的环境:Python版本:3.10;CUDA版本:11.8;torch版本:2.5.0;GPU:RTX 4060;CPU:i9-14900HX

​ 接着,我们生成数据集:

def create_dataset(input_W, input_b, num_examples):"""生成 y = xw + b + 噪声:param input_w:权重参数:param input_b:偏差参数:param num_examples:样本数"""input_x = torch.normal(0, 1, (num_examples, len(input_w)))  # 随机生成输入数据,服从正态分布output_y = torch,matmul(input_x, input_w) + input_b  # 计算输出数据output_y += torch.normal(0, 0.01, oytput_y.shape)  # 加入噪声"""返回数据集,其中通过reshape这个函数去重塑output_y的形状。:param -1:自动推断该向量的维度:param  1:将output_y变成一列所以input_x与output_y均为列向量"""return input_x, output_y.reshape((-1, 1))

沐神的代码中,内参与外参用的是同一个变量符号。为了防止因为内参与外参的优先级而导致的变量覆盖,本文将内参的变量符号修改以区别于外参。

​ 定义实际的权重向量与偏差量,并调用刚刚定义好的函数create_dataset生成一个示例数据集:

true_w = torch.tensor([2, -3.4])  # 定义真实的权重
true_b = 4.2  # 定义真实的偏差
features, labels = create_dataset(true_w, true_b, 1000)  # 生成数据集

​ 在控制台打印数据集并输出数据的散点图:

print("features", features[0], "\n label", label[0])  # 在控制台展示生成的数据序列d2l.set_figsize()  # 设置图窗的尺寸,自动调整一个舒适的尺寸
d2l.plt.scatter(features[:, 1].detach().numpy(), labels.detach().numpy(), 1)  # 
d2l.plt.show()

有几个函数用法要说明一下:

  • .detach():用于返回一个新的张量,与原张量共享相同的数据,但与计算题目无关联(计算梯度是一个很”贵“的事情);
  • .numpy():将张量格式转换为NumPy数组格式;
  • 1为散点的大小。

​ 接着我们需要根据批量的大小生成批量数据:

def data_iter(input_batch_size, input_features, input_labels):"""生成批量数据:param input_batch_size:批量大小:param input_features:输入数据:param input_labels:输出数据:return:返回批量数据"""num_examples = len(input_features)  # 样本数indices = list(range(num_examples))  # 样本的索引列表random.shuffle(indices)  # 样本的读取顺序随机for i in range(0, num_examples, input_batch_size):  # 按批量分割样本batch_indices = torch.tensor(indices[i + min(input_batch_size, num_examples)])  # 生成批量的索引,后面的min()保证索引值不会超出样本yield input_features[batch_indices], input_labels[batch_indices]  # 批量的输入输出数据,以迭代器的形式输出(可以节省内存)

​ 定义批量大小,并读取批量数据:

batch_size = 32  # 批量大小for x, y in data_iter(batch_size, features, labels):  # 读取批量大小print(x, "\n", y)  # 打印结果break

​ 实际上这个批量大小的设置是有讲究的,沐神在他的教学案例设置的是10,往后我会结合我的课题出一个深度学习模型设计的一般步骤,在那里会提到。

​ 接着我们要探索出回归效果最好的权重向量,首先进行一个初始化:

w = torch.normal(0, 0.01, size(2, 1), requires_grad=True)  # 初始化权重参数
b = torch.zeros(1, requires_grad=True)  # 初始化偏置

requires_grad这一项的意思是计算其梯度,这一步是优化过程的重要一步。

​ 定义线性回归模型:

def linreg(input_x, input_w, input_b):return torch.matmul(input_x, input_w) + input_b

​ 定义损失函数:

def squared_loss(y_hat, output_y):return (y_hat - output_y.reshape(y_hat.reshape)) ** 2 / 2

​ 定义小批量随机梯度下降优化算法:

def sgd(params, learning_rate, input_batch_size):with torch.no_grad():  # 不跟踪梯度for param in params:  # 遍历参数param -= learning_rate * param.grad / input_batch_size  # 更新参数param.grad.zeros_()  # 清空梯度,减少计算负担

​ 接下来是最重要的训练过程:

lr = 0.03  # 学习率
num_epochs = 50  # 迭代次数
net = linreg  # 线性模型
loss = squared_loss  # MSE损失for epoch in range(num_epochs):  # 训练模型for x, y in data_iter(batch_size, features, labels):  # 读取批量数据batch_loss = loss(net(x, w, b), y)  # 计算损失batch_loss.sum().backward()  # 反向传播sgd([w, b], lr, batch_size)  # 使用小批量随机梯度下降算法更新参数with torch.no_grad():  # 不跟踪梯度train_l = loss(new(features, w, b), labels)  # 训练集上的损失print(f"epoch {epoch + 1}", loss {float(train_l.mean()):f}")  # 打印训练集上的损失# 打印估计误差              
print(f"w的估计误差为:{true_w - w.reshape(true_w.shape)}")
print(f"b的估计误差为:{true_b - b}")

4 结果分析

​ 代码结果:

epoch 1, loss 2.518388
epoch 2, loss 0.396580
epoch 3, loss 0.062917
epoch 4, loss 0.010110
epoch 5, loss 0.001669
epoch 6, loss 0.000314
epoch 7, loss 0.000095
epoch 8, loss 0.000059
epoch 9, loss 0.000054
epoch 10, loss 0.000053
epoch 11, loss 0.000052
epoch 12, loss 0.000052
epoch 13, loss 0.000052
epoch 14, loss 0.000052
epoch 15, loss 0.000052
epoch 16, loss 0.000052
epoch 17, loss 0.000052
epoch 18, loss 0.000052
epoch 19, loss 0.000052
epoch 20, loss 0.000052
epoch 21, loss 0.000052
epoch 22, loss 0.000052
epoch 23, loss 0.000052
epoch 24, loss 0.000052
epoch 25, loss 0.000052
epoch 26, loss 0.000052
epoch 27, loss 0.000052
epoch 28, loss 0.000052
epoch 29, loss 0.000052
epoch 30, loss 0.000052
epoch 31, loss 0.000052
epoch 32, loss 0.000052
epoch 33, loss 0.000052
epoch 34, loss 0.000052
epoch 35, loss 0.000052
epoch 36, loss 0.000052
epoch 37, loss 0.000052
epoch 38, loss 0.000052
epoch 39, loss 0.000052
epoch 40, loss 0.000052
epoch 41, loss 0.000052
epoch 42, loss 0.000052
epoch 43, loss 0.000052
epoch 44, loss 0.000052
epoch 45, loss 0.000052
epoch 46, loss 0.000052
epoch 47, loss 0.000052
epoch 48, loss 0.000052
epoch 49, loss 0.000052
epoch 50, loss 0.000052
w的估计误差为:tensor([0.0001, 0.0002], grad_fn=<SubBackward0>)
b的估计误差为:tensor([0.0004], grad_fn=<RsubBackward1>)

从结果可以看出,模型训练到第十轮左右就收敛了,且训练误差很小,证明训练的效果很好。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/885182.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

VMware虚拟机可以被外部机器访问吗?

如何设置让同局域网内其他机器访问本地虚拟机服务&#xff08;这里以访问我本地虚拟机ELasticSearch服务为例&#xff09; 选中虚拟机 - 虚拟机 - 设置 虚拟机网络设置&#xff1a; 选中网络适配器&#xff0c;修改网络模式为NAT模式 编辑 - 虚拟机网络编辑器 更改设置 …

【论文复现】自动化细胞核分割与特征分析

本文所涉及所有资源均在这里可获取。 作者主页&#xff1a; 七七的个人主页 文章收录专栏&#xff1a; 论文复现 欢迎大家点赞 &#x1f44d; 收藏 ⭐ 加关注哦&#xff01;&#x1f496;&#x1f496; 自动化细胞核分割与特征分析 引言效果展示HoverNet概述HoverNet原理分析整…

【NOIP普及组】质因数分解

【NOIP普及组】质因数分解 C语言代码C代码Java代码Python代码 &#x1f490;The Begin&#x1f490;点点关注&#xff0c;收藏不迷路&#x1f490; 已知正整数 n 是两个不同的质数的乘积&#xff0c;试求出较大的那个质数。 输入 输入只有一行&#xff0c;包含一个正整数…

2024软件测试面试热点问题

&#x1f345; 点击文末小卡片 &#xff0c;免费获取软件测试全套资料&#xff0c;资料在手&#xff0c;涨薪更快 大厂面试热点问题 1、测试人员需要何时参加需求分析&#xff1f; 如果条件循序 原则上来说 是越早介入需求分析越好 因为测试人员对需求理解越深刻 对测试工…

qt QTextStream详解

1、概述 QTextStream类是Qt框架中用于处理文本输入输出的类。它提供了一种方便的方式&#xff0c;可以从各种QIODevice&#xff08;如QFile、QBuffer、QTcpSocket等&#xff09;中读取文本数据&#xff0c;或者将文本数据写入这些设备中。QTextStream能够自动处理字符编码的转…

Webpack性能优化指南:从构建到部署的全方位策略

文章目录 1、webpack的优化-OneOf2、webpack的优化-Include/Exclude3、webpack优化-SourceMap4、webpack的优化-Babel缓存5、wenbpack的优化-resolve配置6、构建结果分析 webpack优化在现代前端开发中&#xff0c;Webpack已成为模块打包器的事实标准&#xff0c;它通过将项目中…

[ DOS 命令基础 4 ] DOS 命令命令详解-端口进程相关命令

&#x1f36c; 博主介绍 &#x1f468;‍&#x1f393; 博主介绍&#xff1a;大家好&#xff0c;我是 _PowerShell &#xff0c;很高兴认识大家~ ✨主攻领域&#xff1a;【渗透领域】【数据通信】 【通讯安全】 【web安全】【面试分析】 &#x1f389;点赞➕评论➕收藏 养成习…

飞书API-获取tenant_access_token

1.在飞书工作台创建应用&#xff0c;跳到开发者后台&#xff0c;选创建企业自建应用 2.设置并发布应用 必须要发布应用才可以开始使用了&#xff01;&#xff01;&#xff01; 3.调用获取token的API 参考链接&#xff1a; 开发文档 - 飞书开放平台https://open.feishu.cn/do…

linux 安装anaconda3

1.下载 使用repo镜像网址下载对应安装包 右击获取下载地址&#xff0c;使用终端下载 wget https://repo.anaconda.com/archive/Anaconda3-2024.02-1-Linux-x86_64.sh2.安装 使用以下命令可直接指定位置 bash Anaconda3-2024.02-1-Linux-x86_64.sh -b -p /home/anaconda3也…

LabVIEW编程过程中为什么会出现bug?

在LabVIEW编程过程中&#xff0c;Bug的产生往往源自多方面原因。以下从具体的案例角度分析一些常见的Bug成因和调试方法&#xff0c;以便更好地理解和预防这些问题。 ​ 1. 数据流错误 案例&#xff1a;在一个LabVIEW程序中&#xff0c;多个计算节点依赖相同的输入数据&#…

【自用】fastapi 学习记录 --请求和参数部分

fastai个人学习笔记 一、模块化结构框架 设置了默认请求头shop之后就无需再app0x里接口函数前全部写上/shop/xxx&#xff0c;或者/user/xxx&#xff0c;他会同意添加~如果都写了就会出现以下的情况&#xff08;重复shop&#xff09;&#xff1a; 二、请求与响应 关于参数&a…

若依入门案例

若依&#xff08;RuoYi&#xff09;框架是一个基于Java的开源企业级快速开发框架&#xff0c;主要用于构建信息管理系统。它结合了多种前端和后端技术&#xff0c;提供了高效的开发工具&#xff0c;并具备以下主要功能&#xff1a; 一、后端功能 技术选型&#xff1a;若依后端…

【Web前端】OOP编程范式

面向对象编程&#xff08;Object-Oriented Programming&#xff0c;简称 OOP&#xff09;是一种程序设计思想&#xff0c;它通过将程序视为一组相互作用的对象来设计程序。OOP 提出了一些重要的基本概念&#xff0c;包括类与实例、继承和封装。面向对象编程将系统视为由多个对象…

Mac解决 zsh: command not found: ll

Mac解决 zsh: command not found: ll 文章目录 Mac解决 zsh: command not found: ll解决方法 解决方法 1.打开bash_profile 配置文件vim ~/.bash_profile2.在文件中添加配置&#xff1a;alias llls -alF键盘按下 I 键进入编辑模式3. alias llls -alF添加完配置后&#xff0c;按…

JavaAPI(1)

Java的API&#xff08;1&#xff09; 一、Math的API 是一个帮助我们进行数学计算的工具类私有化构造方法&#xff0c;所有的方法都是静态的&#xff08;可以直接通过类名.调用&#xff09; 平方根&#xff1a;Math.sqrt()立方根&#xff1a;Math.cbrt() 示例&#xff1a; p…

UI界面设计入门:打造卓越用户体验

互联网的迅猛发展催生了众多相关职业&#xff0c;其中UI界面设计师成为互联网行业的关键角色之一。UI界面设计无处不在&#xff0c;影响着网站、应用程序以及其他数字平台上的按钮、菜单布局、色彩搭配和字体排版等。UI设计不仅仅是字体、色彩和导航栏的组合&#xff0c;它的意…

std::back_inserter

std::back_inserter 是 C 标准库中的一个函数模板&#xff0c;它用于创建一个插入迭代器&#xff08;insert iterator&#xff09;&#xff0c;这个迭代器可以在容器末尾插入新元素。它定义在 <iterator> 头文件中。 函数原型 template <typename Container> bac…

在 Mac 和 Windows 系统中快速部署 OceanBase

OceanBase 是一款分布式数据库&#xff0c;具备出色的性能和高扩展性&#xff0c;可以为企业用户构建稳定可靠、灵活扩展性能的数据库服务。本文以开发者们普遍熟悉的Windows 或 Mac 环境为例&#xff0c;介绍如何快速上手并体验OceanBase。 一、环境准备 1. 硬件准备 OceanB…

如何有效销售和应用低代码软件?探索其市场机会与策略

随着技术的进步&#xff0c;企业对自动化和数字化的需求日益增加。低代码开发平台应运而生&#xff0c;成为企业实现快速应用程序开发的重要工具。然而&#xff0c;在市场上推广和应用低代码软件并非易事&#xff0c;需要深入了解客户需求&#xff0c;提供定制化的解决方案&…

在函数内部定义函数

【小白从小学Python、C、Java】 【计算机等级考试500强双证书】 【Python-数据分析】 在函数内部定义函数 在以下代码片段中&#xff0c;输出的结果是什么&#xff1f; def outer_function(x): def inner_function(y): return x y return inner_function add_five outer_func…