WeightedRandomSamplerDDP, 加权的DDP采样器

先来看一下WeighedRandomSampler:

[docs]class WeightedRandomSampler(Sampler[int]):r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights).Args:weights (sequence)   : a sequence of weights, not necessary summing up to onenum_samples (int): number of samples to drawreplacement (bool): if ``True``, samples are drawn with replacement.If not, they are drawn without replacement, which means that when asample index is drawn for a row, it cannot be drawn again for that row.generator (Generator): Generator used in sampling.Example:>>> # xdoctest: +IGNORE_WANT("non-deterministic")>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))[4, 4, 1, 4, 5]>>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))[0, 1, 4, 3, 2]"""weights: Tensornum_samples: intreplacement: booldef __init__(self, weights: Sequence[float], num_samples: int,replacement: bool = True, generator=None) -> None:if not isinstance(num_samples, int) or isinstance(num_samples, bool) or \num_samples <= 0:raise ValueError(f"num_samples should be a positive integer value, but got num_samples={num_samples}")if not isinstance(replacement, bool):raise ValueError(f"replacement should be a boolean value, but got replacement={replacement}")weights_tensor = torch.as_tensor(weights, dtype=torch.double)if len(weights_tensor.shape) != 1:raise ValueError("weights should be a 1d sequence but given "f"weights have shape {tuple(weights_tensor.shape)}")self.weights = weights_tensorself.num_samples = num_samplesself.replacement = replacementself.generator = generatordef __iter__(self) -> Iterator[int]:rand_tensor = torch.multinomial(self.weights, self.num_samples, self.replacement, generator=self.generator)yield from iter(rand_tensor.tolist())def __len__(self) -> int:return self.num_samples

再来看一下分布式的采样器。

