深入浅出 diffusion(3):pytorch 实现 diffusion 中的 U-Net

导入python包

import mathimport torch
import torch.nn as nn
import torch.nn.functional as F

 silu激活函数

class SiLU(nn.Module):  # SiLU激活函数@staticmethoddef forward(x):return x * torch.sigmoid(x)

归一化设置

def get_norm(norm, num_channels, num_groups):if norm == "in":return nn.InstanceNorm2d(num_channels, affine=True)elif norm == "bn":return nn.BatchNorm2d(num_channels)elif norm == "gn":return nn.GroupNorm(num_groups, num_channels)elif norm is None:return nn.Identity()else:raise ValueError("unknown normalization type")

 计算时间步长的位置嵌入,一半为sin,一半为cos

class PositionalEmbedding(nn.Module):def __init__(self, dim, scale=1.0):super().__init__()assert dim % 2 == 0self.dim = dimself.scale = scaledef forward(self, x):device      = x.devicehalf_dim    = self.dim // 2emb = math.log(10000) / half_dimemb = torch.exp(torch.arange(half_dim, device=device) * -emb)# x * self.scale和emb外积emb = torch.outer(x * self.scale, emb)emb = torch.cat((emb.sin(), emb.cos()), dim=-1)return emb

 上下采样层设置

class Downsample(nn.Module):def __init__(self, in_channels):super().__init__()self.downsample = nn.Conv2d(in_channels, in_channels, 3, stride=2, padding=1)def forward(self, x, time_emb, y):if x.shape[2] % 2 == 1:raise ValueError("downsampling tensor height should be even")if x.shape[3] % 2 == 1:raise ValueError("downsampling tensor width should be even")return self.downsample(x)class Upsample(nn.Module):def __init__(self, in_channels):super().__init__()self.upsample = nn.Sequential(nn.Upsample(scale_factor=2, mode="nearest"),nn.Conv2d(in_channels, in_channels, 3, padding=1),)def forward(self, x, time_emb, y):return self.upsample(x)

 使用Self-Attention注意力机制,做一个全局的Self-Attention

class AttentionBlock(nn.Module):def __init__(self, in_channels, norm="gn", num_groups=32):super().__init__()self.in_channels = in_channelsself.norm = get_norm(norm, in_channels, num_groups)self.to_qkv = nn.Conv2d(in_channels, in_channels * 3, 1)self.to_out = nn.Conv2d(in_channels, in_channels, 1)def forward(self, x):b, c, h, w  = x.shapeq, k, v     = torch.split(self.to_qkv(self.norm(x)), self.in_channels, dim=1)q = q.permute(0, 2, 3, 1).view(b, h * w, c)k = k.view(b, c, h * w)v = v.permute(0, 2, 3, 1).view(b, h * w, c)dot_products = torch.bmm(q, k) * (c ** (-0.5))assert dot_products.shape == (b, h * w, h * w)attention   = torch.softmax(dot_products, dim=-1)out         = torch.bmm(attention, v)assert out.shape == (b, h * w, c)out         = out.view(b, h, w, c).permute(0, 3, 1, 2)return self.to_out(out) + x

 用于特征提取的残差结构

class ResidualBlock(nn.Module):def __init__(self, in_channels, out_channels, dropout, time_emb_dim=None, num_classes=None, activation=F.relu,norm="gn", num_groups=32, use_attention=False,):super().__init__()self.activation = activationself.norm_1 = get_norm(norm, in_channels, num_groups)self.conv_1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)self.norm_2 = get_norm(norm, out_channels, num_groups)self.conv_2 = nn.Sequential(nn.Dropout(p=dropout), nn.Conv2d(out_channels, out_channels, 3, padding=1),)self.time_bias  = nn.Linear(time_emb_dim, out_channels) if time_emb_dim is not None else Noneself.class_bias = nn.Embedding(num_classes, out_channels) if num_classes is not None else Noneself.residual_connection    = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()self.attention              = nn.Identity() if not use_attention else AttentionBlock(out_channels, norm, num_groups)def forward(self, x, time_emb=None, y=None):out = self.activation(self.norm_1(x))# 第一个卷积out = self.conv_1(out)# 对时间time_emb做一个全连接,施加在通道上if self.time_bias is not None:if time_emb is None:raise ValueError("time conditioning was specified but time_emb is not passed")out += self.time_bias(self.activation(time_emb))[:, :, None, None]# 对种类y_emb做一个全连接,施加在通道上if self.class_bias is not None:if y is None:raise ValueError("class conditioning was specified but y is not passed")out += self.class_bias(y)[:, :, None, None]out = self.activation(self.norm_2(out))# 第二个卷积+残差边out = self.conv_2(out) + self.residual_connection(x)# 最后做个Attentionout = self.attention(out)return out

 U-Net模型设计

