Diffusion Models专栏文章汇总:入门与实战
前言:自从SDXL提出了长宽桶技术之后,彻底解决了不同长宽比的图像输入问题,现在已经成为训练扩散模型必选的方案。这篇博客从代码详细解读如何在模型训练的时候运用长宽桶技术(Aspect Ratio Bucketing)。
目录
原理解读-原有训练的问题
长宽桶技术(Aspect Ratio Bucketing)
完整代码
原理解读-原有训练的问题
纵横比分桶训练可以极大地提高输出质量,现有图像生成模型的一个常见问题是,它们非常容易生成带有非自然作物的图像。这是因为这些模型被训练成生成方形图像。然而,大多数照片和艺术品都不是方形的。然而,该模型只能同时在相同大小的图像上工作,并且在训练过程中,通常的做法是同时在多个训练样本上操作,以优化所使用gpu的效率。作为妥协,选择正方形图像,在训练过程中,只裁剪出每个图像的中心,然后作为训练样例显示给图像生成模型。
例如,人类通常是没有脚或头的,剑只有一个刀刃,剑柄和剑尖在框架外。因为我们正在创建一个图像生成模型来配合我们的故事叙述体验,所以我们的模型能够产生适当的,未裁剪的角色是很重要的,并且生成的骑士不应该持有延伸到无限的金属状直线。
对裁剪图像进行训练的另一个问题是,它可能导致文本和图像之间的不匹配。例如,带有王冠标签的图像通常在中央裁剪后不再包含王冠,因此君主已经被斩首。我们发现使用随机作物代替中心作物只能略微改善这些问题。使用具有可变图像大小的稳定扩散是可能的,尽管可以注意到,远远超过512x512的原生分辨率往往会引入重复的图像元素,并且非常低的分辨率会产生无法识别的图像。
尽管如此,这向我们表明,在可变大小的图像上训练模型应该是可能的。在单个、可变大小的样本上进行训练是微不足道的,但也非常缓慢,而且由于使用小批量提供的缺乏正则化,更容易产生训练不稳定性。
长宽桶技术(Aspect Ratio Bucketing)
由于这个问题似乎没有现有的解决方案,我们已经为我们的数据集实现了自定义批生成代码,允许创建批处理,其中批处理中的每个项目具有相同的大小,但批处理的图像大小可能不同。
我们通过一种叫做宽高比桶的方法来做到这一点。另一种方法是使用固定的图像大小,缩放每个图像以适应这个固定的大小,并应用在训练期间被掩盖的填充。由于这会导致训练期间不必要的计算,我们没有选择遵循这种替代方法。
在下面,我们描述了我们自定义的宽高比桶的批量生成方案背后的原始想法。
首先,我们必须定义要将数据集的图像排序到哪个存储桶中。为此,我们定义的最大图像尺寸为512x768,最大尺寸为1024。由于最大图像大小为512x768,比512x512大,需要更多的VRAM,因此每个gpu的批处理大小必须降低,这可以通过梯度积累来补偿。
我们通过应用以下算法生成桶:
Set the width to 256.
While the width is less than or equal to 1024:Find the largest height such that height is less than or equal to 1024 and that width multiplied by height is less than or equal to 512 * 768.Add the resolution given by height and width as a bucket.Increase the width by 64.
同样的重复,宽度和高度互换。重复的桶将从列表中删除,并添加一个大小为512x512的桶。
接下来,我们将图像分配到相应的桶中。为此,我们首先将桶分辨率存储在NumPy数组中,并计算每个分辨率的长宽比。对于数据集中的每张图像,我们检索其分辨率并计算长宽比。图像宽高比从桶宽高比数组中减去,使我们能够根据宽高比差的绝对值有效地选择最接近的桶:
image_bucket = argmin(abs(bucket_aspects — image_aspect))
图像的桶号与其数据集中的项目ID相关联。如果图像的宽高比非常极端,甚至与最适合的桶相差太大,则从数据集中修剪图像。
由于我们在多个GPU上进行训练,在每个epoch之前,我们对数据集进行了分片,以确保每个GPU在大小相等的不同子集上工作。为此,我们首先复制数据集中的项目id列表并对它们进行洗牌。如果这个复制的列表不能被gpu数量乘以批大小整除,则会对列表进行修剪,并删除最后的项以使其可整除。
然后,我们根据当前进程的全局排名选择1/world_size*bsz项id的不同子集。自定义批处理生成的其余部分将从这些过程中的任何一个过程中进行描述,并对数据集项id的子集进行操作。
对于当前的分片,每个bucket的列表是通过迭代打乱的数据集项目ID列表并将ID分配给分配给图像的bucket对应的列表来创建的。
处理完所有图像后,我们遍历每个bucket的列表。如果它的长度不能被批大小整除,则根据需要删除列表上的最后一个元素以使其可整除,并将它们添加到单独的捕获所有桶中。由于保证整个分片大小包含许多可被批大小整除的元素,因此保证生成一个长度可被批大小整除的所有bucket。
当请求批处理时,我们从加权分布中随机抽取一个桶。桶的权重设置为桶的大小除以所有剩余桶的大小。这确保了即使有大小差异很大的桶,自定义批生成在训练期间不会引入强烈的偏差,根据图像大小显示图像。如果在没有加权的情况下选择桶,那么小的桶将在训练过程中早期清空,只有最大的桶将在训练结束时保留。按大小对桶进行加权可以避免这种情况。
最后从所选的桶中取出一批项。取走的项目从桶中移除。如果桶现在为空,则在epoch的剩余时间内删除它。所选的项id和所选桶的分辨率现在被传递给图像加载函数。
完整代码
import numpy as np
import pickle
import timedef get_prng(seed):return np.random.RandomState(seed)class BucketManager:def __init__(self, bucket_file, valid_ids=None, max_size=(768,512), divisible=64, step_size=8, min_dim=256, base_res=(512,512), bsz=1, world_size=1, global_rank=0, max_ar_error=4, seed=42, dim_limit=1024, debug=False):with open(bucket_file, "rb") as fh:self.res_map = pickle.load(fh)if valid_ids is not None:new_res_map = {}valid_ids = set(valid_ids)for k, v in self.res_map.items():if k in valid_ids:new_res_map[k] = vself.res_map = new_res_mapself.max_size = max_sizeself.f = 8self.max_tokens = (max_size[0]/self.f) * (max_size[1]/self.f)self.div = divisibleself.min_dim = min_dimself.dim_limit = dim_limitself.base_res = base_resself.bsz = bszself.world_size = world_sizeself.global_rank = global_rankself.max_ar_error = max_ar_errorself.prng = get_prng(seed)epoch_seed = self.prng.tomaxint() % (2**32-1)self.epoch_prng = get_prng(epoch_seed) # separate prng for sharding use for increased thread resilienceself.epoch = Noneself.left_over = Noneself.batch_total = Noneself.batch_delivered = Noneself.debug = debugself.gen_buckets()self.assign_buckets()self.start_epoch()def gen_buckets(self):if self.debug:timer = time.perf_counter()resolutions = []aspects = []w = self.min_dimwhile (w/self.f) * (self.min_dim/self.f) <= self.max_tokens and w <= self.dim_limit:h = self.min_dimgot_base = Falsewhile (w/self.f) * ((h+self.div)/self.f) <= self.max_tokens and (h+self.div) <= self.dim_limit:if w == self.base_res[0] and h == self.base_res[1]:got_base = Trueh += self.divif (w != self.base_res[0] or h != self.base_res[1]) and got_base:resolutions.append(self.base_res)aspects.append(1)resolutions.append((w, h))aspects.append(float(w)/float(h))w += self.divh = self.min_dimwhile (h/self.f) * (self.min_dim/self.f) <= self.max_tokens and h <= self.dim_limit:w = self.min_dimgot_base = Falsewhile (h/self.f) * ((w+self.div)/self.f) <= self.max_tokens and (w+self.div) <= self.dim_limit:if w == self.base_res[0] and h == self.base_res[1]:got_base = Truew += self.divresolutions.append((w, h))aspects.append(float(w)/float(h))h += self.divres_map = {}for i, res in enumerate(resolutions):res_map[res] = aspects[i]self.resolutions = sorted(res_map.keys(), key=lambda x: x[0] * 4096 - x[1])self.aspects = np.array(list(map(lambda x: res_map[x], self.resolutions)))self.resolutions = np.array(self.resolutions)if self.debug:timer = time.perf_counter() - timerprint(f"resolutions:\n{self.resolutions}")print(f"aspects:\n{self.aspects}")print(f"gen_buckets: {timer:.5f}s")def assign_buckets(self):if self.debug:timer = time.perf_counter()self.buckets = {}self.aspect_errors = []skipped = 0skip_list = []for post_id in self.res_map.keys():w, h = self.res_map[post_id]aspect = float(w)/float(h)bucket_id = np.abs(self.aspects - aspect).argmin()if bucket_id not in self.buckets:self.buckets[bucket_id] = []error = abs(self.aspects[bucket_id] - aspect)if error < self.max_ar_error:self.buckets[bucket_id].append(post_id)if self.debug:self.aspect_errors.append(error)else:skipped += 1skip_list.append(post_id)for post_id in skip_list:del self.res_map[post_id]if self.debug:timer = time.perf_counter() - timerself.aspect_errors = np.array(self.aspect_errors)print(f"skipped images: {skipped}")print(f"aspect error: mean {self.aspect_errors.mean()}, median {np.median(self.aspect_errors)}, max {self.aspect_errors.max()}")for bucket_id in reversed(sorted(self.buckets.keys(), key=lambda b: len(self.buckets[b]))):print(f"bucket {bucket_id}: {self.resolutions[bucket_id]}, aspect {self.aspects[bucket_id]:.5f}, entries {len(self.buckets[bucket_id])}")print(f"assign_buckets: {timer:.5f}s")def start_epoch(self, world_size=None, global_rank=None):if self.debug:timer = time.perf_counter()if world_size is not None:self.world_size = world_sizeif global_rank is not None:self.global_rank = global_rank# select ids for this epoch/rankindex = np.array(sorted(list(self.res_map.keys())))index_len = index.shape[0]index = self.epoch_prng.permutation(index)index = index[:index_len - (index_len % (self.bsz * self.world_size))]#print("perm", self.global_rank, index[0:16])index = index[self.global_rank::self.world_size]self.batch_total = index.shape[0] // self.bszassert(index.shape[0] % self.bsz == 0)index = set(index)self.epoch = {}self.left_over = []self.batch_delivered = 0for bucket_id in sorted(self.buckets.keys()):if len(self.buckets[bucket_id]) > 0:self.epoch[bucket_id] = np.array([post_id for post_id in self.buckets[bucket_id] if post_id in index], dtype=np.int64)self.prng.shuffle(self.epoch[bucket_id])self.epoch[bucket_id] = list(self.epoch[bucket_id])overhang = len(self.epoch[bucket_id]) % self.bszif overhang != 0:self.left_over.extend(self.epoch[bucket_id][:overhang])self.epoch[bucket_id] = self.epoch[bucket_id][overhang:]if len(self.epoch[bucket_id]) == 0:del self.epoch[bucket_id]if self.debug:timer = time.perf_counter() - timercount = 0for bucket_id in self.epoch.keys():count += len(self.epoch[bucket_id])print(f"correct item count: {count == len(index)} ({count} of {len(index)})")print(f"start_epoch: {timer:.5f}s")def get_batch(self):if self.debug:timer = time.perf_counter()# check if no data left or no epoch initializedif self.epoch is None or self.left_over is None or (len(self.left_over) == 0 and not bool(self.epoch)) or self.batch_total == self.batch_delivered:self.start_epoch()found_batch = Falsebatch_data = Noneresolution = self.base_reswhile not found_batch:bucket_ids = list(self.epoch.keys())if len(self.left_over) >= self.bsz:bucket_probs = [len(self.left_over)] + [len(self.epoch[bucket_id]) for bucket_id in bucket_ids]bucket_ids = [-1] + bucket_idselse:bucket_probs = [len(self.epoch[bucket_id]) for bucket_id in bucket_ids]bucket_probs = np.array(bucket_probs, dtype=np.float32)bucket_lens = bucket_probsbucket_probs = bucket_probs / bucket_probs.sum()bucket_ids = np.array(bucket_ids, dtype=np.int64)if bool(self.epoch):chosen_id = int(self.prng.choice(bucket_ids, 1, p=bucket_probs)[0])else:chosen_id = -1if chosen_id == -1:# using leftover images that couldn't make it into a bucketed batch and returning them for use with basic square imageself.prng.shuffle(self.left_over)batch_data = self.left_over[:self.bsz]self.left_over = self.left_over[self.bsz:]found_batch = Trueelse:if len(self.epoch[chosen_id]) >= self.bsz:# return bucket batch and resolutionbatch_data = self.epoch[chosen_id][:self.bsz]self.epoch[chosen_id] = self.epoch[chosen_id][self.bsz:]resolution = tuple(self.resolutions[chosen_id])found_batch = Trueif len(self.epoch[chosen_id]) == 0:del self.epoch[chosen_id]else:# can't make a batch from this, not enough images. move them to leftovers and try againself.left_over.extend(self.epoch[chosen_id])del self.epoch[chosen_id]assert(found_batch or len(self.left_over) >= self.bsz or bool(self.epoch))if self.debug:timer = time.perf_counter() - timerprint(f"bucket probs: " + ", ".join(map(lambda x: f"{x:.2f}", list(bucket_probs*100))))print(f"chosen id: {chosen_id}")print(f"batch data: {batch_data}")print(f"resolution: {resolution}")print(f"get_batch: {timer:.5f}s")self.batch_delivered += 1return (batch_data, resolution)def generator(self):if self.batch_delivered >= self.batch_total:self.start_epoch()while self.batch_delivered < self.batch_total:yield self.get_batch()if __name__ == "__main__":# prepare a pickle with mapping of dataset IDs to resolutions called resolutions.pkl to use thiswith open("resolutions.pkl", "rb") as fh:ids = list(pickle.load(fh).keys())counts = np.zeros((len(ids),)).astype(np.int64)id_map = {}for i, post_id in enumerate(ids):id_map[post_id] = ibm = BucketManager("resolutions.pkl", debug=True, bsz=8, world_size=8, global_rank=3)print("got: " + str(bm.get_batch()))print("got: " + str(bm.get_batch()))print("got: " + str(bm.get_batch()))print("got: " + str(bm.get_batch()))print("got: " + str(bm.get_batch()))print("got: " + str(bm.get_batch()))print("got: " + str(bm.get_batch()))bm = BucketManager("resolutions.pkl", bsz=8, world_size=1, global_rank=0, valid_ids=ids[0:16])for _ in range(16):bm.get_batch()print("got from future epoch: " + str(bm.get_batch()))bms = []for rank in range(16):bm = BucketManager("resolutions.pkl", bsz=8, world_size=16, global_rank=rank)bms.append(bm)for epoch in range(5):print(f"epoch {epoch}")for i, bm in enumerate(bms):print(f"bm {i}")first = Truefor ids, res in bm.generator():if first and i == 0:#print(ids)first = Falsefor post_id in ids:counts[id_map[post_id]] += 1print(np.bincount(counts))