pytoch单卡改多卡ddp训练

参考:https://blog.csdn.net/weixin_44966641/article/details/121872773
单卡代码,启动代码 python train.py:

import torch
import torch.nn as nn
from torch.optim import SGD
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import os
import argparseparser = argparse.ArgumentParser()
parser.add_argument('--gpu_id', type=str, default='0,2')
parser.add_argument('--batchSize', type=int, default=32)
parser.add_argument('--epochs', type=int, default=500)
parser.add_argument('--dataset-size', type=int, default=128)
parser.add_argument('--num-classes', type=int, default=10)
config = parser.parse_args()os.environ['CUDA_VISIBLE_DEVICES'] = config.gpu_id# 定义一个随机数据集,随机生成样本
class RandomDataset(Dataset):def __init__(self, dataset_size, image_size=32):images = torch.randn(dataset_size, 3, image_size, image_size)labels = torch.zeros(dataset_size, dtype=int)self.data = list(zip(images, labels))def __getitem__(self, index):return self.data[index]def __len__(self):return len(self.data)# 定义模型,简单的一层卷积加一层全连接softmax
class Model(nn.Module):def __init__(self, num_classes):super(Model, self).__init__()self.conv2d = nn.Conv2d(3, 16, 3)self.fc = nn.Linear(30*30*16, num_classes)self.softmax = nn.Softmax(dim=1)def forward(self, x):batch_size = x.shape[0]x = self.conv2d(x)x = x.reshape(batch_size, -1)x = self.fc(x)out = self.softmax(x)return out# 实例化模型、数据集、加载器和优化器
model = Model(config.num_classes)
dataset = RandomDataset(config.dataset_size)
loader = DataLoader(dataset, batch_size=config.batchSize, shuffle=True)
loss_func = nn.CrossEntropyLoss()if torch.cuda.is_available():model.cuda()
optimizer = SGD(model.parameters(), lr=0.1, momentum=0.9)# 若使用DP,仅需一行
# if torch.cuda.device_count > 1: model = nn.DataParallel(model)
# 我们不用DP,而将用DDP# 开始训练
for epoch in range(config.epochs):for step, (images, labels) in enumerate(loader):if torch.cuda.is_available(): images = images.cuda()labels = labels.cuda()preds = model(images)loss = loss_func(preds, labels)optimizer.zero_grad()loss.backward()optimizer.step()print(f'Step: {step}, Loss: {loss.item()}')print(f'Epoch {epoch} Finished !')

多卡ddp训练代代码,启动代码:

import torch
import torch.nn as nn
from torch.optim import SGD
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
import os
import argparse# 定义一个随机数据集
class RandomDataset(Dataset):def __init__(self, dataset_size, image_size=32):images = torch.randn(dataset_size, 3, image_size, image_size)labels = torch.zeros(dataset_size, dtype=int)self.data = list(zip(images, labels))def __getitem__(self, index):return self.data[index]def __len__(self):return len(self.data)# 定义模型
class Model(nn.Module):def __init__(self, num_classes):super(Model, self).__init__()self.conv2d = nn.Conv2d(3, 16, 3)self.fc = nn.Linear(30*30*16, num_classes)self.softmax = nn.Softmax(dim=1)def forward(self, x):batch_size = x.shape[0]x = self.conv2d(x)x = x.reshape(batch_size, -1)x = self.fc(x)out = self.softmax(x)return outparser = argparse.ArgumentParser()
parser.add_argument('--gpu_id', type=str, default='0,1,2,3')
parser.add_argument('--batchSize', type=int, default=64)
parser.add_argument('--epochs', type=int, default=500)
parser.add_argument('--dataset-size', type=int, default=1024)
parser.add_argument('--num-classes', type=int, default=10)
config = parser.parse_args()os.environ['CUDA_VISIBLE_DEVICES'] = config.gpu_id
torch.distributed.init_process_group(backend='nccl', init_method='env://')local_rank = torch.distributed.get_rank()
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)# 实例化模型、数据集和加载器loader
model = Model(config.num_classes)dataset = RandomDataset(config.dataset_size)
sampler = DistributedSampler(dataset) # 这个sampler会自动分配数据到各个gpu上
loader = DataLoader(dataset, batch_size=config.batchSize, sampler=sampler)# loader = DataLoader(dataset, batch_size=config.batchSize, shuffle=True)
loss_func = nn.CrossEntropyLoss()if torch.cuda.is_available():model.cuda()
model = torch.nn.parallel.DistributedDataParallel(model)
optimizer = SGD(model.parameters(), lr=0.1, momentum=0.9)# 开始训练
for epoch in range(config.epochs):for step, (images, labels) in enumerate(loader):if torch.cuda.is_available(): images = images.cuda()labels = labels.cuda()preds = model(images)# print(f"data: {images.device}, model: {next(model.parameters()).device}")loss = loss_func(preds, labels)optimizer.zero_grad()loss.backward()optimizer.step()print(f'Step: {step}, Loss: {loss.item()}')print(f'Epoch {epoch} Finished !')

