使用亚马逊针对 PyTorch 和 MinIO 的 S3 连接器进行模型检查点处理

2023 年 11 月,Amazon 宣布推出适用于 PyTorch 的 S3 连接器。适用于 PyTorch 的 Amazon S3 连接器提供了专为 S3 对象存储构建的 PyTorch 数据集基元(数据集和数据加载器)的实现。它支持用于随机数据访问模式的地图样式数据集和用于流式处理顺序数据访问模式的可迭代样式数据集。适用于 PyTorch 的 S3 连接器还包括一个检查点接口,用于将检查点直接保存和加载到 S3 存储桶,而无需先保存到本地存储。如果您还没有准备好采用正式的 MLOps 工具,而只需要一种简单的方法来保存模型,那么这是一个非常好的选择。这就是我将在这篇文章中介绍的内容。S3 连接器的文档仅展示了如何将其与 Amazon S3 一起使用 - 我将在此处向您展示如何将其用于 MinIO。让我们先执行此作 - 让我们设置 S3 连接器,以便它从 MinIO 写入和读取检查点。

将 S3 连接器连接到 MinIO

将 S3 连接器连接到 MinIO 就像设置环境变量一样简单。之后,一切都会顺利进行。诀窍是以正确的方式设置正确的环境变量。

本文的代码下载使用 .env 文件来设置环境变量,如下所示。此文件还显示了我用于使用 MinIO Python SDK 直接连接到 MinIO 的环境变量。请注意,AWS_ENDPOINT_URL 需要 protocol,而 MinIO 变量不需要。

AWS_ACCESS_KEY_ID=admin
AWS_ENDPOINT_URL=http://172.31.128.1:9000
AWS_REGION=us-east-1
AWS_SECRET_ACCESS_KEY=password
MINIO_ENDPOINT=172.31.128.1:9000
MINIO_ACCESS_KEY=admin
MINIO_SECRET_KEY=password
MINIO_SECURE=false

写入和读取 Checkpoint

我从一个简单的例子开始。下面的代码段创建了一个 S3Checkpointing 对象,并使用其 writer() 方法将模型的状态字典发送到 MinIO。我还使用 Torchvision 创建了一个 ResNet-18(18 层)模型,用于演示目的。

import osfrom dotenv import load_dotenv
from s3torchconnector import S3Checkpoint
import torchvision
import torch# Load the credentials and connection information.
load_dotenv()model = torchvision.models.resnet18()
model_name = 'resnet18.pth'
bucket_name = 'checkpoints'checkpoint_uri = f's3://{bucket_name}/{model_name}'
s3_checkpoint = S3Checkpoint(os.environ['AWS_REGION'])# Save checkpoint to S3
with s3_checkpoint.writer(checkpoint_uri) as writer:torch.save(model.state_dict(), writer)

请注意,该区域有一个强制参数。从技术上讲,访问 MinIO 时没有必要,但如果为此变量选择错误的值,内部检查可能会失败。此外,您的存储桶必须存在,上述代码才能正常工作。如果 writer() 方法不存在,它将引发错误。不幸的是,无论出了什么问题,writer() 方法都会引发相同的错误。例如,如果您的存储桶不存在,您将收到如下所示的错误。如果 writer() 方法不喜欢您指定的区域,您也会收到相同的错误。希望未来的版本将提供更具描述性的错误消息。

S3Exception: Client error: Request canceled

将以前保存的模型读取到内存中的代码类似于写入 MinIO。使用 reader() 方法,而不是 writer() 方法。下面的代码显示了如何执行此作。

import osfrom dotenv import load_dotenv
from s3torchconnector import S3Checkpoint
import torchvision
import torch# Load the credentials and connection information.
load_dotenv()model_name = 'resnet18.pth'
bucket_name = 'checkpoints'checkpoint_uri = f's3://{bucket_name}/{model_name}'
s3_checkpoint = S3Checkpoint(os.environ['AWS_REGION'])# Load checkpoint from S3
with s3_checkpoint.reader(checkpoint_uri) as reader:state_dict = torch.load(reader, weights_only=True)model.load_state_dict(state_dict)

接下来,让我们看看模型训练期间检查点的一些实际注意事项。

在模型训练期间编写检查点

如果您使用大型数据集训练大型模型,请考虑在每个 epoch 后设置检查点。这些训练运行可能需要数小时甚至数天才能完成,因此在发生故障时能够从上次中断的地方继续非常重要。此外,我们假设您必须使用共享存储桶来保存来自多个团队的多个模型的模型检查点。MLOps 约定是按试验组织训练运行。例如,如果您正在研究具有四个隐藏层的架构,那么在寻找各种超参数的最佳值时,您将使用此架构进行多次运行。如果同事使用五层体系结构运行实验,则需要一种方法来防止名称冲突。这可以通过模拟如下所示的层次结构的对象路径来解决。

