深度学习中的并行策略概述:2 Data Parallelism

深度学习中的并行策略概述:2 Data Parallelism
在这里插入图片描述
数据并行(Data Parallelism)的核心在于将模型的数据处理过程并行化。具体来说,面对大规模数据批次时,将其拆分为较小的子批次,并在多个计算设备上同时进行处理。每个设备负责处理一个子批次,实现并行计算。处理完成后,将各个设备上的计算结果汇总,以便对模型进行统一更新。由于其在深度学习中的普遍应用,数据并行成为了一种广泛支持的并行计算策略,并在主流框架中得到了良好的实现。

以下代码展示了如何在PyTorch中使用nn.DataParallel和DistributedDataParallel实现数据并行,以加速模型的训练过程。

使用nn.DataParallel实现数据并行

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader# 假设我们有一个简单的数据集类
class SimpleDataset(Dataset):def __init__(self, data, target):self.data = dataself.target = targetdef __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx], self.target[idx]# 假设我们有一个简单的神经网络模型
class SimpleModel(nn.Module):def __init__(self, input_dim):super(SimpleModel, self).__init__()self.fc = nn.Linear(input_dim, 1)def forward(self, x):return torch.sigmoid(self.fc(x))# 假设我们有一些数据
n_sample = 100
n_dim = 10
batch_size = 10
X = torch.randn(n_sample, n_dim)
Y = torch.randint(0, 2, (n_sample,)).float()
dataset = SimpleDataset(X, Y)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)# 初始化模型
device_ids = [0, 1, 2]  # 指定使用的GPU编号
model = SimpleModel(n_dim).to(device_ids[0])
model = nn.DataParallel(model, device_ids=device_ids)# 定义优化器和损失函数
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = nn.BCELoss()# 训练模型
for epoch in range(10):for batch_idx, (inputs, targets) in enumerate(data_loader):inputs, targets = inputs.to('cuda'), targets.to('cuda')outputs = model(inputs)loss = criterion(outputs, targets.unsqueeze(1))optimizer.zero_grad()loss.backward()optimizer.step()print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item()}')

使用DistributedDataParallel实现数据并行

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP# 假设我们有一个简单的数据集类
class SimpleDataset(Dataset):def __init__(self, data, target):self.data = dataself.target = targetdef __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx], self.target[idx]# 假设我们有一个简单的神经网络模型
class SimpleModel(nn.Module):def __init__(self, input_dim):super(SimpleModel, self).__init__()self.fc = nn.Linear(input_dim, 1)def forward(self, x):return torch.sigmoid(self.fc(x))# 初始化进程组
def init_process(rank, world_size, backend='nccl'):dist.init_process_group(backend, rank=rank, world_size=world_size)# 训练函数
def train(rank, world_size):init_process(rank, world_size)torch.cuda.set_device(rank)model = SimpleModel(10).to(rank)model = DDP(model, device_ids=[rank])dataset = SimpleDataset(torch.randn(100, 10), torch.randint(0, 2, (100,)).float())sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=world_size, rank=rank)data_loader = DataLoader(dataset, batch_size=10, sampler=sampler)optimizer = optim.SGD(model.parameters(), lr=0.01)criterion = nn.BCELoss()for epoch in range(10):for inputs, targets in data_loader:inputs, targets = inputs.to(rank), targets.to(rank)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, targets.unsqueeze(1))loss.backward()optimizer.step()if __name__ == "__main__":world_size = 4torch.multiprocessing.spawn(train, args=(world_size,), nprocs=world_size, join=True)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/65036.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

分布式专题(10)之ShardingSphere分库分表实战指南

一、ShardingSphere产品介绍 Apache ShardingSphere 是一款分布式的数据库生态系统, 可以将任意数据库转换为分布式数据库,并通过数据分片、弹性伸缩、加密等能力对原有数据库进行增强。Apache ShardingSphere 设计哲学为 Database Plus,旨在…

帧缓存的分配

