机器学习 - PyTorch使用流程

通常的 PyTorch Workflow 是这样的. But the workflow steps can be repeated and changed depending on the problem you’re working on.

  1. Get data ready (turn into tensors)
  2. Build or pick a pretrained model to suit your problem
    2.1 Pick a loss function & optimizer
    2.2 Build a training loop
  3. Fit the model to the data and make a prediction
  4. Evaluate the model
  5. Improve through experimentation
  6. Save and reload your trained model
TopicContents
Getting data readyData can be almost anything but to get started we’re going to create a simple straight line
Build a modelCreate a model to learn patterns in the data, and choose a loss function, optimizer and build a training loop
Fitting the model to data (training)Got the data and a model, now let’s the model (try to) find patterns in the (training) data.
Making predictions and evaluating a model (inference)The model’s found patterns in the data, let’s compare its findings to the actual (testing) data.
Saving and loading a modelYou may want to use your model elsewhere, or come back to it later
Putting it all togetherLet’s take all of the above and combine it.

或者也可以是这几个步骤:

  1. 数据准备:首先准备好数据集,包括训练集,验证集和测试集。PyTorch提供了一系列工具和类来加载,预处理和组织数据,例如:torch.utils.data.Datasettorch.utils.data.DataLoader
  2. 模型定义:定义神经网络模型的结构,包括网络层的组织结构,激活函数等。可以使用PyTorch提供的torch.nn.Module类来创建模型。
  3. 损失函数定义:根据任务的性质选择合适的损失函数,用于衡量模型预测与真实标签之间的差异。PyTorch提供了各种损失函数,例如交叉熵损失函数,均方误差损失函数等。
  4. 优化器选择:选择合适的优化算法来更新模型参数,使得损失函数最小化。常见的优化算法包括随机梯度下降 (SGD),Adam, RMSprop等。PyTorch提供了torch.optim模块来实现各种优化算法。
  5. 模型训练:使用准备好的数据集,模型,损失函数和优化器来进行模型训练。训练过程通常包括多个周期 (epochs),每个周期包括数据集的多个批次 (batches)。在每个批次中,依次执行以下步骤:
    • 前向传播 (Forward Pass): 将输入数据传递给模型,计算模型的输出。
    • 计算损失值:使用损失函数计算模型输出与真实标签之间的损失之。
    • 反向传播 (Backward Pass): 根据损失值计算模型参数的梯度。
    • 参数更新:使用优化器根据参数的梯度更新模型参数。
  6. 模型评估:使用验证集或测试集评估训练好的模型的性能。通常会计算模型在验证集或测试集上的准确率,精确率,召回率等指标。
  7. 模型保存和部署:将训练好的模型保存为文件,并在需要时加载模型进行预测。PyTorch提供了·torch.save()torch.load() 函数来保存和加载模型。模型也可以通过TorchScript进行序列化,以便于在其他平台上进行部署。

看到这了,给个赞呗~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/757403.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

第1章 初识 Python 背记手册

1、print()—输出 print()函数的基本用法如下: print("输出的内容")其中,输出内容可以是数字和字符串(使用引号括起来),此类内容将直接输出,也 可 以是包含运算符的表达式,此类内容…

彻底学会系列:一、机器学习之梯度下降(2)

1 梯度具体是怎么下降的? ∂ J ( θ ) ∂ θ \frac{\partial J (\theta )}{\partial \theta} ∂θ∂J(θ)​(损失函数:用来衡量模型预测值与真实值之间差异的函数) 对损失函数求导,与学习率相乘,按梯度反方…

马斯克AI大模型Grok开源了!

2024年3月18日,马斯克的AI创企xAI兑现承诺,正式发布了此前备受期待大模型Grok-1。 代码和模型权重已上线GitHub: https://github.com/xai-org/grok-1 截止目前,Grok已经在GitHub上获得了35.2k颗Star,还在不断上升中。 Grok官方博…

yolov9目标检测可视化图形界面GUI源码

该系统是由微智启软件工作室基于yolov9pyside6开发的目标检测可视化界面系统 运行环境: window python3.8 安装依赖后,运行源码目录下的wzq.py启动 程序提供了ui源文件,可以拖动到Qt编辑器修改样式,然后通过pyside6把ui转成python…

【11】工程化

一、为什么需要模块化 当前端工程到达一定规模后,就会出现下面的问题: 全局变量污染 依赖混乱 上面的问题,共同导致了代码文件难以细分 模块化就是为了解决上面两个问题出现的 模块化出现后,我们就可以把臃肿的代码细分到各个小文件中,便于后期维护管理 前端模块化标准…

Cookie、Session、Token详解及基于JWT的Token实现的用户登陆身份认证

目录 前置知识 Cookie 什么是Cookie Cookie的作用 Cookie的声命周期 Session 什么是Session 服务集群下Session存在的问题 集群模式下Session无法共享问题的解决 Cookie和Session的对比 Token 什么是Token 为什么产生Token 基于JWT的Token认证机制 Token的优势 …

