Fabric实现多GPU运行

官方的将pytorch转换为fabric简单分为五个步骤:

步骤 1:

在训练代码的开头创建 Fabric 对象

from lightning.fabric import Fabricfabric = Fabric()

步骤 2:

如果打算使用多个设备(例如多 GPU),就调用 launch()

fabric.launch()

 步骤 3:

在每个模型和优化器对上调用 setup() ,在所有数据加载器上调用 setup_dataloaders()

model, optimizer = fabric.setup(model, optimizer)
dataloader = fabric.setup_dataloaders(dataloader)

 步骤 4:

删除所有 .to 和 .cuda 调用,因为 Fabric 将自动处理

- model.to(device)    # 删除
- batch.to(device)    # 删除

步骤 5:

将 loss.backward() 替换为 fabric.backward(loss) 

- loss.backward()
+ fabric.backward(loss)

结合起来:

将所有步骤结合起来,这就是代码将如何更改:

  import torchfrom lightning.pytorch.demos import WikiText2, Transformer
+ import lightning as L    # 新增- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")    # 删除
+ fabric = L.Fabric(accelerator="cuda", devices=8, strategy="ddp")    # 新增
+ fabric.launch()    # 新增dataset = WikiText2()dataloader = torch.utils.data.DataLoader(dataset)model = Transformer(vocab_size=dataset.vocab_size)optimizer = torch.optim.SGD(model.parameters(), lr=0.1)- model = model.to(device)    # 删除
+ model, optimizer = fabric.setup(model, optimizer)    # 新增
+ dataloader = fabric.setup_dataloaders(dataloader)    # 新增model.train()for epoch in range(20):for batch in dataloader:input, target = batch
-         input, target = input.to(device), target.to(device)    # 删除optimizer.zero_grad()output = model(input, target)loss = torch.nn.functional.nll_loss(output, target.view(-1))
-         loss.backward()    # 删除
+         fabric.backward(loss)    # 新增optimizer.step()

=======================================================================

记录一下自己代码的修改过程 

训练的是DECA的修改版 (En生bs和lm(Wgan*0.01+dinov2))

main_train.py中

导入lighting和Fabric并实例化,实例化的适合也可以加上【precision='32'】,float32位精度 

from lightning import Fabric
import lightning as Lfabric = Fabric(accelerator="cuda",devices=None, strategy="ddp",precision='32')
fabric.launch()# 这里的devices=None:这样就取决于命令行中CUDA_VISIBLE_DEVICES=的gpu名称
# precision='32'是使用32位精度

其他参数可用内容:

fabric = Fabric()fabric = Fabric(devices=2/4/8)fabric = Fabric(devices=1/2/4/8/"auto", strategy="ddp"/"fspd"/"deepspeed"/"auto")

deca.py中

1.导包+初始化fabric

这个import fabric就是上面实例化出来的fabric,我在trainer中又实例化了一下

2.去除原本的DP或者DDP,因为会冲突

注释了上面的DataParallel,使用fabric.setup对模型进行fabric操作

 trainer.py中

主要修改训练函数

1.我在这里又实例化了一遍

from lightning import Fabric
fabric = Fabric(accelerator="cuda",devices=None, strategy="ddp",precision='32')
fabric.launch()

2.Trainer类中初始化时进行了添加

3.主要修改:training_step

 验证的话 也是这么修改

4.数据dataloader处理

 5.fit.py,关于tensorboard报错

每个卡都有损失,tensorboard好像全局损失什么的,会产生冲突

然后最后loss和backward修改

 然后就可以启动了CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 nohup python -u main_train.py --cfg configs/pretrain.yml > train.log 2>&1

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/837566.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

高级个人主页

高级个人主页 效果图部分代码领取源码下期更新预报 效果图 部分代码 <!DOCTYPE html> <html lang"en"><head><meta charset"utf-8" name"viewport" content"widthdevice-width, initial-scale1, maximum-scale1, use…

ESP32重要库示例详解(四):获取NTP时间之time库

在物联网项目中&#xff0c;时间同步和管理是至关重要的功能之一&#xff0c;特别是在需要执行定时任务或记录事件时间戳的场景下。Arduino平台通过其内置的<time.h>库提供了强大的时间处理能力&#xff0c;使得开发者能够方便地与网络时间协议&#xff08;NTP&#xff0…

PDF文件转换为CAD的方法

有时候我们收到一个PDF格式的设计图纸&#xff0c;但还需要进行编辑或修改时&#xff0c;就必须先将PDF文件转换回CAD格式。分享两个将PDF转换回CAD的方法&#xff0c;一个用到在线网站&#xff0c;一个用到PC软件&#xff0c;大家根据情况选择就可以了。 ☞在线CAD网站转换 …

css超出部分省略(单行、多行,多种方法实现)