帧缓存实际上就是一块内存。在 Android 系统中分配与回收帧缓存,使用的是一个叫 ION 的内核模块,App 使用 ioctl 系统调用后,会在内核内存中分配一块符合要求的内存,用户态会拿到一个 fd(有的地方也称之为 handle&…

vue3+vite一个IP对站点名称的前端curd更新-会议系统优化

vue3-tailwind-todo https://github.com/kgrg/vue3-tailwind-todo 基于这个项目,把ip到sta的映射做了前端管理. 核心代码是存储和获得的接口,需要flask提供. def redis2ipdic():global ipdicipdic.clear()tmdiccl.hgetall(IPDIC_KEY)for k in tmdic.keys():ipdic[k.decode() …

Elasticsearch-脚本查询

脚本查询 概念 Scripting是Elasticsearch支持的一种专门用于复杂场景下支持自定义编程的强大的脚本功能,ES支持多种脚本语言,如painless,其语法类似于Java,也有注释、关键字、类型、变量、函数等,其就要相对于其他脚本高出几倍的性…

golang LeetCode 热题 100(动态规划)-更新中

爬楼梯 假设你正在爬楼梯。需要 n 阶你才能到达楼顶。 每次你可以爬 1 或 2 个台阶。你有多少种不同的方法可以爬到楼顶呢? 示例 1:输入:n 2 输出:2 解释:有两种方法可以爬到楼顶。 1. 1 阶 1 阶 2. 2 阶 示例 2&…

【每日学点鸿蒙知识】Charles抓包、lock文件处理、WebView组件、NFC相关、CallMethod失败等

1、HarmonyOS系统中如何使用Charles抓包? 在HarmonyOS操作系统中,使用Charles进行抓包的步骤如下: 在Charles中设置代理。 首先,在Charles的菜单栏上选择“Proxy”→“Proxy Settings”,然后填入代理端口&#xff0…

抓取手机HCI日志

荣耀手机 1、打开开发者模式 2、开启HCI、ADB调试 3、开启AP LOG 拨号界面输入*##2846579##* 4、蓝牙配对 5、抓取log adb pull /data/log/bt ./

WebAPI编程(第一天,第二天)

WebAPI编程(第一天,第二天) day01 - Web APIs 1.1. Web API介绍 1.1.1 API的概念1.1.2 Web API的概念1.1.3 API 和 Web API 总结 1.2. DOM 介绍 1.2.1 什么是DOM1.2.2. DOM树 1.3. 获取元素 1.3.1. 根据ID获取1.3.2. 根据标签名获取元素1.3.…

windows下Redis的使用

Redis简介: Redis 是一个开源的使用 ANSI C 语言编写、遵守 BSD 协议、支持网络、可基于内存、分布式、可选持久性的键值对(Key-Value)存储数据库,并提供多种语言的 API。 Redis通常被称为数据结构服务器,因为值(value&#xff…

【贪吃蛇小游戏 - JavaIDEA】基于Java实现的贪吃蛇小游戏导入IDEA教程

有问题请留言或私信 步骤 下载项目源码:项目源码 解压项目源码到本地 打开IDEA 左上角:文件 → 新建 → 来自现有源代码的项目 找到解压在本地的项目源代码文件,点击确定 选择“从现有项目创建项目”。点击“下一步” 点击下一步&a…

RTOS下的任务管理

2.3 RTOS下的任务管理(***) RTOS的任务管理主要是进行哪些功能? RTOS的任务管理的多任务管理是怎样进行与实现的? 任务管理中FreeRTOS如何给每个任务分配CPU时间? 文章目录 2.3 RTOS下的任务管理(***)2.3.0 任务概述2.3.1任务的创建与删除2.3…

深度学习——神经网络中前向传播、反向传播与梯度计算原理

一、前向传播 1.1 概念 神经网络的前向传播(Forward Propagation)就像是一个数据处理的流水线。从输入层开始,按照网络的层次结构,每一层的神经元接收上一层神经元的输出作为自己的输入,经过线性变换(加权…

【初阶数据结构与算法】八大排序算法之归并排序与非比较排序(计数排序)

文章目录 一、归并排序二、非比较排序之计数排序三、归并排序和计数排序的性能测试 一、归并排序 归并排序(MERGE-SORT)是建⽴在归并操作上的⼀种有效的排序算法,该算法是采⽤分治法(Divide andConquer)的⼀个⾮常典型的应⽤   …

window安装TradingView

目录 下载安装包 修改文件后缀,解压 将K线换成国内涨红跌绿样式 下载安装包 https://www.tradingview.com/desktop/ 下载完成后是.msix格式文件 (我在win10和win11的系统中尝试运行msix都没有成功,所以放弃直接双击运行msix&#xff…

FPGA多路MIPI转FPD-Link视频缩放拼接显示,基于IMX327+FPD953架构,提供2套工程源码和技术支持

目录 1、前言工程概述免责声明 2、相关方案推荐本博主所有FPGA工程项目-->汇总目录我这里已有的 MIPI 编解码方案我这里已有的FPGA图像缩放方案本博已有的已有的FPGA视频拼接叠加融合方案 3、本 MIPI CSI-RX IP 介绍4、详细设计方案设计原理框图IMX327 及其配置FPD-Link视频…

React+Vite从零搭建项目及配置详解

相信很多React初学者第一次搭建自己的项目,搭建时会无从下手,本篇适合快速实现功能,熟悉React项目搭建流程。 目录 一、创建项目react-item 二、调整项目目录结构 三、使用scss预处理器 四、组件库Ant Design 五、配置基础路由 六、配置…

Unity复刻胡闹厨房复盘 模块一 新输入系统订阅链与重绑定

本文仅作学习交流,不做任何商业用途 郑重感谢siki老师的汉化教程与代码猴的免费教程以及搬运烤肉的小伙伴 版本:Unity6 模板:3D 核心 渲染管线:URP ------------------------------…

从零开始的编程-java篇1.6.1 万变不离其宗,hello word

前言: 通过实践而发现真理,又通过实践而证实真理和发展真理。从感性认识而能动地发展到理性认识,又从理性认识而能动地指导革命实践,改造主观世界和客观世界。实践、认识、再实践、再认识,这种形式,循环往…

【漏洞复现】CVE-2021-45788 SQL Injection

漏洞信息 NVD - cve-2021-45788 Time-based SQL Injection vulnerabilities were found in Metersphere v1.15.4 via the “orders” parameter. Authenticated users can control the parameters in the “order by” statement, which causing SQL injection. API: /test…

Mac系统下 idea运行maven项目中存在的问题BeanDefinitionStoreException

1.在进行 注解XML 方式整合三层架构事出现此问题 org.springframework.beans.factory.BeanDefinitionStoreException: Failed to read candidate component class: file [/Volumes/PS2000/Java/SpringProject/micro-shop/spring-annotation-practice-03/target/classes/com/ja…