用Pytorch实现线性回归模型

目录

  • 回顾
  • Pytorch实现
    • 步骤
    • 1. 准备数据
    • 2. 设计模型
      • class LinearModel
      • 代码
    • 3. 构造损失函数和优化器
    • 4. 训练过程
    • 5. 输出和测试
    • 完整代码
  • 练习

回顾

前面已经学习过线性模型相关的内容,实现线性模型的过程并没有使用到Pytorch。
这节课主要是利用Pytorch实现线性模型。
学习器训练:

  • 确定模型(函数)
  • 定义损失函数
  • 优化器优化(SGD)

之前用过Pytorch的Tensor进行Forward、Backward计算。
现在利用Pytorch框架来实现。

Pytorch实现

步骤

  1. 准备数据集
  2. 设计模型(计算预测值y_hat):从nn.Module模块继承
  3. 构造损失函数和优化器:使用PytorchAPI
  4. 训练过程:Forward、Backward、update

1. 准备数据

在PyTorch中计算图是通过mini-batch形式进行,所以X、Y都是多维的Tensor。
在这里插入图片描述

import torch
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])

2. 设计模型

在之前讲解梯度下降算法时,我们需要自己计算出梯度,然后更新权重。
在这里插入图片描述
而使用Pytorch构造模型,重点时在构建计算图和损失函数上。
在这里插入图片描述

class LinearModel

通过构造一个 class LinearModel类来实现,所有的模型类都需要继承nn.Module,这是所有神经忘了模块的基础类。
class LinearModel这种定义的模型类必须包含两个部分:

  • init():构造函数,进行初始化。
    def __init__(self):super(LinearModel, self).__init__()#调用父类构造函数,不用管,照着写。# torch.nn.Linear(in_featuers, in_featuers)构造Linear类的对象,其实就是实现了一个线性单元self.linear = torch.nn.Linear(1, 1)

在这里插入图片描述

  • forward():进行前馈计算
    (backward没有被写,是因为在这种模型类里面会自动实现)

Class nn.Linear 实现了magic method call():它使类的实例可以像函数一样被调用。通常会调用forward()。

    def forward(self, x):y_pred = self.linear(x)#调用linear对象,输入x进行预测return y_pred

代码

class LinearModel(torch.nn.Module):def __init__(self):super(LinearModel, self).__init__()#调用父类构造函数,不用管,照着写。# torch.nn.Linear(in_featuers, in_featuers)构造Linear类的对象,其实就是实现了一个线性单元self.linear = torch.nn.Linear(1, 1)def forward(self, x):y_pred = self.linear(x)#调用linear对象,输入x进行预测return y_predmodel = LinearModel()#实例化LinearModel()

3. 构造损失函数和优化器

采用MSE作为损失函数

torch.nn.MSELoss(size_average,reduce)

  • size_average:是否求mini-batch的平均loss。
  • reduce:降维,不用管。

在这里插入图片描述SGD作为优化器torch.optim.SGD(params, lr):

  • params:参数
  • lr:学习率

在这里插入图片描述

criterion = torch.nn.MSELoss(size_average=False)#size_average:the losses are averaged over each loss element in the batch.
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)#params:model.parameters(): w、b

4. 训练过程

  1. 预测
  2. 计算loss
  3. 梯度清零
  4. Backward
  5. 参数更新
    简化:Forward–>Backward–>更新
#4. Training Cycle
for epoch in range(100):y_pred = model(x_data)#Forward:预测loss = criterion(y_pred, y_data)#Forward:计算lossprint(epoch, loss)optimizer.zero_grad()#梯度清零loss.backward()#backward:计算梯度optimizer.step()#通过step()函数进行参数更新

5. 输出和测试

# Output weight and bias
print('w = ', model.linear.weight.item())
print('b = ', model.linear.bias.item())# Test Model
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)

完整代码

import torch
#1. Prepare dataset
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])#2. Design Model
class LinearModel(torch.nn.Module):def __init__(self):super(LinearModel, self).__init__()#调用父类构造函数,不用管,照着写。# torch.nn.Linear(in_featuers, in_featuers)构造Linear类的对象,其实就是实现了一个线性单元self.linear = torch.nn.Linear(1, 1)def forward(self, x):y_pred = self.linear(x)#调用linear对象,输入x进行预测return y_predmodel = LinearModel()#实例化LinearModel()# 3. Construct Loss and Optimize
criterion = torch.nn.MSELoss(size_average=False)#size_average:the losses are averaged over each loss element in the batch.
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)#params:model.parameters(): w、b#4. Training Cycle
for epoch in range(100):y_pred = model(x_data)#Forward:预测loss = criterion(y_pred, y_data)#Forward:计算lossprint(epoch, loss)optimizer.zero_grad()#梯度清零loss.backward()#backward:计算梯度optimizer.step()#通过step()函数进行参数更新# Output weight and bias
print('w = ', model.linear.weight.item())
print('b = ', model.linear.bias.item())# Test Model
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)

输出结果:

