U-Net for text-to-image

1. Unet for text-to-image

笔记来源:
1.hkproj/pytorch-stable-diffusion
2.understanding u-net a comprehensive tutorial
3.Deep Dive into Self-Attention by Hand
4.Towards Understanding Cross and Self-Attention in Stable Diffusion for Text-Guided Image Editing.arXiv:2403.03431v1 [cs.CV] 06 Mar 2024


Encoder

The encoder is responsible for capturing high-level features and reducing the spatial dimensions of the input image.
It consists of repeated blocks of convolutional layers followed by max-pooling layers, effectively downsampling the input.

Bottleneck

At the center of the U-Net is a bottleneck layer that captures the most critical features while maintaining spatial information.

Decoder

The decoder is responsible for upsampling the low-resolution feature maps to match the original input size.
It consists of repeated blocks of transposed convolutions (upsampling) followed by concatenation with corresponding feature maps from the contracting path.

1.1 TimeEmbedding

每次给U-Net输入一个t(每个time对应的图片中的噪声程度不同,t越大噪声程度越大,反之,t越小噪声程度越小)

import torch
from torch import nn
from torch.nn import functional as F
from attention import SelfAttention, CrossAttentionclass TimeEmbedding(nn.Module):def __init__(self, n_embd):super().__init__()# First linear layer to expand embedding sizeself.linear_1 = nn.Linear(n_embd, 4 * n_embd) # input,output# Second linear layer for further processingself.linear_2 = nn.Linear(4 * n_embd, 4 * n_embd)def forward(self, x):# x: (1, 320)# Expand embedding size: (1, 320) -> (1, 1280)x = self.linear_1(x)# Apply SiLU activation function# (1, 1280) -> (1, 1280)x = F.silu(x) # Further processing: (1, 1280) -> (1, 1280)x = self.linear_2(x)return x

1.2 ResnetBlock(Resnet+Time_embedding)

ResNetBlocks enable the model to learn richer and more complex feature representations by allowing multiple layers to focus on refining features without the risk of vanishing gradients.

下图来自知乎WeThinkIn

Convolutional Layer: Applies a convolution operation to extract features.
Normalization: Often Batch Normalization or Layer Normalization to stabilize and accelerate training.
Activation Function: Typically SiLU to introduce non-linearity.
Second Convolutional Layer: Another convolution to further process the features.
Normalization and Activation: Additional normalization and activation.
Residual Connection: Adds the input of the block to the output of the block.

import torch
from torch import nn
from torch.nn import functional as F
from attention import SelfAttention, CrossAttentionclass UNET_ResidualBlock(nn.Module):def __init__(self, in_channels, out_channels, n_time=1280):super().__init__()# GNself.groupnorm_feature = nn.GroupNorm(32, in_channels)# Convself.conv_feature = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)# Linearself.linear_time = nn.Linear(n_time, out_channels)# 第一次融合结果输入第二次GSC中GNself.groupnorm_merged = nn.GroupNorm(32, out_channels)# 第二次GSC中最后Convself.conv_merged = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)# 若residualblock输入channel和输出channel相同则直接skip,否则做一次convif in_channels == out_channels:self.residual_layer = nn.Identity() # skip connectionelse:self.residual_layer = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)def forward(self, feature, time):# feature: (Batch_Size, In_Channels, Height, Width)# time: (1, 1280)##(1)对 latent feature 进行 GSC (GN+SiLU+Conv)residue = feature# GN# (Batch_Size, In_Channels, Height, Width) -> (Batch_Size, In_Channels, Height, Width)feature = self.groupnorm_feature(feature) #对latent feature进行归一化# SiLU# (Batch_Size, In_Channels, Height, Width) -> (Batch_Size, In_Channels, Height, Width)feature = F.silu(feature) #对latent feature使用激活函数SiLU# Conv# (Batch_Size, In_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)feature = self.conv_feature(feature)##(2)对 Time Embedding 进行 SiLU+Linear# (1, 1280) -> (1, 1280)time = F.silu(time)# (1, 1280) -> (1, Out_Channels)time = self.linear_time(time)
## 对(1)(2)进行融合# Add width and height dimension to time. # (Batch_Size, Out_Channels, Height, Width) + (1, Out_Channels, 1, 1) -> (Batch_Size, Out_Channels, Height, Width)merged = feature + time.unsqueeze(-1).unsqueeze(-1)
## 对(1)(2)融合结果进行 GSC (GN+SiLU+Conv)        # (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)merged = self.groupnorm_merged(merged)# (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)merged = F.silu(merged)# (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)merged = self.conv_merged(merged)
## latent feature进行skip connection 与(1)(2)融合后进行 GSC 后的结果 进行融合# (Batch_Size, Out_Channels, Height, Width) + (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)return merged + self.residual_layer(residue)

