Pytorch | 从零构建AlexNet对CIFAR10进行分类

Pytorch | 从零构建AlexNet对CIFAR10进行分类

  • CIFAR10数据集
  • AlexNet
    • 网络结构
    • 技术创新点
    • 性能表现
    • 影响和意义
  • AlexNet结构代码详解
    • 结构代码
    • 代码详解
      • 特征提取层 self.features
      • 分类部分self.classifier
      • 前向传播forward
  • 训练过程和测试结果
  • 代码汇总
    • alexnet.py
    • train.py
    • test.py

CIFAR10数据集

CIFAR-10数据集是由加拿大高级研究所(CIFAR)收集整理的用于图像识别研究的常用数据集,基本信息如下:

  • 数据规模:该数据集包含60,000张彩色图像,分为10个不同的类别,每个类别有6,000张图像。通常将其中50,000张作为训练集,用于模型的训练;10,000张作为测试集,用于评估模型的性能。
  • 图像尺寸:所有图像的尺寸均为32×32像素,这相对较小的尺寸使得模型在处理该数据集时能够相对快速地进行训练和推理,但也增加了图像分类的难度。
  • 类别内容:涵盖了飞机(plane)、汽车(car)、鸟(bird)、猫(cat)、鹿(deer)、狗(dog)、青蛙(frog)、马(horse)、船(ship)、卡车(truck)这10个不同的类别,这些类别都是现实世界中常见的物体,具有一定的代表性。

下面是一些示例样本:
在这里插入图片描述

AlexNet

AlexNet是由Alex Krizhevsky、Ilya Sutskever和Geoffrey Hinton在2012年提出的一种深度卷积神经网络,在ImageNet图像识别挑战赛中取得了巨大成功,推动了深度学习在计算机视觉领域的快速发展。以下是对它的详细介绍:

网络结构

  • 卷积层:包含5个卷积层,这些卷积层通过不同的卷积核大小、步长和填充方式,逐步提取图像的特征。
  • 池化层:有3个最大池化层,用于减小特征图的尺寸,同时保留关键特征,减少计算量和过拟合风险。
  • 全连接层:包括3个全连接层,用于对提取的特征进行分类,最后一层输出分类结果。
    在这里插入图片描述
    上图为AlexNet原文中的网络结构(针对ImageNet,图片尺寸为224×224),本文是针对CIFAR10,其尺寸为32×32,因此结构不太相同,比如卷积核的大小,具体可以参考下面的代码。

技术创新点

  • ReLU激活函数:使用ReLU(Rectified Linear Unit)作为激活函数,解决了传统激活函数在深度网络中梯度消失的问题,加快了训练速度。
  • Dropout正则化:在全连接层中使用了Dropout技术,随机丢弃部分神经元,防止过拟合,提高模型的泛化能力。
  • 重叠池化:采用重叠池化(Overlapping Pooling),即池化窗口之间有重叠,有助于提取更多的特征信息,提升模型的性能。
  • 多GPU训练:首次利用多GPU进行并行训练,大大提高了训练速度,使得在大规模数据集上训练深度网络成为可能。

性能表现

  • 在ImageNet数据集上,AlexNet的top-5错误率大幅降低至15.3%,相比之前的方法有了显著提升,展示了其强大的图像识别能力。
  • 能够学习到丰富的图像特征,对不同类别的物体具有很好的区分能力,在实际应用中取得了很好的效果。

影响和意义

  • 推动深度学习发展:AlexNet的成功引起了学术界和工业界对深度学习的广泛关注,激发了更多研究人员对深度神经网络的研究兴趣,推动了深度学习技术的快速发展。
  • 开启卷积神经网络新时代:为后续的卷积神经网络研究提供了重要的参考和借鉴,许多新的网络结构和技术都是在AlexNet的基础上发展而来的。
  • 拓展应用领域:由于其在图像识别任务上的出色表现,AlexNet及其改进模型被广泛应用于计算机视觉的各个领域,如目标检测、图像分割、人脸识别等。

AlexNet结构代码详解

结构代码

