动手学深度学习—使用块的网络VGG(代码详解)

目录

  • 1. VGG块
  • 2. VGG网络
  • 3. 训练模型

1. VGG块

经典卷积神经网络的基本组成部分是下面的这个序列:
1.带填充以保持分辨率的卷积层;
2.非线性激活函数,如ReLU;
3.汇聚层,如最大汇聚层。

定义网络块,便于我们重复构建某些网络架构,不仅利于代码编写与阅读也利于后面参数的优化

"""定义了一个名为vgg_block的函数来实现一个VGG块:1、卷积层的数量num_convs2、输入通道的数量in_channels 3、输出通道的数量out_channels
"""
import torch
from torch import nn
from d2l import torch as d2l# 定义vgg块,(卷积层数,输入通道,输出通道)
def vgg_block(num_convs, in_channels, out_channels):# 创建空网络结果,之后通过循环操作使用append函数进行添加layers = []# 循环操作,添加卷积层和非线性激活层for _ in range(num_convs):layers.append(nn.Conv2d(in_channels, out_channels,kernel_size=3, padding=1))layers.append(nn.ReLU())in_channels = out_channels# 最后添加最大值汇聚层layers.append(nn.MaxPool2d(kernel_size=2, stride=2))return nn.Sequential(*layers)

2. VGG网络

在这里插入图片描述
由于会重复用到卷积层、激活函数ReLU和汇聚层,我们将这三个组合成一个块,每次引用这个块来构建网络模型。
通过定义VGG块,使得重复的网络结构实现起来更加容易,也利于代码阅读。

# 原VGG网络有5个卷积块,前两个有一个卷积层,后三个块有两个卷积层
# 该网络使用8个卷积层和3个全连接层,因此它通常被称为VGG-11# (卷积层数,输出通道数)
conv_arch = ((1, 64), (1, 128), (2, 256), (2, 512), (2, 512))

实现VGG-11:使用8个卷积层和3个全连接层

# 通过for循环实现VGG-11
def vgg(conv_arch):# 定义空网络结构conv_blks = []in_channels = 1# 卷积层部分for (num_convs, out_channels) in conv_arch:# 添加vgg块conv_blks.append(vgg_block(num_convs, in_channels, out_channels))# 下一层输入通道数=当前层输出通道数in_channels = out_channelsreturn nn.Sequential(*conv_blks, nn.Flatten(),# 全连接层部分nn.Linear(out_channels * 7 * 7, 4096), nn.ReLU(), nn.Dropout(0.5),nn.Linear(4096, 4096), nn.ReLU(), nn.Dropout(0.5),nn.Linear(4096, 10))net = vgg(conv_arch)

构建一个高度和宽度为224的单通道数据样本,以观察每个层输出的形状

# 构建一个高度和宽度为224的单通道数据样本,以观察每个层输出的形状
X = torch.randn(size=(1, 1, 224, 224))
for blk in net:X = blk(X)print(blk.__class__.__name__, 'output shape:\t', X.shape)

每一层的输出形状
在这里插入图片描述

3. 训练模型

构建了一个通道数较少的网络,足够用于训练Fashion-MNIST数据集

