【PyTorch实战演练】AlexNet网络模型构建并使用Cifar10数据集进行批量训练(附代码)

目录

0. 前言

1. Cifar10数据集

2. AlexNet网络模型

2.1 AlexNet的网络结构

2.2 激活函数ReLu

2.3 Dropout方法

2.4 数据增强

3. 使用GPU加速进行批量训练

4. 网络模型构建

5. 训练过程

6. 完整代码


0. 前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解及成果,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。

本文的写作目的主要有以下3点:

  1. 介绍经典卷积神经元网络——AlexNet;
  2. 基于AlexNet进行改造,使用PyTorch进行编码;
  3. 使用批量训练的方法,训练Cifar10数据集。

为什么非要对AlexNet进行改造呢?因为我们要训练的数据集Cifar10中的图片尺寸是32×32×3,比AlexNet输入224×224×3还要小。

当然,我们也可以选择先把Cifar10数据集transform成224×224×3的图像,而不用改造AlexNet的网络结构,但是这样有些“浪费”AlexNet的网络结构。

1. Cifar10数据集

Cifar10是一个包含10个类别的图像分类数据集,每个类别包含6000张32x32像素的RGB三通道彩色图像,总计60000张图像,其中50000个图像用于训练网络模型(训练组),10000个图像用于验证网络模型(验证组)。

关于Cifar10数据集的下载及解析,这里不再赘述,之前的文章有过详细说明:【PyTorch实战演练】使用Cifar10数据集训练LeNet5网络并实现图像分类(附代码)

2. AlexNet网络模型

AlexNet是深度学习领域中的一个经典卷积神经网络模型,由Geoffrey Hinton的学生Alex Krizhevsky和Ilya Sutskever在2012年《ImageNet Classification with Deep Convolutional Neural Networks》提出。AlexNet在ImageNet图像分类挑战赛ILSVRC(ImageNet Large Scale Visual Recognition Challenge)上获得了远远超过第二名的成绩,它的出现标志着深度学习在计算机视觉领域的爆发。

讲到这里可能有人会疑问本文为什么不直接用ImageNet数据集?因为这个数据集实在是太大了!作为学习实例实在没必要。(主要是我的电脑性能也跟不上……)

ILSVRC2012

 |-Training images (Task 1 & 2). 138GB.
 |-Training images (Task 3). 728MB.
 |-Validation images (all tasks). 6.3GB.
 |-Test images (all tasks). 13GB.

相比于另一个经典的卷积神经网络模型LeNet,AlexNet的模型更深更广,这点通过模型参数数量可以直观地比较:LeNet总共有60,840个训练参数,而AlexNet的训练参数多达6000万个!

2.1 AlexNet的网络结构

AlexNet论文原文中的结构图:

受限于当时GPU的性能,上面的结构图中分为了一模一样的上下两行,这是为了分在两个GPU中训练,而现在我们完全没必要这么做了。

加入每层的具体参数,我整理AlexNet的网络结构以及每层的输入输出张量维度如下:

AlexNet网络整体由5个卷积层和3个全连接层构成,网络的输入为224×224的3通道图像,最终输出为长度1000的张量,代表1000个分类的置信度。

只要根据卷积层及池化层输出的特征图尺寸计算公式:

output = (input-kernel+2\times padding)/stride+1

不难计算出过程中每层输出的特征图尺寸,计算结果也已在上图中标注出。

这里有必要说明下卷积层C1,因为按照输入=224,卷积核kernel=11,padding=2,stride=4,按照上面的公式计算的输出应该为55.25,不为整数。这时卷积层会向下取整,输出特征图为55。似乎看起来卷积层C1的卷积核设为kernel=12更为合理。

也有的文章说AlexNet输入图像尺寸为227×227×3,我不清楚这个说法是怎么来的,因为Alex的论文原文已经写明:The first convolutional layer filters the 224×224×3 input image。

2.2 激活函数ReLu

