经典卷积神经网络 - VGG

使用块的网络 - VGG。

使用多个 3 × 3 3\times 3 3×3的要比使用少个 5 × 5 5\times 5 5×5的效果要好。
在这里插入图片描述
在这里插入图片描述
VGG全称是Visual Geometry Group,因为是由Oxford的Visual Geometry Group提出的。AlexNet问世之后,很多学者通过改进AlexNet的网络结构来提高自己的准确率,主要有两个方向:小卷积核和多尺度。而VGG的作者们则选择了另外一个方向,即加深网络深度。

网络架构

卷积网络的输入是224 * 224RGB图像,整个网络的组成是非常格式化的,基本上都用的是3 * 3的卷积核以及 2 * 2max pooling,少部分网络加入了1 * 1的卷积核。因为想要体现出“上下左右中”的概念,3*3的卷积核已经是最小的尺寸了。

VGG16相比之前网络的改进是3个33卷积核来代替7x7卷积核,2个33卷积核来代替5*5卷积核,这样做的主要目的是在保证具有相同感知野的条件下,减少参数,提升了网络的深度。

多个VGG块后接全连接层。

不同次数的重复块得到不同的架构,如VGG-16,VGG-19等。

VGG:更大更深的AlexNet。

总结:

  • VGG使用可重复使用的卷积块来构建深度卷积神经网络
  • 不同的卷积块个数和超参数可以得到不同复杂度的变种

代码实现

使用数据集CIFAR

model.py

import torch
from torch import nnclass Vgg16(nn.Module):def __init__(self, *args, **kwargs) -> None:super().__init__(*args, **kwargs)self.model = nn.Sequential(nn.Conv2d(3,64,kernel_size=3,padding=1),nn.ReLU(),nn.Conv2d(64,64,kernel_size=3,padding=1),nn.ReLU(),nn.MaxPool2d(2,2),nn.Conv2d(64,128,kernel_size=3,padding=1),nn.ReLU(),nn.Conv2d(128,128,kernel_size=3,padding=1),nn.ReLU(),nn.MaxPool2d(2,2),nn.Conv2d(128,256,kernel_size=3,padding=1),nn.ReLU(),nn.Conv2d(256,256,kernel_size=3,padding=1),nn.ReLU(),nn.Conv2d(256,256,kernel_size=3,padding=1),nn.ReLU(),nn.MaxPool2d(2,2),nn.Conv2d(256,512,kernel_size=3,padding=1),nn.ReLU(),nn.Conv2d(512,512,kernel_size=3,padding=1),nn.ReLU(),nn.Conv2d(512,512,kernel_size=3,padding=1),nn.ReLU(),nn.MaxPool2d(2,2),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(2,2),nn.Flatten(),nn.Linear(7*7*512,4096),nn.Dropout(0.5),nn.Linear(4096,4096),nn.Dropout(0.5),nn.Linear(4096,10))def forward(self,x):return self.model(x)# 验证模型正确性
if __name__ == '__main__':net = Vgg16()x = torch.ones((64,3,244,244))output = net(x)print(output)

train.py

