代码解读:Diffusion Models中的长宽桶技术(Aspect Ratio Bucketing)

Diffusion Models专栏文章汇总:入门与实战

前言:自从SDXL提出了长宽桶技术之后,彻底解决了不同长宽比的图像输入问题,现在已经成为训练扩散模型必选的方案。这篇博客从代码详细解读如何在模型训练的时候运用长宽桶技术(Aspect Ratio Bucketing)。

目录

原理解读-原有训练的问题

长宽桶技术(Aspect Ratio Bucketing)

完整代码


原理解读-原有训练的问题

纵横比分桶训练可以极大地提高输出质量,现有图像生成模型的一个常见问题是,它们非常容易生成带有非自然作物的图像。这是因为这些模型被训练成生成方形图像。然而,大多数照片和艺术品都不是方形的。然而,该模型只能同时在相同大小的图像上工作,并且在训练过程中,通常的做法是同时在多个训练样本上操作,以优化所使用gpu的效率。作为妥协,选择正方形图像,在训练过程中,只裁剪出每个图像的中心,然后作为训练样例显示给图像生成模型。

例如,人类通常是没有脚或头的,剑只有一个刀刃,剑柄和剑尖在框架外。因为我们正在创建一个图像生成模型来配合我们的故事叙述体验,所以我们的模型能够产生适当的,未裁剪的角色是很重要的,并且生成的骑士不应该持有延伸到无限的金属状直线。

对裁剪图像进行训练的另一个问题是,它可能导致文本和图像之间的不匹配。例如,带有王冠标签的图像通常在中央裁剪后不再包含王冠,因此君主已经被斩首。我们发现使用随机作物代替中心作物只能略微改善这些问题。使用具有可变图像大小的稳定扩散是可能的,尽管可以注意到,远远超过512x512的原生分辨率往往会引入重复的图像元素,并且非常低的分辨率会产生无法识别的图像。

尽管如此,这向我们表明,在可变大小的图像上训练模型应该是可能的。在单个、可变大小的样本上进行训练是微不足道的,但也非常缓慢,而且由于使用小批量提供的缺乏正则化,更容易产生训练不稳定性。

长宽桶技术(Aspect Ratio Bucketing)

由于这个问题似乎没有现有的解决方案,我们已经为我们的数据集实现了自定义批生成代码,允许创建批处理,其中批处理中的每个项目具有相同的大小,但批处理的图像大小可能不同。

我们通过一种叫做宽高比桶的方法来做到这一点。另一种方法是使用固定的图像大小,缩放每个图像以适应这个固定的大小,并应用在训练期间被掩盖的填充。由于这会导致训练期间不必要的计算,我们没有选择遵循这种替代方法。

在下面,我们描述了我们自定义的宽高比桶的批量生成方案背后的原始想法。

首先,我们必须定义要将数据集的图像排序到哪个存储桶中。为此,我们定义的最大图像尺寸为512x768,最大尺寸为1024。由于最大图像大小为512x768,比512x512大,需要更多的VRAM,因此每个gpu的批处理大小必须降低,这可以通过梯度积累来补偿。

我们通过应用以下算法生成桶:

Set the width to 256.
While the width is less than or equal to 1024:Find the largest height such that height is less than or equal to 1024 and that width multiplied by height is less than or equal to 512 * 768.Add the resolution given by height and width as a bucket.Increase the width by 64.

同样的重复,宽度和高度互换。重复的桶将从列表中删除,并添加一个大小为512x512的桶。

接下来,我们将图像分配到相应的桶中。为此,我们首先将桶分辨率存储在NumPy数组中,并计算每个分辨率的长宽比。对于数据集中的每张图像,我们检索其分辨率并计算长宽比。图像宽高比从桶宽高比数组中减去,使我们能够根据宽高比差的绝对值有效地选择最接近的桶:

image_bucket = argmin(abs(bucket_aspects — image_aspect))

