深度学习中的并行策略概述:2 Data Parallelism

深度学习中的并行策略概述:2 Data Parallelism
在这里插入图片描述
数据并行(Data Parallelism)的核心在于将模型的数据处理过程并行化。具体来说,面对大规模数据批次时,将其拆分为较小的子批次,并在多个计算设备上同时进行处理。每个设备负责处理一个子批次,实现并行计算。处理完成后,将各个设备上的计算结果汇总,以便对模型进行统一更新。由于其在深度学习中的普遍应用,数据并行成为了一种广泛支持的并行计算策略,并在主流框架中得到了良好的实现。

以下代码展示了如何在PyTorch中使用nn.DataParallel和DistributedDataParallel实现数据并行,以加速模型的训练过程。

使用nn.DataParallel实现数据并行

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader# 假设我们有一个简单的数据集类
class SimpleDataset(Dataset):def __init__(self, data, target):self.data = dataself.target = targetdef __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx], self.target[idx]# 假设我们有一个简单的神经网络模型
class SimpleModel(nn.Module):def __init__(self, input_dim):super(SimpleModel, self).__init__()self.fc = nn.Linear(input_dim, 1)def forward(self, x):return torch.sigmoid(self.fc(x))# 假设我们有一些数据
n_sample = 100
n_dim = 10
batch_size = 10
X = torch.randn(n_sample, n_dim)
Y = torch.randint(0, 2, (n_sample,)).float()
dataset = SimpleDataset(X, Y)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)# 初始化模型
device_ids = [0, 1, 2]  # 指定使用的GPU编号
model = SimpleModel(n_dim).to(device_ids[0])
model = nn.DataParallel(model, device_ids=device_ids)# 定义优化器和损失函数
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = nn.BCELoss()# 训练模型
for epoch in range(10):for batch_idx, (inputs, targets) in enumerate(data_loader):inputs, targets = inputs.to('cuda'), targets.to('cuda')outputs = model(inputs)loss = criterion(outputs, targets.unsqueeze(1))optimizer.zero_grad()loss.backward()optimizer.step()print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item()}')

使用DistributedDataParallel实现数据并行

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP# 假设我们有一个简单的数据集类
class SimpleDataset(Dataset):def __init__(self, data, target):self.data = dataself.target = targetdef __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx], self.target[idx]# 假设我们有一个简单的神经网络模型
class SimpleModel(nn.Module):def __init__(self, input_dim):super(SimpleModel, self).__init__()self.fc = nn.Linear(input_dim, 1)def forward(self, x):return torch.sigmoid(self.fc(x))# 初始化进程组
def init_process(rank, world_size, backend='nccl'):dist.init_process_group(backend, rank=rank, world_size=world_size)# 训练函数
def train(rank, world_size):init_process(rank, world_size)torch.cuda.set_device(rank)model = SimpleModel(10).to(rank)model = DDP(model, device_ids=[rank])dataset = SimpleDataset(torch.randn(100, 10), torch.randint(0, 2, (100,)).float())sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=world_size, rank=rank)data_loader = DataLoader(dataset, batch_size=10, sampler=sampler)optimizer = optim.SGD(model.parameters(), lr=0.01)criterion = nn.BCELoss()for epoch in range(10):for inputs, targets in data_loader:inputs, targets = inputs.to(rank), targets.to(rank)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, targets.unsqueeze(1))loss.backward()optimizer.step()if __name__ == "__main__":world_size = 4torch.multiprocessing.spawn(train, args=(world_size,), nprocs=world_size, join=True)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/65036.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

中国农业科学院深圳农业基因组研究所合成生物学研究中心-随笔06

更新读研择校贴,生物合成行业领先的单位一览SWHC002 中国科学院合成生物学重点实验室介绍-随笔05-CSDN博客 中国农业科学院深圳农业基因组研究所(基因组所)合成生物学研究中心 https://www.agis.org.cn/bsgk/yjsjj/index.htm #官网 htt…

