使用Pytorch写简单线性回归

文章目录

  • Pytorch
    • 一、Pytorch 介绍
    • 二、概念
    • 三、应用于简单线性回归
  • 1.代码框架
  • 2.引用
  • 3.继续模型
    • (1)要定义一个模型,需要继承`nn.Module`:
    • (2)如果函数的参数不具体指定,那么就需要在`__init__`函数中添加未指定的变量:
  • 2.定义数据
  • 3.实例化模型
  • 4.损失函数
  • 5.优化器
  • 6.模型训练
  • 7.绘制数据

Pytorch

一、Pytorch 介绍

  PyTorch 是一个开源的深度学习框架,由 Facebook 的人工智能研究团队开发。它主要用于构建和训练深度学习模型,具有以下特点:
  动态计算图:PyTorch 使用动态计算图,这意味着可以在运行时动态地构建、修改和执行计算图,使得开发和调试更加灵活。
  易于使用:提供了简洁直观的 API,使得开发者可以快速上手,专注于模型的设计和实现。
  强大的 GPU 加速:支持在 GPU 上进行高效的并行计算,大大加快了训练和推理的速度。
  广泛的社区支持:拥有庞大的开发者社区,提供了丰富的教程、示例和第三方扩展。

二、概念

  张量(Tensor):是 PyTorch 中的基本数据结构,类似于多维数组,可以在 CPUGPU 上存储和操作数据。
  自动求导(Autograd):PyTorch 能够自动计算张量的梯度,这对于训练深度学习模型非常重要,因为它可以通过反向传播算法自动更新模型参数。
  模块(Module):是 PyTorch 中构建模型的基本单元,可以包含多个子模块和参数。
  优化器(Optimizer):用于优化模型参数,常见的优化算法如随机梯度下降(SGD)、Adam 等。
  损失函数(Loss Function):用于衡量模型预测与真实值之间的差异,常见的损失函数有均方误差(MSE)、交叉熵损失等。

三、应用于简单线性回归

  线性回归是一种简单的机器学习算法,用于预测一个连续的数值。下面是使用 PyTorch 实现简单线性回归的步骤:
  准备数据:
  生成一些随机的输入数据和对应的输出数据。例如,假设我们要拟合一个线性函数 y = 2x + 1,可以生成一些随机的 x 值,并计算出对应的 y 值。
  定义模型:
  使用 PyTorch 的模块类定义一个简单的线性回归模型。这个模型通常包含一个线性层,即一个全连接层,它将输入特征映射到输出。
  定义损失函数和优化器:
  选择一个合适的损失函数,如均方误差(MSE)损失。
  选择一个优化器,如随机梯度下降(SGD)优化器,并设置学习率等参数。
  训练模型:
  将数据分成小批次,每次输入一个批次的数据到模型中进行前向传播,计算损失。
  然后进行反向传播,计算梯度,并使用优化器更新模型参数。
  重复这个过程直到达到预定的训练次数或损失收敛。
  测试模型:
  使用训练好的模型对新的数据进行预测,评估模型的性能。

1.代码框架

在这里插入图片描述

2.引用

import torch        
from torch import nn
from torch import optim

3.继续模型

  继承模型主要都是在nn.Module类

(1)要定义一个模型,需要继承nn.Module

class EIModel(nn.Module):def __init__(self):super(EIModel,self).__init__()   #等价于super().__init__()  self.linear=nn.Linear(in_features=1,out_features=1)   #创建线性层def forward(self,inputs):logits=self.linear(inputs)return logits   

  注:forward()return切记要写上

(2)如果函数的参数不具体指定,那么就需要在__init__函数中添加未指定的变量:

class EIModel(nn.Module):def __init__(self,in_features,out_features):super(EIModel,self).__init__()self.linear=nn.Linear(in_features,out_features)  def forward(self,inputs):logits=self.linear(inputs)return logits

  注:这时在实例化模型时,函数内要指定参数:

model = EIModel(in_features=1,out_features=1)

2.定义数据

x_list=[0,1,2,3,4]
y_list=[2,3,4,5,8]x_numpy=np.array(x_list,dtype=np.float32)
x=torch.from_numpy(x_numpy.reshape(-1,1))
y_numpy=np.array(y_list,dtype=np.float32)
y=torch.from_numpy(y_numpy.reshape(-1,1))

3.实例化模型

model = EIModel()

  直接调用模型

import torchvision.models as models
models.resnet50()

  测试模型预测结果

outputs=model(x)
print(outputs)

  结果:

tensor([[-0.9462],[-1.4654],[-1.9846],[-2.5038],[-3.0230]], grad_fn=<AddmmBackward>)

4.损失函数

  nn.MSELoss()定义均方误差损失计算函数
(1)loss_f=nn.MSELoss()
(2)loss_f=nn.CrossEntropyLoss()

5.优化器

  torch.optim.SGD()是一个内置的优化器
  它的第一个参数是需要优化的变量,可以通过model.parameters()方法获取模型中所有变量