# 构建了一个通道数较少的网络,足够用于训练Fashion-MNIST数据集
ratio = 4
# //为整除
small_conv_arch = [(pair[0], pair[1] // 4) for pair in conv_arch]
net = vgg(small_conv_arch)

定义精度评估函数

"""定义精度评估函数:1、将数据集复制到显存中2、通过调用accuracy计算数据集的精度
"""
def evaluate_accuracy_gpu(net, data_iter, device=None): #@save# 判断net是否属于torch.nn.Module类if isinstance(net, nn.Module):net.eval()# 如果不在参数选定的设备,将其传输到设备中if not device:device = next(iter(net.parameters())).device# Accumulator是累加器,定义两个变量:正确预测的数量,总预测的数量。metric = d2l.Accumulator(2)with torch.no_grad():for X, y in data_iter:# 将X, y复制到设备中if isinstance(X, list):# BERT微调所需的(之后将介绍)X = [x.to(device) for x in X]else:X = X.to(device)y = y.to(device)# 计算正确预测的数量,总预测的数量,并存储到metric中metric.add(d2l.accuracy(net(X), y), y.numel())return metric[0] / metric[1]

定义GPU 训练函数

"""定义GPU训练函数:1、为了使用gpu,首先需要将每一小批量数据移动到指定的设备(例如GPU)上;2、使用Xavier随机初始化模型参数;3、使用交叉熵损失函数和小批量随机梯度下降。
"""
#@save
def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):"""用GPU训练模型(在第六章定义)"""# 定义初始化参数,对线性层和卷积层生效def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)# 在设备device上进行训练print('training on', device)net.to(device)# 优化器:随机梯度下降optimizer = torch.optim.SGD(net.parameters(), lr=lr)# 损失函数:交叉熵损失函数loss = nn.CrossEntropyLoss()# Animator为绘图函数animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],legend=['train loss', 'train acc', 'test acc'])# 调用Timer函数统计时间timer, num_batches = d2l.Timer(), len(train_iter)for epoch in range(num_epochs):# Accumulator(3)定义3个变量:损失值,正确预测的数量,总预测的数量metric = d2l.Accumulator(3)net.train()# enumerate() 函数用于将一个可遍历的数据对象for i, (X, y) in enumerate(train_iter):timer.start() # 进行计时optimizer.zero_grad() # 梯度清零X, y = X.to(device), y.to(device) # 将特征和标签转移到devicey_hat = net(X)l = loss(y_hat, y) # 交叉熵损失l.backward() # 进行梯度传递返回optimizer.step()with torch.no_grad():# 统计损失、预测正确数和样本数metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])timer.stop() # 计时结束train_l = metric[0] / metric[2] # 计算损失train_acc = metric[1] / metric[2] # 计算精度# 进行绘图if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(train_l, train_acc, None))# 测试精度test_acc = evaluate_accuracy_gpu(net, test_iter) animator.add(epoch + 1, (None, None, test_acc))# 输出损失值、训练精度、测试精度print(f'loss {train_l:.3f}, train acc {train_acc:.3f},'f'test acc {test_acc:.3f}')# 设备的计算能力print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec'f'on {str(device)}')

在这里插入图片描述

进行训练

# 学习率略高
lr, num_epochs, batch_size = 0.05, 10, 128
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

在这里插入图片描述
块的使用导致网络定义的非常简洁。使用块可以有效地设计复杂的网络。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/113273.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

底层驱动day2作业

控制三盏灯亮灭 代码: //head.h#ifndef __HEAD_H__ #define __HEAD_H__ #define PHY_RCC 0x50000A28 #define PHY_GPIOE_MODER 0x50006000 #define PHY_GPIOF_MODER 0x50007000 #define PHY_GPIOE_ODR 0x50006014 #define PHY_GPIOF_ODR 0x50007014#endif //demo…

vue-pdf多页预览异常,Rendering cancelled, page 1 Error at BaseExceptionClosure xxx

项目开发使用vue-pdf,单页情况预览正常,多页vue-pdf预览异常,第一次预览时,会先弹出异常模态窗口,关闭模态窗口,pdf又是正常显示,报错信息及异常截图如下: 报错信息 Rendering cancelled, page…

【PADS封装】2.4G PCB天线封装(量产用)

包含了我们平时常用的2.4GPCB天线封装,总共11种封装。完全能满足日常设计使用。 下载链接!!https://mp.weixin.qq.com/s?__bizMzU2OTc4ODA4OA&mid2247548815&idx1&sne625e51a06755a34ab4404497770df48&chksmfcfb2c58cb8ca5…

探索图像分割技术:使用 OpenCV 的分水岭算法

贾斯卡兰巴蒂亚 一、说明 图像分割是计算机视觉的一个基本方面,多年来经历了巨大的转变。这将是一系列三篇博客文章,深入研究三种不同的图像分割技术 - 1使用OpenCV的经典分水岭算法,2使用PyTorch实现的基于深度学习的UNet模型,3 …

Ubuntu小知识总结

Ubuntu相关的小知识总结 一、Ubuntu系统下修改用户开机密码二、Vmware虚拟机和主机之间复制、粘贴内容、拖拽文件的详细方法问题描述Vmware tools灰色不能安装解决方法小知识点:MarkDown的空格 三、Ubuntu虚拟机网络无法连接的几种解决方法1.重启网络编辑器2. 重启虚…

Linux下使用openssl为harbor制作证书

openssl是一个功能丰富且自包含的开源安全工具箱。它提供的主要功能有:SSL协议实现(包括SSLv2、SSLv3和TLSv1)、大量软算法(对称/非对称/摘要)、大数运算、非对称算法密钥生成、ASN.1编解码库、证书请求(PKCS10)编解码、数字证书编解码、CRL编解码、OCSP协议、数字证…

免费高清壁纸下载(静态和动态壁纸)

一、网址下载(静态壁纸) 高清图片直接另存为就可以了。然后在电脑空白处右键——个性化设置即可替换壁纸。 ①网址:https://www.hippopx.com ②极简壁纸:https://bz.zzzmh.cn/index ③彼岸图网:http://pic.netbian…

OpenCV17-图像形态学操作

OpenCV17-图像形态学操作 1.形态学操作1.1腐蚀1.2膨胀 2.形态学应用2.1开运算2.2闭运算2.3形态学梯度2.4顶帽运算2.5黑帽运算2.6击中击不中变换2.7形态学应用示例 1.形态学操作 1.1腐蚀 图像腐蚀(Image erosion)可用于减小图像中物体的大小、填充孔洞或…

华为eNSP配置专题-VRRP的配置

文章目录 华为eNSP配置专题-VRRP的配置0、参考文档1、前置环境1.1、宿主机1.2、eNSP模拟器 2、基本环境搭建2.1、基本终端构成和连接 2.VRRP的配置2.1、PC1的配置2.2、接入交换机acsw的配置2.3、核心交换机coresw1的配置2.4、核心交换机coresw2的配置2.5、配置VRRP2.6、配置出口…

C++多态、虚函数、纯虚函数、抽象类

多态的概念 通俗来说,就是多种形态,具体点就是去完成某个行为,当不同的对象去完成时会产生出不同的状态。 举个简单的例子:抢红包,我们每个人都只需要点击一下红包,就会抢到金额。有些人能…

OpenCV中world模块介绍

OpenCV中有很多模块,模块间保持最小的依赖关系,用户可以根据自己的实际需要链接相关的库,而不需链接所有的库,这样在最终交付应用程序时可以减少总库的大小。但如果需要依赖OpenCV的库太多,有时会带来不方便,此时可以使…

vue2 element手术麻醉信息系统源码,手术预约、手术安排、排班查询、手术麻醉监测、麻醉记录单

手术麻醉临床信息系统有着完善的临床业务功能,能够涵盖整个围术期的工作,能够采集、汇总、存储、处理、展现所有的临床诊疗资料。通过该系统的实施,能够规范麻醉科的工作流程,实现麻醉手术过程的信息数字化,自动生成麻…

mac 升级node到指定版本

node版本14.15.1升级到最新稳定版18.18.2 mac系统 先查看一下自己的node版本 node -v开始升级 第一步 清除node的缓存 sudo npm cache clean -f第二步 安装n模块【管理模块 n是管理 nodejs版本】 sudo npm install -g n第三步升级node sudo n stable // 把当前系统的 Node…

计算机毕业设计 基于SpringBoot智慧养老中心管理系统的设计与实现 Javaweb项目 Java实战项目 前后端分离 文档报告 代码讲解 安装调试

🍊作者:计算机编程-吉哥 🍊简介:专业从事JavaWeb程序开发,微信小程序开发,定制化项目、 源码、代码讲解、文档撰写、ppt制作。做自己喜欢的事,生活就是快乐的。 🍊心愿:点…

【趣味随笔】盘点国内外做双足机器人的公司

📢:如果你也对机器人、人工智能感兴趣,看来我们志同道合✨ 📢:不妨浏览一下我的博客主页【https://blog.csdn.net/weixin_51244852】 📢:文章若有幸对你有帮助,可点赞 👍…

leetCode 392. 判断子序列 动态规划 + 优化空间 / 双指针 等多种解法

392. 判断子序列 - 力扣(LeetCode) 给定字符串 s 和 t ,判断 s 是否为 t 的子序列。字符串的一个子序列是原始字符串删除一些(也可以不删除)字符而不改变剩余字符相对位置形成的新字符串。(例如&#xff0c…

Aocoda-RC F405V2 FC(STM32F405RGT6 v.s. AT32F435RGT7) IO Definitions

[TOC](Aocoda-RC F405V2 FC(STM32F405RGT6 v.s. AT32F435RGT7) IO Definitions) 1. 源由 Aocoda-RC F405V2飞控支持betaflight/inav/Ardupilot固件,是一款固件兼容性非常不错的开源硬件。 之前我们对比过STM32F405RGT6 v.s. AT32F435RGT7 Comparison for Flight …

ThreadLocal源码解密

1 背景 作为一只懒懒地程序员,其实我是不太爱看源码的,晦涩、深奥、难懂、耗费时间等等,就觉得不是我这种能力平平地小老百姓能吃得消的,但现实比人强,记得曾经我就被不懂原理的情况下乱用ThreadLocal给毒打了。 犹记得当时在一个JSF服务中的责任链的校验场景中需要在源…

UART、SPI、I2C通信协议超全入门教程

本文引注: https://mp.weixin.qq.com/s/lVWK8xlDt7cOLi8WHYSuPg 1.SPI协议 1.基础 2.简介 3.工作原理 4.SPI数据传输步骤与优缺点 2.UART协议

分布式微服务技术栈-SpringCloud<Eureka,Ribbon,nacos>

微服务技术栈 一、微服务 介绍了解1 架构结构案例与 springboot 兼容关系拆分案例拆分服务拆分-服务远程调用 2 eureka注册中心Eureka-提供者与消费者Eureka-eureka原理分析Eureka-搭建eureka服务Eureka-服务注册Eureka-服务发现 3 Ribbon组件 负载均衡Ribbon-负载均衡原理Ribb…