【深度学习】手动完成线性回归!

🍊,大家好,我是小森( ﹡ˆoˆ﹡ )! 易编橙·终身成长社群创始团队嘉宾橙似锦计划领衔成员阿里云专家博主腾讯云内容共创官CSDN人工智能领域优质创作者 。

易编橙:一个帮助编程小伙伴少走弯路的终身成长社群!


大家好!今天我们将一起踏上一场探索深度学习的奇妙之旅,而我们的起点,就是线性回归这一经典而基础的算法。我将带大家从零开始,手动实现线性回归!

Pytorch完成线性回归

向前计算

对于pytorch中的一个tensor,如果设置它的属性 .requires_gradTrue,那么它将会追踪对于该张量的所有操作。或者可以理解为,这个tensor是一个参数,后续会被计算梯度,更新该参数。

计算过程

假设有以下条件(1/4表示求均值,xi中有4个数),使用torch完成其向前计算的过程

如果x为参数,需要对其进行梯度的计算和更新

那么,在最开始随机设置x的值的过程中,需要设置他的requires_grad属性为True,其默认值为False

import torch
x = torch.ones(2, 2, requires_grad=True)  # 设置requires_grad=True用来追踪其计算历史
print(x)tensor([[1., 1.],[1., 1.]], requires_grad=True)y = x+2
print(y)tensor([[3., 3.],[3., 3.]], grad_fn=<AddBackward0>)z = y*y*3  
print(x)tensor([[27., 27.],[27., 27.]], grad_fn=<MulBackward0>) out = z.mean() # 均值
print(out)tensor(27., grad_fn=<MeanBackward0>)

💦从上述代码可以看出:

  1. x的requires_grad属性为True

  2. 之后的每次计算都会修改其grad_fn属性,用来记录做过的操作

requires_grad和grad_fn

a = torch.randn(2, 2)
a = ((a * 3) / (a - 1))
print(a.requires_grad)  # False
a.requires_grad_(True)  # 修改
print(a.requires_grad)  # True
b = (a * a).sum()
print(b.grad_fn) # <SumBackward0 object at 0x4e2b14345d21>
with torch.no_gard():c = (a * a).sum()  #tensor(151.6830),此时c没有gard_fnprint(c.requires_grad) #False

注意:

为了防止跟踪历史记录(和使用内存),可以将代码块包装在with torch.no_grad():中。在评估模型时特别有用,因为模型可能具有requires_grad = True的可训练的参数,但是我们不需要在此过程中对他们进行梯度计算。

在机器学习和深度学习中,模型有训练模式评估模式

训练模式前向传播、计算损失、反向传播

  • 在训练过程中,模型中的某些层,例如Dropout层会在训练时随机丢弃一部分神经元的输出,以防止过拟合。

评估模式:模型被用来评估其在新数据上的性能,而不需要进行参数的更新;例如,Dropout层在评估模式下会停止丢弃神经元,以确保模型输出的一致性。

梯度计算

对于上面的计算过程,我们可以使用backward方法来进行反向传播,计算梯度💫

out.backward(),此时便能够求出导数$\frac{d out}{dx}$,调用x.gard能够获取导数值:

tensor([[4.5000, 4.5000],[4.5000, 4.5000]])

\frac{d(O)}{d(x_i)} = \frac{3}{2}(x_i+2)

 在x_i等于1时其值为4.5


注意:在输出为一个标量的情况下,我们可以调用输出tensorbackword() 方法,但是在数据是一个向量的时候,调用backward()的时候还需要传入其他参数。  

很多时候我们的损失函数都是一个标量,所以这里就不再介绍损失为向量的情况。

loss.backward()就是根据损失函数,对参数(requires_grad=True)的去计算他的梯度,并且把它累加保存到x.gard,此时还并未更新其梯度

  1. tensor.data:

    • 在tensor的require_grad=False,tensor.data和tensor等价

    • require_grad=True时,tensor.data仅仅是获取tensor中的数据

  2. tensor.numpy():

    • require_grad=True不能够直接转换,需要使用tensor.detach().numpy()

线性回归实现

我们使用一个自定义的数据,来使用torch实现一个简单的线性回归;

假设我们的基础模型就是y = wx+b,其中w和b均为参数,我们使用y = 3x+0.8来构造数据x、y,所以最后通过模型应该能够得出w和b应该分别接近3和0.8。