import torch
import torch.nn as nnclass AlexNet(nn.Module):def __init__(self, num_classes):super(AlexNet, self).__init__()self.features = nn.Sequential(# input size: (B, 3, 32, 32)   (Batch_size, Channel, Height, Width)nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1), # (B, 64, 16, 16)nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2),    # (B, 64, 8, 8)nn.Conv2d(64, 192, kernel_size=3, padding=1),   # (B, 192, 8, 8)nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2),    # (B, 192, 4, 4)nn.Conv2d(192, 384, kernel_size=3, padding=1),  # (B, 384, 4, 4)nn.ReLU(inplace=True),nn.Conv2d(384, 256, kernel_size=3, padding=1),  # (B, 256, 4, 4)nn.ReLU(inplace=True),nn.Conv2d(256, 256, kernel_size=3, padding=1),  # (B, 256, 4, 4)nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2),    # (B, 256, 2, 2))self.classifier = nn.Sequential(nn.Dropout(),nn.Linear(256 * 2 * 2, 4096),nn.ReLU(inplace=True),nn.Dropout(),nn.Linear(4096, 4096),nn.ReLU(inplace=True),nn.Linear(4096, num_classes))def forward(self, x):x = self.features(x)x = x.view(x.size(0), 256 * 2 *2)x = self.classifier(x)return x

代码详解

以下是对上述AlexNet代码的详细解释:

特征提取层 self.features

这部分构建了AlexNet的特征提取层,是一个由多个层组成的顺序结构(通过nn.Sequential来定义)。
- 第一个卷积层
nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1)表示输入图像的通道数为3(通常对应RGB图像的红、绿、蓝三个通道),输出的通道数为64(即卷积核的数量为64,意味着会生成64个不同的特征图),卷积核大小是3×3,步长为2(在空间维度上每次移动2个像素),填充为1(在图像边缘进行1个像素的填充,这样可以保证输入输出图像尺寸在卷积操作下能按预期变化),经过这个卷积层后,输入尺寸为(B, 3, 32, 32)的图像数据会变成(B, 64, 16, 16)
- 激活函数层
nn.ReLU(inplace=True)是使用修正线性单元(Rectified Linear Unit)作为激活函数,inplace=True表示直接在输入的张量上进行修改(节省内存空间),对经过卷积后的特征图进行非线性变换,增强网络的表达能力。
- 池化层
nn.MaxPool2d(kernel_size=2)是最大池化层,池化核大小为2×2,它会在每个2×2的窗口内选取最大值作为输出,起到下采样的作用,减少数据量同时保留重要特征,比如经过第一次池化后特征图尺寸从(B, 64, 16, 16)变为(B, 64, 8, 8)

后续依次重复卷积、激活、池化等操作,不断提取图像的特征,逐步降低特征图的尺寸同时增加特征图的深度(通道数),最终经过这一系列操作后得到尺寸为(B, 256, 2, 2)的特征图。

分类部分self.classifier

self.classifier = nn.Sequential(nn.Dropout(),nn.Linear(256 * 2 * 2, 4096),nn.ReLU(inplace=True),nn.Dropout(),nn.Linear(4096, 4096),nn.ReLU(inplace=True),nn.Linear(4096, num_classes)
)

这部分构建了AlexNet的分类器,同样是顺序结构。
- Dropout层
nn.Dropout()是一种正则化技术,在训练过程中以一定概率(默认0.5)随机将神经元的输出设置为0,防止过拟合,提高模型的泛化能力。这里使用了两次Dropout,分别在不同的全连接层之前。
- 全连接层
第一个nn.Linear(256 * 2 * 2, 4096)表示将经过特征提取后展平的特征向量(尺寸为256 * 2 * 2,因为前面特征提取部分最后得到的特征图尺寸是(B, 256, 2, 2),展平后维度就是256 * 2 * 2)映射到一个4096维的向量空间,后面接着激活函数nn.ReLU(inplace=True)进行非线性变换。然后又是一个Dropout层和一个同样输出维度为4096的全连接层以及相应的激活函数,最后通过nn.Linear(4096, num_classes)将4096维的向量映射到指定的类别数(num_classes)维度,得到最终的分类预测结果。

前向传播forward

def forward(self, x):x = self.features(x)x = x.view(x.size(0), 256 * 2 *2)x = self.classifier(x)return x

