深度学习与神经网络Pytorch版 3.2 线性回归从零开始实现 1.生成数据集

3.2 线性回归从零开始实现

目录

3.2 线性回归从零开始实现

一 ,简介

1. 原理

2. 步骤

3. 优缺点

4. 应用场景

二 ,代码展现

1. 生成数据集(完整代码)

2. 各个函数解析

2.1 torch.normal()函数

2.2 torch.matmul()函数

2.3 d2l.plt.scatter()函数

三 ,总结


一 ,简介

1. 原理

深度学习线性回归的原理是基于神经网络和线性回归的结合。它使用神经网络来构建一个复杂的非线性模型,同时保持线性回归的简单性和可解释性。

在深度学习线性回归中,通常使用全连接神经网络(Fully Connected Neural Network)作为基础结构。输入数据经过一系列的线性变换和非线性激活函数,最终输出预测结果。与传统的线性回归不同,深度学习线性回归可以自动学习特征之间的复杂交互和组合,而不需要手动选择或设计特征。

深度学习线性回归的训练过程与传统神经网络的训练过程类似,使用梯度下降算法优化模型的参数,以最小化预测误差(如均方误差)。在训练过程中,通过反向传播算法计算梯度,并使用优化器(如Adam、SGD等)更新权重和偏置项。

深度学习线性回归的优点是可以处理高维、复杂的非线性数据,并且具有自动特征选择和组合的能力。然而,与传统的线性回归相比,深度学习线性回归需要更多的参数和计算资源,并且可能更容易过拟合。因此,在选择是否使用深度学习线性回归时,需要根据具体问题和数据集的特点进行权衡。

2. 步骤

线性回归从零开始实现步骤包括以下内容:

  1. 导入必要的库:在Python中,需要导入numpy库来处理数据和计算,以及matplotlib库来绘制数据和结果。
  2. 生成数据集:根据实际问题,可以使用随机数生成器生成一组训练数据集,包括输入特征和对应的标签。也可以使用真实数据集进行训练和测试。
  3. 初始化模型参数:为模型权重和偏置项设置初始值,这些初始值可以是随机数或基于先验知识的值。
  4. 定义模型:根据线性回归模型的公式,可以使用numpy的矩阵运算来构建模型。模型可以表示为y = w * x + b,其中x是输入特征,y是对应的标签,w是权重,b是偏置项。
  5. 计算损失函数:损失函数用于衡量模型的预测值与真实值之间的差距。对于线性回归问题,常用的损失函数是均方误差(MSE)。
  6. 执行梯度下降算法:梯度下降算法用于更新模型的参数以最小化损失函数。在每一步迭代中,根据梯度下降公式计算参数的更新方向和步长,并更新参数的值。
  7. 训练:重复执行步骤5和6,直到达到预设的迭代次数或损失函数达到可接受的值。
  8. 评估模型:使用测试数据集评估模型的性能,计算模型的预测值与真实值之间的误差或准确率等指标。
  9. 优化和调整:根据评估结果对模型进行调整,例如调整参数、增加特征或使用正则化等方法来提高模型的性能。
  10. 应用模型进行预测:将新数据输入到模型中进行预测,得到预测结果。

以上是线性回归从零开始实现的基本步骤,具体实现细节可能会根据问题和数据集的不同而有所差异。

3. 优缺点

线性回归的优点:

  1. 简单易行:线性回归模型简单易懂,实现起来也相对容易。
  2. 计算效率高:由于模型简单,计算复杂度较低,因此在线性回归中,无论是训练还是预测,计算速度都比较快。
  3. 可解释性强:线性回归模型可以给出每个特征的权重,这有助于理解特征对目标变量的影响程度。
  4. 适合处理线性关系:线性回归适合处理因变量和自变量之间存在线性关系的情况。
  5. 模型稳定性好:线性回归模型相对稳定,对异常值和噪声的鲁棒性较好。

