Pytorch笔记之回归

文章目录

  • 前言
  • 一、导入库
  • 二、数据处理
  • 三、构建模型
  • 四、迭代训练
  • 五、结果预测
  • 总结


前言

以线性回归为例,记录Pytorch的基本使用方法。


一、导入库

import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.autograd import Variable # 定义求导变量
from torch import nn, optim # 定义网络模型和优化器

二、数据处理

将数据类型转为tensor,第一维度变为batch_size

# 构建数据
x = np.random.rand(100)
noise = np.random.normal(0, 0.01, x.shape)
y = 0.1 * x + 0.2 + noise
# 数据处理
x_data = torch.FloatTensor(x.reshape(-1, 1))
y_data = torch.FloatTensor(y.reshape(-1, 1))
inputs = Variable(x_data)
target = Variable(y_data)

三、构建模型

1、继承nn.Module,定义一个线性回归模型。在__init__中定义连接层,定义前向传播的方法
2、实例化模型,定义损失函数与优化器

# 继承模型
class LinearRegression(nn.Module):def __init__(self):super().__init__()self.fc = nn.Linear(1, 1)def forward(self, x):out = self.fc(x)return out
# 定义模型
print('模型参数')
model = LinearRegression()
mse_loss = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)
for name, param in model.named_parameters():print('{}:{}'.format(name, param))

四、迭代训练

1、梯度清零:optimizer.zero_grad()
2、反向传播计算梯度值:loss.backward()
3、执行参数更新:optimizer.step()
循环迭代,定期输出损失值

print('损失值')
for i in range(1001):out = model.forward(inputs)loss = mse_loss(out, target)optimizer.zero_grad()loss.backward()optimizer.step()if i % 200 == 0:print(i, loss.item())

五、结果预测

绘制样本的散点图与预测值的折线图

print('结果预测')
y_pred = model(x_data)
plt.plot(x, y, 'b.')
plt.plot(x, y_pred.data.numpy(), 'r-')
plt.show()


总结

使用Pytorch进行训练主要的三步:
(1)数据处理:将数据维度转换为(batch, *),数据类型转换为可训练的tensor;
(2)构建模型:继承nn.Module,定义连接层与运算方法,实例化,定义损失函数与优化器;
(3)迭代训练:循环迭代,依次执行梯度清零、梯度计算、参数更新。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/99874.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

web基础及http协议

web基础 全称 world wide web 全球广域网也就是万维网 web1.0 只能看 web2.0 页面交互:静态页面和动态页面 静态页面url:文本文件,可以修改,一般以html .htm保存的文本文件。网站的基础。静态页面和后台数据库没有任何交互不包含…

如何下载IEEE Journal/Conference/Magazine的LaTeX/Word模板

当你准备撰写一篇学术论文或会议论文时,使用IEEE(电气和电子工程师协会)的LaTeX或Word模板是一种非常有效的方式,它可以帮助你确保你的文稿符合IEEE出版的要求。无论你是一名研究生生或一名资深学者,本教程将向你介绍如…

一站式工单系统哪家好?一站式工单系统有什么特点?

伴随着高新科技的不断发展和行业竞争的加重,对于一站式工单系统这一类的公司服务系统软件有着越来越多的流程规定和可靠性的要求。一个比较完善的智能化一站式工单系统包含众多的流程,并适用更广泛性的企业信息化,接下来我们将一起看看一站式…

数字孪生与GIS数据为何高度互补?二者融合后能达到什么样的效果?

山海鲸可视化作为一款数字孪生软件,在GIS的融合方面处于业内领先水平,那么为什么一款数字孪生软件要花费巨大的精力,去实现GIS的融合,实现后又能达到什么样的效果呢?下面就让我们来一探究竟。 一、为什么数字孪生需要…

iview表格 异步修改列数据卡顿 滚动条失效

使用表格row-key属性 将row-key属性设置为true <Table ref"table" border :row-key"true" :columns"tableColumns" :loading"loading":data"tableData"></Table>

【K8S系列】深入解析k8s 网络插件—kube-router

序言 做一件事并不难&#xff0c;难的是在于坚持。坚持一下也不难&#xff0c;难的是坚持到底。 文章标记颜色说明&#xff1a; 黄色&#xff1a;重要标题红色&#xff1a;用来标记结论绿色&#xff1a;用来标记论点蓝色&#xff1a;用来标记论点 在现代容器化应用程序的世界中…

Tomcat隔离web原理和热加载热部署

Tomcat 如何打破双亲委派机制 Tomcat 的自定义类加载器 WebAppClassLoader 打破了双亲委派机制&#xff0c;它首先自己尝试去加载某个类&#xff0c;如果找不到再代理给父类加载器&#xff0c;其目的是优先加载 Web 应用自己定义的类。具体实现就是重写 ClassLoader 的两个方法…

SpringBoot全局异常处理 | Java

⭐简单说两句⭐ 作者&#xff1a;后端小知识 CSDN个人主页&#xff1a;后端小知识 &#x1f50e;GZH&#xff1a;后端小知识 &#x1f389;欢迎关注&#x1f50e;点赞&#x1f44d;收藏⭐️留言&#x1f4dd; 文章目录 ✨SpringBoot全局异常处理 | Java&#x1f3a8;什么是全局…