启动代码

CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.run --nproc_per_node=4 --master_port 12355 train-tmp.py --batchSize 256 --epochs 5000

或者

 torchrun  --nproc_per_node=2 train-tmp.py --batchSize 64 --epochs 500

这里的卡设置要小于等于代码里可见的卡设置。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/884828.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【学习】使用webpack搭建react项目

前言 在日常工作中,我大多是在已有的项目基础上进行开发,而非从头构建项目。因此,我期望通过本次学习能填补我在项目初始化阶段知识的空白,与大家共同进步。在此过程中,我欢迎并感激任何指正或建议,无论是…

什么是人工智能体?

人工智能体(AI Agent)是指能够感知环境、做出决策并采取行动以实现特定目标的自主实体。以下是对人工智能体的具体介绍: 定义与核心概念 智能体的定义:智能体,英文名为Agent,是指具有智能的实体&#xff0…

【笔记】前后端互通中前端登录无响应

后来的前情提要 : 后端的ip地址在本地测试阶段应该设置为localhost 前端中写cors的配置 后端也要写cors的配置 且两者的url都要为localhost 前端写的baseUrl是指定对应的后端的ip地址以及端口号 很重要 在本地时后端的IP的地址也必须为本地的 F12的网页报错是&a…

【初阶数据结构篇】链式结构二叉树(续)

文章目录 须知 💬 欢迎讨论:如果你在学习过程中有任何问题或想法,欢迎在评论区留言,我们一起交流学习。你的支持是我继续创作的动力! 👍 点赞、收藏与分享:觉得这篇文章对你有帮助吗&#xff1…

跨平台Flutter 、ReactNative 开发原理

一、跨平台Flutter开发原理 Flutter是一个跨平台的应用程序开发框架,它允许你使用一组代码库来构建同时运行在Android和iOS上的应用程序。 1.1.Flutter的核心原理基于以下几点: Dart异步、Widget构建块灵活配置、自工化工具链、热重载、Skia图库、Dar…

二叉树 最大深度(递归)

给定一个二叉树 root ,返回其最大深度。 二叉树的 最大深度 是指从根节点到最远叶子节点的最长路径上的节点数。 示例 1: 输入:root [3,9,20,null,null,15,7] 输出:3示例 2: 输入:root [1,null,2] 输出…

python机器人Agent编程——实现一个本地大模型和爬虫结合的手机号归属地天气查询Agent

目录 一、前言二、准备工作三、Agent结构四、python模块实现4.1 实现手机号归属地查询工具4.2实现天气查询工具4.3定义创建Agent主体4.4创建聊天界面 五、小结PS.扩展阅读ps1.六自由度机器人相关文章资源ps2.四轴机器相关文章资源ps3.移动小车相关文章资源ps3.wifi小车控制相关…

python安装selenium,geckodriver,chromedriver

安装浏览器 找到浏览器的版本号 chrome 版本 130.0.6723.92(正式版本) (64 位) firfox 116.0.3 (64 位),但是后面运行的时候又自动更新到了 127.0.0.8923 安装selenium > pip install selenium > pip show …

基于SSM+uniapp的营养食谱系统+LW参考示例

