文章目录
- 1 残差结构回顾
- 2 LDM结构中的残差结构设计
- 2.1 组归一化GroupNorm层
- 2.2 激活函数层
- 2.3 卷积层
- 2.4 dropout层
- 3 代码实现
1 残差结构回顾
残差结构应该是非常重要的基础块之一了,你肯定会在各种各样的网络模型结构里看到残差结构,他是非常强大的~
残差结构往往可以抽象为如下的结构
为什么要设置残差模块?
因为在深度学习中,网络层级很多复杂,非常重要的一部分是前后的梯度信息流动,否则很容易发生梯度爆炸或梯度消失
想象一下我们开车去一个遥远的目的地。我们可以选择直接开到目的地,也可以选择在途中设置几个“中转站”,但是中转站多了有可能会丢失最终目的地的信息,中途有好玩的就被吸引了,所以我们有时候要直接开车前往目的地,避免一些信息的丢失或遗忘。
2 LDM结构中的残差结构设计
第一阶段的残差块是主要的部分
负责不断将送进来的图像进行特征提取~,是多个ResnetBlock前后堆叠起来形成的
- 其中中间部分的ResnetBlock是不改变特征图的通道和尺寸大小
- 最后一个ResnetBlock将通道数加倍
2.1 组归一化GroupNorm层
不是用的传统的批归一化BN,而是用的组归一化GN
是因为BN在Batch_size比较小的时候,表现很差,而我们在图像生成的实际任务中,由于分辨率比较大,所以Batch_size往往比较小
这时候GN的效果会更好
GN实现方式就是按照通道去分组,每组各自归一化,比如图例中的通道数是128,分为32组,那么每组就是4个通道进行归一化
像了解更多归一化知识可以看文章
全面解读Group Normalization
2.2 激活函数层
没有采用传统的ReLU什么的,而是采用了一个组合形式
这是作者尝试不同实验效果比较好的
其实启发我们在自己的平时网络设计过程中,也可以进行损失函数的调整
2.3 卷积层
卷积层就很熟悉了
如果是中间层的ResNetBlock我们要实现两点
- 维持前后通道数不变
- 维持前后尺寸大小不变
所以用了这样的设计
- 小卷积核3*3的大小,步长1,填充1
经过以后,尺寸大小可以维持不变
公式
特征图长或宽 − 卷积核尺寸 + 2 ∗ 填充尺寸 步长 + 1 \frac{特征图长或宽-卷积核尺寸+2*填充尺寸}{步长}+1 步长特征图长或宽−卷积核尺寸+2∗填充尺寸+1
如果是最后的ResNetBlock
则进行通道数的改变
2.4 dropout层
防止过拟合的经典操作,让部分神经元失活(变为0)组织信息传递,避免模型能力过强
最后再经过一个残差输出即可
3 代码实现
from torch import nn
import torchclass ResnetBlock(nn.Module):def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,dropout):super().__init__()self.in_channels = in_channelsout_channels = in_channels if out_channels is None else out_channelsself.out_channels = out_channelsself.use_conv_shortcut = conv_shortcutself.norm1 = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)self.conv1 = torch.nn.Conv2d(in_channels,out_channels,kernel_size=3,stride=1,padding=1)self.norm2 = torch.nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)self.dropout = torch.nn.Dropout(dropout)self.conv2 = torch.nn.Conv2d(out_channels,out_channels,kernel_size=3,stride=1,padding=1)if self.in_channels != self.out_channels:#如果分辨率改变则进行更小的卷积核if self.use_conv_shortcut:self.conv_shortcut = torch.nn.Conv2d(in_channels,out_channels,kernel_size=3,stride=1,padding=1)else:self.nin_shortcut = torch.nn.Conv2d(in_channels,out_channels,kernel_size=1,stride=1,padding=0)def forward(self, x):h = xh = self.norm1(h) #靠后面的时候有问题 shape [80, 256, 8, 8]h = h*torch.sigmoid(h)h = self.conv1(h)h = self.norm2(h)h = h*torch.sigmoid(h)h = self.dropout(h)h = self.conv2(h)if self.in_channels != self.out_channels:if self.use_conv_shortcut:x = self.conv_shortcut(x)else:x = self.nin_shortcut(x)return x+h
其中有几个特别的参数需要说明
- in_channels 就是输入ResNet的网络通道数,out_channels 就是输出ResNet块的网络通道数,注意后者默认是None,这样的话就不会改变通道数,适用于中间ResNet块部分,最后一层的ResNet块要给出out_channels,方便做通道更改
- use_conv_shortcut决定着最后一层的ResNet改变通道数的时候的方式,为True的话则是size=3的卷积核,而如果是false的话则是size为1的小卷积核
- dropout是失活比率
注:与源码相比,简化了time_emb相关的内容,因为在源码中这部分也没有排上用场