U-Net for text-to-image

1. Unet for text-to-image

笔记来源:
1.hkproj/pytorch-stable-diffusion
2.understanding u-net a comprehensive tutorial
3.Deep Dive into Self-Attention by Hand
4.Towards Understanding Cross and Self-Attention in Stable Diffusion for Text-Guided Image Editing.arXiv:2403.03431v1 [cs.CV] 06 Mar 2024


Encoder

The encoder is responsible for capturing high-level features and reducing the spatial dimensions of the input image.
It consists of repeated blocks of convolutional layers followed by max-pooling layers, effectively downsampling the input.

Bottleneck

At the center of the U-Net is a bottleneck layer that captures the most critical features while maintaining spatial information.

Decoder

The decoder is responsible for upsampling the low-resolution feature maps to match the original input size.
It consists of repeated blocks of transposed convolutions (upsampling) followed by concatenation with corresponding feature maps from the contracting path.

1.1 TimeEmbedding

每次给U-Net输入一个t(每个time对应的图片中的噪声程度不同,t越大噪声程度越大,反之,t越小噪声程度越小)

import torch
from torch import nn
from torch.nn import functional as F
from attention import SelfAttention, CrossAttentionclass TimeEmbedding(nn.Module):def __init__(self, n_embd):super().__init__()# First linear layer to expand embedding sizeself.linear_1 = nn.Linear(n_embd, 4 * n_embd) # input,output# Second linear layer for further processingself.linear_2 = nn.Linear(4 * n_embd, 4 * n_embd)def forward(self, x):# x: (1, 320)# Expand embedding size: (1, 320) -> (1, 1280)x = self.linear_1(x)# Apply SiLU activation function# (1, 1280) -> (1, 1280)x = F.silu(x) # Further processing: (1, 1280) -> (1, 1280)x = self.linear_2(x)return x

1.2 ResnetBlock(Resnet+Time_embedding)

ResNetBlocks enable the model to learn richer and more complex feature representations by allowing multiple layers to focus on refining features without the risk of vanishing gradients.

下图来自知乎WeThinkIn

Convolutional Layer: Applies a convolution operation to extract features.
Normalization: Often Batch Normalization or Layer Normalization to stabilize and accelerate training.
Activation Function: Typically SiLU to introduce non-linearity.
Second Convolutional Layer: Another convolution to further process the features.
Normalization and Activation: Additional normalization and activation.
Residual Connection: Adds the input of the block to the output of the block.

import torch
from torch import nn
from torch.nn import functional as F
from attention import SelfAttention, CrossAttentionclass UNET_ResidualBlock(nn.Module):def __init__(self, in_channels, out_channels, n_time=1280):super().__init__()# GNself.groupnorm_feature = nn.GroupNorm(32, in_channels)# Convself.conv_feature = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)# Linearself.linear_time = nn.Linear(n_time, out_channels)# 第一次融合结果输入第二次GSC中GNself.groupnorm_merged = nn.GroupNorm(32, out_channels)# 第二次GSC中最后Convself.conv_merged = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)# 若residualblock输入channel和输出channel相同则直接skip,否则做一次convif in_channels == out_channels:self.residual_layer = nn.Identity() # skip connectionelse:self.residual_layer = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)def forward(self, feature, time):# feature: (Batch_Size, In_Channels, Height, Width)# time: (1, 1280)##(1)对 latent feature 进行 GSC (GN+SiLU+Conv)residue = feature# GN# (Batch_Size, In_Channels, Height, Width) -> (Batch_Size, In_Channels, Height, Width)feature = self.groupnorm_feature(feature) #对latent feature进行归一化# SiLU# (Batch_Size, In_Channels, Height, Width) -> (Batch_Size, In_Channels, Height, Width)feature = F.silu(feature) #对latent feature使用激活函数SiLU# Conv# (Batch_Size, In_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)feature = self.conv_feature(feature)##(2)对 Time Embedding 进行 SiLU+Linear# (1, 1280) -> (1, 1280)time = F.silu(time)# (1, 1280) -> (1, Out_Channels)time = self.linear_time(time)
## 对(1)(2)进行融合# Add width and height dimension to time. # (Batch_Size, Out_Channels, Height, Width) + (1, Out_Channels, 1, 1) -> (Batch_Size, Out_Channels, Height, Width)merged = feature + time.unsqueeze(-1).unsqueeze(-1)
## 对(1)(2)融合结果进行 GSC (GN+SiLU+Conv)        # (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)merged = self.groupnorm_merged(merged)# (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)merged = F.silu(merged)# (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)merged = self.conv_merged(merged)
## latent feature进行skip connection 与(1)(2)融合后进行 GSC 后的结果 进行融合# (Batch_Size, Out_Channels, Height, Width) + (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)return merged + self.residual_layer(residue)

