动手学深度学习—使用块的网络VGG(代码详解)

目录

  • 1. VGG块
  • 2. VGG网络
  • 3. 训练模型

1. VGG块

经典卷积神经网络的基本组成部分是下面的这个序列:
1.带填充以保持分辨率的卷积层;
2.非线性激活函数,如ReLU;
3.汇聚层,如最大汇聚层。

定义网络块,便于我们重复构建某些网络架构,不仅利于代码编写与阅读也利于后面参数的优化

"""定义了一个名为vgg_block的函数来实现一个VGG块:1、卷积层的数量num_convs2、输入通道的数量in_channels 3、输出通道的数量out_channels
"""
import torch
from torch import nn
from d2l import torch as d2l# 定义vgg块,(卷积层数,输入通道,输出通道)
def vgg_block(num_convs, in_channels, out_channels):# 创建空网络结果,之后通过循环操作使用append函数进行添加layers = []# 循环操作,添加卷积层和非线性激活层for _ in range(num_convs):layers.append(nn.Conv2d(in_channels, out_channels,kernel_size=3, padding=1))layers.append(nn.ReLU())in_channels = out_channels# 最后添加最大值汇聚层layers.append(nn.MaxPool2d(kernel_size=2, stride=2))return nn.Sequential(*layers)

2. VGG网络

在这里插入图片描述
由于会重复用到卷积层、激活函数ReLU和汇聚层,我们将这三个组合成一个块,每次引用这个块来构建网络模型。
通过定义VGG块,使得重复的网络结构实现起来更加容易,也利于代码阅读。

# 原VGG网络有5个卷积块,前两个有一个卷积层,后三个块有两个卷积层
# 该网络使用8个卷积层和3个全连接层,因此它通常被称为VGG-11# (卷积层数,输出通道数)
conv_arch = ((1, 64), (1, 128), (2, 256), (2, 512), (2, 512))

实现VGG-11:使用8个卷积层和3个全连接层

# 通过for循环实现VGG-11
def vgg(conv_arch):# 定义空网络结构conv_blks = []in_channels = 1# 卷积层部分for (num_convs, out_channels) in conv_arch:# 添加vgg块conv_blks.append(vgg_block(num_convs, in_channels, out_channels))# 下一层输入通道数=当前层输出通道数in_channels = out_channelsreturn nn.Sequential(*conv_blks, nn.Flatten(),# 全连接层部分nn.Linear(out_channels * 7 * 7, 4096), nn.ReLU(), nn.Dropout(0.5),nn.Linear(4096, 4096), nn.ReLU(), nn.Dropout(0.5),nn.Linear(4096, 10))net = vgg(conv_arch)

构建一个高度和宽度为224的单通道数据样本,以观察每个层输出的形状

# 构建一个高度和宽度为224的单通道数据样本,以观察每个层输出的形状
X = torch.randn(size=(1, 1, 224, 224))
for blk in net:X = blk(X)print(blk.__class__.__name__, 'output shape:\t', X.shape)

每一层的输出形状
在这里插入图片描述

3. 训练模型

构建了一个通道数较少的网络,足够用于训练Fashion-MNIST数据集

