diffusers 源码待理解之处

一、训练DreamBooth时,相关代码的细节小计

在这里插入图片描述
**

class_labels = timesteps 时,模型的前向传播怎么走?待深入去看

**

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

利用class_prompt去生成数据,而不是instance_prompt

在这里插入图片描述

class DreamBoothDataset(Dataset):"""A dataset to prepare the instance and class images with the prompts for fine-tuning the model.It pre-processes the images and the tokenizes prompts."""def __init__(self,instance_data_root,instance_prompt,tokenizer,class_data_root=None,class_prompt=None,class_num=None,size=512,center_crop=False,encoder_hidden_states=None,class_prompt_encoder_hidden_states=None,tokenizer_max_length=None,):self.size = sizeself.center_crop = center_cropself.tokenizer = tokenizerself.encoder_hidden_states = encoder_hidden_statesself.class_prompt_encoder_hidden_states = class_prompt_encoder_hidden_statesself.tokenizer_max_length = tokenizer_max_lengthself.instance_data_root = Path(instance_data_root)if not self.instance_data_root.exists():raise ValueError(f"Instance {self.instance_data_root} images root doesn't exists.")self.instance_images_path = list(Path(instance_data_root).iterdir())self.num_instance_images = len(self.instance_images_path)self.instance_prompt = instance_promptself._length = self.num_instance_imagesif class_data_root is not None:self.class_data_root = Path(class_data_root)self.class_data_root.mkdir(parents=True, exist_ok=True)self.class_images_path = list(self.class_data_root.iterdir())if class_num is not None:self.num_class_images = min(len(self.class_images_path), class_num)else:self.num_class_images = len(self.class_images_path)self._length = max(self.num_class_images, self.num_instance_images)self.class_prompt = class_promptelse:self.class_data_root = Noneself.image_transforms = transforms.Compose([transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),transforms.ToTensor(),transforms.Normalize([0.5], [0.5]),])def __len__(self):return self._lengthdef __getitem__(self, index):example = {}instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])instance_image = exif_transpose(instance_image)if not instance_image.mode == "RGB":instance_image = instance_image.convert("RGB")example["instance_images"] = self.image_transforms(instance_image)if self.encoder_hidden_states is not None:example["instance_prompt_ids"] = self.encoder_hidden_stateselse:text_inputs = tokenize_prompt(self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length)example["instance_prompt_ids"] = text_inputs.input_idsexample["instance_attention_mask"] = text_inputs.attention_maskif self.class_data_root:class_image = Image.open(self.class_images_path[index % self.num_class_images])class_image = exif_transpose(class_image)if not class_image.mode == "RGB":class_image = class_image.convert("RGB")example["class_images"] = self.image_transforms(class_image)if self.class_prompt_encoder_hidden_states is not None:example["class_prompt_ids"] = self.class_prompt_encoder_hidden_stateselse:class_text_inputs = tokenize_prompt(self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length)example["class_prompt_ids"] = class_text_inputs.input_idsexample["class_attention_mask"] = class_text_inputs.attention_maskreturn example
def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None):if tokenizer_max_length is not None:max_length = tokenizer_max_lengthelse:max_length = tokenizer.model_max_lengthtext_inputs = tokenizer(prompt,truncation=True,padding="max_length",max_length=max_length,return_tensors="pt",)return text_inputs
def collate_fn(examples, with_prior_preservation=False):has_attention_mask = "instance_attention_mask" in examples[0]input_ids = [example["instance_prompt_ids"] for example in examples]pixel_values = [example["instance_images"] for example in examples]if has_attention_mask:attention_mask = [example["instance_attention_mask"] for example in examples]# Concat class and instance examples for prior preservation.# We do this to avoid doing two forward passes.if with_prior_preservation:input_ids += [example["class_prompt_ids"] for example in examples]pixel_values += [example["class_images"] for example in examples]if has_attention_mask:attention_mask += [example["class_attention_mask"] for example in examples]pixel_values = torch.stack(pixel_values)pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()input_ids = torch.cat(input_ids, dim=0)batch = {"input_ids": input_ids,"pixel_values": pixel_values,}if has_attention_mask:attention_mask = torch.cat(attention_mask, dim=0)batch["attention_mask"] = attention_maskreturn batch

Dataset和Dataloader的构成
在这里插入图片描述
为了避免模型过拟合或者是说语言漂移的情况,需要用模型去用一个普通的prompt先生成样本。

fine-tune text-encoder,但是对显存要求更高
在这里插入图片描述

二、训练text to image,相关代码的细节小计

**

