Dreambooth on Stable Diffusion

Dreambooth on Stable Diffusion

1.DreamBooth: Fine Tuning Text-to-Image Diffusion Models for Subject-Driven Generation
2.[论文简析]DreamBooth: Fine Tuning Text-to-Image Diffusion Models…[2208.12242]
3.sd_dreambooth_extension
4.stable-dreambooth

对3-5张特定主体的图片进行加噪 + prompt(由unique identifier和class noun组成,前者用于和特定主体进行bind,后者用于继承pre-trained SD中相应主体的特征)输入SD,对SD进行fine-tune(微调),但是这样会造成两个问题:
(1)SD过拟合到这个特定主体上,削弱了pre-trained SD生成相关主体的多样性
(2)SD出现对语义理解的偏差,假设我们输入特定主体为柯基狗狗,当我们在fine-tune之后的SD模型上输入dog(class noun)时,SD理解为dog指的就只是柯基

如何解决这两个问题?
论文引入了 Class-Specific Prior Preservation Loss,相当于一个正则项,对fine-tune之后的SD进行了修正
整个Loss包含两个部分:Reconstruction Loss也就是fine-tune时的损失,Class-Specific Prior Preservation Loss也就是对fine-tune出现问题后的修正

它是如何修正的?
在fine-tune SD之前,先用pre-trained SD生成相应多样化的主体(各种狗狗),将这些图片加噪后输入到SD进行fine-tune,得到的结果与之前生成的进行损失,这相当于前者对fine-tune后的图片进行了监督,提高了fine-tune后生成图片的多样性
fine-tune中prompt包含pre-trained prompt中的class noun,相当于告诉fine-tune SD 我们这里指的狗狗不只是会趴着,也会像pre-trained SD中prompt中描述的狗狗拥有跑、跳等动作


第一行为fine-tune时使用的数据集
第二行为只有reconstruction loss时的效果,输入趴着的狗狗,fine-tune后只会输出趴着的狗狗
第三行加入了正则项损失时的效果,输出有站着、坐着的狗狗,而不仅仅是趴着的狗狗

1.准备数据
在fine-tune之前,我们使用 pre-trained SD模型生成100张图片作为fine-tune模型时的监督信号


class文件夹保存有刚刚生成的100张图片

代码来自:sample.py

from diffusers.pipelines import StableDiffusionPipeline
import torch
from torchvision import transforms
import os
from pathlib import Path# Define the number of samples to generate and the batch size for processing
sample_nums = 100
batch_size = 16# Define the prompt for image generation
prompt = "a photo of dog"# Define the directory where generated images will be saved
save_dir = Path("./data/dogs/class")
save_dir.mkdir(parents=True, exist_ok=True)  # Create the directory if it doesn't existif __name__ == "__main__":# Define the model ID and device to usemodel_id = "CompVis/stable-diffusion-v1-4"device = "cuda"# Load the pre-trained Stable Diffusion model from Hugging Face and move it to the specified devicemodel = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True, cache_dir="./.cache").to(device)# Create a list of prompts, each repeated to match the number of samples neededdatasets = [prompt] * sample_nums# Split the list of prompts into batches of the specified batch sizedatasets = [datasets[x:x + batch_size] for x in range(0, sample_nums, batch_size)]# Initialize an identifier for naming the saved imagesid = 0# Iterate over each batch of promptsfor text in datasets:print(f"Processing batch: {text}")# Generate images for the current batch of prompts without computing gradientswith torch.no_grad():output = model(text, height=512, width=512, num_inference_steps=50)# Generate images using the model with specified height, width, and inference stepsimages = output.images  # Get images list from output dictionaryprint(f"Generated {len(images)} images")# Save each generated image to the specified directoryfor image in images:if isinstance(image, torch.Tensor):image = transforms.ToPILImage()(image)  # Convert tensor to PIL image if neededimage_path = save_dir / f"{id}.png"image.save(image_path)print(f"Saved image to {image_path}")id += 1

我们准备几张特定主体的图片(下图展示的是我们家的边牧)用于fine-tune

这几张特定主体图片保存在instance文件夹中

