basic_sr介绍

文章目录

  • pytorch基础知识和basicSR中用到的语法
    • 1.Sampler类与4种采样方式
    • 2.python dict的get方法使用
    • 3.prefetch_dataloader.py
    • 4. pytorch 并行和分布式训练
      • 4.1 选择要使用的cuda
      • 4.2 DataParallel使用方法
        • 常规使用方法
        • 保存和载入
      • 4.3 DistributedDataParallel
    • 5.wangdb 入门
      • 5.1 sign up(https://wandb.ai/site)
      • 5.2 安装和login
      • 5.3 demo
    • 5.model and train
      • 5.1 create model
      • 5.2 opt中设置
      • 5.2 SRModel 类

pytorch基础知识和basicSR中用到的语法

1.Sampler类与4种采样方式

一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系
pytorch源码阅读(三)Sampler类与4种采样方式

下面代码是自定义的一个采样器:
ratio控制扩充数据集的倍数
num_replicas是进程数,一般是world_size
rank: 当前进程的rank

其实目的就是把数据集的索引划分为num_replicas组,供每个进程(process) 处理
至于ratio,是为了使每个epoch训练的数据增多,for saving time when restart the dataloader after each epoch

import math
import torch
from torch.utils.data.sampler import Samplerclass EnlargedSampler(Sampler):"""Sampler that restricts data loading to a subset of the dataset.Modified from torch.utils.data.distributed.DistributedSamplerSupport enlarging the dataset for iteration-based training, for savingtime when restart the dataloader after each epochArgs:dataset (torch.utils.data.Dataset): Dataset used for sampling.num_replicas (int | None): Number of processes participating inthe training. It is usually the world_size.rank (int | None): Rank of the current process within num_replicas.ratio (int): Enlarging ratio. Default: 1."""def __init__(self, dataset, num_replicas, rank, ratio=1):self.dataset = datasetself.num_replicas = num_replicasself.rank = rankself.epoch = 0self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas)self.total_size = self.num_samples * self.num_replicasdef __iter__(self):# deterministically shuffle based on epochg = torch.Generator()g.manual_seed(self.epoch)indices = torch.randperm(self.total_size, generator=g).tolist()dataset_size = len(self.dataset)indices = [v % dataset_size for v in indices]# subsampleindices = indices[self.rank:self.total_size:self.num_replicas]assert len(indices) == self.num_samplesreturn iter(indices)def __len__(self):return self.num_samplesdef set_epoch(self, epoch):self.epoch = epoch

测试一下:

import numpy as np
if __name__ == "__main__":data = np.arange(20).tolist()en_sample = EnlargedSampler(data, 2, 0)en_sample.set_epoch(1)for i in en_sample:print(i)print('\n------------------\n')en_sample = EnlargedSampler(data, 2, 1)en_sample.set_epoch(1) # 设置为同一个epoch .  rank=0或者1时生成的index是互补的# 或者不用设置,默认为0即可。for i in en_sample:print(i)

结果:
在这里插入图片描述

2.python dict的get方法使用

在这里插入图片描述

3.prefetch_dataloader.py

在这里插入图片描述

载入本批数据的时候,预先载入下一批数据。主要看next函数

import queue as Queue
import threading
import torch
from torch.utils.data import DataLoaderclass PrefetchGenerator(threading.Thread):"""A general prefetch generator.Reference: https://stackoverflow.com/questions/7323664/python-generator-pre-fetchArgs:generator: Python generator.num_prefetch_queue (int): Number of prefetch queue."""def __init__(self, generator, num_prefetch_queue):threading.Thread.__init__(self)self.queue = Queue.Queue(num_prefetch_queue)self.generator = generatorself.daemon = Trueself.start()def run(self):for item in self.generator:self.queue.put(item)self.queue.put(None)def __next__(self):next_item = self.queue.get()if next_item is None:raise StopIterationreturn next_itemdef __iter__(self):return selfclass PrefetchDataLoader(DataLoader):"""Prefetch version of dataloader.Reference: https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#TODO:Need to test on single gpu and ddp (multi-gpu). There is a known issue inddp.Args:num_prefetch_queue (int): Number of prefetch queue.kwargs (dict): Other arguments for dataloader."""def __init__(self, num_prefetch_queue, **kwargs):self.num_prefetch_queue = num_prefetch_queuesuper(PrefetchDataLoader, self).__init__(**kwargs)def __iter__(self):return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)class CPUPrefetcher():"""CPU prefetcher.Args:loader: Dataloader."""def __init__(self, loader):self.ori_loader = loaderself.loader = iter(loader)def next(self):try:return next(self.loader)except StopIteration:return Nonedef reset(self):self.loader = iter(self.ori_loader)class CUDAPrefetcher():"""CUDA prefetcher.Reference: https://github.com/NVIDIA/apex/issues/304#It may consume more GPU memory.Args:loader: Dataloader.opt (dict): Options."""def __init__(self, loader, opt):self.ori_loader = loaderself.loader = iter(loader)self.opt = optself.stream = torch.cuda.Stream()self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')self.preload()def preload(self):try:self.batch = next(self.loader)  # self.batch is a dictexcept StopIteration:self.batch = Nonereturn None# put tensors to gpuwith torch.cuda.stream(self.stream):for k, v in self.batch.items():if torch.is_tensor(v):self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)def next(self):torch.cuda.current_stream().wait_stream(self.stream) # 等待下一批处理完毕batch = self.batch # 赋值self.preload()     # 预先载入下一批return batchdef reset(self):self.loader = iter(self.ori_loader)self.preload()