1.3 SelfAttention

通过SelfAttention机制让模型理解不同位置的像素之间的依赖关系,以更好地理解图像语义

self attention map clearly expresses the outline information of objects
self-attention maps play a crucial role in preserving the geometric and shape details of the source image during the transformation to the target image.


Self-Attention(自注意力机制):自注意力机制的核心是为输入序列中的每一个位置学习一个权重分布,这样模型就能知道在处理当前位置时,哪些位置的信息更为重要。Self-Attention特指在序列内部进行的注意力计算,即序列中的每一个位置都要和其他所有位置进行注意力权重的计算。

下图为笔者个人理解(若有误,请在评论区指正)

Multi-Head Attention(多头注意力机制):为了让模型能够同时关注来自不同位置的信息,Transformer引入了Multi-Head Attention。它的基本思想是将输入序列的表示拆分成多个子空间(头),然后在每个子空间内独立地计算注意力权重,最后将各个子空间的结果拼接起来。这样做的好处是模型可以在不同的表示子空间中捕获到不同的上下文信息。 引用自:Self-Attention 和 Multi-Head Attention 的区别——附最通俗理解!!

下图为笔者个人理解(若有误,请在评论区指正)

下图来自:Towards Understanding Cross and Self-Attention in Stable Diffusion for Text-Guided Image Editing.arXiv:2403.03431v1 [cs.CV] 06 Mar 2024

the first component of the horse’s self-attention map clearly expresses the outline information of the horse.



形参

def init(self, n_heads, d_embed, in_proj_bias=True, out_proj_bias=True)
(1) n_heads defines how many heads we split our embeddings into.
(2) d_embed defines the size of each embedding.
(3) in_proj_bias=True and out_proj_bias=True determine whether biases are added to the input and output projection layers, respectively.

def forward(self, x, causal_mask=False) # x: (Batch_Size, Seq_Len, Dim)
(1) Batch_Size refers to the number of samples processed together in one forward and backward pass of the model.
For instance, if you have 64 images in one batch, then your Batch_Size is 64.
(2) Seq_Len stands for Sequence Length, which, in the context of images, typically refers to the number of patches the image is divided into.
For an image of size H×W(Weight×Width)and patch size P×P,the sequence length (Seq_Len) would be (H×W)/(P×P). For example, an image of size 224×244 divided into 16×16 patches would result in 196 patches.
(3) Dim refers to the dimension of the embeddings or feature vectors for each token (patch).
For example, if each 16x16 patch is embedded into a vector of dimension 768, then Dim is 768.

in_proj() 为三个权重矩阵整合成的一个矩阵