1、Dataloader的构建如下,但是为啥没有attention_mask呢?训练DreamBooth时有
2、训练或者微调模型时需要图文数据对,如果没有文本数据,可以用BLIP去生成图像描述的文本,但是文本描述不一定可靠
**

 # Get the datasets: you can either provide your own training and evaluation files (see below)# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).# In distributed training, the load_dataset function guarantees that only one local process can concurrently# download the dataset.if args.dataset_name is not None:# Downloading and loading a dataset from the hub.dataset = load_dataset(args.dataset_name,args.dataset_config_name,cache_dir=args.cache_dir,data_dir=args.train_data_dir,)else:data_files = {}if args.train_data_dir is not None:data_files["train"] = os.path.join(args.train_data_dir, "**")dataset = load_dataset("imagefolder",data_files=data_files,cache_dir=args.cache_dir,)# See more about loading custom images at# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder# Preprocessing the datasets.# We need to tokenize inputs and targets.column_names = dataset["train"].column_names# 6. Get the column names for input/target.dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)if args.image_column is None:image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]else:image_column = args.image_columnif image_column not in column_names:raise ValueError(f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}")if args.caption_column is None:caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]else:caption_column = args.caption_columnif caption_column not in column_names:raise ValueError(f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}")# Preprocessing the datasets.# We need to tokenize input captions and transform the images.def tokenize_captions(examples, is_train=True):captions = []for caption in examples[caption_column]:if isinstance(caption, str):captions.append(caption)elif isinstance(caption, (list, np.ndarray)):# take a random caption if there are multiplecaptions.append(random.choice(caption) if is_train else caption[0])else:raise ValueError(f"Caption column `{caption_column}` should contain either strings or lists of strings.")inputs = tokenizer(captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt")return inputs.input_ids# Preprocessing the datasets.train_transforms = transforms.Compose([transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),transforms.ToTensor(),transforms.Normalize([0.5], [0.5]),])def preprocess_train(examples):images = [image.convert("RGB") for image in examples[image_column]]examples["pixel_values"] = [train_transforms(image) for image in images]examples["input_ids"] = tokenize_captions(examples)# images text pixel_values input_ids 4种keyreturn exampleswith accelerator.main_process_first():if args.max_train_samples is not None:dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))# Set the training transformstrain_dataset = dataset["train"].with_transform(preprocess_train)def collate_fn(examples):pixel_values = torch.stack([example["pixel_values"] for example in examples])pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()input_ids = torch.stack([example["input_ids"] for example in examples])return {"pixel_values": pixel_values, "input_ids": input_ids}# DataLoaders creation:train_dataloader = torch.utils.data.DataLoader(train_dataset,shuffle=True,collate_fn=collate_fn,batch_size=args.train_batch_size,num_workers=args.dataloader_num_workers,)

三、训ControlNet

Dataloader的搭建的代码如下:


1、新增conditioning_pixel_values图像数据,用于做可控的生成
2、输入中依旧没有attention-mask,待思考


def make_train_dataset(args, tokenizer, accelerator):# Get the datasets: you can either provide your own training and evaluation files (see below)# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).# In distributed training, the load_dataset function guarantees that only one local process can concurrently# download the dataset.if args.dataset_name is not None:# Downloading and loading a dataset from the hub.dataset = load_dataset(args.dataset_name,args.dataset_config_name,cache_dir=args.cache_dir,)else:if args.train_data_dir is not None:dataset = load_dataset(args.train_data_dir,cache_dir=args.cache_dir,)# See more about loading custom images at# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script# Preprocessing the datasets.# We need to tokenize inputs and targets.column_names = dataset["train"].column_names# 6. Get the column names for input/target.if args.image_column is None:image_column = column_names[0]logger.info(f"image column defaulting to {image_column}")else:image_column = args.image_columnif image_column not in column_names:raise ValueError(f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}")if args.caption_column is None:caption_column = column_names[1]logger.info(f"caption column defaulting to {caption_column}")else:caption_column = args.caption_columnif caption_column not in column_names:raise ValueError(f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}")if args.conditioning_image_column is None:conditioning_image_column = column_names[2]logger.info(f"conditioning image column defaulting to {conditioning_image_column}")else:conditioning_image_column = args.conditioning_image_columnif conditioning_image_column not in column_names:raise ValueError(f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}")def tokenize_captions(examples, is_train=True):captions = []for caption in examples[caption_column]:if random.random() < args.proportion_empty_prompts:captions.append("")elif isinstance(caption, str):captions.append(caption)elif isinstance(caption, (list, np.ndarray)):# take a random caption if there are multiplecaptions.append(random.choice(caption) if is_train else caption[0])else:raise ValueError(f"Caption column `{caption_column}` should contain either strings or lists of strings.")inputs = tokenizer(captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt")return inputs.input_idsimage_transforms = transforms.Compose([transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),transforms.CenterCrop(args.resolution),transforms.ToTensor(),transforms.Normalize([0.5], [0.5]),])conditioning_image_transforms = transforms.Compose([transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),transforms.CenterCrop(args.resolution),transforms.ToTensor(),])def preprocess_train(examples):images = [image.convert("RGB") for image in examples[image_column]]images = [image_transforms(image) for image in images]conditioning_images = [image.convert("RGB") for image in examples[conditioning_image_column]]conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images]examples["pixel_values"] = imagesexamples["conditioning_pixel_values"] = conditioning_imagesexamples["input_ids"] = tokenize_captions(examples)return exampleswith accelerator.main_process_first():if args.max_train_samples is not None:dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))# Set the training transformstrain_dataset = dataset["train"].with_transform(preprocess_train)return train_datasetdef collate_fn(examples):pixel_values = torch.stack([example["pixel_values"] for example in examples])pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples])conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()input_ids = torch.stack([example["input_ids"] for example in examples])return {"pixel_values": pixel_values,"conditioning_pixel_values": conditioning_pixel_values,"input_ids": input_ids,}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/594483.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

