【DLI】Generative AI with Diffusion Models通关秘籍

Generative AI with Diffusion Models,加载时间在20分钟左右,耐心等待。
在这里插入图片描述
在这里插入图片描述

6.2TODO

在这里插入图片描述

这里是在设置扩散模型的参数,代码里的FIXME部分需要根据上下文进行替换。以下是各个FIXME的替换说明:
1.a_bar 是 a 的累积乘积,在 PyTorch 里可以用 torch.cumprod 实现。
2.sqrt_a_bar、sqrt_one_minus_a_bar 和 sqrt_a_inv 都是对输入张量求平方根,可使用 torch.sqrt 实现。
3.pred_noise_coeff 中的 FIXME(1 - a_bar) 同样是求平方根,用 torch.sqrt 即可。
以下是替换后的代码:

nrows = 10
ncols = 15T = nrows * ncols
B_start = 0.0001
B_end = 0.02
B = torch.linspace(B_start, B_end, T).to(device)a = 1.0 - B
a_bar = torch.cumprod(a, dim=0)
sqrt_a_bar = torch.sqrt(a_bar)  # Mean Coefficient
sqrt_one_minus_a_bar = torch.sqrt(1 - a_bar)  # St. Dev. Coefficient# Reverse diffusion variables
sqrt_a_inv = torch.sqrt(1 / a)
pred_noise_coeff = (1 - a) / torch.sqrt(1 - a_bar)  # Predicted Noise Coefficient

在扩散模型里,正向扩散过程 q 函数是按照如下公式把原始图像 x_0 逐步添加噪声变成 x_t 的
在这里插入图片描述
FIXME 部分应该分别用 sqrt_a_bar_t 和 sqrt_one_minus_a_bar_t 来替换。
在这个 q 函数中,按照扩散模型的正向过程公式,把原始图像 x_0 和随机噪声 noise 按一定比例组合,从而得到加噪后的图像 x_t。

def q(x_0, t):t = t.int()noise = torch.randn_like(x_0)sqrt_a_bar_t = sqrt_a_bar[t, None, None, None]sqrt_one_minus_a_bar_t = sqrt_one_minus_a_bar[t, None, None, None]x_t = sqrt_a_bar_t * x_0 + sqrt_one_minus_a_bar_t * noisereturn x_t, noise

在反向扩散过程中,我们要根据当前的潜在图像,当前时间步 , 以及预测的噪声 来恢复上一个时间步的图像。在这里插入图片描述
在这个 reverse_q 函数中,我们根据反向扩散过程的公式,从当前的潜在图像和预测的噪声中恢复上一个时间步的图像。如果当前时间步为 0,则表示反向扩散过程完成。否则,我们会添加一些噪声以模拟扩散过程。下面是对代码中 FIXME 部分的分析与替换:

@torch.no_grad()
def reverse_q(x_t, t, e_t):t = t.int()pred_noise_coeff_t = pred_noise_coeff[t]sqrt_a_inv_t = sqrt_a_inv[t]u_t = sqrt_a_inv_t * (x_t - pred_noise_coeff_t * e_t)if t[0] == 0:  # All t values should be the samereturn u_t  # Reverse diffusion complete!else:B_t = B[t - 1]  # Apply noise from the previous timestepnew_noise = torch.randn_like(x_t)return u_t + torch.sqrt(B_t) * new_noise

在这里插入图片描述

6.3TODO

在这里插入图片描述

每个类的功能来添加正确模块名 依次改写FIXME 即可:

DownBlock进行下采样操作,包含卷积和池化相关的块
EmbedBlock将输入进行线性变换和激活
GELUConvBlock使用了卷积、组归一化和 GELU 激活函数,通常是一个卷积块
RearrangePoolBlock使用了 Rearrange 进行张量重排和卷积操作
ResidualConvBlock使用了两个卷积块并进行了残差连接
SinusoidalPositionEmbedBlock实现了正弦位置嵌入的功能
UpBlock上采样操作,包含转置卷积和卷积块

6.4TODO