import torch
import numpy as np
from matplotlib import pyplot as pltx = torch.rand([50])    # 相当于就是y = 3*x + 0.8 这条直线的x的数量
y = 3*x + 0.8# 初始权重w和b都是设置为1
w = torch.rand(1,requires_grad=True)
b = torch.rand(1,requires_grad=True)def loss_fn(y,y_predict):loss = (y_predict-y).pow(2).mean()for i in [w,b]:# 每次反向传播前把梯度设为0if i.grad is not None:i.grad.data.zero_()# [i.grad.data.zero_() for i in [w,b] if i.grad is not None]loss.backward()return loss.datadef optimize(learning_rate):# print(w.grad.data,w.data,b.data)w.data -= learning_rate* w.grad.datab.data -= learning_rate* b.grad.datafor i in range(3000):# 预测值y_predict = x*w + b# 计算损失,把参数的梯度置为0,进行反向传播 loss = loss_fn(y,y_predict)if i%500 == 0:print(i,loss)# 更新参数w和boptimize(0.01)# 绘制图形,观察训练结束的预测值和真实值
predict =  x*w + b  
# 使用训练后的w和b计算预测值plt.scatter(x.data.numpy(), y.data.numpy(),c = "r")
plt.plot(x.data.numpy(), predict.data.numpy())
plt.show()print("w",w)
print("b",b)

输出结果:

0 tensor(2.0233)
500 tensor(0.0692)
1000 tensor(0.0201)
1500 tensor(0.0059)
2000 tensor(0.0017)
2500 tensor(0.0005)
w tensor([2.9586], requires_grad=True)
b tensor([0.8253], requires_grad=True)

💯可以看到已经很接近我们所预期的值了! 

💥下期我们再来动手使用Pytorch的API来创建线性回归!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/44088.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

现代码头装卸系统:技术创新与效率提升

引言 码头装卸系统在全球贸易和物流链中扮演着至关重要的角色。随着全球化进程的加快&#xff0c;国际贸易量不断增加&#xff0c;港口作为货物进出主要枢纽&#xff0c;其装卸效率直接影响到整个物流链的运作效率和成本。一个高效、现代化的码头装卸系统不仅能提高港口的货物处…

JVM是如何创建一个对象的?

哈喽&#xff0c;大家好&#x1f389;&#xff0c;我是世杰。 本文我为大家介绍面试官经常考察的**「Java对象创建流程」** 照例在开头留一些面试考察内容~~ 面试连环call Java对象创建的流程是什么样?JVM执行new关键字时都有哪些操作?JVM在频繁创建对象时&#xff0c;如何…

JVM垃圾回收器详解

垃圾回收器 JDK 默认垃圾收集器&#xff08;使用 java -XX:PrintCommandLineFlags -version 命令查看&#xff09;&#xff1a; JDK 8&#xff1a;Parallel Scavenge&#xff08;新生代&#xff09; Parallel Old&#xff08;老年代&#xff09; JDK 9 ~ JDK20: G1 堆内存中…

CVE-2024-6387Open SSH漏洞彻底解决举措(含踩坑内容)

一、漏洞名称 OpenSSH 远程代码执行漏洞(CVE-2024-6387) 二、漏洞概述 Open SSH是基于SSH协议的安全网络通信工具&#xff0c;广泛应用于远程服务器管理、加密文件传输、端口转发、远程控制等多个领域。近日被爆出存在一个远程代码执行漏洞&#xff0c;由于Open SSH服务器端…

2024年夏季德旺杯数学素养水平测试

此为小高组的测试&#xff0c;不过德旺杯主要看获奖情况&#xff0c;选择学员入营

基于考研题库小程序V2.0实现倒计时功能板块和超时判错功能

V2.0 需求沟通 需求分析 计时模块 3.1.1、功能描述←计时模块用于做题过程中对每一题的作答进行30秒倒计时&#xff0c;超时直接判错&#xff0c;同时将总用时显示在界面上;记录每次做题的总用时。 3.1.2、接口描述←与判定模块的接口为超时判定&#xff0c;若单题用时超过 …

人工智能和机器学习 (复旦大学计算机科学与技术实践工作站)20240703(上午场)人工智能初步、mind+人脸识别

前言 在这个科技日新月异的时代&#xff0c;人工智能&#xff08;AI&#xff09;已经逐渐渗透到我们生活的方方面面&#xff0c;从智能家居到自动驾驶&#xff0c;无一不彰显着AI的强大潜力。而人脸识别技术作为AI领域的一项重要应用&#xff0c;更是以其高效、便捷的特点受到了…

萤石揽获2024葵花奖17项重磅大奖 登顶荣誉之巅

7月9日&#xff0c;第八届葵花奖智能家居评选颁奖盛典在中国建博会&#xff08;广州&#xff09;广交会展馆隆重举行。萤石共斩获横跨智能锁、智能家居摄像机、智能清洁、全屋智能以及物联网云平台等多个领域的17项大奖&#xff0c;创下行业最多记录&#xff0c;并问鼎金至尊奖…

记录|C#安装+HslCommunication安装

