二、TorchRec中的分片

TorchRec中的分片


文章目录

  • TorchRec中的分片
  • 前言
  • 一、Planner
  • 二、EmbeddingTable 的分片
    • TorchRec 中所有可用的分片类型列表
  • 三、使用 TorchRec 分片模块进行分布式训练
    • TorchRec 在三个主要阶段处理此问题
  • 四、DistributedModelParallel(分布式模型并行)
  • 总结


前言

  • 我们来了解TorchRec架构中是如何分片的

一、Planner

  • TorchRec planner 帮助确定模型的最佳分片配置。

  • 它评估嵌入表分片的多种可能性,并优化性能。

  • planner 执行以下操作:

    • 评估硬件的内存约束。
    • 根据内存获取(例如嵌入查找)估算计算需求。
    • 解决特定于数据​​的因素。
    • 考虑其他硬件细节,例如带宽,以生成最佳分片计划。

二、EmbeddingTable 的分片

  • TorchRec sharder 为各种用例提供了多种分片策略,我们概述了一些分片策略及其工作原理,以及它们的优点和局限性。通常,我们建议使用 TorchRec planner 为您生成分片计划,因为它将为模型中的每个嵌入表找到最佳分片策略。
  • 每个分片策略都确定如何进行表拆分、是否应拆分表以及如何拆分、是否保留某些表的一个或几个副本等等。分片结果中的每个表片段,无论是一个嵌入表还是其中的一部分,都称为分片。
  • 可视化 TorchRec 中提供的不同分片方案下表分片的放置

不同分片方案下表分片的放置

TorchRec 中所有可用的分片类型列表

  • 表式 (TW):顾名思义,嵌入表作为一个整体保留并放置在一个 rank 上。
  • 列式 (CW):表沿 emb_dim 维度拆分,例如,emb_dim=256 拆分为 4 个分片:[64, 64, 64, 64]。
  • 行式 (RW):表沿 hash_size 维度拆分,通常在所有 rank 之间均匀拆分。
  • 表式-行式 (TWRW):表放置在一个主机上,在该主机上的 rank 之间进行行式拆分。
  • 网格分片 (GS):表是 CW 分片的,每个 CW 分片都以 TWRW 方式放置在主机上。
  • 数据并行 (DP):每个 rank 保留表的副本。

分片后,模块将转换为它们自身的分片版本,在 TorchRec 中称为 ShardedEmbeddingCollectionShardedEmbeddingBagCollection。这些模块处理输入数据的通信、嵌入查找和梯度。

三、使用 TorchRec 分片模块进行分布式训练

  • 有许多可用的分片策略,我们如何确定使用哪一个?
    • 每种分片方案都有相关的成本,这与模型大小和 GPU 数量相结合,决定了哪种分片策略最适合模型。
  • 在没有分片的情况下,每个 GPU 保留嵌入表的副本 (DP),主要成本是计算,其中每个 GPU 在前向传递中查找其内存中的嵌入向量,并在后向传递中更新梯度。
  • 使用分片时,会增加通信成本:
    • 每个 GPU 都需要向其他 GPU 请求嵌入向量查找,并通信计算出的梯度。这通常被称为 all2all 通信。
    • 在 TorchRec 中,对于给定 GPU 上的输入数据,我们确定数据每个部分的嵌入分片所在的位置,并将其发送到目标 GPU。
    • 然后,目标 GPU 将嵌入向量返回给原始 GPU。在后向传递中,梯度被发送回目标 GPU,并且分片会通过优化器进行相应的更新。
  • 如上所述,分片需要我们通信输入数据和嵌入查找。

TorchRec 在三个主要阶段处理此问题

我们将此称为分片嵌入模块前向传递,该传递用于 TorchRec 模型的训练和推理

  • 特征 All to All / 输入分布 (input_dist)

    • 将输入数据(以 KeyedJaggedTensor 的形式)通信到包含相关嵌入表分片的适当设备
  • 嵌入查找

    • 使用特征 all to all 交换后形成的新输入数据查找嵌入
  • 嵌入 All to All/输出分布 (output_dist)

    • 将嵌入查找数据通信回请求它的适当设备(根据设备接收到的输入数据)
  • 后向传递执行相同的操作,但顺序相反。

