Dreambooth on Stable Diffusion

Dreambooth on Stable Diffusion

1.DreamBooth: Fine Tuning Text-to-Image Diffusion Models for Subject-Driven Generation
2.[论文简析]DreamBooth: Fine Tuning Text-to-Image Diffusion Models…[2208.12242]
3.sd_dreambooth_extension
4.stable-dreambooth

对3-5张特定主体的图片进行加噪 + prompt(由unique identifier和class noun组成,前者用于和特定主体进行bind,后者用于继承pre-trained SD中相应主体的特征)输入SD,对SD进行fine-tune(微调),但是这样会造成两个问题:
(1)SD过拟合到这个特定主体上,削弱了pre-trained SD生成相关主体的多样性
(2)SD出现对语义理解的偏差,假设我们输入特定主体为柯基狗狗,当我们在fine-tune之后的SD模型上输入dog(class noun)时,SD理解为dog指的就只是柯基

如何解决这两个问题?
论文引入了 Class-Specific Prior Preservation Loss,相当于一个正则项,对fine-tune之后的SD进行了修正
整个Loss包含两个部分:Reconstruction Loss也就是fine-tune时的损失,Class-Specific Prior Preservation Loss也就是对fine-tune出现问题后的修正

它是如何修正的?
在fine-tune SD之前,先用pre-trained SD生成相应多样化的主体(各种狗狗),将这些图片加噪后输入到SD进行fine-tune,得到的结果与之前生成的进行损失,这相当于前者对fine-tune后的图片进行了监督,提高了fine-tune后生成图片的多样性
fine-tune中prompt包含pre-trained prompt中的class noun,相当于告诉fine-tune SD 我们这里指的狗狗不只是会趴着,也会像pre-trained SD中prompt中描述的狗狗拥有跑、跳等动作


第一行为fine-tune时使用的数据集
第二行为只有reconstruction loss时的效果,输入趴着的狗狗,fine-tune后只会输出趴着的狗狗
第三行加入了正则项损失时的效果,输出有站着、坐着的狗狗,而不仅仅是趴着的狗狗

1.准备数据
在fine-tune之前,我们使用 pre-trained SD模型生成100张图片作为fine-tune模型时的监督信号


class文件夹保存有刚刚生成的100张图片

代码来自:sample.py

from diffusers.pipelines import StableDiffusionPipeline
import torch
from torchvision import transforms
import os
from pathlib import Path# Define the number of samples to generate and the batch size for processing
sample_nums = 100
batch_size = 16# Define the prompt for image generation
prompt = "a photo of dog"# Define the directory where generated images will be saved
save_dir = Path("./data/dogs/class")
save_dir.mkdir(parents=True, exist_ok=True)  # Create the directory if it doesn't existif __name__ == "__main__":# Define the model ID and device to usemodel_id = "CompVis/stable-diffusion-v1-4"device = "cuda"# Load the pre-trained Stable Diffusion model from Hugging Face and move it to the specified devicemodel = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True, cache_dir="./.cache").to(device)# Create a list of prompts, each repeated to match the number of samples neededdatasets = [prompt] * sample_nums# Split the list of prompts into batches of the specified batch sizedatasets = [datasets[x:x + batch_size] for x in range(0, sample_nums, batch_size)]# Initialize an identifier for naming the saved imagesid = 0# Iterate over each batch of promptsfor text in datasets:print(f"Processing batch: {text}")# Generate images for the current batch of prompts without computing gradientswith torch.no_grad():output = model(text, height=512, width=512, num_inference_steps=50)# Generate images using the model with specified height, width, and inference stepsimages = output.images  # Get images list from output dictionaryprint(f"Generated {len(images)} images")# Save each generated image to the specified directoryfor image in images:if isinstance(image, torch.Tensor):image = transforms.ToPILImage()(image)  # Convert tensor to PIL image if neededimage_path = save_dir / f"{id}.png"image.save(image_path)print(f"Saved image to {image_path}")id += 1

我们准备几张特定主体的图片(下图展示的是我们家的边牧)用于fine-tune

这几张特定主体图片保存在instance文件夹中

