用Pytorch实现线性回归(Linear Regression with Pytorch)

使用pytorch写神经网络的第一步就是需要准备好数据集,设计模型(用于计算y_hat(y的预测值)),构造损失函数和优化器(使用PyTorch API),写训练周期(前馈(算loss)+反馈(算梯度)+更新(更新权重))

一:准备数据

现在使用mini-batch的方式,X和Y为3x1(可以变,但是x和y要相同)的矩阵形式。

从代码中也可以看出来,x和y都是3x1的矩阵。

二:设计模型(构造计算图)

此处使用了一个仿射模型(在pytorch中叫做线性单元)

在我们设计的例子中,我们需要设置权重w的数值,和偏置量b。

那w和b的形状(几x几的矩阵),是由y_hat和x来共同确定。

之后将y_hat和y放入loss函数中进行计算,得出loss的值(一定是一个标量)。

看下模型设计的代码:

#需要继承自module ,因为module中有很多方法我们需要使用
class LinearModel(torch.nn.Module):def __init__(self): #构造函数 在初始化对象时默认调用的函数super(LinearModel,self).__init__() #super调用父类的构造self.linear = torch.nn.Linear(1,1) #构造一个对象 linear Unit中的w和b(linear来自父类,可以自动反向传播)def forward(self,x): #前馈需要进行的计算 发现没有backword模块,因为Module中自动根据计算图实现backword过程y_pred = self.linear(x)return y_predmodel = LinearModel() #实例化 在之后既可以使用model(x)将x传入forword中的x,求得y_pred

其中torch.nn.Linear 的使用方法如下 

三:构造loss和optimizer

此处我们使用MSEloss,需要的参事时y_hat和y,就可以求出loss。

代码如下:

criterion = torch.nn.MSELoss(size_average=False)

我们使用SGD优化器(不会构建计算图),代码如下

optimizer = torch.optim.SGD(model.parameters(),lr=0.01)

四:训练过程

for epoch in range(100):y_pred = model(x_data)  #先计算出y_hatloss = criterion(y_pred,y_data) #再计算出lossprint(epoch,loss.item()) optimizer.zero_grad()#在反馈前将梯度清0loss.backward()#反馈optimizer.step()#更新

最后打印一些相关内容

# w b
print('w=',model.linear.weight.item())
print('b=',model.linear.weight.item())#Test Model
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred=',y_test.data)

发现当range为1000时,已经达到了我们的预期。

五:整体流程

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/46981.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Centos7 rpm 安装 Mysql 8.0.28

Centos7 rpm 安装 Mysql 8.0.28 一、检查系统是否已经安装了Mysql 如果安装了则卸载 [rootiZbp1byzaznzn9jncxr010Z /]# rpm -qa | grep mysql[rootiZbp1byzaznzn9jncxr010Z /]# rpm -qa | grep mariadb mariadb-libs-5.5.68-1.el7.x86_64如果安装了 mysql ,maria…

2-36 基于matlab的流行学习算法程序

基于matlab的流行学习算法程序。通过GUI的形式将MDS、PCA、ISOMAP、LLE、Hessian LLE、Laplacian、Dissusion MAP、LTSA八种算法。程序以可视化界面进行展示,可直接调用进行分析。多种案例举例说明八种方法优劣,并且可设置自己数据进行分析。程序已调通&…

鸿蒙语言基础类库:【@system.brightness (屏幕亮度)】

屏幕亮度 说明: 从API Version 7 开始,该接口不再维护,推荐使用新接口[ohos.brightness]。本模块首批接口从API version 3开始支持。后续版本的新增接口,采用上角标单独标记接口的起始版本。 导入模块 import brightness from sy…

【算法】LRU缓存

难度:中等 题目: 请你设计并实现一个满足 LRU (最近最少使用) 缓存 约束的数据结构。 实现 LRUCache 类: LRUCache(int capacity) 以 正整数 作为容量 capacity 初始化 LRU 缓存int get(int key) 如果关键字 key 存在于缓存中,…

多级表头固定列问题

父级的width,是需要固定的列的width的总和 参考: el-table 多级表头下对应列的固定

JAVA零基础学习1(CMD、JDK、环境变量、变量和键盘键入、IDEA)

JAVA零基础学习1(CMD、JDK、环境变量、变量和键盘键入、IDEA) CMD常见命令配置环境变量JDK的下载和安装变量变量的声明和初始化声明变量初始化变量 变量的类型变量的作用域变量命名规则示例代码 键盘键入使用 Scanner 类读取输入步骤示例代码 常用方法处…

HBuilder X3.4版本中使用uni-app自定义组件

HBuilder X3.4版本中使用uni-app自定义组件 这是我的小程序页面结构 方式一&#xff1a;导入components 1.创建componets文件&#xff0c;并编写你的组件页面 <template><view class"my-search-container"><!-- 使用 view 组件模拟 input 输入框的样…

无人机区域常见名词