最后,为了确保您在每个 epoch 中获得新版本的模型,请确保在用于保存检查点的存储桶上启用版本控制。下面的训练函数使用上述路径结构在每个 epoch 后对模型进行检查点作。(可以在本文的代码下载中找到此训练函数的更强大版本。

def train_model(model: nn.Module, loader: DataLoader, training_parameters: Dict[str, Any]) -> List[float]:if training_parameters['checkpoint']:checkpoint_uri = f's3://{training_parameters["checkpoint_bucket"]} \/{training_parameters["project_name"]} \/{training_parameters["experiment_name"]} \/{training_parameters["run_id"]} \/{training_parameters["model_name"]}'s3_checkpoint = S3Checkpoint(region=os.environ['AWS_REGION'])loss_func = nn.NLLLoss()optimizer = optim.SGD(model.parameters(), lr=training_parameters['lr'], momentum=training_parameters['momentum'])# Epoch loopcompute_time_by_epoch = []for epoch in range(training_parameters['epochs']):# Batch loopfor images, labels in loader:# Flatten MNIST images into a 784 long vector.# shape = [32, 784]images = images.view(images.shape[0], -1)# Training passoptimizer.zero_grad()output = model(images)loss = loss_func(output, labels)loss.backward()optimizer.step()# Save checkpoint to S3if training_parameters['checkpoint']:with s3_checkpoint.writer(checkpoint_uri) as writer:torch.save(model.state_dict(), writer)

请注意,模型名称不包含指示纪元的子字符串。如前所述,我使用了启用了版本控制的存储桶 - 换句话说,版本号表示纪元。这种方法的优点在于,您无需知道引用最新模型的 epoch 数。在上述训练代码运行了 10 个 epoch 后,我的检查点存储桶如下面的屏幕截图所示。

此培训演示可被视为 DIY MLOps 解决方案的开始。

结论

适用于 PyTorch 的 S3 连接器易于使用,工程师在使用时编写的数据访问代码行数会更少。在本文中,我展示了如何将其配置为使用环境变量连接到 MinIO。配置完成后,工程师可以分别使用 writer() 和 reader() 方法将检查点写入和读取 MinIO。在本文中,我展示了如何配置 S3 Connect 以连接到 MinIO。我还演示了 S3Checkpoint 类及其 reader() 和 writer() 方法的基本用法。最后,我展示了一种在实际训练函数中针对启用了版本的检查点存储桶使用这些检查点功能的方法。在这篇文章中,我没有介绍在分布式训练期间检查点所需的技术和工具,这可能有点棘手。分布式训练期间的检查点设置会有所不同,具体取决于您使用的框架(PyTorch、Ray 或 DeepSpeed 等)和您正在进行的分布式训练类型:数据并行(每个工作程序都有模型的完整副本)或模型并行(每个工作程序只有一个模型分片)。在以后的文章中,我将介绍其中的一些技术。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/69705.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

尚硅谷爬虫note004

一、urllib库 1. python自带,无需安装 # _*_ coding : utf-8 _*_ # Time : 2025/2/11 09:39 # Author : 20250206-里奥 # File : demo14_urllib # Project : PythonProject10-14#导入urllib.request import urllib.request#使用urllib获取百度首页源码 #1.定义一…

Spring 项目接入 DeepSeek,分享两种超简单的方式!

⭐自荐一个非常不错的开源 Java 面试指南:JavaGuide (Github 收获148k Star)。这是我在大三开始准备秋招面试的时候创建的,目前已经持续维护 6 年多了,累计提交了 5600 commit ,共有 550 多位贡献者共同参与…

日常知识点之面试后反思裸写string类

1:实现一个字符串类。 简单汇总 最简单的方案,使用一个字符串指针,以及实际字符串长度即可。 参考stl的实现,为了提升string的性能,实际上单纯的字符串指针和实际长度是不够了,如上,有优化方案…

phpipam1.7安装部署

0软件说明 phpipam是一个开源Web IP地址管理应用程序(IPAM) phpipam官网:https://www.phpipam.net/ 1安装环境 操作系统:Rocky Linux9.5x86_64 phpipam版本:1.7 php版本:8.0.30 数据库版本&#xff1a…

python卷积神经网络人脸识别示例实现详解

目录 一、准备 1)使用pytorch 2)安装pytorch 3)准备训练和测试资源 二、卷积神经网络的基本结构 三、代码实现 1)导入库 2)数据预处理 3)加载数据 4)构建一个卷积神经网络 5&#xff0…

【文本处理】如何在批量WORD和txt文本提取手机号码,固话号码,提取邮箱,删除中文,删除英文,提取车牌号等等一些文本提取固定格式的操作,基于WPF的解决方案

企业的应用场景 数据清洗:在进行数据导入或分析之前,往往需要对大量文本数据进行预处理,比如去除文本中的无关字符(中文、英文),只保留需要的联系信息(手机号码、固话号码、邮箱)。…

【Cocos TypeScript 零基础 15.1】