从网络结构上可以看出,除了最后一个全连接层选择Softmax作为激活函数(因为要进行归一化),其他所有层都清一色地选择了ReLu。Alex等人对比了ReLu和Tanh两个激活函数的训练错误率下降速度如下图所示:

其中实线为ReLu,虚线为Tanh,可见选择ReLu作为激活函数比Tanh训练错误率下降速度要快得多(6倍)。而且这并不是针对某个特定的网络,在同样条件下(选用最快的学习率,不使用任何正则化方法),ReLu方法总是会比饱和神经元方法快几倍

论文中的原文:The learning rates for each network were chosen independently to make training as fast as possible. No regularization of any kind was employed. The magnitude of the effect demonstrated here varies with network architecture, but networks with ReLUs consistently learn several times faster than equivalents with saturating neurons.

这里需要再解释下饱和神经元(saturating neurons)是指神经元的输出会限制在一定范围,例如Sigmoid限制在(0,1),Tanh限制在(-1,1),采用这些激活函数的神经元即为饱和神经元。不饱和神经元(non-saturating neurons)是指神经元的输出不会限制在某一范围,例如使用激活函数为ReLu的神经元。

2.3 Dropout方法

AlexNet在全连接层FC6和FC7引入了Dropout层,因为拥有6000万个参数的AlexNet算是比较复杂的深度学习网络,使用较小的数据集进行训练时容易出现过拟合的情况。

没错,即便是ILSVRC数据集对AlexNet来说,都只能算是一个“较小的”数据集。

Dropout方法的本质是随机将深度学习网络中某个单元(神经元)丢弃,即将其输出置0。Dropout层的引入提高了网络结构的鲁棒性,因为其使得网络中随机丢失一些神经元之后仍能保证输出的准确性。

Dropout方法的具体使用在此前的基于torch.nn.Dropout通过实例说明Dropout丢弃法(附代码)已详细介绍过了,这里也不在赘述。但是有一点我必须再强调下:

Dropout是一个训练深度学习网络的方法,在验证输出时需要取消Dropout!

顺便提一下,ReLu和Dropout也都是由Alex的老师Hinton提出的,可见有一个牛逼的老板……

Dropout是一种有效抑制过拟合的方法,但这是以牺牲训练速度为代价的。抑制过拟合的根本手段是要增大训练数据集,但是现实情况往往数据集的量十分有限,这时数据增强就非常有必要了!

2.4 数据增强

在AlexNet论文中介绍的数据增强主要有两个方法

  1. 图像切割和镜像:这个方法非常好理解,即从一个大图像中切割出若干个更小的图像,以及再基于这些图像做镜像。虽然从同一个图像中切割或者镜像出的小图像在训练结果上肯定有高度的相关性,但这仍是一个抑制过拟合的有效手段;
  2. 调整图像的RGB数值:这个方法操作起来比较复杂,可以简单理解其作用就相当于是给各个图像加了不同的“滤镜”。其详细原理非本文重点,感兴趣的童鞋可以参见:主成分分析(PCA)原理详解

3. 使用GPU加速进行批量训练

由于GPU在并行计算相比CPU有着巨大的优势,因此使用GPU进行批量训练可以节省大量的时间!

关于GPU和CPU的运算时间对比,可以参考我的往期文章:【PyTorch&TensorBoard实战】GPU与CPU的计算速度对比(附代码)

比如在训练Cifar10数据集时,我们可以让Batch_size=256个图片作为一个整体一起进行训练:

注意:这里的图像是我手工排列的,数据的真实size是[256, 3, 32, 32],即[batch_size, channel, H, W]。

使用.DataLoader()方法可以实现数据集的分批,以Cifar10数据为例,在PyTorch中的实现方法为:

from torchvision import datasets
import torch.utils.data as data
import torch
from torchvision import transformsbatch_size = 256data_path = 'D:\\DL\\CIFAR10\\CIFAR10\\IMG_file'  #数据集路径
cifar10_train = datasets.CIFAR10(data_path, train=True, download=False,transform=transforms.ToTensor())   #第一次下载download要设定为True
cifar10_train_loader = data.DataLoader(dataset=cifar10_train, batch_size=batch_size , shuffle=False)