import torch
from torch import nn
from torch.nn import functional as F
import mathclass SelfAttention(nn.Module):def __init__(self, n_heads, d_embed, in_proj_bias=True, out_proj_bias=True):super().__init__()# This combines the Wq, Wk and Wv matrices into one matrix# d_embed: The dimension of the input embeddings.# 3 * d_embed: The output dimension is three times the input dimension # to produce the query, key, and value vectors in a single step.self.in_proj = nn.Linear(d_embed, 3 * d_embed, bias=in_proj_bias)# This one represents the Wo matrixself.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)self.n_heads = n_headsself.d_head = d_embed // n_heads # dimension of each headdef forward(self, x, causal_mask=False):# x: # (Batch_Size, Seq_Len, Dim)# (Batch_Size, Seq_Len, Dim)input_shape = x.shape # (Batch_Size, Seq_Len, Dim)# Unpack input dimensionsbatch_size, sequence_length, d_embed = input_shape# Shape to split heads and dimensions# (Batch_Size, Seq_Len, H, Dim / H)# The 'interim_shape' should be a tuple representing the desired shape, often including batch size,, sequence length, number of heads and head dimension.interim_shape = (batch_size, sequence_length, self.n_heads, self.d_head) # Project input tensor into query, key, and value matrices# (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim * 3) -> 3 tensor of shape (Batch_Size, Seq_Len, Dim)# Splits the resulting tensor into three equal parts# The '3' specifies that we want to split the tensor into 3 chunks.# The 'dim=-1' specifies that the split should be done along the last dimension of the tensor.q, k, v = self.in_proj(x).chunk(3, dim=-1)# Reshape and transpose to separate heads# (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, H, Dim / H) -> (Batch_Size, H, Seq_Len, Dim / H)# q.view(interim_shape): Reshapes the 'q' tensor into the specified 'interim_shape'.# .transpose(1, 2): Transposes the second and third dimensions of the reshaped tensor.q = q.view(interim_shape).transpose(1, 2)k = k.view(interim_shape).transpose(1, 2)v = v.view(interim_shape).transpose(1, 2)
# Q·K^T	# Compute scaled dot-product attention# (Batch_Size, H, Seq_Len, Dim / H) @ (Batch_Size, H, Dim / H, Seq_Len) -> (Batch_Size, H, Seq_Len, Seq_Len)weight = q @ k.transpose(-1, -2) # @表示矩阵乘法if causal_mask:# Create a mask for the upper triangle (causal attention)# Mask where the upper triangle (above the principal diagonal) is 1mask = torch.ones_like(weight, dtype=torch.bool).triu(1) # Apply the mask by filling the upper triangle with -inf# Fill the upper triangle with -infweight.masked_fill_(mask, -torch.inf) 
# (Q·K^T)/sqrt{d_k}        # Divide by d_k (Dim / H). # (Batch_Size, H, Seq_Len, Seq_Len) -> (Batch_Size, H, Seq_Len, Seq_Len)weight /= math.sqrt(self.d_head) 
# softmax(Q·K^T/sqrt{d_k})# (Batch_Size, H, Seq_Len, Seq_Len) -> (Batch_Size, H, Seq_Len, Seq_Len)weight = F.softmax(weight, dim=-1) 
# softmax(Q·K^T/sqrt{d_k})·V# (Batch_Size, H, Seq_Len, Seq_Len) @ (Batch_Size, H, Seq_Len, Dim / H) -> (Batch_Size, H, Seq_Len, Dim / H)output = weight @ v# Transpose and reshape to combine heads# (Batch_Size, H, Seq_Len, Dim / H) -> (Batch_Size, Seq_Len, H, Dim / H)output = output.transpose(1, 2) # (Batch_Size, Seq_Len, H, Dim / H) -> (Batch_Size, Seq_Len, Dim)output = output.reshape(input_shape) # softmax(Q·K^T/sqrt{d_k})·V·W^O# (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim)output = self.out_proj(output) # (Batch_Size, Seq_Len, Dim)return output

1.4 CrossAttention

Cross Attention是一种多头注意力机制,它可在两个不同的输入序列之间建立关联,并且可以将其中一个输入序列的信息传递给另一个输入序列
Stable Diffusion中使用Cross Attention模块有助于在输入文本和生成图片之间建立联系,控制文本信息和图像信息的融合交互,通俗来说,控制U-Net把噪声矩阵的某一块与文本里的特定信息相对应。

The cross-attention map is not only a weight measure of the conditional prompt at the corresponding positions in the generated image but also contains the semantic features of the
conditional token.
The cross-attention map enables the diffusion model to locate/align the tokens of the prompt in the image area.


下图为笔者个人理解(若有误,请在评论区指正)

下图来自:Towards Understanding Cross and Self-Attention in Stable Diffusion for Text-Guided Image Editing.arXiv:2403.03431v1 [cs.CV] 06 Mar 2024



下图来自知乎WeThinkIn