四、DistributedModelParallel(分布式模型并行)

  • 以上所有内容最终汇集成 TorchRec 用于分片和集成计划的主要入口点。
  • 在高层次上,DistributedModelParallel 执行以下操作:
    • 通过设置进程组和分配设备类型来初始化环境。
    • 如果没有提供 sharder,则使用默认的 sharder,默认 sharder 包括 EmbeddingBagCollectionSharder
    • 接收提供的分片计划,如果未提供,则生成一个。
    • 创建模块的分片版本,并用它们替换原始模块,例如,将 EmbeddingCollection 转换为 ShardedEmbeddingCollection
    • 默认情况下,使用 DistributedDataParallel 包装 DistributedModelParallel,使模块既是模型并行又是数据并行。

总结

  • 对TorchRec中的分块策略进行了解。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/75431.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何在 Spring Boot 项目中使用 MyBatis 进行批量操作以提升性能?

MyBatis 提供了 ExecutorType.BATCH 类型,允许将多个 SQL 语句进行组合,最后统一执行,从而减少数据库的访问频率,提升性能。 以下是如何在 Spring Boot 项目中使用 MyBatis 进行批量操作的关键点: 1. 配置 MyBatis 使…

Redis 字符串(String)详解

1. 什么是字符串类型 在 Redis 中,字符串(String) 是最基本的数据类型。它可以包含任何数据,比如文本、JSON、甚至二进制数据(如图片的 Base64 编码),最大长度为 512 MB。 字符串在 Redis 中不…

Elasticsearch 系列专题 - 第四篇:聚合分析

聚合(Aggregation)是 Elasticsearch 的强大功能之一,允许你对数据进行分组、统计和分析。本篇将从基础到高级逐步讲解聚合的使用,并结合实际案例展示其应用。 1. 聚合基础 1.1 什么是聚合(Aggregation)? 聚合是对文档集合的统计分析,类似于 SQL 中的 GROUP BY 和聚合…

YOLO学习笔记 | YOLOv8 全流程训练步骤详解(2025年4月更新)

===================================================== github:https://github.com/MichaelBeechan CSDN:https://blog.csdn.net/u011344545 ===================================================== 这里写自定义目录标题 一、数据准备1. 数据标注与格式转换2. 配置文件生…

context上下文(一)

创建一个基础的context 使用BackGround函数,BackGround函数原型如下: func Background() Context {return backgroundCtx{} } 作用:Background 函数用于创建一个空的 context.Context 对象。 context.Background() 函数用于获取一个空的 cont…

Java中常见的设计模式

Java中常见的设计模式 Java 中有 23 种经典设计模式,通常被分为三大类:创建型、结构型和行为型。每个设计模式都解决了不同类型的设计问题。以下是几种常见设计模式的总结,并附带了实际应用场景、示例代码和详细的注释说明。 一、创建型设计…

责任链设计模式(单例+多例)

目录 1. 单例责任链 2. 多例责任链 核心区别对比 实际应用场景 单例实现 多例实现 初始化 初始化责任链 执行测试方法 欢迎关注我的博客!26届java选手,一起加油💘💦👨‍🎓😄😂 最近在…

springboot 处理编码的格式为opus的音频数据解决方案【java8】

opus编码的格式概念: Opus是一个有损声音编码的格式,由Xiph.Org基金会开发,之后由IETF(互联网工程任务组)进行标准化,目标是希望用单一格式包含声音和语音,取代Speex和Vorbis,且适用…

vue项目引入tailwindcss

vue3项目引入tailwindcss vue3 vite tailwindcss3 版本 初始化项目 npm create vitelatest --template vue cd vue npm install npm run dev安装tailwindcss3 和 postcss 引入 npm install -D tailwindcss3 postcss autoprefixer // 初始化引用 npx tailwindcss init -p…