# 构建了一个通道数较少的网络,足够用于训练Fashion-MNIST数据集
ratio = 4
# //为整除
small_conv_arch = [(pair[0], pair[1] // 4) for pair in conv_arch]
net = vgg(small_conv_arch)

定义精度评估函数

"""定义精度评估函数:1、将数据集复制到显存中2、通过调用accuracy计算数据集的精度
"""
def evaluate_accuracy_gpu(net, data_iter, device=None): #@save# 判断net是否属于torch.nn.Module类if isinstance(net, nn.Module):net.eval()# 如果不在参数选定的设备,将其传输到设备中if not device:device = next(iter(net.parameters())).device# Accumulator是累加器,定义两个变量:正确预测的数量,总预测的数量。metric = d2l.Accumulator(2)with torch.no_grad():for X, y in data_iter:# 将X, y复制到设备中if isinstance(X, list):# BERT微调所需的(之后将介绍)X = [x.to(device) for x in X]else:X = X.to(device)y = y.to(device)# 计算正确预测的数量,总预测的数量,并存储到metric中metric.add(d2l.accuracy(net(X), y), y.numel())return metric[0] / metric[1]

定义GPU 训练函数

"""定义GPU训练函数:1、为了使用gpu,首先需要将每一小批量数据移动到指定的设备(例如GPU)上;2、使用Xavier随机初始化模型参数;3、使用交叉熵损失函数和小批量随机梯度下降。
"""
#@save
def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):"""用GPU训练模型(在第六章定义)"""# 定义初始化参数,对线性层和卷积层生效def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)# 在设备device上进行训练print('training on', device)net.to(device)# 优化器:随机梯度下降optimizer = torch.optim.SGD(net.parameters(), lr=lr)# 损失函数:交叉熵损失函数loss = nn.CrossEntropyLoss()# Animator为绘图函数animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],legend=['train loss', 'train acc', 'test acc'])# 调用Timer函数统计时间timer, num_batches = d2l.Timer(), len(train_iter)for epoch in range(num_epochs):# Accumulator(3)定义3个变量:损失值,正确预测的数量,总预测的数量metric = d2l.Accumulator(3)net.train()# enumerate() 函数用于将一个可遍历的数据对象for i, (X, y) in enumerate(train_iter):timer.start() # 进行计时optimizer.zero_grad() # 梯度清零X, y = X.to(device), y.to(device) # 将特征和标签转移到devicey_hat = net(X)l = loss(y_hat, y) # 交叉熵损失l.backward() # 进行梯度传递返回optimizer.step()with torch.no_grad():# 统计损失、预测正确数和样本数metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])timer.stop() # 计时结束train_l = metric[0] / metric[2] # 计算损失train_acc = metric[1] / metric[2] # 计算精度# 进行绘图if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(train_l, train_acc, None))# 测试精度test_acc = evaluate_accuracy_gpu(net, test_iter) animator.add(epoch + 1, (None, None, test_acc))# 输出损失值、训练精度、测试精度print(f'loss {train_l:.3f}, train acc {train_acc:.3f},'f'test acc {test_acc:.3f}')# 设备的计算能力print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec'f'on {str(device)}')

在这里插入图片描述

进行训练

# 学习率略高
lr, num_epochs, batch_size = 0.05, 10, 128
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

在这里插入图片描述
块的使用导致网络定义的非常简洁。使用块可以有效地设计复杂的网络。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/113273.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【安全体系架构】——零信任网络架构

什么是零信任网络架构? 零信任网络架构是一种网络和信息安全模型,它将传统的信任模型颠覆,不再信任内部或外部用户、设备或网络。相反,它将每个访问请求都视为不受信任,要求对每个用户、设备和流量都进行认证和授权&a…

2023-10-20 LeetCode每日一题(根据规则将箱子分类)

2023-10-20每日一题 一、题目编号 2525. 根据规则将箱子分类二、题目链接 点击跳转到题目位置 三、题目描述 给你四个整数 length ,width ,height 和 mass ,分别表示一个箱子的三个维度和质量,请你返回一个表示箱子 类别 的字…

底层驱动day2作业

控制三盏灯亮灭 代码: //head.h#ifndef __HEAD_H__ #define __HEAD_H__ #define PHY_RCC 0x50000A28 #define PHY_GPIOE_MODER 0x50006000 #define PHY_GPIOF_MODER 0x50007000 #define PHY_GPIOE_ODR 0x50006014 #define PHY_GPIOF_ODR 0x50007014#endif //demo…

播放svga动画的时候 第一次加载资源,然后切换动画 会动画会重影

如果在切换 SVGA 动画的过程中,第一次加载时出现重影,但第二次及以后的切换没有重影,这可能是由于第一次加载时资源缓存不完整导致的。为了解决这个问题,你可以尝试以下方法: 1.在每次切换动画之前,预先加…

C#经典十大排序算法(完结)

C#冒泡排序算法 简介 冒泡排序算法是一种基础的排序算法,它的实现原理比较简单。核心思想是通过相邻元素的比较和交换来将最大(或最小)的元素逐步"冒泡"到数列的末尾。 详细文章描述 https://mp.weixin.qq.com/s/z_LPZ6QUFNJcw…

vue-pdf多页预览异常,Rendering cancelled, page 1 Error at BaseExceptionClosure xxx

项目开发使用vue-pdf,单页情况预览正常,多页vue-pdf预览异常,第一次预览时,会先弹出异常模态窗口,关闭模态窗口,pdf又是正常显示,报错信息及异常截图如下: 报错信息 Rendering cancelled, page…

【PADS封装】2.4G PCB天线封装(量产用)

包含了我们平时常用的2.4GPCB天线封装,总共11种封装。完全能满足日常设计使用。 下载链接!!https://mp.weixin.qq.com/s?__bizMzU2OTc4ODA4OA&mid2247548815&idx1&sne625e51a06755a34ab4404497770df48&chksmfcfb2c58cb8ca5…

MMWHS数据集

Multi-Modality Whole Heart Segmentation (MMWHS) 数据集[1] 是多模态医疗图像数据集,有磁共振(Magnetic Resonance Imaging,MRI)和断层扫描(Computed Tomography,CT)两种,[2] 对数…

探索图像分割技术:使用 OpenCV 的分水岭算法

贾斯卡兰巴蒂亚 一、说明 图像分割是计算机视觉的一个基本方面,多年来经历了巨大的转变。这将是一系列三篇博客文章,深入研究三种不同的图像分割技术 - 1使用OpenCV的经典分水岭算法,2使用PyTorch实现的基于深度学习的UNet模型,3 …

【Redis】数据结构之dict

目录 dict的基本结构dict的相关操作函数底层通用的之查找插入key-value对应该放入ht表的哪个槽rehash过程 dict的基本结构 typedef struct dict {dictType *type;void *privdata;dictht ht[2];long rehashidx; /* rehashing not in progress if rehashidx -1 */unsigned long…

Ubuntu小知识总结

Ubuntu相关的小知识总结 一、Ubuntu系统下修改用户开机密码二、Vmware虚拟机和主机之间复制、粘贴内容、拖拽文件的详细方法问题描述Vmware tools灰色不能安装解决方法小知识点:MarkDown的空格 三、Ubuntu虚拟机网络无法连接的几种解决方法1.重启网络编辑器2. 重启虚…

Linux下使用openssl为harbor制作证书

openssl是一个功能丰富且自包含的开源安全工具箱。它提供的主要功能有:SSL协议实现(包括SSLv2、SSLv3和TLSv1)、大量软算法(对称/非对称/摘要)、大数运算、非对称算法密钥生成、ASN.1编解码库、证书请求(PKCS10)编解码、数字证书编解码、CRL编解码、OCSP协议、数字证…

免费高清壁纸下载(静态和动态壁纸)

一、网址下载(静态壁纸) 高清图片直接另存为就可以了。然后在电脑空白处右键——个性化设置即可替换壁纸。 ①网址:https://www.hippopx.com ②极简壁纸:https://bz.zzzmh.cn/index ③彼岸图网:http://pic.netbian…

Linux:firewalld防火墙-介绍(1)

防火墙技术 1.包过滤 packet filtering 2.应用代理 application proxy 3.状态检测 stateful inspection Linux 包过滤防火墙 概述 1.netfilter 位于Linux内核中的包过滤功能体系 称为Linux防火墙的“内核态” 2.firewalld CentOS7默认的管理防火墙规则的工具 称为Linux防火…

OpenCV17-图像形态学操作

OpenCV17-图像形态学操作 1.形态学操作1.1腐蚀1.2膨胀 2.形态学应用2.1开运算2.2闭运算2.3形态学梯度2.4顶帽运算2.5黑帽运算2.6击中击不中变换2.7形态学应用示例 1.形态学操作 1.1腐蚀 图像腐蚀(Image erosion)可用于减小图像中物体的大小、填充孔洞或…

通过数组的指针获得数组个数

这几天学习智能指针时,自己在练习写个管理数组指针的类时碰到了通过数组指针获取数组个数的问题 1.在网上查询了通过数组指针获取数组个数的方法,对于自定义数据在前四个节点保存了数组个数 Student* pAry new Student[3];size_t num *((size_t*)pAry - 1);//3测试是成功的…

华为eNSP配置专题-VRRP的配置

文章目录 华为eNSP配置专题-VRRP的配置0、参考文档1、前置环境1.1、宿主机1.2、eNSP模拟器 2、基本环境搭建2.1、基本终端构成和连接 2.VRRP的配置2.1、PC1的配置2.2、接入交换机acsw的配置2.3、核心交换机coresw1的配置2.4、核心交换机coresw2的配置2.5、配置VRRP2.6、配置出口…

C++多态、虚函数、纯虚函数、抽象类

多态的概念 通俗来说,就是多种形态,具体点就是去完成某个行为,当不同的对象去完成时会产生出不同的状态。 举个简单的例子:抢红包,我们每个人都只需要点击一下红包,就会抢到金额。有些人能…

OpenCV中world模块介绍

OpenCV中有很多模块,模块间保持最小的依赖关系,用户可以根据自己的实际需要链接相关的库,而不需链接所有的库,这样在最终交付应用程序时可以减少总库的大小。但如果需要依赖OpenCV的库太多,有时会带来不方便,此时可以使…

vue2 element手术麻醉信息系统源码,手术预约、手术安排、排班查询、手术麻醉监测、麻醉记录单

手术麻醉临床信息系统有着完善的临床业务功能,能够涵盖整个围术期的工作,能够采集、汇总、存储、处理、展现所有的临床诊疗资料。通过该系统的实施,能够规范麻醉科的工作流程,实现麻醉手术过程的信息数字化,自动生成麻…