从零开始学习线性回归:理论、实践与PyTorch实现

文章目录

  • 🥦介绍
  • 🥦基本知识
  • 🥦代码实现
  • 🥦完整代码
  • 🥦总结

🥦介绍

线性回归是统计学和机器学习中最简单而强大的算法之一,用于建模和预测连续性数值输出与输入特征之间的关系。本博客将深入探讨线性回归的理论基础、数学公式以及如何使用PyTorch实现一个简单的线性回归模型。

🥦基本知识

线性回归的数学基础
线性回归的核心思想是建立一个线性方程,它表示了自变量(输入特征)与因变量(输出)之间的关系。这个线性方程通常表示为:
y=b0+b1x1+b2x2+…+bpxp
y=b0​+b1​x1​+b2​x2​+…+bp​xp​

其中, y y y 是因变量, x 1 , x 2 , … , x p x_1, x_2, \ldots, x_p x1,x2,,xp 是自变量, b 0 , b 1 , b 2 , … , b p b_0, b_1, b_2, \ldots, b_p b0,b1,b2,,bp 是模型的参数, p p p 是特征的数量。我们的目标是找到最佳的参数值,以最小化模型的误差。

损失函数
为了找到最佳参数,我们需要定义一个损失函数来度量模型的性能。在线性回归中,最常用的损失函数是均方误差(MSE),它表示了模型预测值与实际值之间的平方差的平均值:
在这里插入图片描述
其中, n n n 是样本数量, y i y_i yi 是实际值, y ^ i \hat{y}_i y^i 是模型的预测值。

梯度下降优化
为了最小化损失函数,我们使用梯度下降算法。梯度下降通过计算损失函数相对于参数的梯度,并迭代地更新参数,以减小损失。更新规则如下:
在这里插入图片描述
其中, b j b_j bj 是第 j j j个参数, α \alpha α 是学习率, ∂ ∂ b j M S E \frac{\partial}{\partial b_j} MSE bjMSE 是损失函数对参数 b j b_j bj的偏导数。

🥦代码实现

如果你想知道实现线性回归的大体步骤,下图可以充分进行说明
在这里插入图片描述

  • 准备数据
  • 设计模型(计算) y ^ i \hat{y}_i y^i
  • 构造损失和优化器
  • 训练周期(前向,反向 ,更新)

本节还是以刘二大人的视频讲解为例,结尾会设置传送门

class LinearModel(torch.nn.Module):def __init__(self):super(LinearModel, self).__init__() # 调用父类的构造函数self.linear = torch.nn.Linear(1, 1)  # 参数详情下图展示def forward(self, x):y_pred = self.linear(x)   # x代表输入样本的张量return y_pred
model = LinearModel()

所以模型类都要继承Module,此类主要包含两个函数一个是构造函数(初始化对象时调用),另一个是前向计算

好奇的小伙伴会思考为何没有反向(backward),这是因为Module会帮你进行,但是如果后期自己有更高效的方法可以自行设置。
在这里插入图片描述

  • 第一个参数 in_features:这是输入特征的数量。在这里,表示我们的模型只有一个输入特征。如果你有多个输入特征,你可以将这个参数设置为输入特征的数量。

  • 第二个参数 out_features:这是输出特征的数量。这表示我们的模型将生成一个输出。在线性回归中,通常只有一个输出,因为我们试图预测一个连续的数值。

  • 第三个参数:意思是要不要偏置量。默认true

通常情况下特征代表列,比如我们有一个n×2的y和一个n×3的x,那么我们需要一个3×2的权重,有的书中会在两边做转置,但无论咋样目的都是为了让这个矩阵乘法成立

criterion = torch.nn.MSELoss(size_average=False)  # 使用均方误差损失 
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)  # 使用随机梯度下降优化器

在这里插入图片描述
在这里插入图片描述
model.parameters() 用于告诉优化器哪些参数需要在训练过程中进行更新,这包括模型的权重和偏置项等。在线性回归示例中,模型的参数包括权重和偏置项。

优化器的选择有许多大家可以都试试看看
在这里插入图片描述

之后就进行训练了

for epoch in range(1000):y_pred = model(x_data)loss = criterion(y_pred, y_data) print(epoch, loss.item())optimizer.zero_grad()   # 归零loss.backward()  # 反向optimizer.step()  # 更新
print('w = ', model.linear.weight.item()) 
print('b = ', model.linear.bias.item())
x_test = torch.Tensor([[4.0]]) 
y_test = model(x_test)
print('y_pred = ', y_test.data)

