pytorch基于ray和accelerate实现多GPU数据并行的模型加速训练

在pytorch的DDP原生代码使用的基础上,ray和accelerate两个库对于pytorch并行训练的代码使用做了更加友好的封装。

以下为极简的代码示例。

ray

ray.py

#coding=utf-8
import os
import sys
import time
import numpy as np
import torch
from torch import nn
import torch.utils.data as Data
import ray
from ray.train.torch import TorchTrainer
from ray.air.config import ScalingConfig
import onnxruntime# bellow code use AI model to simulate linear regression, formula is: y = x1 * w1 + x2 * w2 + b
# --- DDP RAY --- # # model structure
class LinearNet(nn.Module):def __init__(self, n_feature):super(LinearNet, self).__init__()self.linear = nn.Linear(n_feature, 1)def forward(self, x):y = self.linear(x)return y# whole train task
def train_task():print("--- train_task, pid: ", os.getpid())# device settingdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print("device:", device)device_ids = torch._utils._get_all_device_indices()print("device_ids:", device_ids)if len(device_ids) <= 0:print("invalid device_ids, exit")return# prepare datanum_inputs = 2num_examples = 1000true_w = [2, -3.5]true_b = 3.7features = torch.tensor(np.random.normal(0, 1, (num_examples, num_inputs)), dtype=torch.float)labels = true_w[0] * features[:, 0] + true_w[1] * features[:, 1] + true_b + torch.tensor(np.random.normal(0, 0.01, size=num_examples), dtype=torch.float)# load databatch_size = 10dataset = Data.TensorDataset(features, labels)data_iter = Data.DataLoader(dataset, batch_size, shuffle=True)for X, y in data_iter:print(X, y)breakdata_iter = ray.train.torch.prepare_data_loader(data_iter)# model define and initmodel = LinearNet(num_inputs)ddp_model = ray.train.torch.prepare_model(model)print(ddp_model)# cost functionloss = nn.MSELoss()# optimizeroptimizer = torch.optim.SGD(ddp_model.parameters(), lr=0.03)# trainnum_epochs = 6for epoch in range(1, num_epochs + 1):batch_count = 0sum_loss = 0.0for X, y in data_iter:output = ddp_model(X)l = loss(output, y.view(-1, 1))optimizer.zero_grad()l.backward()optimizer.step()batch_count += 1sum_loss += l.item()print('epoch %d, avg_loss: %f' % (epoch, sum_loss / batch_count))# save modelprint("save model, pid: ", os.getpid())torch.save(ddp_model.module.state_dict(), "ddp_ray_model.pt")def ray_launch_task():num_workers = 2scaling_config = ScalingConfig(num_workers=num_workers, use_gpu=True)trainer = TorchTrainer(train_loop_per_worker=train_task, scaling_config=scaling_config)results = trainer.fit()def predict_task():print("--- predict_task")# prepare datanum_inputs = 2num_examples = 20true_w = [2, -3.5]true_b = 3.7features = torch.tensor(np.random.normal(0, 1, (num_examples, num_inputs)), dtype=torch.float)labels = true_w[0] * features[:, 0] + true_w[1] * features[:, 1] + true_b + torch.tensor(np.random.normal(0, 0.01, size=num_examples), dtype=torch.float)model = LinearNet(num_inputs)model.load_state_dict(torch.load("ddp_ray_model.pt"))model.eval()x, y = features[6], labels[6]pred_y = model(x)print("x:", x)print("y:", y)print("pred_y:", y)if __name__ == "__main__":print("==== task begin ====")print("python version:", sys.version)print("torch version:", torch.__version__)print("model name:", LinearNet.__name__)ray_launch_task()# predict_task()print("==== task end ====")

accelerate

acc.py

