pytorch学习--第一个模型(线性模型)

目标

我们想通过随机初始化的参数 ω , b \omega ,b ω,b能在迭代过程中使预测值和目标值能无限接近
y = ω x + b y=\omega x+b y=ωx+b

定义数据

x = torch.rand([60, 1])*10
y = x*2 + torch.randn(60,1)

构建模型

利用pytorch中的nn.Module
想要构建模型时,继承这个类即可
一些重写nn.Module类时的注意事项
(1)一般把网络中具有可学习参数的层(如全连接层、卷积层等)放在构造函数__init__()中;
(2)一般把不具有可学习参数的层(如ReLU、dropout、BatchNormanation层)可放在构造函数中,也可不放在构造函数中,如果不放在构造函数__init__里面,则在forward方法里面可以使用nn.functional来代替
(3)forward方法是必须要重写的,它是实现模型的功能,实现各个层之间的连接关系的核心。

from torch import nn
class Lr(nn.Module):def __init__(self):super(Lr, self).__init__()  #继承父类init的参数self.linear = nn.Linear(1, 1) #只有线性层(全链接层)def forward(self, x):out = self.linear(x)#输出return out

输出的数量nn.Linear(in_features, out_features);nn.Linear(1, 1)这里的参数易知我们通过方程得到的最终是一列数

# 实例化模型
model = Lr()
# 传入数据,计算结果
predict = model(x)

优化器

1、优化器主要是在模型训练阶段对模型可学习参数进行更新, 常用优化器有 SGD,RMSprop,Adam等
2、优化器初始化时传入传入模型的可学习参数,以及其他超参数如 lr,momentum等
3、在训练过程中先调用 optimizer.zero_grad() 清空梯度,再调用 loss.backward() 反向传播,最后调用 optimizer.step()更新模型参数
4、参数可以使用model.parameters()来获取,获取模型中所有requires_grad=True的参数

optimizer = optim.SGD(model.parameters(), lr=1e-3) #1. 实例化,1e-3也可以写成0.001
optimizer.zero_grad() #2. 梯度置为0
loss.backward() #3. 计算梯度
optimizer.step()  #4. 更新参数的值

损失函数

torch中有很多损失函数

1、均方误差:nn.MSELoss(),常用于回归问题

2、交叉熵损失:nn.CrossEntropyLoss(),常用于分类问题

criterion = nn.MSELoss() # 实例化损失函数

训练模型

1、定义一个epoch,代表需要将所有数据训练epoch个轮次
2、数据传入模型,获取预测值
3、将预测值和目标值传入损失函数,计算损失
4、优化器的梯度归零,在每次更新参数中必须进行此步骤,否则梯度会一直累加
5、计算梯度,此步骤在4之后进行
6、更新梯度,参数随之更新
7、(可选)在训练过程中每隔一段时间打印下损失,观察收敛速度

#训练模型
for i in range(30000):out = model(x)  # 3.1 获取预测值loss = criterion(y, out)  # 3.2 计算损失optimizer.zero_grad()  # 3.3 梯度归零loss.backward()  # 3.4 计算梯度optimizer.step()  # 3.5 更新梯度if i % 300 == 0:print('Epoch[{}/{}], loss: {:.6f}'.format(i, 30000, loss.data))

模型测试

在模型的测试中,我们一般会使用测试集来评估训练得到的模型,这时候我们不需要梯度相关的操作,只需要将数据通过模型,得到损失、精确率等即可。测试中有以下需要注意:

model.eval()  # 设置模型为评估模式,即预测模式
predict = model(x)

绘图

predict = predict.data.numpy()
plt.scatter(x.data.numpy(), y.data.numpy(), c="r")
plt.plot(x.data.numpy(), predict)
plt.show()

在GPU上运行

判断GPU是否可用torch.cuda.is_available()

1、torch.device(“cuda:0” if torch.cuda.is_available() else “cpu”)

device(type=‘cuda’, index=0) #使用gpu
device(type=‘cpu’) #使用cpu

2、把模型参数和input数据转化为cuda的支持类型