在数据集分批后,使用.cuda()方法把分批后的数据发送到GPU上进行训练。

4. 网络模型构建

为了适配Cifar10数据集的尺寸32×32×3及输出类别只有10类,在AlexNet的网络结构基础上进行改造:

Python代码如下:

class AlexNet(nn.Module):def __init__(self, dropout=0.9):super(AlexNet, self).__init__()self.model = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=96, kernel_size=5, stride=1),nn.ReLU(),nn.Conv2d(in_channels=96, out_channels=256, kernel_size=2, stride=1),nn.ReLU(),nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3, stride=2), nn.ReLU(),nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2),nn.Flatten(),nn.Linear(in_features=256*6*6,out_features=4096),nn.ReLU(),nn.Dropout(p=dropout),nn.Linear(in_features=4096, out_features=256),nn.ReLU(),nn.Dropout(p=dropout),nn.Linear(in_features=256, out_features=10),nn.Softmax())def forward(self,x):return self.model(x)

5. 训练过程

训练的损失函数采用交叉熵损失函数,优化器采用Adam,加入余弦退火自调整学习率方法。

自调整学习率方法可以参考【PyTorch实战演练】自调整学习率实例应用(附代码)

criterion = nn.CrossEntropyLoss()
opt = torch.optim.Adam(alexnet.parameters(),lr=initial_lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=opt,T_max=100,last_epoch=-1)

训练过程如下图所示,整体采用了5段的分段训练,每段的迭代次数epoch以及初始学习率initial_lr已在图中标注出:

这里有两点需要说明下:

  1. 受限于电脑性能,我只训练了Cifar10的前2560张图像,即前10个Batch(即使这样整个训练也耗费了大概5~6个小时(T_T))
  2. 从上面训练过程可以看出,损失值下降到一定范围后,就不再下降,我认为这是dropout导致的。

6. 完整代码

from torchvision import datasets
import torch.utils.data as data
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchvision import transforms
from tqdm import tqdmbatch_size = 256data_path = 'D:\\DL\\CIFAR10\\CIFAR10\\IMG_file'
cifar10_train = datasets.CIFAR10(data_path, train=True, download=False,transform=transforms.ToTensor())small_cifar10 = []
for i in range(2560):small_cifar10.append(cifar10_train[i])cifar10_train_loader = data.DataLoader(dataset= small_cifar10, batch_size=batch_size , shuffle=False)cifar10_train_loader = list(cifar10_train_loader)class AlexNet(nn.Module):def __init__(self, dropout=0.9):super(AlexNet, self).__init__()self.model = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=96, kernel_size=5, stride=1),nn.ReLU(),nn.Conv2d(in_channels=96, out_channels=256, kernel_size=2, stride=1),nn.ReLU(),nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3, stride=2), nn.ReLU(),nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2),nn.Flatten(),nn.Linear(in_features=256*6*6,out_features=4096),nn.ReLU(),nn.Dropout(p=dropout),nn.Linear(in_features=4096, out_features=256),nn.ReLU(),nn.Dropout(p=dropout),nn.Linear(in_features=256, out_features=10),nn.Softmax())def forward(self,x):return self.model(x)alexnet = AlexNet(dropout=0.9).cuda()
alexnet.load_state_dict(torch.load('weight/epoch=2000_initial_lr=0.000020.pth'))# img,label=cifar10_train_loader[0]  #用于测试网络正向传播,正式代码中不用这两行
# print(alexnet(img))def train(epoch, initial_lr):criterion = nn.CrossEntropyLoss()opt = torch.optim.Adam(alexnet.parameters(),lr=initial_lr)scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=opt,T_max=100,last_epoch=-1)for e in tqdm(range(epoch)):print('current %i epoch'%e)iter_loss = 0opt.zero_grad()for iter,(img,label) in enumerate(cifar10_train_loader):img = img.cuda()label = label.cuda()output = alexnet(img)loss = criterion(output, label)loss_plt = loss.detach()loss_plt = loss_plt.cpu()iter_loss = iter_loss+loss_pltloss.backward()opt.step()plt.scatter(e, iter_loss,s=2,c='r')scheduler.step()if __name__ == '__main__':epoch = 500initial_lr = 1e-6train(epoch, initial_lr)torch.save(alexnet.state_dict(), 'weight/epoch=%i_initial_lr=%f.pth'%(epoch, initial_lr))plt.title('epoch=%i---initial_lr=%f'%(epoch, initial_lr))plt.xlabel('epoch')plt.ylabel('loss')plt.show()plt.savefig('epoch=%i---initial_lr=%f.jpg'%(epoch, initial_lr))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/131010.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[开源]企业级在线办公系统,基于实时音视频完成在线视频会议功能

