PyTorch 的 10 条内部用法

alt

欢迎阅读这份有关 PyTorch 原理的简明指南[1]。无论您是初学者还是有一定经验,了解这些原则都可以让您的旅程更加顺利。让我们开始吧!

1. 张量:构建模块

PyTorch 中的张量是多维数组。它们与 NumPy 的 ndarray 类似,但可以在 GPU 上运行。

import torch

# Create a 2x3 tensor
tensor = torch.tensor([[123], [456]])
print(tensor)

2. 动态计算图

PyTorch 使用动态计算图,这意味着该图是在执行操作时即时构建的。这为在运行时修改图形提供了灵活性。

# Define two tensors
a = torch.tensor([2.], requires_grad=True)
b = torch.tensor([3.], requires_grad=True)

# Compute result
c = a * b
c.backward()

# Gradients
print(a.grad)  # Gradient w.r.t a

3.GPU加速

PyTorch 允许在 CPU 和 GPU 之间轻松切换。利用 .to(device) 获得最佳性能。

device = "cuda" if torch.cuda.is_available() else "cpu"
tensor = tensor.to(device)

4. Autograd:自动微分

PyTorch 的 autograd 为张量上的所有操作提供自动微分。设置 require_grad=True 来跟踪计算。

x = torch.tensor([2.], requires_grad=True)
y = x**2
y.backward()
print(x.grad)  # Gradient of y w.r.t x

5. 带有 nn.Module 的模块化神经网络

PyTorch 提供 nn.Module 类来定义神经网络架构。通过子类化创建自定义层。

import torch.nn as nn

class SimpleNN(nn.Module):

    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(11)
        
    def forward(self, x):
        return self.fc(x)

6. 预定义层和损失函数

PyTorch 在 nn 模块中提供了各种预定义层、损失函数和优化算法。

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

7. 数据集和DataLoader

为了高效的数据处理和批处理,PyTorch 提供了 Dataset 和 DataLoader 类。

from torch.utils.data import Dataset, DataLoader

class CustomDataset(Dataset):
    # ... (methods to define)
    
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)

8.模型训练循环

通常,PyTorch 中的训练遵循以下模式:前向传递、计算损失、后向传递和参数更新。

for epoch in range(epochs):
    for data, target in data_loader:
        optimizer.zero_grad()
        output = model(data)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()

9. 模型序列化

使用 torch.save() 和 torch.load() 保存和加载模型。

# Save
torch.save(model.state_dict(), 'model_weights.pth')

# Load
model.load_state_dict(torch.load('model_weights.pth'))

10. Eager Execution and JIT

虽然 PyTorch 默认情况下以 eager 模式运行,但它为生产就绪模型提供即时 (JIT) 编译。

scripted_model = torch.jit.script(model)
scripted_model.save("model_jit.pt")

Reference

[1]

Source: https://medium.com/@kasperjuunge/10-principles-of-pytorch-bbe4bf0c42cd

本文由 mdnice 多平台发布

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/222593.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于 Webpack5 Module Federation 的业务解耦实践

前言 本文中会提到很多目前数栈中使用的特定名词,统一做下解释描述 dt-common:每个子产品都会引入的公共包(类似 NPM 包) AppMenus:在子产品中快速进入到其他子产品的导航栏,统一维护在 dt-common 中,子产品从 dt-com…

c/c++ | 少量内存溢出不会影响程序的正常运行

常见 通过指针 获得 字符串“aaaaaaaaaaaaaaaaaaaaaaaaaaa” 数据,然后通过strcpy 拷贝到对象 char str1[5] 中,事实是造成了内存溢出,但是 ,最后 str1 打印输出的结果 是上面的 “aaaaaaaaaaaaaaaaaaaaaaaaaaa” 这是不正常的&am…

Draw.io or diagrams.net 使用方法

0 Preface/Foreword 在工作中,经常需要用到框图,流程图,时序图,等等,draw.io可以完成以上工作。 official website:draw.io 1 Usage 1.1 VS code插件 draw.io可以扩展到VS code工具中。

百度搜索展现服务重构:进步与优化

作者 | 瞭东 导读 本文将简单介绍搜索展现服务发展过程,以及当前其面临的三大挑战:研发难度高、架构能力欠缺、可复用性低,最后提出核心解决思路和具体落地方案,期望大家能有所收货和借鉴。 全文4736字,预计阅读时间12…

产品入门第四讲:Axure动态面板

📚📚 🏅我是默,一个在CSDN分享笔记的博主。📚📚 ​​​​​ 🌟在这里,我要推荐给大家我的专栏《Axure》。🎯🎯 🚀无论你是编程小白,还…

17.分割有效信息【2023.12.9】

1.问题描述 有时候我们需要截取字符串以获取有用的信息,比如对于字符串 “日期:2010-10-29”,我们需要截取后面的 10 个字符来获取日期,以便进行进一步分析。编写一个程序,输入一个字符串,然后输出截取后的…