3.Fine-Tuning the Model:
(1)Load the pre-trained Stable Diffusion model.
(2)Use the DreamBooth training script to fine-tune the model with your reference images.
(3)During training, the model learns to associate the reference images with the specific textual description provided.

代码来自:train.py
不同于Drembooth论文中使用Renconstruction Loss和Class-Specific Prior Preservation Loss计算两个loss(这两个loss的权重不同),此fine-tune过程中,将instance和class中的图片同时输入到unet网络中,直接计算一个MSE Loss,也可能是由于此操作这种训练的效果并没有那么好

import dataclasses
import json
import os
from dataclasses import dataclass
from typing import List
from dataset import TrainDatasetimport torch
import torch.nn.functional as F
from accelerate import Accelerator
from diffusers.schedulers import DDPMScheduler, LMSDiscreteScheduler
from diffusers.pipelines import StableDiffusionPipeline
from PIL import Image
from tqdm.auto import tqdm
from torchvision import transforms
from datasets import load_dataset
from pathlib import Path@dataclass
class TrainingConfig:# Task specific parametersinstance_prompt: str = "photo of a [V] dog"class_prompt: str = "photo of a dog"evaluate_prompt = ["photo of a [V] dog"] * 4 + ["photo of a [V] dog in a doghouse"] * 4 + ["photo of a [V] dog in a bucket"] * 4 + ["photo of a sleeping [V] dog"] * 4data_path: str = "./data/dogs"identifier: str = "Dinger"# Basic Training Parametersnum_epochs: int = 1train_batch_size: int = 2learning_rate: float = 1e-5image_size: int = 512  # the generated image resolutiongradient_accumulation_steps: int = 1# Hyperparmeter for diffusion modelsnum_train_timesteps: int = 1000train_guidance_scale: float = 1  # guidance scale at trainingsample_guidance_scale: float = 7.5  # guidance scale at inference# Practical Training Settingsmixed_precision: str = 'fp16'  # `no` for float32, `fp16` for automatic mixed precisionsave_image_epochs: int = 1save_model_epochs: int = 1output_dir: str = './logs/dog_finetune'overwrite_output_dir: bool = True  # overwrite the old model when re-running the notebookseed: int = 42def __post_init__(self):self.instance_prompt = self.instance_prompt.replace("[V]", self.identifier)self.evaluate_prompt = [s.replace("[V]", self.identifier) for s in self.evaluate_prompt]def pred(model, noisy_latent, time_steps, prompt, guidance_scale):batch_size = noisy_latent.shape[0]text_input = model.tokenizer(prompt,padding="max_length",max_length=model.tokenizer.model_max_length,truncation=True,return_tensors="pt",)with torch.no_grad():text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0]# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`# corresponds to doing no classifier free guidance.do_classifier_free_guidance = guidance_scale > 1.0# get unconditional embeddings for classifier free guidanceif do_classifier_free_guidance:max_length = text_input.input_ids.shape[-1]uncond_input = model.tokenizer([""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt")uncond_embeddings = model.text_encoder(uncond_input.input_ids.to(model.device))[0]# For classifier free guidance, we need to do two forward passes.# Here we concatenate the unconditional and text embeddings into a single batch# to avoid doing two forward passestext_embeddings = torch.cat([uncond_embeddings, text_embeddings])latent_model_input = torch.cat([noisy_latent] * 2) if do_classifier_free_guidance else noisy_latenttime_steps = torch.cat([time_steps] * 2) if do_classifier_free_guidance else time_stepsnoise_pred = model.unet(latent_model_input, time_steps, encoder_hidden_states=text_embeddings)["sample"]# perform guidanceif do_classifier_free_guidance:noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)return noise_preddef train_loop(config: TrainingConfig, model: StableDiffusionPipeline, noise_scheduler, optimizer, train_dataloader):# Initialize accelerator and tensorboard loggingaccelerator = Accelerator(mixed_precision=config.mixed_precision,gradient_accumulation_steps=config.gradient_accumulation_steps,)if accelerator.is_main_process:accelerator.init_trackers("train_example")# Prepare everything# There is no specific order to remember, you just need to unpack the# objects in the same order you gave them to the prepare method.model, optimizer, train_dataloader = accelerator.prepare(model, optimizer, train_dataloader)global_step = 0# Now you train the modelfor epoch in range(config.num_epochs):progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)progress_bar.set_description(f"Epoch {epoch}")for step, batch in enumerate(train_dataloader):instance_imgs, instance_prompt, class_imgs, class_prompt = batchimgs = torch.cat((instance_imgs, class_imgs), dim=0)prompt = instance_prompt + class_prompt# Sample noise to add to the imagesbs = imgs.shape[0]# Sample a random timestep for each imagetimesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bs,), device=accelerator.device).long()# Add noise to the clean images according to the noise magnitude at each timestep# (this is the forward diffusion process)with torch.no_grad():latents = model.vae.encode(imgs).latent_dist.sample() * 0.18215noise = torch.randn(latents.shape, device=accelerator.device)noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)with accelerator.accumulate(model):# Predict the noise residualnoise_pred = pred(model, noisy_latents, timesteps, prompt, guidance_scale=config.train_guidance_scale)loss = F.mse_loss(noise_pred, noise)accelerator.backward(loss)accelerator.clip_grad_norm_(model.unet.parameters(), 1.0)optimizer.step()optimizer.zero_grad()# Clear GPU cachetorch.cuda.empty_cache()progress_bar.update(1)logs = {"loss": loss.detach().item(), "step": global_step}progress_bar.set_postfix(**logs)accelerator.log(logs, step=global_step)global_step += 1# After each epoch you optionally sample some demo images with evaluate() and save the modelif accelerator.is_main_process:if epoch % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:evaluate(config, epoch, model)if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:model.save_pretrained(config.output_dir)def make_grid(images, rows, cols):w, h = images[0].sizegrid = Image.new('RGB', size=(cols * w, rows * h))for i, image in enumerate(images):grid.paste(image, box=(i % cols * w, i // cols * h))return griddef evaluate(config: TrainingConfig, epoch, pipeline: StableDiffusionPipeline):# Sample some images from random noise (this is the backward diffusion process).# The default pipeline output type is `List[PIL.Image]`with torch.no_grad():with torch.autocast("cuda"):images = \pipeline(config.evaluate_prompt, num_inference_steps=50, width=config.image_size, height=config.image_size,guidance_scale=config.sample_guidance_scale)["sample"]# Make a grid out of the imagesimage_grid = make_grid(images, rows=4, cols=4)# Save the imagestest_dir = os.path.join(config.output_dir, "samples")os.makedirs(test_dir, exist_ok=True)image_grid.save(f"{test_dir}/{epoch:04d}.jpg")def get_dataloader(config: TrainingConfig):dataset = TrainDataset(config.data_path, config.instance_prompt, config.class_prompt, config.image_size)dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True, drop_last=True,pin_memory=True)return dataloaderif __name__ == "__main__":config = TrainingConfig()output_dir = Path(config.output_dir)output_dir.mkdir(parents=True, exist_ok=True)with open(output_dir / "config.json", "w") as f:json.dump(dataclasses.asdict(config), f)model_id = "CompVis/stable-diffusion-v1-4"device = "cuda"try:model = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True, cache_dir="./.cache").to(device)except Exception as e:print(e)print("Run 'huggingface-cli login' to store auth token.")exit(1)train_dataloader = get_dataloader(config)optimizer = torch.optim.AdamW(model.unet.parameters(), lr=config.learning_rate)noise_scheduler = DDPMScheduler(num_train_timesteps=config.num_train_timesteps, beta_start=0.00085, beta_end=0.0120)train_loop(config, model, noise_scheduler, optimizer, train_dataloader)