3.Fine-Tuning the Model:
(1)Load the pre-trained Stable Diffusion model.
(2)Use the DreamBooth training script to fine-tune the model with your reference images.
(3)During training, the model learns to associate the reference images with the specific textual description provided.

代码来自:train.py
不同于Drembooth论文中使用Renconstruction Loss和Class-Specific Prior Preservation Loss计算两个loss(这两个loss的权重不同),此fine-tune过程中,将instance和class中的图片同时输入到unet网络中,直接计算一个MSE Loss,也可能是由于此操作这种训练的效果并没有那么好

import dataclasses
import json
import os
from dataclasses import dataclass
from typing import List
from dataset import TrainDatasetimport torch
import torch.nn.functional as F
from accelerate import Accelerator
from diffusers.schedulers import DDPMScheduler, LMSDiscreteScheduler
from diffusers.pipelines import StableDiffusionPipeline
from PIL import Image
from tqdm.auto import tqdm
from torchvision import transforms
from datasets import load_dataset
from pathlib import Path@dataclass
class TrainingConfig:# Task specific parametersinstance_prompt: str = "photo of a [V] dog"class_prompt: str = "photo of a dog"evaluate_prompt = ["photo of a [V] dog"] * 4 + ["photo of a [V] dog in a doghouse"] * 4 + ["photo of a [V] dog in a bucket"] * 4 + ["photo of a sleeping [V] dog"] * 4data_path: str = "./data/dogs"identifier: str = "Dinger"# Basic Training Parametersnum_epochs: int = 1train_batch_size: int = 2learning_rate: float = 1e-5image_size: int = 512  # the generated image resolutiongradient_accumulation_steps: int = 1# Hyperparmeter for diffusion modelsnum_train_timesteps: int = 1000train_guidance_scale: float = 1  # guidance scale at trainingsample_guidance_scale: float = 7.5  # guidance scale at inference# Practical Training Settingsmixed_precision: str = 'fp16'  # `no` for float32, `fp16` for automatic mixed precisionsave_image_epochs: int = 1save_model_epochs: int = 1output_dir: str = './logs/dog_finetune'overwrite_output_dir: bool = True  # overwrite the old model when re-running the notebookseed: int = 42def __post_init__(self):self.instance_prompt = self.instance_prompt.replace("[V]", self.identifier)self.evaluate_prompt = [s.replace("[V]", self.identifier) for s in self.evaluate_prompt]def pred(model, noisy_latent, time_steps, prompt, guidance_scale):batch_size = noisy_latent.shape[0]text_input = model.tokenizer(prompt,padding="max_length",max_length=model.tokenizer.model_max_length,truncation=True,return_tensors="pt",)with torch.no_grad():text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0]# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`# corresponds to doing no classifier free guidance.do_classifier_free_guidance = guidance_scale > 1.0# get unconditional embeddings for classifier free guidanceif do_classifier_free_guidance:max_length = text_input.input_ids.shape[-1]uncond_input = model.tokenizer([""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt")uncond_embeddings = model.text_encoder(uncond_input.input_ids.to(model.device))[0]# For classifier free guidance, we need to do two forward passes.# Here we concatenate the unconditional and text embeddings into a single batch# to avoid doing two forward passestext_embeddings = torch.cat([uncond_embeddings, text_embeddings])latent_model_input = torch.cat([noisy_latent] * 2) if do_classifier_free_guidance else noisy_latenttime_steps = torch.cat([time_steps] * 2) if do_classifier_free_guidance else time_stepsnoise_pred = model.unet(latent_model_input, time_steps, encoder_hidden_states=text_embeddings)["sample"]# perform guidanceif do_classifier_free_guidance:noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)return noise_preddef train_loop(config: TrainingConfig, model: StableDiffusionPipeline, noise_scheduler, optimizer, train_dataloader):# Initialize accelerator and tensorboard loggingaccelerator = Accelerator(mixed_precision=config.mixed_precision,gradient_accumulation_steps=config.gradient_accumulation_steps,)if accelerator.is_main_process:accelerator.init_trackers("train_example")# Prepare everything# There is no specific order to remember, you just need to unpack the# objects in the same order you gave them to the prepare method.model, optimizer, train_dataloader = accelerator.prepare(model, optimizer, train_dataloader)global_step = 0# Now you train the modelfor epoch in range(config.num_epochs):progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)progress_bar.set_description(f"Epoch {epoch}")for step, batch in enumerate(train_dataloader):instance_imgs, instance_prompt, class_imgs, class_prompt = batchimgs = torch.cat((instance_imgs, class_imgs), dim=0)prompt = instance_prompt + class_prompt# Sample noise to add to the imagesbs = imgs.shape[0]# Sample a random timestep for each imagetimesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bs,), device=accelerator.device).long()# Add noise to the clean images according to the noise magnitude at each timestep# (this is the forward diffusion process)with torch.no_grad():latents = model.vae.encode(imgs).latent_dist.sample() * 0.18215noise = torch.randn(latents.shape, device=accelerator.device)noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)with accelerator.accumulate(model):# Predict the noise residualnoise_pred = pred(model, noisy_latents, timesteps, prompt, guidance_scale=config.train_guidance_scale)loss = F.mse_loss(noise_pred, noise)accelerator.backward(loss)accelerator.clip_grad_norm_(model.unet.parameters(), 1.0)optimizer.step()optimizer.zero_grad()# Clear GPU cachetorch.cuda.empty_cache()progress_bar.update(1)logs = {"loss": loss.detach().item(), "step": global_step}progress_bar.set_postfix(**logs)accelerator.log(logs, step=global_step)global_step += 1# After each epoch you optionally sample some demo images with evaluate() and save the modelif accelerator.is_main_process:if epoch % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:evaluate(config, epoch, model)if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:model.save_pretrained(config.output_dir)def make_grid(images, rows, cols):w, h = images[0].sizegrid = Image.new('RGB', size=(cols * w, rows * h))for i, image in enumerate(images):grid.paste(image, box=(i % cols * w, i // cols * h))return griddef evaluate(config: TrainingConfig, epoch, pipeline: StableDiffusionPipeline):# Sample some images from random noise (this is the backward diffusion process).# The default pipeline output type is `List[PIL.Image]`with torch.no_grad():with torch.autocast("cuda"):images = \pipeline(config.evaluate_prompt, num_inference_steps=50, width=config.image_size, height=config.image_size,guidance_scale=config.sample_guidance_scale)["sample"]# Make a grid out of the imagesimage_grid = make_grid(images, rows=4, cols=4)# Save the imagestest_dir = os.path.join(config.output_dir, "samples")os.makedirs(test_dir, exist_ok=True)image_grid.save(f"{test_dir}/{epoch:04d}.jpg")def get_dataloader(config: TrainingConfig):dataset = TrainDataset(config.data_path, config.instance_prompt, config.class_prompt, config.image_size)dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True, drop_last=True,pin_memory=True)return dataloaderif __name__ == "__main__":config = TrainingConfig()output_dir = Path(config.output_dir)output_dir.mkdir(parents=True, exist_ok=True)with open(output_dir / "config.json", "w") as f:json.dump(dataclasses.asdict(config), f)model_id = "CompVis/stable-diffusion-v1-4"device = "cuda"try:model = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True, cache_dir="./.cache").to(device)except Exception as e:print(e)print("Run 'huggingface-cli login' to store auth token.")exit(1)train_dataloader = get_dataloader(config)optimizer = torch.optim.AdamW(model.unet.parameters(), lr=config.learning_rate)noise_scheduler = DDPMScheduler(num_train_timesteps=config.num_train_timesteps, beta_start=0.00085, beta_end=0.0120)train_loop(config, model, noise_scheduler, optimizer, train_dataloader)

