pytorch学习--第一个模型(线性模型)

目标

我们想通过随机初始化的参数 ω , b \omega ,b ω,b能在迭代过程中使预测值和目标值能无限接近
y = ω x + b y=\omega x+b y=ωx+b

定义数据

x = torch.rand([60, 1])*10
y = x*2 + torch.randn(60,1)

构建模型

利用pytorch中的nn.Module
想要构建模型时,继承这个类即可
一些重写nn.Module类时的注意事项
(1)一般把网络中具有可学习参数的层(如全连接层、卷积层等)放在构造函数__init__()中;
(2)一般把不具有可学习参数的层(如ReLU、dropout、BatchNormanation层)可放在构造函数中,也可不放在构造函数中,如果不放在构造函数__init__里面,则在forward方法里面可以使用nn.functional来代替
(3)forward方法是必须要重写的,它是实现模型的功能,实现各个层之间的连接关系的核心。

from torch import nn
class Lr(nn.Module):def __init__(self):super(Lr, self).__init__()  #继承父类init的参数self.linear = nn.Linear(1, 1) #只有线性层(全链接层)def forward(self, x):out = self.linear(x)#输出return out

输出的数量nn.Linear(in_features, out_features);nn.Linear(1, 1)这里的参数易知我们通过方程得到的最终是一列数

# 实例化模型
model = Lr()
# 传入数据,计算结果
predict = model(x)

优化器

1、优化器主要是在模型训练阶段对模型可学习参数进行更新, 常用优化器有 SGD,RMSprop,Adam等
2、优化器初始化时传入传入模型的可学习参数,以及其他超参数如 lr,momentum等
3、在训练过程中先调用 optimizer.zero_grad() 清空梯度,再调用 loss.backward() 反向传播,最后调用 optimizer.step()更新模型参数
4、参数可以使用model.parameters()来获取,获取模型中所有requires_grad=True的参数

optimizer = optim.SGD(model.parameters(), lr=1e-3) #1. 实例化,1e-3也可以写成0.001
optimizer.zero_grad() #2. 梯度置为0
loss.backward() #3. 计算梯度
optimizer.step()  #4. 更新参数的值

损失函数

torch中有很多损失函数

1、均方误差:nn.MSELoss(),常用于回归问题

2、交叉熵损失:nn.CrossEntropyLoss(),常用于分类问题

criterion = nn.MSELoss() # 实例化损失函数

训练模型

1、定义一个epoch,代表需要将所有数据训练epoch个轮次
2、数据传入模型,获取预测值
3、将预测值和目标值传入损失函数,计算损失
4、优化器的梯度归零,在每次更新参数中必须进行此步骤,否则梯度会一直累加
5、计算梯度,此步骤在4之后进行
6、更新梯度,参数随之更新
7、(可选)在训练过程中每隔一段时间打印下损失,观察收敛速度

#训练模型
for i in range(30000):out = model(x)  # 3.1 获取预测值loss = criterion(y, out)  # 3.2 计算损失optimizer.zero_grad()  # 3.3 梯度归零loss.backward()  # 3.4 计算梯度optimizer.step()  # 3.5 更新梯度if i % 300 == 0:print('Epoch[{}/{}], loss: {:.6f}'.format(i, 30000, loss.data))

模型测试

在模型的测试中,我们一般会使用测试集来评估训练得到的模型,这时候我们不需要梯度相关的操作,只需要将数据通过模型,得到损失、精确率等即可。测试中有以下需要注意:

model.eval()  # 设置模型为评估模式,即预测模式
predict = model(x)

绘图

predict = predict.data.numpy()
plt.scatter(x.data.numpy(), y.data.numpy(), c="r")
plt.plot(x.data.numpy(), predict)
plt.show()

在GPU上运行

判断GPU是否可用torch.cuda.is_available()

1、torch.device(“cuda:0” if torch.cuda.is_available() else “cpu”)

device(type=‘cuda’, index=0) #使用gpu
device(type=‘cpu’) #使用cpu

2、把模型参数和input数据转化为cuda的支持类型

model.to(device)
x_true.to(device)