1.3 SelfAttention

通过SelfAttention机制让模型理解不同位置的像素之间的依赖关系,以更好地理解图像语义

self attention map clearly expresses the outline information of objects
self-attention maps play a crucial role in preserving the geometric and shape details of the source image during the transformation to the target image.


Self-Attention(自注意力机制):自注意力机制的核心是为输入序列中的每一个位置学习一个权重分布,这样模型就能知道在处理当前位置时,哪些位置的信息更为重要。Self-Attention特指在序列内部进行的注意力计算,即序列中的每一个位置都要和其他所有位置进行注意力权重的计算。

下图为笔者个人理解(若有误,请在评论区指正)

Multi-Head Attention(多头注意力机制):为了让模型能够同时关注来自不同位置的信息,Transformer引入了Multi-Head Attention。它的基本思想是将输入序列的表示拆分成多个子空间(头),然后在每个子空间内独立地计算注意力权重,最后将各个子空间的结果拼接起来。这样做的好处是模型可以在不同的表示子空间中捕获到不同的上下文信息。 引用自:Self-Attention 和 Multi-Head Attention 的区别——附最通俗理解!!

下图为笔者个人理解(若有误,请在评论区指正)

下图来自:Towards Understanding Cross and Self-Attention in Stable Diffusion for Text-Guided Image Editing.arXiv:2403.03431v1 [cs.CV] 06 Mar 2024

the first component of the horse’s self-attention map clearly expresses the outline information of the horse.



形参

def init(self, n_heads, d_embed, in_proj_bias=True, out_proj_bias=True)
(1) n_heads defines how many heads we split our embeddings into.
(2) d_embed defines the size of each embedding.
(3) in_proj_bias=True and out_proj_bias=True determine whether biases are added to the input and output projection layers, respectively.

def forward(self, x, causal_mask=False) # x: (Batch_Size, Seq_Len, Dim)
(1) Batch_Size refers to the number of samples processed together in one forward and backward pass of the model.
For instance, if you have 64 images in one batch, then your Batch_Size is 64.
(2) Seq_Len stands for Sequence Length, which, in the context of images, typically refers to the number of patches the image is divided into.
For an image of size H×W(Weight×Width)and patch size P×P,the sequence length (Seq_Len) would be (H×W)/(P×P). For example, an image of size 224×244 divided into 16×16 patches would result in 196 patches.
(3) Dim refers to the dimension of the embeddings or feature vectors for each token (patch).
For example, if each 16x16 patch is embedded into a vector of dimension 768, then Dim is 768.

in_proj() 为三个权重矩阵整合成的一个矩阵