图像的桶号与其数据集中的项目ID相关联。如果图像的宽高比非常极端,甚至与最适合的桶相差太大,则从数据集中修剪图像。

由于我们在多个GPU上进行训练,在每个epoch之前,我们对数据集进行了分片,以确保每个GPU在大小相等的不同子集上工作。为此,我们首先复制数据集中的项目id列表并对它们进行洗牌。如果这个复制的列表不能被gpu数量乘以批大小整除,则会对列表进行修剪,并删除最后的项以使其可整除。

然后,我们根据当前进程的全局排名选择1/world_size*bsz项id的不同子集。自定义批处理生成的其余部分将从这些过程中的任何一个过程中进行描述,并对数据集项id的子集进行操作。

对于当前的分片,每个bucket的列表是通过迭代打乱的数据集项目ID列表并将ID分配给分配给图像的bucket对应的列表来创建的。

处理完所有图像后,我们遍历每个bucket的列表。如果它的长度不能被批大小整除,则根据需要删除列表上的最后一个元素以使其可整除,并将它们添加到单独的捕获所有桶中。由于保证整个分片大小包含许多可被批大小整除的元素,因此保证生成一个长度可被批大小整除的所有bucket。

当请求批处理时,我们从加权分布中随机抽取一个桶。桶的权重设置为桶的大小除以所有剩余桶的大小。这确保了即使有大小差异很大的桶,自定义批生成在训练期间不会引入强烈的偏差,根据图像大小显示图像。如果在没有加权的情况下选择桶,那么小的桶将在训练过程中早期清空,只有最大的桶将在训练结束时保留。按大小对桶进行加权可以避免这种情况。

最后从所选的桶中取出一批项。取走的项目从桶中移除。如果桶现在为空,则在epoch的剩余时间内删除它。所选的项id和所选桶的分辨率现在被传递给图像加载函数。

完整代码

