通俗易懂之线性回归时序预测PyTorch实践

线性回归(Linear Regression)是机器学习中最基本且广泛应用的算法之一。它不仅作为入门学习的经典案例,也是许多复杂模型的基础。本文将全面介绍线性回归的原理、应用,并通过一段PyTorch代码进行实践演示,帮助读者深入理解这一重要概念。

线性回归概述

线性回归是一种用于预测因变量(目标变量)与一个或多个自变量(特征变量)之间关系的统计方法。其目标是在数据点之间找到一条最佳拟合直线,使得预测值与实际值之间的误差最小。

基本形式

  • 简单线性回归:只有一个自变量。
  • 多元线性回归:包含多个自变量。

本文将聚焦于简单线性回归,即仅考虑一个自变量的情况。

线性回归的数学原理

模型表达式

简单线性回归的模型表达式为:

y = w x + b y = wx + b y=wx+b

其中:

  • y y y 是预测值。
  • x x x 是输入特征。
  • w w w 是权重(斜率)。
  • b b b 是偏置(截距)。

损失函数

为了衡量模型预测值与实际值之间的差异,通常使用均方误差(Mean Squared Error, MSE)作为损失函数:

Loss = 1 2 ∑ i = 1 N ( y i pred − y i ) 2 \text{Loss} = \frac{1}{2} \sum_{i=1}^{N} (y_i^{\text{pred}} - y_i)^2 Loss=21i=1N(yipredyi)2

优化算法

线性回归常用的优化算法是梯度下降(Gradient Descent)。通过计算损失函数关于参数 w w w b b b 的梯度,迭代更新参数以最小化损失。

更新规则如下:

w : = w − η ∂ Loss ∂ w w := w - \eta \frac{\partial \text{Loss}}{\partial w} w:=wηwLoss

b : = b − η ∂ Loss ∂ b b := b - \eta \frac{\partial \text{Loss}}{\partial b} b:=bηbLoss

其中 η \eta η 是学习率。

应用场景

线性回归在多个领域有广泛应用,包括但不限于:

  • 经济学:预测经济指标,如GDP、通货膨胀率等。
  • 工程学:估计物理量之间的关系,如材料强度与应力。
  • 医疗:预测疾病发展趋势,如体重增长与健康指标。
  • 金融:股价预测、风险评估等。

PyTorch实现线性回归

接下来,我们将通过一段PyTorch代码实践线性回归,从数据生成、模型训练到可视化展示,全面演示线性回归的实现过程。代码参考《深度学习框架PyTorch入门与实践》一书的实现,为了感受线性回归的计算过程,代码并未直接调用python中已有的线性回归库。

代码解析

首先,我们导入必要的库并设置随机种子以确保结果可复现。

import torch as t
import matplotlib.pyplot as plt
from IPython import displayt.manual_seed(1000)
数据生成函数

定义一个函数 get_fake_data 来生成假数据,这些数据遵循线性关系 y = 2 x + 3 y = 2x + 3 y=2x+3 并添加了一定的噪声。

def get_fake_data(batch_size=8):x = t.randn(batch_size, 1, dtype=float) * 20  # 随机生成x,范围扩大到[-20, 20]y = x * 2 + (1 + t.randn(batch_size, 1, dtype=float)) * 3  # y = 2x + 3 + 噪声return x, y

调用该函数生成一批数据并进行可视化。

x, y = get_fake_data()plt.figure()
plt.scatter(x, y)
plt.show()
参数初始化

随机初始化权重 w w w 和偏置 b b b,并设置学习率 l r lr lr

# 随机初始化参数
w = t.rand(1, 1, requires_grad=True, dtype=float)
b = t.zeros(1, 1, requires_grad=True, dtype=float)lr = 0.00001
训练过程

通过1000次迭代,使用梯度下降法优化参数 w w w b b b

for i in range(1000):x, y = get_fake_data()y_pred = x.mm(w) + b.expand_as(y)  # 预测值loss = 0.5 * (y_pred - y) ** 2  # 均方误差loss = loss.sum()loss.backward()  # 反向传播计算梯度# 更新参数w.data.sub_(lr * w.grad.data)b.data.sub_(lr * b.grad.data)# 梯度清零w.grad.data.zero_()b.grad.data.zero_()# 每100次迭代可视化一次结果if i % 100 == 0:display.clear_output(wait=True)x_plot = t.arange(0, 20, dtype=float).view(-1, 1)y_plot = x_plot.mm(w) + b.expand_as(x_plot)plt.plot(x_plot.data, y_plot.data, label='Fitting Line')x2, y2 = get_fake_data(batch_size=20)plt.scatter(x2, y2, color='red', label='Data Points')plt.xlim(0, 20)plt.ylim(0, 41)plt.legend()plt.show()plt.pause(0.5)

