用Pytorch实现线性回归模型

目录

  • 回顾
  • Pytorch实现
    • 步骤
    • 1. 准备数据
    • 2. 设计模型
      • class LinearModel
      • 代码
    • 3. 构造损失函数和优化器
    • 4. 训练过程
    • 5. 输出和测试
    • 完整代码
  • 练习

回顾

前面已经学习过线性模型相关的内容,实现线性模型的过程并没有使用到Pytorch。
这节课主要是利用Pytorch实现线性模型。
学习器训练:

  • 确定模型(函数)
  • 定义损失函数
  • 优化器优化(SGD)

之前用过Pytorch的Tensor进行Forward、Backward计算。
现在利用Pytorch框架来实现。

Pytorch实现

步骤

  1. 准备数据集
  2. 设计模型(计算预测值y_hat):从nn.Module模块继承
  3. 构造损失函数和优化器:使用PytorchAPI
  4. 训练过程:Forward、Backward、update

1. 准备数据

在PyTorch中计算图是通过mini-batch形式进行,所以X、Y都是多维的Tensor。
在这里插入图片描述

import torch
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])

2. 设计模型

在之前讲解梯度下降算法时,我们需要自己计算出梯度,然后更新权重。
在这里插入图片描述
而使用Pytorch构造模型,重点时在构建计算图和损失函数上。
在这里插入图片描述

class LinearModel

通过构造一个 class LinearModel类来实现,所有的模型类都需要继承nn.Module,这是所有神经忘了模块的基础类。
class LinearModel这种定义的模型类必须包含两个部分:

  • init():构造函数,进行初始化。
    def __init__(self):super(LinearModel, self).__init__()#调用父类构造函数,不用管,照着写。# torch.nn.Linear(in_featuers, in_featuers)构造Linear类的对象,其实就是实现了一个线性单元self.linear = torch.nn.Linear(1, 1)

在这里插入图片描述

  • forward():进行前馈计算
    (backward没有被写,是因为在这种模型类里面会自动实现)

Class nn.Linear 实现了magic method call():它使类的实例可以像函数一样被调用。通常会调用forward()。

    def forward(self, x):y_pred = self.linear(x)#调用linear对象,输入x进行预测return y_pred

代码

class LinearModel(torch.nn.Module):def __init__(self):super(LinearModel, self).__init__()#调用父类构造函数,不用管,照着写。# torch.nn.Linear(in_featuers, in_featuers)构造Linear类的对象,其实就是实现了一个线性单元self.linear = torch.nn.Linear(1, 1)def forward(self, x):y_pred = self.linear(x)#调用linear对象,输入x进行预测return y_predmodel = LinearModel()#实例化LinearModel()

3. 构造损失函数和优化器

采用MSE作为损失函数

torch.nn.MSELoss(size_average,reduce)

  • size_average:是否求mini-batch的平均loss。
  • reduce:降维,不用管。

在这里插入图片描述SGD作为优化器torch.optim.SGD(params, lr):

  • params:参数
  • lr:学习率

在这里插入图片描述

criterion = torch.nn.MSELoss(size_average=False)#size_average:the losses are averaged over each loss element in the batch.
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)#params:model.parameters(): w、b

4. 训练过程

  1. 预测
  2. 计算loss
  3. 梯度清零
  4. Backward
  5. 参数更新
    简化:Forward–>Backward–>更新
#4. Training Cycle
for epoch in range(100):y_pred = model(x_data)#Forward:预测loss = criterion(y_pred, y_data)#Forward:计算lossprint(epoch, loss)optimizer.zero_grad()#梯度清零loss.backward()#backward:计算梯度optimizer.step()#通过step()函数进行参数更新

5. 输出和测试

# Output weight and bias
print('w = ', model.linear.weight.item())
print('b = ', model.linear.bias.item())# Test Model
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)

完整代码

import torch
#1. Prepare dataset
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])#2. Design Model
class LinearModel(torch.nn.Module):def __init__(self):super(LinearModel, self).__init__()#调用父类构造函数,不用管,照着写。# torch.nn.Linear(in_featuers, in_featuers)构造Linear类的对象,其实就是实现了一个线性单元self.linear = torch.nn.Linear(1, 1)def forward(self, x):y_pred = self.linear(x)#调用linear对象,输入x进行预测return y_predmodel = LinearModel()#实例化LinearModel()# 3. Construct Loss and Optimize
criterion = torch.nn.MSELoss(size_average=False)#size_average:the losses are averaged over each loss element in the batch.
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)#params:model.parameters(): w、b#4. Training Cycle
for epoch in range(100):y_pred = model(x_data)#Forward:预测loss = criterion(y_pred, y_data)#Forward:计算lossprint(epoch, loss)optimizer.zero_grad()#梯度清零loss.backward()#backward:计算梯度optimizer.step()#通过step()函数进行参数更新# Output weight and bias
print('w = ', model.linear.weight.item())
print('b = ', model.linear.bias.item())# Test Model
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)