4.Generate Images:
(1)After training, you can generate images using the fine-tuned model by providing text prompts that describe the desired output.
(2)The model should now generate images that reflect the characteristics of the reference images.

代码来自:inference.py

from pathlib import Path
from diffusers.pipelines import StableDiffusionPipeline
import torch
from argparse import ArgumentParser
import jsondef parse_args():parser = ArgumentParser()parser.add_argument("--prompt", required=True)parser.add_argument("--checkpoint_dir", required=True)parser.add_argument("--save_dir", default="outputs")parser.add_argument("--sample_nums", default=16)parser.add_argument("-gs", "--guidance_scale", type=float, default=7.5)return parser.parse_args()if __name__ == "__main__":args = parse_args()with open(Path(args.checkpoint_dir) / 'config.json') as f:config = json.loads(f.read())args.prompt = args.prompt.replace("[V]", config["identifier"])device = "cuda"model = StableDiffusionPipeline.from_pretrained(args.checkpoint_dir).to(device)with torch.no_grad():with torch.autocast("cuda"):output = model([args.prompt] * args.sample_nums, height=512, width=512, guidance_scale=args.guidance_scale,num_inference_steps=50)images = output['images']  # Update this based on actual keysave_dir = Path(args.save_dir)save_dir.mkdir(parents=True, exist_ok=True)for i, image in enumerate(images):image.save(save_dir / f'{i}.jpg')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/50938.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