在这个 get_context_mask 函数里,其目的是随机丢弃上下文信息。要实现随机丢弃,通常会使用 torch.bernoulli 函数。torch.bernoulli 函数会依据给定的概率来生成一个二进制掩码张量,其中每个元素为 1 的概率就是传入的概率值。
在这个函数中,我们希望以 drop_prob 的概率丢弃上下文,所以每个元素保留的概率是 1 - drop_prob。因此,FIXME 处应该填入 bernoulli。

def get_context_mask(c, drop_prob):c_hot = F.one_hot(c.to(torch.int64), num_classes=N_CLASSES).to(device)c_mask = torch.bernoulli(torch.ones_like(c_hot).float() * (1 - drop_prob)).to(device)return c_hot, c_mask

代码解释:
c_hot = F.one_hot(c.to(torch.int64), num_classes=N_CLASSES).to(device):将输入的 c 转换为独热编码向量,并且移动到指定的设备(如 GPU)上。
c_mask = torch.bernoulli(torch.ones_like(c_hot).float() * (1 - drop_prob)).to(device):生成一个与 c_hot 形状相同的二进制掩码张量,每个元素以 1 - drop_prob 的概率为 1,以 drop_prob 的概率为 0。
return c_hot, c_mask:返回独热编码向量和二进制掩码张量。
这样,你就可以使用这个函数来随机丢弃上下文信息了。

在这里插入图片描述

在扩散模型里,通常采用均方误差损失(Mean Squared Error Loss,MSE)来衡量预测噪声 noise_pred 和实际添加的噪声 noise 之间的差异。因为均方误差能够很好地衡量两个向量之间的平均平方误差,这对于扩散模型中预测噪声的准确性评估是很合适的。
在 PyTorch 中,nn.functional.mse_loss 函数可用于计算均方误差损失。所以 FIXME 处应填入 mse_loss。

def get_loss(model, x_0, t, *model_args):x_noisy, noise = q(x_0, t)noise_pred = model(x_noisy, t/T, *model_args)return F.mse_loss(noise, noise_pred)

代码解释
x_noisy, noise = q(x_0, t):调用 q 函数给原始图像 x_0 添加噪声,得到加噪后的图像 x_noisy 以及实际添加的噪声 noise。
noise_pred = model(x_noisy, t/T, *model_args):把加噪后的图像 x_noisy 和归一化后的时间步 t/T 输入到模型 model 中,得到模型预测的噪声 noise_pred。
return F.mse_loss(noise, noise_pred):使用 F.mse_loss 函数计算实际噪声 noise 和预测噪声 noise_pred 之间的均方误差损失并返回。
通过使用均方误差损失,模型能够学习到如何更准确地预测添加到图像中的噪声,从而在反向扩散过程中更好地恢复原始图像。

下一个 TODO

  1. c_drop_prob 的设置
    c_drop_prob 是上下文丢弃概率,一般在训练过程中会采用线性衰减策略,也就是在训练初期以较高概率丢弃上下文,随着训练的推进逐渐降低丢弃概率。在代码中,我们可以简单地将其设置为一个随着训练轮数逐渐降低的值。
  2. get_context_mask 函数的输入
    get_context_mask 函数需要一个上下文标签作为输入,在代码里这个标签应该从 batch 中获取。通常假设 batch 的第二个元素为上下文标签。

optimizer = Adam(model.parameters(), lr=0.001)
epochs = 5
preview_c = 0model.train()
for epoch in range(epochs):# 线性衰减上下文丢弃概率c_drop_prob = max(0.1, 1 - epoch / epochs)  #这里我调整了顺序for step, batch in enumerate(dataloader):optimizer.zero_grad()t = torch.randint(0, T, (BATCH_SIZE,), device=device).float()x = batch[0].to(device)# 假设 batch 的第二个元素是上下文标签c = batch[1].to(device)c_hot, c_mask = get_context_mask(c, c_drop_prob)loss = get_loss(model, x, t, c_hot, c_mask)loss.backward()optimizer.step()if epoch % 1 == 0 and step % 100 == 0:print(f"Epoch {epoch} | Step {step:03d} | Loss: {loss.item()} | C: {preview_c}")c_drop_prob = 0  # Do not drop context for previewc_hot, c_mask = get_context_mask(torch.Tensor([preview_c]).to(device), c_drop_prob)sample_images(model, IMG_CH, IMG_SIZE, ncols, c_hot, c_mask)preview_c = (preview_c + 1) % N_CLASSES