4.Generate Images:
(1)After training, you can generate images using the fine-tuned model by providing text prompts that describe the desired output.
(2)The model should now generate images that reflect the characteristics of the reference images.

代码来自:inference.py

from pathlib import Path
from diffusers.pipelines import StableDiffusionPipeline
import torch
from argparse import ArgumentParser
import jsondef parse_args():parser = ArgumentParser()parser.add_argument("--prompt", required=True)parser.add_argument("--checkpoint_dir", required=True)parser.add_argument("--save_dir", default="outputs")parser.add_argument("--sample_nums", default=16)parser.add_argument("-gs", "--guidance_scale", type=float, default=7.5)return parser.parse_args()if __name__ == "__main__":args = parse_args()with open(Path(args.checkpoint_dir) / 'config.json') as f:config = json.loads(f.read())args.prompt = args.prompt.replace("[V]", config["identifier"])device = "cuda"model = StableDiffusionPipeline.from_pretrained(args.checkpoint_dir).to(device)with torch.no_grad():with torch.autocast("cuda"):output = model([args.prompt] * args.sample_nums, height=512, width=512, guidance_scale=args.guidance_scale,num_inference_steps=50)images = output['images']  # Update this based on actual keysave_dir = Path(args.save_dir)save_dir.mkdir(parents=True, exist_ok=True)for i, image in enumerate(images):image.save(save_dir / f'{i}.jpg')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/50938.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