17.5【C语言】static的补充说明

static &#xff08;静态的) 作用&#xff1a;修饰局部变量&#xff0c;修饰全局变量&#xff0c;修饰函数 对比两段代码 #include <stdio.h> void test() {int a 5;a;printf("%d ", a); } int main() {int i 0;for(i0; i<5; i){test();}return 0; } …

HarmonyOS 质量、测试、上架速浏

1.应用质量要求&#xff1a; 1. 应用体验质量建议: 功能数据完备 功能完备 数据完备 基础体验要求 基础约束 兼容性 稳定性 性能 功耗 安全…

IPython的Bash之舞:%%bash命令全解析

IPython的Bash之舞&#xff1a;%%bash命令全解析 IPython的%%bash魔术命令为Jupyter Notebook用户提供了一种在单元格中直接执行Bash脚本的能力。这个特性特别适用于需要在Notebook中运行系统命令或Bash特定功能的场景。本文将详细介绍如何在IPython中使用%%bash命令&#xff…

reduceByKey 函数详解

reduceByKey 函数详解 实现原理 reduceByKey 函数主要用于处理分布式数据集。它接收两个操作符作为参数&#xff1a; keySelector&#xff1a;这是一个映射函数&#xff0c;用于从输入元素中提取键。 valueReducer&#xff1a;这是另一个函数&#xff0c;用于将具有相同键的…

网格布局 HTML CSS grid layout demo

文章目录 页面效果代码 (HTML CSS)参考 页面效果 代码 (HTML CSS) <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"…

高速传输新体验:探索最新USB-C移动硬盘盒的无限可能

在数字化时代&#xff0c;数据存储的重要性不言而喻&#xff0c;而硬盘盒作为连接外部硬盘与计算机的关键设备&#xff0c;其功能也在不断进化。近年来&#xff0c;随着技术的创新与发展&#xff0c;市场上出现了一种新型硬盘盒——它不仅能安全、高效地存储和传输数据&#xf…

安装VMware Workstation Pro

一、下载 通过百度网盘分享的文件&#xff1a;VMware-workstation-full-16.2.4-2008... 链接&#xff1a;https://pan.baidu.com/s/1mDnFhLQErBlpeX_KjsgtzA 提取码&#xff1a;0bw7 二、安装 &#xff08;1&#xff09;双击exe文件 &#xff08;2&#xff09;安装软件 &…

Docker搭建Flink

Docker搭建Flink环境的步骤&#xff1a; 1. 安装Docker 确保你的系统已经安装了Docker。如果没有安装&#xff0c;可以参考以下命令进行安装&#xff1a; # 对于Ubuntu系统 sudo apt-get update sudo apt-get install docker-ce docker-ce-cli containerd.io # 对于CentOS系…

算法-bfs-八数码

题目一 解题思路 将每一串字符都想象成一个点&#xff0c;已知起点和终点,每一次更新相邻的节点&#xff0c;采用bfs得到到达终点的最短路径。 数据结构&#xff1a; unordered_map&#xff08;哈希表&#xff09;来存储每种字符串对应情况需要移动的次数。 queue(队列)存储…

刷题计划 day4 【双指针、快慢指针、环形链表】链表下

⚡刷题计划day4继续&#xff0c;可以点个免费的赞哦~ 下一期将会开启哈希表刷题专题&#xff0c;往期可看专栏&#xff0c;关注不迷路&#xff0c; 您的支持是我的最大动力&#x1f339;~ 目录 ⚡刷题计划day4继续&#xff0c;可以点个免费的赞哦~ 下一期将会开启哈希表刷题…

