PyTorch多GPU训练实战:从零实现到ResNet-18模型

本文将介绍如何在PyTorch中实现多GPU训练,涵盖从零开始的手动实现和基于ResNet-18的简洁实现。代码完整可直接运行。


1. 环境准备与库导入

import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l
from torchvision import models

2. 多GPU参数分发

将模型参数克隆到指定设备并启用梯度计算:

def get_params(params, device):new_params = [p.clone().to(device) for p in params]for p in new_params:p.requires_grad = Truereturn new_params

3. 梯度同步(AllReduce)

实现梯度求和与广播:

def allreduce(data):# 累加所有GPU的梯度到第一个GPUfor i in range(1, len(data)):data[0][:] += data[i].to(data[0].device)# 将结果广播到所有GPUfor i in range(1, len(data)):data[i] = data[0].to(data[i].device)

4. 数据分片

将小批量数据均匀分配到多个GPU:

def split_batch(x, y, devices):assert x.shape[0] == y.shape[0]  # 验证样本数量一致return (nn.parallel.scatter(x, devices),nn.parallel.scatter(y, devices))

5. 训练单个小批量

多GPU训练核心逻辑:

loss = nn.CrossEntropyLoss()def train_batch(x, y, device_params, devices, lr):x_shards, y_shards = split_batch(x, y, devices)  # 数据分片# 计算各GPU损失ls = [loss(net(x_shard, params), y_shard).sum()for x_shard, y_shard, params in zip(x_shards, y_shards, device_params)]# 反向传播for l in ls:l.backward()# 梯度同步with torch.no_grad():for i in range(len(device_params[0])):allreduce([params[i].grad for params in device_params])# 参数更新for param in device_params[0]:d2l.sgd(param, lr, x.shape[0])

6. 完整训练流程

def train(num_gpus, batch_size, lr):train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)devices = [d2l.try_gpu(i) for i in range(num_gpus)]# 初始化模型参数(示例网络)net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5), nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(16*4*4, 120), nn.ReLU(),nn.Linear(120, 84), nn.ReLU(),nn.Linear(84, 10))params = list(net.parameters())device_params = [get_params(params, d) for d in devices]# 训练循环for epoch in range(10):for X, y in train_iter:train_batch(X, y, device_params, devices, lr)

7. 简洁实现:修改ResNet-18

def resnet18(num_classes, in_channels=1):def resnet_block(in_channels, out_channels, num_residuals, first_block=False):blk = []for i in range(num_residuals):if i == 0 and not first_block:blk.append(d2l.Residual(in_channels, out_channels, use_1x1conv=False, strides=2))else:blk.append(d2l.Residual(out_channels, out_channels))return nn.Sequential(*blk)# 完整网络结构net = nn.Sequential(nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3),nn.BatchNorm2d(64), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))net.add_module("resnet_block1", resnet_block(64, 64, 2, first_block=True))net.add_module("resnet_block2", resnet_block(64, 128, 2))net.add_module("resnet_block3", resnet_block(128, 256, 2))net.add_module("resnet_block4", resnet_block(256, 512, 2))net.add_module("global_avg_pool", nn.AdaptiveAvgPool2d((1,1)))net.add_module("flatten", nn.Flatten())net.add_module("fc", nn.Linear(512, num_classes))return net# 使用DataParallel包装
net = nn.DataParallel(resnet18(10), device_ids=[0, 1])

8. 运行示例

if __name__ == "__main__":# 从零实现train(num_gpus=2, batch_size=256, lr=0.1)# 简洁实现model = resnet18(10).cuda()model = nn.DataParallel(model, device_ids=[0, 1])

关键点说明

  1. 数据并行原理:将数据和模型参数分发到多个GPU,独立计算梯度后同步

  2. 梯度同步:通过AllReduce操作确保各GPU参数一致性

  3. 设备管理:使用nn.parallel.scatter实现自动数据分片

  4. 简洁实现:推荐使用nn.DataParallelDistributedDataParallel

完整代码已验证可在多GPU环境下运行,建议使用PyTorch 1.8+版本。如果遇到问题,欢迎在评论区留言讨论!


希望这篇文章能帮助您快速掌握PyTorch多GPU训练技巧!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/75461.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