class CrossAttention(nn.Module):def __init__(self, n_heads, d_embed, d_cross, in_proj_bias=True, out_proj_bias=True):super().__init__()# W^Qself.q_proj   = nn.Linear(d_embed, d_embed, bias=in_proj_bias)# W^Kself.k_proj   = nn.Linear(d_cross, d_embed, bias=in_proj_bias)# W^Vself.v_proj   = nn.Linear(d_cross, d_embed, bias=in_proj_bias)# W^Oself.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)self.n_heads = n_heads #头个数self.d_head = d_embed // n_heads # 每个头的维度 def forward(self, x, y):# x (latent): # (Batch_Size, Seq_Len_Q, Dim_Q)# y (context): # (Batch_Size, Seq_Len_KV, Dim_KV) = (Batch_Size, 77, 768)# Matrix C (Seq_Len_KV×Dim_KV)input_shape = x.shapebatch_size, sequence_length, d_embed = input_shape# Divide each embedding of Q into multiple heads such that d_heads * n_heads = Dim_Qinterim_shape = (batch_size, -1, self.n_heads, self.d_head)
# Q = X·W^Q# (Batch_Size, Seq_Len_Q, Dim_Q) -> (Batch_Size, Seq_Len_Q, Dim_Q)q = self.q_proj(x)
# K = Y·W^K# (Batch_Size, Seq_Len_KV, Dim_KV) -> (Batch_Size, Seq_Len_KV, Dim_Q)k = self.k_proj(y)
# V = Y·W^V# (Batch_Size, Seq_Len_KV, Dim_KV) -> (Batch_Size, Seq_Len_KV, Dim_Q)v = self.v_proj(y)# q.view(interim_shape): Reshapes the 'q' tensor into the specified 'interim_shape'.# .transpose(1, 2): Transposes the second and third dimensions of the reshaped tensor.	# Transpose the tensor by swapping the dimensions 1 and 2# (Batch_Size, Seq_Len_Q, Dim_Q) -> (Batch_Size, Seq_Len_Q, H, Dim_Q / H) -> (Batch_Size, H, Seq_Len_Q, Dim_Q / H)q = q.view(interim_shape).transpose(1, 2) # (Batch_Size, Seq_Len_KV, Dim_Q) -> (Batch_Size, Seq_Len_KV, H, Dim_Q / H) -> (Batch_Size, H, Seq_Len_KV, Dim_Q / H)k = k.view(interim_shape).transpose(1, 2) # (Batch_Size, Seq_Len_KV, Dim_Q) -> (Batch_Size, Seq_Len_KV, H, Dim_Q / H) -> (Batch_Size, H, Seq_Len_KV, Dim_Q / H)v = v.view(interim_shape).transpose(1, 2) 
# Q·K^T        # (Batch_Size, H, Seq_Len_Q, Dim_Q / H) @ (Batch_Size, H, Dim_Q / H, Seq_Len_KV) -> (Batch_Size, H, Seq_Len_Q, Seq_Len_KV)weight = q @ k.transpose(-1, -2)
# (Q·K^T)/sqrt{d_k}          # (Batch_Size, H, Seq_Len_Q, Seq_Len_KV)weight /= math.sqrt(self.d_head)
# softmax(Q·K^T/sqrt{d_k})        # (Batch_Size, H, Seq_Len_Q, Seq_Len_KV)weight = F.softmax(weight, dim=-1)
# softmax(Q·K^T/sqrt{d_k})·V        # (Batch_Size, H, Seq_Len_Q, Seq_Len_KV) @ (Batch_Size, H, Seq_Len_KV, Dim_Q / H) -> (Batch_Size, H, Seq_Len_Q, Dim_Q / H)output = weight @ v# Transpose the tensor by swapping the dimensions 1 and 2# Ensure the tensor is stored in contiguous memory# This is important because some operations require the tensor to be contiguous in memory# After the transpose operation, the tensor might not be stored contiguously# (Batch_Size, H, Seq_Len_Q, Dim_Q / H) -> (Batch_Size, Seq_Len_Q, H, Dim_Q / H)output = output.transpose(1, 2).contiguous()# Reshape the tensor 'output' to the shape specified by 'input_shape'# (Batch_Size, Seq_Len_Q, H, Dim_Q / H) -> (Batch_Size, Seq_Len_Q, Dim_Q)output = output.view(input_shape)
# softmax(Q·K^T/sqrt{d_k})·V·W^O        # (Batch_Size, Seq_Len_Q, Dim_Q) -> (Batch_Size, Seq_Len_Q, Dim_Q)output = self.out_proj(output)# (Batch_Size, Seq_Len_Q, Dim_Q)return output

1.5 AttentionBlock (SelfAttention+CrossAttention)

下图改编自知乎WeThinkIn