import numpy as np
import pickle
import timedef get_prng(seed):return np.random.RandomState(seed)class BucketManager:def __init__(self, bucket_file, valid_ids=None, max_size=(768,512), divisible=64, step_size=8, min_dim=256, base_res=(512,512), bsz=1, world_size=1, global_rank=0, max_ar_error=4, seed=42, dim_limit=1024, debug=False):with open(bucket_file, "rb") as fh:self.res_map = pickle.load(fh)if valid_ids is not None:new_res_map = {}valid_ids = set(valid_ids)for k, v in self.res_map.items():if k in valid_ids:new_res_map[k] = vself.res_map = new_res_mapself.max_size = max_sizeself.f = 8self.max_tokens = (max_size[0]/self.f) * (max_size[1]/self.f)self.div = divisibleself.min_dim = min_dimself.dim_limit = dim_limitself.base_res = base_resself.bsz = bszself.world_size = world_sizeself.global_rank = global_rankself.max_ar_error = max_ar_errorself.prng = get_prng(seed)epoch_seed = self.prng.tomaxint() % (2**32-1)self.epoch_prng = get_prng(epoch_seed) # separate prng for sharding use for increased thread resilienceself.epoch = Noneself.left_over = Noneself.batch_total = Noneself.batch_delivered = Noneself.debug = debugself.gen_buckets()self.assign_buckets()self.start_epoch()def gen_buckets(self):if self.debug:timer = time.perf_counter()resolutions = []aspects = []w = self.min_dimwhile (w/self.f) * (self.min_dim/self.f) <= self.max_tokens and w <= self.dim_limit:h = self.min_dimgot_base = Falsewhile (w/self.f) * ((h+self.div)/self.f) <= self.max_tokens and (h+self.div) <= self.dim_limit:if w == self.base_res[0] and h == self.base_res[1]:got_base = Trueh += self.divif (w != self.base_res[0] or h != self.base_res[1]) and got_base:resolutions.append(self.base_res)aspects.append(1)resolutions.append((w, h))aspects.append(float(w)/float(h))w += self.divh = self.min_dimwhile (h/self.f) * (self.min_dim/self.f) <= self.max_tokens and h <= self.dim_limit:w = self.min_dimgot_base = Falsewhile (h/self.f) * ((w+self.div)/self.f) <= self.max_tokens and (w+self.div) <= self.dim_limit:if w == self.base_res[0] and h == self.base_res[1]:got_base = Truew += self.divresolutions.append((w, h))aspects.append(float(w)/float(h))h += self.divres_map = {}for i, res in enumerate(resolutions):res_map[res] = aspects[i]self.resolutions = sorted(res_map.keys(), key=lambda x: x[0] * 4096 - x[1])self.aspects = np.array(list(map(lambda x: res_map[x], self.resolutions)))self.resolutions = np.array(self.resolutions)if self.debug:timer = time.perf_counter() - timerprint(f"resolutions:\n{self.resolutions}")print(f"aspects:\n{self.aspects}")print(f"gen_buckets: {timer:.5f}s")def assign_buckets(self):if self.debug:timer = time.perf_counter()self.buckets = {}self.aspect_errors = []skipped = 0skip_list = []for post_id in self.res_map.keys():w, h = self.res_map[post_id]aspect = float(w)/float(h)bucket_id = np.abs(self.aspects - aspect).argmin()if bucket_id not in self.buckets:self.buckets[bucket_id] = []error = abs(self.aspects[bucket_id] - aspect)if error < self.max_ar_error:self.buckets[bucket_id].append(post_id)if self.debug:self.aspect_errors.append(error)else:skipped += 1skip_list.append(post_id)for post_id in skip_list:del self.res_map[post_id]if self.debug:timer = time.perf_counter() - timerself.aspect_errors = np.array(self.aspect_errors)print(f"skipped images: {skipped}")print(f"aspect error: mean {self.aspect_errors.mean()}, median {np.median(self.aspect_errors)}, max {self.aspect_errors.max()}")for bucket_id in reversed(sorted(self.buckets.keys(), key=lambda b: len(self.buckets[b]))):print(f"bucket {bucket_id}: {self.resolutions[bucket_id]}, aspect {self.aspects[bucket_id]:.5f}, entries {len(self.buckets[bucket_id])}")print(f"assign_buckets: {timer:.5f}s")def start_epoch(self, world_size=None, global_rank=None):if self.debug:timer = time.perf_counter()if world_size is not None:self.world_size = world_sizeif global_rank is not None:self.global_rank = global_rank# select ids for this epoch/rankindex = np.array(sorted(list(self.res_map.keys())))index_len = index.shape[0]index = self.epoch_prng.permutation(index)index = index[:index_len - (index_len % (self.bsz * self.world_size))]#print("perm", self.global_rank, index[0:16])index = index[self.global_rank::self.world_size]self.batch_total = index.shape[0] // self.bszassert(index.shape[0] % self.bsz == 0)index = set(index)self.epoch = {}self.left_over = []self.batch_delivered = 0for bucket_id in sorted(self.buckets.keys()):if len(self.buckets[bucket_id]) > 0:self.epoch[bucket_id] = np.array([post_id for post_id in self.buckets[bucket_id] if post_id in index], dtype=np.int64)self.prng.shuffle(self.epoch[bucket_id])self.epoch[bucket_id] = list(self.epoch[bucket_id])overhang = len(self.epoch[bucket_id]) % self.bszif overhang != 0:self.left_over.extend(self.epoch[bucket_id][:overhang])self.epoch[bucket_id] = self.epoch[bucket_id][overhang:]if len(self.epoch[bucket_id]) == 0:del self.epoch[bucket_id]if self.debug:timer = time.perf_counter() - timercount = 0for bucket_id in self.epoch.keys():count += len(self.epoch[bucket_id])print(f"correct item count: {count == len(index)} ({count} of {len(index)})")print(f"start_epoch: {timer:.5f}s")def get_batch(self):if self.debug:timer = time.perf_counter()# check if no data left or no epoch initializedif self.epoch is None or self.left_over is None or (len(self.left_over) == 0 and not bool(self.epoch)) or self.batch_total == self.batch_delivered:self.start_epoch()found_batch = Falsebatch_data = Noneresolution = self.base_reswhile not found_batch:bucket_ids = list(self.epoch.keys())if len(self.left_over) >= self.bsz:bucket_probs = [len(self.left_over)] + [len(self.epoch[bucket_id]) for bucket_id in bucket_ids]bucket_ids = [-1] + bucket_idselse:bucket_probs = [len(self.epoch[bucket_id]) for bucket_id in bucket_ids]bucket_probs = np.array(bucket_probs, dtype=np.float32)bucket_lens = bucket_probsbucket_probs = bucket_probs / bucket_probs.sum()bucket_ids = np.array(bucket_ids, dtype=np.int64)if bool(self.epoch):chosen_id = int(self.prng.choice(bucket_ids, 1, p=bucket_probs)[0])else:chosen_id = -1if chosen_id == -1:# using leftover images that couldn't make it into a bucketed batch and returning them for use with basic square imageself.prng.shuffle(self.left_over)batch_data = self.left_over[:self.bsz]self.left_over = self.left_over[self.bsz:]found_batch = Trueelse:if len(self.epoch[chosen_id]) >= self.bsz:# return bucket batch and resolutionbatch_data = self.epoch[chosen_id][:self.bsz]self.epoch[chosen_id] = self.epoch[chosen_id][self.bsz:]resolution = tuple(self.resolutions[chosen_id])found_batch = Trueif len(self.epoch[chosen_id]) == 0:del self.epoch[chosen_id]else:# can't make a batch from this, not enough images. move them to leftovers and try againself.left_over.extend(self.epoch[chosen_id])del self.epoch[chosen_id]assert(found_batch or len(self.left_over) >= self.bsz or bool(self.epoch))if self.debug:timer = time.perf_counter() - timerprint(f"bucket probs: " + ", ".join(map(lambda x: f"{x:.2f}", list(bucket_probs*100))))print(f"chosen id: {chosen_id}")print(f"batch data: {batch_data}")print(f"resolution: {resolution}")print(f"get_batch: {timer:.5f}s")self.batch_delivered += 1return (batch_data, resolution)def generator(self):if self.batch_delivered >= self.batch_total:self.start_epoch()while self.batch_delivered < self.batch_total:yield self.get_batch()if __name__ == "__main__":# prepare a pickle with mapping of dataset IDs to resolutions called resolutions.pkl to use thiswith open("resolutions.pkl", "rb") as fh:ids = list(pickle.load(fh).keys())counts = np.zeros((len(ids),)).astype(np.int64)id_map = {}for i, post_id in enumerate(ids):id_map[post_id] = ibm = BucketManager("resolutions.pkl", debug=True, bsz=8, world_size=8, global_rank=3)print("got: " + str(bm.get_batch()))print("got: " + str(bm.get_batch()))print("got: " + str(bm.get_batch()))print("got: " + str(bm.get_batch()))print("got: " + str(bm.get_batch()))print("got: " + str(bm.get_batch()))print("got: " + str(bm.get_batch()))bm = BucketManager("resolutions.pkl", bsz=8, world_size=1, global_rank=0, valid_ids=ids[0:16])for _ in range(16):bm.get_batch()print("got from future epoch: " + str(bm.get_batch()))bms = []for rank in range(16):bm = BucketManager("resolutions.pkl", bsz=8, world_size=16, global_rank=rank)bms.append(bm)for epoch in range(5):print(f"epoch {epoch}")for i, bm in enumerate(bms):print(f"bm {i}")first = Truefor ids, res in bm.generator():if first and i == 0:#print(ids)first = Falsefor post_id in ids:counts[id_map[post_id]] += 1print(np.bincount(counts))

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/48471.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【机器学习】-- SVM核函数(超详细解读)