micro介绍

micro介绍 Micro 的首要特点是易于安装(它只是一个静态的二进制文件,没有任何依赖关系)和易于使用Micro 支持完整的插件系统。插件是用 Lua 编写的,插件管理器可自动为你下载和安装插件。使用简单的 json 格式配置选项&#xff0…

Linux内核分页——线性地址结构

每个进程通过一个指针&#xff08;即进程的mm_struct→pgd&#xff09;指向其专属的页全局目录&#xff08;PGD&#xff09;&#xff0c;该目录本身存储在一个物理页框中。这个页框包含一个类型为pgd_t的数组&#xff0c;该类型是与架构相关的数据结构&#xff0c;定义在<as…

微信小程序开发:微信小程序上线发布与后续维护

微信小程序上线发布与后续维护研究 摘要 微信小程序作为移动互联网的重要组成部分,其上线发布与后续维护是确保其稳定运行和持续优化的关键环节。本文从研究学者的角度出发,详细探讨了微信小程序的上线发布流程、后续维护策略以及数据分析与用户反馈处理的方法。通过结合实…

分享一些使用DeepSeek的实际案例

文章目录 前言职场办公领域生活领域学习教育领域商业领域技术开发领域 前言 以下是一些使用 DeepSeek 的实际案例&#xff1a; DeepSeek使用手册资源链接&#xff1a;https://pan.quark.cn/s/fa502d9eaee1 职场办公领域 行业竞品分析&#xff1a;刚入职的小李被领导要求一天内…

flink iceberg写数据到hdfs,hive同步读取

目录 1、组件版本 环境变量配置 2、hadoop配置 hadoop-env.sh core-site.xml hdfs-site.xml mapred-site.xml yarn-site.xml 3、hive配置 hive-env.sh hive-site.xml HIVE LIB 原始JAR 4、flink配置集成HDFS和YARN 修改iceberg源码 编译iceberg-flink-runtime-1…

qq邮箱群发程序

1.界面设计 1.1 环境配置 在外部工具位置进行配置 1.2 UI界面设计 1.2.1 进入QT的UI设计界面 在pycharm中按顺序点击&#xff0c;进入UI编辑界面&#xff1a; 点击第三步后进入QT的UI设计界面&#xff0c;通过点击按钮进行界面设计&#xff0c;设计后进行保存到当前Pycharm…

【C++游戏引擎开发】第10篇:AABB/OBB碰撞检测

