分布式执行引擎ray入门--(3)Ray Train

Ray Train中包含4个部分

  1. Training function: 包含训练模型逻辑的函数

  2. Worker: 用来跑训练的

  3. Scaling configuration: 配置

  4. Trainer: 协调以上三个部分

Ray Train+PyTorch

这一块比较建议直接去官网看diff,官网色块标注的比较清晰,非常直观。

import os
import tempfileimport torch
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor, Normalize, Composeimport ray.train.torchdef train_func(config):# Model, Loss, Optimizermodel = resnet18(num_classes=10)model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)# model.to("cuda")  # This is done by `prepare_model`# [1] Prepare model.model = ray.train.torch.prepare_model(model)criterion = CrossEntropyLoss()optimizer = Adam(model.parameters(), lr=0.001)# Datatransform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])data_dir = os.path.join(tempfile.gettempdir(), "data")train_data = FashionMNIST(root=data_dir, train=True, download=True, transform=transform)train_loader = DataLoader(train_data, batch_size=128, shuffle=True)# [2] Prepare dataloader.train_loader = ray.train.torch.prepare_data_loader(train_loader)# Trainingfor epoch in range(10):for images, labels in train_loader:# This is done by `prepare_data_loader`!# images, labels = images.to("cuda"), labels.to("cuda")outputs = model(images)loss = criterion(outputs, labels)optimizer.zero_grad()loss.backward()optimizer.step()# [3] Report metrics and checkpoint.metrics = {"loss": loss.item(), "epoch": epoch}with tempfile.TemporaryDirectory() as temp_checkpoint_dir:torch.save(model.module.state_dict(),os.path.join(temp_checkpoint_dir, "model.pt"))ray.train.report(metrics,checkpoint=ray.train.Checkpoint.from_directory(temp_checkpoint_dir),)if ray.train.get_context().get_world_rank() == 0:print(metrics)# [4] Configure scaling and resource requirements.
scaling_config = ray.train.ScalingConfig(num_workers=2, use_gpu=True)# [5] Launch distributed training job.
trainer = ray.train.torch.TorchTrainer(train_func,scaling_config=scaling_config,# [5a] If running in a multi-node cluster, this is where you# should configure the run's persistent storage that is accessible# across all worker nodes.# run_config=ray.train.RunConfig(storage_path="s3://..."),
)
result = trainer.fit()# [6] Load the trained model.
with result.checkpoint.as_directory() as checkpoint_dir:model_state_dict = torch.load(os.path.join(checkpoint_dir, "model.pt"))model = resnet18(num_classes=10)model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)model.load_state_dict(model_state_dict)

模型 

  ray.train.torch.prepare_model() 

model = ray.train.torch.prepare_model(model)
相当于model.to(device_id or "cpu") +  DistributedDataParallel(model, device_ids=[device_id])

将model移动到合适的device上,同时实现分布式

数据

ray.train.torch.prepare_data_loader() 

报告 checkpoints 和 metrics

+import ray.train
+from ray.train import Checkpointdef train_func(config):...torch.save(model.state_dict(), f"{checkpoint_dir}/model.pth"))
+    metrics = {"loss": loss.item()} # Training/validation metrics.
+    checkpoint = Checkpoint.from_directory(checkpoint_dir) # Build a Ray Train checkpoint from a directory
+    ray.train.report(metrics=metrics, checkpoint=checkpoint)...
data_loader = ray.train.torch.prepare_data_loader(data_loader)

将batches移动到合适的device上,同时实现分布式sampler

配置 scale 和 GPUs

from ray.train import ScalingConfig
scaling_config = ScalingConfig(num_workers=2, use_gpu=True)

配置持久化存储

多节点分布式训练时必须指定,本地路径会有问题。

from ray.train import RunConfig# Local path (/some/local/path/unique_run_name)
run_config = RunConfig(storage_path="/some/local/path", name="unique_run_name")# Shared cloud storage URI (s3://bucket/unique_run_name)
run_config = RunConfig(storage_path="s3://bucket", name="unique_run_name")# Shared NFS path (/mnt/nfs/unique_run_name)
run_config = RunConfig(storage_path="/mnt/nfs", name="unique_run_name")

启动训练任务

from ray.train.torch import TorchTrainertrainer = TorchTrainer(train_func, scaling_config=scaling_config, run_config=run_config
)
result = trainer.fit()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/735473.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MySQL-视图:视图概述、使用视图注意点、视图是否影响基本表

视图 一、视图概述二、使用视图注意点三、视图操作是否影响基本表 一、视图概述 在数据库管理系统中,视图(View)是一种虚拟表,它并不实际存储数据,而是基于一个或多个实际表的查询结果。视图提供了一种对数据库中数据…

RabbitMQ备份交换机

1. 备份交换机 备份交换机可以理解为 RabbitMQ 中交换机的“备胎”,当我们为某一个交换机声明一个对应的备份交换机时,就是为它创建一个备胎,当交换机接收到一条不可路由消息时,将会把这条消息转发到备份交换机中,由备…

reids设计与实现(一)——数据对象

文章目录 1. 前言2. redis 动态字符串2.1. 字符串的数据结构:2.2. 剖析,length;2.3. 剖析,free;2.3. 使用c字符串函数; 3. redis 链表4. 字典5. 跳跃表 1. 前言 reids作为最常用的缓存数据库,深…

【MATLAB】MATLAB学习笔记