forward方法定义了数据在网络中的前向传播过程。

  • 特征提取
    首先x = self.features(x),将输入数据x送入到之前定义的特征提取部分(features),按照特征提取层中定义的卷积、激活、池化等操作依次对输入数据进行处理,得到提取后的特征图。
  • 特征图展平
    x = x.view(x.size(0), 256 * 2 *2)这行代码将特征图进行展平操作,使其变成一个二维张量,其中第一维对应批次大小(x.size(0)表示批次中的样本数量),第二维就是展平后的特征向量长度(由前面特征提取最后得到的特征图尺寸计算得出),这样才能输入到后面的全连接层中进行分类处理。
  • 分类预测
    最后x = self.classifier(x)将展平后的特征向量送入分类器部分(classifier),经过全连接层、激活函数、Dropout等操作逐步得到最终的分类预测结果,然后通过return x返回这个预测结果。

训练过程和测试结果

训练过程损失函数变化曲线:
在这里插入图片描述
在这里插入图片描述
训练过程准确率变化曲线:
在这里插入图片描述
测试结果:
在这里插入图片描述

代码汇总

项目github地址
项目结构:

|--data
|--models|--__init__.py|--alexnet.py
|--results
|--weights
|--train.py
|--test.py

alexnet.py

import torch
import torch.nn as nnclass AlexNet(nn.Module):def __init__(self, num_classes):super(AlexNet, self).__init__()self.features = nn.Sequential(# input size: (B, 3, 32, 32)   (Batch_size, Channel, Height, Width)nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1), # (B, 64, 16, 16)nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2),    # (B, 64, 8, 8)nn.Conv2d(64, 192, kernel_size=3, padding=1),   # (B, 192, 8, 8)nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2),    # (B, 192, 4, 4)nn.Conv2d(192, 384, kernel_size=3, padding=1),  # (B, 384, 4, 4)nn.ReLU(inplace=True),nn.Conv2d(384, 256, kernel_size=3, padding=1),  # (B, 256, 4, 4)nn.ReLU(inplace=True),nn.Conv2d(256, 256, kernel_size=3, padding=1),  # (B, 256, 4, 4)nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2),    # (B, 256, 2, 2))self.classifier = nn.Sequential(nn.Dropout(),nn.Linear(256 * 2 * 2, 4096),nn.ReLU(inplace=True),nn.Dropout(),nn.Linear(4096, 4096),nn.ReLU(inplace=True),nn.Linear(4096, num_classes))def forward(self, x):x = self.features(x)x = x.view(x.size(0), 256 * 2 *2)x = self.classifier(x)return x

train.py

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from models import AlexNet
import matplotlib.pyplot as pltimport ssl
ssl._create_default_https_context = ssl._create_unverified_context# 定义数据预处理操作
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])# 加载CIFAR10训练集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,shuffle=True, num_workers=2)# 定义设备(GPU优先,若可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 实例化模型
model = AlexNet(num_classes=10).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练轮次
epochs = 15def train(model, trainloader, criterion, optimizer, device):model.train()running_loss = 0.0correct = 0total = 0for i, data in enumerate(trainloader, 0):inputs, labels = data[0].to(device), data[1].to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()epoch_loss = running_loss / len(trainloader)epoch_acc = 100. * correct / totalreturn epoch_loss, epoch_accif __name__ == "__main__":loss_history, acc_history = [], []for epoch in range(epochs):train_loss, train_acc = train(model, trainloader, criterion, optimizer, device)print(f'Epoch {epoch + 1}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')loss_history.append(train_loss)acc_history.append(train_acc)# 保存模型权重,每5轮次保存到weights文件夹下if (epoch + 1) % 5 == 0:torch.save(model.state_dict(), f'weights/alexnet_epoch_{epoch + 1}.pth')# 绘制损失曲线plt.plot(range(1, epochs+1), loss_history, label='Loss', marker='o')plt.xlabel('Epoch')plt.ylabel('Loss')plt.title('Training Loss Curve')plt.legend()plt.savefig('results\\train_loss_curve.png')plt.close()# 绘制准确率曲线plt.plot(range(1, epochs+1), acc_history, label='Accuracy', marker='o')plt.xlabel('Epoch')plt.ylabel('Accuracy (%)')plt.title('Training Accuracy Curve')plt.legend()plt.savefig('results\\train_acc_curve.png')plt.close()