class UNET_AttentionBlock(nn.Module):def __init__(self, n_head: int, n_embd: int, d_context=768):super().__init__()channels = n_head * n_embd# GNself.groupnorm = nn.GroupNorm(32, channels, eps=1e-6)# Convself.conv_input = nn.Conv2d(channels, channels, kernel_size=1, padding=0)# LNself.layernorm_1 = nn.LayerNorm(channels)# SelfAttentionself.attention_1 = SelfAttention(n_head, channels, in_proj_bias=False)# LNself.layernorm_2 = nn.LayerNorm(channels)# CrossAttentionself.attention_2 = CrossAttention(n_head, channels, d_context, in_proj_bias=False)# LNself.layernorm_3 = nn.LayerNorm(channels)# GeGLUself.linear_geglu_1  = nn.Linear(channels, 4 * channels * 2)self.linear_geglu_2 = nn.Linear(4 * channels, channels)# Convself.conv_output = nn.Conv2d(channels, channels, kernel_size=1, padding=0)def forward(self, x, context):# x: (Batch_Size, Features, Height, Width)# context(text_embedding): (Batch_Size, Seq_Len, Dim) residue_long = x
# GN# (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height, Width)x = self.groupnorm(x)
# Conv        # (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height, Width)x = self.conv_input(x)n, c, h, w = x.shape# (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height * Width)x = x.view((n, c, h * w))# (Batch_Size, Features, Height * Width) -> (Batch_Size, Height * Width, Features)x = x.transpose(-1, -2)# Normalization + Self-Attention with skip connection
# Basci Transformer Block# (Batch_Size, Height * Width, Features)residue_short = x
## LN_1    # (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)x = self.layernorm_1(x)
## Self Attention       # (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)x = self.attention_1(x)
## Skip Connection       # (Batch_Size, Height * Width, Features) + (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)x += residue_short# (Batch_Size, Height * Width, Features)residue_short = x# Normalization + Cross-Attention with skip connection
## LN_2        # (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)x = self.layernorm_2(x)
## Cross Attention        # (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)x = self.attention_2(x, context)
## Skip Connection        # (Batch_Size, Height * Width, Features) + (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)x += residue_short# (Batch_Size, Height * Width, Features)residue_short = x
## LN_3  # Normalization + FFN with GeGLU and skip connection      # (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)x = self.layernorm_3(x)
## Feed Forward
### GeGLU        # GeGLU as implemented in the original code: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/attention.py#L37C10-L37C10# (Batch_Size, Height * Width, Features) -> two tensors of shape (Batch_Size, Height * Width, Features * 4)x, gate = self.linear_geglu_1(x).chunk(2, dim=-1) # Element-wise product: (Batch_Size, Height * Width, Features * 4) * (Batch_Size, Height * Width, Features * 4) -> (Batch_Size, Height * Width, Features * 4)x = x * F.gelu(gate)# (Batch_Size, Height * Width, Features * 4) -> (Batch_Size, Height * Width, Features)x = self.linear_geglu_2(x)
## Skip Connection        # (Batch_Size, Height * Width, Features) + (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)x += residue_short# (Batch_Size, Height * Width, Features) -> (Batch_Size, Features, Height * Width)x = x.transpose(-1, -2)# (Batch_Size, Features, Height * Width) -> (Batch_Size, Features, Height, Width)x = x.view((n, c, h, w))
# Conv + Skip Connection# Final skip connection between initial input and output of the block# (Batch_Size, Features, Height, Width) + (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height, Width)return self.conv_output(x) + residue_long

1.6 Upsample

class Upsample(nn.Module):def __init__(self, channels):super().__init__()self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)def forward(self, x):# (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height * 2, Width * 2)# 使用nearest进行上采样x = F.interpolate(x, scale_factor=2, mode='nearest') return self.conv(x)

1.7 SwitchSequential

# Define a custom sequential container class named SwitchSequential
# Inherits from nn.Sequential, which is a container module from PyTorch
class SwitchSequential(nn.Sequential):
# Define the forward method, which specifies how the input data flows through the layers
# x: the input tensor
# context: additional context information, possibly used for attention mechanisms
# time: additional time information, possibly used for temporal aspects in certain layersdef forward(self, x, context, time):for layer in self:# Check if the current layer is an instance of UNET_AttentionBlock# Pass the input tensor and context information through the attention blockif isinstance(layer, UNET_AttentionBlock):x = layer(x, context)# Check if the current layer is an instance of UNET_ResidualBlock# Pass the input tensor and time information through the residual blockelif isinstance(layer, UNET_ResidualBlock):x = layer(x, time)# For all other types of layers# Simply pass the input tensor through the layerelse:x = layer(x)return x

1.8 Unet