支持向量机&#xff08;SVM&#xff09;中的核函数是支持向量机能够处理非线性问题并在高维空间中学习复杂决策边界的关键。核函数在SVM中扮演着将输入特征映射到更高维空间的角色&#xff0c;使得原始特征空间中的非线性问题在高维空间中变得线性可分。 一、SVM是什么&#x…

时间卷积网络(TCN):序列建模的强大工具(附Pytorch网络模型代码)

这里写目录标题 1. 引言2. TCN的核心特性2.1 序列建模任务描述2.2 因果卷积2.3 扩张卷积2.4 残差连接 3. TCN的网络结构4. TCN vs RNN5. TCN的应用TCN的实现 1. 引言 引用自&#xff1a;Bai S, Kolter J Z, Koltun V. An empirical evaluation of generic convolutional and re…

Linux系统之部署扫雷小游戏(三)

Linux系统之部署扫雷小游戏(三) 一、小游戏介绍1.1 小游戏简介1.2 项目预览二、本次实践介绍2.1 本地环境规划2.2 本次实践介绍三、检查本地环境3.1 检查系统版本3.2 检查系统内核版本3.3 检查软件源四、安装Apache24.1 安装Apache2软件4.2 启动apache2服务4.3 查看apache2服…

大厂生产解决方案:泳道隔离机制

