pytorch(四)用pytorch实现线性回归

文章目录

    • 代码过程
    • 准备数据
    • 设计模型
    • 设计构造函数与优化器
    • 训练过程
    • 训练代码和结果
    • pytorch中的Linear层的底层原理(个人喜欢,不用看)
      • 普通矩阵乘法实现
      • Linear层实现
    • 回调机制

代码过程

训练过程:

  1. 准备数据集
  2. 设计模型(用来计算 y ^ \hat y y^
  3. 构造损失函数和优化器(API)
  4. 训练周期(前馈、反馈、更新)

准备数据

这里的输入输出数据均表示为3×1的,也就是维度均为1

# 行表示实例数量,列表示维度feature
import torch
x_data=torch.Tensor([[1.0],[2.0],[3.0]])
y_data=torch.Tensor([[2.0],[4.0],[6.0]])

设计模型

模型继承Module类,并且必须要实现 init 和 forward 两个方法,其中self.linear=torch.nn.Linear(1,1)表示实例化Linear类,这个类是可调用的,其__call__函数调用了 forward 方法

class LinearModel(torch.nn.Module):def __init__(self):super(LinearModel,self).__init__()# weight 和 bias 1 1 self.linear=torch.nn.Linear(1,1)def forward(self,x):# callabley_pred=self.linear(x)return y_pred# callable
model=LinearModel()

pytorch中的linear类是在某一个数据上应用线性转换,其公式表达为 y = x w T + b y=xw^T+b y=xwT+b

class torch.nn.Linear(in_features,out_features,bias=True) :其中in_features和out_features分别表示输入和输出的数据的维度(列的数量),bias表示偏置,默认是true,该类有两个参数

  • weight:可学习参数,值从均匀分布 U ( − k , k ) U(-\sqrt k,\sqrt k) U(k ,k )中获取,其中 k = 1 i n _ f e a t u r e s k=\frac{1}{in\_features} k=in_features1
  • bias:shape和输出的维度一样,也是从分布 U ( − k , k ) U(-\sqrt k,\sqrt k) U(k ,k )中初始化的
    在这里插入图片描述

设计构造函数与优化器

# 构造损失函数和优化器
criterion=torch.nn.MSELoss(size_average=False)# w和b--->parameters
opyimizer=torch.optim.SGD(model.parameters(),lr=0.01)

在这里插入图片描述
在这里插入图片描述

训练过程

# 训练过程
for epoch in range(100):y_pred=model(x_data)loss=criterion(y_pred,y_data)# loss标量,自动调用__str__()print(epoch,loss)optimizer.zero_grad()# backwardloss.backward()# updateoptimizer.step()

训练代码和结果

# 行表示实例数量,列表示维度feature
import torch
x_data=torch.Tensor([[1.0],[2.0],[3.0]])
y_data=torch.Tensor([[2.0],[4.0],[6.0]])class LinearModel(torch.nn.Module):def __init__(self):super(LinearModel,self).__init__()# weight 和 bias 1 1 self.linear=torch.nn.Linear(1,1)def forward(self,x):# callabley_pred=self.linear(x)return y_pred# callable
model=LinearModel()# 构造损失函数和优化器
criterion=torch.nn.MSELoss(size_average=False)
optimizer=torch.optim.SGD(model.parameters(),lr=0.01)# 训练过程
for epoch in range(100):y_pred=model(x_data)loss=criterion(y_pred,y_data)# loss标量,自动调用__str__()print(epoch,loss)optimizer.zero_grad()# backwardloss.backward()# updateoptimizer.step()# 打印信息
print('w=',model.linear.weight.item())
print('b=',model.linear.bias.item())x_test=torch.Tensor([4.0])
y_test=model(x_test)
print('y_pred=',y_test.data)

在这里插入图片描述


pytorch中的Linear层的底层原理(个人喜欢,不用看)

我们在课本使用到的线性函数的基本公式表达为 y = x w T + b y=xw^T+b y=xwT+b,但是在Linear层中,当输入特征被Linear层接收是,它会接收后转置,然后乘以权重矩阵,得到的是输出特征的转置,换句话说可以在底层使用Linear,它实际上做的是 y T = w x T + b y^T=wx^T+b yT=wxT+b。可以使用下面的案例进行验证:

在这里插入图片描述

普通矩阵乘法实现

很明显,上面的图标表示一个 3×4 的矩阵乘以 4×1 的矩阵,得到一个 3×1 的输出矩阵,使用普通矩阵的乘法实现如下。

import torchin_features=torch.tensor([1,2,3,4],dtype=torch.float32)
weight_matrix=torch.tensor([[1,2,3,4],[2,3,4,5],[3,4,5,6]
],dtype=torch.float32)weight_matrix.matmul(in_features)# 矩阵乘法

实现截图:
在这里插入图片描述

Linear层实现

# 这里还是使用上面使用过的数据
import torch
in_features=torch.tensor([1,2,3,4],dtype=torch.float32)
weight_matrix=torch.tensor([[1,2,3,4],[2,3,4,5],[3,4,5,6]
],dtype=torch.float32)print(weight_matrix.matmul(in_features))# 矩阵乘法fc = torch.nn.Linear(in_features=4, out_features=3, bias=False)
# 这里是随机一个权重矩阵
print('fc.weight',fc.weight)
fc(in_features)

输出结果:
在这里插入图片描述

print('fc.weight',fc.weight)# 使用上面的权重矩阵进行计算
fc.weight = torch.nn.Parameter(weight_matrix)
print('fc.weight',fc.weight)
fc(in_features)

结果截图:
在这里插入图片描述

可以看到上面截图与下面的截图的区别,一开始随机一个权重的时候,进行运算,使用到前面提及到的权重矩阵后,Linear层进行运算之后,得到与使用普通矩阵乘法一样的结果,相同的结果说明,Linear底层的实现与上面的矩阵乘法的逻辑是一致的

以上的论证可以说明,Linear的底层实现其实是 y T = w x T + b y^T=wx^T+b yT=wxT+b,而不是 y = x w T + b y=xw^T+b y=xwT+b,可能会有人好奇,为什么书本上都是写的后者而不是写前者,其实本质上二者都一样,前者的转置就是后者。

回调机制

在pytorch学习(一)线性模型中,第一个代码中,我们没有通过pytorch实现线性模型的时候,我们会显式调用forward函数,计算前馈的值,我们是这样写的y_pred_val=forward(x_val),但是在使用pytorch之后,我们是这样写的y_pred=model(x_data),直接实例化一个对象,然后通过对象直接计算预测值(前馈值),但是并没有使用到forward函数。这是因为pytorch模块类中实现了python中一个特殊的函数,也就是回调函数

如果一个类实现了回调方法,那么只要对象实例被调用,这个特殊的方法也会被调用。我们不直接调用forward()方法,而是调用对象实例。在对象实例被调用之后,在底层调用了__ call __方法,然后调用了forward()方法。这适用于所有的PyTorch神经网络模块。

以上仅代表小白个人学习观点,如有错误欢迎批评指正。

参考

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/718935.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

国图公考:山东事业编考试即将开始

山东事业编考试时间为2024年3月10日-9.00-11.30分 考试科目为公基写作 准考证打印时间为2024年3月5日9.00-3月10日9.30分 准考证打印入口:山东考试信息网 综合类笔试在全省十六市均设置考点,参加考试的考生可凭借准考证和本人身份证参加笔试

Python爬虫实战(基础篇)—13获取《人民网》【最新】【国内】【国际】写入Word(附完整代码)

文章目录 专栏导读背景测试代码分析请求网址请求参数代码测试数据分析利用lxml+xpath进一步分析将获取链接再获取文章内容测试代码写入word完整代码总结专栏导读 🔥🔥本文已收录于《Python基础篇爬虫》 🉑🉑本专栏专门针对于有爬虫基础准备的一套基础教学,轻松掌握Py…

第 2 个 Java Web 应用工程(JSP JavaBean DB)(含源码)(图文版)

JavaBean 是一种符合特定约定的 Java 类,通常用于在 Java 应用程序中封装数据以及提供对数据的访问和修改方法。 本文示例:建立一个 Tomcat 工程,编写一个 JSP 页面,调用 JavaBean 访问数据库并显示到页面上,发布到 T…

音视频数字化(视频线缆与接口)

目录 1、DVI接口 2、DP接口 之前的文章【音视频数字化(线缆与接口)】提到了部分视频线缆,今天再补充几个。 视频模拟信号连接从莲花头的“复合”线开始,经历了S端子、色差分量接口,通过亮度、色度尽量分离的办法提高画面质量,到VGA已经到了模拟的顶峰,实现了RGB的独立…

一文读懂Penpad 以 Fair Launch 方式推出的首个资产 PEN

随着 2 月 28 日比特币重新站上 6 万美元的高峰后,标志着加密市场正在进入新一轮牛市周期。在 ETF 的促进作用下,加密市场不断有新的资金流入,加密货币总市值不断攀升。Layer2 市场率先做出了反应,有数据显示,当前以太…

2020PAT--冬

The Closest Fibonacci Number The Fibonacci sequence Fn​ is defined by Fn2​Fn1​Fn​ for n≥0, with F0​0 and F1​1. The closest Fibonacci number is defined as the Fibonacci number with the smallest absolute difference with the given integer N. Your job…

Spring初始(相关基础知识和概述)

Spring初始(相关基础知识和概述) 一、Spring相关基础知识(引入Spring)1.开闭原则OCP2.依赖倒置原则DIP3.控制反转IoC 二、Spring概述1.Spring 8大模块2.Spring特点2.Spring的常用jar文件 一、Spring相关基础知识(引入S…

除微信视频号下载器还有哪些可以应用可以下载视频?

市面上有很多视频号下载器,但犹豫部分视频号下载器逐步失效,就有很多小伙伴问还有哪些可以应用可以下载视频? 视频下载助手 除视频号视频下载器以外,还有【视频号下载助手】简称:视频下载助手 比如说,抖音…

spring cloud 之 Netflix Eureka

1、Eureka 简介 Eureka是Spring Cloud Netflix 微服务套件中的一个服务发现组件,本质上是一个基于REST的服务,主要用于AWS云来定位服务以实现中间层服务的负载均衡和故障转移,它的设计理念就是“注册中心”。 你可以认为它是一个存储服务地址信息的大本…

18个惊艳的可视化大屏(第14辑):能源行业应用

能源行业涉及能源生产、转化、储存、输送和使用的各个领域和环节,包括石油和天然气行业、煤炭行业、核能行业、可再生能源行业和能源服务行业,本期贝格前端工场带来能源行业可视化大屏界面供大家欣赏。 能源行业的组成 能源行业是指涉及能源生产、转化、…

数字化转型导师坚鹏:金融机构数字化运营

金融机构数字化运营 课程背景: 很多金融机构存在以下问题: 不清楚数字化运营对金融机构发展有什么影响? 不知道如何提升金融机构数字化运营能力? 不知道金融机构如何开展数字化运营工作? 课程特色:…

盘点全网哪些超乎想象的高科技工具?有哪些免费开源的最新AI智能工具?短视频自媒体运营套装?

盘点全网哪些超乎想象的高科技工具?有哪些免费开源的最新AI智能工具?短视频自媒体运营套装? 自媒体主要用来干什么? 可以通过短视频吸引更多的观众和粉丝,提升自媒体账号的影响力和知名度。 短视频形式更加生动、直观…

使用C++界面框架ImGUI开发一个简单程序

简介 ImGui 是一个用于C的用户界面库,跨平台、无依赖,支持OpenGL、DirectX等多种渲染API,是一种即时UI(Immediate Mode User Interface)库,保留模式与即时模式的区别参考保留模式与即时模式。ImGui渲染非常…

关于企业数字化转型:再认识、再思考、再出发

近年来,随着国家数字化政策不断出台、新兴技术不断进步、企业内生需求持续释放,数字化转型逐步成为企业实现高质量发展的必由之路,成为企业实现可持续发展乃至弯道超车的重要途径。本文重点分析当下阻碍企业数字化转型的难点,提出…

SPC 之 I-MR 控制图

概述 1924 年,美国的休哈特博士应用统计数学理论将 3Sigma 原理运用于生产过程中,并发表了 著名的“控制图法”,对产品特性和过程变量进行控制,开启了统计过程控制新时代。 什么是控制图 控制图指示过程何时不受控制&#xff…

通过 Jenkins 经典 UI 创建一个基本流水线

通过 Jenkins 经典 UI 创建一个基本流水线 点击左上的 新建任务。 在 输入一个任务名称字段,填写你新建的流水线项目的名称。 点击 流水线,然后点击页面底部的 确定 打开流水线配置页 点击菜单的流水线 选项卡让页面向下滚动到 流水线 部分 在 流水线 …

微信小程序开发学习笔记《19》uni-app框架-配置小程序分包与轮播图跳转

微信小程序开发学习笔记《19》uni-app框架-配置小程序分包与轮播图跳转 博主正在学习微信小程序开发,希望记录自己学习过程同时与广大网友共同学习讨论。建议仔细阅读uni-app对应官方文档 一、配置小程序分包 分包可以减少小程序首次启动时的加载时间 为此&#…

YOLOV5学习

【目标检测】yolov5模型详解-CSDN博客

如何使用生成式人工智能探索视频博客的魅力?

视频博客,尤其是关于旅游的视频博客,为观众提供了一种全新的探索世界的方式。通过图像和声音的结合,观众可以身临其境地体验到旅行的乐趣和发现的喜悦。而对于内容创作者来说,旅游视频博客不仅能分享他们的旅行故事,还…

模拟算法题练习(一)(扫雷,灌溉,回文日期)

目录 模拟算法介绍: (一、扫雷) (二、灌溉) (三、回文日期) 有一说一这题大佬的题解是真的强 模拟算法介绍: 模拟算法通过模拟实际情况来解决问题,一般容易理解但是实…