class UNET(nn.Module):def __init__(self):super().__init__()
# Encoderself.encoders = nn.ModuleList([# (Batch_Size, 4, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)SwitchSequential(nn.Conv2d(4, 320, kernel_size=3, padding=1)),
# CrossAttentionDownBlock2d_1 (320 channels)
## ResnetBlock+AttentionBlock           # (Batch_Size, 320, Height / 8, Width / 8) -> # (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)SwitchSequential(UNET_ResidualBlock(320, 320), UNET_AttentionBlock(8, 40)),
## ResnetBlock+AttentionBlock            # (Batch_Size, 320, Height / 8, Width / 8) -> # (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)SwitchSequential(UNET_ResidualBlock(320, 320), UNET_AttentionBlock(8, 40)),
## Downsample2D             # (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 16, Width / 16)SwitchSequential(nn.Conv2d(320, 320, kernel_size=3, stride=2, padding=1)),
# CrossAttentionDownBlock2d_2 (640 channels)  
## ResnetBlock+AttentionBlock          # (Batch_Size, 320, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16)SwitchSequential(UNET_ResidualBlock(320, 640), UNET_AttentionBlock(8, 80)),
## ResnetBlock+AttentionBlock          # (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16)SwitchSequential(UNET_ResidualBlock(640, 640), UNET_AttentionBlock(8, 80)),
## Downsample2D             # (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 32, Width / 32)SwitchSequential(nn.Conv2d(640, 640, kernel_size=3, stride=2, padding=1)),
# CrossAttentionDownBlock2d_3 (1280 channels)
## ResnetBlock+AttentionBlock           # (Batch_Size, 640, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32)SwitchSequential(UNET_ResidualBlock(640, 1280), UNET_AttentionBlock(8, 160)),
## ResnetBlock+AttentionBlock            # (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32)SwitchSequential(UNET_ResidualBlock(1280, 1280), UNET_AttentionBlock(8, 160)),
# Downsample2D
## ResnetBlock+ResnetBlock            # (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 64, Width / 64)SwitchSequential(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1)),
## DownBlock2D            # (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)SwitchSequential(UNET_ResidualBlock(1280, 1280)),# (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)SwitchSequential(UNET_ResidualBlock(1280, 1280)),])
# Bottleneckself.bottleneck = SwitchSequential(# (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)UNET_ResidualBlock(1280, 1280), # (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)UNET_AttentionBlock(8, 160), # (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)UNET_ResidualBlock(1280, 1280), )
# Decoder        self.decoders = nn.ModuleList([
# UpBlock2D
## ResnetBlock+ResnetBlock # (Batch_Size, 2560, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)SwitchSequential(UNET_ResidualBlock(2560, 1280)),# (Batch_Size, 2560, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)SwitchSequential(UNET_ResidualBlock(2560, 1280)),
# CrossAttentionUpBlock2d_3 (1280 channels)
## upsample            # (Batch_Size, 2560, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 32, Width / 32) SwitchSequential(UNET_ResidualBlock(2560, 1280), Upsample(1280)),
## ResnetBlock+AttentionBlock           # (Batch_Size, 2560, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32)SwitchSequential(UNET_ResidualBlock(2560, 1280), UNET_AttentionBlock(8, 160)),
## ResnetBlock+AttentionBlock            # (Batch_Size, 2560, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32)SwitchSequential(UNET_ResidualBlock(2560, 1280), UNET_AttentionBlock(8, 160)),
## ResnetBlock+AttentionBlock            # (Batch_Size, 1920, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 16, Width / 16)SwitchSequential(UNET_ResidualBlock(1920, 1280), UNET_AttentionBlock(8, 160), Upsample(1280)),
# CrossAttentionUpBlock2d_2 (640 channels)            # (Batch_Size, 1920, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16)SwitchSequential(UNET_ResidualBlock(1920, 640), UNET_AttentionBlock(8, 80)),
## ResnetBlock+AttentionBlock            # (Batch_Size, 1280, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16)SwitchSequential(UNET_ResidualBlock(1280, 640), UNET_AttentionBlock(8, 80)),
## ResnetBlock+AttentionBlock            # (Batch_Size, 960, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 8, Width / 8)SwitchSequential(UNET_ResidualBlock(960, 640), UNET_AttentionBlock(8, 80), Upsample(640)),
# CrossAttentionUpBlock2d_1 (1280 channels)           # (Batch_Size, 960, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)SwitchSequential(UNET_ResidualBlock(960, 320), UNET_AttentionBlock(8, 40)),
## ResnetBlock+AttentionBlock            # (Batch_Size, 640, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)SwitchSequential(UNET_ResidualBlock(640, 320), UNET_AttentionBlock(8, 40)),
## ResnetBlock+AttentionBlock            # (Batch_Size, 640, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)SwitchSequential(UNET_ResidualBlock(640, 320), UNET_AttentionBlock(8, 40)),])def forward(self, x, context, time):# x: (Batch_Size, 4, Height / 8, Width / 8)# context: (Batch_Size, Seq_Len, Dim) # time: (1, 1280)skip_connections = []
# Encoderfor layers in self.encoders:x = layers(x, context, time)skip_connections.append(x) # 将每层lay的输出添加到列表中,便于后续up中进行skip connection
# Bottleneckx = self.bottleneck(x, context, time)
# Decoderfor layers in self.decoders:# Since we always concat with the skip connection of the encoder, the number of features increases before being sent to the decoder's layerx = torch.cat((x, skip_connections.pop()), dim=1) x = layers(x, context, time)return x