更多大厂面试内容可见 -> http://11come.cn 大厂生产解决方案&#xff1a;泳道隔离机制 背景 在公司中&#xff0c;由于项目多、开发人员多&#xff0c;一般会有多套测试环境&#xff08;可以理解为多个服务器&#xff09;&#xff0c;同一套服务会在多套测试环境中都部署…

如何解决微服务下引起的 分布式事务问题

一、什么是分布式事务&#xff1f; 虽然叫分布式事务&#xff0c;但不是一定是分布式部署的服务之间才会产生分布式事务。不是在同一个服务或同一个数据库架构下&#xff0c;产生的事务&#xff0c;也就是分布式事务。 跨数据源的分布式事务 跨服务的分布式事务 二、解决方…

配置服务器

参考博客 1. https://blog.csdn.net/qq_31278903/article/details/83146031 2. https://blog.csdn.net/u014374826/article/details/134093409 3. https://blog.csdn.net/weixin_42728126/article/details/88887350 4. https://blog.csdn.net/Dreamhai/article/details/109…

javac详解 idea maven内部编译原理 自制编译器

起因 不知道大家在开发中&#xff0c;有没有过下面这些疑问。有的话&#xff0c;今天就一次解答清楚。 如何使用javac命令编译一个项目&#xff1f;java或者javac的一些参数到底有什么用&#xff1f;idea或者maven是如何编译java项目的&#xff1f;&#xff08;你可能猜测底层…

【一刷《剑指Offer》】面试题 47:不用加减乘除做加法

力扣对应题目链接&#xff1a;LCR 190. 加密运算 - 力扣&#xff08;LeetCode&#xff09; 牛客对应题目链接&#xff1a;不用加减乘除做加法_牛客题霸_牛客网 (nowcoder.com) 一、《剑指Offer》对应内容 二、分析题目 sumdataA⊕dataB 非进位和&#xff1a;异或运…

Unity UGUI 之 Graphic Raycaster

本文仅作学习笔记与交流&#xff0c;不作任何商业用途 本文包括但不限于unity官方手册&#xff0c;唐老狮&#xff0c;麦扣教程知识&#xff0c;引用会标记&#xff0c;如有不足还请斧正 首先手册连接如下&#xff1a; Unity - Manual: Graphic Raycaster 笔记来源于&#xff…

无人车技术浪潮真的挡不住了~

正文 无人驾驶汽车其实也不算是新鲜玩意了&#xff0c;早在十年前大家都开始纷纷投入研发&#xff0c;在那时就已经蠢蠢欲动&#xff0c;像目前大部分智驾系统和辅助驾驶系统都是无人驾驶系统的一个中间过度版本&#xff0c;就像手机进入智能机时代的中间版本。 然而前段时间突…

SpringBoot 介绍和使用(详细)