分布式专题(10)之ShardingSphere分库分表实战指南

一、ShardingSphere产品介绍 Apache ShardingSphere 是一款分布式的数据库生态系统, 可以将任意数据库转换为分布式数据库,并通过数据分片、弹性伸缩、加密等能力对原有数据库进行增强。Apache ShardingSphere 设计哲学为 Database Plus,旨在…

QT 控件定义为智能指针引发的bug

问题描述&#xff1a; std::unique_ptr<QStackedLayout> m_stacked_layout; 如上为定义&#xff1b; 调用&#xff1a; Line13ABClient::Line13ABClient(QWidget *parent) : BaseWidget(parent) { // 成员变量初始化 m_get_ready false; m_tittle_wnd…

帧缓存的分配

帧缓存实际上就是一块内存。在 Android 系统中分配与回收帧缓存&#xff0c;使用的是一个叫 ION 的内核模块&#xff0c;App 使用 ioctl 系统调用后&#xff0c;会在内核内存中分配一块符合要求的内存&#xff0c;用户态会拿到一个 fd&#xff08;有的地方也称之为 handle&…

vue3+vite一个IP对站点名称的前端curd更新-会议系统优化

vue3-tailwind-todo https://github.com/kgrg/vue3-tailwind-todo 基于这个项目,把ip到sta的映射做了前端管理. 核心代码是存储和获得的接口,需要flask提供. def redis2ipdic():global ipdicipdic.clear()tmdiccl.hgetall(IPDIC_KEY)for k in tmdic.keys():ipdic[k.decode() …

Elasticsearch-脚本查询

脚本查询 概念 Scripting是Elasticsearch支持的一种专门用于复杂场景下支持自定义编程的强大的脚本功能&#xff0c;ES支持多种脚本语言&#xff0c;如painless&#xff0c;其语法类似于Java,也有注释、关键字、类型、变量、函数等&#xff0c;其就要相对于其他脚本高出几倍的性…

golang LeetCode 热题 100(动态规划)-更新中

爬楼梯 假设你正在爬楼梯。需要 n 阶你才能到达楼顶。 每次你可以爬 1 或 2 个台阶。你有多少种不同的方法可以爬到楼顶呢&#xff1f; 示例 1&#xff1a;输入&#xff1a;n 2 输出&#xff1a;2 解释&#xff1a;有两种方法可以爬到楼顶。 1. 1 阶 1 阶 2. 2 阶 示例 2&…

EasyExcel 模板+公式填充