输出结果:

85 tensor(0.2294, grad_fn=)
86 tensor(0.2261, grad_fn=)
87 tensor(0.2228, grad_fn=)
88 tensor(0.2196, grad_fn=)
89 tensor(0.2165, grad_fn=)
90 tensor(0.2134, grad_fn=)
91 tensor(0.2103, grad_fn=)
92 tensor(0.2073, grad_fn=)
93 tensor(0.2043, grad_fn=)
94 tensor(0.2014, grad_fn=)
95 tensor(0.1985, grad_fn=)
96 tensor(0.1956, grad_fn=)
97 tensor(0.1928, grad_fn=)
98 tensor(0.1900, grad_fn=)
99 tensor(0.1873, grad_fn=)
w = 1.711882472038269
b = 0.654958963394165
y_pred = tensor([[7.5025]])

可以看到误差还比较大,可以增加训练轮次,训练1000次后的结果:

980 tensor(2.1981e-07, grad_fn=)
981 tensor(2.1671e-07, grad_fn=)
982 tensor(2.1329e-07, grad_fn=)
983 tensor(2.1032e-07, grad_fn=)
984 tensor(2.0737e-07, grad_fn=)
985 tensor(2.0420e-07, grad_fn=)
986 tensor(2.0143e-07, grad_fn=)
987 tensor(1.9854e-07, grad_fn=)
988 tensor(1.9565e-07, grad_fn=)
989 tensor(1.9260e-07, grad_fn=)
990 tensor(1.8995e-07, grad_fn=)
991 tensor(1.8728e-07, grad_fn=)
992 tensor(1.8464e-07, grad_fn=)
993 tensor(1.8188e-07, grad_fn=)
994 tensor(1.7924e-07, grad_fn=)
995 tensor(1.7669e-07, grad_fn=)
996 tensor(1.7435e-07, grad_fn=)
997 tensor(1.7181e-07, grad_fn=)
998 tensor(1.6931e-07, grad_fn=)
999 tensor(1.6700e-07, grad_fn=)
w = 1.9997280836105347
b = 0.0006181497010402381
y_pred = tensor([[7.9995]])

练习

用以下这些优化器替换SGD,得到训练结果并画出损失曲线图。
在这里插入图片描述
比如说:Adam的loss图:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/625084.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

(1)(1.13) SiK无线电高级配置(六)

文章目录 前言 15 使用FTDI转USB调试线配置SiK无线电设备 16 强制启动加载程序模式 17 名词解释 前言 本文提供 SiK 遥测无线电(SiK Telemetry Radio)的高级配置信息。它面向"高级用户"和希望更好地了解无线电如何运行的用户。 15 使用FTDI转USB调试线配置SiK无线…

vue3 锚点定位 点击滚动高亮

功能描述 点击导航跳到对应模块的起始位置,并且高亮点击的导航; 滚动到相应的模块时,对应的导航也自动高亮; 效果展示 注意事项 一定要明确哪个是要滚动的盒子;滚动的高度要减去导航栏的高度;当前在导航1…

【发票识别】支持pdf、ofd、图片格式(orc、信息提取)的发票

背景 为了能够满足识别各种发票的功能,特地开发了当前发票识别的功能,当前的功能支持pdf、ofd、图片格式的发票识别,使用到的技术包括文本提取匹配、ocr识别和信息提取等相关的技术,用到机器学习和深度学习的相关技术。 体验 体…

vue知识-06

es6导入导出语法 # 做项目:肯定要写模块--导入使用 # 如果包下有个 index.js 直接导到index.js上一次即可 默认导出和导入 : export default name // 只导出变量 export default add // 只导出函数 export default {name,add} // 导出对象 export defau…

【Linux】Git - 新手入门

文章目录 1. git 版本控制器 - 该如何理解?2. git / gitee / github 区别?3. Linux 中 git 的使用3.1 安装 git3.2 使用 github 新建远端仓库3.2.1 账号注册3.2.2 创建代码仓库3.2.3 克隆仓库到本地3.2.4 .gitignore 文件 3.3 使用 git 提交代码到 githu…

LeetCode 0082.删除排序链表中的重复元素 II:模拟

【LetMeFly】82.删除排序链表中的重复元素 II:模拟 力扣题目链接:https://leetcode.cn/problems/remove-duplicates-from-sorted-list-ii/ 给定一个已排序的链表的头 head , 删除原始链表中所有重复数字的节点,只留下不同的数字…

数据结构学习 jz30 包含 min 函数的栈

关键词:排序 题目:最小栈 方法一:在记录这个数的同时,记录目前的最小值。看了提示才写出来的。 方法二:辅助栈。辅助栈保持非严格递减。看了k神的答案。 方法一: 一开始没想到怎么存最小,看…