lr=0.0001定义学习率
  (1)opt=torch.optim.SGD(model.parameters(),lr=0.0001)
  (2)optimizer_ft=optim.Adam(params_to_update,lr=1e-2)
  Adam优点:可以自动调整学习效率

6.模型训练

  (1)因为pytorch会累积每次计算的梯度,所以需要将上一循环中的计算的梯度归零
将全部数据训练一遍称为一个epoch,这里训练了500epoch

for epoch in range(500):for x_index,y_index in zip(x,y): #同时对x和y迭代y_pred=model(x_index)        #等价于model.forward(inputs)loss=loss_f(y_pred,y_index)  #根据模型预测输出与实际值y_index计算损失opt.zero_grad()              #将累计的梯度清0loss.backward()              #反向传播损失,计算损失与模型参数之间的梯度opt.step()                   #根据计算得到梯度优化模型参数

  (2)将损失误差打印出来

for epoch in range(500):for x_index,y_index in zip(x,y):   y_pred=model(x_index)loss=loss_f(y_pred,y_index)opt.zero_grad()     #将累计的梯度清0loss.backward()     #反向传播损失,计算损失与模型参数之间的梯度opt.step()          #根据计算得到梯度优化模型参数if (epoch + 1) % 50 == 0:print(f'epoch:{epoch + 1}, loss = {loss.item():.4f}')

  结果:

epoch:50, loss = 12.1212
epoch:100, loss = 7.1772
epoch:150, loss = 4.4344
epoch:200, loss = 2.8781
epoch:250, loss = 1.9724
epoch:300, loss = 1.4308
epoch:350, loss = 1.0978
epoch:400, loss = 0.8877
epoch:450, loss = 0.7521
epoch:500, loss = 0.6629

  参数名称和值:
model.named_parameters()可以以生成器的形式返回模型参数的名称和值

print(list(model.named_parameters()))

  结果:

[('linear.weight', Parameter containing:tensor([[1.4773]], requires_grad=True)), 
('linear.bias', Parameter containing:tensor([1.2792], requires_grad=True))]

  单独查看权重/偏置:

print(model.linear.weight)
print(model.linear.bias)

7.绘制数据

  使用tensor.detach()方法获得具有相同内容但不需要跟踪运算的新张量,可以认为是获取张量的值

plt.scatter(x_list,y_list,label='scatter plot')
plt.plot(x,model(x).detach().numpy(),c='r',label='line plot')
plt.legend()
plt.show()

  结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/881461.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Qt C++设计模式->备忘录模式

备忘录模式&#xff08;Memento Pattern&#xff09;是一种行为型设计模式&#xff0c;用于在不破坏封装性的前提下&#xff0c;捕获并保存对象的内部状态&#xff0c;以便在将来的某个时刻可以恢复到之前的状态。备忘录模式的核心是状态的保存和恢复&#xff0c;常用于实现撤销…

IP地址类型选择指南:动态IP、静态IP还是数据中心IP?

你是否曾经困惑于如何选择最适合业务需求的IP地址类型&#xff1f;面对动态IP、静态IP和数据中心IP这三种选择&#xff0c;你是否了解它们各自对你的跨境在线业务可能产生的深远影响&#xff1f; 在跨境电商领域&#xff0c;选择合适的IP类型对于业务的成功至关重要。动态IP、…

gitee开源商城diygw-mall

DIYGW可视化开源商城系统。所的界面布局显示都通过低代码可视化开发工具生成源码实现。支持集成微信小程序支付。 DIYGW可视化开源商城系统是一款基于thinkphp8 framework、 element plus admin、uniapp开发而成的前后端分离系统。 开源商城项目源码地址&#xff1a;diygw商城…

Java中String类的常见操作Api

目录 String类的常见操作 1).int indexOf (char 字符) 2).int lastIndexOf(char 字符) 3).int indexOf(String 字符串) 4).int lastIndexOf(String 字符串) 5).char charAt(int 索引) 6).Boolean endWith(String 字符串) 7).int length() 8).boolean equals(T 比较对象) 9).b…

区块链积分系统:重塑支付安全与商业创新的未来

在当今社会&#xff0c;数字化浪潮席卷全球&#xff0c;支付安全与风险管理议题日益凸显。随着交易频次与规模的不断扩大&#xff0c;传统支付体系正面临前所未有的效率、合规性和安全挑战。 区块链技术&#xff0c;凭借其去中心化、高透明度以及数据不可篡改的特性&#xff0c…

Linux !ko/5.17-BBRplus AMD64(X86_64)内核致命的 futex_wait 函数死锁问题。

!ko 表示系统内核&#xff08;system-kernel&#xff09; 致命&#xff1a; 在 CentOS&#xff08;RedHat&#xff09;、Ubuntu、Debian 等多个发行版本 Linux 操作系统上&#xff0c;若人们升级 5.17-BBRplus 版本内核&#xff0c;那么在应用程式频繁的 futex_wait&#xff0…

C++面试速通宝典——14

220. static关键字的作用 ‌‌‌‌  static关键字在编程中有多种作用&#xff1a; 在类的成员变量前使用&#xff0c;表示该变量属于类本身&#xff0c;而不是任何类的实例。在类的成员函数前使用&#xff0c;表示该函数不需要对象实例即可调用&#xff0c;且只能访问类的静…

