[跑代码]BK-SDM: A Lightweight, Fast, and Cheap Version of Stable Diffusion

Installation(下载代码-装环境)

conda create -n bk-sdm python=3.8
conda activate bk-sdm
git clone https://github.com/Nota-NetsPresso/BK-SDM.git
cd BK-SDM
pip install -r requirements.txt
Note on the torch versions we've used
  • torch 1.13.1 for MS-COCO evaluation & DreamBooth finetuning on a single 24GB RTX3090
     

  • torch 2.0.1 for KD pretraining on a single 80GB A10
    火炬2.0.1在单个80GB A100上进行KD预训练

    • 如果A100上总批大小为256的预训练导致gpu内存不足,请检查torch版本并考虑升级到torch>2.0.0。
      我的版本也是torch2.0.1 单个A100(80G)理论上吃的下256batch

小的例子

PNDM采样器 50步去噪声

等效代码(仅修改SD-v1.4的U-Net,同时保留其文本编码器和图像解码器):

Distillation Pretraining

Our code was based on train_text_to_image.py of Diffusers 0.15.0.dev0. To access the latest version, use this link.
BK-SDM的diffusers版本0.15
我的diffusers版本比较高0.24.0

检测是否能够训练(先下载数据集get_laion_data.sh再运行代码kd_train_toy.sh)

1 一个玩具数据集(11K的img-txt对)下载到。

bash scripts/get_laion_data.sh preprocessed_11k

/data/laion_aes/preprocessed_11k (1.7GB in tar.gz;1.8GB数据文件夹)。
get_laion_data.sh

需要修改,实际就是下载这三个数据集,我自行下载

# https://netspresso-research-code-release.s3.us-east-2.amazonaws.com/data/improved_aesthetics_6.5plus/preprocessed_11k.tar.gz
# https://netspresso-research-code-release.s3.us-east-2.amazonaws.com/data/improved_aesthetics_6.5plus/preprocessed_212k.tar.gz
# https://netspresso-research-code-release.s3.us-east-2.amazonaws.com/data/improved_aesthetics_6.5plus/preprocessed_2256k.tar.gz

我修改后下载文件名 https://... .../preprocessed_11k.tar.gz直接粘贴到网址里面也可以下载
wget $S3_URL -0 $FILe_PATH
$S3_URL 就是这个网址
$FILe_PATH 就是下载路径./data/laion_aes/preprocessed_11k

DATA_TYPE=$"preprocessed_11k"  # {preprocessed_11k, preprocessed_212k, preprocessed_2256k}
FILE_NAME="${DATA_TYPE}.tar.gz"DATA_DIR="./data/laion_aes/"
FILE_UNZIP_DIR="${DATA_DIR}${DATA_TYPE}"
FILE_PATH="${DATA_DIR}${FILE_NAME}"if [ "$DATA_TYPE" = "preprocessed_11k" ] || [ "$DATA_TYPE" = "preprocessed_212k" ]; thenecho "-> preprocessed_11k or 212k"S3_URL="https://netspresso-research-code-release.s3.us-east-2.amazonaws.com/data/improved_aesthetics_6.5plus/${FILE_NAME}"
elif [ "$DATA_TYPE" = "preprocessed_2256k" ]; thenS3_URL="https://netspresso-research-code-release.s3.us-east-2.amazonaws.com/data/improved_aesthetics_6.25plus/${FILE_NAME}"
elseecho "Something wrong in data folder name"exit
fiwget $S3_URL -O $FILE_PATH
tar -xvzf $FILE_PATH -C $DATA_DIR
echo "downloaded to ${FILE_UNZIP_DIR}"

2 一个小脚本可以用来验证代码的可执行性,并找到与你的GPU匹配的批处理大小。
批量大小为8 (=4×2),训练BK-SDM-Base 20次迭代大约需要5分钟和22GB的GPU内存。