4. pytorch 并行和分布式训练

4.1 选择要使用的cuda

当我们的服务器上有多个GPU,我们应该指明我们使用的GPU是哪一块,如果我们不设置的话,tensor.cuda()方法会默认将tensor保存到第一块GPU上,等价于tensor.cuda(0),这将会导致爆出out of memory的错误。我们可以通过以下两种方式继续设置。

  1. 在文件最开始部分
    #设置在文件最开始部分
    import os
    os.environ["CUDA_VISIBLE_DEVICE"] = "0,1,2" # 设置默认的显卡
    
  2. 在命令行运行的时候设置
     CUDA_VISBLE_DEVICE=0,1 python train.py # 使用0,1两块GPU
    

4.2 DataParallel使用方法

常规使用方法
   model = UNetSeeInDark()model._initialize_weights()gpus = [0123]model = nn.DataParallel(model, device_ids=gpus)device = torch.device('cuda:0')model = model.to(device)# 如果不使用并行,只需要注释掉 model = nn.DataParallel(model, device_ids=gpus)# 如果要更改要使用的gpu, 更改gpus,和device中的torch.device('cuda:0')中的number即可
保存和载入

保存可以使用

# 因为model被DP wrap了,得先取出模型
save_model_path = os.path.join(save_model_dir, f'checkpoint_{epoch:05d}.pth')
# torch.save(model.state_dict(), save_model_path)
torch.save(model.module.state_dict(), save_model_path)

然后载入模型:

model_copy.load_state_dict(torch.load(m_path, map_location=device))

如果没有提出model.module进行保存
在载入的时候可能需要如下方式:

checkpoint = torch.load(m_path)
model_copy.load_state_dict({k.replace('module.', ''): v for k, v in checkpoint.items()})

4.3 DistributedDataParallel

首先DataParallel是单进程多线程的方法,并且仅能工作在单机多卡的情况。而DistributedDataParallel方法是多进程,多线程的,并且适用与单机多卡和多机多卡的情况。即使在在单机多卡的情况下DistributedDataParallell也比DataParallel的速度更快。
目前还未深入理解:
深入理解Pytorch中的分布式训练
pytorch分布式训练
Pytorch中多GPU并行计算教程
PyTorch 并行训练极简 Demo

5.wangdb 入门

直接参看:https://docs.wandb.ai/quickstart
最详细的介绍和入门