3、在GPU上计算结果也为cuda的数据类型,需要转化为numpy或者torch的cpu的tensor类型

predict = predict.cpu().detach().numpy()
detach()的效果和data的相似,但是detach()是深拷贝,data是取值,是浅拷贝

完整代码

import torch
from torch import nn
from torch import optim
import numpy as np
from matplotlib import pyplot as plt# 1. 定义数据
x = torch.rand([60, 1])*10
y = x*2 + torch.randn(60,1)# 2 .定义模型
class Lr(nn.Module):def __init__(self):super(Lr, self).__init__()self.linear = nn.Linear(1, 1)def forward(self, x):out = self.linear(x)return out# 2. 实例化模型,loss,和优化器
model = Lr()
# 损失函数
criterion = nn.MSELoss()
# 优化器
optimizer = optim.SGD(model.parameters(), lr=1e-3)
# 3. 训练模型
for i in range(30000):out = model(x)  # 3.1 获取预测值loss = criterion(y, out)  # 3.2 计算损失optimizer.zero_grad()  # 3.3 梯度归零loss.backward()  # 3.4 计算梯度optimizer.step()  # 3.5 更新梯度if i % 300 == 0:print('Epoch[{}/{}], loss: {:.6f}'.format(i, 30000, loss.data))# 4. 模型评估
model.eval()  # 设置模型为评估模式,即预测模式
predict = model(x)
predict = predict.data.numpy()
plt.scatter(x.data.numpy(), y.data.numpy(), c="r")
plt.plot(x.data.numpy(), predict)
plt.show()

在这里插入图片描述

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/3709.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

(四)「消息队列」之 RabbitMQ 路由(使用 .NET 客户端)

0、引言 先决条件 本教程假设 RabbitMQ 已安装并且正在 本地主机 的标准端口(5672)上运行。如果您使用了不同的主机、端口或凭证,则要求调整连接设置。 获取帮助 如果您在阅读本教程时遇到问题,可以通过邮件列表或者 RabbitMQ 社区…

Meta发布升级大模型LLaMA 2:开源可商用

论文地址:https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/ Github地址:https://github.com/facebookresearch/llama LLaMA 2介绍 Meta之前发布自了半开源的大模型LLaMA,自从LLaMA发布以来…

C# WPF实现动画渐入暗黑明亮主题切换效果

C# WPF实现动画渐入暗黑明亮主题切换效果 效果图如下最近在Bilibili的桌面端看到一个黑白主题切换的效果感觉,挺有意思。于是我使用WPF尝试实现该效果。 主要的切换效果,基本实现不过还存在一些小瑕疵,比如字体等笔刷不能跟随动画进入进行切…

Docker简介

Docker简介 文章目录 Docker简介一、Docker1.什么是docker?2.容器引擎3.容器和虚拟机的区别4.namespace(命名空间)5.三大容器核心概念镜像容器仓库 二、Docker镜像操作1.搜索镜像2.获取镜像镜像加速下载 3.查看本地下载镜像4.获取镜像详细信息5.为本地镜…

SQL 上升的温度

197 上升的温度 SQL架构 表: Weather ---------------------- | Column Name | Type | ---------------------- | id | int | | recordDate | date | | temperature | int | ---------------------- id 是这个表的主键 该表包含特定日期的温度信息 编写一个 SQL …

事务@transactional执行产生重复数据

背景 系统设计之初,每次来新请求,业务层会先查询数据库,判断是否存在相同的id数据(id是唯一标识产品的),有则返回当前数据库查到的数据,根据数据决定下一步动作,没有则认为是初次请…

销售自动化如何提高团队生产力?从这5个方面发力

任何用于减少人工劳动和缩短销售流程相关任务时间的技术,都可定义为销售自动化。 对于忙碌的销售人员来说,流程自动化是真正的救星。它可以使他们的工作简化30%,让他们更专注于创收任务。这将显著提高团队的工作效率,并带来许多其…