bash scripts/kd_train_toy.sh
MODEL_NAME="CompVis/stable-diffusion-v1-4"
TRAIN_DATA_DIR="./data/laion_aes/preprocessed_11k" # please adjust it if needed
UNET_CONFIG_PATH="./src/unet_config"UNET_NAME="bk_small" # option: ["bk_base", "bk_small", "bk_tiny"]
OUTPUT_DIR="./results/toy_"$UNET_NAME # please adjust it if neededBATCH_SIZE=2
GRAD_ACCUMULATION=4StartTime=$(date +%s)CUDA_VISIBLE_DEVICES=1 accelerate launch src/kd_train_text_to_image.py \--pretrained_model_name_or_path $MODEL_NAME \--train_data_dir $TRAIN_DATA_DIR\--use_ema \--resolution 512 --center_crop --random_flip \--train_batch_size $BATCH_SIZE \--gradient_checkpointing \--mixed_precision="fp16" \--learning_rate 5e-05 \--max_grad_norm 1 \--lr_scheduler="constant" --lr_warmup_steps=0 \--report_to="all" \--max_train_steps=20 \--seed 1234 \--gradient_accumulation_steps $GRAD_ACCUMULATION \--checkpointing_steps 5 \--valid_steps 5 \--lambda_sd 1.0 --lambda_kd_output 1.0 --lambda_kd_feat 1.0 \--use_copy_weight_from_teacher \--unet_config_path $UNET_CONFIG_PATH --unet_config_name $UNET_NAME \--output_dir $OUTPUT_DIREndTime=$(date +%s)
echo "** KD training takes $(($EndTime - $StartTime)) seconds."

单GPU训练BK-SDM{Base, Small, Tiny}-0.22M数据训练
 

bash scripts/get_laion_data.sh preprocessed_212k
bash scripts/kd_train.sh

1 下载数据集preprocessed_212k
2 训练kd_train.sh
(256batch 训练BD-SM-Base 50K轮次需要300hours/53G单卡)
(64batch 训练BD-SM-Base 50K轮次需要60hours/28G单卡) 不理解?
 

单GPU训练BK-SDM{Base, Small, Tiny}-2.3M数据训练

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/186326.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

windows远程桌面登录,提示:“出现身份验证错误,要求的函数不受支持”

问题: windows登录远程桌面,提示:“出现身份验证错误,要求的函数不受支持”,如下图: 问题原因: windows系统更新,微软系统补丁的更新将 CredSSP 身份验证协议的默认设置进行了调…

Windows + docker + python + vscode : 使用容器docker搭建python开发环境,无需本地安装python开发组件

下载docker for Windows docker window下载 如果没有翻墙工具,可以该网盘中的docker 链接:https://pan.baidu.com/s/11zLy3e5kusZR-4m_Fq_cqg?pwdesmv 提取码:esmv 安装docker docker的安装会重启电脑,不要惊讶,且…

Unity 注释的方法