[docs]class DistributedSampler(Sampler[T_co]):r"""Sampler that restricts data loading to a subset of the dataset.It is especially useful in conjunction with:class:`torch.nn.parallel.DistributedDataParallel`. In such a case, eachprocess can pass a :class:`~torch.utils.data.DistributedSampler` instance as a:class:`~torch.utils.data.DataLoader` sampler, and load a subset of theoriginal dataset that is exclusive to it... note::Dataset is assumed to be of constant size and that any instance of it alwaysreturns the same elements in the same order.Args:dataset: Dataset used for sampling.num_replicas (int, optional): Number of processes participating indistributed training. By default, :attr:`world_size` is retrieved from thecurrent distributed group.rank (int, optional): Rank of the current process within :attr:`num_replicas`.By default, :attr:`rank` is retrieved from the current distributedgroup.shuffle (bool, optional): If ``True`` (default), sampler will shuffle theindices.seed (int, optional): random seed used to shuffle the sampler if:attr:`shuffle=True`. This number should be identical across allprocesses in the distributed group. Default: ``0``.drop_last (bool, optional): if ``True``, then the sampler will drop thetail of the data to make it evenly divisible across the number ofreplicas. If ``False``, the sampler will add extra indices to makethe data evenly divisible across the replicas. Default: ``False``... warning::In distributed mode, calling the :meth:`set_epoch` method atthe beginning of each epoch **before** creating the :class:`DataLoader` iteratoris necessary to make shuffling work properly across multiple epochs. Otherwise,the same ordering will be always used.Example::>>> # xdoctest: +SKIP>>> sampler = DistributedSampler(dataset) if is_distributed else None>>> loader = DataLoader(dataset, shuffle=(sampler is None),...                     sampler=sampler)>>> for epoch in range(start_epoch, n_epochs):...     if is_distributed:...         sampler.set_epoch(epoch)...     train(loader)"""def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None,rank: Optional[int] = None, shuffle: bool = True,seed: int = 0, drop_last: bool = False) -> None:if num_replicas is None:if not dist.is_available():raise RuntimeError("Requires distributed package to be available")num_replicas = dist.get_world_size()if rank is None:if not dist.is_available():raise RuntimeError("Requires distributed package to be available")rank = dist.get_rank()if rank >= num_replicas or rank < 0:raise ValueError(f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]")self.dataset = datasetself.num_replicas = num_replicasself.rank = rankself.epoch = 0self.drop_last = drop_last# If the dataset length is evenly divisible by # of replicas, then there# is no need to drop any data, since the dataset will be split equally.if self.drop_last and len(self.dataset) % self.num_replicas != 0:  # type: ignore[arg-type]# Split to nearest available length that is evenly divisible.# This is to ensure each rank receives the same amount of data when# using this Sampler.self.num_samples = math.ceil((len(self.dataset) - self.num_replicas) / self.num_replicas  # type: ignore[arg-type])else:self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)  # type: ignore[arg-type]self.total_size = self.num_samples * self.num_replicasself.shuffle = shuffleself.seed = seeddef __iter__(self) -> Iterator[T_co]:if self.shuffle:# deterministically shuffle based on epoch and seedg = torch.Generator()g.manual_seed(self.seed + self.epoch)indices = torch.randperm(len(self.dataset), generator=g).tolist()  # type: ignore[arg-type]else:indices = list(range(len(self.dataset)))  # type: ignore[arg-type]if not self.drop_last:# add extra samples to make it evenly divisiblepadding_size = self.total_size - len(indices)if padding_size <= len(indices):indices += indices[:padding_size]else:indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]else:# remove tail of data to make it evenly divisible.indices = indices[:self.total_size]assert len(indices) == self.total_size# subsampleindices = indices[self.rank:self.total_size:self.num_replicas]assert len(indices) == self.num_samplesreturn iter(indices)def __len__(self) -> int:return self.num_samplesdef set_epoch(self, epoch: int) -> None:r"""Set the epoch for this sampler.When :attr:`shuffle=True`, this ensures all replicasuse a different random ordering for each epoch. Otherwise, the next iteration of thissampler will yield the same ordering.Args:epoch (int): Epoch number."""self.epoch = epoch

最后是WeightedRandomSamplerDDP

class WeightedRandomSamplerDDP(torch.utils.data.distributed.DistributedSampler):r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights).Args:data_set: Dataset used for sampling.weights (sequence)   : a sequence of weights, not necessary summing up to onenum_replicas (int, optional): Number of processes participating indistributed training. By default, :attr:`world_size` is retrieved from thecurrent distributed group.rank (int, optional): Rank of the current process within :attr:`num_replicas`.By default, :attr:`rank` is retrieved from the current distributedgroup.num_samples (int): number of samples to drawreplacement (bool): if ``True``, samples are drawn with replacement.If not, they are drawn without replacement, which means that when asample index is drawn for a row, it cannot be drawn again for that row.generator (Generator): Generator used in sampling."""weights: torch.Tensornum_samples: intreplacement: booldef __init__(self, data_set, weights: Sequence[float], num_samples: int,num_replicas: Optional[int] = None, rank: Optional[int] = None, shuffle: bool = True,seed: int = 0, drop_last: bool = True, replacement: bool = True, generator=None) -> None:super(WeightedRandomSamplerDDP, self).__init__(data_set, num_replicas, rank, shuffle, seed, drop_last)self.weights = torch.as_tensor(weights, dtype=torch.double)if self.drop_last and len(self.dataset) % self.num_replicas != 0:  # type: ignore[arg-type]# Split to nearest available length that is evenly divisible.# This is to ensure each rank receives the same amount of data when# using this Sampler.self.num_samples = math.ceil((len(self.dataset) - self.num_replicas) / self.num_replicas  # type: ignore[arg-type])else:self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)  # type: ignore[arg-type]self.replacement = replacementself.generator = generatorself.weights = self.weights[self.rank::self.num_replicas]self.num_samples = self.num_samples // self.num_replicasdef __iter__(self):rand_tensor = torch.multinomial(self.weights, self.num_samples, self.replacement, generator=self.generator)rand_tensor =  self.rank + rand_tensor * self.num_replicasreturn iter(rand_tensor.tolist())def __len__(self):return self.num_samples

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/18157.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JAVA基础----线程池

