Exponential Moving Average (EMA) in Stable Diffusion

1.Moving Average in Stable Diffusion (SMA&EMA)

1.Moving average
2.移动平均值
3.How We Trained Stable Diffusion for Less than $50k (Part 3)

Moving Average
在统计学中,移动平均是通过创建整个数据集中不同选择的一系列平均值来分析数据点的计算。


给定一数字序列和固定子集大小,移动平均值的第一个元素是通过对数字序列的初始固定子集求平均值而获得的。然后通过“前移”的方式修改子集;也就是说,排除系列的第一个数字并包括子集中的下一个值。

移动平均的理解,来自移动平均值

1.1 Simple Moving Average(SMA,an unweighted MA)


1.2 Exponential Moving Average (EMA,a weighted MA)

In the context of Stable Diffusion, the Exponential Moving Average (EMA) is a technique used during the training of machine learning models, particularly neural networks, to stabilize and improve the model’s performance.

The Exponential Moving Average is a method of averaging that gives more weight to recent data points, making it more responsive to recent changes compared to a simple moving average, which treats all data points equally.

1.2.1 EMA in Stable Diffusion

In the context of Stable Diffusion, EMA is applied to the model parameters during training to create a smoothed version of the model. This is particularly useful in machine learning because the training process can be noisy, with the model parameters oscillating as they converge towards an optimal solution. By maintaining an EMA of the model parameters, the training process can benefit from the following:

  1. Smoothing: EMA smooths out the parameter updates, reducing the impact of noise and making the training process more stable.
  2. Better Generalization: The EMA version of the model often generalizes better on unseen data compared to the model with the raw parameters. This is because EMA tends to favor parameter values that are more consistent over time.
  3. Preventing Overfitting: By averaging the parameters over time, EMA can help mitigate overfitting, especially in cases where the model might otherwise converge too quickly to a suboptimal solution.

笔者个人理解
代价函数(loss function)是关于参数(weight&bias)的函数,也就是说一个loss值对应一组参数值,loss值表现为震荡,也就是说模型参数也在变化。在训练SD时的MSE Loss在梯度下降过程中是上下震荡的,对应的模型参数也在震荡,可以用EMA取得这些模型参数震荡值的中间值,这个模型参数的中间值也就能更好的代表所有时刻模型参数的平均水平,让模型获得了更好的泛化能力

Stable Diffusion 2 uses Exponential Moving Averaging (EMA), which maintains an exponential moving average of the weights. At every time step, the EMA model is updated by taking 0.9999 times the current EMA model plus 0.0001 times the new weights after the latest forward and backward pass. By default, the EMA algorithm is applied after every gradient update for the entire training period. However, this can be slow due to the memory operations required to read and write all the weights at every step.
每个时间步都对所有参数进行EMA代价较大,因为要在每个时刻读写模型的全部参数
EMA t = 0.0001 ⋅ x t + 0.9999 ⋅ EMA t − 1 \text{EMA}_t=0.0001\cdot x_t+0.9999\cdot \text{EMA}_{t-1} EMAt=0.0001xt+0.9999EMAt1
为了使得计算EMA代价减小,我们仅仅采取在最后时间段进行EMA计算
To avoid this costly procedure, we start with a key observation: since the old weights are decayed by a factor of 0.9999 at every batch, the early iterations of training only contribute minimally to the final average. This means we only need to take the exponential moving average of the final few steps. Concretely, we train for 1,400,000 batches and only apply EMA for the final 50,000 steps, which is about 3.5% of the training period. The weights from the first 1,350,000 iterations decay away by (0.9999)^50000, so their aggregate contribution would have a weight of less than 1% in the final model. Using this technique, we can avoid adding overhead for 96.5% of training and still achieve a nearly equivalent EMA model.

1.2.2 Implementation in Stable Diffusion

During the training of a diffusion model, the EMA of the model’s weights is updated alongside the regular updates. Here’s a typical process:

  1. Initialize EMA Weights: At the start of training, initialize the EMA weights to be the same as the model’s initial weights.
  2. Update During Training: After each batch update, update the EMA weights using the formula mentioned above. This requires storing a separate set of weights for the EMA.
  3. Use for Inference: At the end of the training, use the EMA weights for inference instead of the raw model weights. This is because the EMA weights represent a more stable and potentially better-performing version of the model.