#coding=utf-8
import os
import sys
import time
import numpy as np
from accelerate import Accelerator
import torch
from torch import nn
import torch.utils.data as Data
import onnxruntime# bellow code use AI model to simulate linear regression, formula is: y = x1 * w1 + x2 * w2 + b
# --- accelerate --- # # model structure
class LinearNet(nn.Module):def __init__(self, n_feature):super(LinearNet, self).__init__()self.linear = nn.Linear(n_feature, 1)def forward(self, x):y = self.linear(x)return y# whole train task
def train_task():print("--- train_task, pid: ", os.getpid())# device settingdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print("device:", device)device_ids = torch._utils._get_all_device_indices()print("device_ids:", device_ids)if len(device_ids) <= 0:print("invalid device_ids, exit")return# prepare datanum_inputs = 2num_examples = 1000true_w = [2, -3.5]true_b = 3.7features = torch.tensor(np.random.normal(0, 1, (num_examples, num_inputs)), dtype=torch.float)labels = true_w[0] * features[:, 0] + true_w[1] * features[:, 1] + true_b + torch.tensor(np.random.normal(0, 0.01, size=num_examples), dtype=torch.float)# load databatch_size = 10dataset = Data.TensorDataset(features, labels)data_iter = Data.DataLoader(dataset, batch_size, shuffle=True)for X, y in data_iter:print(X, y)break# model define and initmodel = LinearNet(num_inputs)# cost functionloss = nn.MSELoss()# optimizeroptimizer = torch.optim.SGD(model.parameters(), lr=0.03)accelerator = Accelerator()model, optimizer, data_iter = accelerator.prepare(model, optimizer, data_iter) # automatically move model and data to gpu as config# trainnum_epochs = 3for epoch in range(1, num_epochs + 1):batch_count = 0sum_loss = 0.0for X, y in data_iter:output = model(X)l = loss(output, y.view(-1, 1))optimizer.zero_grad()accelerator.backward(l)optimizer.step()batch_count += 1sum_loss += l.item()print('epoch %d, avg_loss: %f' % (epoch, sum_loss / batch_count))# save modeltorch.save(model, "acc_model.pt")def predict_task():print("--- predict_task")# prepare datanum_inputs = 2num_examples = 20true_w = [2, -3.5]true_b = 3.7features = torch.tensor(np.random.normal(0, 1, (num_examples, num_inputs)), dtype=torch.float)labels = true_w[0] * features[:, 0] + true_w[1] * features[:, 1] + true_b + torch.tensor(np.random.normal(0, 0.01, size=num_examples), dtype=torch.float)model = torch.load("acc_model.pt")model.eval()x, y = features[6], labels[6]pred_y = model(x)print("x:", x)print("y:", y)print("pred_y:", y)if __name__ == "__main__":# launch method: use command line# for example# accelerate launch ACC.py print("python version:", sys.version)print("torch version:", torch.__version__)print("model name:", LinearNet.__name__)train_task()predict_task()print("==== task end ====")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/50760.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用 Amazon Redshift Serverless 和 Toucan 构建数据故事应用程序

这是由 Toucan 的解决方案工程师 Django Bouchez与亚马逊云科技共同撰写的特约文章。 带有控制面板、报告和分析的商业智能&#xff08;BI&#xff0c;Business Intelligence&#xff09;仍是最受欢迎的数据和分析使用场景之一。它为业务分析师和经理提供企业的过去状态和当前状…

MQ消息队列(主要介绍RabbitMQ)

消息队列概念&#xff1a;是在消息的传输过程中保存消息的容器。 作用&#xff1a;异步处理、应用解耦、流量控制..... RabbitMQ&#xff1a; SpringBoot继承RabbitMQ步骤&#xff1a; 1.加入依赖 <dependency><groupId>org.springframework.boot</groupId&g…

隧道HTTP具备的条件

作为一名专业的爬虫代理供应商&#xff0c;我们都知道使用代理是保证爬虫的高效性和稳定性的重要手段之一。而隧道代理则是近年来备受推崇的一种代理形式&#xff0c;它通过将请求通过隧道传输&#xff0c;可以有效地隐藏爬虫的真实IP地址&#xff0c;提高爬虫的反爬能力。 在…

ConfigMap(可变应用配置管理)

实验环境 实验环境&#xff1a; 1、win10,vmwrokstation虚机&#xff1b; 2、k8s集群&#xff1a;3台centos7.6 1810虚机&#xff0c;1个master节点,2个node节点k8s version&#xff1a;v1.22.2containerd://1.5.5实验软件(无) 1 基础知识 1.1 什么是ConfigMap(可变配置管理…

matlab工具箱Filter Designer设计butterworth带通滤波器

1、在matlab控制界面输入fdatool; 2、在显示的界面中选择合适的参数&#xff1b;本实验中采样频率是200&#xff0c;低通30hz&#xff0c;高通60hz,点击butterworth滤波器。 3、点击设计滤波器按钮后&#xff0c;在生成的界面点击红框按钮&#xff0c;可生成simulink模型到当前…

算法通关村十一关 | 位运算的规则

1.数字在计算机中的表示 机器数&#xff1a;一个数在计算机中的二进制表示形式&#xff0c;叫做这个数的机器数。机器数是自带符号的&#xff0c;在计算机用一个数的最高位存放符号&#xff0c;整数为0&#xff0c;负数为1。比如&#xff0c;十进制中的数3&#xff0c;计算机字…

Linux下彻底卸载jenkins

文章目录 1、停服务进程2、查找安装目录3、删掉相关目录4、确认已完全删除 1、停服务进程 查看jenkins服务是否在运行&#xff0c;如果在运行&#xff0c;停掉 ps -ef|grep jenkins kill -9 XXX2、查找安装目录 find / -name "jenkins*"3、删掉相关目录 # 删掉相…

一文全懂!带你了解芯片“流片”!

一、流片是什么&#xff1f; 流片(tape-out)是指通过一系列工艺步骤在流水线上制造芯片&#xff0c;是集成电路设计的最后环节&#xff0c;也就是送交制造。 流片即为"试生产"&#xff0c;简单来说就是设计完电路以后&#xff0c;先生产几片几十片&#xff0c;供测试…