①什么是线程池&#xff1f; 线程池是对所有线程进行统一的管理和控制&#xff0c;从而提高系统的运行效率。当我们要使用线程的时候可以直接从线程池中拿&#xff0c;用完也不用自己去销毁&#xff0c;省去创建和销毁的时间&#xff0c;提升系统的响应时间。 ②线程池的七大核…

单链表的相关题目

1.删除链表中给定值val的所有结点 public void removeall(int key) {//由于是删除链表中所有和key值相同的结点,所以可以设置两个ListNode类型的数据,一个在前面,一个在后面.//直到前面的走到链表的最后,这样完成了遍历.//先判断一下这个链表是否为空if(headnull){System.out.…

【ArcGIS For JS】前端geojson渲染行政区划图层并加标签

原理 通过DataV工具 生成行政区的geojson&#xff08;得到各区的面元素数据&#xff09;, 随后使用手动绘制featureLayer与Label&#xff0c;并加载到地图。 //vue3加载geojson数据public/geojson/pt.json,在MapView渲染上加载geojson数据 type是"MultiPolygon"fetc…

Vue 3中的v-for指令使用详解

Vue 3中的v-for指令使用详解 一、前言1. 基本语法2. 循环渲染对象3. 在组件中使用v-for4.普通案例5. 其他用法 二、结语 一、前言 在Vue 3中&#xff0c;v-for指令是一个非常强大且常用的指令&#xff0c;它用于在模板中循环渲染数组或对象的内容。本文将为您详细介绍Vue 3中v…

Android项目实战 —— 手把手教你实现一款本地音乐播放器Dora Music

今天带大家实现一款基于Dora SDK的Android本地音乐播放器app&#xff0c;本项目也作为Dora SDK的实践项目或使用教程。使用到开源库有[https://github.com/dora4/dora] 、[https://github.com/dora4/dcache-android] 等。先声明一点&#xff0c;本项目主要作为框架的使用教程&a…

Unity【入门】环境搭建、界面基础、工作原理

Unity环境搭建、界面基础、工作原理 Unity环境搭建 文章目录 Unity环境搭建1、Unity引擎概念1、什么是游戏引擎2、游戏引擎对于我们的意义3、如何学习游戏引擎 2、软件下载和安装3、新工程和工程文件夹 Unity界面基础1、Scene场景和Hierarchy层级窗口1、窗口布局2、Hierarchy层…

跨平台游戏引擎 Axmol-2.1.3 发布

我们非常荣幸&#xff0c;axmol 能在发布此版本之前被 awsome-cpp 收录&#xff01; The 2.1.3 release is a minor LTS release for bugfixes and improvements, we also have new home page: https://axmol.dev , thanks to all contributers of axmol, especially iAndyHD…

多分支拓扑阻抗匹配

最近测试信号质量&#xff0c;发现在有过冲、振铃等问题的时候大部分硬件工程师喜欢直接调大匹配电阻或者减小驱动电流&#xff0c;虽然这种操作是有效果的&#xff0c;但是我认为应该还可以更严谨的计算下&#xff0c;而不是选几个电阻多次尝试&#xff0c;显得不是很专业。 …

一文了解Redis及场景应用

Redis是一个高性能的、开源的、基于键值对&#xff08;Key-Value&#xff09;的数据结构存储系统&#xff0c;它支持网络、内存存储以及可选的持久化特性。 以下是关于Redis的一些详细说明&#xff1a; 核心特性 数据结构丰富&#xff1a; Strings&#xff1a;最基本的数据类型…

call函数实现

call 函数的实现步骤&#xff1a; 判断调用对象是否为函数&#xff0c;即使我们是定义在函数的原型上的&#xff0c;但是可能出现使用 call 等方式调用的情况。 判断传入上下文对象是否存在&#xff0c;如果不存在&#xff0c;则设置为 window 。 处理传入的参数&#xff0c;…

推送镜像到私有harbor仓库

本地已制作镜像&#xff1a;tomcat-8.5.100-centos7.9:1.0。 本地已经搭建私有仓库&#xff1a;harbor.igmwx.com。 现在需要把镜像 tomcat-8.5.100-centos7.9:1.0 推送到harbor。 &#xff08;1&#xff09;查看本地镜像&#xff1a;sudo docker images zhangzkzhangzk:~/d…

人脸识别--Dlib(二)

Dlib 是一个现代化的 C 工具库&#xff0c;包含了机器学习、计算机视觉和图像处理的广泛功能。它特别在面部识别和检测方面非常流行。Dlib 的主要优点是其易用性、广泛的功能集和跨平台支持。下面是对 Dlib 的详细介绍&#xff0c;包括其主要功能、使用方法和优缺点。 主要功能…

java 对接农行支付相关业务(二)

文章目录 农行掌银集成第三方APP1:掌银支付对接快e通的流程1.1 在农行网站上注册我们的app信息([网址](https://openbank.abchina.com/Portal/index/index.html))1.2:java整合农行的jar包依赖1.3:把相关配置信息整合到项目中1.4:前端获取授权码信息1.5:后端根据授权码信…

【动态规划】速解简单多状态类问题

目录 17.16 按摩师 题⽬描述&#xff1a; 解法&#xff08;动态规划&#xff09;&#xff1a; 1. 状态表⽰&#xff1a; 2. 状态转移⽅程&#xff1a; 3. 初始化&#xff1a; 4. 填表顺序 5. 返回值 代码 总结&#xff1a; 213.打家劫舍II&#xff08;medium&#x…

Android 之广播监听网络变化

网络状态变化监听帮助类 NetBroadcastReceiverHelper public class NetBroadcastReceiverHelper {private static final String TAG "NetBroadcastReceiverHelper";private static final String NET_CHANGE_ACTION "android.net.conn.CONNECTIVITY_CHANGE&qu…

大模型中GPTs,Assistants API, 原生API的使用场景?

在大模型的使用中&#xff0c;GPTs、Assistants API和原生API各有其独特的应用场景和优势。以下是它们各自的使用场景&#xff1a; GPTs场景&#xff1a; 自然语言处理任务&#xff1a; GPTs擅长处理各种自然语言处理任务&#xff0c;如文本生成、翻译、摘要、情感分析等。 对…

C++ 基于vs2019创建并使用动态链接库(dll)

库的基本认识 静态库&#xff08;Static Library&#xff09; 基本概念&#xff1a;静态库是在编译时链接到目标程序中的库文件。它包含了程序运行所需的所有函数和数据&#xff0c;这些函数和数据会被直接嵌入到最终生成的可执行文件中。静态库通常以.a&#xff08;在Unix-l…

分频器对相位噪声影响

本文我们将分析输入时钟被N分频之后的输出时钟的相位噪声如何变化。首先理想分频器的意思是我们假设分频器不会引入附加相位噪声&#xff0c;并且输入和输出时钟之间没有延时。我们假设每一个输出边沿的位置都完美的与输入边沿相对齐&#xff0c;这样便于分析。由于每N个输入时…

[FlareOn6]Overlong

很简单的逻辑 一度让我以为是加保护了 运行告诉我从未编码,懵逼 动调你也发现,你根本没什么可以操作的空间,密文什么的,都是固定的 但是这里大家发现没 我们只加密了28个密文 然后text是128 也就是 0x80 是不是因为密文没加密完呢 我也懒得去写代码了 汇编直接修改push 字…

axios和ts的简单使用

按照官网的使用案例简单记下笔记 1&#xff1a;安装 npm install axios 2&#xff1a;案例 一个简单的config配置信息 // 发起一个post请求 axios({method: post,url: /user/12345,data: {firstName: Fred,lastName: Flintstone} }); case // 在 node.js 用GET请求获取…