Pytorch Lighting Hydra库的学习

MVsplat 使用了Hydra 库来进行参数的配置 :

在文件运行的最开始的地方, 使用装饰器 使用 Hydra 这个库,一般都是对于 Main 函数进行修饰的,需要读取代码中的 yaml 文件:

@hydra.main(version_base=None,config_path="../config",   ## config 文件的路径config_name="main",      ## 读取 main.yaml 文件
)

yaml 文件和 defaults 关键词搭配起来,可以去调用 其他的 yaml 配置文件。
Main.yaml 文件的内容如下:

defaults:- dataset: re10k   ## 表示 dataset 的配置文件在re10k.yaml 去读取- optional dataset/view_sampler_dataset_specific_config: ${dataset/view_sampler}_${dataset}- model/encoder: costvolume- model/decoder: splatting_cuda- loss: [mse]
Hydra 库 一般和 yaml 文件组合起来设置超参数

参考网址:https://zhuanlan.zhihu.com/p/662221581

启动命令:python -m src.main +experiment=re10k
这行命令会导致 程序的 config 最终会加入在 experiment 目录下的 读取 re10k.yaml 文件,作为配置文件

Pytorch Lighting 的学习

Youtube 小哥的教学视频: https://www.youtube.com/watch?v=XbIN9LaQycQ&list=PLhhyoLH6IjfyL740PTuXef4TstxAK6nGP&index=1

Lighting 的源代码库,查看API:

https://github.com/Lightning-AI/pytorch-lightning/blob/5aadfa62508ee20735083900273c8e3ff5867602/src/lightning/pytorch/core/module.py#L2

Overview:

1. 在继承 Lighting 的一个类里面,需要实现以下的函数:

训练的主函数:,最后只需要返回 loss 即可,之后的 Backward 操作 Lighting 会自己完成,并不需要用户编写。

def training_step(self, batch, batch_idx):retrun Loss

这里返回 一个 Loss 或者 预测的 dictionary , 像loss.backward() 等工程性质的代码,在Lighting 已经被自动计算好了。

在 下面使用Test 和 Validate 的时候 会自动不计算和保留程序的梯度。

model.eval() and torch.no_grad() are called automatically for validation.

测试的主函数:

def testing_step(self, batch, batch_idx):

配置优化器:

def configure_optimizers(self):