第112讲:Mycat实践指南:字符串Hash算法分片下的水平分表详解

文章目录 1.字符串Hash算法分片的概念1.1.字符串Hash算法的概念1.2.字符串Hash算法是如何将数据路由到分片节点的 2.使用字符串Hash算法分片对某张表进行水平拆分2.1.在所有的分片节点中创建表结构2.2.配置Mycat实现字符串Hash算法分片的水平分表2.2.1.配置Schema配置文件2.2.2…

Redis Pub/Sub: 实时消息传递的完美解决方案

Redis发布订阅(Pub/Sub)是一种消息传递模式,允许消息的发送者(发布者)将消息发送给多个接收者(订阅者)。在Redis中,发布者和订阅者之间通过频道(Channel)进行…

算法刷题day33

目录 引言一、动态网格二、画图三、扫雷 引言 这几天一直再写关于搜索的问题,我发现搜索不仅仅局限于网格中的那种搜索,还有状态的变换,也可以抽象成一个点,去找最小变换次数,这也是一种搜索,所以说还是得…

SpringData JPA 快速入门案例详解

SpringData JPA JPA 简介: JPA(Java Persistence API)是 Java 持久层规范,定义了一些列 ORM 接口,它本身是不能直接使用的,因为接口需要实现才能使用,Hibernate 框架就是实现 JPA 规范的框架。…

colab中数据集保存到drive与取出的方法

from google.colab import drive drive.mount(/content/drive) 一、下载数据集 from datasets import load_dataset max_length 32 # Maximum length of the captions in tokens coco_dataset_ratio 50 # 50% of the COCO2014 dataset# Load the COCO2014 dataset for tr…

浅谈MVVM、MVC、MVP的区别

MVC、MVP 和 MVVM 是三种常见的软件架构设计模式,主要通过分离关注点的方式来组织代码结构,优化开发效率。 在开发单页面应用时,往往一个路由页面对应了一个脚本文件,所有的页面逻辑都在一个脚本文件里。页面的渲染、数据的获取&…

计算机毕业设计-基于python的旅游信息爬取以及数据分析

概要 随着计算机网络技术的发展,近年来,新的编程语言层出不穷,python语言就是近些年来最为火爆的一门语言,python语言,相对于其他高级语言而言,python有着更加便捷实用的模块以及库,具有语法简单…

使用原生nodejs搭建一个简易的web服务器demo

简易demo var http require(http); var url require("url"); const app http.createServer(function (request, response) {var urlObj url.parse(request.url,true);console.log(request.url);// 内容类型: text/plain。并用charsetUTF-8解决输出中文乱码respon…

S2-066漏洞分析与复现(CVE-2023-50164)

Foreword 自struts2官方纰漏S2-066漏洞已经有一段时间,期间断断续续地写,直到最近才完成,o(╥﹏╥)o。羞愧地回顾一下官方通告: 2023.12.9发布,编号CVE-2023-50164,主要影响版本是 2.5.0-2.5.32 以及 6.0…

QT6实现创建与操作sqlite数据库三种方式方式对比(二)

一.概述 Qt访问Sqlite数据库的三种方式(即使用三种类库去访问),QSqlQuery、QSqlQueryModel、QSqlTableModel,对于这三种类库,可看为一个比一个上层,也就是封装的更厉害,甚至第三种QSqlTableModel,根本就不…

Spring Security AuthenticatedVoter 错误访问控制漏洞复现(CVE-2024-22257)

免责声明 由于传播、利用本CSDN所提供的信息而造成的任何直接或者间接的后果及损失,均由使用者本人负责,作者不为此承担任何责任,一旦造成后果请自行承担! 一、产品介绍 Spring Security 是基于Spring应用程序的认证和访问控制框架。 二、漏洞描述 Spring Security在处理…

JJJ:改善ubuntu网速慢的方法

Ubuntu 系统默认的软件下载源由于服务器的原因, 在国内的下载速度往往比较慢,这时我 们可以将 Ubuntu 系统的软件下载源更改为国内软件源,譬如阿里源、中科大源、清华源等等, 下载速度相比 Ubuntu 官方软件源会快很多!…

[AIGC] 在Spring Boot中指定请求体格式

在使用Spring Boot开发Web应用的时候,我们经常会遇到需要接收并处理HTTP请求的情况。一个HTTP请求通常包括一个请求行、若干请求头和一个请求体。请求体在POST和PUT请求中特别重要,因为它通常用于向服务器传递数据。 文章目录 创建并使用一个Java Bean指…

【技术栈】Redis 企业级解决方案

​ SueWakeup 个人主页:SueWakeup ​​​​​​​ 系列专栏:学习技术栈 ​​​​​​​ ​​​​​​​ ​​​​​​​ ​​​​​​​ ​​​​​​​ ​​​​​​​ ​​​​​​​ 个性签名&…