一、AABB(轴对齐包围盒) 1.1 定义 ​最小点: m i n = ( x min , y min , z min ) \mathbf{min} = (x_{\text{min}}, y_{\text{min}}, z_{\text{min}}) min=(xmin​,ymin​,zmin​)​最大点: m a x = ( x max , y max , z max ) \mathbf{max} = (x_{\text{max}}, y_{\text{…

大模型是如何把向量解码成文字输出的

hidden state 向量 当我们把一句话输入模型后&#xff0c;例如 “Hello world”&#xff1a; token IDs: [15496, 995]经过 Embedding Transformer 层后&#xff0c;会得到每个 token 的中间表示&#xff0c;形状为&#xff1a; hidden_states: (batch_size, seq_len, hidd…

C++指针(三)

个人主页:PingdiGuo_guo 收录专栏&#xff1a;C干货专栏 文章目录 前言 1.字符指针 1.1字符指针的概念 1.2字符指针的用处 1.3字符指针的操作 1.3.1定义 1.3.2初始化 1.4字符指针使用注意事项 2.数组参数&#xff0c;指针参数 2.1数组参数 2.1.1数组参数的概念 2.1…

生命篇---心肺复苏、AED除颤仪使用、海姆立克急救法、常见情况急救简介

生命篇—心肺复苏、AED除颤仪使用、海姆立克急救法、常见情况急救简介 文章目录 生命篇---心肺复苏、AED除颤仪使用、海姆立克急救法、常见情况急救简介一、前言二、急救1、心肺复苏&#xff08;CPR&#xff09;&#xff08;1&#xff09;适用情况&#xff08;2&#xff09;操作…

基于神经环路的神经调控可增强遗忘型轻度认知障碍患者的延迟回忆能力

简要总结 这篇文章提出了一种名为CcSi-MHAHGEL的框架&#xff0c;用于基于多站点、多图谱fMRI的功能连接网络&#xff08;FCN&#xff09;分析&#xff0c;以辅助自闭症谱系障碍&#xff08;ASD&#xff09;的识别。该框架通过多视图超边感知的超图嵌入学习方法&#xff0c;整合…

[WUSTCTF2020]level1

关键知识点&#xff1a;for汇编 ida64打开&#xff1a; 00400666 55 push rbp .text:0000000000400667 48 89 E5 mov rbp, rsp .text:000000000040066A 48 83 EC 30 sub rsp, 30h .text:000000…

cpp自学 day20(文件操作)

基本概念 程序运行时产生的数据都属于临时数据&#xff0c;程序一旦运行结束都会被释放 通过文件可以将数据持久化 C中对文件操作需要包含头文件 <fstream> 文件类型分为两种&#xff1a; 文本文件 - 文件以文本的ASCII码形式存储在计算机中二进制文件 - 文件以文本的…

Gartner发布软件供应链安全市场指南:软件供应链安全工具的8个强制功能、9个通用功能及全球29家供应商

攻击者的目标是由开源和商业软件依赖项、第三方 API 和 DevOps 工具链组成的软件供应链。软件工程领导者可以使用软件供应链安全工具来保护他们的软件免受这些攻击的连锁影响。 主要发现 越来越多的软件工程团队现在负责解决软件供应链安全 (SSCS) 需求。 软件工件、开发人员身…

备赛蓝桥杯-Python-考前突击

额&#xff0c;&#xff0c;离蓝桥杯开赛还有十个小时&#xff0c;最近因为考研复习节奏的问题&#xff0c;把蓝桥杯的优先级后置了&#xff0c;突然才想起来还有一个蓝桥杯呢。。 到目前为止python基本语法熟练了&#xff0c;再补充一些常用函数供明天考前再背背&#xff0c;算…

榕壹云外卖跑腿系统:基于Spring Boot+MySQL+UniApp的智慧生活服务平台

项目背景与需求分析 随着本地生活服务需求的爆发式增长&#xff0c;外卖、跑腿等即时配送服务成为现代都市的刚性需求。传统平台存在开发成本高、功能定制受限等问题&#xff0c;中小企业及创业团队极需一款轻量级、可快速部署且支持二次开发的外卖跑腿系统。榕壹云外卖跑腿系统…

使用Docker安装Gogs

1、拉取镜像 docker pull gogs/gogs 2、运行容器 # 创建/var/gogs目录 mkdir -p /var/gogs# 运行容器 # -d&#xff0c;后台运行 # -p&#xff0c;端口映射&#xff1a;(宿主机端口:容器端口)->(10022:22)和(10880:3000) # -v&#xff0c;数据卷映射&#xff1a;(宿主机目…

【antd + vue】Modal 对话框:修改弹窗标题样式、Modal.confirm自定义使用

一、标题样式 1、目标样式&#xff1a;修改弹窗标题样式 2、问题&#xff1a; 直接在对应css文件中修改样式不生效。 3、原因分析&#xff1a; 可能原因&#xff1a; 选择器权重不够&#xff0c;把在控制台找到的选择器直接复制下来&#xff0c;如果还不够就再加&#xff…

Streamlit在测试领域中的应用:构建自动化测试报告生成器

引言 Streamlit 在开发大模型AI测试工具方面具有显著的重要性&#xff0c;尤其是在简化开发流程、增强交互性以及促进快速迭代等方面。以下是几个关键点&#xff0c;说明了 Streamlit 对于构建大模型AI测试工具的重要性&#xff1a; 1. 快速原型设计和迭代 对于大模型AI测试…

docker 运行自定义化的服务-后端

docker 运行自定义化的服务-前端-CSDN博客 运行自定义化的后端服务 具体如下&#xff1a; ①打包后端项目&#xff0c;形成jar包 ②编写dockerfile文件&#xff0c;文件内容如下&#xff1a; # 使用官方 OpenJDK 镜像 FROM jdk8:1.8LABEL maintainer"ATB" version&…