import torch
from torch import nn
from torch.nn import functional as F
import mathclass SelfAttention(nn.Module):def __init__(self, n_heads, d_embed, in_proj_bias=True, out_proj_bias=True):super().__init__()# This combines the Wq, Wk and Wv matrices into one matrix# d_embed: The dimension of the input embeddings.# 3 * d_embed: The output dimension is three times the input dimension # to produce the query, key, and value vectors in a single step.self.in_proj = nn.Linear(d_embed, 3 * d_embed, bias=in_proj_bias)# This one represents the Wo matrixself.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)self.n_heads = n_headsself.d_head = d_embed // n_heads # dimension of each headdef forward(self, x, causal_mask=False):# x: # (Batch_Size, Seq_Len, Dim)# (Batch_Size, Seq_Len, Dim)input_shape = x.shape # (Batch_Size, Seq_Len, Dim)# Unpack input dimensionsbatch_size, sequence_length, d_embed = input_shape# Shape to split heads and dimensions# (Batch_Size, Seq_Len, H, Dim / H)# The 'interim_shape' should be a tuple representing the desired shape, often including batch size,, sequence length, number of heads and head dimension.interim_shape = (batch_size, sequence_length, self.n_heads, self.d_head) # Project input tensor into query, key, and value matrices# (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim * 3) -> 3 tensor of shape (Batch_Size, Seq_Len, Dim)# Splits the resulting tensor into three equal parts# The '3' specifies that we want to split the tensor into 3 chunks.# The 'dim=-1' specifies that the split should be done along the last dimension of the tensor.q, k, v = self.in_proj(x).chunk(3, dim=-1)# Reshape and transpose to separate heads# (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, H, Dim / H) -> (Batch_Size, H, Seq_Len, Dim / H)# q.view(interim_shape): Reshapes the 'q' tensor into the specified 'interim_shape'.# .transpose(1, 2): Transposes the second and third dimensions of the reshaped tensor.q = q.view(interim_shape).transpose(1, 2)k = k.view(interim_shape).transpose(1, 2)v = v.view(interim_shape).transpose(1, 2)
# Q·K^T	# Compute scaled dot-product attention# (Batch_Size, H, Seq_Len, Dim / H) @ (Batch_Size, H, Dim / H, Seq_Len) -> (Batch_Size, H, Seq_Len, Seq_Len)weight = q @ k.transpose(-1, -2) # @表示矩阵乘法if causal_mask:# Create a mask for the upper triangle (causal attention)# Mask where the upper triangle (above the principal diagonal) is 1mask = torch.ones_like(weight, dtype=torch.bool).triu(1) # Apply the mask by filling the upper triangle with -inf# Fill the upper triangle with -infweight.masked_fill_(mask, -torch.inf) 
# (Q·K^T)/sqrt{d_k}        # Divide by d_k (Dim / H). # (Batch_Size, H, Seq_Len, Seq_Len) -> (Batch_Size, H, Seq_Len, Seq_Len)weight /= math.sqrt(self.d_head) 
# softmax(Q·K^T/sqrt{d_k})# (Batch_Size, H, Seq_Len, Seq_Len) -> (Batch_Size, H, Seq_Len, Seq_Len)weight = F.softmax(weight, dim=-1) 
# softmax(Q·K^T/sqrt{d_k})·V# (Batch_Size, H, Seq_Len, Seq_Len) @ (Batch_Size, H, Seq_Len, Dim / H) -> (Batch_Size, H, Seq_Len, Dim / H)output = weight @ v# Transpose and reshape to combine heads# (Batch_Size, H, Seq_Len, Dim / H) -> (Batch_Size, Seq_Len, H, Dim / H)output = output.transpose(1, 2) # (Batch_Size, Seq_Len, H, Dim / H) -> (Batch_Size, Seq_Len, Dim)output = output.reshape(input_shape) # softmax(Q·K^T/sqrt{d_k})·V·W^O# (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim)output = self.out_proj(output) # (Batch_Size, Seq_Len, Dim)return output

1.4 CrossAttention

Cross Attention是一种多头注意力机制,它可在两个不同的输入序列之间建立关联,并且可以将其中一个输入序列的信息传递给另一个输入序列
Stable Diffusion中使用Cross Attention模块有助于在输入文本和生成图片之间建立联系,控制文本信息和图像信息的融合交互,通俗来说,控制U-Net把噪声矩阵的某一块与文本里的特定信息相对应。

The cross-attention map is not only a weight measure of the conditional prompt at the corresponding positions in the generated image but also contains the semantic features of the
conditional token.
The cross-attention map enables the diffusion model to locate/align the tokens of the prompt in the image area.


下图为笔者个人理解(若有误,请在评论区指正)

下图来自:Towards Understanding Cross and Self-Attention in Stable Diffusion for Text-Guided Image Editing.arXiv:2403.03431v1 [cs.CV] 06 Mar 2024



下图来自知乎WeThinkIn