一、开源项目简介 企业级在线办公系统 本项目使用了SpringBootMybatisSpringMVC框架,技术功能点应用了WebSocket、Redis、Activiti7工作流引擎, 基于TRTC腾讯实时音视频完成在线视频会议功能。 二、开源协议 使用GPL-3.0开源协议 三、界面展示 部分…

2024天津理工大学中环信息学院专升本机械设计制造自动化专业考纲

2024年天津理工大学中环信息学院高职升本科《机械设计制造及其自动化》专业课考试大纲《机械设计》《机械制图》 《机械设计》考试大纲 教 材:《机械设计》(第十版),高等教育出版社,濮良贵、陈国定、吴立言主编&#…

1.如何实现统一的API前缀-web组件篇

文章目录 1. 问题的由来2.实现原理3. 总结 1. 问题的由来 系统提供了 2 种类型的用户,分别满足对应的管理后台、用户 App 场景。 两种场景的前缀不同,分别为/admin-api/和/app-api/,都写在一个controller里面,显然比较混乱。分开…

AI:60-基于深度学习的瓜果蔬菜分类识别

🚀 本文选自专栏:AI领域专栏 从基础到实践,深入了解算法、案例和最新趋势。无论你是初学者还是经验丰富的数据科学家,通过案例和项目实践,掌握核心概念和实用技能。每篇案例都包含代码实例,详细讲解供大家学习。 📌📌📌在这个漫长的过程,中途遇到了不少问题,但是…

网络基础扫盲-多路转发

博客内容:多路转发的常见方式select,poll,epoll 文章目录 一、五种IO模型二、多路转发的常见接口1.select2、poll3、epoll 总结 前言 Linux下一切皆文件,是文件就会存在IO的情况,IO的方式决定了效率的高低。 一、五种…

基于java+springboot+vue在线选课系统

项目介绍 本系统结合计算机系统的结构、概念、模型、原理、方法,在计算机各种优势的情况下,采用JAVA语言,结合SpringBoot框架与Vue框架以及MYSQL数据库设计并实现的。员工管理系统主要包括个人中心、课程管理、专业管理、院系信息管理、学生…

Cube MX 开发高精度电流源跳坑过程/SPI连接ADS1255/1256系列问题总结/STM32 硬件SPI开发过程

文章目录 概要整体架构流程技术名词解释技术细节小结 概要 1.使用STM32F系列开发一款高精度恒流电源,用到了24位高精度采样芯片ADS1255/ADS1256系列。 2.使用时发现很多的坑,详细介绍了每个坑的具体情况和实际的解决办法。 坑1:波特率设置…

如何使用Ruby 多线程爬取数据

现在比较主流的爬虫应该是用python,之前也写了很多关于python的文章。今天在这里我们主要说说ruby。我觉得ruby也是ok的,我试试看写了一个爬虫的小程序,并作出相应的解析。 Ruby中实现网页抓取,一般用的是mechanize,使…

Pytorch从零开始实战08