然而,线性回归也存在一些缺点:

  1. 假设限制:线性回归基于一些假设,如误差项的独立性、同方差性、无序列相关性和常数方差等。在实际应用中,这些假设可能不成立,导致模型误判。
  2. 欠拟合与过拟合:如果线性模型过于简单(即过于欠拟合),它可能无法捕获数据的复杂模式;而如果模型过于复杂(即过拟合),它可能会捕获到数据中的噪声和无关紧要的信息。
  3. 无法处理非线性关系:对于非线性关系的数据,线性回归可能无法给出很好的预测。
  4. 对异常值敏感:如果数据集中存在异常值,线性回归模型的预测结果可能会受到影响。
  5. 特征选择困难:对于特征之间的交互和特征选择,线性回归模型可能会遇到困难。

4. 应用场景

线性回归的应用场景包括但不限于:

  1. 预测:当因变量是连续变量,并且与其影响因素有线性关系时,可以用线性回归进行建模。例如,预测信用卡用户的生命周期价值,可以基于用户所在小区的平均收入、年龄、学历、收入等因素进行线性回归建模。
  2. 模型解释:当需要理解自变量与因变量之间的关系时,可以通过建立线性回归模型,例如决策树、线性回归等模型,以自变量作为输入变量,以因变量作为目标变量进行建模,以此了解黑盒模型的运作机制,并对其作出解释。
  3. 全量实验效果评估:全量实验评估是指当在时间点时,对全量用户加入干预策略,然后评估策略所带来的影响。进行评估时,核心是要剥离其他因素,对实验效果进行评估,线性回归就能解决这个问题。
  4. AB实验:在AB实验中,假定有两组无差异的用户群体和,以作为实验组对其施加策略干预,作为对照组不采取施加任何策略,来评估实验对观测变量的影响。可以通过t或z检验来得到结果,当然也可以建立线性回归模型 ,为是否为实验组的哑变量(当策略变多时,也可为分类变量),通过检验参数的显著性即可得到策略的效果。
  5. 预测疾病发生概率:医院可以根据患者的病历数据(如体检指标、药物复用情况、平时的饮食习惯等)来预测某种疾病发生的概率。
  6. 预测用户支付转化率:网站可以根据访问的历史数据(包括新用户的注册量、老用户的活跃度、网站内容的更新频率等)来预测用户的支付转化率。

以上只是部分应用场景,线性回归模型的应用非常广泛,具体应用取决于数据的特征和业务需求。

二 ,代码展现

1. 生成数据集(完整代码)

# 线性回归从零开始实现
# 生成数据集# 导入必要的库
import matplotlib.pyplot as plt
import random
import torch
from d2l import torch as d2l# 定义一个生成合成数据的函数
def synthetic_data(w, b, num_examples):    # 函数参数包括权重w、偏置b和数据点数量num_examples# 生成y=Xw+b+噪声满足线性关系y=Xw+b的数据,并添加噪声X = torch.normal(0, 1, (num_examples, len(w)))  # 创建一个形状为(num_examples, len(w))的张量X,元素值为从标准正态分布中抽取的随机数y = torch.matmul(X, w) + b  # 使用矩阵乘法计算y的值,y = X * w + by += torch.normal(0, 0.01, y.shape)  # 在y的值上添加从标准正态分布中抽取的随机噪声,噪声的标准差为0.01return X, y.reshape((-1, 1))  # 返回X和y。y被重新整形为(-1, 1)的形状,这是因为matplotlib在绘图时需要这样的形状# 定义真实的权重和偏置值
true_w = torch.tensor([2, -3.4])  # 真实的权重w为[2, -3.4]的张量
true_b = 4.2  # 真实的偏置b为4.2的标量# 使用上面定义的函数生成数据集
features, labels = synthetic_data(true_w, true_b, 1000)  # 生成1000个数据点作为训练或测试样本,特征为X,标签为y(即labels)print('features:', features[0],'\nlabel:', labels[0])
d2l.set_figsize()
d2l.plt.scatter(features[:, (1)].detach().numpy(), labels.detach().numpy(), 1) 
# 这行代码也是从d2l库中调用的。它使用散点图来可视化特征和标签。#
# features[:, (1)].detach().numpy()选取了所有数据点的第二个特征(索引为1,因为索引是从0开始的)并转换为NumPy数组。
# #.detach()是PyTorch中的方法,用于从计算图中分离张量,这样张量就不会追踪其历史计算,这在进行绘图等操作时是很有用的。
# labels.detach().numpy()将标签转换为NumPy数组。这里的1表示散点的大小。
plt.show()