class CrossAttention(nn.Module):def __init__(self, n_heads, d_embed, d_cross, in_proj_bias=True, out_proj_bias=True):super().__init__()# W^Qself.q_proj   = nn.Linear(d_embed, d_embed, bias=in_proj_bias)# W^Kself.k_proj   = nn.Linear(d_cross, d_embed, bias=in_proj_bias)# W^Vself.v_proj   = nn.Linear(d_cross, d_embed, bias=in_proj_bias)# W^Oself.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)self.n_heads = n_heads #头个数self.d_head = d_embed // n_heads # 每个头的维度 def forward(self, x, y):# x (latent): # (Batch_Size, Seq_Len_Q, Dim_Q)# y (context): # (Batch_Size, Seq_Len_KV, Dim_KV) = (Batch_Size, 77, 768)# Matrix C (Seq_Len_KV×Dim_KV)input_shape = x.shapebatch_size, sequence_length, d_embed = input_shape# Divide each embedding of Q into multiple heads such that d_heads * n_heads = Dim_Qinterim_shape = (batch_size, -1, self.n_heads, self.d_head)
# Q = X·W^Q# (Batch_Size, Seq_Len_Q, Dim_Q) -> (Batch_Size, Seq_Len_Q, Dim_Q)q = self.q_proj(x)
# K = Y·W^K# (Batch_Size, Seq_Len_KV, Dim_KV) -> (Batch_Size, Seq_Len_KV, Dim_Q)k = self.k_proj(y)
# V = Y·W^V# (Batch_Size, Seq_Len_KV, Dim_KV) -> (Batch_Size, Seq_Len_KV, Dim_Q)v = self.v_proj(y)# q.view(interim_shape): Reshapes the 'q' tensor into the specified 'interim_shape'.# .transpose(1, 2): Transposes the second and third dimensions of the reshaped tensor.	# Transpose the tensor by swapping the dimensions 1 and 2# (Batch_Size, Seq_Len_Q, Dim_Q) -> (Batch_Size, Seq_Len_Q, H, Dim_Q / H) -> (Batch_Size, H, Seq_Len_Q, Dim_Q / H)q = q.view(interim_shape).transpose(1, 2) # (Batch_Size, Seq_Len_KV, Dim_Q) -> (Batch_Size, Seq_Len_KV, H, Dim_Q / H) -> (Batch_Size, H, Seq_Len_KV, Dim_Q / H)k = k.view(interim_shape).transpose(1, 2) # (Batch_Size, Seq_Len_KV, Dim_Q) -> (Batch_Size, Seq_Len_KV, H, Dim_Q / H) -> (Batch_Size, H, Seq_Len_KV, Dim_Q / H)v = v.view(interim_shape).transpose(1, 2) 
# Q·K^T        # (Batch_Size, H, Seq_Len_Q, Dim_Q / H) @ (Batch_Size, H, Dim_Q / H, Seq_Len_KV) -> (Batch_Size, H, Seq_Len_Q, Seq_Len_KV)weight = q @ k.transpose(-1, -2)
# (Q·K^T)/sqrt{d_k}          # (Batch_Size, H, Seq_Len_Q, Seq_Len_KV)weight /= math.sqrt(self.d_head)
# softmax(Q·K^T/sqrt{d_k})        # (Batch_Size, H, Seq_Len_Q, Seq_Len_KV)weight = F.softmax(weight, dim=-1)
# softmax(Q·K^T/sqrt{d_k})·V        # (Batch_Size, H, Seq_Len_Q, Seq_Len_KV) @ (Batch_Size, H, Seq_Len_KV, Dim_Q / H) -> (Batch_Size, H, Seq_Len_Q, Dim_Q / H)output = weight @ v# Transpose the tensor by swapping the dimensions 1 and 2# Ensure the tensor is stored in contiguous memory# This is important because some operations require the tensor to be contiguous in memory# After the transpose operation, the tensor might not be stored contiguously# (Batch_Size, H, Seq_Len_Q, Dim_Q / H) -> (Batch_Size, Seq_Len_Q, H, Dim_Q / H)output = output.transpose(1, 2).contiguous()# Reshape the tensor 'output' to the shape specified by 'input_shape'# (Batch_Size, Seq_Len_Q, H, Dim_Q / H) -> (Batch_Size, Seq_Len_Q, Dim_Q)output = output.view(input_shape)
# softmax(Q·K^T/sqrt{d_k})·V·W^O        # (Batch_Size, Seq_Len_Q, Dim_Q) -> (Batch_Size, Seq_Len_Q, Dim_Q)output = self.out_proj(output)# (Batch_Size, Seq_Len_Q, Dim_Q)return output

1.5 AttentionBlock (SelfAttention+CrossAttention)

下图改编自知乎WeThinkIn