滑动奇异频谱分析:数据驱动的非平稳信号分解工具(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

师承AI世界新星|7天获新加坡南洋理工大学访学邀请函

能够拜师在“人工智能10大新星”名下,必定可以学习到前沿技术,受益良多,本案例中的C老师无疑就是这个幸运儿。我们只用了7天时间就取得了这位AI新星导师的邀请函,最终C老师顺利获批CSC,如愿出国。 C老师背景&#xff1…

Leetcode刷题4

⼆叉树、BFS、堆、Top K、⼆叉搜索树、模拟、图算法 一、二叉树 二叉树的前序中序后序 二叉树节点定义 为了方便演示,我们先定义一个二叉树节点类。 class TreeNode:def __init__(self, val0, leftNone, rightNone):self.val valself.left leftself.right r…

Android ViewGroup onDraw为什么没调用

ViewGroup,它本身并没有任何可画的东西,它是一个透明的控件,因些并不会触发onDraw,但是你现在给LinearLayout设置一个背景色,其实这个背景色不管你设置成什么颜色,系统会认为,这个LinearLayout上…

[回馈]ASP.NET Core MVC开发实战之商城系统(开篇)

在编程方面,从来都是实践出真知,书读百遍其义自见,所以实战是最好的提升自己编程能力的方式。 前一段时间,写了一些实战系列文章,如: ASP.NET MVC开发学生信息管理系统VueAntdvAsp.net WebApi开发学生信息…

R语言的水文、水环境模型优化技术及快速率定方法与多模型案例实践

在水利、环境、生态、机械以及航天等领域中,数学模型已经成为一种常用的技术手段。同时,为了提高模型的性能,减小模型误用带来的风险;模型的优化技术也被广泛用于模型的使用过程。模型参数的快速优化技术不但涉及到优化本身而且涉…

Python中的break和continue语句应用举例

Python中的break和continue语句应用举例 在进行Python编程时候,有时需要,对循环中断或跳过某部分语句,此时常会用到break语句或continue语句。本文将通过实际例子阐述这两个语句的用法。 1.break语句 break语句是实现在某个地方中断循环&a…

Java设计模式之行为型-迭代器模式(UML类图+案例分析)

目录 一、基础概念 二、UML类图 三、角色设计 四、案例分析 五、总结 一、基础概念 迭代器模式是一种常用的设计模式,它主要用于遍历集合对象,提供一种方法顺序访问一个聚合对象中的各个元素,而又不暴露该对象的内部表示。 举个简单的…

5分钟给你破解这套10万赞的生产教程,访谈乔布斯的AI对话数字人视频是怎么做的

本期是赤辰第16期AI项目拆解栏目; 底部准备了7月粉丝福利,看完可以领取; 上周给粉丝们讲解AI动图说话月涨粉20万的案例并给出保姆式教程,粉丝反馈很热烈,都觉得AI强大,有些学员给自己账号做视频&#xff…

大数据与视频技术的融合趋势将带来怎样的场景应用?

视频技术和AI技术的融合是一种新兴的技术趋势,它将改变视频行业的运作方式。视频技术和AI技术的融合主要包括以下几个方面: 1)人脸识别技术 人脸识别技术是AI技术的一个重要应用场景。它可以通过对视频中的人脸进行识别和分析,实…

3.9 Bootstrap 分页

文章目录 Bootstrap 分页分页(Pagination)默认的分页分页的状态分页的大小 翻页(Pager)默认的翻页对齐的链接翻页的状态 分页 Bootstrap 分页 本章将讲解 Bootstrap 支持的分页特性。分页(Pagination)&…

Unity平台如何实现RTSP转RTMP推送?

技术背景 Unity平台下,RTSP、RTMP播放和RTMP推送,甚至包括轻量级RTSP服务这块都不再赘述,今天探讨的一位开发者提到的问题,如果在Unity下,实现RTSP播放的同时,随时转RTMP推送出去? RTSP转RTMP…

浙大数据结构第四周之04-树6 Complete Binary Search Tree

题目详情: A Binary Search Tree (BST) is recursively defined as a binary tree which has the following properties: The left subtree of a node contains only nodes with keys less than the nodes key.The right subtree of a node contains only nodes w…