class UNet(nn.Module):def __init__(self, img_channels, base_channels=128, channel_mults=(1, 2, 2, 2),num_res_blocks=2, time_emb_dim=128 * 4, time_emb_scale=1.0, num_classes=None, activation=F.silu,dropout=0.1, attention_resolutions=(1,), norm="gn", num_groups=32, initial_pad=0,):super().__init__()# 使用到的激活函数,一般为SILUself.activation = activation# 是否对输入进行paddingself.initial_pad = initial_pad# 需要去区分的类别数self.num_classes = num_classes# 对时间轴输入的全连接层self.time_mlp = nn.Sequential(PositionalEmbedding(base_channels, time_emb_scale),nn.Linear(base_channels, time_emb_dim),nn.SiLU(),nn.Linear(time_emb_dim, time_emb_dim),) if time_emb_dim is not None else None# 对输入图片的第一个卷积self.init_conv  = nn.Conv2d(img_channels, base_channels, 3, padding=1)# self.downs用于存储下采样用到的层,首先利用ResidualBlock提取特征# 然后利用Downsample降低特征图的高宽self.downs      = nn.ModuleList()self.ups        = nn.ModuleList()# channels指的是每一个模块处理后的通道数# now_channels是一个中间变量,代表中间的通道数channels        = [base_channels]now_channels    = base_channelsfor i, mult in enumerate(channel_mults):out_channels = base_channels * multfor _ in range(num_res_blocks):self.downs.append(ResidualBlock(now_channels, out_channels, dropout,time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation,norm=norm, num_groups=num_groups, use_attention=i in attention_resolutions,))now_channels = out_channelschannels.append(now_channels)if i != len(channel_mults) - 1:self.downs.append(Downsample(now_channels))channels.append(now_channels)# 可以看作是特征整合,中间的一个特征提取模块self.mid = nn.ModuleList([ResidualBlock(now_channels, now_channels, dropout,time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation,norm=norm, num_groups=num_groups, use_attention=True,),ResidualBlock(now_channels, now_channels, dropout,time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation, norm=norm, num_groups=num_groups, use_attention=False,),])# 进行上采样,进行特征融合for i, mult in reversed(list(enumerate(channel_mults))):out_channels = base_channels * multfor _ in range(num_res_blocks + 1):self.ups.append(ResidualBlock(channels.pop() + now_channels, out_channels, dropout, time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation, norm=norm, num_groups=num_groups, use_attention=i in attention_resolutions,))now_channels = out_channelsif i != 0:self.ups.append(Upsample(now_channels))assert len(channels) == 0self.out_norm = get_norm(norm, base_channels, num_groups)self.out_conv = nn.Conv2d(base_channels, img_channels, 3, padding=1)def forward(self, x, time=None, y=None):# 是否对输入进行paddingip = self.initial_padif ip != 0:x = F.pad(x, (ip,) * 4)# 对时间轴输入的全连接层if self.time_mlp is not None:if time is None:raise ValueError("time conditioning was specified but tim is not passed")time_emb = self.time_mlp(time)else:time_emb = Noneif self.num_classes is not None and y is None:raise ValueError("class conditioning was specified but y is not passed")# 对输入图片的第一个卷积x = self.init_conv(x)# skips用于存放下采样的中间层skips = [x]for layer in self.downs:x = layer(x, time_emb, y)skips.append(x)# 特征整合与提取for layer in self.mid:x = layer(x, time_emb, y)# 上采样并进行特征融合for layer in self.ups:if isinstance(layer, ResidualBlock):x = torch.cat([x, skips.pop()], dim=1)x = layer(x, time_emb, y)# 上采样并进行特征融合x = self.activation(self.out_norm(x))x = self.out_conv(x)if self.initial_pad != 0:return x[:, :, ip:-ip, ip:-ip]else:return x