class UNET_AttentionBlock(nn.Module):def __init__(self, n_head: int, n_embd: int, d_context=768):super().__init__()channels = n_head * n_embd# GNself.groupnorm = nn.GroupNorm(32, channels, eps=1e-6)# Convself.conv_input = nn.Conv2d(channels, channels, kernel_size=1, padding=0)# LNself.layernorm_1 = nn.LayerNorm(channels)# SelfAttentionself.attention_1 = SelfAttention(n_head, channels, in_proj_bias=False)# LNself.layernorm_2 = nn.LayerNorm(channels)# CrossAttentionself.attention_2 = CrossAttention(n_head, channels, d_context, in_proj_bias=False)# LNself.layernorm_3 = nn.LayerNorm(channels)# GeGLUself.linear_geglu_1  = nn.Linear(channels, 4 * channels * 2)self.linear_geglu_2 = nn.Linear(4 * channels, channels)# Convself.conv_output = nn.Conv2d(channels, channels, kernel_size=1, padding=0)def forward(self, x, context):# x: (Batch_Size, Features, Height, Width)# context(text_embedding): (Batch_Size, Seq_Len, Dim) residue_long = x
# GN# (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height, Width)x = self.groupnorm(x)
# Conv        # (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height, Width)x = self.conv_input(x)n, c, h, w = x.shape# (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height * Width)x = x.view((n, c, h * w))# (Batch_Size, Features, Height * Width) -> (Batch_Size, Height * Width, Features)x = x.transpose(-1, -2)# Normalization + Self-Attention with skip connection
# Basci Transformer Block# (Batch_Size, Height * Width, Features)residue_short = x
## LN_1    # (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)x = self.layernorm_1(x)
## Self Attention       # (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)x = self.attention_1(x)
## Skip Connection       # (Batch_Size, Height * Width, Features) + (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)x += residue_short# (Batch_Size, Height * Width, Features)residue_short = x# Normalization + Cross-Attention with skip connection
## LN_2        # (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)x = self.layernorm_2(x)
## Cross Attention        # (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)x = self.attention_2(x, context)
## Skip Connection        # (Batch_Size, Height * Width, Features) + (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)x += residue_short# (Batch_Size, Height * Width, Features)residue_short = x
## LN_3  # Normalization + FFN with GeGLU and skip connection      # (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)x = self.layernorm_3(x)
## Feed Forward
### GeGLU        # GeGLU as implemented in the original code: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/attention.py#L37C10-L37C10# (Batch_Size, Height * Width, Features) -> two tensors of shape (Batch_Size, Height * Width, Features * 4)x, gate = self.linear_geglu_1(x).chunk(2, dim=-1) # Element-wise product: (Batch_Size, Height * Width, Features * 4) * (Batch_Size, Height * Width, Features * 4) -> (Batch_Size, Height * Width, Features * 4)x = x * F.gelu(gate)# (Batch_Size, Height * Width, Features * 4) -> (Batch_Size, Height * Width, Features)x = self.linear_geglu_2(x)
## Skip Connection        # (Batch_Size, Height * Width, Features) + (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)x += residue_short# (Batch_Size, Height * Width, Features) -> (Batch_Size, Features, Height * Width)x = x.transpose(-1, -2)# (Batch_Size, Features, Height * Width) -> (Batch_Size, Features, Height, Width)x = x.view((n, c, h, w))
# Conv + Skip Connection# Final skip connection between initial input and output of the block# (Batch_Size, Features, Height, Width) + (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height, Width)return self.conv_output(x) + residue_long

1.6 Upsample

class Upsample(nn.Module):def __init__(self, channels):super().__init__()self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)def forward(self, x):# (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height * 2, Width * 2)# 使用nearest进行上采样x = F.interpolate(x, scale_factor=2, mode='nearest') return self.conv(x)

1.7 SwitchSequential

# Define a custom sequential container class named SwitchSequential
# Inherits from nn.Sequential, which is a container module from PyTorch
class SwitchSequential(nn.Sequential):
# Define the forward method, which specifies how the input data flows through the layers
# x: the input tensor
# context: additional context information, possibly used for attention mechanisms
# time: additional time information, possibly used for temporal aspects in certain layersdef forward(self, x, context, time):for layer in self:# Check if the current layer is an instance of UNET_AttentionBlock# Pass the input tensor and context information through the attention blockif isinstance(layer, UNET_AttentionBlock):x = layer(x, context)# Check if the current layer is an instance of UNET_ResidualBlock# Pass the input tensor and time information through the residual blockelif isinstance(layer, UNET_ResidualBlock):x = layer(x, time)# For all other types of layers# Simply pass the input tensor through the layerelse:x = layer(x)return x

1.8 Unet

