pytorch单机多卡训练 logger日志记录和wandb可视化

PyTorch 单机多卡训练示例

  • 1、工具:
  • 2、代码
  • 3、启动

1、工具:

  • wandb:云端保存训练记录,可实时刷新
  • logging:记录训练日志
  • argparse:设置全局参数

2、代码

import os
import time
import torch
import wandb
import argparse
import logging
from datetime import datetime
import torch.nn as nn
import torch.distributed as dist
from torch.utils.data import DataLoader, Dataset
import torch.optim.lr_scheduler as lr_scheduleros.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2"
os.environ["WANDB_MODE"] = "run"class MLP(nn.Module):def __init__(self) -> None:super().__init__()self.network = nn.Sequential(nn.Linear(5, 64),nn.SiLU(),nn.Linear(64, 32),nn.SiLU(),nn.Linear(32, 5),)def forward(self, x):x = self.network(x)return xclass RandomDataset(Dataset):def __init__(self, length):self.len = lengthself.data = torch.stack([torch.ones(5), torch.ones(5)*2,torch.ones(5)*3,torch.ones(5)*4,torch.ones(5)*5,torch.ones(5)*6,torch.ones(5)*7,torch.ones(5)*8,torch.ones(5)*9, torch.ones(5)*10,torch.ones(5)*11,torch.ones(5)*12,torch.ones(5)*13,torch.ones(5)*14,torch.ones(5)*15,torch.ones(5)*16]).to('cuda')self.label = torch.stack([torch.zeros(5), torch.zeros(5)*2,torch.zeros(5)*3,torch.zeros(5)*4,torch.zeros(5)*5,torch.zeros(5)*6,torch.zeros(5)*7,torch.zeros(5)*8,torch.zeros(5)*9, torch.zeros(5)*10,torch.zeros(5)*11,torch.zeros(5)*12,torch.zeros(5)*13,torch.zeros(5)*14,torch.zeros(5)*15,torch.zeros(5)*16]).to('cuda')def __getitem__(self, index):return [self.data[index], self.label[index]]def __len__(self):return self.lendef collate_batch(self, batch):"""实现一个自定义的batch拼接函数,这个函数将会被传递给DataLoader的collate_fn参数"""data = torch.stack([x[0] for x in batch])label = torch.stack([x[1] for x in batch])return [data, label]def create_logger(log_file=None, rank=0, log_level=logging.INFO):print("rank: ", rank)logger = logging.getLogger(__name__)logger.setLevel(log_level if rank == 0 else 'ERROR')  # 只有当rank=0时,才会输出info信息formatter = logging.Formatter('%(asctime)s  %(levelname)5s  %(message)s')console = logging.StreamHandler()console.setLevel(log_level if rank == 0 else 'ERROR')console.setFormatter(formatter)logger.addHandler(console)if log_file is not None:file_handler = logging.FileHandler(filename=log_file)file_handler.setLevel(log_level if rank == 0 else 'ERROR')file_handler.setFormatter(formatter)logger.addHandler(file_handler)logger.propagate = Falsereturn loggerdef parse_config():parser = argparse.ArgumentParser(description='arg parser')parser.add_argument('--batch_size', type=int, default=2, required=False, help='batch size for training')parser.add_argument('--train_epochs', type=int, default=100, required=False, help='number of epochs to train for')parser.add_argument('--train_lr', type=int, default=1e-3, required=False, help='training learning rate')parser.add_argument('--loader_num_workers', type=int, default=0, help='number of workers for dataloader')parser.add_argument('--training', type=bool, default=True, help='training or testing mode')parser.add_argument('--without_sync_bn', type=bool, default=False, help='whether to use Synchronization Batch Normaliation')parser.add_argument('--output_dir', type=bool, default="/home/", help='output directory for saving model and logs')args = parser.parse_args()return argsdef main():args = parse_config()device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")# 初始化分布式参数num_gpus = torch.cuda.device_count()local_rank = int(os.environ['LOCAL_RANK'])  # 每个卡上运行一个程序(系统自动分配local_rank)# local_rank = 0  # 在第一张卡运行多个程序torch.cuda.set_device(local_rank % num_gpus)  # 设置当前卡的device# 初始化分布式dist.init_process_group(backend="nccl",  # 指定后端通讯方式为nccl# init_method='tcp://localhost:23456',rank=local_rank,  # rank是指当前进程的编号world_size=num_gpus  # worla_size是指总共的进程数)rank = dist.get_rank()world_size = dist.get_world_size()os.makedirs(args.output_dir, exist_ok=True)# 配置日志logger_file = "/home/caihuaiguang/wjl/Waymo/test-project/log.txt"logger = create_logger(log_file=logger_file, rank=rank)logger.info("rank: %d, world_size: %d" % (rank, world_size))logger.info("local_rank: %d, num_gpus: %d" % (local_rank, num_gpus))wandb_logger = logging.getLogger("wandb")wandb_logger.setLevel(logging.WARNING)wandb_config = wandb.init(project="test-project", name=str(datetime.now().strftime('%Y.%m.%d-%H.%M.%S'))+f"rank-{rank}",dir=args.output_dir,)# 加载并切分数据集dataset = RandomDataset(16)if args.training:sampler = torch.utils.data.distributed.DistributedSampler(dataset)else:sampler = torch.utils.data.distributed.DistributedSampler(dataset, world_size, rank, shuffle=False)data_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=sampler,collate_fn=dataset.collate_batch, num_workers=args.loader_num_workers)# 加载模型model = MLP().to(device)# 统计模型参数量total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)logger.info(f"Total number of parameters: {total_params}")model = nn.parallel.DistributedDataParallel(model, device_ids=[rank % torch.cuda.device_count()])if not args.without_sync_bn:  # 不同卡之间进行同步批量正则化model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)model.train()# 定义优化器optimizer = torch.optim.AdamW(model.parameters(), lr=args.train_lr, weight_decay=0)# 定义学习率调度器milestones = [10, 30, 70]  # 在这些 epoch 后降低学习率gamma = 0.5  # 学习率降低的乘数因子scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=gamma)# 训练模型for iter in range(args.train_epochs):sampler.set_epoch(iter)logger.info("-------------epoch %d start----------------" % iter)for _, batch in enumerate(data_loader):data, label = batch[0], batch[1]data=data.to(device)label=label.to(device)output = model(data)loss = (label - output).pow(2).sum()optimizer.zero_grad()loss.backward()logger.info("loss: %s" % str(loss))wandb_config.log({"loss": loss})optimizer.step()scheduler.step()logger.info("-------------epoch %d end----------------" % iter)time.sleep(10)if __name__ == "__main__":main()