参考链接:GitCode - 开发者的代码家园icon-default.png?t=N7T8https://gitcode.com/bubbliiiing/ddpm-pytorch/tree/master?utm_source=csdn_github_accelerator&isLogin=1

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/651306.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

TensorRT英伟达官方示例解析(一)

系列文章目录 TensorRT英伟达官方示例解析(一) TensorRT英伟达官方示例解析(二) TensorRT英伟达官方示例解析(三) 文章目录 系列文章目录前言一、参考资料二、配置系统环境三、00-MNISTData四、01-SimpleD…

银行数据仓库体系实践(4)--数据抽取和加载

1、ETL和ELT ETL是Extract、Transfrom、Load即抽取、转换、加载三个英文单词首字母的集合: E:抽取,从源系统(Souce)获取数据; T:转换,将源系统获取的数据进行处理加工,比如数据格式转化、数据精…

【labVIEW】学习记录

【labVIEW】学习记录 一、简介二、安装及激活三、使用 回到目录 一、简介 labVIEW(Laboratory Virtual Instrument Engineering Workbench)是一款由美国国家仪器公司(National Instruments)开发的可视化编程环境和开发平台。LabV…

3.3 实验三:以太网链路聚合实验

HCIA-Datacom实验指导手册:3.3 实验三:以太网链路聚合实验 一、实验介绍:二、实验拓扑:三、实验目的:四、配置步骤:步骤 1 掌握使用手动模式配置链路聚合的方法步骤 2 掌握使用静态 LACP 模式配置链路聚合的…

【JavaEE进阶】 数据库连接池与MySQL企业开发规范

文章目录 🌴数据库连接池🎋数据库连接池的使用🎄MySQL企业开发规范⭕总结🌴数据库连接池 数据库连接池负责分配、管理和释放数据库连接,它允许应⽤程序重复使⽤⼀个现有的数据库连接,⽽不是再重新建⽴⼀个. 没有使⽤数据库连接池的情况:每次执⾏SQL语句,要先创建⼀…

Linux系统——函数与数组

目录 一、函数 1.函数的定义 2.使用函数 3.定义函数的方法 4.函数举例 4.1判断操作系统 4.2判断ip地址 5.查看函数列表 6.删除函数 7.函数返回值——Return 8.函数的作用范围 9.函数传参 10.函数递归 10.1病毒 10.2阶乘 10.2.1 用for循环 10.2.2函数阶乘 10.…

Python实战项目Excel拆分与合并——合并篇

在实际工作中,我们经常会遇到各种表格的拆分与合并的情况。如果只是少量表,手动操作还算可行,但是如果是几十上百张表,最好使用Python编程进行自动化处理。下面介绍两种拆分案例场景,如何用Pandas实现Excel文件的合并。…

ELK日志解决方案

ELK日志解决方案 ELK套件日志系统应该是Elasticsearch使用最广泛的场景之一了,Elasticsearch支持海量数据的存储和查询,特别适合日志搜索场景。广泛使用的ELK套件(Elasticsearch、Logstash、Kibana)是日志系统最经典的案例,使用Logstash和Be…

Scikit-learn (sklearn)速通 -【莫凡Python学习笔记】

视频教程链接:【莫烦Python】Scikit-learn (sklearn) 优雅地学会机器学习 视频教程代码 scikit-learn官网 莫烦官网学习链接 本人matplotlib、numpy、pandas笔记 1 为什么学习 Scikit learn 也简称 sklearn, 是机器学习领域当中最知名的 python 模块之一. Sk…