记录线索 前言一、C#安装1.社区版下载2.VS2022界面设置 二、HslCommunication安装1.前提2.安装3.相关文件【重点】 更新记录 前言 初心是为了下次到新的电脑上安装VS2022做C#上机位项目时能快速安装成功。 一、C#安装 1.社区版下载 Step1. 直接点击VS2022&#xff0c;跳转下…

华为机试HJ106字符逆序

华为机试HJ106字符逆序 题目&#xff1a; 想法&#xff1a; 将输入的字符串倒叙输出即可 input_str input()print(input_str[::-1])

二十年大数据到 AI,图灵奖得主眼中的数据库因果循环

最近&#xff0c;MIT 教授 Michael Stonebraker 和 CMU 教授 Andrew Pavlo (Andy) 教授联合发表了一篇数据库论文。Michael Stonebraker 80 高龄&#xff0c;是数据库行业唯一在世的图灵奖得主&#xff0c;Andy 则是业界少壮派里的最大 KOL。 一老一少&#xff0c;当今数据库届…

MVC架构

MVC架构 MVC架构在软件开发中通常指的是一种设计模式&#xff0c;它将应用程序分为三个主要组成部分&#xff1a;模型&#xff08;Model&#xff09;、视图&#xff08;View&#xff09;和控制器&#xff08;Controller&#xff09;。这种分层结构有助于组织代码&#xff0c;使…

钡铼技术有限公司S270用于智慧物流中心货物追踪与调度

钡铼技术有限公司的第四代S270是一款专为智慧物流中心设计的工业级4G远程遥测终端RTU&#xff0c;其强大的功能和灵活性使其成为货物追踪与调度的理想选择。 技术规格和功能特点 钡铼S270支持多种通信协议&#xff0c;包括短信和MQTT&#xff0c;这使得它能够与各种云平台如华…

图论---匈牙利算法求二分图最大匹配的实现

开始编程前分析设计思路和程序的整体的框架&#xff0c;以及作为数学问题的性质&#xff1a; 程序流程图&#xff1a; 数学原理&#xff1a; 求解二分图最大匹配问题的算法&#xff0c;寻找一个边的子集&#xff0c;使得每个左部点都与右部点相连&#xff0c;并且没有两条边共享…

【STM32学习】cubemx配置,串口的使用,串口发送接收函数使用,以及串口重定义、使用printf发送

1、串口的基本配置 选择USART1&#xff0c;选择异步通信&#xff0c;设置波特率 选择后&#xff0c;会在右边点亮串口 串口引脚是用来与其他设备通信的&#xff0c;如在程序中打印发送信息&#xff0c;电脑上打开串口助手&#xff0c;就会收到信息。 串口的发送接收&#xff0…

Java - JDK17语法新增特性(如果想知道Java - JDK17语法新增常见的特性的知识点,那么只看这一篇就足够了!)

前言&#xff1a;Java在2021年发布了最新的长期支持版本&#xff1a;JDK 17。这个版本引入了许多新的语法特性&#xff0c;提升了开发效率和代码可读性。本文将简要介绍一些常见的新特性&#xff0c;帮助开发者快速掌握并应用于实际开发中。 ✨✨✨这里是秋刀鱼不做梦的BLOG ✨…

相机光学(三十)——N5-N7-N8中性灰

GTI可提供N5/N7/N8中性灰涂料&#xff0c;用于不同的看色环境&#xff0c;N5/N7/N8代表深中浅不同的灰色程度&#xff0c;在成像、工业、印刷行业中&#xff0c;分别对周围观察环境有一定的要求&#xff0c;也出台了相应的标准文件&#xff0c;客户可以根据实际使用环境进行选择…

QT开发积累——qt中的注释和多行注释的几种方式,函数方法注释生成

目录 引出qt中的注释和多行注释方法的注释生成 总结日积月累&#xff0c;开发集锦方法参数加const和不加const的区别方法加static和不加static的区别Qt遍历list提高效率显示函数的调用使用&与不使用&qt方法的参数中使用&与不使用&除法的一个坑 项目创建相关新建…

交通气象站:保障道路安全的智慧之眼

随着社会的快速发展&#xff0c;交通运输日益繁忙&#xff0c;道路安全成为公众关注的焦点。在这个背景下&#xff0c;交通气象站作为保障道路安全的重要设施&#xff0c;正发挥着越来越重要的作用。它们不仅为交通管理部门提供及时、准确的气象信息&#xff0c;也为广大驾驶员…

高阶面试-dubbo的学习

SPI机制 SPI&#xff0c;service provider interface&#xff0c;服务发现机制&#xff0c;其实就是把接口实现类的全限定名配置在文件里面&#xff0c;然后通过加载器ServiceLoader去读取配置加载实现类&#xff0c;比如说数据库驱动&#xff0c;我们把mysql的jar包放到项目的…