使用Pytorch写简单线性回归

文章目录

  • Pytorch
    • 一、Pytorch 介绍
    • 二、概念
    • 三、应用于简单线性回归
  • 1.代码框架
  • 2.引用
  • 3.继续模型
    • (1)要定义一个模型,需要继承`nn.Module`:
    • (2)如果函数的参数不具体指定,那么就需要在`__init__`函数中添加未指定的变量:
  • 2.定义数据
  • 3.实例化模型
  • 4.损失函数
  • 5.优化器
  • 6.模型训练
  • 7.绘制数据

Pytorch

一、Pytorch 介绍

  PyTorch 是一个开源的深度学习框架,由 Facebook 的人工智能研究团队开发。它主要用于构建和训练深度学习模型,具有以下特点:
  动态计算图:PyTorch 使用动态计算图,这意味着可以在运行时动态地构建、修改和执行计算图,使得开发和调试更加灵活。
  易于使用:提供了简洁直观的 API,使得开发者可以快速上手,专注于模型的设计和实现。
  强大的 GPU 加速:支持在 GPU 上进行高效的并行计算,大大加快了训练和推理的速度。
  广泛的社区支持:拥有庞大的开发者社区,提供了丰富的教程、示例和第三方扩展。

二、概念

  张量(Tensor):是 PyTorch 中的基本数据结构,类似于多维数组,可以在 CPUGPU 上存储和操作数据。
  自动求导(Autograd):PyTorch 能够自动计算张量的梯度,这对于训练深度学习模型非常重要,因为它可以通过反向传播算法自动更新模型参数。
  模块(Module):是 PyTorch 中构建模型的基本单元,可以包含多个子模块和参数。
  优化器(Optimizer):用于优化模型参数,常见的优化算法如随机梯度下降(SGD)、Adam 等。
  损失函数(Loss Function):用于衡量模型预测与真实值之间的差异,常见的损失函数有均方误差(MSE)、交叉熵损失等。

三、应用于简单线性回归

  线性回归是一种简单的机器学习算法,用于预测一个连续的数值。下面是使用 PyTorch 实现简单线性回归的步骤:
  准备数据:
  生成一些随机的输入数据和对应的输出数据。例如,假设我们要拟合一个线性函数 y = 2x + 1,可以生成一些随机的 x 值,并计算出对应的 y 值。
  定义模型:
  使用 PyTorch 的模块类定义一个简单的线性回归模型。这个模型通常包含一个线性层,即一个全连接层,它将输入特征映射到输出。
  定义损失函数和优化器:
  选择一个合适的损失函数,如均方误差(MSE)损失。
  选择一个优化器,如随机梯度下降(SGD)优化器,并设置学习率等参数。
  训练模型:
  将数据分成小批次,每次输入一个批次的数据到模型中进行前向传播,计算损失。
  然后进行反向传播,计算梯度,并使用优化器更新模型参数。
  重复这个过程直到达到预定的训练次数或损失收敛。
  测试模型:
  使用训练好的模型对新的数据进行预测,评估模型的性能。

1.代码框架

在这里插入图片描述

2.引用

import torch        
from torch import nn
from torch import optim

3.继续模型

  继承模型主要都是在nn.Module类

(1)要定义一个模型,需要继承nn.Module

class EIModel(nn.Module):def __init__(self):super(EIModel,self).__init__()   #等价于super().__init__()  self.linear=nn.Linear(in_features=1,out_features=1)   #创建线性层def forward(self,inputs):logits=self.linear(inputs)return logits   

  注:forward()return切记要写上

(2)如果函数的参数不具体指定,那么就需要在__init__函数中添加未指定的变量:

class EIModel(nn.Module):def __init__(self,in_features,out_features):super(EIModel,self).__init__()self.linear=nn.Linear(in_features,out_features)  def forward(self,inputs):logits=self.linear(inputs)return logits

  注:这时在实例化模型时,函数内要指定参数:

model = EIModel(in_features=1,out_features=1)

2.定义数据

x_list=[0,1,2,3,4]
y_list=[2,3,4,5,8]x_numpy=np.array(x_list,dtype=np.float32)
x=torch.from_numpy(x_numpy.reshape(-1,1))
y_numpy=np.array(y_list,dtype=np.float32)
y=torch.from_numpy(y_numpy.reshape(-1,1))

3.实例化模型

model = EIModel()

  直接调用模型

import torchvision.models as models
models.resnet50()

  测试模型预测结果

outputs=model(x)
print(outputs)

  结果:

tensor([[-0.9462],[-1.4654],[-1.9846],[-2.5038],[-3.0230]], grad_fn=<AddmmBackward>)

4.损失函数

  nn.MSELoss()定义均方误差损失计算函数
(1)loss_f=nn.MSELoss()
(2)loss_f=nn.CrossEntropyLoss()

5.优化器

  torch.optim.SGD()是一个内置的优化器
  它的第一个参数是需要优化的变量,可以通过model.parameters()方法获取模型中所有变量
