【深度学习入门篇 ②】Pytorch完成线性回归!

🍊,大家好,我是小森( ﹡ˆoˆ﹡ )! 易编橙·终身成长社群创始团队嘉宾橙似锦计划领衔成员阿里云专家博主腾讯云内容共创官CSDN人工智能领域优质创作者 。

易编橙:一个帮助编程小伙伴少走弯路的终身成长社群!


上一部分我们自己通过torch的方法完成反向传播和参数更新,在Pytorch中预设了一些更加灵活简单的对象,让我们来构造模型、定义损失,优化损失等;那么接下来,我们一起来了解一下其中常用的API!

nn.Module

nn.Module 是 PyTorch 框架中用于构建所有神经网络模型的基类。在 PyTorch 中,几乎所有的神经网络模块(如层、卷积层、池化层、全连接层等)都继承自 nn.Module。这个类提供了构建复杂网络所需的基本功能,如参数管理、模块嵌套、模型的前向传播等。

当我们自定义网络的时候,有两个方法需要特别注意:

  1. __init__需要调用super方法,继承父类的属性和方法

  2. farward方法必须实现,用来定义我们的网络的向前计算的过程

用前面的y = wx+b的模型举例如下:

from torch import nn
class Lr(nn.Module):def __init__(self):super(Lr, self).__init__()  # 继承父类init的参数self.linear = nn.Linear(1, 1) def forward(self, x):out = self.linear(x)return out
  •  nn.Linear为torch预定义好的线性模型,也被称为全链接层,传入的参数为输入的数量,输出的数量(in_features, out_features),是不算(batch_size的列数)
  • nn.Module定义了__call__方法,实现的就是调用forward方法,即Lr的实例,能够直接被传入参数调用,实际上调用的是forward方法并传入参数
  • __init__方法里面的内容就是类创建的时候,跟着自动创建的部分。
  • 与之对应的就是__del__方法,在对象被销毁时执行一些清理操作。
# 实例化模型
model = Lr()
# 传入数据,计算结果
predict = model(x)

优化器类

优化器(optimizer),可以理解为torch为我们封装的用来进行更新参数的方法,比如常见的随机梯度下降(stochastic gradient descent,SGD)。

优化器类都是由torch.optim提供的,例如

  1. torch.optim.SGD(参数,学习率)

  2. torch.optim.Adam(参数,学习率)

注意:

  • 参数可以使用model.parameters()来获取,获取模型中所有requires_grad=True的参数


optimizer = optim.SGD(model.parameters(), lr=1e-3) # 实例化
optimizer.zero_grad() # 梯度置为0
loss.backward() #  计算梯度
optimizer.step()  #  更新参数的值

损失函数

  1. 均方误差:nn.MSELoss(),常用于回归问题

  2. 交叉熵损失:nn.CrossEntropyLoss(),常用于分类问题

model = Lr() # 实例化模型
criterion = nn.MSELoss() # 实例化损失函数
optimizer = optim.SGD(model.parameters(), lr=1e-3) #  实例化优化器类
for i in range(100):y_predict = model(x_true) # 预测值loss = criterion(y_true,y_predict) # 调用损失函数传入真实值和预测值,得到损失optimizer.zero_grad() loss.backward() # 计算梯度optimizer.step()  # 更新参数的值

线性回归代码!

import torch
from torch import nn
from torch import optim
import numpy as np
from matplotlib import pyplot as pltx = torch.rand([50,1])
y = x*3 + 0.8# 自定义线性回归模型
class Lr(nn.Module):def __init__(self):super(Lr,self).__init__()self.linear = nn.Linear(1,1)def forward(self, x):    # 模型的传播过程out = self.linear(x)return out# 实例化模型,loss,和优化器
model = Lr()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=1e-3)
#训练模型
for i in range(30000):out = model(x) loss = criterion(y,out) optimizer.zero_grad() loss.backward() optimizer.step()  if (i+1) % 20 == 0:print('Epoch[{}/{}], loss: {:.6f}'.format(i,30000,loss.data))# 模型评估模式,之前说过的
model.eval() 
predict = model(x)
predict = predict.data.numpy()
plt.scatter(x.data.numpy(),y.data.numpy(),c="r")
plt.plot(x.data.numpy(),predict)
plt.show()

  • 可以看出经过30000次训练后(相当于看书,一遍遍的回归学习),基本就可以拟合预期直线了