import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets
from torchvision.transforms import transforms
from model import Vgg16# 扫描数据次数
epochs = 3
# 分组大小
batch = 64
# 学习率
learning_rate = 0.01
# 训练次数
train_step = 0
# 测试次数
test_step = 0# 定义图像转换
transform = transforms.Compose([transforms.Resize(224),transforms.ToTensor()
])
# 读取数据
train_dataset = datasets.CIFAR10(root="./dataset",train=True,transform=transform,download=True)
test_dataset = datasets.CIFAR10(root="./dataset",train=False,transform=transform,download=True)
# 加载数据
train_dataloader = DataLoader(train_dataset,batch_size=batch,shuffle=True,num_workers=0)
test_dataloader = DataLoader(test_dataset,batch_size=batch,shuffle=True,num_workers=0)
# 数据大小
train_size = len(train_dataset)
test_size = len(test_dataset)
print("训练集大小:{}".format(train_size))
print("验证集大小:{}".format(test_size))# GPU
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(device)
# 创建网络
net = Vgg16()
net = net.to(device)
# 定义损失函数
loss = nn.CrossEntropyLoss()
loss = loss.to(device)
# 定义优化器
optimizer = torch.optim.SGD(net.parameters(),lr=learning_rate)writer = SummaryWriter("logs")
# 训练
for epoch in range(epochs):print("-------------------第 {} 轮训练开始-------------------".format(epoch))net.train()for data in train_dataloader:train_step = train_step + 1images,targets = dataimages = images.to(device)targets = targets.to(device)outputs = net(images)loss_out = loss(outputs,targets)optimizer.zero_grad()loss_out.backward()optimizer.step()if train_step%100==0:writer.add_scalar("Train Loss",scalar_value=loss_out.item(),global_step=train_step)print("训练次数:{},Loss:{}".format(train_step,loss_out.item()))# 测试net.eval()total_loss = 0total_accuracy = 0with torch.no_grad():for data in test_dataloader:test_step = test_step + 1images, targets = dataimages = images.to(device)targets = targets.to(device)outputs = net(images)loss_out = loss(outputs, targets)total_loss = total_loss + loss_outaccuracy = (targets == torch.argmax(outputs,dim=1)).sum()total_accuracy = total_accuracy + accuracy# 计算精确率print(total_accuracy)accuracy_rate = total_accuracy / test_sizeprint("第 {} 轮,验证集总损失为:{}".format(epoch+1,total_loss))print("第 {} 轮,精确率为:{}".format(epoch+1,accuracy_rate))writer.add_scalar("Test Total Loss",scalar_value=total_loss,global_step=epoch+1)writer.add_scalar("Accuracy Rate",scalar_value=accuracy_rate,global_step=epoch+1)torch.save(net,"./model/net_{}.pth".format(epoch+1))print("模型net_{}.pth已保存".format(epoch+1))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/116279.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

TDengine小知识-数据文件命名规则

TDengine 时序数据库对数据文件有自己的命名规则,文件名中包含了vnodeID、时间范围、版本、文件类型等多种信息。了解数据文件命名规则,可以让运维工作更简单。 废话不多说,直接上图: v4:文件所属 Vgroup 组&#xf…

leetcode:2347. 最好的扑克手牌(python3解法)

难度:简单 给你一个整数数组 ranks 和一个字符数组 suit 。你有 5 张扑克牌,第 i 张牌大小为 ranks[i] ,花色为 suits[i] 。 下述是从好到坏你可能持有的 手牌类型 : "Flush":同花,五张相同花色的…

安装visual studio报错“无法安装msodbcsql“

在安装visual studio2022时安装完成后提示无法安装msodbcsql, 查看日志文件详细信息提示:指定账户已存在。 未能安装包“msodbcsql,version17.2.30929.1,chipx64,languagezh-CN”。 搜索 URL https://aka.ms/VSSetupErrorReports?qPackageIdmsodbcsql;PackageActi…

用matlab求解线性规划

文章目录 1、用单纯形表求解线性规划绘制单纯形表求解: 2、用matlab求解线性规划——linprog()函数问题:补充代码:显示出完整的影子价格向量 1、用单纯形表求解线性规划 求解线性规划 m i n − 3 x 1 − 4 x 2 x 3 min -3x_1-4x_2x_3 min−…

【ArcGIS模型构建器】04:根据矢量范围批量裁剪影像栅格数据

本文以中国2000-2010-2020年3期GLC30土地覆盖数据为例,演示用模型构建器批量裁剪出四川省3年的数据。 文章目录 一、结果预览二、模型构建三、运行模型四、注意事项一、结果预览 用四川省行政区数据裁剪出的3年Globeland30(配套实验数据data04.rar中有三年中国区域成品数据)…

Java编写图片转base64