85 tensor(0.2294, grad_fn=)
86 tensor(0.2261, grad_fn=)
87 tensor(0.2228, grad_fn=)
88 tensor(0.2196, grad_fn=)
89 tensor(0.2165, grad_fn=)
90 tensor(0.2134, grad_fn=)
91 tensor(0.2103, grad_fn=)
92 tensor(0.2073, grad_fn=)
93 tensor(0.2043, grad_fn=)
94 tensor(0.2014, grad_fn=)
95 tensor(0.1985, grad_fn=)
96 tensor(0.1956, grad_fn=)
97 tensor(0.1928, grad_fn=)
98 tensor(0.1900, grad_fn=)
99 tensor(0.1873, grad_fn=)
w = 1.711882472038269
b = 0.654958963394165
y_pred = tensor([[7.5025]])

可以看到误差还比较大,可以增加训练轮次,训练1000次后的结果:

980 tensor(2.1981e-07, grad_fn=)
981 tensor(2.1671e-07, grad_fn=)
982 tensor(2.1329e-07, grad_fn=)
983 tensor(2.1032e-07, grad_fn=)
984 tensor(2.0737e-07, grad_fn=)
985 tensor(2.0420e-07, grad_fn=)
986 tensor(2.0143e-07, grad_fn=)
987 tensor(1.9854e-07, grad_fn=)
988 tensor(1.9565e-07, grad_fn=)
989 tensor(1.9260e-07, grad_fn=)
990 tensor(1.8995e-07, grad_fn=)
991 tensor(1.8728e-07, grad_fn=)
992 tensor(1.8464e-07, grad_fn=)
993 tensor(1.8188e-07, grad_fn=)
994 tensor(1.7924e-07, grad_fn=)
995 tensor(1.7669e-07, grad_fn=)
996 tensor(1.7435e-07, grad_fn=)
997 tensor(1.7181e-07, grad_fn=)
998 tensor(1.6931e-07, grad_fn=)
999 tensor(1.6700e-07, grad_fn=)
w = 1.9997280836105347
b = 0.0006181497010402381
y_pred = tensor([[7.9995]])

练习

用以下这些优化器替换SGD,得到训练结果并画出损失曲线图。
在这里插入图片描述
比如说:Adam的loss图:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/625084.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Pandas实战100例 | 案例 51: 日期时间过滤

案例 51: 日期时间过滤 知识点讲解 当你的 DataFrame 包含 datetime 类型的列时,你可以基于日期时间条件过滤数据。这在处理时间序列数据时特别有用。 日期时间过滤: 使用布尔索引,可以根据日期时间条件过滤数据。 示例代码 # 准备数据和示例代码的…

SVD和EVD的关系

文章目录 SVD和EVD基本概念具体计算中的关系 SVD和EVD基本概念 奇异值分解(Singular Value Decomposition,SVD)和特征值分解(Eigenvalue Decomposition,EVD)是矩阵分解的两种常见方法,它们在线…

(1)(1.13) SiK无线电高级配置(六)

文章目录 前言 15 使用FTDI转USB调试线配置SiK无线电设备 16 强制启动加载程序模式 17 名词解释 前言 本文提供 SiK 遥测无线电(SiK Telemetry Radio)的高级配置信息。它面向"高级用户"和希望更好地了解无线电如何运行的用户。 15 使用FTDI转USB调试线配置SiK无线…

Oracle12c创建表空间及用户

Oracle12c创建表空间及用户 1. 表空间相关内容 表空间数据文件容量与DB_BLOCK_SIZE有关,在初始建库时,DB_BLOCK_SIZE要根据实际需要,设置为 4K、8K、16K、32K、64K等几种大小,ORACLE的物理文件最大只允许4194304个数据块&#xf…

Go 语言运算符详解:加法、算术、赋值、比较、逻辑和位运算符全面解析

运算符用于对变量和值执行操作。 加号运算符()将两个值相加,如下面的示例所示: 示例代码: package mainimport ("fmt" )func main() {var a 15 25fmt.Println(a) }尽管加号运算符通常用于将两个值相加&a…

vue3 锚点定位 点击滚动高亮

功能描述 点击导航跳到对应模块的起始位置,并且高亮点击的导航; 滚动到相应的模块时,对应的导航也自动高亮; 效果展示 注意事项 一定要明确哪个是要滚动的盒子;滚动的高度要减去导航栏的高度;当前在导航1…

C++ 树与图的深度优先遍历 || 模版题:树的重心

树和无向图都可以看成有向图&#xff08;无向图在添加边的时候添加双向的&#xff09; 下面是模版&#xff0c;实际使用要根据情况改&#xff1a; #include <iostream> #include <cstring> using namespace std;const int N 10010, M N * 2;int n; int h[N], e[…

【VTKExamples::PolyData】第二期 曲率

很高兴在雪易的CSDN遇见你 VTK技术爱好者 QQ:870202403 前言 本文分享VTK中的曲率计算及显示,希望对各位小伙伴有所帮助! 感谢各位小伙伴的点赞+关注,小易会继续努力分享,一起进步! 你的点赞就是我的动力(^U^)ノ~YO 目录 前言 1. Curvatures样例 2. vtkCurv…