2. 各个函数解析

2.1 torch.normal()函数
normal(mean, std, *, generator=None, out=None)

参数说明

  • mean (Tensor): 每个输出元素的均值。它是一个张量,其中包含各个分布的均值。
  • std (Tensor): 每个输出元素的标准差。它也是一个张量,其中包含各个分布的标准差。
  • *: 表示后面的参数是关键字参数。
  • generator: 可选参数,一个伪随机数生成器。
  • out: 可选参数,输出张量。

注意事项

  1. meanstd的形状不必匹配,但它们的元素总数必须相同。如果形状不匹配,将使用mean的形状作为返回输出张量的形状。
  2. 如果std是一个CUDA tensor,该函数将同步其设备与CPU。
2.2 torch.matmul()函数
matmul(input, other, *, out=None) -> Tensor

参数说明:

  • input (Tensor): 输入张量。
  • other (Tensor): 另一个张量。
  • *: 表示后面的参数是关键字参数。
  • out (Tensor, optional): 可选参数,输出张量。

行为取决于张量的维度如下:

  • 如果两个张量都是一维的,返回点积(标量)。
  • 如果两个参数都是二维的,返回矩阵-矩阵乘积。
  • 如果第一个参数是一维的,而第二个参数是二维的,为了矩阵乘法,向其维度添加一个1。矩阵乘法之后,添加的维度被移除。
  • 如果第一个参数是二维的,而第二个参数是一维的,返回矩阵-向量乘积。
  • 如果两个参数都至少是一维的,并且至少有一个参数是N维的(其中N>2),则返回批处理矩阵乘法。如果第一个参数是一维的,为了批处理矩阵乘法,向其维度添加一个1,然后在批处理矩阵乘法之后移除它。如果第二个参数是一维的,为了批处理矩阵乘法,向其维度添加一个1,然后在批处理矩阵乘法之后移除它。非矩阵(即批处理)维度是广播的(因此必须可广播)。例如,如果input是一个(j × 1 × n × n)张量,而other是一个(k × n × n)张量,则out将是一个(j × k × n × n)张量。
2.3 d2l.plt.scatter()函数
scatter(x, y, s=None, c=None, marker=None, cmap=None, norm=None, vmin=None, vmax=None, alpha=None, linewidths=None, *, edgecolors=None, plotnonfinite=False, data=None, **kwargs)

参数说明:

  • x, y:这些是您要在散点图中表示的数据点的x和y坐标。

  • s:散点的面积,以像素为单位。这通常用于根据数据点的值进行大小调整。

  • c:用于颜色映射的单个值或数组,通常表示颜色或数据点的值。

  • marker:散点的形状。例如,'o'表示圆形,'.'表示点,','表示像素等。

  • cmap:颜色映射对象或名称。这决定了如何根据c参数的值映射颜色。

  • norm:用于映射到给定范围的归一化对象。这通常与cmap一起使用,以控制颜色映射的范围。

  • vmin, vmax:这些参数指定了归一化对象的下限和上限。它们与norm一起使用来控制颜色映射的范围。

  • alpha:散点的透明度。值范围从0(完全透明)到1(完全不透明)。

  • linewidths:用于绘制边框线的宽度。当不为None时,这会使散点变为带边框的圆圈。

  • edgecolors:用于边框线的颜色。这可以是单一的颜色或颜色数组,与数据点一一对应。

  • plotnonfinite:如果为True,则非有限数值的数据点将被绘制。默认为False。

  • data:提供给所有数据的原始数据的字典。这通常在传递给函数的数据不是直接参数时使用。

  • kwargs:其他关键字参数将传递给collections.PathCollection的构造函数,允许您自定义散点图的其他方面。例如,您可以指定label来在图例中标识这些点等。