SSH 公钥认证:从gitlab clone项目repo到本地

这篇文章的分割线以下文字内容由 ChatGPT 生成&#xff08;我稍微做了一些文字上的调整和截图的补充&#xff09;&#xff0c;我review并实践后觉得内容没有什么问题&#xff0c;由此和大家分享。 假如你想通过 git clone git10.12.5.19:your_project.git 命令将 git 服务器上…

【openwrt-21.02】T750 openwrt 出现nat46_ipv4_iput crash

Openwrt版本 NAME="OpenWrt" VERSION="21.02-SNAPSHOT" ID="openwrt" ID_LIKE="lede openwrt" PRETTY_NAME="OpenWrt 21.02-SNAPSHOT" VERSION_ID="21.02-snapshot" HOME_URL="https://openwrt.org/" …

简单的maven nexus私服学习

简单的maven nexus私服学习 1.需求 我们现在使用的maven私服是之前同事搭建的&#xff0c;是在公司的一台windows电脑上面&#xff0c;如果出问题会比较难搞&#xff0c;所以现在想将私服迁移到我们公司的测试服务器上&#xff0c;此处简单了解一下私服的一些配置记录一下&am…

多线程(二):Thread类常见的属性和方法

目录 1、run & start 2、Thread类常见的属性和方法 2.1 构造方法 2.2 属性 3、后台进程 & 前台进程 4、setDaemon 5、isAlive 6、终止一个线程 6.1 变量捕获 6.2 currentThread & isInterrupted & interrupt 1、run & start 在多线程&#xff08…

Java面试宝典-Java集合01

Java面试宝典-Java集合01 目录 Java面试宝典-Java集合01 1、Java中常用的集合有哪些&#xff1f; 2、Collection 和 Collections 有什么区别&#xff1f; 3、为什么集合类没有实现 Cloneable 和 Serializable 接口&#xff1f; 4、数组和集合有什么本质区别&#xff1f; 5、数组…

Java | Leetcode Java题解之第470题用Rand7()实现Rand10()

题目&#xff1a; 题解&#xff1a; class Solution extends SolBase {public int rand10() {int a, b, idx;while (true) {a rand7();b rand7();idx b (a - 1) * 7;if (idx < 40) {return 1 (idx - 1) % 10;}a idx - 40;b rand7();// get uniform dist from 1 - 63…

深入浅出理解七层网络协议

目录 深入浅出理解七层网络协议OSI 七层模型概述七层协议详解1. 物理层&#xff08;Physical Layer&#xff09;2. 数据链路层&#xff08;Data Link Layer&#xff09;3. 网络层&#xff08;Network Layer&#xff09;4. 传输层&#xff08;Transport Layer&#xff09;5. 会话…

蓝桥杯【物联网】零基础到国奖之路:十七. 扩展模块之单路ADC和NE555

蓝桥杯【物联网】零基础到国奖之路:十七. 扩展模块之单路ADC和NE555 第一节 硬件解读第二节 CubeMx配置第三节 代码1&#xff0c;脉冲部分代码2&#xff0c;ADC部分代码![在这里插入图片描述](https://i-blog.csdnimg.cn/direct/57531a4ee76d46daa227ae0a52993191.png) 第一节 …

React技术在Meta Connect 2024大会

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗&#xff1f;订阅我们的简报&#xff0c;深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同&#xff0c;从行业内部的深度分析和实用指南中受益。不要错过这个机会&#xff0c;成为AI领…

熵权法计算评价指标权重——使用Excel VBA实现

[ 熵权法 ] 信息是系统有序程度的一个度量&#xff0c;熵是系统无序程度的一个度量&#xff1b;根据信息熵的定义&#xff0c;对于某项指标&#xff0c;可以用熵值来判断某个指标的离散程度&#xff0c;其信息熵值越小&#xff0c;指标的离散程度越大&#xff0c; 该指标对综合…

数据库——表格之间的关系(表格之间的连接和处理)

数据库表格之间经常存在各种关系&#xff1a; 一对一、一对多、多对多 1.一对一 —— 丈夫表&#xff0c;妻子表为例 连接方式一&#xff1a;合并为一张表 这种方式对于一对一来说最优 连接方式二&#xff1a;在其中一张表内加入一个外键&#xff0c;连接另一张表 连…

ARM base instruction -- sdiv

有符号除法运算 Signed Divide divides a signed integer register value by another signed integer register value, and writes the result to the destination register. The condition flags are not affected. 将一个有符号整数寄存器值除以另一个有符号整数寄存器值&am…

使用AudioRelay+ VB-CABLE 实现手机无线麦克风及音响功能

我们有时会有这样的需求: 1、会议中,现场没有麦克风,有手机,有电脑,想直接用手机当用电脑的远程麦克风来使用 2、没有音响,但空间比较大、吵,电脑的声音不够大,要电脑的声音直接发到手机上播放. 这时 AudioRelay VB-CABLE 就可以满足&#xff0c;支持windows 以及macos 具体的…