分类神经网络1:VGGNet模型复现

目录

分类网络的常见形式

VGG网络架构

VGG网络部分实现代码


分类网络的常见形式

常见的分类网络通常由特征提取部分分类部分组成。

特征提取部分实质就是各种神经网络,如VGG、ResNet、DenseNet、MobileNet等。其负责捕获数据的有用信息,一般是通过堆叠多个卷积层和池化层来实现的,这些层有助于检测图像中的边缘、纹理和特征。

分类部分通常是一个全连接层,负责将特征提取部分输出的信息映射到最终的类别或标签。这些全连接层通常包括一个或多个隐藏层,以及一个输出层,其中输出层的节点数量等于任务中的类别数量。

VGG网络架构

论文原址:https://arxiv.org/pdf/1409.1556v6.pdf

VGG 网络是由牛津大学的Visual Geometry Group 开发的,其结构特点在于使用了多个 3x3 的小卷积核,并通过这些小卷积层的重复堆叠来构建网络,从而能够捕捉到更加复杂和抽象的特征表示。VGG 网络的模型结构如下:

VGG网络的核心架构可以分为以下几个部分:

  1. 输入层:VGG网络接受224x224像素的RGB图像作为输入。
  2. 卷积层:网络的前几层由多个卷积层组成,每个卷积层都使用3x3的卷积核来提取图像的特征。这些卷积层后面通常跟着一个2x2 最大池化,用于逐步减小特征图的空间尺寸,同时增加特征深度。
  3. 池化层:在卷积层之后,网络使用最大池化层来降低特征图的空间分辨率,这有助于减少计算量并提取更加抽象的特征。
  4. 全连接层:经过多个卷积和池化层之后,网络的特征图被展平并通过几个全连接层进行处理。全连接层的作用是将学习到的特征映射到最终的分类结果。
  5. 输出层:VGG网络的最后是一个softmax层,它将网络的输出转换为概率分布,以便进行多类别的分类任务。

VGG网络的一个显著特点是其深度,其相关配置信息如下:

VGG系列不同变体内容如下:

  • VGG A:这是一个基础的配置,没有特别独特的设计。
  • VGG A-LRN:在这个版本中,加入了局部响应归一化(LRN),这是一种在AlexNet中首次使用的技术。不过,LRN在当前的深度学习实践中已经较少被采用。
  • VGG B:相较于A版本,B版本增加了两个卷积层,以增强网络的学习能力。
  • VGG C:在B的基础上,C版本进一步增加了三个卷积层,但这些层使用的是1x1的卷积核。1x1卷积核可以看作是对输入特征图进行线性变换,有助于减少参数数量并增加非线性。
  • VGG D:D版本在C版本的基础上做了调整,将1x1的卷积核替换为3x3的卷积核,这个配置后来被称为VGG16,因为它总共有16层。
  • VGG E:在D版本的基础上,E版本进一步增加了三个3x3的卷积层,形成了VGG19,总共有19层。

从图中可以看出,随着网络深度的加深,模型变得更为复杂。通常来说,增加网络的深度可以增加模型的表示能力,使其能够学习到更复杂的特征和模式,从而在某些任务上取得更好的性能。然而,随着网络深度的增加,模型的参数数量也会增加,导致模型的复杂度增加,训练和推理的计算成本也会增加,同时可能会增加过拟合的风险。

VGG网络部分实现代码

废话不多说,直接上干货