🥦完整代码

x_data = torch.Tensor([[1.0], [2.0], [3.0]]) 
y_data = torch.Tensor([[2.0], [4.0], [6.0]])
class LinearModel(torch.nn.Module):def __init__(self):super(LinearModel, self).__init__() self.linear = torch.nn.Linear(1, 1)def forward(self, x):y_pred = self.linear(x) return y_pred
model = LinearModel()
criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
for epoch in range(1000):y_pred = model(x_data)loss = criterion(y_pred, y_data) print(epoch, loss.item())optimizer.zero_grad() loss.backward()optimizer.step()
print('w = ', model.linear.weight.item()) 
print('b = ', model.linear.bias.item())
x_test = torch.Tensor([[4.0]]) 
y_test = model(x_test)
print('y_pred = ', y_test.data)
predicted = model(x_data).detach().numpy()
plt.scatter(x_data, y_data, label='Original data')
plt.plot(x_data, predicted, label='Fitted line', color='r')
plt.legend()
plt.show()

运行结果如下
在这里插入图片描述

在这里插入图片描述

🥦总结

在本篇博客中,我们使用PyTorch实现了一个简单的线性回归模型,并使用随机生成的数据对其进行了训练和可视化。线性回归是一个入门级的机器学习模型,但它为理解模型训练和预测的基本概念提供了一个很好的起点。

请添加图片描述

挑战与创造都是很痛苦的,但是很充实。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/95856.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

面试题: Spring中Bean的实例化和Bean的初始化有什么区别?

Spring中Bean的实例化和Bean的初始化有什么区别? 背景答案扩展知识什么是实例化什么是初始化 个人评价我的回答 背景 想换工作, 看了图灵周瑜老师的视频想记录一下, 算是学习结果的一个输出. 答案 Spring 在创建一个Bean对象时, 会先创建出一个Java对象, 会通过反射来执行…

Mac上protobuf环境构建-java

参考文献 getting-started 官网pb java介绍 maven protobuf插件 简单入门1 简单入门2 1. protoc编译器下载安装 https://github.com/protocolbuffers/protobuf/releases?page10 放入.zshrc中配置环境变量  ~/IdeaProjects/test2/ protoc --version libprotoc 3.12.1  …

Javascript文件上传

什么是文件上传 文件上传包含两部分, 一部分是选择文件,包含所有相关的界面交互。一部分是网络传输,通过一个网络请求,将文件的数据携带过去,传递到服务器中,剩下的,在服务器中如何存储&#xf…

基于SpringBoot的电影评论网站

目录 前言 一、技术栈 二、系统功能介绍 电影信息管理 电影评论回复 电影信息 用户注册 三、核心代码 1、登录模块 2、文件上传模块 3、代码封装 前言 随着信息技术在管理上越来越深入而广泛的应用,管理信息系统的实施在技术上已逐步成熟。本文介绍了电影评…

Redisson—分布式集合详述

7.1. 映射(Map) 基于Redis的Redisson的分布式映射结构的RMap Java对象实现了java.util.concurrent.ConcurrentMap接口和java.util.Map接口。与HashMap不同的是,RMap保持了元素的插入顺序。该对象的最大容量受Redis限制,最大元素数…

Pyhon-每日一练(1)

🌈write in front🌈 🧸大家好,我是Aileen🧸.希望你看完之后,能对你有所帮助,不足请指正!共同学习交流. 🆔本文由Aileen_0v0🧸 原创 CSDN首发🐒 如…

Vue中如何进行分布式路由配置与管理

Vue中的分布式路由配置与管理 随着现代Web应用程序的复杂性不断增加,分布式路由配置和管理成为了一个重要的主题。Vue.js作为一种流行的前端框架,提供了多种方法来管理Vue应用程序的路由。本文将深入探讨在Vue中如何进行分布式路由配置与管理&#xff0…

OpenWRT、Yocto 、Buildroot和Ubuntu有什么区别

OpenWRT: 用途:OpenWRT 是一个专注于路由器和嵌入式网络设备的Linux发行版。它提供了一个优化的Linux环境,旨在将网络设备变成功能丰富、高度可定制的路由器。 包管理器:OpenWRT 使用 opkg 包管理器,它是一个轻量级的…

Spring Boot中的@Controller使用教程

