1、Resnet是什么?
Resnet是一种深度神经网络架构,被广泛用于计算机视觉任务,特别是图像分类。它是由微软研究院的研究员于2015年提出的,是深度学习领域的重要里程碑之一。
2、网络退化问题
理论上来讲,随着网络的层数的增加,网络能够进行更加复杂的特征提取,可以取得更好的结果。但是实验发现深度网络出现了退化问题,如下图所示。网络深度增加时,网络准确度出现饱和,之后甚至还快速下降。而且这种下降不是因为过拟合引起的,而是因为在适当的深度模型上添加更多的层会导致了更高的训练误差,从而使其下降。
图1 网络深度对比(来源:Resnet的论文)
当你使用深度神经网络进行训练时,网络层可以被看作是一系列的函数堆叠,每个函数代表一个网络层的操作,这里我们就记作。在反向传播过程中,梯度是通过链式法则逐层计算得出的。假设每个操作的梯度都小于1,因为多个小于1的数相乘可能会导致结果变得更小。在神经网络中,随着反向传播的逐层传递,梯度可能会逐渐变得非常小,甚至接近于零,这就是梯度消失问题。
而如果经过网络层操作后的输出值大于1,那么反向传播时梯度可能会相应地增大。这种情况下,梯度爆炸问题可能会出现。梯度爆炸问题指的是在深度神经网络中,梯度逐渐放大,导致底层网络的参数更新过大,甚至可能导致数值溢出。
3、残差结构
在ResNet提出之前,所有的神经网络都是通过卷积层和池化层的叠加组成的。所以,Resnet对后面计算机视觉的发展影响是巨大的。
图2 残差结构(来源:Resnet的论文)
它这里完成的一个很简单的过程,我先举一个例子:
想象一张经过神经网络处理后的低分辨率图像。为了提高图像的质量,我们引入了一个创新的思想:将原始高分辨率图像与低分辨率图像之间的差异提取出来,形成了一个残差图像。这个残差图像代表了低分辨率图像与目标高分辨率图像之间的差异或缺失的细节。
图3 残差图像
然后,我们将这个残差图像与低分辨率图像相加,得到一个结合了低分辨率信息和残差细节的新图像。这个新图像作为下一个神经网络层的输入,使网络能够同时利用原始低分辨率信息和残差细节信息进行更精确的学习。
图4 残差+低分辨率图像
通过这种方式,我们的神经网络能够逐步地从低分辨率图像中提取信息,并通过残差图像的相加操作将遗漏的细节加回来。这使得网络能够更有效地进行图像恢复或其他任务,提高了模型的性能和准确性。
我相信我已经成功表达了残差结构的思想和操作过程。其实这个思想也并非是resnet创新的,在我们过去的其他领域中早已有这种思想,ResNet将这一思想引入了计算机视觉领域,并在深度神经网络中的训练中取得了重要突破。这种创新在一定程度上解决了深层神经网络训练中的梯度消失和梯度爆炸问题,使得网络能够更深更准确地学习特征和表示。
4、Resnet网络结构
(1)对于相同的输出特征图尺寸,层具有相同数量的滤波器
(2)当feature map大小降低一半时,feature map的数量增加一倍【过滤器(可以看作是卷积核的集合)的数量增加一倍】,这保持了网络层的复杂度。然后通过步长为2的卷积层直接执行下采样。
网络结构具体如下图所示:
图5 左为VGG-19,中为34个参数层的简单网络,右为34个参数层的残差网络
ResNet相比普通网络每两层间增加了短路机制,这就形成了残差学习,其中实线表示快捷连接,虚线表示feature map数量发生了改变。
有两种情况需要考虑:
(1)输入和输出具有相同的维度时(对应实线部分):
直接使用恒等快捷连接
(2)维度增加(当快捷连接跨越两种尺寸的特征图时,它们执行时步长为2):
①快捷连接仍然执行恒等映射,额外填充零输入以增加维度。这样就不会引入额外的参数。
②用下面公式的投影快捷连接用于匹配维度(由1×1卷积完成)
论文中也提供了更详细的结构,如下图所示:
5、使用Pytorch实现Resnet
本来是按照论文手写的代码,但用的时候发现维度不匹配,用不了预训练权重,所以这里就照着pytorch源码进行了修改。
import torch
import torchvision
import torch.nn as nn
import torchsummary
from torch.hub import load_state_dict_from_urlmodel_urls = {"resnet18" : "https://download.pytorch.org/models/resnet18-f37072fd.pth","resnet34" : "https://download.pytorch.org/models/resnet34-b627a593.pth","resnet50" : "https://download.pytorch.org/models/resnet50-0676ba61.pth","resnet101" : "https://download.pytorch.org/models/resnet101-63fe2227.pth","resnet152" : "https://download.pytorch.org/models/resnet152-394f9c45.pth",
}cfgs = {"resnet18": [2, 2, 2, 2],"resnet34": [3, 4, 6, 3],"resnet50": [3, 4, 6, 3],"resnet101": [3, 4, 23, 3],"resnet152": [3, 8, 36, 3],
}def conv1x1(in_planes, out_planes, stride = 1):"""1x1 convolution"""return nn.Conv2d(in_planes, out_planes, kernel_size=(1,1), stride=(stride,stride), bias=False)
def conv3x3(in_planes, out_planes, stride= 1, groups=1, dilation=1):"""3x3 convolution with padding"""return nn.Conv2d(in_planes,out_planes,kernel_size=(3,3),stride=(stride,stride),padding=dilation,groups=groups,bias=False,dilation=(dilation,dilation))
class BasicBlock(nn.Module):expansion = 1def __init__(self,inplanes: int,planes,stride = 1,downsample = None,groups =1,base_width = 64,dilation= 1,norm_layer= None,):super(BasicBlock,self).__init__()if norm_layer is None:norm_layer = nn.BatchNorm2dif groups != 1 or base_width != 64:raise ValueError("BasicBlock only supports groups=1 and base_width=64")if dilation > 1:raise NotImplementedError("Dilation > 1 not supported in BasicBlock")self.conv1 = conv3x3(inplanes, planes, stride)self.bn1 = norm_layer(planes)self.relu = nn.ReLU(inplace=True)self.conv2 = conv3x3(planes, planes)self.bn2 = norm_layer(planes)self.downsample = downsampleself.stride = stridedef forward(self, x):identity = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)if self.downsample is not None:identity = self.downsample(x)out += identityout = self.relu(out)return outclass Bottleneck(nn.Module):expansion = 4def __init__(self,inplanes,planes,stride = 1,downsample = None,groups = 1,base_width = 64,dilation = 1,norm_layer = None,):super(Bottleneck,self).__init__()if norm_layer is None:norm_layer = nn.BatchNorm2dwidth = int(planes * (base_width / 64.0)) * groupsself.conv1 = conv1x1(inplanes, width)self.bn1 = norm_layer(width)self.conv2 = conv3x3(width, width, stride, groups, dilation)self.bn2 = norm_layer(width)self.conv3 = conv1x1(width, planes * self.expansion)self.bn3 = norm_layer(planes * self.expansion)self.relu = nn.ReLU(inplace=True)self.downsample = downsampleself.stride = stridedef forward(self, x):identity = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)if self.downsample is not None:identity = self.downsample(x)out += identityout = self.relu(out)return outclass ResNet(nn.Module):def __init__(self,block,layers,in_channels=3,num_classes = 1000,zero_init_residual = False,groups = 1,width_per_group = 64,replace_stride_with_dilation = None,):super(ResNet,self).__init__()norm_layer = nn.BatchNorm2dself._norm_layer = norm_layerself.num=num_classesself.inplanes = 64self.dilation = 1self.groups = groupsself.base_width = width_per_groupself.conv1 = nn.Conv2d(in_channels, self.inplanes, kernel_size=(7,7), stride=(2,2), padding=3, bias=False)self.bn1 = norm_layer(self.inplanes)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.layer1 = self._make_layer(block, 64, layers[0])self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=False)self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=False)self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=False)self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512 * block.expansion, self.num)for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)if zero_init_residual:for m in self.modules():if isinstance(m, Bottleneck) and m.bn3.weight is not None:nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]elif isinstance(m, BasicBlock) and m.bn2.weight is not None:nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]def _make_layer(self,block, planes, blocks, stride = 1,dilate = False,):norm_layer = self._norm_layerdownsample = Noneprevious_dilation = self.dilationif dilate:self.dilation *= stridestride = 1if stride != 1 or self.inplanes != planes * block.expansion:downsample = nn.Sequential(conv1x1(self.inplanes, planes * block.expansion, stride),norm_layer(planes * block.expansion),)layers = []layers.append(block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer))self.inplanes = planes * block.expansionfor _ in range(1, blocks):layers.append(block(self.inplanes,planes,groups=self.groups,base_width=self.base_width,dilation=self.dilation,norm_layer=norm_layer,))return nn.Sequential(*layers)def forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avgpool(x)x = torch.flatten(x, 1)x = self.fc(x)return xdef resnet(in_channels, num_classes, mode='resnet50', pretrained=False):if mode == "resnet18" or mode == "resnet34":block = BasicBlockelse:block = Bottleneckmodel = ResNet(block, cfgs[mode], in_channels=in_channels, num_classes=num_classes)if pretrained:state_dict = load_state_dict_from_url(model_urls[mode], model_dir='./model', progress=True) # 预训练模型地址if num_classes != 1000:num_new_classes = num_classesfc_weight = state_dict['fc.weight']fc_bias = state_dict['fc.bias']fc_weight_new = fc_weight[:num_new_classes, :]fc_bias_new = fc_bias[:num_new_classes]state_dict['fc.weight'] = fc_weight_newstate_dict['fc.bias'] = fc_bias_newmodel.load_state_dict(state_dict)return model
这种写法是按照先前VGG那样写的,这样有助于使用同一个model_urls和cfgs。
BasicBlock类中的init()函数是先定义网络架构,forward()的函数是前向传播,实现的功能就是残差块:
Bottleneck类是另一种blcok类型,同上,init()函数是预定义网络架构,forward函数是进行前向传播。该block中有三个卷积,分别是1x1,3x3,1x1,分别完成的功能就是维度压缩,卷积,恢复维度,所以bottleneck实现的功能就是对通道数进行压缩,再放大。
注意:这里的plane不再是输出的通道数,输出通道数应该就是plane*expansion,即4*plane。
- resnet18: BasicBlock, [2, 2, 2, 2]
- resnet34: BasicBlock, [3, 4, 6, 3]
- resnet50: Bottleneck, [3, 4, 6, 3]
- resnet101: Bottleneck, [3, 4, 23, 3]
- resnet152: Bottleneck, [3, 8, 36, 3]
这个后面的结构是作者自己挑的一个参数,所以不用管它为什么。BasicBlock主要用于resnet18和34,Bottleneck用于resnet50,101和152。