burp靶场--WebSockets安全漏洞

burp靶场–WebSockets安全漏洞 https://portswigger.net/web-security/websockets/what-are-websockets ### 什么是 WebSocket? WebSocket是一种通过 HTTP 发起的双向、全双工通信协议。它们通常在现代 Web 应用程序中用于流数据和其他异步流量。 在本节中&#x…

ChatGPT 官方中文页面上线

根据页面显示,OpenAI 现已推出 ChatGPT 的多语言功能 Alpha 版测试,允许用户选择不同语言的界面进行交互。 如下图所示,ChatGPT 会检测系统当前所使用的语言,并提示用户进行语言切换。 用户也可通过设置页面选择其他语言。目前&a…

Linux之系统安全与应用

Linux系统提供了多种机制来确保用户账号的正当,安全使用。 系统安全措施 一. 清理系统账号 1.1 将用户设置为无法登录 Linux系统中除手动创建的各种账号外,还包括随系统或程序安装过程而生成的其他大量账号。除了超级用户root以外,其他的…

免费开源的微信小程序源码、小游戏源码精选70套!

微信小程序已经成为我们日常的一部分了,也基本是每个程序员都会涉及的内容,今天给大家分享从网络收集的70个小程序源码。其中这些源码包含:小游戏到商城小程序,再到实用的工具小程序,以及那些令人惊叹的防各大站点的小…

【Linux】文件描述符 | 重定向 | C文件指针与fd的关系 | 用户级缓冲区

文章目录 一、文件描述符1. 理解:Linux下一切皆文件2. 文件描述符(fd)的概念3. 文件描述符的分配规则4. 进程创建时默认打开的 0 & 1 & 2 号文件 二、重定向1. 重定向的本质2. 使用dup2系统调用函数3. bash下的三种重定向4. 三种重定…

全面理解“张量”概念

1. 多重视角看“张量” 张量(Tensor)是一个多维数组的概念,在不同的学科领域中有不同的应用和解释: 物理学中的张量: 在物理学中,张量是一个几何对象,用来表示在不同坐标系下变换具有特定规律的…

(N-141)基于springboot,vue网上拍卖平台

开发工具:IDEA 服务器:Tomcat9.0, jdk1.8 项目构建:maven 数据库:mysql5.7 系统分前后台,项目采用前后端分离 前端技术:vueelementUI 服务端技术:springbootmybatis-plusredi…

GNS3连接Vmware虚拟机

1 安装配置Gns3、Vmware 安装过程略,最终版本号: Gns3:2.2.44.1 Vmware:17.0 建议保持一致,特别是Gns3,功能虽然强大的,但bug问题感觉也不少 2 虚拟机配置 新建两台Ubuntu 22.04虚拟机&#…

【JavaScript权威指南第七版】读书笔记速度

JavaScript权威指南第七版 序正文前言:图中笔记重点知识第1章 JavaScript简介第一章总结 第2章 词法结构注释字面量标识符和保留字Unicode可选的分号第二章总结 第3章 类型、值和变量【重要】原始类型特殊类型第三章总结 第4章 表达式与操作符表达式操作符条件式调用…

【JAVA面试精选篇-初生牛犊不怕虎】

文章目录 🌽 简介🧺 线程池🌄 Redis⏰ JVM🚛 数据结构🍎 Mysql🍡 结语🌽 简介 海阔凭鱼跃,天高任鸟飞! 学习不要盲目,让大脑舒服的方式吸收知识!!! 本人马上离开济南,回泰安发展,为了积极准备面试,目前在梳理一些知识点,同时希望能够帮助到需要的人… …

Rabbitmq调用FeignClient接口失败

文章目录 一、框架及逻辑介绍1.背景服务介绍2.问题逻辑介绍 二、代码1.A服务2.B服务3.C服务 三、解决思路1.确认B调用C服务接口是否能正常调通2.确认B服务是否能正常调用A服务3.确认消息能否正常消费4.总结 四、修改代码验证1.B服务异步调用C服务接口——失败2.将消费消息放到C…