Pytorch线性回归

使用pytorch来重现线性模型的过程,构造神经网络module,构造损失函数loss,构造随机梯度下降的优化器sgd。

一 revise

首先确定我们的模型,我们希望完成的目标就是得到较小的loss,所以我们就需要一个标量值的loss。

那其实在上一部分的内容就提到了tensor,loss,backward的使用,其实这个就是我们利用pytorch给我们的功能了。

二 pytorch fashion

pytorch写神经网络的第一步就是要准备数据集(有构造数据集的工具),设计一个模型计算y_head,构造损失函数和优化器,写训练的周期(前馈反馈更新)。

2.1准备数据

我们之前准备数据,就是直接使用列表。

x_data = [1.0,2.0,3.0]
y_data = [2.0,4.0,6.0]

In pytorch , the computational graph is in mini-batch fashion,so X and Y are 3x1 Tensor.

但是现在我们希望它不在是一个列表或者一个向量,我们希望它可以成为一个3X1的矩阵。当然公式依然还是使用之前的y_head = w*x +b 。图中的y_pred 就是我们说的y_head。

大家可能注意到我们上面说计算图使用小批量的方式(mini batch),这其实就是把x的三个数据同时放在一起,同时进行计算。

现在我们有了3x1矩阵的x了,那我们就可以得到3x1矩阵的y。因为存在广播机制,看图中w和b,其实也会变成3x1的矩阵。

当然loss函数的公式也是和之前不变为(y_head - y)**2。此时loss和y也会变成3x1的矩阵。

说了这么多,无非就是要把数据处理成矩阵的形式那使用tensor的方式就是:

import torch
x_data = torch.Tensor([[1.0],[2.0],[3.0]])
y_data = torch.Tensor([[2.0],[4.0],[6.0]])

总结来说,使用mini batch构造数据集的时候,我们就是需要x和y都是矩阵的形式。

2.2 构造计算图

最开始计算梯度手工计算,后来我们构建计算图可以自动把梯度求出来,之后就可以进行优化了。  

因此我们在准备好数据的基础上,下一步就是构造出计算图。我还是使用y_head = x*w + b这样一个函数(放仿射模型)。我们通过这个构建出计算图的一部分叫线性单元。

那其实我们现在进行计算的都是矩阵,既然计算矩阵就需要确定w的大小和b的大小。就需要从z和x的维度来确定w和b的维度。

接下来我们需要把y_head放到loss函数中进行计算,上面我们也说到,loss必须是一个标量(只有是标量的情况下才可以backward)。现在y_head是一个矩阵形式,且y也是,得到的loss应该也是矩阵形式,此时我们就需要对loss进行适当的修改,通常会对loss内的数进行求和,使其变成一个标量。

class LinearModel(torch.nn.Module):def __init__(self): #初始化super(LinearModel,self).__init__()self.linear = torch.nn.Linear(1,1)def forward(self,x): #前馈y_pred = self.linear(x)return y_predmodel = LinearModel()

大家可以看见,我们在类中没有写反馈的函数,这是由于model构造出的对象会自动根据计算图去实现backward。

(1)构造函数的super不用变。

torch.nn.Linear 使pytorch中的一个类,这个类里面的对象包括了权重w和偏置b,就可以直接完成下面的整个运算。Linear也是源于Model,所以他也可以进行自动的反向传播。

class torch.nn.linea(in_features,out_features,bias=True) 所做的计算就是y=ax+b。其中in_features表示的就是输入的x是几维的,out_features表示的就是输出的是几维的。那我们在mini batch中矩阵的行表示的使各个样本的值,那此时不难猜出矩阵的列表示的就是feature。bias为True表示需要偏置量,默认为True。

下面两个计算公式都可以,注意w在矩阵乘法的位置。

(2)forward(self,x)

 y_pred = self.linear(x) 这一步其实就是在计算我们的y_head。其实看上面的定义也可以看出来。

(3)最后将模型进行实例化,供我们后面使用。

2.3构造损失函数和优化器

(1)损失函数

我们还是使用MSE损失函数,此时也还是需要构建计算图,整体过程就是在拿到y_head的前提下,对y进行计算,得到loss的值,此时loss还是一个矩阵的形式,最后还需要对其进行变成标量。

class torch.nn.MSELoss(size_average=True,reduce=True) 