代码解释
c_drop_prob 的设置:运用线性衰减策略,在训练初期 c_drop_prob 为 0.9,随着训练的推进逐渐降低到 0.1。
get_context_mask 函数的输入:假设 batch 的第二个元素是上下文标签,将其传入 get_context_mask 函数。
训练过程:在每个训练步骤中,先将梯度清零,接着计算损失,再进行反向传播和参数更新。每训练 100 个步骤,就打印一次损失信息并进行一次样本生成。
通过这些修改,代码就能正常运行,从而开始训练模型。
在这里插入图片描述

6.5TODO

在扩散模型的采样过程中,为了给扩散过程添加权重,一般会根据给定的权重 w 对保留上下文的预测噪声 e_t_keep_c 和丢弃上下文的预测噪声 e_t_drop_c 进行加权组合。在这里插入图片描述
在代码中,FIXME 处应该根据上述公式进行计算,将 e_t_keep_c 和 e_t_drop_c 按照权重 w 进行组合。具体的代码如下:

def sample_w(model, c, w):input_size = (IMG_CH, IMG_SIZE, IMG_SIZE)n_samples = len(c)w = torch.tensor([w]).float()w = w[:, None, None, None].to(device)  # Make w broadcastablex_t = torch.randn(n_samples, *input_size).to(device)# One c for each wc = c.repeat(len(w), 1)# Double the batchc = c.repeat(2, 1)# Don't drop context at test timec_mask = torch.ones_like(c).to(device)c_mask[n_samples:] = 0.0x_t_store = []for i in range(0, T)[::-1]:# Duplicate t for each samplet = torch.tensor([i]).to(device)t = t.repeat(n_samples, 1, 1, 1)# Double the batchx_t = x_t.repeat(2, 1, 1, 1)t = t.repeat(2, 1, 1, 1)# Find weighted noisee_t = model(x_t, t, c, c_mask)e_t_keep_c = e_t[:n_samples]e_t_drop_c = e_t[n_samples:]e_t = w * e_t_keep_c + (1 - w) * e_t_drop_c# Deduplicate batch for reverse diffusionx_t = x_t[:n_samples]t = t[:n_samples]x_t = reverse_q(x_t, t, e_t)return x_t

## TODO

在扩散模型里,权重 w 可用于控制上下文信息在生成过程中的影响程度。w 值越接近 1,生成结果就越依赖上下文信息;w 值越接近 0,生成结果受上下文信息的影响就越小。若要让生成的数字能够被持续识别,你可以试着增大 w 的值,以此增强上下文信息对生成过程的影响。
下面是修改后的代码,你可以调整 w 的值来观察生成结果:

model.eval()
w = 5.0  # 可以尝试不同的值,通常大于 1 能增强上下文的影响
c = torch.arange(N_CLASSES).to(device)
c_drop_prob = 0 
c_hot, c_mask = get_context_mask(c, c_drop_prob)x_0 = sample_w(model, c_hot, w)
other_utils.to_image(make_grid(x_0.cpu(), nrow=N_CLASSES))

代码解释
w = 5.0:把 w 的值设为 5.0,你可以根据实际情况调整这个值。通常,当 w 大于 1 时,上下文信息的影响会得到增强,这样生成的数字可能会更易于识别。
x_0 = sample_w(model, c_hot, w):调用 sample_w 函数生成图像,将 w 作为参数传入。
other_utils.to_image(make_grid(x_0.cpu(), nrow=N_CLASSES)):把生成的图像转换为可视化的形式。
你可以多次运行这段代码,并且调整 w 的值,直到生成的数字能够被稳定识别。

至此结束。
在这里插入图片描述