可视化与训练过程

训练过程中,每隔100次迭代,会清除之前的输出,绘制当前拟合的直线与新生成的数据点。随着训练的进行,拟合线将逐渐接近真实的线性关系 y = 2 x + 3 y = 2x + 3 y=2x+3

以下是训练过程中的可视化效果示例:

在这里插入图片描述

注:实际运行代码时,图像会动态更新,展示拟合过程。

代码关键点解析

  1. 数据生成

    • 使用 torch.randn 生成标准正态分布的随机数,并通过线性变换获取 xy
    • 添加噪声使模型更贴近真实场景。
  2. 参数初始化

    • w 随机初始化,b 初始化为零。
    • requires_grad=True 表示在反向传播时需要计算梯度。
  3. 前向传播

    • 计算预测值 y_pred = x.mm(w) + b.expand_as(y)
    • 使用矩阵乘法 mm 实现线性变换。
  4. 损失计算

    • 采用均方误差损失函数。
    • loss.backward() 计算损失函数相对于参数的梯度。
  5. 参数更新

    • 使用学习率 lr 按梯度方向更新参数。
    • data.sub_ 进行原地更新,避免梯度计算图的干扰。
  6. 梯度清零

    • 每次参数更新后,需要清零梯度 w.grad.data.zero_()b.grad.data.zero_(),以防止梯度累积。
  7. 可视化

    • 使用 matplotlib 绘制拟合线和数据点。
    • display.clear_output(wait=True) 清除之前的图像,避免图形堆积。
    • plt.pause(0.5) 控制图像更新速度。

总结

本文从线性回归的基本概念出发,详细介绍了其数学原理和应用场景,并通过一段PyTorch代码演示了线性回归模型的实现过程。从数据生成、参数初始化、模型训练到结果可视化,全面展示了线性回归的实际应用。通过这种实例讲解,读者不仅能够理解线性回归的理论基础,还能掌握其在深度学习框架中的具体实现方法。

线性回归作为机器学习的基础模型,虽然简单,但其思想却深刻影响着更加复杂的算法和模型。在实际应用中,理解并掌握线性回归对于进一步学习和开发更加复杂的机器学习模型具有重要意义。


如果这篇文章对你有一点点的帮助,欢迎点赞、关注、收藏、转发、评论哦!
我也会在微信公众号“智识小站”坚持分享更多内容,以期记录成长、普及技术、造福后来者!

在这里插入图片描述

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/66590.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MATLAB深度学习实战文字识别

文章目录 前言视频演示效果1.DB文字定位环境配置安装教程与资源说明1.1 DB概述1.2 DB算法原理1.2.1 整体框架1.2.2 特征提取网络Resnet1.2.3 自适应阈值1.2.4 文字区域标注生成1.2.5 DB文字定位模型训练 2.CRNN文字识别2.1 CRNN概述2.2 CRNN原理2.2.1 CRNN网络架构实现2.2.2 CN…

和为0的四元组-蛮力枚举(C语言实现)

目录 一、问题描述 二、蛮力枚举思路 1.初始化: 2.遍历所有可能的四元组: 3.检查和: 4.避免重复: 5.更新计数器: 三、代码实现 四、运行结果 五、 算法复杂度分析 一、问题描述 给定一个整数数组 nums&…

SpringBoot日常:集成Kafka

文章目录 1、pom.xml文件2、application.yml3、生产者配置类4、消费者配置类5、消息订阅6、生产者发送消息7、测试发送消息 本章内容主要介绍如何在springboot项目对kafka进行整合,最终能达到的效果就是能够在项目中通过配置相关的kafka配置,就能进行消息…

【实用技能】如何使用 .NET C# 中的 Azure Key Vault 中的 PFX 证书对 PDF 文档进行签名

TX Text Control 是一款功能类似于 MS Word 的文字处理控件,包括文档创建、编辑、打印、邮件合并、格式转换、拆分合并、导入导出、批量生成等功能。广泛应用于企业文档管理,网站内容发布,电子病历中病案模板创建、病历书写、修改历史、连续打…

33.3K 的Freqtrade:开启加密货币自动化交易之旅