2. 配置训练器Trainer :

   trainer = Trainer(max_epochs=-1,  ## 设置为 -1 表示可以无限训练accelerator="gpu",logger=logger,devices="auto",strategy="ddp" if torch.cuda.device_count() > 1 else "auto",callbacks=callbacks,check_val_every_n_epoch=None, ## 我们是暗战 step 来计算,而不是 epoch  val_check_interval=500, ## 500个step 运行一次 validationenable_progress_bar=cfg.mode == "test",gradient_clip_val=cfg.trainer.gradient_clip_val, ## 梯度裁剪。 防止出现梯度消失或者爆炸。max_steps=cfg.trainer.max_steps,  ## 指定了 最大的 stepsnum_sanity_val_steps=cfg.trainer.num_sanity_val_steps, ## 训练前先进行 validate, 保证代码没有出错)
在 Pytorch 中使用 Tensorboard Logger:
  • 先在主函数里面定义 TensorboardLogger, 并且添加到 Trainer 当中:
logger = TensorBoardLogger(save_dir=cfg_dict.output,version=cfg.descriptor)
trainer = Trainer(max_epochs=-1,accelerator="gpu",logger=logger,  ## 使用 Tensorboard 的 Loggerdevices="auto")
  • 先在 training_step 当中 使用我们定义的 Logger:
self.logger.experiment.add_image()
self.log('PSNR', psnr, prog_bar=True, on_step=True, on_epoch=False)

3. Metrics :

Video 里面说可以在 **回调函数 training_step ** 去计算某一些指标.

def training_step(self, batch, batch_idx):

4. DataModule

Lighting 的 Dataset 和 Pytorch 的 Dataset 的定义方式是很相近的。 都是需要先 自己定义一个 Dataset, 然后根据自己定义的 Dataset 去实现 对应的 Dataloader
在 DataModule 里面需要实现3个 DataLoader

class DataModule(LightningDataModule):def prepare_data(self):  ## 最开始运行的 函数,一般也可以用于读取数据self.dataset = passdef train_dataloader(self):return DataLoader(self.datset)def val_dataloader(self):return DataLoader(self.datset)def test_dataloader(self, dataset_cfg=None):	

prepare_data: 会首先调用这个函数去 准备 数据集,比如说生成 **Dataset. ** MVSNeRF 的代码就是在 默认的 prepare_data 里面去 生成了 数据集 self.train_datatset

   def prepare_data(self):dataset = dataset_dict[self.args.dataset_name]train_dir, val_dir = self.args.datadir , self.args.datadirself.train_dataset = dataset(root_dir=train_dir, split='train', max_len=-1 , downSample=args.imgScale_train)self.val_dataset   = dataset(root_dir=val_dir, split='val', max_len=10 , downSample=args.imgScale_test)#

但是所有的 关于 Dataset 的 参数设定,最后都需要 体现在 DataLoader 的参数当中,或者 Datalodaer 的参数之前。

4. Device

Pytorch Lighting 会自动分布device, 因此代码里不需要显式调用 .cuda() 或者 device.

Remove any .cuda() or .to(device) Calls

装饰器 rank_zero_only

这个 命令表示,这个函数只会在 GPU:0 上进行运行,而不会在多GPU 训练的时候进入到其他的 GPU。

@rank_zero_onlydef validation_step(self, batch, batch_idx):batch: BatchedExample = self.data_shim(batch)if self.global_rank == 0:print(f"validation step {self.global_step};")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/27259.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Android面试八股文】你了解 pthread 吗?创建一个新线程会占用多少内存?

文章目录 一、你了解pthread吗?二、创建一个新线程会占用多少内存?三、在实际开发中,你会如何优化线程的内存使用呢?一、你了解pthread吗? 面试官: 你了解 pthread 吗? 求职者: 是的,我了解 pthread。pthread 是 POSIX threads 的缩写,是一种在 UNIX 和类 UNIX 系统…

SpringBoot中获取当前请求的request和response

在Spring Boot中,你可以以多种方式获取当前请求的HttpServletRequest和HttpServletResponse对象。以下是几种常见的写法示例: 1. 在方法参数中声明 最常见和推荐的方式是在控制器方法的参数中直接声明HttpServletRequest和HttpServletResponse对象。Sp…

java打印99乘法表

public class NineNineMulTable{public static void main(String[] args){for(int i 1; i < 9; i ){for(int j 1; j < i; j ){System.out.print(j " * " i " " i * j "\t");//再次先输出j在输出i是打印出来是1*2&#xff0c;2*2}S…

宝藏速成秘籍(7)堆排序法

一、前言 1.1、概念 堆排序&#xff08;Heapsort&#xff09;是指利用堆这种数据结构所设计的一种排序算法 。堆是一个近似 完全二叉树 的结构&#xff0c;并同时满足堆积的性质&#xff1a;即子结点的键值或索引总是小于&#xff08;或者大于&#xff09;它的父节点。 1.2、排…

模板方法模式(大话设计模式)C/C++版本

模板方法模式 C #include <iostream> using namespace std;class TestPaper { public:void TestQ1(){cout << "杨过得到&#xff0c;后来给了郭靖&#xff0c;炼成倚天剑&#xff0c;屠龙刀的玄铁可能是[ ]\na.球磨铸铁 b.马口贴 c.高速合金钢 d.碳素纤维&q…

Linux——ansible剧本

剧本&#xff08;playbook&#xff09; 现在&#xff0c;可以写各种临时命令 但如果&#xff0c;想把所有步骤&#xff0c;集合到一起&#xff0c;写到同一个文件里 让ansible自动按顺序执行 就必须要写“剧本” 剧本里面&#xff0c;也可以写临时命令&#xff0c;但是剧本…

C++中bool类型的使用细节

C中bool类型的使用细节 ANSIISO C标准添加了一种名叫bool的新类型(对 C来说是新的)。它的名称来源于英国数学家 George Boole&#xff0c;是他开发了逻辑律的数学表示法。在计算中&#xff0c;布尔变量的值可以是true或false。过去&#xff0c;C和C一样&#xff0c;也没有布尔…

Kafka 负载均衡挑战及解决思路

本文转载自 Agoda Engineering&#xff0c;介绍了在实际应用中&#xff0c;如何应对 Kafka 负载均衡所遇到的各种挑战&#xff0c;并提出相应的解决思路。本文简要阐述了 Kafka 的并行性机制、常用的分区策略以及在实际操作中遇到的异构硬件、不均匀工作负载等问题。通过深入分…

重生之 SpringBoot3 入门保姆级学习(19、场景整合 CentOS7 Docker 的安装)

重生之 SpringBoot3 入门保姆级学习&#xff08;19、场景整合 CentOS7 Docker 的安装&#xff09; 6、场景整合6.1 Docker 6、场景整合 6.1 Docker 官网 https://docs.docker.com/查看自己的 CentOS配置 cat /etc/os-releaseStep 1: 安装必要的一些系统工具 sudo yum insta…

继承-进阶-易错点

子类同名方法隐藏父类方法 即使调用不匹配也不会再去父类寻找&#xff0c;而是直接报错 //下面代码输出结果&#xff1a;( )&#xfeff;class A { public:void f(){ cout<<"A::f()"<<endl; }int a; };class B : public A { public:void f(int a){c…

【Android面试八股文】Android开发中怎样判断当前线程是否是主线程?

文章目录 1. 使用 `Looper.getMainLooper()`2. 使用 `Handler`3. 使用 `Activity` 或 `View` 的方法4. 使用 Thread 类的 isMainThread 方法示例代码在Android开发中,判断当前线程是否是主线程(也称为UI线程)非常重要,因为只有主线程才能更新UI。 以下是几种常用的方法来判…

Qt6的获取调色板颜色和Qt5不一样了

Qt5中是[static] QRgb QColorDialog::getRgba(QRgb initial 0xffffffff, bool *ok nullptr, QWidget *parent nullptr) 而Qt6更加直接了&#xff0c;[static] QColor QColorDialog::getColor(const QColor &initial Qt::white, QWidget *parent nullptr, const QStri…

Excel使用技巧(一)

一. 快速调整数据位置 已经录入数据的表格&#xff0c;要调整某一列的位置怎么办&#xff1f; 只要选中要调整的数据区域&#xff0c;然后按住Shift键不放&#xff0c;光标放到绿色边框位置后&#xff0c;按下鼠标左键不放拖动即可&#xff1a; 二. 取消合并单元格并恢复数据…

电商项目-day03

文章目录 退出登录流程首先判断前端后端逻辑 登录验证的思路和ThreadLocal讲解 退出登录流程 首先判断前端 首先定义退出请求 // 退出登录 export const Logout () > {return request({url: ${api_name}/logout,method: get,})}const api_name “admin/system/index”; …

「C系列」C 字符串及操作字符串的函数

文章目录 一、C 字符串1. 声明和初始化字符串2. 访问字符串中的字符3. 字符串的长度4. 字符串的复制和连接5. 字符串的比较6. 字符串的查找 二、C 操作字符串的函数三、相关链接 一、C 字符串 在C语言中&#xff0c;字符串是由字符&#xff08;包括字母、数字、标点符号等&…

深入理解 JVM 的几种常见垃圾回收算法

在线工具站 推荐一个程序员在线工具站&#xff1a;程序员常用工具&#xff08;http://cxytools.com&#xff09;&#xff0c;有时间戳、JSON格式化、文本对比、HASH生成、UUID生成等常用工具&#xff0c;效率加倍嘎嘎好用。 程序员资料站 推荐一个程序员编程资料站&#xff1a;…

Django DeleteView视图

Django 的 DeleteView 是一个基于类的视图&#xff0c;用于处理对象的删除操作。 1&#xff0c;添加视图函数 Test/app3/views.py from django.shortcuts import render# Create your views here. from .models import Bookfrom django.views.generic import ListView class B…

信息科学与工程学院第五届大学生程序设计竞赛——热身赛

A:X星人的地盘 题目描述 一天&#xff0c;X星人和Y星人在一张矩形地图上玩抢地盘的游戏。 X星人每抢到一块地&#xff0c;在地图对应的位置标记一个“X”&#xff1b;Y星人每抢到一块地&#xff0c;在地图对应的位置标记一个“Y”&#xff1b;如果某一块地无法确定其归属则标记…

2024050901-重学 Java 设计模式《实战访问者模式》

重学 Java 设计模式&#xff1a;实战访问者模式「模拟家长与校长&#xff0c;对学生和老师的不同视角信息的访问场景」 一、前言 能力&#xff0c;是你前行的最大保障 年龄会不断的增长&#xff0c;但是什么才能让你不慌张。一定是能力&#xff0c;即使是在一个看似还很安稳…

Web后端开发的学习

REST风格 GET:查询用户POST:新增用户POT:修改用户DELETE:删除用户 前后端交互统一的响应结果 记录日志 SLf4j 注解&#xff1a; PathVariable&#xff1a;获取路径的参数ResponseBody :方法的返回值直接作为 HTTP 响应的正文返回,将响应的实体类转为json发送给前端Request…