1.9 OutputLayer

class UNET_OutputLayer(nn.Module):def __init__(self, in_channels, out_channels):super().__init__()self.groupnorm = nn.GroupNorm(32, in_channels)self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)def forward(self, x):# x: (Batch_Size, 320, Height / 8, Width / 8)# (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)x = self.groupnorm(x)# (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)x = F.silu(x)# (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 4, Height / 8, Width / 8)x = self.conv(x)# (Batch_Size, 4, Height / 8, Width / 8) return x

1.10 Diffusion

class Diffusion(nn.Module):def __init__(self):super().__init__()self.time_embedding = TimeEmbedding(320)self.unet = UNET()self.final = UNET_OutputLayer(320, 4)def forward(self, latent, context, time):# latent: (Batch_Size, 4, Height / 8, Width / 8)# context: (Batch_Size, Seq_Len, Dim)# time: (1, 320)# (1, 320) -> (1, 1280)time = self.time_embedding(time)# (Batch, 4, Height / 8, Width / 8) -> (Batch, 320, Height / 8, Width / 8)output = self.unet(latent, context, time)# (Batch, 320, Height / 8, Width / 8) -> (Batch, 4, Height / 8, Width / 8)output = self.final(output)# (Batch, 4, Height / 8, Width / 8)return output

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/37430.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

java大型医院绩效考核系统源码(医院为什么需要绩效机制?)医院绩效考核系统源码 医院管理绩效考核系统源码

java大型医院绩效考核系统源码(医院为什么需要绩效机制?)医院绩效考核系统源码 医院管理绩效考核系统源码 医院作为提供医疗服务的核心机构,其运营和管理效率直接影响到患者的就医体验、治疗效果以及医院的长期发展。因此&#xf…

构造函数的小白理解

一、实例 using System; using System.Collections; using System.Collections.Generic; using UnityEngine;//定义一个名为Question的类,用于存储问题及相关信息 [Serializable] public class Question {public string questionText;//存储题目文本字段public str…

Unix/Linux shell实用小程序1:生字本

前言 在日常工作学习中,我们会经常遇到一些不认识的英语单词,于时我们会打开翻译网站或者翻译软件进行查询,但是大部分工具没有生词本的功能,而有生字本的软件又需要注册登陆,免不了很麻烦,而且自己的数据…

风控图算法之中心性算法(小数据集Python版)

风控图算法之中心性算法(小数据集Python版) 图算法在金融风控领域的应用已经超越了传统的社区发现技术,这些技术曾被主要用于识别和分析欺诈性行为模式,例如黑产团伙。当前,一系列图统计算法,包括介数中心…

Hive SQL:实现炸列(列转行)以及逆操作(行转列)

目录 列转行行转列 列转行 函数: EXPLODE(ARRAY):将ARRAY中的每一元素转换为每一行 EXPLODE(MAP):将MAP中的每个键值对转换为两行,其中一行数据包含键,另一行数据包含值 数据样例: 1、将每天的课程&#…

ServletConfig与ServletContext详解

文章目录 概要web.xmlServletConfig介绍ServletConfig实例ServletConfig细节ServletContext介绍ServletContext实例ServletContext细节ServletContext获得服务访问次数&#xff08;可拓展&#xff09;总结 概要 web.xml <?xml version"1.0" encoding"UTF-…

OBD诊断(ISO15031) 02服务

文章目录 功能简介请求和响应1、read-supported PIDs1.1、请求1.2、肯定响应 2、read PID value1.1、请求1.2、肯定响应 3、同时请求多个PID4、同时读取多个PID数据 Parameter definition报文示例1、单个PID请求和读取2、多个PID请求和读取 功能简介 02服务&#xff0c;即 Req…

亚太杯赛题思路发布(中文版)

导读&#xff1a; 本文将继续修炼回归模型算法&#xff0c;并总结了一些常用的除线性回归模型之外的模型&#xff0c;其中包括一些单模型及集成学习器。 保序回归、多项式回归、多输出回归、多输出K近邻回归、决策树回归、多输出决策树回归、AdaBoost回归、梯度提升决策树回归…

UI(三)布局

文章目录 1、Colum和Row——垂直方向容器和水平方向容器2、ColumnSplit和RowSplit——子组件之间插入一条分割线3、Flex——弹性布局子组件的容器4、Grid和GridItem——网格容器和网格容器单元格5、GridRow和GridCol——栅格容器组件和栅格子组件6、List、ListItem、ListItemGr…

力扣每日一题 6/28 动态规划/数组

