深度学习——回归实战

线性回归:

线性:自变量和应变量之间是线性关系,如:y = wx +b

回归:拟合一条曲线,使真实值和拟合值差距尽可能小

目标:求解参数w和b         所用算法:梯度下降算法

梯度下降:向着梯度方向(下降最快的方向)走一步,不停迭代

梯度下降过程:

训练循环(核心步骤):

  • 前向传播
    • 将训练集中的一批数据(一个批次)输入到模型中,通过模型的各层计算得到输出。
    • 这个过程中,数据按照模型的架构顺序依次通过各层,每层根据其权重和偏置对数据进行计算,如在全连接层中,计算是通过矩阵乘法和加法实现的。
  • 计算损失
    • 使用定义好的损失函数,计算模型预测输出与该批次数据真实标签之间的损失。
    • 损失值反映了当前模型在这一批次数据上的预测误差大小。
  • 反向传播
    • 根据计算得到的损失,通过链式法则从最后一层开始,逐层计算损失对每个模型参数(权重和偏置)的梯度。
    • 反向传播算法使得模型能够知道每个参数对损失的影响程度,从而为参数更新提供依据。
  • 更新参数
    • 使用选择的优化器,根据计算得到的梯度更新模型的参数。例如,在使用 SGD 优化器时,参数更新公式为:参数 = 参数 - 学习率 * 梯度。
    • 更新后的参数将用于下一次迭代的前向传播,通过不断重复这个过程,模型的参数逐渐调整,使得损失函数不断减小。
      import torch #深度学习框架
      import matplotlib.pyplot as plt #画图
      import random #随机def create_data(w, b, data_num):   #生成数据x = torch.normal(0, 1, (data_num, len(w)))   #平均数为0,方差为1,长度为data_num,宽度为len(w)y = torch.matmul(x, w) + b #通过矩阵乘法将输入数据x与权重w相乘,然后加上偏置项b,生成新的输出ynoise = torch.normal(0, 0.01, y.shape)   #噪声要加到y上y+= noise      #模拟真实数据的不确定性、防止模型过拟合return x, ynum = 500   # 生成的数据数量true_w = torch.tensor([8.1, 2, 2, 4])  # 真实的权重
      true_b = torch.tensor(1.1)   # 真实的偏置X, Y = create_data(true_w, true_b, num)  # 生成数据plt.scatter(X[:, 0], Y, 1)  #对x张量进行切片,选择所有行、第一列
      plt.show()def data_provider(data, label, batchsize):   #每次访问这个函数,就能提供一批数据,传入参数依次为:数据、标签、步长length = len(label)   # 获取标签数据的长度,由于数据和标签一一对应,以此代表整个数据集的长度indices = list(range(length))  #  创建一个从0到length - 1的索引列表,每个索引对应数据集中的一个样本,用于后续操作random.shuffle(indices)  # 对索引列表进行随机打乱,确保每次取数据批次时是随机顺序,避免数据顺序依赖,增强模型训练效果# 按照指定的批量大小batchsize遍历整个数据集,每次取出一个批次的数据范围for each in range(0, length, batchsize):get_indices = indices[each: each+batchsize]  # 从打乱后的索引列表中取出当前批次对应的索引范围get_data = data[get_indices]   # 根据取出的索引范围从数据张量data中获取当前批次的数据get_label = label[get_indices]  # 根据取出的索引范围从标签张量label中获取当前批次对应的标签yield get_data, get_label  # 使用yield关键字返回当前批次的数据和标签,使函数成为生成器,下次调用继续返回下一批次batchsize = 16  #步长设置为16
      # for batch_x, batch_y in data_provider(X, Y, batchsize):
      #     print(batch_x, batch_y)
      #     #break# 定义函数fun,用于根据输入数据x、权重w和偏置b进行线性变换计算,得到预测输出
      def fun(x, w, b):pred_y = torch.matmul(x, w) + b   # 使用torch.matmul对输入数据x和权重w进行矩阵乘法运算,然后加上偏置b,得到预测输出pred_yreturn pred_y# 定义函数maeloss,用于计算预测值pre_y和真实值y之间的平均绝对误差(MAE)损失
      def maeloss(pre_y, y):return torch.sum(abs(pre_y-y))/len(y)  # 先计算预测值和真实值之间差值的绝对值,再对所有差值的绝对值求和,最后除以数据数量len(y),得到平均绝对误差作为损失值并返回# 定义函数sgd,实现随机梯度下降算法,用于更新模型参数
      def sgd(paras, lr):          #随机梯度下降,更新参数# 使用torch.no_grad()上下文管理器,在这个范围内的操作不会进行梯度计算,因为参数更新阶段不需要对更新操作本身计算梯度with torch.no_grad():  #属于这句代码的部分,不计算梯度for para in paras:# 遍历要更新的参数列表paras中的每一个参数# 根据随机梯度下降算法规则,将当前参数para减去其梯度para.grad与学习率lr的乘积,实现参数更新(注意要用 -= 操作符进行原位更新)para -= para.grad* lr  #不能写成  para = para - para.grad*lrpara.grad.zero_()     #使用过的梯度,归0,避免下一次迭代时梯度累积,导致参数更新错误lr = 0.03  #学习率
      w_0 = torch.normal(0, 0.01, true_w.shape, requires_grad=True)   # 使用torch.normal函数按照正态分布来初始化权重参数w_0,其中均值为0,方差为0.01,形状与之前定义的真实权重true_w保持一致,并且设置requires_grad为True,意味着这个参数在后续的计算中需要跟踪计算梯度,以便进行自动求导来更新它
      b_0 = torch.tensor(0.01, requires_grad=True)  # 初始化偏置参数b_0,将其设置为值是0.01的标量张量,同时设置requires_grad为True,这样该参数就能参与到梯度计算以及后续的参数更新过程中
      print(w_0, b_0)epochs  = 50#训练多少轮# 按训练轮数epochs循环,每轮训练模型
      for epoch in range(epochs):data_loss = 0   # 初始化本轮累计损失为0for batch_x, batch_y in data_provider(X, Y, batchsize):   # 按批次获取数据,循环处理每批次pred_y = fun(batch_x, w_0, b_0)   # 用当前参数预测本批次数据,得预测值loss = maeloss(pred_y, batch_y)  # 计算预测值与真实值的损失loss.backward()  # 反向传播求梯度sgd([w_0, b_0], lr)  # 用sgd更新模型参数data_loss += loss  # 累加本批次损失到本轮累计损失print("epoch %03d: loss: %.6f"%(epoch, data_loss))  # 打印本轮轮数和累计损失,观察训练情况print("真实的函数值是", true_w, true_b)
      print("训练得到的参数值是", w_0, b_0)idx = 0  # 初始化一个索引变量idx为0,这个索引通常用于选择数据张量X中的某一列数据,后续可能用于可视化等相关操作
      plt.plot(X[:, idx].detach().numpy(), X[:, idx].detach().numpy()*w_0[idx].detach().numpy()+b_0.detach().numpy())  # 绘制拟合直线,取X第idx列数据转numpy作横坐标,按线性回归公式用训练参数算纵坐标来绘制
      plt.scatter(X[:, idx], Y, 1)  # 绘制散点图,以输入数据张量X的第idx列数据作为横坐标,以对应的输出数据Y作为纵坐标,展示原始数据的分布情况
      plt.show()
      Tips:
      1、yield get_data, get_label当函数执行到 yield 语句时,函数会暂停执行,将 yield 后面的值返回给调用者,但函数并没有结束。下次调用这个函数时,它会从上次暂停的地方继续执行,直到遇到下一个 yield 或者函数结束。这意味着在一个生成器函数中,可以通过多个 yield 语句多次返回不同的值。就像在 data_provider 函数中,每次调用会返回一个新的数据批次和标签批次,直到所有批次都返回完。
      2、para.grad.zero_() 如果我们不清零这个梯度,在第二次训练批次进行反向传播时,新计算出来的梯度会和第一次遗留下来的梯度相加。就好像你在走迷宫,第一次得到的指示(梯度)是向左走三步,但是你没记住这个指示,第二次又得到一个指示(新的梯度)是向右走两步,但是你把两次的指示混在一起,变成了向左走一步(假设梯度相加的情况),这样就会让你的方向(参数更新方向)变得混乱。
      3、plt.plot(X[:, idx].detach().numpy(), X[:, idx].detach().numpy()*w_0[idx].detach().numpy()+b_0.detach().numpy())①绘制一条直线,用于表示拟合的线性关系(在二维平面上展示线性回归拟合情况)。②首先从输入数据张量X中取出第idx列数据(通过X[:, idx]),并将其转换为numpy数组(使用detach().numpy()方法,目的是从计算图中分离出来并转为numpy格式方便绘图),作为横坐标。然后根据线性回归的公式y = w * x + b,计算对应的纵坐标,这里使用当前训练得到的权重w_0的第idx个元素(w_0[idx])和偏置b_0,同样转换为numpy数组后参与计算,以此绘制出拟合的直线。
      

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/65438.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