GPU上运行代码

当模型太大,或者参数太多的情况下,为了加快训练速度,经常会使用GPU来进行训练

此时我们的代码需要稍作调整:

1.判断GPU是否可用torch.cuda.is_available()

torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device(type='cuda', index=0)  # 使用GPU

2.把模型参数和input数据转化为cuda的支持类型

model.to(device)
x.to(device)

3.在GPU上计算结果也为cuda的数据类型,需要转化为numpy或者torch的cpu的tensor类型

predict = predict.cpu().detach().numpy() 
  • predict.cpu() predict张量从可能的其他设备(如GPU)移动到CPU上
  • predict.detach() .detach()方法会返回一个新的张量,这个张量不再与原始计算图相关联,即它不会参与后续的梯度计算。
  • .numpy()方法将张量转换为NumPy数组。

 GPU代码:

import torch
from torch import nn
from torch import optim
import numpy as np
from matplotlib import pyplot as plt
import timex = torch.rand([50,1])
y = x*3 + 0.8class Lr(nn.Module):def __init__(self):super(Lr,self).__init__()self.linear = nn.Linear(1,1)def forward(self, x):out = self.linear(x)return outdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
x,y = x.to(device),y.to(device)model = Lr().to(device)
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=1e-3)for i in range(300):out = model(x)loss = criterion(y,out)optimizer.zero_grad()loss.backward()optimizer.step()if (i+1) % 20 == 0:print('Epoch[{}/{}], loss: {:.6f}'.format(i,30000,loss.data))model.eval() 
predict = model(x)
predict = predict.cpu().detach().numpy() 
plt.scatter(x.cpu().data.numpy(),y.cpu().data.numpy(),c="r")
plt.plot(x.cpu().data.numpy(),predict,)
plt.show()

💯常见的优化算法

在大多数情况下,我们关注的是最小化损失函数,因为它衡量了模型预测与真实标签之间的差异。

梯度下降算法(batch gradient descent BGD)

每次迭代都需要把所有样本都送入,这样的好处是每次迭代都顾及了全部的样本,做的是全局最优化,但是有可能达到局部最优。

随机梯度下降法 (Stochastic gradient descent SGD)

针对梯度下降算法训练速度过慢的缺点,提出了随机梯度下降算法,随机梯度下降算法算法是从样本中随机抽出一组,训练后按梯度更新一次,然后再抽取一组,再更新一次,在样本量及其大的情况下,可能不用训练完所有的样本就可以获得一个损失值在可接受范围之内的模型了。

小批量梯度下降 (Mini-batch gradient descent MBGD)

SGD相对来说要快很多,但是也有存在问题,由于单个样本的训练可能会带来很多噪声,使得SGD并不是每次迭代都向着整体最优化方向,因此在刚开始训练时可能收敛得很快,但是训练一段时间后就会变得很慢。在此基础上又提出了小批量梯度下降法,它是每次从样本中随机抽取一小批进行训练,而不是一组,这样即保证了效果又保证的速度。

AdaGrad

AdaGrad算法就是将每一个参数的每一次迭代的梯度取平方累加后在开方,用全局学习率除以这个数,作为学习率的动态更新,从而达到自适应学习率的效果

Adam

Adam(Adaptive Moment Estimation)算法是将Momentum算法和RMSProp算法结合起来使用的一种算法,能够达到防止梯度的摆幅多大,同时还能够加开收敛速度。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/44806.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vue3中antd上传图片组件及回显

实现效果&#xff1a; 调用后端接口后&#xff0c;后端返回的数据&#xff1a; 1.在项目components/base下新建UploadNew.vue文件&#xff08;上传图片公共组件&#xff09; <template><div class"clearfix"><a-uploadv-model:file-list"fileL…