三 ,总结

 这段代码的主要目的是生成数据集,并使用散点图可视化其特征和标签。通过这种方式,可以直观地观察到数据分布和特征之间的关系。此外,代码还演示了如何使用PyTorch进行矩阵运算和NumPy数组转换,以及如何使用d2l库中的函数进行绘图操作。

        之后我会更新,线性回归的读取数据集,初始化模型参数,定义模型,定义模型,定义损失函数,定义优化算法,训练等步骤。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/656956.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

18.通过telepresence调试部署在Kubernetes上的微服务

Telepresence简介 在微服务架构中,本地开发和调试往往是一项具有挑战性的任务。Telepresence 是一种强大的工具,使得开发者本地机器上开发微服务时能够与运行在 Kubernetes 集群中的其他服务无缝交互。本文将深入探讨 Telepresence 的架构、运行原理,并通过实际的案例演示其…

在mgre环境下配置ospf

实验规则如上图所示: 首先规划IP 配置缺省路由,使得公网全网可达 此处在r1上进行配置: 由此可以实现,公网已经全网可达: 其次,再分配 全连的MGRE网段,全连的MGRE网段每个路由器都是中心站点&…

基于springboot招聘信息管理系统源码和论文

在Internet高速发展的今天,我们生活的各个领域都涉及到计算机的应用,其中包括招聘信息管理系统的网络应用,在外国招聘信息管理系统已经是很普遍的方式,不过国内的线上管理系统可能还处于起步阶段。招聘信息管理系统具有招聘信息管…

配置nginx作为静态文件托管服务器

下载nginx windows上是个压缩包 解压后, 使用命令行输入 nginx 进行启动 nginx -s stop 进行停止 nginx -s status 查看状态 可以配置一下环境变量 主要是配置文件, windows的nginx配置文件在 conf文件夹下 在http标签下 添加如下配置 其他地方不用更改,保持原样即可, 以…

git diff查看比对两次不同时间点提交的异同

git diff查看比对两次不同时间点提交的异同 用 git diff命令: git diff commit-id-1 commit-id-2 不同commit-id在不同的时间点提交产生,因为也可以认为git diff是比对两个不同时间点的代码异同。 git diff比较不同commit版本的代码文件异同_git diff c…

2024年航海制造工程与海洋工程国际会议(ICNMEME2024)

一、【会议简介】 2024年航海制造工程与海洋工程国际会议(ICNMEME2024)旨在将研究人员、工程师、科学家和行业专业人士聚集在一个开放论坛上,展示他们在导航制造工程与海洋工程领域的激励研究和知识转移理念。然而,我们也认识到,工程师的未来…

代码随想录算法训练营第二十天 |654.最大二叉树,617.合并二叉树,700.二叉搜索树种的搜索,98.验证二叉搜索树(待补充)

654.最大二叉树 1、题目链接:力扣(LeetCode)官网 - 全球极客挚爱的技术成长平台 2、文章讲解:代码随想录 3、题目: 给定一个不含重复元素的整数数组。一个以此数组构建的最大二叉树定义如下: 二叉树的…

petalinux2022.2启动文件编译配置

安装必要运行库: sudo apt-get install iproute2 gawk python3 python sudo apt-get install build-essential gcc git make net-tools libncurses5-dev tftpd sudo apt-get install zlib1g-dev libssl-dev flex bison libselinux1 gnupg wget git-core diffstat sudo apt-ge…

react实现滚动到顶部组件

