7.2 手撕VGG11模型 使用Fashion_mnist数据训练VGG

VGG首先引入块的思想将模型通用模板化

VGG模型的特点

与AlexNet,LeNet一样,VGG网络可以分为两部分,第一部分主要由卷积层和汇聚层组成,第二部分由全连接层组成。

VGG有5个卷积块,前两个块包含一个卷积层,后三个块包含两个卷积层。 2 * 1 + 3 * 2 = 8个卷积层和后面3个全连接层,所以它被称为VGG11

AlexNet模型架构与VGG模型架构对比

在这里插入图片描述

import torch
from torch import nn
from d2l import torch as d2l
import time
# 卷积块函数
def vgg_block(num_convs,in_channels,out_channels):layers = []for _ in range(num_convs):layers.append(nn.Conv2d(in_channels,out_channels,kernel_size=3,padding=1))layers.append(nn.ReLU())in_channels = out_channelslayers.append(nn.MaxPool2d(kernel_size=2,stride=2))'''`nn.Sequential(*layers)`中的`*layers`将会展开`layers`列表,将其中的每个层作为单独的参数传递给`nn.Sequential`函数,以便构建一个顺序模型。'''return nn.Sequential(*layers)
# 定义卷积块的输入输出
conv_arch = ((1,64),(1,128),(2,256),(2,512),(2,512))
# VGG有5个卷积块,前两个块包含一个卷积层,后三个块包含两个卷积层。 2 * 1 + 3 * 2 = 8个卷积层和后面3个全连接层,所以它被称为VGG11
def vgg(conv_arch):conv_blks = []in_channels = 1# 卷积层部分for (num_convs,out_channels) in conv_arch:conv_blks.append(vgg_block(num_convs,in_channels,out_channels))in_channels = out_channelsreturn nn.Sequential(# 5个卷积块部分*conv_blks,nn.Flatten(),# 3个全连接部分nn.Linear(out_channels*7*7,4096),nn.ReLU(),nn.Dropout(0.5),nn.Linear(4096,4096),nn.ReLU(),nn.Dropout(0.5),nn.Linear(4096,10))
net = vgg(conv_arch)
X = torch.randn(size=(1,1,224,224))
for blk in net:X = blk(X)print(blk.__class__.__name__,'output shape:\t',X.shape)
Sequential output shape:	 torch.Size([1, 64, 112, 112])
Sequential output shape:	 torch.Size([1, 128, 56, 56])
Sequential output shape:	 torch.Size([1, 256, 28, 28])
Sequential output shape:	 torch.Size([1, 512, 14, 14])
Sequential output shape:	 torch.Size([1, 512, 7, 7])
Flatten output shape:	 torch.Size([1, 25088])
Linear output shape:	 torch.Size([1, 4096])
ReLU output shape:	 torch.Size([1, 4096])
Dropout output shape:	 torch.Size([1, 4096])
Linear output shape:	 torch.Size([1, 4096])
ReLU output shape:	 torch.Size([1, 4096])
Dropout output shape:	 torch.Size([1, 4096])
Linear output shape:	 torch.Size([1, 10])

为了使用Fashion-MNIST数据集,使用缩小VGG11的通道数的VGG11