17.5【C语言】static的补充说明

static &#xff08;静态的) 作用&#xff1a;修饰局部变量&#xff0c;修饰全局变量&#xff0c;修饰函数 对比两段代码 #include <stdio.h> void test() {int a 5;a;printf("%d ", a); } int main() {int i 0;for(i0; i<5; i){test();}return 0; } …

HarmonyOS 质量、测试、上架速浏

1.应用质量要求&#xff1a; 1. 应用体验质量建议: 功能数据完备 功能完备 数据完备 基础体验要求 基础约束 兼容性 稳定性 性能 功耗 安全…

网格布局 HTML CSS grid layout demo

文章目录 页面效果代码 (HTML CSS)参考 页面效果 代码 (HTML CSS) <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"…

高速传输新体验:探索最新USB-C移动硬盘盒的无限可能

在数字化时代&#xff0c;数据存储的重要性不言而喻&#xff0c;而硬盘盒作为连接外部硬盘与计算机的关键设备&#xff0c;其功能也在不断进化。近年来&#xff0c;随着技术的创新与发展&#xff0c;市场上出现了一种新型硬盘盒——它不仅能安全、高效地存储和传输数据&#xf…

安装VMware Workstation Pro

一、下载 通过百度网盘分享的文件&#xff1a;VMware-workstation-full-16.2.4-2008... 链接&#xff1a;https://pan.baidu.com/s/1mDnFhLQErBlpeX_KjsgtzA 提取码&#xff1a;0bw7 二、安装 &#xff08;1&#xff09;双击exe文件 &#xff08;2&#xff09;安装软件 &…

算法-bfs-八数码

题目一 解题思路 将每一串字符都想象成一个点&#xff0c;已知起点和终点,每一次更新相邻的节点&#xff0c;采用bfs得到到达终点的最短路径。 数据结构&#xff1a; unordered_map&#xff08;哈希表&#xff09;来存储每种字符串对应情况需要移动的次数。 queue(队列)存储…

刷题计划 day4 【双指针、快慢指针、环形链表】链表下

⚡刷题计划day4继续&#xff0c;可以点个免费的赞哦~ 下一期将会开启哈希表刷题专题&#xff0c;往期可看专栏&#xff0c;关注不迷路&#xff0c; 您的支持是我的最大动力&#x1f339;~ 目录 ⚡刷题计划day4继续&#xff0c;可以点个免费的赞哦~ 下一期将会开启哈希表刷题…

vuex学习day01-vuex简述、基于脚手架创建项目、基于脚手架创建项目

1、vuex简述 之所以采用vuex是因为当我们有多个公共状态的组件时&#xff0c;vue的简单性容易崩溃 &#xff08;1&#xff09;概念 Vuex 是Vue.js 应用程序的状态管理模式库。简单讲&#xff0c;vuex是vue的一个状态管理工具。 &#xff08;2&#xff09;作用 管理vue中的…

w30-python02-pytest入门

代码如下&#xff1a; import pytest class Test_Obj:"""测试类"""#用例级别前后置def setup(self):print(用例级别------的前置处理)def teardown(self):print("用例级别--------的后置处理")# 用例def test_case1(self):print(&quo…

自动驾驶-机器人-slam-定位面经和面试知识系列05之常考公式推导(02)