单片机-串转并-74HC595芯片

1、74HC595芯片介绍 74HC595 是一个 8 位串行输入、并行输出的位移缓存器,其中并行输出为三态输出(即高电平、低电平和高阻抗)。 15 和 1 到 7 脚 QA--QH:并行数据输出 9 脚 QH 非:串行数据输出 10 脚 SCLK 非&#x…

探索AI在地质科研绘图中的应用:ChatGPT与Midjourney绘图流程与效果对比

文章目录 个人感受一、AI绘图流程1.1 Midjourney(1)环境配置(2)生成prompt(3)完善prompt(4)开始绘图(5)后处理 1.2 ChatGPT不合理的出图结果解决方案 二、主题…

【微服务】6、限流 熔断

线程隔离与容错处理 本视频主要讲解了在购物车业务中,因商品微服务响应慢导致的问题及解决方案,重点介绍了线程隔离后查询购物车业务不可用的情况,以及如何通过Fallback逻辑进行缓解,包括配置Feign调用为簇点资源、添加Fallback逻…

25年01月HarmonyOS应用基础认证最新题库

判断题 “一次开发,多端部署”指的是一个工程,一次开发上架,多端按需部署。为了实现这一目的,HarmonyOS提供了多端开发环境,多端开发能力以及多端分发机制。 答案:正确 《鸿蒙生态应用开发白皮书》全面阐释…