class UNET(nn.Module):def __init__(self):super().__init__()
# Encoderself.encoders = nn.ModuleList([# (Batch_Size, 4, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)SwitchSequential(nn.Conv2d(4, 320, kernel_size=3, padding=1)),
# CrossAttentionDownBlock2d_1 (320 channels)
## ResnetBlock+AttentionBlock           # (Batch_Size, 320, Height / 8, Width / 8) -> # (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)SwitchSequential(UNET_ResidualBlock(320, 320), UNET_AttentionBlock(8, 40)),
## ResnetBlock+AttentionBlock            # (Batch_Size, 320, Height / 8, Width / 8) -> # (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)SwitchSequential(UNET_ResidualBlock(320, 320), UNET_AttentionBlock(8, 40)),
## Downsample2D             # (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 16, Width / 16)SwitchSequential(nn.Conv2d(320, 320, kernel_size=3, stride=2, padding=1)),
# CrossAttentionDownBlock2d_2 (640 channels)  
## ResnetBlock+AttentionBlock          # (Batch_Size, 320, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16)SwitchSequential(UNET_ResidualBlock(320, 640), UNET_AttentionBlock(8, 80)),
## ResnetBlock+AttentionBlock          # (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16)SwitchSequential(UNET_ResidualBlock(640, 640), UNET_AttentionBlock(8, 80)),
## Downsample2D             # (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 32, Width / 32)SwitchSequential(nn.Conv2d(640, 640, kernel_size=3, stride=2, padding=1)),
# CrossAttentionDownBlock2d_3 (1280 channels)
## ResnetBlock+AttentionBlock           # (Batch_Size, 640, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32)SwitchSequential(UNET_ResidualBlock(640, 1280), UNET_AttentionBlock(8, 160)),
## ResnetBlock+AttentionBlock            # (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32)SwitchSequential(UNET_ResidualBlock(1280, 1280), UNET_AttentionBlock(8, 160)),
# Downsample2D
## ResnetBlock+ResnetBlock            # (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 64, Width / 64)SwitchSequential(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1)),
## DownBlock2D            # (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)SwitchSequential(UNET_ResidualBlock(1280, 1280)),# (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)SwitchSequential(UNET_ResidualBlock(1280, 1280)),])
# Bottleneckself.bottleneck = SwitchSequential(# (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)UNET_ResidualBlock(1280, 1280), # (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)UNET_AttentionBlock(8, 160), # (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)UNET_ResidualBlock(1280, 1280), )
# Decoder        self.decoders = nn.ModuleList([
# UpBlock2D
## ResnetBlock+ResnetBlock # (Batch_Size, 2560, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)SwitchSequential(UNET_ResidualBlock(2560, 1280)),# (Batch_Size, 2560, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)SwitchSequential(UNET_ResidualBlock(2560, 1280)),
# CrossAttentionUpBlock2d_3 (1280 channels)
## upsample            # (Batch_Size, 2560, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 32, Width / 32) SwitchSequential(UNET_ResidualBlock(2560, 1280), Upsample(1280)),
## ResnetBlock+AttentionBlock           # (Batch_Size, 2560, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32)SwitchSequential(UNET_ResidualBlock(2560, 1280), UNET_AttentionBlock(8, 160)),
## ResnetBlock+AttentionBlock            # (Batch_Size, 2560, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32)SwitchSequential(UNET_ResidualBlock(2560, 1280), UNET_AttentionBlock(8, 160)),
## ResnetBlock+AttentionBlock            # (Batch_Size, 1920, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 16, Width / 16)SwitchSequential(UNET_ResidualBlock(1920, 1280), UNET_AttentionBlock(8, 160), Upsample(1280)),
# CrossAttentionUpBlock2d_2 (640 channels)            # (Batch_Size, 1920, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16)SwitchSequential(UNET_ResidualBlock(1920, 640), UNET_AttentionBlock(8, 80)),
## ResnetBlock+AttentionBlock            # (Batch_Size, 1280, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16)SwitchSequential(UNET_ResidualBlock(1280, 640), UNET_AttentionBlock(8, 80)),
## ResnetBlock+AttentionBlock            # (Batch_Size, 960, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 8, Width / 8)SwitchSequential(UNET_ResidualBlock(960, 640), UNET_AttentionBlock(8, 80), Upsample(640)),
# CrossAttentionUpBlock2d_1 (1280 channels)           # (Batch_Size, 960, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)SwitchSequential(UNET_ResidualBlock(960, 320), UNET_AttentionBlock(8, 40)),
## ResnetBlock+AttentionBlock            # (Batch_Size, 640, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)SwitchSequential(UNET_ResidualBlock(640, 320), UNET_AttentionBlock(8, 40)),
## ResnetBlock+AttentionBlock            # (Batch_Size, 640, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)SwitchSequential(UNET_ResidualBlock(640, 320), UNET_AttentionBlock(8, 40)),])def forward(self, x, context, time):# x: (Batch_Size, 4, Height / 8, Width / 8)# context: (Batch_Size, Seq_Len, Dim) # time: (1, 1280)skip_connections = []
# Encoderfor layers in self.encoders:x = layers(x, context, time)skip_connections.append(x) # 将每层lay的输出添加到列表中,便于后续up中进行skip connection
# Bottleneckx = self.bottleneck(x, context, time)
# Decoderfor layers in self.decoders:# Since we always concat with the skip connection of the encoder, the number of features increases before being sent to the decoder's layerx = torch.cat((x, skip_connections.pop()), dim=1) x = layers(x, context, time)return x