1.项目介绍 功能模块:用户管理、年龄类型管理、阶段食谱管理、体质类型管理、季节食谱管理、职业食谱管理等系统角色:管理员、普通用户技术栈:SSM,uniapp, Vue等测试环境:idea2024,HbuilderX&a…

python常用的第三方库下载方法

方法一:打开pycharm-打开项目-点击左侧图标查看已下载的第三方库-没有下载搜索后点击install即可直接安装--安装成功后会显示在installed列表 方法二:打开dos窗口输入命令“pip install requests“后按回车键,看到successfully既安装成功&…

vue项目安装组件失败解决方法

1.vue项目 npm install 失败 删除node_modules文件夹、package-lock.json 关掉安装对话框 重新打开对话框 npm install

qt QComboBox详解

QComboBox是一个下拉选择框控件,用于从多个选项中选择一个。通过掌握QComboBox 的用法,你将能够在 Qt 项目中轻松添加和管理组合框组件,实现复杂的数据选择和交互功能。 重要方法 addItem(const QString &text):将一个项目添…

window 利用Putty免密登录远程服务器

1 在本地电脑用putty-gen生成密钥 参考1 参考2 2 服务器端操作 将公钥上传至Linux服务器。 复制上述公钥到服务器端的authorized_keys文件 mkdir ~/.ssh vi ~/.ssh/authorized_keys在vi编辑器中,按下ShiftInsert键或者右键选择粘贴,即可将剪贴板中的文…

JAVA基础:多重循环、方法、递归 (习题笔记)

一&#xff0c;编码题 1.打印九九乘法表 import java.util.*;public class PanTi {public static void main(String[] args) {Scanner input new Scanner(System.in);for (int i 0; i < 9; i) {//i控制行数/* System.out.println("。\t。\t。\t。\t。\t。\t。\t。\…

微服务系列二:跨微服务请求优化,注册中心+OpenFeign

目录 前言 一、纯 RestTemplate 方案存在的缺陷 二、注册中心模式介绍 三、注册中心技术&#xff1a;Nacos 3.1 Docker部署Nacos 3.2 服务注册 3.3 服务发现 四、代码优化&#xff1a;OpenFeign工具 4.1 OpenFeign快速入门 4.2 连接池的必要性 4.3 抽取服务、最佳实…

国产数据库之Vastbase海量数据库 G100

海量数据库Vastbase是基于openGauss内核开发的企业级关系型数据库。其语法和Oracle数据库很像&#xff0c;基本是从Oracle数据库迁移到海量数据库&#xff0c;以下简单介绍入门的使用 1、建库操作 地址&#xff1a;x.x.x.x root/Qa2021 安装路径&#xff1a;/home/vastbase 创…

QStackedWidget使用实例

制作一个页面X&#xff0c;该页面X中有两个按钮a&#xff0c;b&#xff1b;和两个页面A&#xff0c;B。要求点击按钮a时&#xff0c;显示页面A。点击按钮b时&#xff0c;显示页面B。 要求按钮位于页面A或B的上方&#xff0c;并且依次由左向右间隔排列。 要求将页面A分为两个子页…

爬虫学习4

from threading import Thread#创建任务 def func(name):for i in range(100):print(name,i)if __name__ __main__:#创建线程t1 Thread(targetfunc,args("1"))t2 Thread(targetfunc, args("2"))t1.start()t2.start()print("我是诛仙剑")from …

oracle使用CTE递归分解字符串

oracle使用CTE递归分解字符串 背景 给定一个不定长度字符串 并且以&#xff0c;分割例如 ‘1&#xff0c;2&#xff0c;3&#xff0c;4’ 使用sql查询 返回1&#xff0c;2&#xff0c;3&#xff0c;4四行 如果‘1&#xff0c;2’ 则返回 1&#xff0c;2 两行 使用sql实现 实…

不要只知道deepl翻译,这里有10个专业好用的翻译工具等着你。

deepl翻译的优点还是有很多的&#xff0c;比如翻译的准确性很高&#xff0c;支持翻译的语言有很多&#xff0c;并且支持翻译文件和文本。但是现在翻译工具那么多&#xff0c;大家需要翻译的场景也有很多&#xff0c;怎么能只拥有一个翻译工具呢。所以在这里我帮助大家寻找了一波…