test.py

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from models import AlexNetimport ssl
ssl._create_default_https_context = ssl._create_unverified_context
# 定义数据预处理操作
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])# 加载CIFAR10测试集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,shuffle=False, num_workers=2)# 定义设备(GPU优先,若可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 实例化模型
model = AlexNet(num_classes=10).to(device)
criterion = nn.CrossEntropyLoss()# 加载模型权重
weights_path = "weights/alexnet_epoch_15.pth"  
model.load_state_dict(torch.load(weights_path, map_location=device))def test(model, testloader, criterion, device):model.eval()running_loss = 0.0correct = 0total = 0with torch.no_grad():for data in testloader:inputs, labels = data[0].to(device), data[1].to(device)outputs = model(inputs)loss = criterion(outputs, labels)running_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()epoch_loss = running_loss / len(testloader)epoch_acc = 100. * correct / totalreturn epoch_loss, epoch_accif __name__ == "__main__":test_loss, test_acc = test(model, testloader, criterion, device)print("================AlexNet Test================")print(f"Load Model Weights From: {weights_path}")print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/64595.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

c语言-----数组

基本概念 数组是C语言中一种用于存储多个相同类型数据的数据结构。这些数据在内存中是连续存储的,可以通过索引(下标)来访问数组中的各个元素。数组的索引从0开始,这是C语言的规定。例如,一个有n个元素的数组&#xff…

【最新攻略】腾讯云双十一最强攻略密码

引言一起来薅羊毛准备工作双人成团PK有大礼! 福利总结 引言 腾讯云(Tencent Cloud)想必大家都听说过吧?腾讯云是腾讯提供的“云计算”服务。你可以把它想成一个超级强大的网络平台,帮助公司和开发者把自己的技术、数据、网站等东西…

【C#】List求并集、交集、差集