1.2.3 Practical Considerations

  1. Choosing α \alpha α:The smoothing factor α \alpha α is a hyperparameter that needs to be chosen carefully. A common practice is to set α \alpha α based on the number of iterations or epochs, such as α = 2 N + 1 \alpha=\frac{2}{N+1} α=N+12 where N N N is the number of iterations
  2. Performance Overhead: Maintaining EMA weights requires additional memory and computational overhead, but the benefits in terms of model stability and performance often outweigh these costs.

module.py

class EMA:
# Initializes the EMA object with a smoothing factor (beta) and a step counter (step).def __init__(self, beta):super().__init__()self.beta = beta  # Smoothing factor for the exponential moving averageself.step = 0  # Step counter to keep track of the number of updates
# Updates the moving average of the parameters of the EMA model (ma_model) based on the current model (current_model)def update_model_average(self, ma_model, current_model):# Update the moving average (EMA) of model parametersfor current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):old_weight, up_weight = ma_params.data, current_params.data# Update the moving average of the parametersma_params.data = self.update_average(old_weight, up_weight)
# Computes the exponentially weighted average of the old and new parameters.def update_average(self, old, new):# Compute the updated averageif old is None:return newreturn old * self.beta + (1 - self.beta) * new
# Either resets the EMA model parameters to match the current model parameters 
# if the step count is less than step_start_ema, 
# or updates the EMA model parameters based on the current model parameters. 
# It increments the step counter after each call.def step_ema(self, ema_model, model, step_start_ema=2000):# Update EMA model parameters or reset them based on the step countif self.step < step_start_ema:self.reset_parameters(ema_model, model)else:self.update_model_average(ema_model, model)self.step += 1  # Increment the step counter
# Copies the current model's parameters to the EMA model to initialize the EMA model parametersdef reset_parameters(self, ema_model, model):# Initialize EMA model parameters to be the same as the current model's parametersema_model.load_state_dict(model.state_dict())

train.py

def train(args):device = args.device  # Get the device to run the training onmodel = UNET().to(device)   # Initialize the model and move it to the devicemodel.train()optimizer = optim.AdamW(model.parameters(), lr=args.lr)  # set up the optimizer with AdamWmse = nn.MSELoss()  # Mean Squared Error loss functionlogger = SummaryWriter(os.path.join("runs", args.run_name))len_train = len(train_loader)
# EMA:Exponential Moving Averageema = EMA(0.995)  # Exponential Moving Average with decay rate 0.995
# At the start of training, initialize the EMA weights to be the same as the model’s initial weights.ema_model = copy.deepcopy(model).eval().requires_grad_(False)  # Create a copy of the model for EMA, set to eval mode and no gradientsprint('Start into the loop !')for epoch in range(args.epochs):logging.info(f"Starting epoch {epoch}:")  # log the start of the epochprogress_bar = tqdm(train_loader)  # progress bar for the dataloaderoptimizer.zero_grad()  # Explicitly zero the gradient buffersaccumulation_steps = 4# Load all data into a batchfor batch_idx, (images, captions) in enumerate(progress_bar):images = images.to(device)  # move images to the device# The dataloaer will add a batch size dimension to the tensor, but I've already added batch size to the VAE# and CLIP input, so we're going to remove a batch size and just keep the batch size of the dataloaderimages = torch.squeeze(images, dim=1)captions = captions.to(device)  # move caption to the devicetext_embeddings = torch.squeeze(captions, dim=1) # squeeze batch_sizetimesteps = ddpm_sampler.sample_timesteps(images.shape[0]).to(device)  # Sample random timestepsnoisy_latent_images, noises = ddpm_sampler.add_noise(images, timesteps)  # Add noise to the imagestime_embeddings = timesteps_to_time_emb(timesteps)# x_t (batch_size, channel, Height/8, Width/8) (bs,4,256/8,256/8)# caption (batch_size, seq_len, dim) (bs, 77, 768)# t (batch_size, channel) (batch_size, 1280)# (bs,320,H/8,W/8)with torch.no_grad():last_decoder_noise = model(noisy_latent_images, text_embeddings, time_embeddings)# (bs,4,H/8,W/8)final_output = diffusion.final.to(device)predicted_noise = final_output(last_decoder_noise).to(device)loss = mse(noises, predicted_noise)  # Compute the lossloss.backward()  # Backpropagate the lossif (batch_idx + 1) % accumulation_steps == 0:  # Wait for several backward passesoptimizer.step()  # Now we can do an optimizer stepoptimizer.zero_grad()  # Reset gradients to zero
# EMA:Exponential Moving Averageema.step_ema(ema_model, model)progress_bar.set_postfix(MSE=loss.item())  # Update the progress bar with the loss# log the loss to TensorBoardlogger.add_scalar("MSE", loss.item(), global_step=epoch * len_train + batch_idx)# Save the model checkpointos.makedirs(os.path.join("models", args.run_name), exist_ok=True)torch.save(model.state_dict(), os.path.join("models", args.run_name, f"stable_diffusion.ckpt"))torch.save(optimizer.state_dict(),os.path.join("models", args.run_name, f"optim.pt"))  # Save the optimizer state

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/875658.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数据结构与算法-插入排序