新建ScrollToTop.js import React, { useState, useEffect } from react; import ./ScrollToTop.css;function ScrollToTop() {const [isVisible, setIsVisible] useState(true);// Show button when page is scorlled upto given distanceconst toggleVisibility () > {…

处理Servlet生命周期事件

处理Servlet生命周期事件 接收关于 Servlet生命周期事件通知的类称为事件侦听器。这些侦听器实现Servlet API中定义的一个或多个servlet事件侦听器接口。侦听器类的逻辑分类如下: servlet请求侦听器Servlet上下文侦听器HTTP会话侦听器1. servlet请求侦听器 servlet请求侦听器…

专业138总分420+中国科学技术大学843信号与系统考研经验中科大电子信息通信

**今年中科大专业课843信号与系统138分,总分420顺利上岸,梦圆中科大,也是报了高考失利的遗憾,总结一下自己的复习经历,希望可以给大家提供参考。**首先,中科大843包括信号与系统,和数字信号处理…

网络隔离场景下访问 Pod 网络

接着上文 VPC网络架构下的网络上数据采集 介绍 考虑一个监控系统,它的数据采集 Agent 是以 daemonset 形式运行在物理机上的,它需要采集 Pod 的各种监控信息。现在很流行的一个监控信息是通过 Prometheus 提供指标信息。 一般来说,daemonset …

线性代数------矩阵的运算和逆矩阵

矩阵VS行列式 矩阵是一个数表,而行列式是一个具体的数; 矩阵是使用大写字母表示,行列式是使用类似绝对值的两个竖杠; 矩阵的行数可以不等于列数,但是行列式的行数等于列数; 1.矩阵的数乘就是矩阵的每个…

记录springboot bug

mybatis bug mapper 自动生成xml 产生错误 首先我这个bug十分奇怪,不管是报错,还是解决方法 首先,我还原我bug的过程 我首先要在 ordersMapper生成一个方法 本来是这样的方法 Mapper public interface OrdersMapper extends BaseMapper<Orders> {List<GoodsSales…

C语言——深入理解指针2

目录 1. 野指针1.1 野指针成因1.1.1 指针未初始化1.1.2 指针越界访问1.1.3 指针指向的空间释放 1.2 如何规避野指针1.2.1 指针初始化1.2.2 小心指针越界1.2.3 指针变量不再使用时&#xff0c;及时置NULL&#xff0c;指针使用之前检查有效性1.2.4 避免返回局部变量的地址 2. ass…

Linux:进度条的创建

目录 使用工具的简单介绍&#xff1a; \r &#xff1a; fflush &#xff1a; 倒计时的创建&#xff1a; 倒计时的工作原理&#xff1a; 进度条的创建&#xff1a; 不同场景下、打印任意长度的进度条&#xff1a; main .c procbor.c 测试效果&#xff1a; 使用工具…

YOLOv8实例分割实战:TensorRT加速部署

课程链接&#xff1a;https://edu.csdn.net/course/detail/39273 PyTorch版的YOLOv8支持高性能实时实例分割方法。 TensorRT是针对英伟达GPU的加速工具。 本课程讲述如何使用TensorRT对YOLOv8实例分割进行加速和部署&#xff0c;实测推理速度提高3倍以上。  采用改进后的t…

设计模式第2篇|策略模式

&#x1f680; 作者简介&#xff1a;程序员小豪&#xff0c;全栈工程师&#xff0c;热爱编程&#xff0c;曾就职于蔚来、腾讯&#xff0c;现就职于某互联网大厂&#xff0c;技术栈&#xff1a;Vue、React、Python、Java &#x1f388; 本文收录于小豪的前端系列专栏&#xff0c…

Vertica单点更改服务器ip

需求 服务器网段调整&#xff0c;将ip&#xff1a;192.168.40.190收回&#xff0c;使用ip&#xff1a;192.168.40.200 默认情况下&#xff0c;节点 IP 地址和导出 IP 地址配置相同的 IP 地址。导出地址是网络上有权访问其他 DBMS 系统的节点的 IP 地址。使用导出地址从 DBMS …

解锁Web3:数字未来的大门

随着科技的不断推进&#xff0c;我们正站在数字时代的新门槛上。Web3&#xff0c;作为互联网的下一个演进阶段&#xff0c;正在逐渐揭开数字未来的面纱。本文将深入探讨Web3的本质、对社会的影响以及在数字时代中所扮演的关键角色。 什么是Web3&#xff1f; Web3是互联网发展的…