1.9 OutputLayer

class UNET_OutputLayer(nn.Module):def __init__(self, in_channels, out_channels):super().__init__()self.groupnorm = nn.GroupNorm(32, in_channels)self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)def forward(self, x):# x: (Batch_Size, 320, Height / 8, Width / 8)# (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)x = self.groupnorm(x)# (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)x = F.silu(x)# (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 4, Height / 8, Width / 8)x = self.conv(x)# (Batch_Size, 4, Height / 8, Width / 8) return x

1.10 Diffusion

class Diffusion(nn.Module):def __init__(self):super().__init__()self.time_embedding = TimeEmbedding(320)self.unet = UNET()self.final = UNET_OutputLayer(320, 4)def forward(self, latent, context, time):# latent: (Batch_Size, 4, Height / 8, Width / 8)# context: (Batch_Size, Seq_Len, Dim)# time: (1, 320)# (1, 320) -> (1, 1280)time = self.time_embedding(time)# (Batch, 4, Height / 8, Width / 8) -> (Batch, 320, Height / 8, Width / 8)output = self.unet(latent, context, time)# (Batch, 320, Height / 8, Width / 8) -> (Batch, 4, Height / 8, Width / 8)output = self.final(output)# (Batch, 4, Height / 8, Width / 8)return output

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/37430.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

java大型医院绩效考核系统源码(医院为什么需要绩效机制?)医院绩效考核系统源码 医院管理绩效考核系统源码

java大型医院绩效考核系统源码(医院为什么需要绩效机制?)医院绩效考核系统源码 医院管理绩效考核系统源码 医院作为提供医疗服务的核心机构,其运营和管理效率直接影响到患者的就医体验、治疗效果以及医院的长期发展。因此&#xf…

Github 2024-06-29 Rust开源项目日报 Top10

根据Github Trendings的统计,今日(2024-06-29统计)共有10个项目上榜。根据开发语言中项目的数量,汇总情况如下: 开发语言项目数量Rust项目10Move项目1Rust编程语言的可靠异步运行时:Tokio 创建周期:2759 天开发语言:Rust协议类型:MIT LicenseStar数量:24319 个Fork数量…

什么是js?特点是什么?组成部分?

Js是一种直译式脚本语言,一种动态类型,弱类型,基于原型的高级语言。 直译式:js程序运行过程中直接编译成机器语言。 脚本语言:在程序运行过程中逐行进行解释说明,不需要预编译。 动态类型:js…

React-Native优质开源项目介绍

React Native 是一个用于构建跨平台移动应用的框架,它允许开发者使用 JavaScript 和 React 来构建 iOS 和 Android 应用。以下是一些优质的 React Native 开源项目,它们在 GitHub 上受到了广泛的认可和使用: 1. React Native Elements GitH…

JavaScript(3)——变量

声明变量 想要使用变量,首先需要创建变量 语法: let 变量名 声明变量有两部分构成:声明关键字、变量名(标识)let即关键字,关键字是系统提供的专门用来声明变量的词语let不允许多次声明同一个变量 使用变量…

代码随想录算法跟练 | Day15 | 二叉树 Part02

个人博客主页:http://myblog.nxx.nx.cn 代码GitHub地址:https://github.com/nx-xn2002/Data_Structure.git Day15 226. 翻转二叉树 题目链接: https://leetcode.cn/problems/invert-binary-tree/ 题目描述: 给你一棵二叉树的根…

构造函数的小白理解

一、实例 using System; using System.Collections; using System.Collections.Generic; using UnityEngine;//定义一个名为Question的类,用于存储问题及相关信息 [Serializable] public class Question {public string questionText;//存储题目文本字段public str…

【云原生】更改Kubernetes为ipvs代理模式

更改Kubernetes为ipvs代理模式 文章目录 更改Kubernetes为ipvs代理模式资源列表基础环境一、kube-proxy介绍1.1、userspace模式1.2、iptables代理模式1.3、ipvs代理模式 二、更改代理模式2.1、查看kube-proxy代理模式2.2、更改代理模式2.2.1、所有节点安装IPVS软件2.2.2、所有节…