值类型List List<int> intList1 new List<int>() { 1, 2, 3 };List<int> intList2 new List<int>() { 3, 4, 5 };var result intList1.Union(intList2);Console.WriteLine($"并 {string.Join(,,result)}");result intList1.Intersect(in…

游戏渠道假量解决方案

某推广公司在推广过程中被查出“短期内点击量激增”“存在同一地址多次访问”“已注册用户重复注册”等数据作弊行为&#xff0c;法院判罚退还服务费200余万元&#xff0c;并赔偿违约金约350万元。 某公司为提升其游戏在应用商店榜单排名&#xff0c;委托某网络公司进行下载、注…

【Linux运维】配置ssh免密登录

1.场景描述 内网环境&#xff0c;需要同步17服务器的文件到10服务器进行备份。因为每次输入密码比较繁琐&#xff0c;如果实现免密登录后&#xff0c;即可简化脚本。 要求&#xff1a;需要2台服务器-免密登录 2.方案分析 &#xff08;1&#xff09;现状&#xff1a;登录需要输…

Android实现RecyclerView边缘渐变效果

Android实现RecyclerView边缘渐变效果 1.前言&#xff1a; 是指在RecyclerView中实现淡入淡出效果的边缘效果。通过这种效果&#xff0c;可以使RecyclerView的边缘在滚动时逐渐淡出或淡入&#xff0c;以提升用户体验。 2.Recyclerview属性&#xff1a; 2.1、requiresFading…

C语言——实现找出最高分

问题描述&#xff1a;分别有6名学生的学号、姓名、性别、年龄和考试分数&#xff0c;找出这些学生当中考试成绩最高的学生姓名。 //找出最高分#include<stdio.h>struct student {char stu_num[10]; //学号 char stu_name[10]; //姓名 char sex; //性别 int age; …

Kafka Streams 在监控场景的应用与实践

作者&#xff1a;来自 vivo 互联网服务器团队- Pang Haiyun 介绍 Kafka Streams 的原理架构&#xff0c;常见配置以及在监控场景的应用。 一、背景 在当今大数据时代&#xff0c;实时数据处理变得越来越重要&#xff0c;而监控数据的实时性和可靠性是监控能力建设最重要的一环…

作业Day4: 链表函数封装 ; 思维导图

目录 作业&#xff1a;实现链表剩下的操作&#xff1a; 任意位置删除 按位置修改 按值查找返回地址 反转 销毁 运行结果 思维导图 作业&#xff1a;实现链表剩下的操作&#xff1a; 1>任意位置删除 2>按位置修改 3>按值查找返回地址 4>反转 5>销毁 任意…

省略内容在句子中间

一、使用二分查找法 每次查找时&#xff0c;将查找范围分成两半&#xff0c;并判断目标值位于哪一半&#xff0c;从而逐步缩小查找范围。 循环查找 计算中间位置 mid Math.floor((low high) / 2)。比较目标值 target 和中间位置的元素 arr[mid]&#xff1a; 如果 target ar…

IDEA中解决Edit Configurations中没有tomcat Server选项的问题

今天使用IDEA2024专业版的时候,发现Edit Configurations里面没有tomcat Server,最终找到解决方案。 一、解决办法 1、打开Settings 2、搜索tomcat插件 搜索tomcat插件之后,找到tomcat 发现tomcat插件处于未勾选状态,然后我们将其勾选保存即可。 二、结果展示 最后,再次编…

UE5中实现Billboard公告板渲染

公告板&#xff08;Billboard&#xff09;通常指永远面向摄像机的面片&#xff0c;游戏中许多技术都基于公告板&#xff0c;例如提示拾取图标、敌人血槽信息等&#xff0c;本文将使用UE5和材质节点制作一个公告板。 Gif效果&#xff1a; 网格效果&#xff1a; 1.思路 通过…

LabVIEW在电液比例控制与伺服控制中的应用

LabVIEW作为一种图形化编程环境&#xff0c;广泛应用于各类控制系统中&#xff0c;包括电液比例控制和伺服控制领域。在这些高精度、高动态要求的控制系统中&#xff0c;LabVIEW的优势尤为突出。以下从多个角度探讨其应用与优势&#xff1a; ​ 1. 灵活的控制架构 LabVIEW为电…

《深入浅出Apache Spark》系列⑤:Spark SQL的表达式优化

导读&#xff1a;随着数据量的快速增长&#xff0c;传统的数据处理方法难以满足对计算速度、资源利用率以及查询响应时间的要求。为了应对这些挑战&#xff0c;Spark SQL 引入了多种优化技术&#xff0c;以提高查询效率&#xff0c;降低计算开销。本文从表达式层面探讨了 Spark…

unity webgl部署到iis报错

Unable to parse Build/WebGLOut.framework.js.unityweb! The file is corrupt, or compression was misconfigured? (check Content-Encoding HTTP Response Header on web server) iis报错的 .unityweb application/octet-stream iis中添加 MIME类型 .data applicatio…

CXF WebService SpringBoot 添加拦截器,处理响应报文格式

描述 XFIRE升级CXF框架&#xff0c;但是对接的系统不做调整&#xff0c;这时候就要保证参数报文和响应报文和以前是一致的。但是不同的框架有不同的规则&#xff0c;想要将报文调整的一致&#xff0c;就需要用到拦截器拦截报文&#xff0c;自定义解析处理。 CXF框架本身就是支…

基于Spring Boot的雅苑小区管理系统

一、系统背景与意义 随着信息化技术的快速发展&#xff0c;传统的小区物业管理方式已经难以满足现代居民对于高效、便捷服务的需求。因此&#xff0c;开发一款基于Spring Boot的小区管理系统显得尤为重要。该系统旨在通过信息化手段&#xff0c;实现小区物业管理的智能化、自动…

Docke_常用命令详解

这篇文章分享一下笔者常用的Docker命令供各位读者参考。 为什么要用Docker? 简单来说&#xff1a;Docker通过提供轻量级、隔离且可移植的容器化环境&#xff0c;使得应用在不同平台上保持一致性、易于部署和管理&#xff0c;具体如下 环境一致性&#xff1a; Docker容器使得…

Ubuntu 20.04下Kinect2驱动环境配置与测试【稳定无坑版】

一、引言 微软Kinect2传感器作为一个包含深度传感器、RGB摄像头以及红外摄像头的多模态采集设备&#xff0c;在计算机视觉、机器人感知、人体姿态识别、3D建模等领域有着广泛应用。相比第一代Kinect&#xff0c;Kinect2拥有更好的深度分辨率和更高的数据质量。本文将详细介绍如…

Flask入门:打造简易投票系统

目录 准备工作 创建项目结构 编写HTML模板 编写Flask应用 代码解读 进一步优化 结语 Flask,这个轻量级的Python Web框架,因其简洁和易用性,成为很多开发者入门Web开发的首选。今天,我们就用Flask来做一个简单的投票系统,让你快速上手Web开发,同时理解Flask的核心概…