循环与基础函数

循环与函数 1.循环的三种方式2.循环的中断与空语句3.函数的定义与使用4.参数的作用域5.指针6.总结 1.循环的三种方式 我们最熟悉的循环为for和while&#xff0c;这两种循环方式在Python系列介绍过。在C中&#xff0c;循环的基本逻辑同Python是类似的。c中while循环的语法如下&…

力扣每日一题99:恢复二叉搜索树

题目 给你二叉搜索树的根节点 root &#xff0c;该树中的 恰好 两个节点的值被错误地交换。请在不改变其结构的情况下&#xff0c;恢复这棵树 。 示例 1&#xff1a; 输入&#xff1a;root [1,3,null,null,2] 输出&#xff1a;[3,1,null,null,2] 解释&#xff1a;3 不能是 1 …

【linux】ufw 的基本使用

碎碎念 所有的云平台的网络流量的进出基本上有三层&#xff0c;首先是虚拟网的流量控制&#xff0c;一般是通过子网访问控制列表来控制vpc也好子网也好的流量出入&#xff0c;其次是安全组控制一层&#xff0c;通过安全组规则控制一类/一组主机&#xff08;指EC2/ECS/VM/CE这些…

c语言结构体学习

文章目录 前言一、结构体的声明1&#xff0c;什么叫结构体?2&#xff0c;结构体的类型3,结构体变量的创建和初始化4&#xff0c;结构体的类型5&#xff0c;结构体的初始化 二、结构体的访问1&#xff0c;结构体成员的点操作符访问2&#xff0c;结构体体成员的指针访问 三、结构…

rime中州韵小狼毫 inputShow lua Filter 输入字符透传滤镜

在 rime中州韵小狼毫 inputShow lua Translator 一文中&#xff0c;我们通过 inputShow.lua 定制了 inputShow_translator&#xff0c;这使得我们的输入方案可以将用户输入的字符透传到候选列表中来。如下&#x1f447;&#xff1a; &#x1f446;上图中我们在候选列表中看到了…

基于ssm+vue服装商城购物系统

摘要 在基于SSM框架和Vue.js的服装商城购物系统中&#xff0c;整合了多种先进的技术&#xff0c;为电子商务领域的发展提供了有力支持。该系统不仅仅是技术层面的整合&#xff0c;更是对于业务流程和用户体验的深入考虑。以下是对该系统扩展的一些关键方面的讨论&#xff0c;以…

synchronized锁

synchronized 类锁&#xff1a;给类的静态方法加上synchronized 关键字进行修饰&#xff0c; 锁的是当前类class&#xff0c;一个静态同步方法拿到锁&#xff0c;其他静态同步方法就会等待静态同步方法和普通同步方法间是没有竞争的 对象锁&#xff1a;给类的方法加上synchron…

elasticsearch如何操作索引库里面的文档

上节介绍了索引库的CRUD&#xff0c;接下来操作索引库里面的文档 目录 一、添加文档 二、查询文档 三、删除文档 四、修改文档 一、添加文档 新增文档的DSL语法如下 POST /索引库名/_doc/文档id(不加id,es会自动生成) { "字段1":"值1", "字段2&q…

基于头脑风暴算法优化的Elman神经网络数据预测 - 附代码