完整代码都在图片里

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/75483.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何在本地部署魔搭上千问Qwen2.5-VL-32B-Instruct-AWQ模型在显卡1上面运行推理,并开启api服务

环境: 云服务器Ubuntu NVIDIA H20 96GB Qwen2.5-VL-32B Qwen2.5-VL-72B 问题描述: 如何在本地部署魔搭上千问Qwen2.5-VL-32B-Instruct-AWQ模型在显卡1上面运行推理,并开启api服务 解决方案: 1.环境准备 硬件要求 显卡1(显存需≥48GB,推荐≥64GB)CUDA 11.7或更高…

基于方法分类的无监督图像去雾论文

在之前的博客中,我从研究动机的角度对无监督图像去雾论文进行了分类,而现在我打算根据论文中提出的方法进行新的分类。 1. 基于对比学习的方法 2022年 论文《UCL-Dehaze: Towards Real-world Image Dehazing via Unsupervised Contrastive Learning》&a…

4月3号.

JDK7前时间相关类: 时间的相关知识: Data时间类: //1.创建对象表示一个时间 Date d1 new Date(); //System.out.println(d1);//2.创建对象表示一个指定的时间 Date d2 new Date(0L); System.out.println(d2);//3.setTime修改时间 //1000毫秒1秒 d2.setTime(1000L); System.o…

数据结构与算法:子数组最大累加和问题及扩展