融合空域 是指有其他航空器同时运行的空域。 隔离空域 是指专门分配给无人机系统运行的空域&#xff0c;通过限制其他航空器的进入以规避碰撞风险。 人口稠密区 是指城镇、村庄、繁忙道路或大型露天集会场所等区域。 重点地区 是指军事重地、核电站和行政中心等关乎国家…

LintcCode 468 · 对称二叉树【简单 二叉树 递归 Java】

题目 题目链接&#xff1a; https://www.lintcode.com/problem/468/description?showListFetrue&page1&problemTypeId2&tagIds371&orderingid&pageSize50 思路 递归 Java代码 /*** Definition of TreeNode:* public class TreeNode {* public int…

厂家置换电费如何达到最大化收益

新能源行业知识体系-------主目录-----持续更新https://blog.csdn.net/grd_java/article/details/140004020 文章目录 一、电能电费二、同时刻不同厂家置换&#xff0c;不会影响最终电能电费结果三、风险防范补偿和回收机制四、我们的数据如何考虑补偿和回收五、如何利用补偿和…

蓝桥杯14小白月赛题解

直接输出pi/ti,for遍历 #include <iostream> using namespace std; #define int long long int a,b,c ; double t1.00; signed main() {cin>>a;int an0;for(int i1;i<a;i){cin>>b>>c;if(t>c*1.00/b){tc*1.00/b;ani;} }cout<<an<<e…

【车载开发系列】GIT教程---如何使用GUI来提交变更

【车载开发系列】GIT教程—如何使用GUI来提交变更 GIT教程---如何使用GUI来提交变更 【车载开发系列】GIT教程---如何使用GUI来提交变更一. 使用Git GUI的好处二. 使用GUI克隆出一个已有仓库 一. 使用Git GUI的好处 在软件开发中&#xff0c;Git通常用于管理和操作版本控制系统…

Python 从PDF中提取图片和图片信息(坐标、宽度和高度等)

目录 使用工具 Python从PDF的特定页面中提取图片 Python从PDF文档中提取图片 Python从PDF中提取图片的坐标、宽度和高度等信息 PDF文件作为一种广泛使用的电子文档格式&#xff0c;不仅包含文字信息&#xff0c;还可能包含各种图片、图表等视觉元素。在某些场景下&#xff…

【Linux】安装PHP扩展-Swoole

说明 本文档是在centos7.6的环境下&#xff0c;安装PHP7.4之后&#xff0c;安装对应的PHP扩展Swoole。 一、swoole简述 Swoole 是一个为 PHP 设计的高性能的异步并行网络通信引擎&#xff0c;它以扩展&#xff08;extension&#xff09;的形式存在&#xff0c;极大地提升了 …

zephyr设置BLE广播数据实例

目录 实例1&#xff1a;静态开启广播数据实例2&#xff1a;动态更改广播数据实例3&#xff1a;创建可连接的广播 实例1&#xff1a;静态开启广播数据 新建一个hello world的工程模板。 在prj.conf中开启蓝牙 CONFIG_BTy这个宏&#xff0c;默认会开启广播支持 ( BT_BROADCAS…

Spring Boot项目的404是如何发生的

问题 在日常开发中&#xff0c;假如我们访问一个Sping容器中并不存在的路径&#xff0c;通常会返回404的报错&#xff0c;具体原因是什么呢&#xff1f; 结论 错误的访问会调用两次DispatcherServlet&#xff1a;第一次调用无法找到对应路径时&#xff0c;会给Response设置一个…

TCP与UDP的理解

文章目录 UDP协议UDP协议的特点UDP的应用以及杂项 TCP协议TCP协议段格式解释和TCP过程详解确认应答机制 -- 序号和确认序号以及6位标志位中的ACK超时重传机制连接管理机制 与标志位SYN,FIN,ACK滑动窗口流量控制拥塞控制延迟应答捎带应答和面向字节流粘包问题TCP异常情况TCP特点…

5.4 软件工程-系统设计

系统设计 - 概述 设计软件系统总体结构 数据结构及数据库设计 编写概要设计文档、评审 详细设计的基本任务 真题

anaconda常用指令学习

在win系统下安装完成后&#xff0c;需要进行环境变量的配置 配置环境变量&#xff0c;把下面的4个目录全部加到PATH变量里面。 Anaconda安装目录的根目录\ Anaconda安装目录的根目录\Scripts Anaconda安装目录的根目录\Library\bin Anaconda安装目录的根目录\Library\mingw-w6…

[米联客-安路飞龙DR1-FPSOC] FPGA基础篇连载-15 SPI接收程序设计

软件版本&#xff1a;Anlogic -TD5.9.1-DR1_ES1.1 操作系统&#xff1a;WIN10 64bit 硬件平台&#xff1a;适用安路(Anlogic)FPGA 实验平台&#xff1a;米联客-MLK-L1-CZ06-DR1M90G开发板 板卡获取平台&#xff1a;https://milianke.tmall.com/ 登录“米联客”FPGA社区 ht…