博客主页&#xff1a;誓则盟约系列专栏&#xff1a;IT竞赛 专栏关注博主&#xff0c;后期持续更新系列文章如果有错误感谢请大家批评指出&#xff0c;及时修改感谢大家点赞&#x1f44d;收藏⭐评论✍ 2742.给墙壁刷油漆【困难】 题目&#xff1a; 给你两个长度为 n 下标从 0…

密码学及其应用 —— 非对称加密/公匙密码技术

1 RSA加密算法 RSA加密算法是一种基于公钥密码学的加密技术&#xff0c;由罗纳德里维斯特&#xff08;Ron Rivest&#xff09;、阿迪萨莫尔&#xff08;Adi Shamir&#xff09;和伦纳德阿德曼&#xff08;Leonard Adleman&#xff09;在1977年共同发明。RSA算法是第一个既能用于…

C++ sizeof的各种

C sizeof的各种 1. 含有虚函数的类对象的空间大小2. 虚拟继承的类对象的空间大小3. 普通变量所占空间大小4. 复合数据类型&#xff08;结构体和类&#xff09;5. 数组6. 类型别名7. 动态分配内存8. 指针9. 静态变量10. 联合体11. 结构体使用#program pack 1. 含有虚函数的类对象…

RuoYi_Cloud本地搭建

1.进入若依官网获取git地址 &#xff08;1&#xff09;百度搜“若依官网进”入如下界面 &#xff08;2&#xff09;点击进入git&#xff0c;点克隆下载 &#xff08;3&#xff09;复制http地址 2.在git链接在idea本地打开 &#xff08;1&#xff09;返回桌面——右键&#xf…

金属波纹管

金属波纹管是一种外型规则的波浪样的管材&#xff0c;常用的金属波纹管有碳钢的&#xff0c;和不锈钢的&#xff0c;也有钢质衬塑的、铝质的等等。这种管材主要用于需要很小的弯曲半径非同心轴向传动&#xff0c;或者不规则转弯、伸缩&#xff0c;或者吸收管道的热变形等&#…

数据结构历年考研真题对应知识点(数组和特殊矩阵)

目录 3.4数组和特殊矩阵 3.4.2数组的存储结构 【二维数组按行优先存储的下标对应关系(2021)】 3.4.3特殊矩阵的压缩存储 【对称矩阵压缩存储的下标对应关系(2018、2020)】 【上三角矩阵采用行优先存储的应用(2011)】 【三对角矩阵压缩存储的下标对应关系(2016)】 3.4.…

为什么有的手机卡没有语音功能呢?

大家好&#xff0c;今天这篇文章为大家介绍一下&#xff0c;无通话功能的手机卡&#xff0c; 在网上申请过手机卡的朋友应该都知道&#xff0c;现在有这么一种手机卡&#xff0c;虽然是运营商推出的正规号卡&#xff0c;但是却屏蔽了通话功能&#xff0c;你知道这是为什么吗&am…

自组装mid360便捷化bag包采集设备

一、问题一&#xff1a;电脑太重&#xff0c;换nuc 采集mid360数据的过程中&#xff0c;发现了头疼的问题&#xff0c;得一手拿着电脑&#xff0c;一手拿着mid360来采集&#xff0c;实在是累胳膊。因此&#xff0c;网购了一个intel nuc, 具体型号是12wshi5000华尔街峡谷nuc12i…

二刷算法训练营Day45 | 动态规划(7/17)

目录 详细布置&#xff1a; 1. 139. 单词拆分 2. 多重背包理论基础 3. 背包总结 3.1 背包递推公式 3.2 遍历顺序 01背包 完全背包 详细布置&#xff1a; 1. 139. 单词拆分 给你一个字符串 s 和一个字符串列表 wordDict 作为字典。如果可以利用字典中出现的一个或多个单…

昇思25天学习打卡营第6天|linchenfengxue

​​​​​​SSD目标检测 SSD&#xff0c;全称Single Shot MultiBox Detector&#xff0c;是Wei Liu在ECCV 2016上提出的一种目标检测算法。使用Nvidia Titan X在VOC 2007测试集上&#xff0c;SSD对于输入尺寸300x300的网络&#xff0c;达到74.3%mAP(mean Average Precision)以…

nginx架构基本数据结构配置模块请求详解

nginx源码的目录结构&#xff1a; . ├── auto 自动检测系统环境以及编译相关的脚本 │ ├── cc 关于编译器相关的编译选项的检测脚本 │ ├── lib nginx编译所需要的一些库的检测脚本 │ ├── os 与平台相关的一些系统参…