5.1 sign up(https://wandb.ai/site)

在这里插入图片描述

5.2 安装和login

pip install wandb
wandb.login() 然后复制API key

5.3 demo

import wandb
import random# start a new wandb run to track this script
wandb.init(# set the wandb project where this run will be loggedproject="my-awesome-project",# track hyperparameters and run metadataconfig={"learning_rate": 0.02,"architecture": "CNN","dataset": "CIFAR-100","epochs": 10,}
)# simulate training
epochs = 10
offset = random.random() / 5
for epoch in range(2, epochs):acc = 1 - 2 ** -epoch - random.random() / epoch - offsetloss = 2 ** -epoch + random.random() / epoch + offset# log metrics to wandbwandb.log({"acc": acc, "loss": loss})# [optional] finish the wandb run, necessary in notebooks5b1bb8a27da51a7375b4b52c24a82fe1807877f1
wandb.finish()

运行之后:

wandb: Currently logged in as: wangty537. Use `wandb login --relogin` to force relogin
wandb: Tracking run with wandb version 0.15.10
wandb: Run data is saved locally in D:\code\denoise\noise-synthesis-main\wandb\run-20230921_103737-j9ezjcqo
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run wobbly-jazz-1
wandb:  View project at https://wandb.ai/wangty537/my-awesome-project
wandb:  View run at https://wandb.ai/wangty537/my-awesome-project/runs/j9ezjcqo
wandb: Waiting for W&B process to finish... (success).
wandb: 
wandb: Run history:
wandb:  acc ▁▆▇██▇▇█
wandb: loss █▄█▁▅▁▄▁
wandb: 
wandb: Run summary:
wandb:  acc 0.88762
wandb: loss 0.12236
wandb: 
wandb:  View run wobbly-jazz-1 at: https://wandb.ai/wangty537/my-awesome-project/runs/j9ezjcqo
wandb: Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
wandb: Find logs at: .\wandb\run-20230921_103737-j9ezjcqo\logs

然后可以在 https://wandb.ai/home 查看相关信息
在这里插入图片描述

https://docs.wandb.ai/quickstart 还介绍了更多高阶应用。

5.model and train

5.1 create model

利用注册机制

# create model
model = build_model(opt)
def build_model(opt):"""Build model from options.Args:opt (dict): Configuration. It must contain:model_type (str): Model type."""opt = deepcopy(opt)model = MODEL_REGISTRY.get(opt['model_type'])(opt)logger = get_root_logger()logger.info(f'Model [{model.__class__.__name__}] is created.')return model

5.2 opt中设置

model_type: SRModel
scale: 2

5.2 SRModel 类

BaseModel是基类

@MODEL_REGISTRY.register()
class SRModel(BaseModel):xxx

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/109704.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

详解js数组操作——filter()方法

引言 在JavaScript中,我们经常需要对数组进行筛选,以便根据特定的条件获取所需的元素。而JavaScript的filter()方法就是一个非常有用的工具,它可以帮助我们轻松地筛选数组中的元素。本文将介绍如何使用filter()方法,以及一些实用…

react+antd+Table实现表格初始化勾选某条数据,分页切换保留上一页勾选的数据

加上rowKey这个属性 <Table rowKey{record > record.id} // 加上rowKey这个属性rowSelection{rowSelection}columns{columns}dataSource{tableList}pagination{paginationProps} />

众佰诚:抖音小店的体验分什么时候更新

随着移动互联网的发展&#xff0c;越来越多的电商平台开始涌现&#xff0c;其中抖音小店作为一种新型的电商模式&#xff0c;受到了许多用户的欢迎。然而&#xff0c;对于抖音小店的体验分更新时间&#xff0c;很多用户并不是很清楚。本文将对此进行详细的解答。 首先&#xff…

SimpleCG图像操作基础

上一篇我们介绍了程序的交互功能&#xff0c;就可以编写一些简单的游戏了&#xff0c;例如贪吃蛇、扫雷、俄罗斯方块、五子棋等&#xff0c;都可以使用图形函数直接绘制&#xff0c;在后续文章中将逐一展示。不过编写画面丰富游戏离不开图像&#xff0c;所以本篇我们介绍一下基…

智能合同和TikTok:揭示加密技术的前景

在当今数字化时代&#xff0c;智能合同和加密技术都成为了技术和商业世界中的热门话题。它们代表了一个崭新的未来&#xff0c;有着潜在的巨大影响。 然而&#xff0c;你或许从未想过将这两者联系在一起&#xff0c;直到今天。本文将探讨智能合同和TikTok之间的联系&#xff0…

代码随想录算法训练营Day56|动态规划14

代码随想录算法训练营Day56|动态规划14 文章目录 代码随想录算法训练营Day56|动态规划14一、1143.最长公共子序列二、 1035.不相交的线三、53. 最大子序和 动态规划 一、1143.最长公共子序列 class Solution {public int longestCommonSubsequence(String text1, String text2…

sql聚合函数嵌套问题 aggregate function cannot contain aggregate parameters

在需求的应用场景&#xff0c;需要对create_time字段求最小值并求和&#xff0c;刚开始理所当然写成像下面这样&#xff1a; SUM(COALESCE (CASE WHEN MIN(crl.create_time) BETWEEN date_add(date_sub(current_date(), 1), -1 * (open_case_day_num % 6)) AND current_date()…

辉视IP对讲与SIP视频对讲:革新的通信技术与应用领域的开启

辉视IP对讲与辉视SIP视频对讲系统&#xff0c;不仅在技术上实现了一次革新&#xff0c;更在应用领域上开启了新的篇章。它们不仅仅是一种通信工具&#xff0c;更是一种集成了先进技术和多种功能的高效解决方案&#xff0c;为各领域提供了一种安全、便捷、高效的通信体验。 辉视…

5. 函数式接口

5.1 概述 只有一个抽象方法的接口我们称之为函数接口。 JDK的函数式接口都加上了 FunctionalInterface 注解进行标识。但是无论是否加上该注解只要接口中只有一个抽象方法&#xff0c;都是函数式接口。 在Java中&#xff0c;抽象方法是一种没有方法体&#xff08;实现代码&a…

【AOP系列】6.缓存处理

在Java中&#xff0c;我们可以使用Spring AOP&#xff08;面向切面编程&#xff09;和自定义注解来做缓存处理。以下是一个简单的示例&#xff1a; 首先&#xff0c;我们创建一个自定义注解&#xff0c;用于标记需要进行缓存处理的方法&#xff1a; import java.lang.annotat…

联想G50笔记本直接使用F键功能(F1~F12)需要在BIOS设置关闭热键功能可以这样操作!

如果开启启用热键模式按F1就会出现FnF1的效果&#xff0c;不喜欢此方式按键的用户可以进入BIOS设置界面停用热键模式即可。 停用热键模式方法如下&#xff1a; 1、重新启动笔记本电脑&#xff0c;当笔记本电脑屏幕出现Lenovo标识的时候&#xff0c;立即按FnF2进入BIOS设置界面…

表单规定输入域的选项列表(html5新元素)

datalist datalist 元素规定输入域的选项列表。 datalist属性规定 form 或 input 域应该拥有自动完成功能。当用户在自动完成域中开始输入时&#xff0c;浏览器应该在该域中显示填写的选项&#xff1a; 使用 input元素的列表属性与datalist元素绑定. 还有一定的搜索能力&…

CVE-2020-9483 apache skywalking SQL注入漏洞

漏洞概述 当使用H2 / MySQL / TiDB作为Apache SkyWalking存储时&#xff0c;通过GraphQL协议查询元数据时&#xff0c;存在SQL注入漏洞&#xff0c;该漏洞允许访问未指定的数据。 Apache SkyWalking 6.0.0到6.6.0、7.0.0 H2 / MySQL / TiDB存储实现不使用适当的方法来设置SQL参…

GPIO基本原理

名词解释 高低电平&#xff1a;GPIO引脚电平范围&#xff1a;0V~3.3V&#xff08;部分引脚可容忍5V&#xff09;数据0就是0V&#xff0c;代表低电平&#xff1b;数据1就是3.3V&#xff0c;代表高电平&#xff1b; STM32是32位的单片机&#xff0c;所以内部寄存器也都是32位的…

FilterRegistrationBean能不能排除指定url

文章目录 什么是FilterRegistrationBean举个栗子但是如果我想要排除某些uri方法总结FilterRegistrationBean只能设置指定的url进行过滤,而不能指定排除uri,只能使用OncePerRequestFilter的shouldNotFilter方法,排除uri 什么是FilterRegistrationBean FilterRegistrationBean是…

用于细胞定位的指数距离变换图--Exponential Distance Transform Maps for Cell Localization

论文&#xff1a;Exponential Distance Transform Maps for Cell Localization Paper Link&#xff1a; Exponential Distance Transform Maps for Cell Localization Code&#xff08;有EDT Map的生成方式&#xff09;&#xff1a; https://github.com/Boli-trainee/MHFAN 核…

深入了解Golang:基本语法与核心特性解析

1. 引言 Golang&#xff08;Go&#xff09;是谷歌开发的一门开源编程语言&#xff0c;于2007年首次公开亮相&#xff0c;随后在2012年正式发布。Golang以其简洁、高效和可靠的设计而备受开发者青睐。作为一门编译型语言&#xff0c;Golang具有静态类型和垃圾回收功能&#xff…

网络编程 - TCP协议

一&#xff0c;TCP基本概念 TCP的特性&#xff1a; TCP是有连接的&#xff1a;TCP想要通信&#xff0c;就需要先建立连接&#xff0c;之后才能通信 TCP是可靠传输&#xff1a;网络上进行通信&#xff0c;A给B发消息&#xff0c;这个消息是不可能做到100%送达的&#xff0c;所以…

树模型(三)决策树

决策树是什么&#xff1f;决策树(decision tree)是一种基本的分类与回归方法。 长方形代表判断模块 (decision block)&#xff0c;椭圆形成代表终止模块(terminating block)&#xff0c;表示已经得出结论&#xff0c;可以终止运行。从判断模块引出的左右箭头称作为分支(branch)…

【大数据 - Doris 实践】数据表的基本使用(三):数据模型

数据表的基本使用&#xff08;三&#xff09;&#xff1a;数据模型 1.Aggregate 模型1.1 例一&#xff1a;导入数据聚合1.2 例二&#xff1a;保留明细数据1.3 例三&#xff1a;导入数据与已有数据聚合 2.Uniq 模型3.Duplicate 模型4.数据模型的选择建议5.聚合模型的局限性 Dori…