ELK实战(最详细)

一、什么是ELK ELK是三个产品的简称:ElasticSearch(简称ES) 、Logstash 、Kibana 。其中: ElasticSearch:是一个开源分布式搜索引擎Logstash :是一个数据收集引擎,支持日志搜集、分析、过滤,支持大量数据…

Dubbo-笔记随记一

一、实战 1 . Springboot整合 1.1 服务提供者 1.1.1 依赖 <dependency><groupId>org.apache.dubbo</groupId><artifactId>dubbo-spring-boot-starter</artifactId><version>3.2.10</version></dependency><dependency&g…

ETCD渗透利用指南

目录 未指定使用put操作报错 未指定操作版本使用get报错 首先etcd分为两个版本v2和v3&#xff0c;不同的API结果无论是访问URL还是使用etcdctl进行通信&#xff0c;都会导致问题&#xff0c;例如使用etcdctl和v3进行通信&#xff0c;如果没有实名ETCDCTL_API3指定API版本会直接…

使用VUE3创建个人静态主页

使用VUE3创建个人静态主页 &#x1f31f; 前言&#x1f60e;体验&#x1f528; 具体实现✨ 核心功能&#x1f3d7;️ 项目结构&#x1f680; 用这个项目部署 Git Page &#x1f4d6; 参考 &#x1f31f; 前言 作为开发者或者内容创作者&#xff0c;我们经常需要创建静态网页&a…

llm大模型学习

llm大模型 混合专家模型&#xff08;MoE&#xff09;MoE结构路由router专家expertSwitch Transformer的典型MOE模型最后MoE总结 混合专家模型&#xff08;MoE&#xff09; 模型规模是提升LLM大语言模型性能的关键因素&#xff0c;但也会增加计算成本。Mixture of Experts (MoE…

Linux入门攻坚——43、keepalived入门-1

Linux Cluster&#xff08;Linux集群的类型&#xff09;&#xff1a;LB、HA、HPC&#xff0c;分别是负载均衡集群、高可用性集群、高性能集群。 LB&#xff1a;lvs&#xff0c;nginx HA&#xff1a;keepalived&#xff0c;heartbeat&#xff0c;corosync&#xff0c;cman HP&am…