使用 CellWriteHandler 的实现类来实现公式写入 Data NoArgsConstructor public class CustomCellWriteHandler implements CellWriteHandler {private int maxRowNum 2000;// 动态传入列表数量public CustomCellWriteHandler(int maxRowNum) {this.maxRowNum maxRowNum;}Ov…

【每日学点鸿蒙知识】Charles抓包、lock文件处理、WebView组件、NFC相关、CallMethod失败等

1、HarmonyOS系统中如何使用Charles抓包&#xff1f; 在HarmonyOS操作系统中&#xff0c;使用Charles进行抓包的步骤如下&#xff1a; 在Charles中设置代理。 首先&#xff0c;在Charles的菜单栏上选择“Proxy”→“Proxy Settings”&#xff0c;然后填入代理端口&#xff0…

抓取手机HCI日志

荣耀手机 1、打开开发者模式 2、开启HCI、ADB调试 3、开启AP LOG 拨号界面输入*##2846579##* 4、蓝牙配对 5、抓取log adb pull /data/log/bt ./

WebAPI编程(第一天,第二天)

WebAPI编程&#xff08;第一天&#xff0c;第二天&#xff09; day01 - Web APIs 1.1. Web API介绍 1.1.1 API的概念1.1.2 Web API的概念1.1.3 API 和 Web API 总结 1.2. DOM 介绍 1.2.1 什么是DOM1.2.2. DOM树 1.3. 获取元素 1.3.1. 根据ID获取1.3.2. 根据标签名获取元素1.3.…

windows下Redis的使用

Redis简介&#xff1a; Redis 是一个开源的使用 ANSI C 语言编写、遵守 BSD 协议、支持网络、可基于内存、分布式、可选持久性的键值对(Key-Value)存储数据库&#xff0c;并提供多种语言的 API。 Redis通常被称为数据结构服务器&#xff0c;因为值&#xff08;value&#xff…

【贪吃蛇小游戏 - JavaIDEA】基于Java实现的贪吃蛇小游戏导入IDEA教程

有问题请留言或私信 步骤 下载项目源码&#xff1a;项目源码 解压项目源码到本地 打开IDEA 左上角&#xff1a;文件 → 新建 → 来自现有源代码的项目 找到解压在本地的项目源代码文件&#xff0c;点击确定 选择“从现有项目创建项目”。点击“下一步” 点击下一步&a…

一个简单封装的的nodejs缓存对象

我们在日常编码中&#xff0c;经常会用到缓存&#xff0c;而一个有效的缓存管理&#xff0c;也是大家必不可少的工具。而nodejs没有内置专用的缓存对象&#xff0c;并且由于js的作用域链的原因&#xff0c;很多变量使用起来容易出错&#xff0c;如果用一个通用的缓存管理起来&a…

RTOS下的任务管理

2.3 RTOS下的任务管理(***) RTOS的任务管理主要是进行哪些功能&#xff1f; RTOS的任务管理的多任务管理是怎样进行与实现的&#xff1f; 任务管理中FreeRTOS如何给每个任务分配CPU时间&#xff1f; 文章目录 2.3 RTOS下的任务管理(***)2.3.0 任务概述2.3.1任务的创建与删除2.3…

GitLab 停止为中国区用户提供 GitLab.com 账号服务

GitLab 通知中国区用户将停止提供 GitLab.com 账号服务&#xff0c;建议现有用户迁移到极狐。 中国 IP 地址现在访问 GitLab.com 会跳转到 about.gitlab.com&#xff0c;推荐用户访问极狐。 Gundaz Aghayev 写道&#xff1a;GitLab 在发送中国地区用户的电子邮件通知中称&…

深度学习——神经网络中前向传播、反向传播与梯度计算原理

一、前向传播 1.1 概念 神经网络的前向传播&#xff08;Forward Propagation&#xff09;就像是一个数据处理的流水线。从输入层开始&#xff0c;按照网络的层次结构&#xff0c;每一层的神经元接收上一层神经元的输出作为自己的输入&#xff0c;经过线性变换&#xff08;加权…

【初阶数据结构与算法】八大排序算法之归并排序与非比较排序(计数排序)

文章目录 一、归并排序二、非比较排序之计数排序三、归并排序和计数排序的性能测试 一、归并排序 归并排序&#xff08;MERGE-SORT&#xff09;是建⽴在归并操作上的⼀种有效的排序算法,该算法是采⽤分治法&#xff08;Divide andConquer&#xff09;的⼀个⾮常典型的应⽤   …

window安装TradingView

目录 下载安装包 修改文件后缀&#xff0c;解压 将K线换成国内涨红跌绿样式 下载安装包 https://www.tradingview.com/desktop/ 下载完成后是.msix格式文件 &#xff08;我在win10和win11的系统中尝试运行msix都没有成功&#xff0c;所以放弃直接双击运行msix&#xff…

STM32HAL库中RTC闹钟设置时分秒,年月日

在STM32的HAL库中&#xff0c;RTC&#xff08;实时时钟&#xff09;模块提供了多种功能来管理时间和日期&#xff0c;包括设置闹钟。对于RTC闹钟功能&#xff0c;确实主要集中在时、分、秒的配置上&#xff0c;但年、月、日也可以通过RTC日期寄存器进行设置&#xff0c;并且可以…