vuex学习day01-vuex简述、基于脚手架创建项目、基于脚手架创建项目

1、vuex简述 之所以采用vuex是因为当我们有多个公共状态的组件时&#xff0c;vue的简单性容易崩溃 &#xff08;1&#xff09;概念 Vuex 是Vue.js 应用程序的状态管理模式库。简单讲&#xff0c;vuex是vue的一个状态管理工具。 &#xff08;2&#xff09;作用 管理vue中的…

w30-python02-pytest入门

代码如下&#xff1a; import pytest class Test_Obj:"""测试类"""#用例级别前后置def setup(self):print(用例级别------的前置处理)def teardown(self):print("用例级别--------的后置处理")# 用例def test_case1(self):print(&quo…

自动驾驶-机器人-slam-定位面经和面试知识系列05之常考公式推导(02)

这个博客系列会分为C STL-面经、常考公式推导和SLAM面经面试题等三个系列进行更新&#xff0c;基本涵盖了自己秋招历程被问过的面试内容&#xff08;除了实习和学校项目相关的具体细节&#xff09;。在知乎和牛客&#xff08;牛客上某些文章上会附上内推码&#xff09;也会同步…

mfc100u.dll 文件缺失?两种方法快速修复丢失mfc100u.dll 文件难题

您的电脑是否遭遇了 mfc100u.dll 文件缺失的问题&#xff1f;这种情况通常由多种原因引起。在本文中&#xff0c;我们将介绍两种修复 mfc100u.dll 文件丢失问题的策略——一种是手动方法&#xff0c;另一种是自动修复的使用。我们将探讨如何有效地解决 mfc100u.dll 文件缺失的几…

Linux中基本目录介绍

/bin: bin是Binary的缩写, 这个目录存放着最基本的程序。/boot: 这里存放的是启动Linux时使用的一些核心文件&#xff0c;包括一些连接文件和映像文件。&#xff08;不要动&#xff09;/dev: dev是Device(设备)的缩写, 存放的是Linux的外部设备&#xff0c;在Linux中访问这些设…

vscode 调试web后端

1、调试环境配置 一、安装python环境管理器 其中要先在vscode选择对应的python环境&#xff0c;最方便的是按照环境管理器后从中选择。其中在【externsions】里面安装python即可。 如下&#xff1a; 二、编写launch.json文件 其中如下&#xff1a; {// Use IntelliSense …

2024年第三届钉钉杯大学生大数据挑战赛初赛赛题浅析

一图流 题目 涉及模型 所需背景知识 综合难度 题量 题目复杂度 初赛A&#xff1a;烟草营销案例数据分析 ARIMA, SARIMA, Prophet, 线性回归, LSTM, 随机森林, XGBoost 时间序列分析, 机器学习, 数据预处理, 统计学 中等 3个主要问题 涉及时间序列预测和集成学习模型…

Redis 哨兵搭建

Redis哨兵(sentinel)搭建 7.2.5 文章目录 一、单节点哨兵1. 环境介绍2. 环境前准备工作3. 安装 Redis 7.2.54. redis 配置修改并且启动4.1 修改配置文件4.2 编写启动脚本 5. 开启主从5.1 开启5.2 主库实例查看主从信息 6. 创建sentinel的配置文件并启动6.1 创建配置文件6.2 启…

Qt-线程-创建线程的三种方法

文章目录 1: 派生于QThread2: 派生与QRunable3: moveToThread 1: 派生于QThread 派生于Qthread 是创建的Qt创建线程的方法 &#xff0c;重写虚函数void QThread::run()&#xff0c;在run写具体的内容&#xff0c;外部通过start调用&#xff0c;即可执行线程体run() 注意 :派生…

算法-插入排序

插入排序步骤 前面文章分享了两种排序算法&#xff1a;冒泡排序和选择排序。虽然它们的效率都是O(N2)&#xff0c;但其实选择排序比冒泡排序快一倍。现在来学第三种排序算法——插入排序。你会发现&#xff0c;顾及最坏情况以外的场景将是多么有用。 插入排序包括以下步骤。 …