3、启动

启动训练代码,在终端输入

python -m torch.distributed.run --nproc_per_node=3 test.py

其中参数nproc_per_node=3表示使用3张 GPU 进行训练。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/735834.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

elementPlus的坑

记录由 element ui 到element plus的过程 el-form v-model与:model v-model就不用说了,这个:model类似于内置的API接口,用的时候这两个值一样就行 不一样的话会出现,如下奇怪的情况 能输入,但是只能文本框中只显示1个字符&#x…

jmeter快速使用

文章目录 前言一、安装jmeter二、插件安装三、添加常用监听器参考 前言 Apache JMeter may be used to test performance both on static and dynamic resources, Web dynamic applications. It can be used to simulate a heavy load on a server, group of servers, network…

Redis核心数据结构之整数集合

整数集合 概述 整数集合(intset)是集合键的底层实现之一,当一个集合只包含整数值元素,并且这个结合的元素数量不多时,Redis就会使用整数集合作为集合键的底层实现。 例子 举个例子,如果创建一个只包含五个元素的集合键&#x…

MySQL 8.0 架构 之 慢查询日志(Slow query log)(2)流程图:查询记录到慢查询日志中的条件

文章目录 MySQL 8.0 架构 之 慢查询日志(Slow query log)(2)流程图:查询记录到慢查询日志中的条件确定查询是否会记录在慢查询日志中的流程图参考 【声明】文章仅供学习交流,观点代表个人,与任何…

JavaScript数组方法常用方法大全

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 1. push()2. pop()3. unshift()4. shift()5. isArray()6. map()7. filter()8. every()9. some()10. splice()11. slice()12. indexOf()13. includes()14. concat()1…

RK3588 Android 12 系统内核开发+Native层脚本自启动+SELinux配置

前言 开发板型号:RK_EVB7_RK3588_LP4…_V11目标:在开发板上随开机自启动脚本,带起二进制程序,并完备一些其他系统功能。简介:本文自启动脚本run.sh唯一的作用就是拉起二进制程序demo;demo是简单的hello_wo…

Linux下阻塞IO驱动实验三的测试

一. 简介 前面一篇文章实现了驱动代码,以实现应用程序阻塞式访问设备,核心使用的Linux内核提供的阻塞IO机制:等待队列。文章地址如下: Linux下阻塞IO驱动实验实例三-CSDN博客 本文对驱动模块进行测试,测试按键功能是否正常,查看应用程序运行时CPU占用率是否接近0%,当…

【大厂AI课学习笔记NO.76】人工智能人才金字塔

人工智能领域,分为源头创新人才、产业研发人才、应用开发人才和实用技能人才。 人工智能领域的人才结构呈现多样化特点,主要可以分为源头创新人才、产业研发人才、应用开发人才和实用技能人才四大类。这四大类人才在人工智能领域的发展中各自扮演着不可或…