Pytorch从零开始实战——YOLOv5-C3模块实现 本系列来源于365天深度学习训练营 原作者K同学 文章目录 Pytorch从零开始实战——YOLOv5-C3模块实现环境准备数据集模型选择开始训练可视化模型预测总结 环境准备 本文基于Jupyter notebook,使用Python3.8&#xff0c…

webJS基础-----制作一个时间倒计时

1,可以使用以下两个方式制作 方式1:setTimeout ()定时器是在指定的时间后执行某些代码,代码执行一次就会自动停止; 方式2:setInterval ()定时器是按照指定的周期来重复执行某些代码,该定时器不会自动停止…

DL Homework 6

目录 一、概念 (1)卷积 (2)卷积核 (3)特征图 (4)特征选择 (5)步长 (6)填充 (7)感受野 二、探究不同卷…

JVM运行时数据区-堆

目录 一、堆的核心概述 (一)概述 (二)堆空间细分 (三)jvisualvm工具 二、设置堆内存的大小与OOM 三、年轻代与老年代 四、图解对象分配一般过程 五、对象分配特殊过程 六、常用调优工具 七、Mino…

leetCode 416.分割等和子集 + 01背包 + 动态规划 + 记忆化搜索 + 递推 + 空间优化

关于此题我的往期文章: LeetCode 416.分割等和子集(动态规划【0-1背包问题】采用一维数组dp:滚动数组)_呵呵哒( ̄▽ ̄)"的博客-CSDN博客https://heheda.blog.csdn.net/article/details/133212716看本期文章时&…

关于JADX和JEB的小问题

关于JADX和JEB的小问题 很久没水过技术文啦,最近也刚好遇到点小问题,特此记录 第一个问题 在处理app加密逻辑的时候一直拿不到正确的密文,反复看了反编译出来的代码(如下图) public static string n(String str, Stri…

基础课22——云服务(SaaS、Pass、laas、AIaas)

1.云服务概念和类型 云服务是一种基于互联网的计算模式,通过云计算技术将计算、存储、网络等资源以服务的形式提供给用户,用户可以通过网络按需使用这些资源,无需购买、安装和维护硬件设备。云服务具有灵活扩展、按需使用、随时随地访问等优…

linux 查看当前目录下每个文件夹大小

要在 Linux 中查看当前目录下每个文件夹的大小,可以使用 du 命令(磁盘使用情况)结合其他一些选项。下面是几个常用的命令示例: 显示当前目录下每个文件夹的大小——只显示一层文件夹: du -h --max-depth1该命令会以人…

2023年内衣行业分析:京东大数据平台-服饰内衣市场解析

如今,女性消费力的提升正在推动国内女性内衣市场份额逐年提升。而今年,内衣市场更是进入了存量之战,增长趋势明显减弱。 根据鲸参谋数据显示,今年1月至9月,京东平台内衣(文胸)累计销量约500万件…

【数智化案例展】某国际高端酒店品牌——呼叫中心培训数智化转型项目

‍ 维音案例 本项目案例由维音投递并参与数据猿与上海大数据联盟联合推出的《2023中国数智化转型升级创新服务企业》榜单/奖项”评选。 大数据产业创新服务媒体 ——聚焦数据 改变商业 培训是呼叫中心管理的重要环节,由于员工流动性强、培训需求多样、考核流程繁琐…

2003 - Can‘t connect to MysQL server on ‘39.108.169.0‘ (10060 “Unknown error“)

问题描述 某天和往常一样启动java项目,发现数据库出问题了,然后打开navicat,发现数据库的链接都连接不上, 一点击就会弹出报错框: 然后就各种上网搜索。 解决方案 上网查了一些解决方案,大部分都是说看…

hivesql,sql 函数总结:

1、NVL函数与Coalesce差异 -- select nvl(null,8); -- 结果是 8 -- select nvl(,7); -- 结果是"" -- select coalesce(null,null,9); -- 结果是 9 -- select coalesce("",null,9); -- 结果是 "" 1.2、 NVL函数与Coalesce差异 …