发挥储能系统领域优势,海博思创坚定不移推动能源消费革命

随着新发展理念的深入贯彻&#xff0c;我国正全面落实“双碳”目标任务&#xff0c;通过积极转变能源消费方式&#xff0c;大幅提升能源利用效率&#xff0c;实现了以年均约3.3%的能源消费增长支撑了年均超过6%的国民经济增长。这一成就的背后&#xff0c;是我国能源结构的持续…

龙蜥Anolis OS基于开源项目制作openssh 9.8p1 rpm包 —— 筑梦之路

环境信息 制作过程和centos 7几乎没有区别&#xff0c;此处就不再赘述。 CentOS 7基于开源项目制作openssh9.8p1 rpm二进制包修复安全漏洞CVE-2024-6387 —— 筑梦之路_cve-2024-6387修复-CSDN博客 制作成果展示 tree RPMS/ RPMS/ └── x86_64├── openssh-9.8p1-1.an7.…

springboot3整合SpringSecurity实现登录校验与权限认证(万字超详细讲解)

目录 身份认证&#xff1a; 1、创建一个spring boot项目&#xff0c;并导入一些初始依赖&#xff1a; 2、由于我们加入了spring-boot-starter-security的依赖&#xff0c;所以security就会自动生效了。这时直接编写一个controller控制器&#xff0c;并编写一个接口进行测试&…

【面试题】防火墙的部署模式有哪些?

防火墙的部署模式多种多样&#xff0c;每种模式都有其特定的应用场景和优缺点。以下是防火墙的主要部署模式&#xff1a; 一、按工作模式分类 路由模式 定义&#xff1a;当防火墙位于内部网络和外部网络之间时&#xff0c;需要将防火墙与内部网络、外部网络以及DMZ&#xff0…

事半功倍大法!财务数据API让企业工作智能化

在快速变化的商业环境中&#xff0c;财务管理的自动化已成为企业提升效率、降低成本和增强决策质量的关键。财务API&#xff0c;作为现代企业架构中不可或缺的一部分&#xff0c;提供了一种强大的工具&#xff0c;使得企业能够无缝集成各种财务服务和应用&#xff0c;实现数据的…

数据分析理论

数据分析的概念 数据分析是指通过恰当的统计方法和分析手段&#xff0c;对数据进行收集汇总&#xff0c;并进行加工处理。对处理过后的有效数据进行分析&#xff0c;发现存在的问题&#xff0c;制定可行的方案、从而帮助人们采取更科学的行动 数据分析4个层次 著名咨询公司Gart…

【WEB前端2024】3D智体编程:乔布斯3D纪念馆-第59-agent自动获取喵星人资讯并保存至云文件夹

【WEB前端2024】3D智体编程&#xff1a;乔布斯3D纪念馆-第59-agent自动获取喵星人资讯并保存至云文件夹 使用dtns.network德塔世界&#xff08;开源的智体世界引擎&#xff09;&#xff0c;策划和设计《乔布斯超大型的开源3D纪念馆》的系列教程。dtns.network是一款主要由Java…

Oracle执行一条SQL的内部过程

一、SQL语句根据其功能主要可以分为以下几大类&#xff1a; 1. 数据查询语言&#xff08;DQL, Data Query Language&#xff09; 功能&#xff1a;用于从数据库中检索数据&#xff0c;常用于查询表中的记录。基本结构&#xff1a;主要由SELECT子句、FROM子句、WHERE子句等组成…

stm32h743 阿波罗v2 NetXduo http server CubeIDE+CubeMX

在这边要设置mpu的大小&#xff0c;要用到http server&#xff0c;mpu得设置的大一些 我是这么设置的&#xff0c;做一个参考 同样&#xff0c;在FLASH.ld里面也要对应修改&#xff0c;SECTIONS里增加.tcp_sec和 .nx_data两个区&#xff0c;我们用ram_d2区域去做网络&#xff…

生产英特尔CPU处理器繁忙的一天

