【PyTorch知识点汇总】

PyTorch是一个广泛使用的深度学习框架,它提供了许多功能强大的工具和函数,用于构建和训练神经网络。以下是一些PyTorch的常用知识点和示例说明:

  1. 张量(Tensors)

    • 创建张量:使用torch.tensor()​、torch.Tensor()​或特定创建函数如torch.zeros()​, torch.ones()​, torch.randn()​等创建不同类型的张量。

      import torch
      x = torch.tensor([1., 2., 3.])  # 创建一个浮点型张量
      zeros_tensor = torch.zeros((3, 4))  # 创建一个3x4的全零张量
      
    • 张量操作:类似NumPy,支持各种数学运算和索引操作,如加减乘除、矩阵乘法、广播机制、切片等。

      y = torch.tensor([4., 5., 6.])
      result = x + y  # 张量加法
      
    • 数据类型转换:通过.to()​方法可以改变张量的数据类型或者设备(CPU/GPU)。

      device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
      x_gpu = x.to(device)  # 将张量移动到GPU上
      
  2. 自动微分(Autograd)

    • 使用.requires_grad_()​标记张量以启用梯度计算:

      x.requires_grad_()
      y = x * 2
      z = y.sum()
      z.backward()  # 自动计算梯度
      print(x.grad)  # 输出x的梯度
      
  3. 神经网络模块(nn.Module)

    • 定义网络结构:继承自nn.Module​并实现__init__​和forward​方法。

      import torch.nn as nnclass SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.linear = nn.Linear(784, 10)  # 定义一个线性层def forward(self, x):out = self.linear(x)return out
      
    • 构建与训练模型:

      model = SimpleNet()
      criterion = nn.CrossEntropyLoss()
      optimizer = torch.optim.SGD(model.parameters(), lr=0.01)for epoch in range(num_epochs):# 前向传播output = model(inputs)loss = criterion(output, targets)# 反向传播及优化optimizer.zero_grad()loss.backward()optimizer.step()
      
  4. 数据加载器(DataLoader)

    • 使用torch.utils.data.DataLoader​来加载和批处理数据。

      from torch.utils.data import DataLoader, TensorDatasetdataset = TensorDataset(data_tensor, label_tensor)
      dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
      for batch_data, batch_labels in dataloader:# 在每个迭代周期中,batch_data和batch_labels会是当前批次的张量数据pass
      
  5. 保存与加载模型

    • 使用torch.save()​和torch.load()​保存和加载模型参数或整个模型。

      torch.save(model.state_dict(), 'model.pth')  # 保存模型参数
      model.load_state_dict(torch.load('model.pth'))  # 加载模型参数# 或者保存整个模型
      torch.save(model, 'model_full.pth')  # 保存整个模型(包括其结构和参数)
      loaded_model = torch.load('model_full.pth', map_location=device)  # 加载整个模型
      
  6. 多GPU并行训练

    • 使用torch.nn.DataParallel​或torch.nn.parallel.DistributedDataParallel​进行多GPU训练。

      model = nn.DataParallel(SimpleNet())  # 如果有多块GPU可用,则将模型分布到多个GPU上
      
  7. 控制流(autograd with control flow)

    • PyTorch支持在动态图模式下使用Python的控制流语句(如if-else、for循环),并且能正确跟踪梯度。

动态计算图、混合精度训练、量化压缩、可视化工具

动态计算图(Dynamic Computation Graph)
在PyTorch中,计算图是在运行时构建的,这意味着你可以根据程序运行的状态实时改变网络结构或执行不同的计算路径。这是与静态计算图框架如TensorFlow的一个显著区别。

示例:

# 动态改变模型结构
class DynamicModel(nn.Module):def __init__(self):super(DynamicModel, self).__init__()self.linear1 = nn.Linear(10, 5)self.linear2 = nn.Linear(5, 3)def forward(self, x, use_second_layer=True):out = F.relu(self.linear1(x))if use_second_layer:out = self.linear2(out)  # 根据条件决定是否使用第二层return outmodel = DynamicModel()

混合精度训练(Mixed Precision Training)
混合精度训练利用了FP16和FP32数据类型的优势,通过将部分计算转移到半精度上以减少内存占用和加快计算速度,同时保持关键部分(如梯度更新)在全精度下进行,以维持数值稳定性。

使用torch.cuda.amp​模块实现自动混合精度训练:

import torch
from torch.cuda.amp import autocast, GradScalerscaler = GradScaler()for inputs, targets in dataloader:inputs = inputs.cuda()targets = targets.cuda()with autocast():outputs = model(inputs)loss = criterion(outputs, targets)scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()optimizer.zero_grad()