使用SpringBoot之前,我们需要了解Maven,并配置国内源(为什么要配置这些,下面会详细介绍),下面我们将创建一个SpringBoot项目"输出Hello World"介绍. 1.环境准备 ⾃检Idea版本: 社区版: 2021.1 -2022.1.4 专业版: ⽆要求 如果个⼈电脑安装的idea不在这个范围, 需要…

LeetCode 热题 HOT 100 (001/100)【宇宙最简单版】

【链表】 No. 0160 相交链表 【简单】&#x1f449;力扣对应题目指路 希望对你有帮助呀&#xff01;&#xff01;&#x1f49c;&#x1f49c; 如有更好理解的思路&#xff0c;欢迎大家留言补充 ~ 一起加油叭 &#x1f4a6; 欢迎关注、订阅专栏 【力扣详解】谢谢你的支持&#x…

搜维尔科技:【产品推荐】Euleria Health Riablo 运动功能训练与评估系统

Euleria Health Riablo 运动功能训练与评估系统 Riablo提供一种创新的康复解决方案&#xff0c;将康复和训练变得可激励、可衡量和可控制。Riablo通过激活本体感觉&#xff0c;并通过视听反馈促进神经肌肉的训练。 得益于其技术先进和易用性&#xff0c;Riablo是骨科、运动医…

jmeter部署

一、windows环境下部署 1、安装jdk并配置jdk的环境变量 (1) 安装jdk jdk下载完成后双击安装包&#xff1a;无限点击"下一步"直到完成&#xff0c;默认路径即可。 (2) jdk安装完成后配置jdk的环境变量 找到环境变量中的系统变量&#xff1a;此电脑 --> 右键属性 …

C语言:温度转换

1.题目&#xff1a;实现摄氏度&#xff08;Celsius&#xff09;和华氏度&#xff08;Fahrenheit&#xff09;之间的转换。 输入一个华氏温度&#xff0c;输出摄氏温度&#xff0c;结果保留两位小数。 2.思路&#xff1a;&#xff08;这是固定公式&#xff0c;其中 F 是华氏度&a…

【C语言】详解结构体(下)(位段)

文章目录 前言1. 位段的含义2. 位段的声明3. 位段的内存分配&#xff08;重点&#xff09;3.1 存储方向的问题3.2 剩余空间利用的问题 4. 位段的跨平台问题5. 位段的应用6. 总结 前言 相信大部分的读者在学校或者在自学时结构体的知识时&#xff0c;可能很少会听到甚至就根本没…

STM32实战篇:按键(外部输入信号)触发中断

功能要求 将两个按键分别与引脚PA0、PA1相连接&#xff0c;通过按键按下&#xff0c;能够触发中断响应程序&#xff08;不需明确功能&#xff09;。 代码流程如下&#xff1a; 实现代码 #include "stm32f10x.h" // Device headerint main() {//开…

JUC并发编程01-基础概念

概念 进程 进程可以视为程序的一个实例&#xff0c;进程就是用来加载指令、管理内存、管理I0 线程 一个进程内可以有多个线程&#xff0c;一个线程就是一个指令流。 在Java中&#xff0c;线程作为最小调度单位&#xff0c;进程作为资源分配的最小单位&#xff0c;可以说进程…

Mysql数据库第二次作业

(1)显示所有职工的基本信息。 mysql> select * from t_worker; (2)查询所有职工所属部门的部门号&#xff0c;不显示重复的部门号。 mysql> select distinct department_id from t_worker; (3)求出所有职工的人数。 mysql> select count(1) from t_worker; (4)列…

Figma 中文版指南:获取和安装汉化插件

Figma是一种主流的在线团队合作设计工具&#xff0c;也是一种基于 Web 端的设计工具。在当今的设计时代&#xff0c;Figma 的使用满足了每个人的设计需求&#xff0c;不仅可以实现在线编辑&#xff0c;还可以方便日常管理&#xff0c;有效提高工作效率。然而&#xff0c;相信很…