# 由于VGG11比AlexNet计算量更大,所以构建一个通道数校小的网络
ratio = 4
# 样本数pair[0]不变,通道数pair[1]缩小四倍
small_conv_arch = [(pair[0],pair[1] // ratio) for pair in conv_arch]
net = vgg(small_conv_arch)
X = torch.randn(size=(1,1,224,224))
for blk in net:X = blk(X)print(blk.__class__.__name__,'output shape:\t',X.shape)
Sequential output shape:	 torch.Size([1, 16, 112, 112])
Sequential output shape:	 torch.Size([1, 32, 56, 56])
Sequential output shape:	 torch.Size([1, 64, 28, 28])
Sequential output shape:	 torch.Size([1, 128, 14, 14])
Sequential output shape:	 torch.Size([1, 128, 7, 7])
Flatten output shape:	 torch.Size([1, 6272])
Linear output shape:	 torch.Size([1, 4096])
ReLU output shape:	 torch.Size([1, 4096])
Dropout output shape:	 torch.Size([1, 4096])
Linear output shape:	 torch.Size([1, 4096])
ReLU output shape:	 torch.Size([1, 4096])
Dropout output shape:	 torch.Size([1, 4096])
Linear output shape:	 torch.Size([1, 10])
'''开始计时'''
start_time = time.time()
lr,num_epochs,batch_size = 0.05,10,128
train_iter,test_iter = d2l.load_data_fashion_mnist(batch_size,resize=224)
d2l.train_ch6(net,train_iter,test_iter,num_epochs,lr,d2l.try_gpu())
'''时间结束'''
end_time = time.time()
run_time = end_time - start_time
# 将输出的秒数保留两位小数
print(f'{round(run_time,2)}s')

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/32256.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

P8604 [蓝桥杯 2013 国 C] 危险系数

题目背景 抗日战争时期,冀中平原的地道战曾发挥重要作用。 题目描述 地道的多个站点间有通道连接,形成了庞大的网络。但也有隐患,当敌人发现了某个站点后,其它站点间可能因此会失去联系。 我们来定义一个危险系数 DF(x,y)&…

MySQL_SQL性能分析

SQL执行频次 语法: SHOW GLOBAL STATUS LIKE COM_类型; COM_insert; 插入次数 com_delete; 删除次数 com_update; 更新次数 com_select; 查询次数 com_______; 注意:通过语法,可以查询到数据库的实际状态,就可以知道数据库是以增删…

TDesign中后台管理系统-用户登录

目录 1 创建用户表2 开发后端接口3 测试接口4 修改登录页面调用后端接口最终效果总结 中后台系统第一个要实现的功能就是登录了,我们通常的逻辑是让用户在登录页面输入用户名和密码,调用后端接口去验证用户的合法性,然后根据接口返回的结果进…

【T3】金蝶kis凭证数据转换到畅捷通T3软件中。

【问题需求】 将金蝶软件中的账套转换到畅捷通T3软件中。 由于金蝶老版本使用的是非sql server数据库。 进而需要将其数据导入到sql中,在转换到T3。 【转换环境】 金蝶中数据:凭证;科目无项目核算。 1、金蝶的数据文件后缀为.AIS; 2、安装office2003全版软件; 3、安装sq…

【算法】双指针——leetcode盛最多水的容器、剑指Offer57和为s的两个数字

盛水最多的容器 (1)暴力解法 算法思路:我们枚举出所有的容器大小,取最大值即可。 容器容积的计算方式: 设两指针 i , j ,分别指向水槽板的最左端以及最右端,此时容器的宽度为 j - i 。由于容器…

【CDH集群】无法发出查询:Host Monitor未运行

无法发出查询:Host Monitor未运行 【CDH集群】无法发出查询:Host Monitor未运行同事的解决方案解决方法:删除原uuid重启agent查看新uuid修改scm数据库中HOSTS表中的agent的uuid 【CDH集群】无法发出查询:Host Monitor未运行 起初是impala报错,连接不上&…

使用 React Native CLI 创建项目

React Native 安装的先决条件和设置 需要掌握的知识点 掌握 JavaScript 基础知识掌握 React 相关基础知识掌握 TypeScript 相关基础知识 安装软件前需要首先安装Chocolatey。Chocolatey 是一种流行的 Windows 包管理器。 安装 nodejs 和 JDK choco install -y nodejs-lts …

【工作记录】mysql中实现分组统计的三种方式

前言 实际工作中对范围分组统计的需求还是相对普遍的,本文记录下在mysql中通过函数和sql完成分组统计的实现过程。 数据及期望 比如我们获取到了豆瓣电影top250,现在想知道各个分数段的电影总数. 表数据如下: 期望结果: 实现方案 主要思路是根据s…

解决Vue+Element-UI 进行From表单校验时出现了英文提示问题

说明:该篇博客是博主一字一码编写的,实属不易,请尊重原创,谢谢大家! 问题描述 在使用form表单时,往往会对表单字段进行校验,字段为必填项时会添加required属性,此时自定义rules规则…

RabbitMQ 消息队列

文章目录 🍰有几个原因可以解释为什么要选择 RabbitMQ:🥩mq之间的对比🌽RabbitMQ vs Apache Kafka🌽RabbitMQ vs ActiveMQ🌽RabbitMQ vs RocketMQ🌽RabbitMQ vs Redis 🥩linux docke…

MyBatis Plus-个人笔记

前言 学习视频 尚硅谷-Mybatis-Plus教程学习主要内容 本文章记录尚硅谷-Mybatis-Plus教程内容,只是作为自己学习笔记,如有侵扰请联系删除 一、MyBatis-Plus简介 1、简介 MyBatis-Plus(简称 MP)是一个 MyBatis的增强工具&#…

用python来爬取某鱼的商品信息(1/2)

目录 前言 第一大难题——找到网站入口 曲线救国 模拟搜索 第二大难题——登录 提一嘴 登录cookie获取 第一种 第二种 第四大难题——无法使用导出的cookie 原因 解决办法 最后 出现小问题 总结 前言 本章讲理论,后面一节讲代码 拿来练练手的&#xff…

聊城大学823软件工程考研

1.什么是软件工程?它目标和内容是什么? 软件工程就是用科学的知识和技术原理来定义,开发,维护软件的一门学科。 软件工程目标:付出较低开发成本;达到要求的功能;取得较好的性能;开发的软件易于移植&…

使用 Python 中的 Langchain 从零到高级快速进行工程

大型语言模型 (LLM) 的一个重要方面是这些模型用于学习的参数数量。模型拥有的参数越多,它就能更好地理解单词和短语之间的关系。这意味着具有数十亿个参数的模型有能力生成各种创造性的文本格式,并以信息丰富的方式回答开放式和挑战性的问题。 ChatGPT 等法学硕士利用 T

秋招算法备战第41天 | 343. 整数拆分、96.不同的二叉搜索树

343. 整数拆分 - 力扣(LeetCode) 数学方法 观察数字拆分的模式,我们可以发现以下事实: 将数字拆分为尽可能多的3会使乘积最大化。这是因为当 n > 4 时,3(n-3) > n,所以我们总是更喜欢3的拆分&…

【Spring专题】Spring底层核心原理解析

目录 前言阅读导航前置知识Q1:你能描述一下JVM对象创建过程吗?Q2:Spring的特性是什么?前置知识总结 课程内容一、Spring容器的启动二、一般流程推测2.1 扫描2.2 IOC2.3 AOP 2.4 小结三、【扫描】过程简单推测四、【IOC】过程简单推…

laravel\lumen rabbitmq

1、安装扩展 composer require bschmitt/laravel-amqp 2、config文件下增加amqp.php配置文件 <?phpreturn [/*|--------------------------------------------------------------------------| Define which configuration should be used|----------------------------…

c++ STL--容器 (第二部分)

c STL–容器 &#xff08;第二部分&#xff09; 1.vector向量&#xff08;序列性容器&#xff09; 1.特点&#xff1a; ​ 数据的存储访问比较方便&#xff0c;可以像数组一阿姨那个使用[index]访问或修改值&#xff0c;适用于对元素修改和查看比较多的情况&#xff0c;对于…

数据库--MySQL

一、什么是范式&#xff1f; 范式是数据库设计时遵循的一种规范&#xff0c;不同的规范要求遵循不同的范式。 最常用的三大范式 第一范式(1NF)&#xff1a;属性不可分割&#xff0c;即每个属性都是不可分割的原子项。(实体的属性即表中的列) 第二范式(2NF)&#xff1a;满足…

代码随想录day03

链表理论基础 ● 链表是一种通过指针串联在一起的线性结构&#xff0c;每一个节点有两个部分&#xff0c;数据域和指针域&#xff0c;最后一个节点指针域指向null 链表类型 ● 单链表 ● 双链表 ○ 每个节点有两个指针域&#xff0c;一个指向下一个节点&#xff0c;一个是上…