import torch
import torch.nn as nn__all__ = ["VGG", "vgg11_bn", "vgg13_bn", "vgg16_bn", "vgg19_bn"]cfg = {'A': [64,     'M', 128,      'M', 256, 256,           'M', 512, 512,           'M', 512, 512,           'M'],'B': [64, 64, 'M', 128, 128, 'M', 256, 256,           'M', 512, 512,           'M', 512, 512,           'M'],'C': [64, 64, 'M', 128, 128, 'M', 256, 256, 256,      'M', 512, 512, 512,      'M', 512, 512, 512,      'M'],'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}class ConvBNReLU(nn.Module):def __init__(self, in_channels, out_channels, stride=1,  kernel_size=3, padding=1):super(ConvBNReLU, self).__init__()self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding)self.bn = nn.BatchNorm2d(num_features=out_channels)self.relu = nn.ReLU(inplace=True)def forward(self, x):x = self.conv(x)x = self.bn(x)x = self.relu(x)return xclass VGG(nn.Module):def __init__(self, features, num_classes=1000, init_weights=True):super(VGG, self).__init__()self.features = featuresself.avgpool = nn.AdaptiveAvgPool2d((7, 7))self.classifier = nn.Sequential(nn.Linear(512 * 7 * 7, 4096),nn.ReLU(True),nn.Dropout(),nn.Linear(4096, 4096),nn.ReLU(True),nn.Dropout(),nn.Linear(4096, num_classes),)if init_weights:self._initialize_weights()def forward(self, x):for layer in self.features:x = layer(x)x = self.avgpool(x)x = torch.flatten(x, 1)x = self.classifier(x)return xdef _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0, 0.01)nn.init.constant_(m.bias, 0)def make_layers(cfg):layers = nn.ModuleList()in_channels = 3for i in cfg:if i == 'M':layers.append(nn.MaxPool2d(kernel_size=2, stride=2))else:layers.append(ConvBNReLU(in_channels=in_channels, out_channels=i))in_channels = ireturn layersdef vgg11_bn(num_classes):model = VGG(make_layers(cfg['A']), num_classes=num_classes)return modeldef vgg13_bn(num_classes):model = VGG(make_layers(cfg['B']), num_classes=num_classes)return modeldef vgg16_bn(num_classes):model = VGG(make_layers(cfg['C']), num_classes=num_classes)return modeldef vgg19_bn(num_classes):model = VGG(make_layers(cfg['D']), num_classes=num_classes)return modelif __name__=='__main__':import torchsummarydevice = 'cuda' if torch.cuda.is_available() else 'cpu'input = torch.ones(2, 3, 224, 224).to(device)net = vgg16_bn(num_classes=4)net = net.to(device)out = net(input)print(out)print(out.shape)torchsummary.summary(net, input_size=(3, 224, 224))# Total params: 134,285,380

这只是一个网络架构部分实现代码,其中 cfg 列表是 VGG 卷积和池化后的通道数,大家可以结合 VGG 的配置信息图一起对比理解。希望对大家有所帮助呀!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/2941.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

5分钟——测试搭建的springboot接口(二)

5分钟——测试搭建的springboot接口(二) 1. 查看数据库字段2. 测试getAll接口3. 测试add接口4. 测试update接口5. 测试deleteById接口 1. 查看数据库字段 2. 测试getAll接口 3. 测试add接口 4. 测试update接口 5. 测试deleteById接口

Docker 开启远程安全访问

说明 如果你的服务器是公网IP,并且开放了docker的远程访问,如果没有进行保护是非常危险的,任何人都可以向你的docker中推送镜像、运行实例。我曾开放过阿里云服务器中docker的远程访问权限,在没有开启保护的状态下,几…

用 LMDeploy 高效部署 Llama-3-8B,1.8倍vLLM推理效率

节前,我们星球组织了一场算法岗技术&面试讨论会,邀请了一些互联网大厂朋友、参加社招和校招面试的同学,针对算法岗技术趋势、大模型落地项目经验分享、新手如何入门算法岗、该如何准备、面试常考点分享等热门话题进行了深入的讨论。 汇总…

Springboot 整合 Quartz框架做定时任务

在Spring Boot中整合Quartz&#xff0c;可以实现定时任务调度的功能 1、首先&#xff0c;在pom.xml文件中添加Quartz和Spring Boot Starter Quartz的依赖&#xff1a; <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-bo…

一些好听且有心意的英文全名Burwood新南威尔士州伯伍德喝酒上脸就是乙醛中毒1. 康奈尔大学官宣恢复标化要求2. 香港城市大学(东莞)正式设立!

目录 一些好听且有心意的英文全名 Burwood新南威尔士州伯伍德 喝酒上脸就是乙醛中毒 1. 康奈尔大学官宣恢复标化要求 2. 香港城市大学&#xff08;东莞&#xff09;正式设立&#xff01; 一些好听且有心意的英文全名 在选择好听且有意义的英文全名时&#xff0c;我们可…

synchronized的底层原理

目录 介绍 实现原理 对象头 Monitor&#xff08;监视器&#xff09; 锁升级 偏向锁 轻量级锁 重量级锁 锁的优缺点 介绍 synchronized 是 Java 中的关键字&#xff0c;它用于锁定代码块或方法&#xff0c;以确保同一时刻只有一个线程可以进入被锁定的部分。这在多线程…

css盒子设置圆角边框的方法

前言 欢迎来到我的博客 个人主页&#xff1a;北岭敲键盘的荒漠猫-CSDN博客 本文为我整理的设置圆角边框的方法 需求描述 我们在设置盒子边框时&#xff0c;他总是方方正正的。 我们想让这个直直的边框委婉一点该怎么办呢。这个就提到了我们这篇文章讲的东西&#xff1a; bord…

聚观早报 | OpenAI在印度开始招聘;特斯拉将发布一季度财报

聚观早报每日整理最值得关注的行业重点事件&#xff0c;帮助大家及时了解最新行业动态&#xff0c;每日读报&#xff0c;就读聚观365资讯简报。 整理丨Cutie 4月23日消息 OpenAI在印度开始招聘 特斯拉将发布一季度财报 理想汽车全线产品降价 优酷升级悬疑剧场为白夜剧场 …

ffmpeg支持MP3编码的方法

目录 现象 解决办法 如果有编译包没有链接上的情况 现象 解决办法 在ffmpeg安装包目录下 &#xff0c;通过./configure --list-encoders 和 ./configure --list-decoders 命令可以看到&#xff0c;ffmpeg只支持mp3解码&#xff0c;但是不支持mp3编码。 上网查寻后发现&…

C++ :设计模式实现

文章目录 原则单一职责原则开闭原则依赖倒置原则接口隔离原则里氏替换原则 设计模式单例模式观察者模式策略模式代理模式 原则 单一职责原则 定义&#xff1a; 即一个类只负责一项职责 问题&#xff1a; 类 T 负责两个不同的职责&#xff1a;职责 P1&#xff0c;职责 P2。当…

Tomcat源码解析——一次请求的处理流程

在上一篇文章中&#xff0c;我们知道Tomcat在启动后&#xff0c;会在Connector中开启一个Acceptor(接收器)绑定线程然后用于监听socket的连接&#xff0c;那么当我们发出请求时&#xff0c;第一步也就是建立TCP连接&#xff0c;则会从Acceptor的run方法处进入。 Acceptor&…

使用CSS+HTML完成导航栏

HTML <!DOCTYPE html> <html lang"en"> <head> <meta charset"UTF-8"> <meta name"viewport" content"widthdevice-width, initial-scale1.0"> <title>导航栏示例</title> &l…

07 内核开发-避免命名冲突经验技巧分享

07 内核开发-避免命名冲突经验技巧分享 目录 07 内核开发-避免命名冲突经验技巧分享 1.如何在内核开发过程中&#xff0c;避免命名冲突 2. 背景 3.避免方法 4.了解下 文件/proc/kallsyms 5.总结 课程简介&#xff1a; Linux内核开发入门是一门旨在帮助学习者从最基本的…

java-异常

一、异常的概念及分类 Exception&#xff1a;异常&#xff0c;代表程序可能出现的问题 Exception分为两类&#xff1a; 1、运行时异常&#xff1a;RuntimeException以及其子类&#xff0c;编译阶段不会出现异常提醒&#xff0c;在运行阶段会出现异常提醒 2、编译时异常&…

基于SpringBoot+Vue网上商城系统的设计与实现

系统介绍 随着社会的不断进步与发展&#xff0c;人们经济水平也不断的提高&#xff0c;于是对各行各业需求也越来越高。特别是从2019年新型冠状病毒爆发以来&#xff0c;利用计算机网络来处理各行业事务这一概念更深入人心&#xff0c;由于用户工作繁忙的原因&#xff0c;去商…

抽象工厂模式设计实验

【实验内容】 楚锋软件公司欲开发一套界面皮肤库&#xff0c;可以对 Java 桌面软件进行界面美化。为了保护版权&#xff0c;该皮肤库源代码不打算公开&#xff0c;而只向用户提供已打包为 jar 文件的 class 字节码文件。用户在使用时可以通过菜单来选择皮肤&#xff0c;不同的…

Java数据类型以及范围

数据类型&#xff1a; 取值范围&#xff1a; 取值&#xff1a;

磁性呼吸传感技术与机器学习结合在COVID-19审断中的应用

介绍 呼吸不仅是人类生存的基础&#xff0c;而且其模式也是评估个体健康状态的关键指标。异常的呼吸模式往往是呼吸系统疾病的一个警示信号&#xff0c;包括但不限于慢性阻塞性肺病&#xff08;COPD&#xff09;、阻塞性睡眠呼吸暂停&#xff08;OSA&#xff09;、肺炎、囊性纤…

idea连接Docker数据库

我们在docker下创建了数据库&#xff0c;想要更方便的查看和操作该数据库&#xff0c;idea和DataGrip或者其他软件都可以。在数据库连接时需要填写数据库名字&#xff0c;主机&#xff0c;端口&#xff0c;数据库用户名和密码。 输入之后先不要点击OK和按Enter键&#xff0c;我…

GAN详解,公式推导解读,详细到每一步的理论推导

在看这一篇文章之前&#xff0c;希望熟悉掌握熵的知识&#xff0c;可看我写的跟熵相关的一篇博客https://blog.csdn.net/m0_59156726/article/details/138128622 1. GAN 原始论文&#xff1a;https://arxiv.org/pdf/1406.2661.pdf 放一张GAN的结构&#xff0c;如下&#xff1…