&#x1f49d;&#x1f49d;&#x1f49d;首先&#xff0c;欢迎各位来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里不仅可以有所收获&#xff0c;同时也能感受到一份轻松欢乐的氛围&#xff0c;祝你生活愉快&#xff01; 文章目录 引言一、插入排…

unity ui toolkit的使用

UIToolkitExamples (github)样例 GitHub - ikewada/UIToolkitExamples: チュートリアル動画「使ってみようUI Toolkit」のためのサンプルプロジェクトです官网 Unity - Manual: UI Toolkit视频教程 使用 UI Toolkit - 上集_哔哩哔哩_bilibili 使用 UI Toolkit - 下集_哔哩哔哩_…

Java | Leetcode Java题解之第283题移动零

题目&#xff1a; 题解&#xff1a; class Solution {public void moveZeroes(int[] nums) {int n nums.length, left 0, right 0;while (right < n) {if (nums[right] ! 0) {swap(nums, left, right);left;}right;}}public void swap(int[] nums, int left, int right)…

赋能未来教育,3DCAT助力深圳鹏程技师学院打造5G+XR实训室

随着国家对教育行业的重视&#xff0c;实训室建设已成为推动教育现代化的关键。《教育信息化2.0行动计划》、《职业教育示范性虚拟仿真实训基地建设指南》等政策文件&#xff0c;明确指出了加强虚拟仿真实训教学环境建设的重要性。 在这一大背景下&#xff0c;教育行业对于实训…

初识C++ · AVL树(1)

目录 前言&#xff1a; 1 AVL树的创建 2 部分成员函数 2.1 查找 2.2 中序遍历 2.3 插入 2.4 左旋转 2.5右旋转 前言&#xff1a; 上文&#xff0c;上上文提到了map set&#xff0c;二叉搜索树&#xff0c;其实都是为了近两文做铺垫的&#xff0c;虽然map的底层是红黑树…

openFeign配置okhttp

原来的项目出现了性能问题&#xff0c;老大不知道怎么的&#xff0c;让我改openFeign线程池为okhttp&#xff0c;说原生的不支持线程池性能比较差。 原openFeign配置文章地址 一、pom文件 <dependency><groupId>org.springframework.cloud</groupId><arti…

【2024最新华为OD-C/D卷试题汇总】[支持在线评测] LYA的跳格子游戏(200分) - 三语言AC题解(Python/Java/Cpp)

🍭 大家好这里是清隆学长 ,一枚热爱算法的程序员 ✨ 本系列打算持续跟新华为OD-C/D卷的三语言AC题解 💻 ACM银牌🥈| 多次AK大厂笔试 | 编程一对一辅导 👏 感谢大家的订阅➕ 和 喜欢💗 🍿 最新华为OD机试D卷目录,全、新、准,题目覆盖率达 95% 以上,支持题目在线…

手写spring简易版本,让你更好理解spring源码

首先我们要模拟spring&#xff0c;先搞配置文件&#xff0c;并配置bean 创建我们需要的类&#xff0c;beandefito&#xff0c;这个类是用来装解析后的bean&#xff0c;主要三个字段&#xff0c;id&#xff0c;class&#xff0c;scop&#xff0c;对应xml配置的属性 package org…

第二讲:NJ网络配置

Ethernet/IP网络拓扑结构 一. NJ EtherNet/IP 1、网络端口位置 NJ的CPU上面有两个RJ45的网络接口,其中一个是EtherNet/IP网络端口(另一个是EtherCAT的网络端口) 2、网络作用 如图所示,EtherNet/IP网络既可以做控制器与控制器之间的通信,也可以实现与上位机系统的对接通…

MySQL --- 表的操作

在对表进行操作时&#xff0c;需要先选定操作的表所在的数据库&#xff0c;即先执行 use 数据库名; 一、创建表 create table 表名( field1 datatype, field2 datatype, field3 datatype ) character set 字符集 collate 校验规则 engine 存储引擎 ; 说明&#xff1a…

