Pytorch笔记之回归

文章目录

  • 前言
  • 一、导入库
  • 二、数据处理
  • 三、构建模型
  • 四、迭代训练
  • 五、结果预测
  • 总结


前言

以线性回归为例,记录Pytorch的基本使用方法。


一、导入库

import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.autograd import Variable # 定义求导变量
from torch import nn, optim # 定义网络模型和优化器

二、数据处理

将数据类型转为tensor,第一维度变为batch_size

# 构建数据
x = np.random.rand(100)
noise = np.random.normal(0, 0.01, x.shape)
y = 0.1 * x + 0.2 + noise
# 数据处理
x_data = torch.FloatTensor(x.reshape(-1, 1))
y_data = torch.FloatTensor(y.reshape(-1, 1))
inputs = Variable(x_data)
target = Variable(y_data)

三、构建模型

1、继承nn.Module,定义一个线性回归模型。在__init__中定义连接层,定义前向传播的方法
2、实例化模型,定义损失函数与优化器

# 继承模型
class LinearRegression(nn.Module):def __init__(self):super().__init__()self.fc = nn.Linear(1, 1)def forward(self, x):out = self.fc(x)return out
# 定义模型
print('模型参数')
model = LinearRegression()
mse_loss = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)
for name, param in model.named_parameters():print('{}:{}'.format(name, param))

四、迭代训练

1、梯度清零:optimizer.zero_grad()
2、反向传播计算梯度值:loss.backward()
3、执行参数更新:optimizer.step()
循环迭代,定期输出损失值

print('损失值')
for i in range(1001):out = model.forward(inputs)loss = mse_loss(out, target)optimizer.zero_grad()loss.backward()optimizer.step()if i % 200 == 0:print(i, loss.item())

五、结果预测

绘制样本的散点图与预测值的折线图

print('结果预测')
y_pred = model(x_data)
plt.plot(x, y, 'b.')
plt.plot(x, y_pred.data.numpy(), 'r-')
plt.show()


总结

使用Pytorch进行训练主要的三步:
(1)数据处理:将数据维度转换为(batch, *),数据类型转换为可训练的tensor;
(2)构建模型:继承nn.Module,定义连接层与运算方法,实例化,定义损失函数与优化器;
(3)迭代训练:循环迭代,依次执行梯度清零、梯度计算、参数更新。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/99874.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

web基础及http协议

web基础 全称 world wide web 全球广域网也就是万维网 web1.0 只能看 web2.0 页面交互:静态页面和动态页面 静态页面url:文本文件,可以修改,一般以html .htm保存的文本文件。网站的基础。静态页面和后台数据库没有任何交互不包含…

阿里云/腾讯云国际站:私服服务器:什么是游戏虚拟服务器及用途讲解?

游戏虚拟服务器是一种新兴的技术,它可以为玩家提供更好的游戏体验。私服服务器它可以将游戏服务器的负载分散到多台服务器上,从而提高游戏的流畅度和稳定性。此外,游戏虚拟服务器还可以提供更多的游戏功能,比如游戏聊天室、游戏排…

如何下载IEEE Journal/Conference/Magazine的LaTeX/Word模板

当你准备撰写一篇学术论文或会议论文时,使用IEEE(电气和电子工程师协会)的LaTeX或Word模板是一种非常有效的方式,它可以帮助你确保你的文稿符合IEEE出版的要求。无论你是一名研究生生或一名资深学者,本教程将向你介绍如…

一站式工单系统哪家好?一站式工单系统有什么特点?

伴随着高新科技的不断发展和行业竞争的加重,对于一站式工单系统这一类的公司服务系统软件有着越来越多的流程规定和可靠性的要求。一个比较完善的智能化一站式工单系统包含众多的流程,并适用更广泛性的企业信息化,接下来我们将一起看看一站式…

如何才能在Ubuntu系统部署RabbitMQ服务器并公网访问

在Ubuntu系统上部署RabbitMQ服务器并公网访问,可以按照以下步骤进行: 安装RabbitMQ服务器: 在终端中输入以下命令安装RabbitMQ服务器: sudo apt-get update sudo apt-get install rabbitmq-server启动RabbitMQ服务器: …

数字孪生与GIS数据为何高度互补?二者融合后能达到什么样的效果?

山海鲸可视化作为一款数字孪生软件,在GIS的融合方面处于业内领先水平,那么为什么一款数字孪生软件要花费巨大的精力,去实现GIS的融合,实现后又能达到什么样的效果呢?下面就让我们来一探究竟。 一、为什么数字孪生需要…

iview表格 异步修改列数据卡顿 滚动条失效

使用表格row-key属性 将row-key属性设置为true <Table ref"table" border :row-key"true" :columns"tableColumns" :loading"loading":data"tableData"></Table>

Java之UDP,TCP的详细解析