一 Controller使用方法,如下所示: Controller是SpringBoot里最基本的组件,他的作用是把用户提交来的请求通过对URL的匹配,分配个不同的接收器,再进行处理,然后向用户返回结果。下面通过本文给大家介绍Spr…

linux内核分析:虚拟化

https://www.cnblogs.com/wujuntian/p/16294898.html 三种虚拟化方式 1. 对于虚拟机内核来讲,只要将标志位设为虚拟机状态,我们就可以直接在 CPU 上执行大部分的指令,不需要虚拟化软件在中间转述,除非遇到特别敏感的指令,才需要将标志位设为物理机内核态运行,这样大大…

minio分布式文件存储

基本介绍 什么是 MinIO MinIO 是一款基于 Go 语言的高性能、可扩展、云原生支持、操作简单、开源的分布式对象存储产品。基于 Apache License v2.0 开源协议,虽然轻量,却拥有着不错的性能。它兼容亚马逊S3云存储服务接口。可以很简单的和其他应…

Java方法:重复使用的操作可以写成方法哦

👑专栏内容:Java⛪个人主页:子夜的星的主页💕座右铭:前路未远,步履不停 目录 一、方法的概念1、什么是方法?2、方法的定义3、方法调用的过程 二、方法重载1、重载的概念2、方法签名 在日常生活中…

使用 python 检测泛洪攻击的案例

使用 python 检测泛洪攻击的案例 本案例只使用python标准库通过执行命令来监控异常请求, 并封锁IP, 不涉及其他第三方库工具. import os import time from collections import Counter# 1、update 命令, 采集CPU的平均负载 def get_cpu_load():"""uptime 命令…

计算机网络 快速了解网络层次、常用协议、常见物理设备。 掌握程序员必备网络基础知识!!!

文章目录 0 引言1 基础知识的定义1.1 计算机网络层次1.2 网络供应商 ISP1.3 猫、路由器、交换机1.4 IP协议1.5 TCP、UDP协议1.6 HTTP、HTTPS、FTP协议1.7 Web、Web浏览器、Web服务器1.8 以太网和WLAN1.9 Socket (网络套接字) 2 总结 0 引言 在学习的过程…

【Java-LangChain:使用 ChatGPT API 搭建系统-2】语言模型,提问范式与 Token

第二章 语言模型,提问范式与 Token 在本章中,我们将和您分享大型语言模型(LLM)的工作原理、训练方式以及分词器(tokenizer)等细节对 LLM 输出的影响。我们还将介绍 LLM 的提问范式(chat format…

KNN(下):数据分析 | 数据挖掘 | 十大算法之一

⭐️⭐️⭐️⭐️⭐️欢迎来到我的博客⭐️⭐️⭐️⭐️⭐️ 🐴作者:秋无之地 🐴简介:CSDN爬虫、后端、大数据领域创作者。目前从事python爬虫、后端和大数据等相关工作,主要擅长领域有:爬虫、后端、大数据开发、数据分析等。 🐴欢迎小伙伴们点赞👍🏻、收藏⭐️、…

postgresql-物化视图

postgresql-物化视图 物化视图创建物化视图刷新物化视图修改物化视图删除物化视图 物化视图 创建物化视图 postgresql使用create materialized view 语句创建视图 create materialized view if not exists name as query [with [NO] data];-- 创建一个包含员工统计信息的物化…

java遇到的问题

java遇到的问题 Tomcat与JDK版本问题 当使用Tomcat10的版本用于springmvc借用浏览器调试时,使用JDK8浏览器会报异常。 需要JDK17(可以配置多个JDK环境,切换使用)才可以使用,配置为JAVA_HOME路径,否则&a…

【AI视野·今日Robot 机器人论文速览 第四十七期】Wed, 4 Oct 2023

AI视野今日CS.Robotics 机器人学论文速览 Wed, 4 Oct 2023 Totally 40 papers 👉上期速览✈更多精彩请移步主页 Interesting: 📚基于神经网络的多模态触觉感知, classification, position, posture, and force of the grasped object多模态形象的解耦(f…

PostgreSQL进阶

PostgreSQL进阶:发挥强大功能的数据库技巧 PostgreSQL,作为一款强大且高度可定制的关系型数据库管理系统(RDBMS),提供了众多高级功能和功能,使其成为开发人员和数据库管理员的首选。在这篇博客中&#xff0…