【发票识别】支持pdf、ofd、图片格式(orc、信息提取)的发票

背景 为了能够满足识别各种发票的功能&#xff0c;特地开发了当前发票识别的功能&#xff0c;当前的功能支持pdf、ofd、图片格式的发票识别&#xff0c;使用到的技术包括文本提取匹配、ocr识别和信息提取等相关的技术&#xff0c;用到机器学习和深度学习的相关技术。 体验 体…

vue知识-06

es6导入导出语法 # 做项目&#xff1a;肯定要写模块--导入使用 # 如果包下有个 index.js 直接导到index.js上一次即可 默认导出和导入 : export default name // 只导出变量 export default add // 只导出函数 export default {name,add} // 导出对象 export defau…

Cesium中设置弹窗随轨迹动画对象移动

1.这是要移动的弹窗&#xff0c;隐藏显示逻辑、样式、展示内容自己写&#xff0c;主要就是动态设置弹窗的style&#xff0c;floatLeft和floatTop都是Vue中的data双向绑定数据&#xff1b; <div id"box" v-show"hasMove" :style"{ left: floatLeft…

【软件测试学习笔记2】用例设计方法

1.能对穷举场景设计测试点&#xff08;等价法&#xff09; 等价类&#xff1a; 说明&#xff1a;在所有测试数据中&#xff0c;具有某种共同特征的数据集合进行划分 分类&#xff1a;有效等价类&#xff1a;满足需求的数据集合 无效等价类&#xff1a;不满足需求的数据集合 步…

【Linux】Git - 新手入门

文章目录 1. git 版本控制器 - 该如何理解&#xff1f;2. git / gitee / github 区别&#xff1f;3. Linux 中 git 的使用3.1 安装 git3.2 使用 github 新建远端仓库3.2.1 账号注册3.2.2 创建代码仓库3.2.3 克隆仓库到本地3.2.4 .gitignore 文件 3.3 使用 git 提交代码到 githu…

刷题02 数组easy

1752.数组是否能经排序和轮转得到 如果整个数组非递减&#xff0c;返回true&#xff0c;如果只有两个子数列非递减&#xff0c;并且两个子序列之间是有序的&#xff0c;返回true。 先找到第一个不满足非递减的位置i&#xff0c;如果inumsize&#xff0c;说明整个数组非递减。否…

LeetCode 0082.删除排序链表中的重复元素 II:模拟

【LetMeFly】82.删除排序链表中的重复元素 II&#xff1a;模拟 力扣题目链接&#xff1a;https://leetcode.cn/problems/remove-duplicates-from-sorted-list-ii/ 给定一个已排序的链表的头 head &#xff0c; 删除原始链表中所有重复数字的节点&#xff0c;只留下不同的数字…

1【Linux】入门 (权限的理解||umask||粘滞位||cc++程序的翻译过程||解释性语言和编译性语言的区别)

无废话&#xff0c;全干货 一.权限 1.基本权限 读&#xff08;r&#xff09;&#xff1a; Read 对文件而言&#xff0c;具有读取文件内容的权限&#xff1b;对目录来说&#xff0c;具有浏览该目录信息的权限。 写&#xff08;w&#xff09;&#xff1a; Write 对文件而言&…

线程池执行流程详解,主要介绍其参数,以及执行过程中核心线程,队列,最大线程等的执行添加顺序,以及为什么这么做

线程池执行流程主要涉及到以下几个关键组件&#xff1a;核心线程数、任务队列、最大线程数以及拒绝策略等。以下是线程池的工作流程详解&#xff1a; 创建线程池&#xff1a; 当我们创建一个线程池时&#xff0c;需要指定几个核心参数&#xff0c;如corePoolSize(核心线程数)、…

数据结构学习 jz30 包含 min 函数的栈

关键词&#xff1a;排序 题目&#xff1a;最小栈 方法一&#xff1a;在记录这个数的同时&#xff0c;记录目前的最小值。看了提示才写出来的。 方法二&#xff1a;辅助栈。辅助栈保持非严格递减。看了k神的答案。 方法一&#xff1a; 一开始没想到怎么存最小&#xff0c;看…

【野火i.MX6NULL开发板】Linux系统下的Hello World

0、前言 参考资料&#xff1a; 《野火 Linux 基础与应用开发实战指南基于 i.MX6ULL 系列》PDF 第25章 本章比较抽象&#xff0c;涉及理论知识&#xff0c;不明白&#xff0c;可以看看视频讲解&#xff1a; https://www.bilibili.com/video/BV1JK4y1t7io?p29&vd_sourcef…

C++中JSON与string格式互转

1、JSON-》string 操作步骤&#xff1a; 1、在C中新建一个json对象并赋值&#xff0c;然后将其转给char *data。 2、在使用 #include <json.h> 头文件时&#xff0c;通常是使用第三方库 jsoncpp。由于它不是标准库的一部分&#xff0c;所以需要从官网http://jsoncpp.sou…