nanodiffusion代码逐行理解之diffusion

目录

  • 一、diffusion创建
  • 二、GaussianDiffusion定义
  • 三、代码理解
    • def __init__(self,model,img_size,img_channels,num_classes,betas, loss_type="l2", ema_decay=0.9999, ema_start=5000, ema_update_rate=1,):
    • def remove_noise(self, x, t, y, use_ema=True):
    • def sample(self, batch_size, device, y=None, use_ema=True):
    • def perturb_x(self, x, t, noise):
    • def get_losses(self, x, t, y):
    • def forward(self, x, y=None):
    • def generate_cosine_schedule(T, s=0.008)和def generate_linear_schedule(T, low, high):

一、diffusion创建

diffusion = GaussianDiffusion(model,args.img_size,args.img_channels,args.num_classes,betas,ema_decay=args.ema_decay,ema_update_rate=args.ema_update_rate,ema_start=2000,loss_type=args.loss_type,)

二、GaussianDiffusion定义

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as Ffrom functools import partial
from copy import deepcopyfrom ddpm.ema import EMA
from ddpm.utils import extractclass GaussianDiffusion(nn.Module):def __init__(self,model,img_size,img_channels,num_classes,betas,loss_type="l2",ema_decay=0.9999,ema_start=5000,ema_update_rate=1,):super().__init__()self.model = modelself.ema_model = deepcopy(model)self.ema = EMA(ema_decay)self.ema_decay = ema_decayself.ema_start = ema_startself.ema_update_rate = ema_update_rateself.step = 0self.img_size = img_sizeself.img_channels = img_channelsself.num_classes = num_classesif loss_type not in ["l1", "l2"]:raise ValueError("__init__() got unknown loss type")self.loss_type = loss_typeself.num_timesteps = len(betas)alphas = 1.0 - betasalphas_cumprod = np.cumprod(alphas)to_torch = partial(torch.tensor, dtype=torch.float32)self.register_buffer("betas", to_torch(betas))self.register_buffer("alphas", to_torch(alphas))self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))self.register_buffer("sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1 - alphas_cumprod)))self.register_buffer("reciprocal_sqrt_alphas", to_torch(np.sqrt(1 / alphas)))self.register_buffer("remove_noise_coeff", to_torch(betas / np.sqrt(1 - alphas_cumprod)))self.register_buffer("sigma", to_torch(np.sqrt(betas)))def update_ema(self):self.step += 1if self.step % self.ema_update_rate == 0:if self.step < self.ema_start:self.ema_model.load_state_dict(self.model.state_dict())else:self.ema.update_model_average(self.ema_model, self.model)@torch.no_grad()def remove_noise(self, x, t, y, use_ema=True):if use_ema:return ((x - extract(self.remove_noise_coeff, t, x.shape) * self.ema_model(x, t, y)) *extract(self.reciprocal_sqrt_alphas, t, x.shape))else:return ((x - extract(self.remove_noise_coeff, t, x.shape) * self.model(x, t, y)) *extract(self.reciprocal_sqrt_alphas, t, x.shape))@torch.no_grad()def sample(self, batch_size, device, y=None, use_ema=True):if y is not None and batch_size != len(y):raise ValueError("sample batch size different from length of given y")x = torch.randn(batch_size, self.img_channels, *self.img_size, device=device)for t in range(self.num_timesteps - 1, -1, -1):t_batch = torch.tensor([t], device=device).repeat(batch_size)x = self.remove_noise(x, t_batch, y, use_ema)if t > 0:x += extract(self.sigma, t_batch, x.shape) * torch.randn_like(x)return x.cpu().detach()@torch.no_grad()def sample_diffusion_sequence(self, batch_size, device, y=None, use_ema=True):if y is not None and batch_size != len(y):raise ValueError("sample batch size different from length of given y")x = torch.randn(batch_size, self.img_channels, *self.img_size, device=device)for t in range(self.num_timesteps - 1, -1, -1):t_batch = torch.tensor([t], device=device).repeat(batch_size)x = self.remove_noise(x, t_batch, y, use_ema)if t > 0:x += extract(self.sigma, t_batch, x.shape) * torch.randn_like(x)yield x.cpu().detach()def perturb_x(self, x, t, noise):return (extract(self.sqrt_alphas_cumprod, t, x.shape) * x +extract(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * noise)   def get_losses(self, x, t, y):noise = torch.randn_like(x)perturbed_x = self.perturb_x(x, t, noise)estimated_noise = self.model(perturbed_x, t, y)if self.loss_type == "l1":loss = F.l1_loss(estimated_noise, noise)elif self.loss_type == "l2":loss = F.mse_loss(estimated_noise, noise)return lossdef forward(self, x, y=None):b, c, h, w = x.shapedevice = x.deviceif h != self.img_size[0]:raise ValueError("image height does not match diffusion parameters")if w != self.img_size[0]:raise ValueError("image width does not match diffusion parameters")t = torch.randint(0, self.num_timesteps, (b,), device=device)return self.get_losses(x, t, y)def generate_cosine_schedule(T, s=0.008):def f(t, T):return (np.cos((t / T + s) / (1 + s) * np.pi / 2)) ** 2alphas = []f0 = f(0, T)for t in range(T + 1):alphas.append(f(t, T) / f0)betas = []for t in range(1, T + 1):betas.append(min(1 - alphas[t] / alphas[t - 1], 0.999))return np.array(betas)def generate_linear_schedule(T, low, high):return np.linspace(low, high, T)

三、代码理解

Input:
x: (N, img_channels, *img_size)
y: (N)
Output:
scalar loss tensor
Args:
model (nn.Module):估计高斯噪声的模型
img_size (tuple): (H, W)
img_channels (int): 图像通道数
betas (np.ndarray): diffusion betas 数组
loss_type (string): loss type, “l1” or “l2” 类型
ema_decay (float): model weights exponential moving average decay
ema_start (int): number of steps before EMA
ema_update_rate (int): number of steps before each EMA update
“”"

def init(self,model,img_size,img_channels,num_classes,betas, loss_type=“l2”, ema_decay=0.9999, ema_start=5000, ema_update_rate=1,):

在这里插入图片描述
np.cumprod返回数组沿指定轴的累计积。
a=[a1,a2,a3,a4,a5]
np.cumprod(a)=array([a1,a1a2,a1a2a3,a1a2a3a4,a1a2a3a4a5])。

def remove_noise(self, x, t, y, use_ema=True):

(x - extract(self.remove_noise_coeff, t, x.shape) * self.model(x, t, y)) *extract(self.reciprocal_sqrt_alphas, t, x.shape)

这个函数就是去除第t-1到第t步的噪声
在这里插入图片描述
在这个函数里面调用了extract函数。实现的功能:提取时间步t时对应的参数

def extract(a, t, x_shape):b, *_ = t.shapeout = a.gather(-1, t)return out.reshape(b, *((1,) * (len(x_shape) - 1)))

a: Tensor:(1000,)
t: Tensor:(128,)
x_shape: torch.Size([128, 1, 28, 28])
最终返回的是Tensor:(128,1,1,1)
模型定义在初始化函数中,模型调用定义在forward函数中。

def sample(self, batch_size, device, y=None, use_ema=True):

    def sample(self, batch_size, device, y=None, use_ema=True):if y is not None and batch_size != len(y):raise ValueError("sample batch size different from length of given y")x = torch.randn(batch_size, self.img_channels, *self.img_size, device=device)for t in range(self.num_timesteps - 1, -1, -1):t_batch = torch.tensor([t], device=device).repeat(batch_size)x = self.remove_noise(x, t_batch, y, use_ema)if t > 0:x += extract(self.sigma, t_batch, x.shape) * torch.randn_like(x)return x.cpu().detach()

def perturb_x(self, x, t, noise):

在图像中添加噪声

        return (extract(self.sqrt_alphas_cumprod, t, x.shape) * x +extract(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * noise)   

在这里插入图片描述

def get_losses(self, x, t, y):

计算添加噪声和估计噪声的损失

        noise = torch.randn_like(x)perturbed_x = self.perturb_x(x, t, noise)estimated_noise = self.model(perturbed_x, t, y)if self.loss_type == "l1":loss = F.l1_loss(estimated_noise, noise)elif self.loss_type == "l2":loss = F.mse_loss(estimated_noise, noise)return loss

def forward(self, x, y=None):

前向函数很简单,随机b个t,然后计算对应的噪声损失。

    def forward(self, x, y=None):b, c, h, w = x.shapedevice = x.deviceif h != self.img_size[0]:raise ValueError("image height does not match diffusion parameters")if w != self.img_size[0]:raise ValueError("image width does not match diffusion parameters")t = torch.randint(0, self.num_timesteps, (b,), device=device)return self.get_losses(x, t, y)

def generate_cosine_schedule(T, s=0.008)和def generate_linear_schedule(T, low, high):

这个函数就是两种不同的生成betas的方法。betas数组是从小到大排列的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/41076.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MySQL 集群

MySQL 集群有多种类型&#xff0c;每种类型都有其特定的用途和优势。以下是一些常见的 MySQL 集群解决方案&#xff1a; 1. MySQL Replication 描述&#xff1a;MySQL 复制是一种异步复制机制&#xff0c;允许将一个 MySQL 数据库的数据复制到一个或多个从服务器。 用途&…

一拖二快充线:生活充电新风尚,高效便捷解决双设备充电难题

一拖二快充线在生活应用领域的优势与双接充电的便携性问题 在现代快节奏的生活中&#xff0c;电子设备已成为我们不可或缺的日常伴侣。无论是智能手机、平板电脑还是笔记本电脑&#xff0c;它们在我们的工作、学习和娱乐中扮演着至关重要的角色。然而&#xff0c;随着设备数量…

产品经理系列1—如何实现一个电商系统

具体笔记如下&#xff0c;主要按获客—找货—下单—售后四个部分进行模块拆解

代码随想录算法训练Day58|LeetCode417-太平洋大西洋水流问题、LeetCode827-最大人工岛

太平洋大西洋水流问题 力扣417-太平洋大西洋水流问题 有一个 m n 的矩形岛屿&#xff0c;与 太平洋 和 大西洋 相邻。 “太平洋” 处于大陆的左边界和上边界&#xff0c;而 “大西洋” 处于大陆的右边界和下边界。 这个岛被分割成一个由若干方形单元格组成的网格。给定一个…

【Unity】unity学习扫盲知识点

1、建议检查下SystemInfo的引用。这个是什么 Unity的SystemInfo类提供了一种获取关于当前硬件和操作系统的信息的方法。这包括设备类型&#xff0c;操作系统&#xff0c;处理器&#xff0c;内存&#xff0c;显卡&#xff0c;支持的Unity特性等。使用SystemInfo类非常简单。它的…

Linux 查看磁盘是不是 ssd 的方法

lsblk 命令检查 $ lsblk -d -o name,rota如果 ROTA 值为 1&#xff0c;则磁盘类型为 HDD&#xff0c;如果 ROTA 值为 0&#xff0c;则磁盘类型为 SSD。可以在上面的屏幕截图中看到 sda 的 ROTA 值是 1&#xff0c;表示它是 HDD。 2. 检查磁盘是否旋转 $ cat /sys/block/sda/q…

机器学习之保存与加载

前言 模型的数据需要存储和加载&#xff0c;这节介绍存储和加载的方式方法。 存和加载模型权重 保存模型使用save_checkpoint接口&#xff0c;传入网络和指定的保存路径&#xff0c;要加载模型权重&#xff0c;需要先创建相同模型的实例&#xff0c;然后使用load_checkpoint…

Autosar Dcm配置-0x85服务配置及使用-基于ETAS软件

文章目录 前言Dcm配置DcmDsdDcmDsp代码实现总结前言 0x85服务用来控制DTC设置的开启和关闭。某OEM3.0架构强制支持0x85服务,本文介绍ETAS工具中的配置 Dcm配置 DcmDsd 配置0x85服务 此处配置只在扩展会话下支持(具体需要根据需求决定),两个子服务Disable为0x02,Enable…

冯诺依曼体系结构与操作系统(Linux)

文章目录 前言冯诺依曼体系结构&#xff08;硬件&#xff09;操作系统&#xff08;软件&#xff09;总结 前言 冯诺依曼体系结构&#xff08;硬件&#xff09; 上图就是冯诺依曼体系结构图&#xff0c;主要包括输入设备&#xff0c;输出设备&#xff0c;存储器&#xff0c;运算…

Go高级库存照片源码v5.3

GoStock – 免费和付费库存照片脚本这是一个免费和付费共享高质量库存照片的平台,用户可以上传照片与整个社区和访客分享,并可以通过 PayPal 接收捐款。此外,用户还可以点赞、评论、分享和收藏您最喜欢的照片。 下载 特征: 使用Laravel 10构建订阅系统Stripe 连接渐进式网页…

从零开始读RocketMq源码(一)生产者启动

目录 前言 获取源码 总概论 生产者实例 源码 A-01:设置生产者组名称 A-02:生产者服务启动 B-01&#xff1a;初始化状态 B-02&#xff1a;该方法再次对生产者组名称进行校验 B-03&#xff1a;判断是否为默认生产者组名称 B-04: 该方法是为了实例化MQClientInstance对…

白嫖A100-interLM大模型部署试用活动,亲测有效-2.Git

申明 以下部分内容来源于活动教学文档&#xff1a; Docs git 安装 是一个开源的分布式版本控制系统&#xff0c;被广泛用于软件协同开发。程序员的必备基础工具。 常用的 Git 操作 git init 初始化一个新的 Git 仓库&#xff0c;在当前目录创建一个 .git 隐藏文件夹来跟踪…

Windows系统下载安装ngnix

一 nginx下载安装 nginx是HTTP服务器和反向代理服务器&#xff0c;功能非常丰富&#xff0c;在nginx官网首页&#xff0c;点击download 在download页面下&#xff0c;可以选择Stable version稳定版本&#xff0c;点击下载 将下载完成的zip解压即可&#xff0c;然乎在nginx所在…

SpringBoot新手快速入门系列教程五:基于JPA的一个Mysql简单读写例子

现在我们来做一个简单的读写Mysql的项目 1&#xff0c;先新建一个项目&#xff0c;我们叫它“HelloJPA”并且添加依赖 2&#xff0c;引入以下依赖&#xff1a; Spring Boot DevTools (可选&#xff0c;但推荐&#xff0c;用于开发时热部署)Lombok&#xff08;可选&#xff0c…

三相感应电机的建模仿真(2)基于ABC相坐标系S-Fun的仿真模型

1. 概述 2. 三相感应电动机状态方程式 3. 基于S-Function的仿真模型建立 4. 瞬态分析实例 5. 总结 6. 参考文献 1. 概述 前面建立的三相感应电机在ABC相坐标系下的数学模型是一组周期性变系数微分方程&#xff08;其电感矩阵是转子位置角的函数&#xff0c;转子位置角随时…

qt 开发笔记堆栈布局的应用

1.概要 画面中有一处位置&#xff0c;有个按钮点击后&#xff0c;这片位置完全换成另一个画面&#xff0c;这中情况特别适合用堆栈布局。 //堆栈布局的应用 #include <QStackedLayout> QStackedLayout *layout new QStackedLayout(this); layout->setCurrentIndex(…

Unity Scrollview的Scrollbar控制方法

备忘&#xff1a;碰到用scrollview自带的scrollbar去控制滑动&#xff0c;结果发现用代码控制scrollbar.value无效&#xff0c;搜了一下都是说用scrollRect.verticalNormalizedPosition和scrollRect.horizontalNormalizedPosition来控制的。我寻思着有关联的scrollbar为什么用不…

Interview preparation--Https 工作流程

HTTP 传输的弊端 如上图&#xff0c;Http进行数据传输的时候是明文传输&#xff0c;导致任何人都有可能截获信息&#xff0c;篡改信息如果此时黑客冒充服务器&#xff0c;或者黑客窃取信息&#xff0c;则其可以返回任意信息给客户端&#xff0c;而且不被客户端察觉&#xff0c;…

2.3.2 主程序和外部IO交互 (文件映射方式)----C#调用范例

2.3.2 主程序和外部IO交互 &#xff08;文件映射方式&#xff09;----C#调用范例 效果显示 1 说明 1 .1 Test_IOServer是64bit 程序&#xff0c; BD_SharedIOServerd.dll 在 /Debug文件夹中 1 .2 Test_IOServer是32bit 程序&#xff0c; BD_SharedIOClientd.dll (32bit&#…

[FreeRTOS 内部实现] 事件组

文章目录 事件组结构体创建事件组事件组等待位事件组设置位 事件组结构体 // 路径&#xff1a;Source/event_groups.c typedef struct xEventGroupDefinition {EventBits_t uxEventBits;List_t xTasksWaitingForBits; } EventGroup_t;uxEventBits 中的每一位表示某个事件是否…