一、定义提取特征网络结构
将要实现的神经网络参数存放在列表中,方便使用。
数字代表卷积核的个数,字符代表池化层的结构
cfgs = {"vgg11": [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}
二、 定义提取特征网络
如果遍历过程中 v== 'M',就是定义池化层,后面的卷积核与stride步距都是网络的默认参数。
数字代表的就是定义卷积层,然后与激活函数链接在一起。
最后返回时,以非关键字参数的形式传入。
def make_features(cfg: list):layers = []in_channels = 3for v in cfg:if v == 'M':layers += [nn.MaxPool2d(kernel_size=2, stride=2)]else:conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)layers += [conv2d, nn.ReLU(True)]in_channels = vreturn nn.Sequential(*layers)
三、初始化网络
传入参数features,class_num,是否需要初始化权重。
定义分类网络结构,dropout方法缓解过拟合问题,再全连接核relu激活函数链接起来。
如果需要初始化权重,那么就会进入初始化权重的函数中。
class VGG(nn.Module):def __init__(self, features, class_num=1000, init_weight=False):super(VGG, self).__init__()self.features = featuresself.classifier = nn.Sequential(nn.Dropout(p=0.5),nn.Linear(512*7*7, 2048),nn.ReLU(True),nn.Dropout(p=0.5),nn.Linear(2048, 2048),nn.ReLU(True),nn.Linear(2048, class_num))if init_weight:self._initialize_weights()
四、初始化权重函数
这个函数会遍历网络的每一个子模块。
如果遍历的当前层是一个卷积层,那么这个方法会初始化卷积核的权重,如果采用了偏置,那就默认初始化为0.
如果遍历的当前层是全连接层,也是用这个方法去初始化全连接层的权重,并将偏置设置为0.
def _initialize_weights(self):for m in self.modules(): # 遍历模块中的每一个子模块if isinstance(m, nn.Conv2d):nn.init.xavier_uniform_(m.weight)if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.xavier_uniform_(m.weight)nn.init.constant_(m.bias, 0)
五、定义正向传播
x:输入的图像数据
features:提取网络特征结构
flatten:展平处理。因为第0个维度是batch,所以我们从第一个维度开始展平
经过分类网络结构后返回
def forword(self, x):x = self.features(x)x = torch.flatten(x, strat_dim=1)x = self.classifier(x)return x
六、实例化模型
传入参数model_name:实例化给定的配置模型。
将key值传入字典当中
通过VGG这个类来实例化这个网络
features通过make_features这个函数来实现
最后创建对象实现VGG神经网络的搭建。
def vgg(model_name="vgg16", **kwargs):try:cfg = cfgs[model_name]except:print("waring: model number {} not in cfgs dict".format(model_name))model = VGG(make_features(cfg), **kwargs)return modelvgg_model = vgg(model_name='vgg13')
运行成功,网络搭建成功。
全部代码
import torch.nn as nn
import torchclass VGG(nn.Module):def __init__(self, features, class_num=1000, init_weight=False):super(VGG, self).__init__()self.features = featuresself.classifier = nn.Sequential(nn.Dropout(p=0.5),nn.Linear(512*7*7, 2048),nn.ReLU(True),nn.Dropout(p=0.5),nn.Linear(2048, 2048),nn.ReLU(True),nn.Linear(2048, class_num))if init_weight:self._initialize_weights()def forword(self, x):x = self.features(x)x = torch.flatten(x, strat_dim=1)x = self.classifier(x)return xdef _initialize_weights(self):for m in self.modules(): # 遍历模块中的每一个子模块if isinstance(m, nn.Conv2d):nn.init.xavier_uniform_(m.weight)if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.xavier_uniform_(m.weight)nn.init.constant_(m.bias, 0)def make_features(cfg: list):layers = []in_channels = 3for v in cfg:if v == 'M':layers += [nn.MaxPool2d(kernel_size=2, stride=2)]else:conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)layers += [conv2d, nn.ReLU(True)]in_channels = vreturn nn.Sequential(*layers)cfgs = {"vgg11": [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}def vgg(model_name="vgg16", **kwargs):try:cfg = cfgs[model_name]except:print("waring: model number {} not in cfgs dict".format(model_name))model = VGG(make_features(cfg), **kwargs)return modelvgg_model = vgg(model_name='vgg13')
全部代码与分开模块的顺序不同,但不影响最终实现。