lr=0.0001定义学习率
  (1)opt=torch.optim.SGD(model.parameters(),lr=0.0001)
  (2)optimizer_ft=optim.Adam(params_to_update,lr=1e-2)
  Adam优点:可以自动调整学习效率

6.模型训练

  (1)因为pytorch会累积每次计算的梯度,所以需要将上一循环中的计算的梯度归零
将全部数据训练一遍称为一个epoch,这里训练了500epoch

for epoch in range(500):for x_index,y_index in zip(x,y): #同时对x和y迭代y_pred=model(x_index)        #等价于model.forward(inputs)loss=loss_f(y_pred,y_index)  #根据模型预测输出与实际值y_index计算损失opt.zero_grad()              #将累计的梯度清0loss.backward()              #反向传播损失,计算损失与模型参数之间的梯度opt.step()                   #根据计算得到梯度优化模型参数

  (2)将损失误差打印出来

for epoch in range(500):for x_index,y_index in zip(x,y):   y_pred=model(x_index)loss=loss_f(y_pred,y_index)opt.zero_grad()     #将累计的梯度清0loss.backward()     #反向传播损失,计算损失与模型参数之间的梯度opt.step()          #根据计算得到梯度优化模型参数if (epoch + 1) % 50 == 0:print(f'epoch:{epoch + 1}, loss = {loss.item():.4f}')

  结果:

epoch:50, loss = 12.1212
epoch:100, loss = 7.1772
epoch:150, loss = 4.4344
epoch:200, loss = 2.8781
epoch:250, loss = 1.9724
epoch:300, loss = 1.4308
epoch:350, loss = 1.0978
epoch:400, loss = 0.8877
epoch:450, loss = 0.7521
epoch:500, loss = 0.6629

  参数名称和值:
model.named_parameters()可以以生成器的形式返回模型参数的名称和值

print(list(model.named_parameters()))

  结果:

[('linear.weight', Parameter containing:tensor([[1.4773]], requires_grad=True)), 
('linear.bias', Parameter containing:tensor([1.2792], requires_grad=True))]

  单独查看权重/偏置:

print(model.linear.weight)
print(model.linear.bias)

7.绘制数据

  使用tensor.detach()方法获得具有相同内容但不需要跟踪运算的新张量,可以认为是获取张量的值

plt.scatter(x_list,y_list,label='scatter plot')
plt.plot(x,model(x).detach().numpy(),c='r',label='line plot')
plt.legend()
plt.show()

  结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/881461.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

IP地址类型选择指南:动态IP、静态IP还是数据中心IP?

你是否曾经困惑于如何选择最适合业务需求的IP地址类型&#xff1f;面对动态IP、静态IP和数据中心IP这三种选择&#xff0c;你是否了解它们各自对你的跨境在线业务可能产生的深远影响&#xff1f; 在跨境电商领域&#xff0c;选择合适的IP类型对于业务的成功至关重要。动态IP、…

gitee开源商城diygw-mall

DIYGW可视化开源商城系统。所的界面布局显示都通过低代码可视化开发工具生成源码实现。支持集成微信小程序支付。 DIYGW可视化开源商城系统是一款基于thinkphp8 framework、 element plus admin、uniapp开发而成的前后端分离系统。 开源商城项目源码地址&#xff1a;diygw商城…

Java中String类的常见操作Api

目录 String类的常见操作 1).int indexOf (char 字符) 2).int lastIndexOf(char 字符) 3).int indexOf(String 字符串) 4).int lastIndexOf(String 字符串) 5).char charAt(int 索引) 6).Boolean endWith(String 字符串) 7).int length() 8).boolean equals(T 比较对象) 9).b…

区块链积分系统:重塑支付安全与商业创新的未来

在当今社会&#xff0c;数字化浪潮席卷全球&#xff0c;支付安全与风险管理议题日益凸显。随着交易频次与规模的不断扩大&#xff0c;传统支付体系正面临前所未有的效率、合规性和安全挑战。 区块链技术&#xff0c;凭借其去中心化、高透明度以及数据不可篡改的特性&#xff0c…

SSH 公钥认证:从gitlab clone项目repo到本地

这篇文章的分割线以下文字内容由 ChatGPT 生成&#xff08;我稍微做了一些文字上的调整和截图的补充&#xff09;&#xff0c;我review并实践后觉得内容没有什么问题&#xff0c;由此和大家分享。 假如你想通过 git clone git10.12.5.19:your_project.git 命令将 git 服务器上…

简单的maven nexus私服学习

简单的maven nexus私服学习 1.需求 我们现在使用的maven私服是之前同事搭建的&#xff0c;是在公司的一台windows电脑上面&#xff0c;如果出问题会比较难搞&#xff0c;所以现在想将私服迁移到我们公司的测试服务器上&#xff0c;此处简单了解一下私服的一些配置记录一下&am…

多线程(二):Thread类常见的属性和方法

目录 1、run & start 2、Thread类常见的属性和方法 2.1 构造方法 2.2 属性 3、后台进程 & 前台进程 4、setDaemon 5、isAlive 6、终止一个线程 6.1 变量捕获 6.2 currentThread & isInterrupted & interrupt 1、run & start 在多线程&#xff08…