【Spark精讲】Spark Shuffle详解

目录 Shuffle概述 Shuffle执行流程 总体流程 中间文件 ShuffledRDD生成 Stage划分 Task划分 Map端写入(Shuffle Write) Reduce端读取(Shuffle Read) Spark Shuffle演变 SortShuffleManager运行机制 普通运行机制 bypass 运行机制 Tungsten Sort Shuffle 运行机制…

Spark环境搭建和使用方法

目录 一、安装Spark (一)基础环境 (二)安装Python3版本 (三)下载安装Spark (四)配置相关文件 二、在pyspark中运行代码 (一)pyspark命令 &#xff08…

AR眼镜光学方案_AR眼镜整机硬件定制

增强现实(Augmented Reality,AR)技术通过将计算机生成的虚拟物体或其他信息叠加到真实世界中,实现对现实的增强。AR眼镜作为实现AR技术的重要设备,具备虚实结合、实时交互的特点。为了实现透视效果,AR眼镜需要同时显示真实的外部世…

基于vue实现的疫情数据可视化分析及预测系统-计算机毕业设计推荐django

目 录 摘 要 I ABSTRACT II 目 录 II 第1章 绪论 1 1.1背景及意义 1 1.2 国内外研究概况 1 1.3 研究的内容 1 第2章 相关技术 3 2.1 nodejs简介 4 2.2 express框架介绍 6 2.4 MySQL数据库 4 第3章 系统分析 5 3.1 需求分析 5 3.2 系统可行性分析 5 3.2.1技术可行性:…

mac python安装grpcio以及xcode升级权限问题记录

问题1: ERROR: Could not build wheels fol grpcio, which is required to install pyproject.toml-based projects pip3 install --no-cache-dir --force-reinstall -Iv grpcio1.41.0 # (我这里是降级安装的) 问题2: fatal error: ‘stdio.h’ file not found 25 | #include …

案例课7——百度智能客服

1.公司介绍 百度智能客服是百度智能云推出的将AI技术赋能企业客服业务的一揽子解决方案。该方案基于百度世界先进的语音技术、自然语言理解技术、知识图谱等构建完备的一体化产品方案,结合各行业头部客户丰富的运营经验,持续深耕机场服务、电力调度等场…

智能化配电房

智能化配电房是一种集成了先进技术和智能化设备的配电房,通过智能采集终端与通信设备,实时将电气参数、运行信息和环境数据传送至智慧电力物联网平台—电易云,对配电室进行数字化升级,对运维工作数字化升级,建设电力系…

使用Qt制作网易云播放器的歌曲排行界面

!!!直接上图!!! !!!直接上图!!! !!!直接上图!!! 网易云排行榜…

论文润色降重哪个平台好 papergpt

大家好,今天来聊聊论文润色降重哪个平台好,希望能给大家提供一点参考。 以下是针对论文重复率高的情况,提供一些修改建议和技巧: 标题:论文润色降重哪个平台好――专业、高效、可靠的学术支持 一、引言 在学术研究中&…

机器学习 | 机器学习基础知识

一、机器学习是什么 计算机从数据中学习规律并改善自身进行预测的过程。 二、数据集 1、最常用的公开数据集 2、结构化数据与非结构化数据 三、任务地图 1、分类任务 Classification 已知样本特征判断样本类别二分类、多分类、多标签分类 二分类:垃圾邮件分类、图像…

风速预测(二)基于Pytorch的EMD-LSTM模型

目录 前言 1 风速数据EMD分解与可视化 1.1 导入数据 1.2 EMD分解 2 数据集制作与预处理 2.1 先划分数据集,按照8:2划分训练集和测试集 2.2 设置滑动窗口大小为7,制作数据集 3 基于Pytorch的EMD-LSTM模型预测 3.1 数据加载&#xff0…

QT QIFW Linux下制作软件安装包

一、概述 和windows的操作步骤差不多,我们需要下装linux下的安装程序,然后修改config.xml、installscript.qs和package.xml文件。 QT QIFW Windows下制作安装包(一)-CSDN博客 一、下装QIFW 下装地址:/official_releases/qt-installer-fra…

SpringBoot中日志的使用log4j2

SpringBoot中日志的使用log4j2 1、log4j2介绍 Apache Log4j2 是对 Log4j 的升级,它比其前身 Log4j 1.x 提供了重大改进,并提供了 Logback 中可用的许多改 进,同时修复了 Logback 架构中的一些问题,主要有: 异常处理…

Google DeepMind发布Imagen 2文字到图像生成模型;微软在 HuggingFace 上发布了 Phi-2 的模型

🦉 AI新闻 🚀 Google DeepMind发布Imagen 2文字到图像生成模型 摘要:谷歌的Imagen 2是一种先进的文本到图像技术,可以生成与用户提示紧密对齐的高质量、逼真的图像。它通过使用训练数据的自然分布来生成更逼真的图像&#xff0c…