Android下使用OpenOCD

目录 1. 准备工作 2. 运行bootstrap 3. 运行Configure 4. 编译make 4.1 错误1 4.2 错误2 4.3 错误3 4.4 错误4 4.5 错误5 4.6 错误6 4.7 错误7 5. 安装 主要是使用NDK编译OpenOCD源码。最好先在Ubuntu中编译通过OpenOCD。 1. 准备工作 Ubuntu下下载NDK和OpenOCD&…

linux安全配置规范

一、 概述 1.1 适用范围 本配置规范适用于凝思操作系统,主要涉及LINUX操作系统安全配置方面的基本要求,用于指导LINUX操作系统安全加固工作,落实信息安全等级保护等保三级系统操作系统安全配置,为主机安全配置核查提供依据。…

Python刘诗诗

写在前面 刘诗诗在电视剧《一念关山》中饰演了女主角任如意,这是一个极具魅力的女性角色,她既是一位有着高超武艺和智慧的女侠士,也曾经是安国朱衣卫前左使,身怀绝技且性格坚韧不屈。剧中,任如意因不满于朱衣卫的暴行…

P1948 [USACO08JAN] Telephone Lines S

Here 典中之典!! 解题思路 可选k条边代价为0如何决策? 将到当前位置选择了几条代价为0的边放入状态,即若当前状态选的边数小于,则可以进行决策,是否选择当前边,若选,则&#xff0c…

基于智慧灯杆的智慧城市解决方案(2)

功能规划 智慧照明功能 智慧路灯的基本功能仍然是道路照明, 因此对照明功能的智慧化提升是最基本的一项要求。 对道路照明管理进行智慧化提升, 实施智慧照明, 必然将成为智慧城市中道路照明发展的主要方向之一。 智慧照明是集计算机网络技术、 通信技术、 控制技术、 数据…

uniapp:小程序数字键盘功能样式实现

代码如下&#xff1a; <template><view><view><view class"money-input"><view class"input-container" click"toggleBox"><view class"input-wrapper"><view class"input-iconone"…

C++ 队列

目录 队列的应用场景 1、429. N 叉树的层序遍历 2、 103. 二叉树的锯齿形层序遍历 3、662. 二叉树最大宽度 4、515. 在每个树行中找最大值 队列的应用场景 广度优先搜索&#xff08;BFS&#xff09;&#xff1a;队列是广度优先搜索算法的核心数据结构。在BFS中&#xff…

C语言:深入补码计算原理

C语言&#xff1a;深入补码计算原理 有符号整数存储原码、反码、补码转换规则数据与内存的关系 补码原理 有符号整数存储 原码、反码、补码 有符号整数的2进制表示方法有三种&#xff0c;即原码、反码和补码 三种表示方法均有符号位和数值位两部分&#xff0c;符号位用0表示“…

Linux:kubernetes(k8s)lable和selecto标签和选择器的使用(11)

通过标签是可以让我们的容器和容器之间相互认识&#xff0c;简单来说一边打了标签&#xff0c;一边使用选择器去选择就可以快速的让他们之间耦合 定义标签有两种办法&#xff0c;一个是文件中&#xff0c;一个是命令行里 我们在前几章编进文件的时候里面都有lable比如 这个就是…

rk3399使用阿里推理引擎MNN使用cpu和gpu进行benchmark,OpenCL效果不佳?

视频讲解 rk3399使用阿里推理引擎MNN使用cpu和gpu进行benchmark&#xff0c;OpenCL效果不佳&#xff1f; 背景 MNN是阿里开源的推理引擎&#xff0c;今天测试一下在rk3399平台上的benchmark怎么样&#xff1f; alibaba/MNN: MNN is a blazing fast, lightweight deep learning…

keycloak18.0.0==前后端分离项目中使用,前端react后端springboot

配置keycloak 启动keycloak18 新建一个realm,名字叫test1 新建两个client&#xff0c;一个用于前端&#xff0c;一个用于后端 第一个 react http://localhost:8081/auth/realms/test1/react/ 第二个 backend-service 在两个client下分别创建role testRole backend-servic…

可免费使用的AI平台汇总 + 常用赋能科研的AI工具推荐

赋能科研&#xff0c;AI工具助你飞跃学术巅峰&#xff01;(推荐收藏) 文章目录 赋能科研&#xff0c;AI工具助你飞跃学术巅峰&#xff01;(推荐收藏)一、可免费使用的AI平台汇总1. ChatGPT2. New Bing3. Slack4. POE5. Vercel6. 其他平台7. 特定功能平台8. 学术资源平台9. 中文…