Java面试宝典-Java集合01

Java面试宝典-Java集合01 目录 Java面试宝典-Java集合01 1、Java中常用的集合有哪些&#xff1f; 2、Collection 和 Collections 有什么区别&#xff1f; 3、为什么集合类没有实现 Cloneable 和 Serializable 接口&#xff1f; 4、数组和集合有什么本质区别&#xff1f; 5、数组…

Java | Leetcode Java题解之第470题用Rand7()实现Rand10()

题目&#xff1a; 题解&#xff1a; class Solution extends SolBase {public int rand10() {int a, b, idx;while (true) {a rand7();b rand7();idx b (a - 1) * 7;if (idx < 40) {return 1 (idx - 1) % 10;}a idx - 40;b rand7();// get uniform dist from 1 - 63…

蓝桥杯【物联网】零基础到国奖之路:十七. 扩展模块之单路ADC和NE555

蓝桥杯【物联网】零基础到国奖之路:十七. 扩展模块之单路ADC和NE555 第一节 硬件解读第二节 CubeMx配置第三节 代码1&#xff0c;脉冲部分代码2&#xff0c;ADC部分代码![在这里插入图片描述](https://i-blog.csdnimg.cn/direct/57531a4ee76d46daa227ae0a52993191.png) 第一节 …

React技术在Meta Connect 2024大会

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗&#xff1f;订阅我们的简报&#xff0c;深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同&#xff0c;从行业内部的深度分析和实用指南中受益。不要错过这个机会&#xff0c;成为AI领…

熵权法计算评价指标权重——使用Excel VBA实现

[ 熵权法 ] 信息是系统有序程度的一个度量&#xff0c;熵是系统无序程度的一个度量&#xff1b;根据信息熵的定义&#xff0c;对于某项指标&#xff0c;可以用熵值来判断某个指标的离散程度&#xff0c;其信息熵值越小&#xff0c;指标的离散程度越大&#xff0c; 该指标对综合…

数据库——表格之间的关系(表格之间的连接和处理)

数据库表格之间经常存在各种关系&#xff1a; 一对一、一对多、多对多 1.一对一 —— 丈夫表&#xff0c;妻子表为例 连接方式一&#xff1a;合并为一张表 这种方式对于一对一来说最优 连接方式二&#xff1a;在其中一张表内加入一个外键&#xff0c;连接另一张表 连…

ARM base instruction -- sdiv

有符号除法运算 Signed Divide divides a signed integer register value by another signed integer register value, and writes the result to the destination register. The condition flags are not affected. 将一个有符号整数寄存器值除以另一个有符号整数寄存器值&am…

Java中的switch分支结构

switch分支结构 switch分支结构1.基本语法2.说明3.流程图4.案例5.注意事项6.练习7.switch和if的比较 switch分支结构 1.基本语法 switch&#xff08;表达式&#xff09;{case 常量1: //当...语句块1;break;case 常量2: 语句块2;break;...case 常量n: 语句块n;break;defaul…

路径跟踪之导航向量场——二维导航向量场

今天带来一期轨迹跟踪算法的讲解&#xff0c;首先讲解二维平面中的导航向量场[1]。该方法具有轻量化、计算简便、收敛性强等多项优点。该方法根据期望的轨迹函数&#xff0c;计算全局位置的期望飞行向量&#xff0c;将期望飞行向量转为偏光角&#xff0c;输入底层控制器&#x…

prometheus client_java实现进程的CPU、内存、IO、流量的可观测

文章目录 1、获取进程信息的方法1.1、通过读取/proc目录获取进程相关信息1.2、通过Linux命令获取进程信息1.2.1、top&#xff08;CPU/内存&#xff09;命令1.2.2、iotop&#xff08;磁盘IO&#xff09;命令1.2.3、nethogs&#xff08;流量&#xff09;命令 2、使用prometheus c…

AAA Mysql与redis的主从复制原理

一 &#xff1a;Mysql主从复制 重要的两个日志文件&#xff1a;bin log 和 relay log bin log&#xff1a;二进制日志&#xff08;binnary log&#xff09;以事件形式记录了对MySQL数据库执行更改的所有操作。 relay log&#xff1a;用来保存从节点I/O线程接受的bin log日志…

用凡尔码系统进行隐患排查二维码的制作

隐患排查是企业安全管理的重要环节&#xff0c;通过定期或不定期地对生产设备、作业场所、作业人员等进行检查&#xff0c;发现并消除安全隐患&#xff0c;预防事故的发生。隐患排查的效率和质量直接影响到企业的安全生产水平和经济效益。 传统的隐患排查方法主要依靠纸质进行…

PostgreSQL学习笔记七:常规SQL操作

PostgreSQL 支持标准的 SQL 语句&#xff0c;同时也扩展了一些特有的功能。以下是一些常规的 SQL 语句示例&#xff0c;这些示例涵盖了数据定义、数据操作和数据查询的基本操作&#xff1a; 数据定义语言 (DDL 创建数据库&#xff1a; CREATE DATABASE mydatabase;创建表&#…