早晨&#xff1a;准备与检查 7:00 AM - 起床与准备 工厂员工们早早起床&#xff0c;快速洗漱并享用早餐。为了在一天的工作中保持高效&#xff0c;他们会进行一些晨间锻炼&#xff0c;保持头脑清醒和身体活力。 8:00 AM - 到达工厂 员工们到达英特尔的半导体制造工厂&#…

电脑拼图软件有哪些?盘点7种简单好用电脑拼图软件

如今我们无时无刻不使用着社交媒体&#xff0c;图片已经成为我们日常生活中不可或缺的一部分。无论是社交媒体分享、工作汇报还是个人创作&#xff0c;拼图软件都扮演着至关重要的角色。今天&#xff0c;就让我们一起来盘点7款电脑拼图软件&#xff0c;帮助你轻松找到最适合自己…

AI应用行业落地100例 | 全国首个司法审判垂直领域AI大模型落地深圳法院

《AI应用行业落地100例》专题汇集了人工智能技术在金融、医疗、教育、制造等多个关键行业中的100个实际应用案例&#xff0c;深入剖析了AI如何助力行业创新、提升效率&#xff0c;并预测了技术发展趋势&#xff0c;旨在为行业决策者和创新者提供宝贵的洞察和启发。 随着人工智能…

研究突破:无矩阵乘法的LLMs 计算!

通过在推理过程中使用优化的内核&#xff0c;内存消耗可以比未优化模型减少超过10倍。&#x1f92f; 该论文总结道&#xff0c;有可能创建第一个可扩展的无矩阵乘法LLM&#xff0c;在数十亿参数规模上实现与最先进的Transformer相媲美的性能。 另一篇最新论文《语言模型物理学…

14 学习总结:指针 · 其二 · 数组

目录 一、数组名的理解 &#xff08;一&#xff09;【数组名】与【&数组名[0]】 &#xff08;二&#xff09;区别于 【 sizeof(数组名) 】 和 【 &数组名 】 &#xff08;三&#xff09;总结 二、使用指针访问数组 三、一维数组传参的本质 四、冒泡排序 五、二…

PlugLink的技术架构实例解析(附源码)

在探讨PlugLink这一开源应用的实际应用与技术细节时&#xff0c;我们可以从其构建的几个核心方面入手&#xff0c;结合当前AI编程的发展趋势&#xff0c;为您提供既有实例又有深度解析的内容。 PlugLink的技术架构实例解析 前端技术选型 —— layui框架&#xff1a; PlugLi…

Windows桌面上透明的记事本怎么设置

作为一名经常需要记录灵感的作家&#xff0c;我的Windows桌面总是布满了各种文件和窗口。在这样的环境下&#xff0c;一个传统的记事本应用往往会显得突兀&#xff0c;遮挡住我急需查看的资料。于是&#xff0c;我开始寻找一种既能满足记录需求&#xff0c;又能保持桌面整洁美观…

画了一个简陋的曼德勃罗集

原文画了一个简陋的曼德勃罗集 - 知乎 (zhihu.com) 前两天看妈咪叔科普曼德勃罗集的视频&#xff1a; 【分形与混沌2】最有魅力的几何图形——曼德勃罗集与朱利亚集 天使与魔鬼共存_哔哩哔哩 (゜-゜)つロ 干杯~-bilibili​www.bilibili.com/video/av79113074​编辑 虽然看过…

Dify中的工具

Dify中的工具分为内置工具&#xff08;硬编码&#xff09;和第三方工具&#xff08;OpenAPI Swagger/ChatGPT Plugin&#xff09;。工具可被Workflow&#xff08;工作流&#xff09;和Agent使用&#xff0c;当然Workflow也可被发布为工具&#xff0c;这样Workflow&#xff08;工…

java线程锁synchronized的几种情况

一、对象方法锁 1、成员方法加锁 同一个对象成员方法有3个synchronized修饰的方法&#xff0c;通过睡眠模拟业务操作 public class CaseOne {public synchronized void m1(){try { TimeUnit.SECONDS.sleep(3); } catch (InterruptedException e) { e.printStackTrace()…