Unix/Linux shell实用小程序1:生字本

前言 在日常工作学习中,我们会经常遇到一些不认识的英语单词,于时我们会打开翻译网站或者翻译软件进行查询,但是大部分工具没有生词本的功能,而有生字本的软件又需要注册登陆,免不了很麻烦,而且自己的数据…

风控图算法之中心性算法(小数据集Python版)

风控图算法之中心性算法(小数据集Python版) 图算法在金融风控领域的应用已经超越了传统的社区发现技术,这些技术曾被主要用于识别和分析欺诈性行为模式,例如黑产团伙。当前,一系列图统计算法,包括介数中心…

LoRaWAN网关源码分析(SPI篇)

目录 一、前言 二、lgw_spi_open函数 三、lgw_spi_w函数 四、lgw_spi_r函数 五、lgw_spi_wb函数 六、lgw_spi_rb函数 一、前言 本篇文章整理了LoRaWAN网关如何处理与 LoRa 前端设备之间的 SPI通信(在loralgw_spi.c文件中)。对SPI协议不了解的可以看…

Hive SQL:实现炸列(列转行)以及逆操作(行转列)

目录 列转行行转列 列转行 函数: EXPLODE(ARRAY):将ARRAY中的每一元素转换为每一行 EXPLODE(MAP):将MAP中的每个键值对转换为两行,其中一行数据包含键,另一行数据包含值 数据样例: 1、将每天的课程&#…

免费代码生成工具

领取&安装链接:Baidu Comate 领取季卡 代码自动化生成工具,软件工程师可以在ide中沉浸式写代码,自动化给出代码生成,自然语言直接输出代码。 1.Baidu Comate是什么? Baidu Comate是JetBrains/VSCode插件&#…

ServletConfig与ServletContext详解

文章目录 概要web.xmlServletConfig介绍ServletConfig实例ServletConfig细节ServletContext介绍ServletContext实例ServletContext细节ServletContext获得服务访问次数&#xff08;可拓展&#xff09;总结 概要 web.xml <?xml version"1.0" encoding"UTF-…

OBD诊断(ISO15031) 02服务

文章目录 功能简介请求和响应1、read-supported PIDs1.1、请求1.2、肯定响应 2、read PID value1.1、请求1.2、肯定响应 3、同时请求多个PID4、同时读取多个PID数据 Parameter definition报文示例1、单个PID请求和读取2、多个PID请求和读取 功能简介 02服务&#xff0c;即 Req…

索引失效的场景主要有那些

1、不满足最左匹配原则 当使用联合索引时&#xff0c;查询条件没有从最左索引列开始&#xff0c;或者跳过了索引中的列&#xff0c;那么索引可能会失效。例如&#xff0c;对于联合索引(sex, age, name)&#xff0c;如果查询条件只包含了sex和name而没有age&#xff0c;那么索引…

亚太杯赛题思路发布(中文版)

导读&#xff1a; 本文将继续修炼回归模型算法&#xff0c;并总结了一些常用的除线性回归模型之外的模型&#xff0c;其中包括一些单模型及集成学习器。 保序回归、多项式回归、多输出回归、多输出K近邻回归、决策树回归、多输出决策树回归、AdaBoost回归、梯度提升决策树回归…

oracle数据库之使用Python程序调用存储过程(二十五)

在Oracle数据库中&#xff0c;你可以使用Python程序通过Oracle的数据库适配器&#xff08;如cx_Oracle&#xff09;来调用存储过程。以下是一个简单的步骤和示例代码&#xff0c;说明如何使用Python程序调用Oracle的存储过程&#xff1a; 1. 安装cx_Oracle库 首先&#xff0c…

UI(三)布局

文章目录 1、Colum和Row——垂直方向容器和水平方向容器2、ColumnSplit和RowSplit——子组件之间插入一条分割线3、Flex——弹性布局子组件的容器4、Grid和GridItem——网格容器和网格容器单元格5、GridRow和GridCol——栅格容器组件和栅格子组件6、List、ListItem、ListItemGr…

力扣每日一题 6/28 动态规划/数组

博客主页&#xff1a;誓则盟约系列专栏&#xff1a;IT竞赛 专栏关注博主&#xff0c;后期持续更新系列文章如果有错误感谢请大家批评指出&#xff0c;及时修改感谢大家点赞&#x1f44d;收藏⭐评论✍ 2742.给墙壁刷油漆【困难】 题目&#xff1a; 给你两个长度为 n 下标从 0…