其中size_average是是否求均值。reduce是是否进行降维,也就是是否进行求和。

现在这个size_average被废除了,大家可以看下列代码实现求均值的True or False。

#criterion = torch.nn.MSELoss(size_average=True)
#改为:
criterion = torch.nn.MSELoss(reduction='mean')#criterion = torch.nn.MSELoss(size_average=False)
#改为:
criterion = torch.nn.BCELoss(reduction='sum')

因此对于criterion这个对象需要的参数是(y_head,y)

(2)优化器

optimizer = torch.optim.SGD(model.parameters(),lr=0.01)

优化器是不会建立计算图的。

model.parameters() 这个指的是权重。简单来说就是model中其实没有定义权重,model里存在我们上面定义的成员linear,linear中包含两个权重w和b。现在就需要告诉优化器哪些tensor是需要优化的,哪些是用于梯度下降的。因此在使用SGD这个模型时,想找权重,直接model.parameters。可以把model中所有的参数全部找到。

lr时学习率,一般会给一个固定的值。可以在不同的部分使用不同的学习率。

2.4训练过程

for epoch in range(100):y_pred = model(x_data)  #先计算出y_headloss = criterion(y_pred,y_data) #再计算出lossprint(epoch,loss.item()) optimizer.zero_grad()#在反馈前将梯度清0loss.backward()#反馈optimizer.step()#更新

整个过程其实就先算y_head,再计算loss,随后backward,更新。

最后就是打印一下权重信息。

# w b
print('w=',model.linear.weight.item())
print('b=',model.linear.weight.item())#Test Model
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred=',y_test.data)

对于我们给出的x_data和y_data,其实最终最好的情况是w=2 b=0。

分别是100次epoch和1000次epoch的结果,看得出来1000次更接近于我们想要的值。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/20551.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

YOLOv10(2):网络结构及其检测模型代码部分阅读

YOLOv10(1):初探,训练自己的数据-CSDN博客 目录 1. 写在前面 2. 局部模块 (1)SCDown (2)C2fCIB (3)PSA(partial self-attention) 3. 代码解读 &#x…

手把手教大家如何使用Kaggle平台的免费GPU资源跑深度学习模型

如果手头没有GPU资源是没法很好进行学习和实操各种深度学习模型的,所幸有一些平台提供了GPU资源供广大兴趣爱好者进行免费使用。 一、免费GPU资源的平台 1. Google Colab 地址:https://colab.research.google.com/ 简介:Google Colab&…

ssm_mysql_高校自习室预约系统(源码)

博主介绍:✌程序员徐师兄、8年大厂程序员经历。全网粉丝15w、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专栏推荐订阅👇…

IDEA中,MybatisPlus整合Spring项目的基础用法

一、本文涉及的知识点【重点】 IDEA中使用MybatisPlus生成代码,并使用。 Spring整合了Mybatis框架后,开发变得方便了很多,然而,Mapper、Service和XML文件,在Spring开发中常常会重复地使用,每一次的创建、修…

权限修饰符和代码块

一.权限修饰符 1.权限修饰符:是用来控制一个成员能够被访问的范围的。 2.可以修饰成员变量,方法,构造方法,内部类。 3.例子: public class Student {priviate String name;prviate int age;} 二.权限修饰符的分类 有四种作用范围大小…

详解寄存器模型reg_model的auto_predict

什么是reg_model镜像值? DUT的配置寄存器的值是实际值,reg_model有镜像值、期望值的概念。 镜像值:存放我们认为此时DUT里寄存器的实际值。 期望值:存放我们期望DUT寄存器被赋予的值。 什么是auto predict? 那么怎么更新reg…

安卓ANR检测、分析、优化面面谈

前言 一个引发讨论的楔子,以下三种现象有什么区别: App停止运行App暂无响应App闪退 答案: 产生原因不同:停止运行是UNCheckExceptionError暂无响应是ANRDialog闪退是CheckExceptionError 本文讨论的主题是ANR的定义、分类、复现…

内核注入DLL,支持注入PPL

这是我的个人项目,目前功能: 内核注入DLL到进程,支持注入PPL进程,可绕过任意代码卫士保护,签名校验。内核调用应用层任意函数,支持常见的调用约定。 后续可能会增加: 代码注入 Rookit和Anti-…