YOLOv8/YOLOv11改进 添加CBAM、GAM、SimAM、EMA、CAA、ECA、CA等多种注意力机制

目录 前言 CBAM GAM SimAM EMA CAA ECA CA 添加方法 YAML文件添加 使用改进训练 前言 本篇文章将为大家介绍Ultralytics/YOLOv8/YOLOv11中常用注意力机制的添加&#xff0c;可以满足一些简单的涨点需求。本文仅写方法&#xff0c;原理不多讲解&#xff0c;需要可跳…

【C语言】_指针与数组

目录 1. 数组名的含义 1.1 数组名与数组首元素的地址的联系 1.3 数组名与首元素地址相异的情况 2. 使用指针访问数组 3. 一维数组传参的本质 3.1 代码示例1&#xff1a;函数体内计算sz&#xff08;sz不作实参传递&#xff09; 3.2 代码示例2&#xff1a;sz作为实参传递 3…

解决“KEIL5软件模拟仿真无法打印浮点数”之问题

在没有外部硬件支持时&#xff0c;我们会使用KEIL5软件模拟仿真&#xff0c;这是是仿真必须要掌握的技巧。 1、点击“Project”&#xff0c;然后点击“Options for target 项目名字”&#xff0c;点击“Device”,选择CPU型号。 2、点击“OK” 3、点击“Target”,勾选“Use Mi…

donet (MVC)webAPI 的接受json 的操作

直接用对象来进行接收&#xff0c;这个方法还不错的。 public class BangdingWeiguiJiluController : ApiController{/// <summary>/// Json数据录入错误信息/// </summary>/// <param name"WeiguiInfos"></param>/// <returns></r…

设计模式与游戏完美开发(3)

更多内容可以浏览本人博客&#xff1a;https://azureblog.cn/ &#x1f60a; 该文章主体内容来自《设计模式与游戏完美开发》—蔡升达 第二篇 基础系统 第五章 获取游戏服务的唯一对象——单例模式&#xff08;Singleton&#xff09; 游戏实现中的唯一对象 在游戏开发过程中…

pygame飞机大战

飞机大战 1.main类2.配置类3.游戏主类4.游戏资源类5.资源下载6.游戏效果 1.main类 启动游戏。 from MainWindow import MainWindow if __name__ __main__:appMainWindow()app.run()2.配置类 该类主要存放游戏的各种设置参数。 #窗口尺寸 #窗口尺寸 import random import p…

如何让用户在网页中填写PDF表格?

在网页中让用户直接填写PDF表格&#xff0c;可以大大简化填写、打印、扫描和提交表单的流程。通过使用复选框、按钮和列表等交互元素&#xff0c;PDF表格不仅让填写过程更高效&#xff0c;还能方便地在电脑或移动设备上访问和提交数据。 以下是在浏览器中显示可填写PDF表单的四…

ThinkPHP 8高效构建Web应用-获取请求对象

【图书介绍】《ThinkPHP 8高效构建Web应用》-CSDN博客 《2025新书 ThinkPHP 8高效构建Web应用 编程与应用开发丛书 夏磊 清华大学出版社教材书籍 9787302678236 ThinkPHP 8高效构建Web应用》【摘要 书评 试读】- 京东图书 使用VS Code开发ThinkPHP项目-CSDN博客 编程与应用开…

23.行号没有了怎么办 滚动条没有了怎么办 C#例子

新建了一个C#项目&#xff0c;发现行号没有了。 想把行号调出来&#xff0c;打开项目&#xff0c;选择工具>选项> 如下图&#xff0c;在文本编辑器的C#里有一个行号&#xff0c;打开就可以了 滚动条在这里&#xff1a;

30天开发操作系统 第 12 天 -- 定时器

前言 定时器(Timer)对于操作系统非常重要。它在原理上却很简单&#xff0c;只是每隔一段时间(比如0.01秒)就发送一个中断信号给CPU。幸亏有了定时器&#xff0c;CPU才不用辛苦地去计量时间。……如果没有定时器会怎么样呢?让我们想象一下吧。 假如CPU看不到定时器而仍想计量时…