基于头脑风暴算法优化的Elman神经网络数据预测 - 附代码 文章目录 基于头脑风暴算法优化的Elman神经网络数据预测 - 附代码1.Elman 神经网络结构2.Elman 神经用络学习过程3.电力负荷预测概述3.1 模型建立 4.基于头脑风暴优化的Elman网络5.测试结果6.参考文献7.Matlab代码 摘要&…

Vue3 的 emit 该怎么写, vue2 对比

Vue3 的 emit 该怎么写&#xff0c; vue2 对比 这是个新手问题&#xff0c;从 vue2 转到 vue3 之后&#xff0c;一时间不知道该怎么用它了。 vue2 用法 vue2 在 template 中 和 在方法中的用法如下&#xff1a; <template><button click"$emit(clicked, 要传…

贝锐花生壳全新功能:浏览器一键远程访问SSHRDP远程桌面

为了满足特定场景的远程访问需求&#xff0c;如&#xff1a;远程群晖NAS设备、远程SQL Server数据库/MySQL数据库、3389远程桌面&#xff08;RDP远程桌面&#xff09;、远程SSH、我的世界游戏联机…… 贝锐花生壳推出了场景映射服务&#xff0c;不仅提供满足相应场景的网络带宽…

在 2024 年搜索中提升排名的 7 项内容调整

忘掉关键词填充和算法追逐。2024 年的重点是 EEAT&#xff0c;宝贝&#xff01;谷歌希望最专业、最权威、最值得信赖&#xff08;EEAT&#xff09;的内容能够排名靠前&#xff0c;这就意味着您的内容需要成为专业知识、参与度和信任度的交响乐。 准备好让搜索引擎和人类都无法…

YOLOv5算法进阶改进(10)— 更换主干网络之MobileViTv3 | 轻量化Backbone

前言:Hello大家好,我是小哥谈。MobileViTv3是一种改进的模型架构,用于图像分类任务。它是在MobileViTv1和MobileViTv2的基础上进行改进的,通过引入新的模块和优化网络结构来提高性能。本节课就给大家介绍一下如何在主干网络中引入MobileViTv3网络结构,希望大家学习之后能够…

基于Java SSM框架实现四六级在线考试系统项目【项目源码+论文说明】计算机毕业设计

基于java的SSM框架实现四六级在线考试系统演示 摘要 随着现在网络的快速发展&#xff0c;网上管理系统也逐渐快速发展起来&#xff0c;网上管理模式很快融入到了许多学院的之中&#xff0c;随之就产生了“四六级在线考试系统”&#xff0c;这样就让四六级在线考试系统更加方便…

kbdnso.dll文件缺失,软件或游戏报错的快速修复方法

很多小伙伴遇到电脑报错&#xff0c;提示“kbdnso.dll文件缺失&#xff0c;程序无法启动执行”时&#xff0c;不知道应该怎样处理&#xff0c;还以为是程序出现了问题&#xff0c;想卸载重装。 首先&#xff0c;先要了解“kbdnso.dll文件”是什么&#xff1f; kbdnso.dll是Win…

2007-2022年上市公司数字化转型数据(区分年报和管理层讨论)(含原始数据+处理代码+结果)

2007-2022年上市公司数字化转型数据&#xff08;年报和管理层讨论&#xff09;&#xff08;含原始数据处理代码结果&#xff09; 1、时间&#xff1a;2007-2022年 2、指标&#xff1a;统计年度、证券代码、人工智能技术、区块链技术、云计算技术、大数据技术、数字技术应用、…

requests库中Session对象超时解决过程

引言 在使用Python进行网络请求时&#xff0c;requests库是一个非常常用的工具。它提供了Session对象来管理和持久化参数&#xff0c;例如cookies、headers等。但是&#xff0c;对于一些需要长时间运行的请求&#xff0c;我们需要设置超时时间来避免长时间等待或者无限期阻塞的…

浏览器使用隧道代理HTTP:洞悉无界信息

在信息爆炸的时代&#xff0c;互联网已经成为获取信息的首选渠道。然而&#xff0c;在某些地区或情况下&#xff0c;访问某些网站可能会受到限制。这时&#xff0c;隧道代理HTTP便成为了一个重要的工具&#xff0c;帮助用户突破限制&#xff0c;洞悉无界信息。 一、隧道代理HT…

上海AI lab大模型微调

教程链接&#xff1a;InternLM学习教程链接 命令行演示结果&#xff1a; web演示结果

书生·浦语大模型全链路开源开放体系

书生浦语大模型全链路开源开放体系 大模型成为热门关键词书生浦语大模型开源历程书生浦语20B开源大模型性能从模型到应用书生浦语全链路开源开放体系数据预训练微调评测部署智能体 大模型成为热门关键词 大模型成为发展通用人工智能的重要途径 书生浦语大模型开源历程 书生浦语…