E. 矩阵第k大

看到这句话,其中任意两个数都不能在同一行或者同一列 经典的网络流/匈牙利 由于小白看不懂网络流 (其实是我不会) ,不妨就讲讲匈牙利 匈牙利算法 前置知识: 二分图 匈牙利(是个人)算法是二分…

纵向导航栏使用navbar-nav-scroll溢出截断问题

项目场景: 组件:Bootstrap-4.6.2、JQuery 3.7.1 测试浏览器:Firefox126.0.1、Microsoft Edge125.0.2535.67 IDE:eclipes2024-03.R 在编写CRM的工作台主页面时,由于该页面使用的是较旧的技术,所以打算使用…

ChatGPT-4o 有何特别之处?

文章目录 多模态输入,多模态输出之前的模型和现在模型对比 大家已经知道,OpenAI 在 GPT-4 发布一年多后终于推出了一个新模型。它仍然是 GPT-4 的一个变体,但具有前所未见的多模态功能。 有趣的是,它包括实时视频处理等强大功能&…

基础9 探索图形化编程的奥秘:从物联网到工业自动化

办公室内,明媚的阳光透过窗户洒落,为每张办公桌披上了一层金色的光辉。同事们各自忙碌着,键盘敲击声、文件翻页声和低声讨论交织在一起,营造出一种忙碌而有序的氛围。空气中氤氲着淡淡的咖啡香气和纸张的清新味道,令人…

fastjson 泛型转换问题(详解)

系列文章目录 附属文章一:fastjson TypeReference 泛型类型(详解) 文章目录 系列文章目录前言一、代码演示1. 不存在泛型转换2. 存在泛型转换3. 存在泛型集合转换 二、原因分析三、解决方案1. 方案1:重新执行泛型的 json 转换2. …

数据可视化每周挑战——中国高校数据分析

最近要高考了,这里祝大家金榜题名,旗开得胜。 这是数据集,如果有需要的,可以私信我。 import pandas as pd import numpy as np import matplotlib.pyplot as plt from pyecharts.charts import Line from pyecharts.charts impo…

图像处理ASIC设计方法 笔记26 非均匀性校正SOC如何设计

在红外成像技术领域,非均匀性校正是一个至关重要的环节,它直接影响到成像系统的性能和目标检测识别的准确性。非均匀性是指红外焦平面阵列(IRFPA)中各个像元对同一辐射强度的响应不一致的现象,这种不一致性可能是由于制造过程中的缺陷、材料的不均匀性或者像元间的热电特性…

simCSE句子向量表示(1)-使用transformers API

SimCSE SimCSE: Simple Contrastive Learning of Sentence Embeddings. Gao, T., Yao, X., & Chen, D. (2021). SimCSE: Simple Contrastive Learning of Sentence Embeddings. arXiv preprint arXiv:2104.08821. 1、huggingface官网下载模型 官网手动下载:pri…

集合操作进阶:关于移除列表元素的那点事

介绍 日常开发中,难免会对集合中的元素进行移除操作,如果对这方面不熟悉的话,就可能遇到 ConcurrentModificationException,那么,如何优雅地进行元素删除?以及其它方式为什么不行? 数据初始化…

国内类似ChatGPT的大模型应用有哪些?发展情况如何了

第一部分:几个容易混淆的概念 很多人,包括很多粉丝的科技博主,经常把ChatGPT和预训练大模型混为一谈,因此有必要先做一个澄清。预训练大语言模型属于预训练大模型的一类,而ChatGPT、文心一言又是预训练大语言模型的一个…

node基础-持续更新

node基础 1.node模块2.node环境搭建3.fs模块4.ES模块和CommonJS模块4.1 更改后缀名4.2 package.json配置支持es模块4.3 变量别名4.4 CommonJS模块 5.打造自己的脚手架工具5.1创建自定义全局指令5.2 使用commander处理--help参数5.3 处理自定义指令5.4 逻辑代码模块化拆分5.5 命…

iPad里的图片如何导出 iPad的照片如何管理

我们的设备中充满了各种重要的照片和视频,特别是iPad,作为苹果公司的一款强大的平板电脑,它不仅能够捕捉生活中的精彩瞬间,还可以存储和展示我们珍贵的回忆。然而,随着照片数量的不断增加,有效地管理和导出…