前言 子数组最大累加和问题看似简单,但能延伸出的题目非常多,千题千面,而且会和其他算法结合出现。 一、最大子数组和 class Solution { public:int maxSubArray(vector<int>& nums) {int n=nums.size();vector<int>dp(n);//i位置往左能延伸出的最大累加…

MIT6.828 Lab3-2 Print a page table (easy)

实验内容 实现一个函数来打印页表的内容&#xff0c;帮助我们更好地理解 xv6 的三级页表结构。 修改内容 kernel/defs.h中添加函数声明&#xff0c;方便其它函数调用 void vmprint(pagetable_t);// lab3-2 Print a page tablekernel/vm.c中添加函数具体定义 采用…

2025高频面试设计模型总结篇

文章目录 设计模型概念单例模式工厂模式策略模式责任链模式 设计模型概念 设计模式是前人总结的软件设计经验和解决问题的最佳方案&#xff0c;它们为我们提供了一套可复用、易维护、可扩展的设计思路。 &#xff08;1&#xff09;定义&#xff1a; 设计模式是一套经过验证的…

Java基础:面向对象进阶(二)

01-static static修饰成员方法 static注意事项&#xff08;3种&#xff09; static应用知识&#xff1a;代码块 static应用知识&#xff1a;单列模式 02-面向对象三大特征之二&#xff1a;继承 什么是继承&#xff1f; 使用继承有啥好处? 权限修饰符 单继承、Object类 方法重…

Spring框架如何做EhCache缓存?

在Spring框架中&#xff0c;缓存是一种常见的优化手段&#xff0c;用于减少对数据库或其他资源的访问次数&#xff0c;从而提高应用性能。Spring提供了强大的缓存抽象&#xff0c;支持多种缓存实现&#xff08;如EhCache、Redis、Caffeine等&#xff09;&#xff0c;并可以通过…

NVIDIA显卡

NVIDIA显卡作为全球GPU技术的标杆&#xff0c;其产品线覆盖消费级、专业级、数据中心、移动计算等多个领域&#xff0c;技术迭代贯穿架构创新、AI加速、光线追踪等核心方向。以下从技术演进、产品矩阵、核心技术、生态布局四个维度展开深度解析&#xff1a; 一、技术演进&…

【BUG】生产环境死锁问题定位排查解决全过程

目录 生产环境死锁问题定位排查解决过程0. 表面现象1. 问题分析&#xff08;1&#xff09;数据库连接池资源耗尽&#xff08;2&#xff09;数据库锁竞争(3) 代码实现问题 2. 分析解决(0) 分析过程&#xff08;1&#xff09;优化数据库连接池配置&#xff08;2&#xff09;优化数…

【计算机网络应用层】

文章目录 计算机网络应用层详解一、前言二、应用层的功能三、常见的应用层协议1. HTTP/HTTPS&#xff08;超文本传输协议&#xff09;2. DNS&#xff08;域名系统&#xff09;3. FTP&#xff08;文件传输协议&#xff09;4. SMTP/POP3/IMAP&#xff08;电子邮件协议&#xff09…

Linux 虚拟化方案

一、Linux 虚拟化技术分类 1. 全虚拟化 (Full Virtualization) 特点&#xff1a;Guest OS 无需修改&#xff0c;完全模拟硬件 代表技术&#xff1a; KVM (Kernel-based Virtual Machine)&#xff1a;主流方案&#xff0c;集成到 Linux 内核 QEMU&#xff1a;硬件模拟器&…

树莓派 5 换清华源

首先备份原设置 cp /etc/apt/sources.list ~/sources.list.bak cp /etc/apt/sources.list.d/raspi.list ~/raspi.list.bak修改配置 /etc/apt/sources.list 文件替换内容如下&#xff08;原内容删除&#xff09; deb https://mirrors.tuna.tsinghua.edu.cn/debian/ bookworm …

WGAN原理及实现(pytorch版)

WGAN原理及实现 一、WGAN原理1.1 原始GAN的缺陷1.2 Wasserstein距离的引入1.3 Kantorovich-Rubinstein对偶1.4 WGAN的优化目标1.4 数学推导步骤1.5 权重裁剪 vs 梯度惩罚1.6 优势1.7 总结 二、WGAN实现2.1 导包2.2 数据加载和处理2.3 构建生成器2.4 构建判别器2.5 训练和保存模…

Unity网络开发基础 (3) Socket入门 TCP同步连接 与 简单封装练习

本文章不作任何商业用途 仅作学习与交流 教程来自Unity唐老狮 关于练习题部分是我观看教程之后自己实现 所以和老师写法可能不太一样 唐老师说掌握其基本思路即可,因为前端程序一般不需要去写后端逻辑 1.认识Socket的重要API Socket是什么 Socket&#xff08;套接字&#xff0…

【linux】一文掌握 ssh和scp 指令的详细用法(ssh和scp 备忘速查)

文章目录 入门连接执行SCP配置位置SCP 选项配置示例ProxyJumpssh-copy-id SSH keygenssh-keygen产生钥匙类型known_hosts密钥格式 此快速参考备忘单提供了使用 SSH 的各种方法。 参考&#xff1a; OpenSSH 配置文件示例 (cyberciti.biz)ssh_config (linux.die.net) 入门 连…

真实笔试题

文章目录 线程题树的深度遍历 线程题 实现一个类支持100个线程同时向一个银行账户中存入一元钱.需通过同步机制消除竞态条件,当所有线程执行完成后,账户余额必须精确等于100元 package com.itheima.thread;public class ShowMeBug {private double balance; // 账户余额priva…

2.2 路径问题专题:LeetCode 63. 不同路径 II

动态规划解决LeetCode 63题&#xff1a;不同路径 II&#xff08;含障碍物&#xff09; 1. 题目链接 LeetCode 63. 不同路径 II 2. 题目描述 一个机器人位于 m x n 网格的左上角&#xff0c;每次只能向右或向下移动一步。网格中可能存在障碍物&#xff08;标记为 1&#xff…

2874. 有序三元组中的最大值 II

给你一个下标从 0 开始的整数数组 。nums 请你从所有满足 的下标三元组 中&#xff0c;找出并返回下标三元组的最大值。 如果所有满足条件的三元组的值都是负数&#xff0c;则返回 。i < j < k(i, j, k)0 下标三元组 的值等于 。(i, j, k)(nums[i] - nums[j]) * nums[k…

【论文笔记】Llama 3 技术报告

Llama 3中的顶级模型是一个拥有4050亿参数的密集Transformer模型&#xff0c;并且它的上下文窗口长度可以达到128,000个tokens。这意味着它能够处理非常长的文本&#xff0c;记住和理解更多的信息。Llama 3.1的论文长达92页&#xff0c;详细描述了模型的开发阶段、优化策略、模…