目录 见缝插针UI脚本针脚本球脚本心得_旋转心得_更改父节点心得_缓动动画成品展示图 见缝插针 本人只是看了老师的大纲,中途不明白不会的时候再去看的视频 所以代码可能与老师代码有出入 SIKI_学院_点击跳转 UI脚本 import { _decorator, Camera, color, Component, directo…

pdf.js默认显示侧边栏和默认手形工具

文章目录 默认显示侧边栏(切换侧栏)默认手形工具(手型工具) 大部分的都是在viewer.mjs中的const defaultOptions 变量设置默认值,可以使用数字也可以使用他们对应的变量枚举值 默认显示侧边栏(切换侧栏) 在viewer.mjs中找到defaultOptions,大概在732行,或则搜索sidebarViewOn…

Pdf手册阅读(1)--数字签名篇

原文阅读摘要 PDF支持的数字签名, 不仅仅是公私钥签名,还可以是指纹、手写、虹膜等生物识别签名。PDF签名的计算方式,可以基于字节范围进行计算,也可以基于Pdf 对象(pdf object)进行计算。 PDF文件可能包…

Zabbix-监控SSL证书有效期

背景 项目需要,需要监控所有的SSL证书的有效期,因此需要自定义一个监控项 实现 创建自定义脚本 在Zabbix的scripts目录(/etc/zabbix/scripts/)下创建一个新的shell脚本check_ssl.sh,内容如下 #!/bin/bash time$(echo | openssl s_client…

java安全中的类加载

java安全中的类加载 提前声明: 本文所涉及的内容仅供参考与教育目的,旨在普及网络安全相关知识。其内容不代表任何机构、组织或个人的权威建议,亦不构成具体的操作指南或法律依据。作者及发布平台对因使用本文信息直接或间接引发的任何风险、损失或法律纠…

只需三步!5分钟本地部署deep seek——MAC环境

MAC本地部署deep seek 第一步:下载Ollama第二步:下载deepseek-r1模型第三步:安装谷歌浏览器插件 第一步:下载Ollama 打开此网址:https://ollama.com/,点击下载即可,如果网络比较慢可使用文末百度网盘链接 注:Ollama是…

神经网络常见激活函数 9-CELU函数

文章目录 CELU函数导函数函数和导函数图像优缺点pytorch中的CELU函数tensorflow 中的CELU函数 CELU 连续可微指数线性单元:CELU(Continuously Differentiable Exponential Linear Unit),是一种连续可导的激活函数,结合了 ELU 和 …

w~自动驾驶~合集17

我自己的原文哦~ https://blog.51cto.com/whaosoft/13269720 #FastOcc 推理更快、部署友好Occ算法来啦! 在自动驾驶系统当中,感知任务是整个自驾系统中至关重要的组成部分。感知任务的主要目标是使自动驾驶车辆能够理解和感知周围的环境元素&#…

Visual Studio 进行单元测试【入门】

摘要:在软件开发中,单元测试是一种重要的实践,通过验证代码的正确性,帮助开发者提高代码质量。本文将介绍如何在VisualStudio中进行单元测试,包括创建测试项目、编写测试代码、运行测试以及查看结果。 1. 什么是单元测…

解决珠玑妙算游戏问题:C 语言实现

一、引言 珠玑妙算游戏(the game of master mind)是一个有趣的逻辑推理游戏。在编程领域,我们可以通过编写代码来模拟游戏中计算猜中与伪猜中次数的过程。本文将详细介绍如何使用 C 语言实现这一功能,并对核心代码进行解析。 二、…

查询语句来提取 detail 字段中包含 xxx 的 URL 里的 commodity/ 后面的数字串

您可以使用以下 SQL 查询语句来提取 detail 字段中包含 oss.kxlist.com 的 URL 里的 commodity/ 后面的数字串&#xff1a; <p><img style"max-width:100%;" src"https://oss.kxlist.com//8a989a0c55e4a7900155e7fd7971000b/commodity/20170925/20170…

2024BaseCTF_week4_web上

继续&#xff01;冲冲冲 目录 圣钥之战1.0 nodejs 原型 原型链 原型链污染 回到题目 flag直接读取不就行了&#xff1f; 圣钥之战1.0 from flask import Flask,request import jsonapp Flask(__name__)def merge(src, dst):for k, v in src.items():if hasattr(dst, __geti…

摄像头动捕:摄像头+AI精准捕捉动作

在科技蓬勃发展的当下&#xff0c;动作捕捉技术已从最初的小众应用逐渐走进大众视野&#xff0c;广泛渗透到众多领域。其中&#xff0c;摄像头动捕&#xff0c;也就是无穿戴动作捕捉系统&#xff0c;以其独特的技术优势和创新应用&#xff0c;正悄然改变着人们对动作捕捉的认知…

机器学习 - 词袋模型(Bag of Words)实现文本情感分类的详细示例

为了简单直观的理解模型训练&#xff0c;我这里搜集了两个简单的实现文本情感分类的例子&#xff0c;第一个例子基于朴素贝叶斯分类器&#xff0c;第二个例子基于逻辑回归&#xff0c;通过这两个例子&#xff0c;掌握词袋模型&#xff08;Bag of Words&#xff09;实现文本情感…