1、单行注释:使用双斜线(//)开始注释,后面跟注释内容。通常注释一个属性或者方法,如: //速度 public float Speed;//打印输出 private void DoSomething() {Debug.Log("运行了我"); } …

构建智能预约体验:深度解析预约系统源码的代码精髓

随着数字化时代的发展,预约系统在各行业中扮演着越来越重要的角色。本文将深入研究预约系统源码,通过代码示例分析其技术要点,为开发者提供实用的指导,助力构建智能、高效的预约体验。 技术栈综述 预约系统源码采用了现代化的技…

JAVEE初阶 多线程基础(四)

线程安全 一.线程安全存在的问题二.锁三.关于锁的理解四.关于锁操作混淆的理解4.1两个线程是否对同一对象加锁 一.线程安全存在的问题 为什么这里的count不是一百万呢?这就是线程所存在的不安全的问题,由于线程是抢占式执行,同时执行count,操作本质是三个指令 1.load 读取内存…

带大家做一个,易上手的家常炒鸡蛋

想做这道菜 先准备五个鸡蛋 然后将鸡蛋打到碗里面 然后 加小半勺盐 这个看个人喜好 放多少都没问题 不要太咸就好 将鸡蛋搅拌均匀 起锅烧油 油温热了之后 放三个干辣椒进去炒 干辣椒烧黑后 捞出来 味道就留在油里了 然后 倒入鸡蛋液 翻炒 注意翻炒 不要粘锅底 或者 一面糊…

南开大学与字节跳动研究人员推出开源AI工具ChatAnything:用文本描述生成虚拟角色

南开大学与字节跳动研究人员合作推出了一项引人注目的研究,发布了一种名为ChatAnything的全新AI框架。该框架专注于通过在线方式生成基于大型语言模型(LLM)的角色的拟人化形象,从而创造具有定制视觉外观、个性和语调的人物。 简答…

深度解析 Spring Security 自定义异常失效问题:源码剖析与解决方案

🚀 作者主页: 有来技术 🔥 开源项目: youlai-mall 🍃 vue3-element-admin 🍃 youlai-boot 🌺 仓库主页: Gitee 💫 Github 💫 GitCode 💖 欢迎点赞…

设计模式之装饰模式(2)--有意思的想法

目录 背景概述概念角色 基本代码分析❀❀花样重难点聚合关系认贼作父和认孙做父客户端的优化及好处继承到设计模式的演变过程 总结 背景 这是我第二次写装饰模式,这一次是在上一次的基础上进一步探究装饰模式,这一次有了很多新的感受和想法,也…

BUUCTF john-in-the-middle 1

BUUCTF:https://buuoj.cn/challenges 题目描述: 注意:得到的 flag 请包上 flag{} 提交 密文: 下载附件,解压得到john-in-the-middle.pcap文件。 解题思路: 1、双击文件,打开wireshark。 看到很多http流…

基于springboot实现的在线考试系统

一、系统架构 前端:html | js | css | jquery | bootstrap 后端:springboot | springdata-jpa 环境:jdk1.7 | mysql | maven 二、 代码及数据库 三、功能介绍 01. 登录页 02. 管理员端-课程管理 03. 管理员端-班级管理 04. 管理员端-老师管理…

AT89S52单片机智能寻迹小车自动红外避障趋光检测发声发光设计

wx供重浩:创享日记 对话框发送:寻迹 获取完整说明报告源程序数据 小车具有以下几个功能:自动避障功能;寻迹功能(按路面的黑色轨道行驶);趋光功能(寻找前方的点光源并行驶到位&…

C++ ini配置文件的简单读取使用

ini文件就是简单的section 下面有对应的键值对 std::map<std::string, std::map<std::string, std::string>>MyIni::readIniFile() {std::ifstream file(filename);if (!file.is_open()) {std::cerr << "Error: Unable to open file " << …

以STM32CubeMX创建DSP库工程方法一

以STM32CubeMX创建DSP库工程方法 略过时钟树的分配和UART的创建等&#xff0c;直接进入主题生成工程文件 它们中的文件功能如下&#xff1a; 1&#xff09;BasicMathFunctions 基本数学函数&#xff1a;提供浮点数的各种基本运算函数&#xff0c;如向量加减乘除等运算。 2&…

【MATLAB】EWT分解+FFT+HHT组合算法

有意向获取代码&#xff0c;请转文末观看代码获取方式~也可转原文链接获取~ 1 基本定义 EWTFFTHHT组合算法是一种广泛应用于信号处理领域的算法&#xff0c;它结合了经验小波变换&#xff08;Empirical Wavelet Transform&#xff0c;EWT&#xff09;、快速傅里叶变换&#x…

SpringBoot查询指定范围内的坐标点

使用Redis geo实现 redis geo是基于Sorted Set来实现的 Redis 3.2 版本新增了geo相关命令&#xff0c;用于存储和操作地理位置信息。提供的命令包括添加、计算位置之间距离、根据中心点坐标和距离范围来查询地理位置集合等&#xff0c;说明如下: geoadd&#xff1a;添加地理…

DCDC前馈电容与RC串并联电路

一、RC串并联电路特性分析 1、RC串联电路 RC 串联的转折频率&#xff1a; f01/&#xff08;2πR1C1&#xff09;&#xff0c;当输入信号频率大于 f0 时&#xff0c;整个 RC 串联电路总的阻抗基本不变了&#xff0c;其大小等于 R1。 2、RC并联电路 RC 并联电路的转折频率&…

02、Tensorflow实现手写数字识别(数字0-9)

02、Tensorflow实现手写数字识别&#xff08;数字0-9&#xff09; 01、Tensorflow实现二元手写数字识别&#xff08;二分类问题&#xff09; 02、Tensorflow实现手写数字识别&#xff08;数字0-9&#xff09; 开始学习机器学习啦&#xff0c;已经把吴恩达的课全部刷完了&…

zookeeper集群和kafka集群

&#xff08;一&#xff09;kafka 1、kafka3.0之前依赖于zookeeper 2、kafka3.0之后不依赖zookeeper&#xff0c;元数据由kafka节点自己管理 &#xff08;二&#xff09;zookeeper 1、zookeeper是一个开源的、分布式的架构&#xff0c;提供协调服务&#xff08;Apache项目&…

【Openstack Train安装】二、NTP安装

网络时间协议&#xff1a;Network Time Protocol&#xff08;NTP&#xff09;是用来使计算机时间同步化的一种协议&#xff0c;它可以使计算机对其服务器或时钟源&#xff08;如石英钟&#xff0c;GPS等等)做同步化&#xff0c;它可以提供高精准度的时间校正&#xff08;LAN上与…