这个博客系列会分为C STL-面经、常考公式推导和SLAM面经面试题等三个系列进行更新&#xff0c;基本涵盖了自己秋招历程被问过的面试内容&#xff08;除了实习和学校项目相关的具体细节&#xff09;。在知乎和牛客&#xff08;牛客上某些文章上会附上内推码&#xff09;也会同步…

mfc100u.dll 文件缺失?两种方法快速修复丢失mfc100u.dll 文件难题

您的电脑是否遭遇了 mfc100u.dll 文件缺失的问题&#xff1f;这种情况通常由多种原因引起。在本文中&#xff0c;我们将介绍两种修复 mfc100u.dll 文件丢失问题的策略——一种是手动方法&#xff0c;另一种是自动修复的使用。我们将探讨如何有效地解决 mfc100u.dll 文件缺失的几…

vscode 调试web后端

1、调试环境配置 一、安装python环境管理器 其中要先在vscode选择对应的python环境&#xff0c;最方便的是按照环境管理器后从中选择。其中在【externsions】里面安装python即可。 如下&#xff1a; 二、编写launch.json文件 其中如下&#xff1a; {// Use IntelliSense …

算法-插入排序

插入排序步骤 前面文章分享了两种排序算法&#xff1a;冒泡排序和选择排序。虽然它们的效率都是O(N2)&#xff0c;但其实选择排序比冒泡排序快一倍。现在来学第三种排序算法——插入排序。你会发现&#xff0c;顾及最坏情况以外的场景将是多么有用。 插入排序包括以下步骤。 …

2024固定资产管理软件排名 6款好用的企业资产管理软件

固定资产管理是企业财务管理的重要组成部分&#xff0c;选择一款好用的固定资产管理软件可以显著提升资产跟踪和维护的效率。本文将介绍几款功能强大且易于操作的固定资产管理软件&#xff0c;帮助企业优化资产管理流程&#xff0c;确保资产数据的准确性和实时性&#xff0c;从…

软件测试:Postman 工具的使用。开发及测试均需要掌握的测试工具

工具介绍 各个模块功能的介绍如下&#xff1a; 1、New&#xff1a;在这里创建新的请求、集合或环境&#xff1b;还可以创建更高级的文档、Mock Server 和 Monitor以及API。 2、Import&#xff1a;这用于导入集合或环境。有一些选项&#xff0c;例如从文件&#xff0c;文件夹导…

SpringBoot3:轻松使用Jasypt实现配置文件信息加密

文章目录 前言一、概述1.1 Jasypt库简介1.2 Jasypt库的主要特点 二、开发环境三、Jasypt集成到SpringBoot33.1 引入依赖3.2 配置Jasypt3.3 加密配置文件信息3.3.1 方案一&#xff08;不推荐&#xff09;a.编写测试类生成加密后的配置文件信息b.运行c.修改原本的配置文件信息 3.…

探索 GPT-4o mini:成本效益与创新的双重驱动

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

【优秀python web系统毕设】基于python的全国招聘数据分析可视化系统,包括随机森林算法

1.1 研究背景 自1997年互联网开始在国内的招聘行业发展至今已有二十几年的历史&#xff0c;互联网招聘进入了蓬勃发展的“黄金时代”。根据智研咨询发布的《2023年中国互联网招聘行业发展现状》报告显示&#xff0c;截至2023年5月&#xff0c;中国互联网招聘平台中&#xff0c…

2-46 基于matlab的声音信号的短时能量、短时过零率、端点检测

基于matlab的声音信号的短时能量、短时过零率、端点检测。通过计算计算短时能量、调整能量门限&#xff0c;然后开始端点检测。输出可视化结果。程序已调通&#xff0c;可直接运行。 2-46 短时能量 短时过零率 端点检测 - 小红书 (xiaohongshu.com)

C++初阶学习第四弹——类与对象(中)

目录 一. 类的默认成员函数 二.六种默认成员函数 1、构造函数 1.1 构造函数的作用 1.2 特性 1.3 默认构造函数 2、析构函数 2.1 析构函数的作用 2.2 析构函数的用法 3、拷贝构造函数 3.1 拷贝构造函数的作用 3.2 特征 3.3 默认拷贝构造函数 三.总结 类与对象&…