model.to(device)
x_true.to(device)

3、在GPU上计算结果也为cuda的数据类型,需要转化为numpy或者torch的cpu的tensor类型

predict = predict.cpu().detach().numpy()
detach()的效果和data的相似,但是detach()是深拷贝,data是取值,是浅拷贝

完整代码

import torch
from torch import nn
from torch import optim
import numpy as np
from matplotlib import pyplot as plt# 1. 定义数据
x = torch.rand([60, 1])*10
y = x*2 + torch.randn(60,1)# 2 .定义模型
class Lr(nn.Module):def __init__(self):super(Lr, self).__init__()self.linear = nn.Linear(1, 1)def forward(self, x):out = self.linear(x)return out# 2. 实例化模型,loss,和优化器
model = Lr()
# 损失函数
criterion = nn.MSELoss()
# 优化器
optimizer = optim.SGD(model.parameters(), lr=1e-3)
# 3. 训练模型
for i in range(30000):out = model(x)  # 3.1 获取预测值loss = criterion(y, out)  # 3.2 计算损失optimizer.zero_grad()  # 3.3 梯度归零loss.backward()  # 3.4 计算梯度optimizer.step()  # 3.5 更新梯度if i % 300 == 0:print('Epoch[{}/{}], loss: {:.6f}'.format(i, 30000, loss.data))# 4. 模型评估
model.eval()  # 设置模型为评估模式,即预测模式
predict = model(x)
predict = predict.data.numpy()
plt.scatter(x.data.numpy(), y.data.numpy(), c="r")
plt.plot(x.data.numpy(), predict)
plt.show()

在这里插入图片描述

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/3709.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

(四)「消息队列」之 RabbitMQ 路由(使用 .NET 客户端)

0、引言 先决条件 本教程假设 RabbitMQ 已安装并且正在 本地主机 的标准端口(5672)上运行。如果您使用了不同的主机、端口或凭证,则要求调整连接设置。 获取帮助 如果您在阅读本教程时遇到问题,可以通过邮件列表或者 RabbitMQ 社区…

Meta发布升级大模型LLaMA 2:开源可商用

论文地址:https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/ Github地址:https://github.com/facebookresearch/llama LLaMA 2介绍 Meta之前发布自了半开源的大模型LLaMA,自从LLaMA发布以来…

30、spring是什么

spring是什么? 1、轻量级的开源的J2EE框架。它是一个容器框架,用来装javabean (java对象),中间层框架(万能胶)可以起一个连接作用,比如说把Struts和hibernate粘合在一起运用,可以让我们的企业开发更快、更简洁 2、Spring是一个…

win11安装.net framework 3.5

背景 安装专业软件时需要.net framework 3.54错误代码:0x80072f8f控制面板安装失败microsoft安装失败DISM安装依旧失败 解决方法 下载win11 iso文件 下载网址 note:不能挂v p niso文件装载后,找到*\sources\sxs路径复制文件夹到新建路径,这…

C# WPF实现动画渐入暗黑明亮主题切换效果

C# WPF实现动画渐入暗黑明亮主题切换效果 效果图如下最近在Bilibili的桌面端看到一个黑白主题切换的效果感觉,挺有意思。于是我使用WPF尝试实现该效果。 主要的切换效果,基本实现不过还存在一些小瑕疵,比如字体等笔刷不能跟随动画进入进行切…

Docker简介

Docker简介 文章目录 Docker简介一、Docker1.什么是docker?2.容器引擎3.容器和虚拟机的区别4.namespace(命名空间)5.三大容器核心概念镜像容器仓库 二、Docker镜像操作1.搜索镜像2.获取镜像镜像加速下载 3.查看本地下载镜像4.获取镜像详细信息5.为本地镜…

代码随想录第二十二天

代码随想录第二十二天 Leetcode 235. 二叉搜索树的最近公共祖先Leetcode 701. 二叉搜索树中的插入操作Leetcode 450. 删除二叉搜索树中的节点 Leetcode 235. 二叉搜索树的最近公共祖先 题目链接: 二叉搜索树的最近公共祖先 自己的思路:乍一看和二叉树的最近公共祖先类似&#…