HTML <p class"text">这是一行测试数据,这是一行测试数据,这是二行测试数据,这是一行测试数据,这是三行测试数据,这是四行测试数据</p>1.单行 .text{width: 200px;border: 1px solid #000000;white-space: nowrap; /* 控制元素不换行 */overflow: hi…

Django图书馆综合项目-学习(2)

接下来我们来实现一下图书管理系统的一些相关功能 1.在书籍的book_index.html中有一个"查看所有书毂"的超链接按钮&#xff0c;点击进入书籍列表book_list.html页面. 这边我们使用之前创建的命名空间去创建超连接 这里的book 是在根路由创建的namespacelist是在bo…

6. RedHat认证-基于公钥的认证方式

6. RedHat认证-基于公钥的认证方式 主要学习客户端访问服务端的时候&#xff0c;免密登录这一方式 注意: 免密登录只是基于公钥认证的一个附带属性(基于公钥认证的方式更加安全&#xff0c;防止黑客暴力破解) 第一步&#xff1a;将客户端生成的秘钥传送到服务器 在客户端通过…

2024中国(厦门)国际医用消毒及感控设备展览会

2024中国&#xff08;厦门&#xff09;国际医用消毒及感控设备展览会 2024 China (Xiamen) International Medical Disinfection And Infection Control Exhibition 致力于打造医用消毒及感控设备产业采购一站式平台 时 间&#xff1a;2024年11月1-3日 November 1-3, 2024 …

一文扫盲(13):电商管理系统的功能模块和设计要点

电商管理系统是一种用于管理和运营电子商务平台的软件系统。它提供了一系列功能模块&#xff0c;帮助企业进行商品管理、订单管理、会员管理、营销推广、数据分析等工作。本文将从以下四个方面介绍电商管理系统。 一、什么是电商管理系统 电商管理系统是一种集成了各种功能模块…

免费的集成组件有哪些?

集成组件是指将多个软件或系统进行整合&#xff0c;以实现更高效、更可靠的数据处理和管理。在数据管理和分析领域&#xff0c;集成组件是不可或缺的工具之一。 在当今高度信息化的时代&#xff0c;集成组件在各行各业的应用中扮演着举足轻重的角色。集成组件能够将不同来源的…

企业安全必备利器:专业级加密软件介绍

随着信息技术的迅猛发展&#xff0c;数据安全问题日益凸显&#xff0c;专业级加密软件应运而生&#xff0c;成为保护数据安全的重要工具。本文将对专业级加密软件进行概述&#xff0c;分析其特点、应用场景及分享。 一、专业级加密软件概述 专业级加密软件是指那些采用高级加密…

三分钟了解计算机网络核心概念-数据链路层和物理层

计算机网络数据链路层和物理层 节点&#xff1a;一般指链路层协议中的设备。 链路&#xff1a;一般把沿着通信路径连接相邻节点的通信信道称为链路。 MAC 协议&#xff1a;媒体访问控制协议&#xff0c;它规定了帧在链路上传输的规则。 奇偶校验位&#xff1a;一种差错检测方…

uniapp怎么使用jsx

安装vitejs/plugin-vue-jsx npm install vitejs/plugin-vue-jsx -Dvite.config.js配置 import { defineConfig } from "vite"; import uni from "dcloudio/vite-plugin-uni"; import vueJsx from vitejs/plugin-vue-jsxexport default defineConfig({plu…

upload-labs靶场通关详解(1-15)

1.pass-01 查看源代码 是js&#xff0c;属于前端校验 可以通过禁用js来上传文件 2.pass-02 根据提示是MIME绕过 MIME&#xff1a;是设定某种扩展名的文件 用一种应用程序来打开的方式类型&#xff0c;当该扩展名文件被访问的时候&#xff0c;浏览器会自动使用指定应用程序来…

冯喜运:5.14黄金价格空头延续反弹空,原油走势分析实时操作

【黄金消息面分析】&#xff1a;周二&#xff08;5月14日&#xff09;亚洲时段&#xff0c;现货黄金窄幅震荡&#xff0c;目前交投于2342美元/盎司。金价周一因获利了结下跌1%&#xff0c;收报2336.10美元/盎司&#xff0c;投资者等待本周的关键通胀数据为今年美国降息提供更多…

使用Subtitle edit合成双语字幕

有的时候从网上下载的字幕有单独的中文版和英语版&#xff0c;但是没有中英文一起的双语字幕&#xff1a; 后缀为chs的是中文简体后缀为cht的是中文繁体后缀为eng的是英文 如果我们在电脑端上可以直接用potplayer添加副字幕来实现双语&#xff0c;但是如果是别的播放器&#…

多线程·线程状态

目录 1.等待一个线程 join 2.休眠当前线程 3.线程的所有状态 4.线程的状态转换 1.等待一个线程 join 有些场景&#xff0c;我们需要控制线程的执行顺序&#xff0c;这时候就需要用到 join 了 比如&#xff1a;把大象装进冰箱要几步&#xff1f; 第一步&#xff1a;打开冰…

【数据结构陈越版笔记】第1章 概论

我最近准备以陈姥姥的数据结构教材为蓝本重新学一下数据结构&#xff0c;写一下读书笔记 第1章 概论 1.1 引子 概论中首先描述了&#xff0c;数据结构的定义没有具体的定义&#xff0c;初学者可以不用管这个定义的问题&#xff0c;但是我理解的和维基百科的说法是一样的“数…

全面了解 Swagger 导出功能的使用方式

Swagger 是一个强大的平台&#xff0c;专门用于开发、构建和记录 RESTful Web 接口。通过其提供的交互式用户界面&#xff0c;开发人员能够轻松且迅速地创建和测试 API。Swagger 还允许用户以多种格式&#xff0c;包括 JSON 和 Markdown&#xff0c;导出 API 文档。选择 JSON 格…

人工神经网络(科普)

人工神经网络&#xff08;Artificial Neural Network&#xff0c;即ANN &#xff09;&#xff0c;是20世纪80 年代以来人工智能领域兴起的研究热点。它从信息处理角度对人脑神经元网络进行抽象&#xff0c; 建立某种简单模型&#xff0c;按不同的连接方式组成不同的网络。在工程…

MySQL中的索引失效问题

索引失效的情况 这是正常查询情况&#xff0c;满足最左前缀&#xff0c;先查有先度高的索引。 1. 注意这里最后一种情况&#xff0c;这里和上面只查询 name 小米科技 的命中情况一样。说明索引部分丢失&#xff01; 2. 这里第二条sql中的&#xff0c;status > 1 就是范围查…