图片转成base64 url , 在我们的工作中也会经常用到,比如说导出 word,pdf 等功能,今天我们尝试写一下。 File file new File("");byte[] data null;InputStream in null;ByteArrayOutputStream out null;try{URL url new URL(&…

NAS搭建指南三——私人云盘

一、私人云盘选择 我选择的是可道云进行私人云盘的搭建可道云官网地址可道云下载地址,下载服务器端和 Windows 客户端可道云官方文档 二、环境配置 PHP 与 MySQL 环境安装:XAMPP 官网地址 下载最新的 windows 版本 安装时只勾选 MySQL 与 PHP相关即可…

信号继电器驱动芯片(led驱动芯片)

驱动继电器需要配合BAV99(防止反向脉冲)使用 具体应用参考开源项目 电阻箱 sbstnh/programmable_precision_resistor: A SCPI programmable precision resistor (github.com) 这个是芯片的输出电流设置 对应到上面的实际开源项目其设置电阻为1.5K&…

侯捷C++面向对象程序设计笔记(上)-Object Based(基于对象)部分

基于对象就是对于单一class的设计。 对于有指针的:complex.h complex-test.cpp 对于没有指针的: string.h string-test.cpp https://blog.csdn.net/ncepu_Chen/article/details/113843775?spm1001.2014.3001.5501#commentBox 没有指针成员——以复数co…

力扣第55题 跳跃游戏 c++ 贪心 + 覆盖 加暴力超时参考

题目 55. 跳跃游戏 中等 相关标签 贪心 数组 动态规划 给你一个非负整数数组 nums ,你最初位于数组的 第一个下标 。数组中的每个元素代表你在该位置可以跳跃的最大长度。 判断你是否能够到达最后一个下标,如果可以,返回 true &…

【RocketMQ系列十二】RocketMQ集群核心概念之主从复制生产者负载均衡策略消费者负载均衡策略

您好,我是码农飞哥(wei158556),感谢您阅读本文,欢迎一键三连哦。 💪🏻 1. Python基础专栏,基础知识一网打尽,9.9元买不了吃亏,买不了上当。 Python从入门到精…

自动驾驶之—车道线感知

零、前言 : 最近在学习自动驾驶方向的东西,简单整理一些学习笔记,学习过程中发现宝藏up 手写AI 一、视觉系统坐标系 视觉系统一共有四个坐标系:像素平面坐标系(u,v)、图像坐标系(x,y&#xff09…

开源WAF--Safeline(雷池)测试手册

长亭科技—雷池(SafeLine)社区版 官方网站:长亭雷池 WAF 社区版 (chaitin.cn) WAF 工作在应用层,对基于 HTTP/HTTPS 协议的 Web 系统有着更好的防护效果,使其免于受到黑客的攻击 1.1 雷池的搭建 1.1.1 配置需求 操作系统:Linux 指令架构&am…

【数据结构与算法】二叉树的运用要点

目录 一,二叉树的结构深入认识 二,二叉树的遍历 三,二叉树的基本运算 3-1,计算二叉树的大小 3-2,统计二叉树叶子结点个数 3-3,计算第k层的节点个数 3-4,查找指定值的结点 一,二叉…

Spring底层原理(二)

Spring底层原理(二) BeanFactory的实现 //创建BeanFactory对象 DefaultListableBeanFactory factory new DefaultListableBeanFactory(); //注册Bean定义对象 AbstractBeanDefinition beanDefinition BeanDefinitionBuilder.genericBeanDefinition(SpringConfig.class).set…

搭建Pytorch的GPU环境超详细

效果 1、下载和安装VS2019 https://visualstudio.microsoft.com/zh-hans/vs/older-downloads/ 登录需要用户名和密码 安装后需要联网下载组件的,安装的时候要勾选使用C++的桌面开发 2、下载和安装显卡驱动 查看自己的显卡型号 从英伟达下载和安装最新驱动

Python学习笔记——MYSQL,SQL核心

食用说明:本笔记适用于有一定编程基础的伙伴们。希望有助于各位! SQL语言分类 SQL注释 库管理 表管理 数据操作 分组聚合 分页限制 需要注意的是关键字的顺序不可以错乱,否则会报错其中LIMIT关键字的n是指从第n个开始,m是指查…

【Zero to One系列】微服务Hystrix的熔断器集成

前期回顾: 【Zero to One系列】springcloud微服务集成nacos,形成分布式系统 【Zero to One系列】SpringCloud Gateway结合Nacos完成微服务的网关路由 1、hystrix依赖包 首先引入hystrix相关的依赖包,版本方面自己和项目内相对应即可&#…

Electron 学习

Electron基本简介 如果你可以建一个网站,你就可以建一个桌面应用程序。Eletron 是一个使用 JavaScript, HTML和 CSS等Web 技术创建原生程序的框架,它负责比较难搞的部分,你只需把精力放在你的应用的核心上即可。 Electron 可以让你使用纯 Jav…

kr第三阶段(二)32 位汇编

编译与链接 环境配置 masm32 masm32 是微软的 masm32 的民间工具集合。该工具集合除了 asm32 本身的汇编器 ml 外还提供了: SDK 对应的函数声明头文件和 lib 库。32 位版本的 link(原版本是 16 位,这里的 32 位版本的 link 来自 VC 6.0&a…