系列文章目录
9种经典图片分类卷积模型系列合集(推荐程度依次递减):
- Se_resnet50
- Resnet50
- Xception
- inceptionresnetv2
- resnext
- bninception
- shufflenetv2
- polynet
- vggm
Imagenet的预训练inceptionresnetv2是1000个类别,根据笔者添加了一个bottleneck层和一个head层使得可以进行自定义类别训练。
源码
from __future__ import print_function, division, absolute_import
import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
import os
import sys__all__ = ['InceptionResNetV2', 'inceptionresnetv2']pretrained_settings = {'inceptionresnetv2': {'imagenet': {'url': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth','input_space': 'RGB','input_size': [3, 299, 299],'input_range': [0, 1],'mean': [0.5, 0.5, 0.5],'std': [0.5, 0.5, 0.5],'num_classes': 1000},'imagenet+background': {'url': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth','input_space': 'RGB','input_size': [3, 299, 299],'input_range': [0, 1],'mean': [0.5, 0.5, 0.5],'std': [0.5, 0.5, 0.5],'num_classes': 1001}}
}def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):"""3x3 convolution with padding"""return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,padding=dilation, groups=groups, bias=False, dilation=dilation)def conv1x1(in_planes, out_planes, stride=1):"""1x1 convolution"""return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)class BasicBlock(nn.Module):expansion = 1__constants__ = ['downsample']def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,base_width=64, dilation=1, norm_layer=None):super(BasicBlock, self).__init__()if norm_layer is None:norm_layer = nn.BatchNorm2dif groups != 1 or base_width != 64:raise ValueError('BasicBlock only supports groups=1 and base_width=64')if dilation > 1:raise NotImplementedError("Dilation > 1 not supported in BasicBlock")# Both self.conv1 and self.downsample layers downsample the input when stride != 1self.conv1 = conv3x3(inplanes, planes, stride)self.bn1 = norm_layer(planes)self.relu = nn.ReLU(inplace=True)self.conv2 = conv3x3(planes, planes)self.bn2 = norm_layer(planes)self.downsample = downsampleself.stride = stridedef forward(self, x):identity = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)if self.downsample is not None:identity = self.downsample(x)out += identityout = self.relu(out)return outclass Bottleneck(nn.Module):expansion = 4__constants__ = ['downsample']def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,base_width=64, dilation=1, norm_layer=None):super(Bottleneck, self).__init__()if norm_layer is None:norm_layer = nn.BatchNorm2dwidth = int(planes * (base_width / 64.)) * groupsself.conv1 = conv1x1(inplanes, width)self.bn1 = norm_layer(width)self.conv2 = conv3x3(width, width, stride, groups, dilation)self.bn2 = norm_layer(width)self.conv3 = conv1x1(width, planes * self.expansion)self.bn3 = norm_layer(planes * self.expansion)self.relu = nn.ReLU(inplace=True)self.downsample = downsampleself.stride = stridedef forward(self, x):identity = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)if self.downsample is not None:identity = self.downsample(x)out += identityout = self.relu(out)return outclass BasicConv2d(nn.Module):def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):super(BasicConv2d, self).__init__()self.conv = nn.Conv2d(in_planes, out_planes,kernel_size=kernel_size, stride=stride,padding=padding, bias=False) # verify bias falseself.bn = nn.BatchNorm2d(out_planes,eps=0.001, # value found in tensorflowmomentum=0.1, # default pytorch valueaffine=True)self.relu = nn.ReLU(inplace=False)def forward(self, x):x = self.conv(x)x = self.bn(x)x = self.relu(x)return xclass Mixed_5b(nn.Module):def __init__(self):super(Mixed_5b, self).__init__()self.branch0 = BasicConv2d(192, 96, kernel_size=1, stride=1)self.branch1 = nn.Sequential(BasicConv2d(192, 48, kernel_size=1, stride=1),BasicConv2d(48, 64, kernel_size=5, stride=1, padding=2))self.branch2 = nn.Sequential(BasicConv2d(192, 64, kernel_size=1, stride=1),BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1),BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1))self.branch3 = nn.Sequential(nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),BasicConv2d(192, 64, kernel_size=1, stride=1))def forward(self, x):x0 = self.branch0(x)x1 = self.branch1(x)x2 = self.branch2(x)x3 = self.branch3(x)out = torch.cat((x0, x1, x2, x3), 1)return outclass Block35(nn.Module):def __init__(self, scale=1.0):super(Block35, self).__init__()self.scale = scaleself.branch0 = BasicConv2d(320, 32, kernel_size=1, stride=1)self.branch1 = nn.Sequential(BasicConv2d(320, 32, kernel_size=1, stride=1),BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1))self.branch2 = nn.Sequential(BasicConv2d(320, 32, kernel_size=1, stride=1),BasicConv2d(32, 48, kernel_size=3, stride=1, padding=1),BasicConv2d(48, 64, kernel_size=3, stride=1, padding=1))self.conv2d = nn.Conv2d(128, 320, kernel_size=1, stride=1)self.relu = nn.ReLU(inplace=False)def forward(self, x):x0 = self.branch0(x)x1 = self.branch1(x)x2 = self.branch2(x)out = torch.cat((x0, x1, x2), 1)out = self.conv2d(out)out = out * self.scale + xout = self.relu(out)return outclass Mixed_6a(nn.Module):def __init__(self):super(Mixed_6a, self).__init__()self.branch0 = BasicConv2d(320, 384, kernel_size=3, stride=2)self.branch1 = nn.Sequential(BasicConv2d(320, 256, kernel_size=1, stride=1),BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1),BasicConv2d(256, 384, kernel_size=3, stride=2))self.branch2 = nn.MaxPool2d(3, stride=2)def forward(self, x):x0 = self.branch0(x)x1 = self.branch1(x)x2 = self.branch2(x)out = torch.cat((x0, x1, x2), 1)return outclass Block17(nn.Module):def __init__(self, scale=1.0):super(Block17, self).__init__()self.scale = scaleself.branch0 = BasicConv2d(1088, 192, kernel_size=1, stride=1)self.branch1 = nn.Sequential(BasicConv2d(1088, 128, kernel_size=1, stride=1),BasicConv2d(128, 160, kernel_size=(1,7), stride=1, padding=(0,3)),BasicConv2d(160, 192, kernel_size=(7,1), stride=1, padding=(3,0)))self.conv2d = nn.Conv2d(384, 1088, kernel_size=1, stride=1)self.relu = nn.ReLU(inplace=False)def forward(self, x):x0 = self.branch0(x)x1 = self.branch1(x)out = torch.cat((x0, x1), 1)out = self.conv2d(out)out = out * self.scale + xout = self.relu(out)return outclass Mixed_7a(nn.Module):def __init__(self):super(Mixed_7a, self).__init__()self.branch0 = nn.Sequential(BasicConv2d(1088, 256, kernel_size=1, stride=1),BasicConv2d(256, 384, kernel_size=3, stride=2))self.branch1 = nn.Sequential(BasicConv2d(1088, 256, kernel_size=1, stride=1),BasicConv2d(256, 288, kernel_size=3, stride=2))self.branch2 = nn.Sequential(BasicConv2d(1088, 256, kernel_size=1, stride=1),BasicConv2d(256, 288, kernel_size=3, stride=1, padding=1),BasicConv2d(288, 320, kernel_size=3, stride=2))self.branch3 = nn.MaxPool2d(3, stride=2)def forward(self, x):x0 = self.branch0(x)x1 = self.branch1(x)x2 = self.branch2(x)x3 = self.branch3(x)out = torch.cat((x0, x1, x2, x3), 1)return outclass Block8(nn.Module):def __init__(self, scale=1.0, noReLU=False):super(Block8, self).__init__()self.scale = scaleself.noReLU = noReLUself.branch0 = BasicConv2d(2080, 192, kernel_size=1, stride=1)self.branch1 = nn.Sequential(BasicConv2d(2080, 192, kernel_size=1, stride=1),BasicConv2d(192, 224, kernel_size=(1,3), stride=1, padding=(0,1)),BasicConv2d(224, 256, kernel_size=(3,1), stride=1, padding=(1,0)))self.conv2d = nn.Conv2d(448, 2080, kernel_size=1, stride=1)if not self.noReLU:self.relu = nn.ReLU(inplace=False)def forward(self, x):x0 = self.branch0(x)x1 = self.branch1(x)out = torch.cat((x0, x1), 1)out = self.conv2d(out)out = out * self.scale + xif not self.noReLU:out = self.relu(out)return outclass InceptionResNetV2(nn.Module):def __init__(self, num_classes=1001, zero_init_residual=False):super(InceptionResNetV2, self).__init__()# Special attributsself.input_space = Noneself.input_size = (299, 299, 3)self.mean = Noneself.std = None# Modulesself.conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2)self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1)self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1)self.maxpool_3a = nn.MaxPool2d(3, stride=2)self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1)self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1)self.maxpool_5a = nn.MaxPool2d(3, stride=2)self.mixed_5b = Mixed_5b()self.repeat = nn.Sequential(Block35(scale=0.17),Block35(scale=0.17),Block35(scale=0.17),Block35(scale=0.17),Block35(scale=0.17),Block35(scale=0.17),Block35(scale=0.17),Block35(scale=0.17),Block35(scale=0.17),Block35(scale=0.17))self.mixed_6a = Mixed_6a()self.repeat_1 = nn.Sequential(Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10),Block17(scale=0.10))self.mixed_7a = Mixed_7a()self.repeat_2 = nn.Sequential(Block8(scale=0.20),Block8(scale=0.20),Block8(scale=0.20),Block8(scale=0.20),Block8(scale=0.20),Block8(scale=0.20),Block8(scale=0.20),Block8(scale=0.20),Block8(scale=0.20))self.block8 = Block8(noReLU=True)self.conv2d_7b = BasicConv2d(2080, 1536, kernel_size=1, stride=1)self.avgpool_1a = nn.AvgPool2d(5, count_include_pad=False)self.bottleneck = nn.Sequential(nn.Linear(1536, 512),nn.BatchNorm1d(512),nn.ReLU(),nn.Dropout(0.5))self.bottleneck[0].weight.data.normal_(0, 0.005)self.bottleneck[0].bias.data.fill_(0.1)self.head = nn.Sequential(nn.Linear(512, 512),nn.ReLU(),nn.Dropout(0.5),nn.Linear(512, num_classes))# self.fc = nn.Linear(512, num_classes)for dep in range(2):self.head[dep * 3].weight.data.normal_(0, 0.01)self.head[dep * 3].bias.data.fill_(0.0)for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)if zero_init_residual:for m in self.modules():if isinstance(m, Bottleneck):nn.init.constant_(m.bn3.weight, 0)elif isinstance(m, BasicBlock):nn.init.constant_(m.bn2.weight, 0)def features(self, input):x = self.conv2d_1a(input)x = self.conv2d_2a(x)x = self.conv2d_2b(x)x = self.maxpool_3a(x)x = self.conv2d_3b(x)x = self.conv2d_4a(x)x = self.maxpool_5a(x)x = self.mixed_5b(x)x = self.repeat(x)x = self.mixed_6a(x)x = self.repeat_1(x)x = self.mixed_7a(x)x = self.repeat_2(x)x = self.block8(x)x = self.conv2d_7b(x)return xdef logits(self, features):x = self.avgpool_1a(features)# print("x1.size={}".format(x.shape))x = x.view(x.size(0), -1)# print("x2.size={}".format(x.shape))x = self.bottleneck(x)x = self.head(x)# x = self.last_linear(x)return xdef forward(self, input):x = self.features(input)# print("x0.size={}".format(x.shape))x = self.logits(x)return xdef inceptionresnetv2(num_classes=1000, pretrained='imagenet'):r"""InceptionResNetV2 model architecture from the`"InceptionV4, Inception-ResNet..." <https://arxiv.org/abs/1602.07261>`_ paper."""if pretrained:pretrained = 'imagenet+background'num_classes_hat = 1001settings = pretrained_settings['inceptionresnetv2'][pretrained]# print(settings)# print('num=%d\n',num_classes)# assert num_classes == settings['num_classes'], \# "num_classes should be {}, but is {}".format(settings['num_classes'], num_classes)# both 'imagenet'&'imagenet+background' are loaded from same parametersmodel = InceptionResNetV2(num_classes=num_classes_hat)model.load_state_dict(model_zoo.load_url(settings['url']), strict=False)# if pretrained == 'imagenet+background':# # print("yes")# # model.last_linear = nn.Linear(1536, num_classes).cuda()# new_last_linear = nn.Linear(1536, num_classes).cuda()# new_last_linear.weight.data = model.last_linear.weight.data[1:]# new_last_linear.bias.data = model.last_linear.bias.data[1:]# model.last_linear = new_last_linearmodel.input_space = settings['input_space']model.input_size = settings['input_size']model.input_range = settings['input_range']model.mean = settings['mean']model.std = settings['std']else:model = InceptionResNetV2(num_classes=num_classes)return model'''
TEST
Run this code with:
cd $HOME/pretrained-models.pytorch
python -m pretrainedmodels.inceptionresnetv2
'''
if __name__ == '__main__':assert inceptionresnetv2(num_classes=10, pretrained=None)print('success')assert inceptionresnetv2(num_classes=1000, pretrained='imagenet')print('success')assert inceptionresnetv2(num_classes=1001, pretrained='imagenet+background')print('success')# failassert inceptionresnetv2(num_classes=1001, pretrained='imagenet')