SQL 上升的温度

197 上升的温度 SQL架构 表: Weather ---------------------- | Column Name | Type | ---------------------- | id | int | | recordDate | date | | temperature | int | ---------------------- id 是这个表的主键 该表包含特定日期的温度信息 编写一个 SQL …

事务@transactional执行产生重复数据

背景 系统设计之初,每次来新请求,业务层会先查询数据库,判断是否存在相同的id数据(id是唯一标识产品的),有则返回当前数据库查到的数据,根据数据决定下一步动作,没有则认为是初次请…

销售自动化如何提高团队生产力?从这5个方面发力

任何用于减少人工劳动和缩短销售流程相关任务时间的技术,都可定义为销售自动化。 对于忙碌的销售人员来说,流程自动化是真正的救星。它可以使他们的工作简化30%,让他们更专注于创收任务。这将显著提高团队的工作效率,并带来许多其…

Kubernetes - kubeadm部署(二)

kubeadm 1. docker 和 containerd1.1 命令区分1.2 常用命令1.3 关于crictl 2. 示例2.1 部署nginx 先从一个报错开始吧 [rootnode-129 kubernetes]# crictl ps WARN[0000] runtime connect using default endpoints: [unix:///var/run/dockershim.sock unix:///run/containerd/…

滑动奇异频谱分析:数据驱动的非平稳信号分解工具(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

django中的model的一些笔记

class WarehouseRecordDetail(models.Model):warehouse_record models.ForeignKey(WarehouseRecord, verbose_nameu"出入库记录",related_namewarehouse_record_detail_ref_warehouse_record, on_deletemodels.PROTECT)warehouse_stock_record models.ForeignKey(W…

师承AI世界新星|7天获新加坡南洋理工大学访学邀请函

能够拜师在“人工智能10大新星”名下,必定可以学习到前沿技术,受益良多,本案例中的C老师无疑就是这个幸运儿。我们只用了7天时间就取得了这位AI新星导师的邀请函,最终C老师顺利获批CSC,如愿出国。 C老师背景&#xff1…

Leetcode刷题4

⼆叉树、BFS、堆、Top K、⼆叉搜索树、模拟、图算法 一、二叉树 二叉树的前序中序后序 二叉树节点定义 为了方便演示,我们先定义一个二叉树节点类。 class TreeNode:def __init__(self, val0, leftNone, rightNone):self.val valself.left leftself.right r…

数据结构C

数据结构 表 线性表 线性表是 具有相同数据类型的 n个数据元素 的有限序列(有次序),其中n为表长,当n0时 线性表是一个空表。 若用L命名线性表,则其一般表示为 L {a1,a2,…,ai,ai1,…,an} 几个概念: …

Android ViewGroup onDraw为什么没调用

ViewGroup,它本身并没有任何可画的东西,它是一个透明的控件,因些并不会触发onDraw,但是你现在给LinearLayout设置一个背景色,其实这个背景色不管你设置成什么颜色,系统会认为,这个LinearLayout上…

[回馈]ASP.NET Core MVC开发实战之商城系统(开篇)

在编程方面,从来都是实践出真知,书读百遍其义自见,所以实战是最好的提升自己编程能力的方式。 前一段时间,写了一些实战系列文章,如: ASP.NET MVC开发学生信息管理系统VueAntdvAsp.net WebApi开发学生信息…

R语言的水文、水环境模型优化技术及快速率定方法与多模型案例实践

在水利、环境、生态、机械以及航天等领域中,数学模型已经成为一种常用的技术手段。同时,为了提高模型的性能,减小模型误用带来的风险;模型的优化技术也被广泛用于模型的使用过程。模型参数的快速优化技术不但涉及到优化本身而且涉…

Python中的break和continue语句应用举例

Python中的break和continue语句应用举例 在进行Python编程时候,有时需要,对循环中断或跳过某部分语句,此时常会用到break语句或continue语句。本文将通过实际例子阐述这两个语句的用法。 1.break语句 break语句是实现在某个地方中断循环&a…