量化压缩(Quantization)
量化是将模型的权重和激活从浮点数转换为低比特整数的过程,从而减小模型大小并加速推理。PyTorch提供了量化API来实现这一过程。

简化版量化示例:

import torch.quantization# 假设model是一个已经训练好的模型
model_fp32 = ...  # 初始化并训练模型# 首先对模型进行伪量化(模拟量化)
prepared_model = torch.quantization.prepare(model_fp32)
# 进行量化校准(收集统计数据)
quantized_model = torch.quantization.convert(prepared_model)# 现在quantized_model是一个量化后的模型,可以用于推理

可视化工具
PyTorch支持通过torchviz​库来进行计算图可视化,或者配合其他工具(如TensorBoard)展示模型结构、训练指标等。

对于简单的计算图可视化:

from torchviz import make_dotx = torch.randn(5, requires_grad=True)
y = x * 2
z = y ** 2
z.backward(torch.ones_like(z))dot_graph = make_dot(z)
dot_graph.view()  # 在Jupyter Notebook中显示图形

对于模型结构可视化,通常结合torchsummary​或直接使用TensorBoard配合torch.utils.tensorboard​接口:

from torchsummary import summarysummary(model, input_size=(1, 28, 28))  # 对于卷积神经网络,输入维度为(通道, 高, 宽)# 或者在TensorBoard中展示模型结构
from torch.utils.tensorboard import SummaryWriterwriter = SummaryWriter()
writer.add_graph(model, torch.rand((1, 28, 28)))  # 输入一个随机张量获取模型结构
writer.close()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/711304.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

面试经典150题——用最少数量的箭引爆气球

"The only person you are destined to become is the person you decide to be." - Ralph Waldo Emerson 1. 题目描述 2. 题目分析与解析 这个题目开始读题的时候是有点不好理解题意的,因此我先做个图让大家对于题意有更好更直观的理解再来分析题目。 …

如何使用Portainer创建Nginx容器并搭建web网站发布至公网可访问【内网穿透】

文章目录 前言1. 安装Portainer1.1 访问Portainer Web界面 2. 使用Portainer创建Nginx容器3. 将Web静态站点实现公网访问4. 配置Web站点公网访问地址4.1公网访问Web站点 5. 固定Web静态站点公网地址6. 固定公网地址访问Web静态站点 前言 Portainer是一个开源的Docker轻量级可视…

SQL 常见命令及规范