【野火i.MX6NULL开发板】Linux系统下的Hello World

0、前言 参考资料: 《野火 Linux 基础与应用开发实战指南基于 i.MX6ULL 系列》PDF 第25章 本章比较抽象,涉及理论知识,不明白,可以看看视频讲解: https://www.bilibili.com/video/BV1JK4y1t7io?p29&vd_sourcef…

拼多多今年的校招薪资。。。

拼多多校招情况分析 关于校招情况分析,我们写过了争议巨巨巨巨大的 华为、互联网宇宙厂 字节跳动 以及如今有点高攀不起的新能源车企 比亚迪。 群里收集过小伙伴的意见,除上述大厂以外,大家最感兴趣的还是市值刚超过阿里的砍厂:拼…

transbigdata笔记:其他方法

1 出租车相关 1.1 taxigps_to_od transbigdata.taxigps_to_od(data, col[VehicleNum, Stime, Lng, Lat, OpenStatus]) 输入出租车GPS数据,提取OD信息 data出租车GPS数据col[VehicleNum, Time, Lng, Lat, OpenStatus]五列 比如GPS数据长这样: oddata …

Maven《二》-- Maven的安装与配置(亲测成功版)

目录 🐶2.1 Maven的安装条件 🐶2.2 Maven安装步骤 1. 检查本地%JAVA_HOME% 2. 解压maven 3. 配置maven的环境变量 4. 校验maven是否配置成功 5. 配置本地仓库 🐶2.3 Idea配置本地Maven软件 🐶2.1 Maven的安装条件 各个工具…

为什么要找实习以及如何更好地度过实习期

前言 在职业发展的旅程中,实习是一个至关重要的阶段。不论是在大学生涯的尾声,还是在职场新人的起步阶段,寻找实习机会都是一项关键任务。然而,为什么要找实习?这个问题背后蕴含着更深层次的意义和价值。在这篇博客中…

java SSM物资采购管理系统myeclipse开发mysql数据库springMVC模式java编程计算机网页设计

一、源码特点 java SSM物资采购管理系统是一套完善的web设计系统(系统采用SSM框架进行设计开发,springspringMVCmybatis),对理解JSP java编程开发语言有帮助,系统具有完整的源代 码和数据库,系统主要采…

burp靶场-API testing

burp靶场 1.服务端主题 1.API测试 https://portswigger.net/web-security/api-testing#top 1.1 api探测api路径,数据格式,交互方法,参数是否必选: ## 使用Burp Scanner来爬取 API https://portswigger.net/burp/vulnerabilit…

虚幻UE 材质-材质图层、材质图层混合

学习材质图层和材质图层混合的使用,便于节点扫盲。 文章目录 前言一、材质图层混合二、使用步骤总结 前言 材质混合我们之前用Bridge的插件进行混合过 而此次我们的材质混合使用UE自带的材质图层和材质图层混合来实现 一、材质图层混合 材质图层混合是一种允许将…

Github镜像加速器-FastGit

简介 FastGit 是一个对于 GitHub.com 的镜像加速器。使用共享资源为 GitHub 加速。 FastGit中文指南 # 基本使用 关于 FastGit 的使用,本质上与 git 有关。常规的面向 GitHub 的 clone 命令可能如下: git clone https://github.com/author/repo使用 F…

烟火检测/周界入侵/视频智能识别AI智能分析网关V4如何配置ONVIF摄像机接入

AI边缘计算智能分析网关V4性能高、功耗低、检测速度快,易安装、易维护,硬件内置了近40种AI算法模型,支持对接入的视频图像进行人、车、物、行为等实时检测分析,上报识别结果,并能进行语音告警播放。算法可按需组合、按…

《C++大学教程》4.25星号正方形

题目: //while循环实现int main() {int n;cout << "请输入边长&#xff1a;";cin >> n;int i 1; while (i < n){ // 控制行数int j 1;while (j < n){ // 控制列数if (i 1 || i n || j 1 || j n){cout << "*";}else{cout <…

yarn包管理器在添加、更新、删除模块时,在项目中是如何体现的

技术很久不用&#xff0c;就变得生疏起来。对npm深受其害&#xff0c;决定对yarn再整理一遍。 yarn包管理器 介绍安装yarn帮助信息最常用命令 介绍 yarn官网&#xff1a;https://yarn.bootcss.com&#xff0c;学任何技术的最新知识&#xff0c;都可以通过其对应的网站了解。无…

浏览器打印无法显示单选框选中效果

上面是原代码&#xff0c;我点击打印&#xff0c;出现打印页面&#xff0c;但单选框并未勾选中&#xff0c;我在外部放了一模一样的代码是能勾选上的&#xff0c;于是我对打印页的input单选框进行分析&#xff0c;发现他丢失了checked属性。然后通过gpt分析原因。得知了default…