【SQL语句】SQL编写规范

简介 本文编写原因主要来于XC迁移过程中修改SQL语句时&#xff0c;发现大部分修改均源自于项目SQL编写不规范&#xff0c;以此文档做以总结。 注&#xff1a;此文档覆盖不甚全面&#xff0c;大体只围绕迁移遇到的修改而展开。 正文 1、【字段引号】 列名、表名如无特殊情况…

使用BeeWare实现iOS调用Python

1、准备工作 1.1、安装Python 1.2、设置虚拟环境 我们现在将创建一个虚拟环境——一个“沙盒”&#xff0c;如果我们将软件包安装到虚拟环境中&#xff0c;我们计算机上的任何其他Python项目将不会受到影响。如果我们把虚拟环境搞得一团糟&#xff0c;我们将能够简单地删除它…

C++入门:内联函数,auto,范围for循环,nullptr

目录 1.内联函数 1.1 概念 1.2 特性 1.3 内联函数与宏的区别 2.auto关键字(C11) 2.1 auto简介 2.2 auto的使用细则 2.3 auto不能推导的场景 3.基于范围的for循环(C11) 3.1 范围for的语法 3.2 范围for的使用方法 4.指针空值nullptr(C11) 4.1 C98中的指针空值 1.内联…

面试问题记录

1.多线程&#xff0c;线程池 1.如何创建线程 实现 Runnable 接口&#xff0c;重写run方法&#xff1b;实现 Callable 接口&#xff0c;重写call方法&#xff1b;继承 Thread 类&#xff0c;重写run方法。 2.基础线程机制 Executors&#xff1a;可以创建四种类型的线程池&am…

15.树与二叉树基础

目录 一. 树&#xff0c;基本术语 二. 二叉树 &#xff08;1&#xff09;二叉树 &#xff08;2&#xff09;满二叉树 &#xff08;3&#xff09;完全二叉树 三. 二叉树的性质 四. 二叉树的存储结构 &#xff08;1&#xff09;顺序存储结构 &#xff08;2&#xff09;链…

CSerialPort教程4.3.x (3) - CSerialPort在MFC中的使用

CSerialPort教程4.3.x (3) - CSerialPort在MFC中的使用 环境&#xff1a; 系统&#xff1a;windows 10 64位 编译器&#xff1a;Visual Studio 2008前言 CSerialPort项目是一个基于C/C的轻量级开源跨平台串口类库&#xff0c;可以轻松实现跨平台多操作系统的串口读写&#x…

C#__自定义类传输数据和前台线程和后台线程

// 前台线程和后台线程 // 默认情况下&#xff0c;用Thread类创建的线程是前台线程。线程池中的线程总是后台线程。 // 用Thread类创建线程的时候&#xff0c;可以设置IsBackground属性&#xff0c;表示一个后台线程。 // 前台线程在主函数运行结束后依旧执行&#xff0c;后台线…

golang的继承

golang中并没有继承以及oop&#xff0c;但是我们可以通过struct嵌套来完成这个操作。 定义struct 以下定义了一个Person结构体&#xff0c;这个结构体有Eat方法以及三个属性 type Person struct {Name stringAge uint16Phone string }func (recv *Person) Eat() {fmt.Prin…

01.Django入门

1.创建项目 1.1基于终端创建Django项目 打开终端进入文件路径&#xff08;打算将项目放在哪个目录&#xff0c;就进入哪个目录&#xff09; E:\learning\python\Django 执行命令创建项目 F:\Anaconda3\envs\pythonWeb\Scripts\django-admin.exe&#xff08;Django-admin.exe所…

微信支付-使用微信JSSDK完成微信支付

前言 使用微信JSSDK完成微信支付 一、安装weixin-js-sdk npm install weixin-js-sdk二、引入 var jweixin require(jweixin-module);三、使用 调用接口 一般调用成功会返回debug&#xff0c;appId&#xff0c;timestamp&#xff0c;nonceStr&#xff0c;signature等参数注…

RK3588平台开发系列讲解(AI 篇)RKNN-Toolkit2 API 介绍

文章目录 一、RKNN 初始化及对象释放二、RKNN 模型配置沉淀、分享、成长,让自己和他人都能有所收获!😄 📢本篇章主要讲解 RKNN-Toolkit2 API 详细说明。 一、RKNN 初始化及对象释放 在使用 RKNN Toolkit2 的所有 API 接口时,都需要先调用 RKNN()方法初始化 RKNN 对象,…

HAProxy 调度算法介绍

HAProxy 调度算法介绍 HAProxy 的调度算法比较多&#xff0c;在没有设置 mode 或者其它选项时&#xff0c;HAProxy 默认对 后端服务器使用 roundrobin 算法来分配请求处理。对后端服务器指明使用的算法 时使用balance关键字&#xff0c;该关键字可在listen和backend中出现。在…