“ 如何更高效、智能地进行交易成为众多投资者关注的焦点。” Freqtrade 是一款用 Python 编写的免费开源加密货币交易机器人。它就像一位不知疲倦的智能交易助手,能够连接到众多主流加密货币交易所,如 Binance、Bitmart、Bybit 等(支…

Mac M2基于MySQL 8.4.3搭建(伪)主从集群

前置准备工作 安装MySQL 8.4.3 参考博主之前的文档,在本地Mac安装好MySQL:Mac M2 Pro安装MySQL 8.4.3安装目录:/usr/local/mysql,安装好的MySQL都处于运行状态,需要先停止MySQL服务最快的方式:系统设置 …

事务的回滚与失效行为

创建一张测试表 AccountMapper public interface AccountMapper {Update("update account set balance #{balance} where username #{username}")int updateUserBalance(Param("username") String username, Param("balance") Integer bal…

【C语言】_字符数组与常量字符串

目录 1. 常量字符串的不可变性 2. 关于常量字符串的打印 3. 关于字符数组与常量字符串的内存分布 1. 常量字符串的不可变性 char arr[10] "abcdef";// 字符数组char* p2 arr;char* p3 "abcdef"; // 常量字符串 尝试对常量字符串进行修改&#xff…

【GUI-pyqt5】QCommandLinkButton类

1. 描述 命令链接的Windows Vista引入的新控件他的用途类似于单选按钮的用途,因为他用于在一组互斥选项之间进行选择命令链接按钮不应单独使用,而应作为向导和对话框中单选按钮替代选项外观通常类似于平面按钮的外观,但除了普通按钮文本外&a…

69.基于SpringBoot + Vue实现的前后端分离-家乡特色推荐系统(项目 + 论文PPT)

项目介绍 在Internet高速发展的今天,我们生活的各个领域都涉及到计算机的应用,其中包括家乡特色推荐的网络应用,在外国家乡特色推荐系统已经是很普遍的方式,不过国内的管理网站可能还处于起步阶段。家乡特色推荐系统采用java技术&…

HCIE-day10-ISIS

ISIS ISIS(Intermediate System-to-Intermediate System)中间系统到中间系统,属于IGP(内部网关协议);是一种链路状态协议,使用最短路径优先SPF算法进行路由计算,与ospf协议有很多相…

图像处理|膨胀操作

在图像处理领域,形态学操作是一种基于图像形状的操作,用于分析和处理图像中对象的几何结构。**膨胀操作(Dilation)**是形态学操作的一种,它能够扩展图像中白色区域(前景)或减少黑色区域&#xf…

【机器学习】量子机器学习:当量子计算遇上人工智能,颠覆即将来临?

我的个人主页 我的领域:人工智能篇,希望能帮助到大家!!!👍点赞 收藏❤ 在当今科技飞速发展的时代,量子计算与人工智能宛如两颗璀璨的星辰,各自在不同的苍穹闪耀,正以前…

Sprint Boot教程之五十:Spring Boot JpaRepository 示例

Spring Boot JpaRepository 示例 Spring Boot建立在 Spring 之上,包含 Spring 的所有功能。由于其快速的生产就绪环境,使开发人员能够直接专注于逻辑,而不必费力配置和设置,因此如今它正成为开发人员的最爱。Spring Boot 是一个基…

腾讯云AI代码助手编程挑战赛-桌面壁纸随机更换

作品简介 用于更换壁纸缓缓心情,或者选择困难症,每一个图片都想用来做壁纸,并且节约了手工时间,所以根据这个需求来创建的这款应用工具,使用的是腾讯云AI代码助手来生成的所有代码,使用方便,快…

说说你对作用域链的理解

一、作用域 作用域,即变量(变量作用域又称上下文)和函数生效(能被访问)的区域或集合 换句话说,作用域决定了代码区块中变量和其他资源的可见性 举个例子 function myFunction() {let inVariable "…

SpringBootWeb 登录认证(day12)

登录功能 基本信息 请求参数 参数格式:application/json 请求数据样例: 响应数据 参数格式:application/json 响应数据样例: Slf4j RestController public class LoginController {Autowiredpriva…

ASP.NET Core 实现微服务 - Consul 配置中心

这一次我们继续介绍微服务相关组件配置中心的使用方法。本来打算介绍下携程开源的重型配置中心框架 apollo 但是体系实在是太过于庞大,还是让我爱不起来。因为前面我们已经介绍了使用Consul 做为服务注册发现的组件 ,那么干脆继续使用 Consul 来作为配置…

DeviceNet转Profinet网关如何革新污水处理行业!

DeviceNet转Profinet网关如何革新污水处理行业?在污水处理行业中,随着环保法规的日益严格和处理技术的不断进步,工业自动化技术的应用越来越广泛。特别是在提高生产效率、降低运营成本以及确保处理质量方面,自动化技术发挥着不可替…

(四)结合代码初步理解帧缓存(Frame Buffer)概念

帧缓存(Framebuffer)是图形渲染管线中的一个非常重要的概念,它用于存储渲染过程中产生的像素数据,并最终输出到显示器上。简单来说,帧缓存就是计算机图形中的“临时画布”,它储存渲染操作生成的图像数据&am…