KUKA机器人通过直接输入法设定负载数据和附加负载数据的具体操作

KUKA机器人通过直接输入法设定负载数据和附加负载数据的具体操作 设置背景色: 工具负载数据 工具负载的定义: 工具负载数据是指所有装在机器人法兰上的负载。它是另外装在机器人上并由机器人一起移动的质量。需要输入的值有质量、重心位置、质量转动惯量以及所属的主惯性轴。…

爱尔眼科角膜塑形镜验配超百万,全力做好“角塑镜把关人”

你知道吗?过去的2022年&#xff0c;我国儿童青少年总体近视率为53.6%&#xff0c;其中6岁儿童为14.5%&#xff0c;小学生为36%&#xff0c;初中生为71.6%&#xff0c;高中生为81%①。儿童青少年眼健康问题俨然成为全社会关心的热点与痛点&#xff0c;牵动着每一个人的神经。 好…

给手机上液冷?谈谈华为Mate 60系列手机专属黑科技—— “微泵液冷”手机壳

最近&#xff0c;有一个手机配件吸引了我的注意——华为的微泵液冷壳。 简单来说&#xff0c;就是在手机壳里装了无线充电微泵&#xff0c;为手机实现外置水冷的能力。让手机壳在“外观装饰”和“防摔保护”的功能性上额外加了一个“降温提性能”的作用。 接下来&#xff0c;本…

MySQL之MHA集群

MHA概述 什么是 MHA MHA&#xff08;MasterHigh Availability&#xff09;是一套优秀的MySQL高可用环境下故障切换和主从复制的软件。 MHA 的出现就是解决MySQL 单点故障的问题。 MySQL故障切换过程中&#xff0c;MHA能做到0-30秒内自动完成故障切换操作。 MHA能在故障切换…

iPhone手机上使用的定时提醒APP是哪个

在日常喧闹的生活和工作中&#xff0c;琐碎的任务会像喷泉一样突涌而至&#xff0c;如不及时规划&#xff0c;我们将陷入手足无措的境地。而想要让各项工作任务按时完成&#xff0c;我们可以借助一些比较好用的时间提醒软件来督促各项任务。 就拿常用的iPhone手机来讲&#xf…

Flutter的Platform介绍-跨平台开发,如何根据不同平台创建不同UI和行为

文章目录 Flutter跨平台概念介绍跨平台开发平台相关性Platform ChannelPlatform-specific UIPlatform Widgets 如何判断当前是什么平台实例 Platform 类介绍获取当前平台的名称检查当前平台其他属性 利用flutter设计跨Android和IOS平台应用的技巧1. 遵循平台的设计准则2. 使用平…

回归算法全解析!一文读懂机器学习中的回归模型

目录 一、引言回归问题的重要性文章目的和结构概览 二、回归基础什么是回归问题例子&#xff1a; 回归与分类的区别例子&#xff1a; 回归问题的应用场景例子&#xff1a; 三、常见回归算法3.1 线性回归数学原理代码实现输出例子&#xff1a; 3.2 多项式回归数学原理代码实现输…

Ubuntu磁盘满了,导致黑屏

前言 &#xff08;1&#xff09;最近要玩Milk-V Duo&#xff0c;配置环境过程中&#xff0c;发现磁盘小了。于是退出虚拟机&#xff0c;扩大Ubuntu大小&#xff0c;重新开机&#xff0c;发现无法进入Ubuntu界面。 &#xff08;2&#xff09;查了很久&#xff0c;后面发现是磁盘…

这短短 6 行代码你能数出几个bug?

前言&#xff1a;本文仅仅只是分享笔者一年前见到的诡异代码&#xff0c;大家可以看看乐子&#xff0c;随便数一数一共有多少个bug&#xff0c;这数bug多少还是要点水平的 在初学编程的时候&#xff0c;写的第一个代码大多都是 hello world&#xff0c;可是就算是 hello world…

Windows安装Node.js

1、Node.js介绍 ①、Node.js简介 Node.js是一个开源的、跨平台的JavaScript运行环境&#xff0c;它允许开发者使用JavaScript语言来构建高性能的网络应用程序和服务器端应用。Node.js的核心特点包括&#xff1a; 1. 事件驱动: Node.js采用了事件驱动的编程模型&#xff0c;通…

04训练——基于YOLO V8的自定义数据集训练——在windows环境下使用pycharm做训练-1总体步骤

在上文中,笔者介绍了使用google公司提供的免费GPU资源colab来对大量的自定义数据集进行模型训练。该方法虽然简单好用,但是存在以下几方面的短板问题: 一是需要通过虚拟服务器做为跳板机来访问,总体操作起来非常繁杂。 二是需要将大量的数据上传缓慢,管理和使用非常不友…

FutureTask和CompletableFuture的模拟使用

模拟了查询耗时操作&#xff0c;并使用FutureTask和CompletableFuture分别获取计算结果&#xff0c;统计执行时长 package org.alllearn.futurtask;import com.google.common.base.Stopwatch; import com.google.common.collect.Lists; import lombok.AllArgsConstructor; imp…