Google ADK(Agent Development Kit)简要示例说明

一、环境准备与依赖安装 1.1 系统 硬件: GPU NVIDIA 3070加速模型推理,内存64GB软件: Python 3.11Docker 28.04(用于容器化部署)Kubernetes 1.25(可选,用于集群管理) 1.2 安装 A…

批量给文件编排序号,支持数字序号及时间日期序号编排文件

当我们需要对文件进行编号的时候,我们可以通过这个工具来帮我们完成,它可以支持从 001 到 100 甚至更多的数字序号编号。也可以支持按照日期、时间等方式对文件进行编号操作。这是一种操作简单,处理起来也非常的高效文件编排序号的方法。 工作…

【系统架构】AI时代下,系统架构师如何修炼

在AI时代,系统架构师的角色正经历深刻变革,需在技术深度、工具应用与思维模式上全面升级。以下结合行业趋势与实践建议,总结系统架构师的修炼路径: 一、掌握AI工具,重构工作流 自动化文档与设计 利用生成式AI(如DeepSeek、ChatGPT)完成70%的需求文档、接口设计及架构图生…

图像颜色空间对比(Opencv)

1. 颜色转换 import cv2 import matplotlib.pyplot as plotimg cv2.imread("tmp.jpg") img_r cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img_g cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) img_h cv2.cvtColor(img, cv2.COLOR_BGR2HSV) img_l cv2.cvtColor(img, cv2.C…

JDBC驱动autosave缺陷的修复与配置指南

opengauss-jdbc-6.0.0.jar和opengauss-jdbc-6.0.0-og.jar版本修复了:autosavealways时,事务嵌套太深,导致栈溢出问题。如果使用的版本低于opengauss-jdbc-6.0.0版本,需要通过替换jdbc驱动和修改url参数来解决autosave缺陷。以下是…

K8S-证书过期更新

K8S证书过期问题 K8S证书过期处理方法 Unable to connect to the server: x509: certificate has expired or is not yet valid 1、查看证书有效期: # kubeadm certs check-expiration2、备份证书 # cp -rp /etc/kubernetes /etc/kubernetes.bak3、直接重建证书 …

2025 年网络安全终极指南

我们生活在一个科技已成为日常生活不可分割的一部分的时代。对数字世界的依赖性日益增强的也带来了更大的网络风险。 网络安全并不是IT专家的专属特权,而是所有用户的共同责任。通过简单的行动,我们可以保护我们的数据、隐私和财务,降低成为…

Python的那些事第四十九篇:基于Python的智能客服系统设计与实现

基于Python的智能客服系统设计与实现 摘要 随着人工智能技术的飞速发展,智能客服系统逐渐成为企业提升客户服务质量和效率的关键工具。本文详细介绍了基于Python的智能客服系统的设计与实现方案,涵盖了系统架构、核心功能、技术选型及优化建议,旨在为企业构建高效、智能的客…

第Y1周:调用YOLOv5官方权重进行检测

🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 文章目录 1、前言2、下载源码3、运行代码 1、前言 YOLOv5分为YOLOv5s、YOLOv5m、YOLOv5l、YOLOv5x四个版本,这里以YOLOv5s为例。 2、下载源码 安…

Python小程序 - 文件处理3:正则表达式

正则表达式:文本年鉴表。遗留的问题很多。。。用AI再想想 需求:读入txt文件,过滤文件有关年记录 0)读入txt文件 1)以“。”,中文句号,为界区分一句,最小统计单位 2)年格…

【antd + vue】Tree 树形控件:默认展开所有树节点 、点击文字可以“选中/取消选中”节点

一、defaultExpandAll 默认展开所有树节点 1、需求:默认展开所有树节点 2、问题: v-if"data.length"判断的层级不够,只判断到了物理那一层,所以只展开到那一层。 3、原因分析: 默认展开所有树节点, 如果是…