MATLAB入门 基础操作变量命名数据类型逻辑和流程控制循环结构分支结构 绘图基本操作二维平面绘图绘图参数三位立体绘图图像窗口的分割 本文参考B站视频:BV13D4y1Q7RS 由于我对于C语言很熟悉,很多语法是会参考C来学 基础操作 清屏%% 清空环境变量及命令 …

图腾柱PFC工作原理:一张图

视屏链接: PFC工作原理

docker学习笔记——Dockerfile

Dockerfile是一个镜像描述文件,通过Dockerfile文件可以构建一个属于自己的镜像。 如何通过Dockerfile构建自己的镜像: 在指定位置创建一个Dockerfile文件,在文件中编写Dockerfile相关语法。 构建镜像,docker build -t aa:1.0 .(指…

【每日一题】2834. 找出美丽数组的最小和-2024.3.8

题目: 2834. 找出美丽数组的最小和 给你两个正整数:n 和 target 。 如果数组 nums 满足下述条件,则称其为 美丽数组 。 nums.length n.nums 由两两互不相同的正整数组成。在范围 [0, n-1] 内,不存在 两个 不同 下标 i 和 j &…

阿里云实现两个VPC网络资源互通

背景 由于实际项目预算有限,两套环境虽然分别属于不同的专有网络即不同的VPC,但是希望借助一台运维机器实现对两个环境的监控和日常的运维操作 网络架构 如下是需要实现的外网架构图,其中希望实现UAT环境的一台windows的堡垒机可以访问生产…

第G3周:CGAN入门|生成手势图像

🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 | 接辅导、项目定制 一、前置知识 CGAN(条件生成对抗网络)的原理是在原始GAN的基础上,为生成器和判别器提供 额外的条件信息…

【linux】04 :linix实用操作

1.常用快捷键 ctrlc表示强制停止。linux某些程序的运行,如果想强制停止,可以使用;命令输入错误,也可以通过ctrlc,退出当前输入,重新输入。 ctrld表示退出登录,比如退出root以回到普通用户,或者…

Stable Diffusion 模型下载:ZavyChromaXL(现实、魔幻)

本文收录于《AI绘画从入门到精通》专栏,专栏总目录:点这里。 文章目录 模型介绍生成案例案例一案例二案例三案例四案例五案例六案例七案例八 下载地址 模型介绍 作者述:该模型系列应该是用于 SDXL 的 ZavyMix SD1.5 模型的延续。主要重点是获…

背包问题算法

背包问题算法 0-1背包问题二维数组一维数组 完全背包问题二维数组一维数组 多重背包问题一维数组 0-1背包问题 问题:背包的容量为9,有重量分别为[2, 4, 6, 9]的四个物品,价值分别为[3, 4, 5, 6],求背包能装的物品的最大价值是多少…

Orange3数据预处理(预处理器组件)

1.组件介绍 Orange3 提供了一系列的数据预处理工具,这些工具可以帮助用户在数据分析之前准备好数据。以下是您请求的预处理组件的详细解释: Discretize Continuous Variables(离散化连续变量): 这个组件将连续变量转…

个人网站展示(静态)

大学期间做了一个个人博客网站,纯H5编码的网站,利用php搭建了一个留言模块。 有需要源码的同学,可以联系我~ 首页: IT杂记模块 文人墨客模块 劳有所获模块 生活日志模块 关于我 一个推崇全栈开发的前端开发人员 微信: itrzzh …

elasticsearch篇

1.初识elasticsearch 1.1.了解ES 1.1.1.elasticsearch的作用 elasticsearch是一款非常强大的开源搜索引擎,具备非常多强大功能,可以帮助我们从海量数据中快速找到需要的内容 例如: 在电商网站搜索商品 在百度搜索答案 在打车软件搜索附近…

代码随想录算法训练营Day39 || leetCode 762.不同路径 || 63. 不同路径 II

62.不同路径 每一位的结果等于上方与左侧结果和 class Solution { public:int uniquePaths(int m, int n) {vector<vector<int>> dp(m,vector(n,0));for (int i 0; i < m; i) dp[i][0] 1;for (int j 0; j < n; j) dp[0][j] 1;for (int i 1; i < m; …

使用docker部署redis集群

编写脚本 批量创建目录文件&#xff0c;编写配置文件 [rootlocalhost ~]# cat redis.sh #/bin/bash for port in $(seq 1 6); do mkdir -p /mydata/redis/node-${port}/conf touch /mydata/redis/node-${port}/conf/redis.conf cat << EOF >>/mydata/redis/node-…

记录西门子:IO隔离SCL编程

在PLC变量中创建IO输入输出 在PLC类型中创建输入和输出&#xff0c;并将PLC变量的输入输出名称复制过来 创建一个FC块或者FB块 创建一个DB块 MAIN主程序中&#xff1a;

【UVM_phase objection_2024.03.08

phase 棕色&#xff1a;function phase 不消耗仿真时间 绿色&#xff1a;task phase 消耗仿真时间 run_phase与右边的phase并行执行&#xff0c;右边的phase&#xff08;run_time phase&#xff09;依次执行&#xff1a; List itemreset_phase对DUT进行复位&#xff0c;初始…

24 深度卷积神经网络 AlexNet【李沐动手学深度学习v2课程笔记】(备注:含AlexNet和LeNet对比)

目录 1. 深度学习机器学习的发展 1.1 核方法 1.2 几何学 1.3 特征工程 opencv 1.4 Hardware 2. AlexNet 3. 代码 1. 深度学习机器学习的发展 1.1 核方法 2001 Learning with Kernels 核方法 &#xff08;机器学习&#xff09; 特征提取、选择核函数来计算相似性、凸优…