从零入门 AI for Science(AI+药物) #Datawhale AI 夏令营

使用平台 我的Notebook 魔搭社区 https://modelscope.cn/my/mynotebook/preset 主要操作 运行实例&#xff0c;如果有时长尽量选择方式二&#xff08;以下操作基于方式二的实例实现&#xff09; 创建文件夹&#xff0c;并重命名为 2.3siRNA 上传两个文件 到文件夹&#…

BGP路由反射器

原理概述 缺省情况下&#xff0c;路由器从它的一个 IBGP对等体那里接收到的路由条目不会被该路由器再传递给其他IBGP对等体&#xff0c;这个原则称为BGP水平分割原则&#xff0c;该原则的根本作用是防止 AS内部的BGP路由环路。因此&#xff0c;在AS内部&#xff0c;一般需要每台…

【Android】数据存储方案——文件存储、SharedPreferences、SQLite数据库用法总结

文章目录 文件存储存储到文件读取文件 SharedPreferences存储存储获取SharedPreferences对象Context 类的 getSharedPreferences() 方法Activity 类的 getPreferences() 方法PreferenceManager 类中的 getDefaultSharedPreferences() 方法 示例 读取记住密码的功能 SQLite数据库…

4.Java Web开发模式(javaBean+servlet+MVC)

Java Web开发模式 一、Java Web开发模式 1.javaBean简介 JavaBeans是Java中一种特殊的类&#xff0c;可以将多个对象封装到一个对象&#xff08;bean&#xff09;中。特点是可序列化&#xff0c;提供无参构造器&#xff0c;提供getter方法和setter方法访问对象的属性。名称中…

JAVA代码审计JAVA0基础学习(需要WEB基础知识)DAY2

JAVA 在 SQL执行当中 分为3种写法&#xff1a; JDBC注入分析 Mybatis注入分析 Hibernate注入分析 JDBC 模式不安全JAVA代码示例部分特征 定义了一个 sql 参数 直接让用户填入id的内容 一个最简单的SQL语句就被执行了 使用安全语句却并没有被执行 Mybatis&#xff1a; #…

【MetaGPT系列】【MetaGPT完全实践宝典——多智能体实践】

目录 前言一、智能体1-1、Agent概述1-2、Agent与ChatGPT的区别 二、多智能体框架MetaGPT2-1、安装&配置2-2、使用已有的Agent&#xff08;ProductManager&#xff09;2-3、多智能体系统介绍2-4、多智能体案例分析2-4-1、构建智能体团队2-4-2、动作/行为 定义2-4-3、角色/智…

【OpenCV C++20 学习笔记】调节图片对比度和亮度(像素变换)

调节图片对比度和亮度&#xff08;像素变换&#xff09; 原理像素变换亮度和对比度调整 代码实现更简便的方法结果展示 γ \gamma γ校正及其实操案例线性变换的缺点 γ \gamma γ校正低曝光图片矫正案例代码实现 原理 关于OpenCV的配置和基础用法&#xff0c;请参阅本专栏的其…

五、工厂方法模式

文章目录 1 基本介绍2 案例2.1 Drink 抽象类2.2 Tea 类2.3 Coffee 类2.4 DrinkFactory 抽象类2.5 TeaFactory 类2.6 CoffeeFactory 类2.7 Client 类2.8 Client 类运行结果2.9 总结 3 各角色之间的关系3.1 角色3.1.1 Product ( 抽象产品 )3.1.2 ConcreteProduct ( 具体产品 )3.1…

生物信息学新突破:在英特尔 Gaudi 2 上实现 ProtST 蛋白质语言模型加速

引言 随着人工智能技术的快速发展&#xff0c;蛋白质结构预测和语言模型在生物信息学领域扮演着越来越重要的角色。ProtST作为一种新兴的蛋白质语言模型&#xff0c;其性能在英特尔 Gaudi 2 加速器的助力下得到了显著提升。本文将探讨如何利用英特尔 Gaudi 2 加速 ProtST 模型…

哈希表相关的力扣题和讲解和Java、C++常用的数据结构(哈希法)

20240725 一、什么时候适用什么样的结构。1.java中1.1 HashSet&#xff1a;1.2 TreeSet&#xff1a;1.3 LinkedHashSet&#xff1a;1.4 HashMap&#xff1a;1.5 TreeMap&#xff1a;1.6 LinkedHashMap&#xff1a;1.7 总结 2. c中2.1 std::unordered_set&#xff1a;2.2 std::s…