常见命令 1. 查看当前所有数据库 show databases; 2. 打开指定的库 use 库名 ; 3. 查看当前库的所有表 show tables; 4. 查看其他库的所有表 show tables from 库名 ; 5. 创建表 cerate table 表名 ( 列名 列类型, 列名 列类型, ..... …

基于YOLO家族最新模型YOLOv9开发构建自己的个性化目标检测系统从零构建模型完整训练、推理计算超详细教程【以自建数据酸枣病虫害检测为例】

在我前面的系列博文中,对于目标检测系列的任务写了很多超详细的教程,目的是能够读完文章即可实现自己完整地去开发构建自己的目标检测系统,感兴趣的话可以自行移步阅读: 《基于官方YOLOv4-u5【yolov5风格实现】开发构建目标检测模型超详细实战教程【以自建缺陷检测数据集为…

C# OpenVINO Crack Seg 裂缝分割 裂缝检测

目录 效果 模型信息 项目 代码 数据集 下载 C# OpenVINO Crack Seg 裂缝分割 裂缝检测 效果 模型信息 Model Properties ------------------------- date:2024-02-29T16:35:48.364242 author:Ultralytics task:segment version&…

去掉WordPress网页图片默认链接功能

既然是wordpress自动添加的,那么我们在上传图片到wordpress后台多媒体的时候,就可以手动改变链接指向或者删除掉,问题是每次都要这么做很麻烦,更别说有忘记的时候。一次性解决这个问题有两种方法,一种是No Image Link插…

【生成式AI】ChatGPT原理解析(1/3)- 对ChatGPT的常见误解

Hung-yi Lee 课件整理 文章目录 误解1误解2ChatGPT真正在做的事情-文字接龙 ChatGPT是在2022年12月7日上线的。 当时试用的感觉十分震撼。 误解1 我们想让chatGPT讲个笑话,可能会以为它是在一个笑话的集合里面随机地找一个笑话出来。 我们做一个测试就知道不是这样…

C# Post数据或文件到指定的服务器进行接收

目录 应用场景 实现原理 实现代码 PostAnyWhere类 ashx文件部署 小结 应用场景 不同的接口服务器处理不同的应用,我们会在实际应用中将A服务器的数据提交给B服务器进行数据接收并处理业务。 比如我们想要处理一个OFFICE文件,由用户上传到A服务器…

中国汽车电子行业发展现状分析及投资前景预测报告

全版价格:壹捌零零 报告版本:下单后会更新至最新版本 交货时间:1-2天 第一章 汽车电子相关概述 1.1 汽车的相关介绍 1.1.1 汽车的概念 我国国家最新标准《汽车和挂车类型的术语和定义》(GB/T3730.1—2001&…

基于springboot+vue的贸易行业crm系统

博主主页:猫头鹰源码 博主简介:Java领域优质创作者、CSDN博客专家、阿里云专家博主、公司架构师、全网粉丝5万、专注Java技术领域和毕业设计项目实战,欢迎高校老师\讲师\同行交流合作 ​主要内容:毕业设计(Javaweb项目|小程序|Pyt…

Flink分区相关

0、要点 Flink的分区列不会存数据,也就是两个列有一个分区列,则文件只会存另一个列的数据 1、CreateTable 根据SQL的执行流程,进入TableEnvironmentImpl.executeInternal,createTable分支 } else if (operation instanceof Crea…

Java-nio

一、NIO三大组件 NIO的三大组件分别是Channel,Buffer与Selector Java NIO系统的核心在于:通道(Channel)和缓冲区(Buffer)。通道表示打开到 IO 设备(例如:文件、套接字)的连接。若需要使用 NIO 系统,需要获取用于连接 IO 设备的通…

Spring的简单使用及内部实现原理

在现代的Java应用程序开发中,Spring Framework已经成为了不可或缺的工具之一。它提供了一种轻量级的、基于Java的解决方案,用于构建企业级应用程序和服务。本文将介绍Spring的简单使用方法,并深入探讨其内部实现原理。 首先,让我们…

mysql8.0使用MGR实现高可用

一、三节点MGR集群的安装部署 1. 安装准备 准备好下面三台服务器&#xff1a; IP端口角色192.168.150.213306mgr1192.168.150.223306mgr2192.168.150.233306mgr3 配置hosts解析 # cat >> /etc/hosts << EOF 192.168.150.21 mgr1 192.168.150.22 mgr2 192.168…

Windows环境下的调试器探究——硬件断点

与软件断点与内存断点不同&#xff0c;硬件断点不依赖被调试程序&#xff0c;而是依赖于CPU中的调试寄存器。 调试寄存器有7个&#xff0c;分别为Dr0~Dr7。 用户最多能够设置4个硬件断点&#xff0c;这是由于只有Dr0~Dr3用于存储线性地址。 其中&#xff0c;Dr4和Dr5是保留的…

java中容器继承体系

首先上图 源码解析 打开Collection接口源码&#xff0c;能够看到Collection接口是继承了Iterable接口。 public interface Collection<E> extends Iterable<E> { /** * ...... */ } 以下是Iterable接口源码及注释 /** * Implementing this inte…

makefileGDB使用

一、makefile 1、make && makefile makefile带来的好处就是——自动化编译&#xff0c;一旦写好&#xff0c;只需要一个make命令&#xff0c;整个工程完全自动编译&#xff0c;极大的提高了软件开发的效率 下面我们通过如下示例来进一步体会它们的作用&#xff1a; ①…

使用 Python 实现一个飞书/微信记账机器人,酷B了!

Python飞书文档机器人 今天的主题是&#xff1a;使用Python联动飞书文档机器人&#xff0c;实现一个专属的记账助手&#xff0c;这篇文章如果对你帮助极大&#xff0c;欢迎你分享给你的朋友、她、他&#xff0c;一起成长。 也欢迎大家留言&#xff0c;说说自己想看什么主题的…

代码随想录第天 78.子集 90.子集II

LeetCode 78 子集 题目描述 给你一个整数数组 nums &#xff0c;数组中的元素 互不相同 。返回该数组所有可能的子集&#xff08;幂集&#xff09;。 解集 不能 包含重复的子集。你可以按 任意顺序 返回解集。 示例 1&#xff1a; 输入&#xff1a;nums [1,2,3] 输出&…

LeetCode 2581.统计可能的树根数目:换根DP(树形DP)

【LetMeFly】2581.统计可能的树根数目&#xff1a;换根DP(树形DP) 力扣题目链接&#xff1a;https://leetcode.cn/problems/count-number-of-possible-root-nodes/ Alice 有一棵 n 个节点的树&#xff0c;节点编号为 0 到 n - 1 。树用一个长度为 n - 1 的二维整数数组 edges…