练习四&#xff1a;文件名重复 public class UUIDTest { public static void main(String[] args) { String str UUID.randomUUID().toString().replace("-", ""); System.out.println(str);//9f15b8c356c54f55bfcb0ee3023fce8a } } public class Client…

【K8S系列】深入解析k8s 网络插件—kube-router

序言 做一件事并不难&#xff0c;难的是在于坚持。坚持一下也不难&#xff0c;难的是坚持到底。 文章标记颜色说明&#xff1a; 黄色&#xff1a;重要标题红色&#xff1a;用来标记结论绿色&#xff1a;用来标记论点蓝色&#xff1a;用来标记论点 在现代容器化应用程序的世界中…

java最优建树算法

建树算法 树的数据结构 {"code": "1111","name": "","parentcode": "0000","children": null }, {"code": "2222","name": "","parentcode": &q…

面试算法19:最多删除一个字符得到回文

题目 给定一个字符串&#xff0c;请判断如果最多从字符串中删除一个字符能不能得到一个回文字符串。例如&#xff0c;如果输入字符串"abca"&#xff0c;由于删除字符’b’或’c’就能得到一个回文字符串&#xff0c;因此输出为true。 分析 本题还是从字符串的两端…

Tomcat隔离web原理和热加载热部署

Tomcat 如何打破双亲委派机制 Tomcat 的自定义类加载器 WebAppClassLoader 打破了双亲委派机制&#xff0c;它首先自己尝试去加载某个类&#xff0c;如果找不到再代理给父类加载器&#xff0c;其目的是优先加载 Web 应用自己定义的类。具体实现就是重写 ClassLoader 的两个方法…

SpringBoot全局异常处理 | Java

⭐简单说两句⭐ 作者&#xff1a;后端小知识 CSDN个人主页&#xff1a;后端小知识 &#x1f50e;GZH&#xff1a;后端小知识 &#x1f389;欢迎关注&#x1f50e;点赞&#x1f44d;收藏⭐️留言&#x1f4dd; 文章目录 ✨SpringBoot全局异常处理 | Java&#x1f3a8;什么是全局…

CAXA汇总尺寸避免混乱

把单图的尺寸名称全部重命名

Android 13 骁龙相机点击拍照流程分析(二)——点击拍照到存入相册

一.前言 本篇是在Android 13 骁龙相机点击拍照流程分析(一)——点击拍照到更新到左下角缩略图文章的基础上进行延申的,前面的预览、点击拍照的过程参考第一篇:Android 13 骁龙相机点击拍照流程分析(一)——点击拍照到更新到左下角缩略图-CSDN博客 二.生成图片并保存 从第…

KUKA机器人通过直接输入法设定负载数据和附加负载数据的具体操作

KUKA机器人通过直接输入法设定负载数据和附加负载数据的具体操作 设置背景色: 工具负载数据 工具负载的定义: 工具负载数据是指所有装在机器人法兰上的负载。它是另外装在机器人上并由机器人一起移动的质量。需要输入的值有质量、重心位置、质量转动惯量以及所属的主惯性轴。…

爱尔眼科角膜塑形镜验配超百万,全力做好“角塑镜把关人”

你知道吗?过去的2022年&#xff0c;我国儿童青少年总体近视率为53.6%&#xff0c;其中6岁儿童为14.5%&#xff0c;小学生为36%&#xff0c;初中生为71.6%&#xff0c;高中生为81%①。儿童青少年眼健康问题俨然成为全社会关心的热点与痛点&#xff0c;牵动着每一个人的神经。 好…

给手机上液冷?谈谈华为Mate 60系列手机专属黑科技—— “微泵液冷”手机壳

最近&#xff0c;有一个手机配件吸引了我的注意——华为的微泵液冷壳。 简单来说&#xff0c;就是在手机壳里装了无线充电微泵&#xff0c;为手机实现外置水冷的能力。让手机壳在“外观装饰”和“防摔保护”的功能性上额外加了一个“降温提性能”的作用。 接下来&#xff0c;本…

深度学习DAY3:神经网络训练常见算法概述

梯度下降法&#xff08;Gradient Descent&#xff09;&#xff1a; 这是最常见的神经网络训练方法之一。它通过计算损失函数对权重的梯度&#xff0c;并沿着梯度的反方向更新权重&#xff0c;从而逐步减小损失函数的值。梯度下降有多个变种&#xff0c;包括随机梯度下降&#…

软考-高级系统架构师复习知识点总结

软件架构概念 软件架构定义 系统的一个或多个结构。结构中包含软件的构件、构件的外部可见属性、构件之间的相互关系。 软件架构设计与生命周期 需求分析阶段:问题空间,根据需求模型构件SA模型、保证模型转换的可追踪性 设计阶段 SA模型的描述: SA的基本概念:构件、…