CNN、LeNet、AlexNet基于MNIST数据集进行训练和测试,并可视化对比结果

完成内容:

  1. 构建CNN并基于MNIST数据集进行训练和测试
  2. 构建LeNet并基于MNIST数据集进行训练和测试
  3. 构建AlexNet并基于MNIST数据集进行训练和测试
  4. 对比了不同网络在MNIST数据集上训练的效果

准备工作

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from matplotlib import pyplot as plt
import pandas as pd
from math import pi

下载数据,加载data_loader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'device:{device}')
batch_size = 256transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])# 加载数据(本步建议挂梯子)
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)# 加载data_loader
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)results = []

定义CNN和LeNet通用的训练函数和测试函数

def train(model, train_loader, criterion, optimizer, device):model.train()running_loss = 0.0for inputs, labels in train_loader:inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()return running_loss / len(train_loader)def test(model, test_loader, criterion, device):model.eval()correct = 0total = 0with torch.no_grad():for inputs, labels in test_loader:inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)_, predicted = torch.max(outputs, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = correct / totalreturn accuracy

构建CNN并基于MNIST数据集进行训练和测试

class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.features = nn.Sequential(nn.Conv2d(1, 16, kernel_size=5, padding=2),nn.ReLU(),nn.MaxPool2d(kernel_size=2))self.classifier = nn.Linear(16 * 14 * 14, 10)def forward(self, x):x = self.features(x)x = x.view(x.size(0), -1)x = self.classifier(x)return x
# 展示网络内部结构
X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)
for layer in CNN().features:X = layer(X)print(layer.__class__.__name__,'output shape: \t',X.shape)
网络结构:
Conv2d output shape: 	 torch.Size([1, 16, 28, 28])
ReLU output shape: 	 torch.Size([1, 16, 28, 28])
MaxPool2d output shape: 	 torch.Size([1, 16, 14, 14])
# 初始化CNN,优化器,损失函数
model = CNN().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
result = []
# 训练网络
num_epochs = 5
for epoch in tqdm(range(num_epochs), desc="training", unit="epoch"):train_loss = train(model, train_loader, criterion, optimizer, device)test_acc = test(model, test_loader, criterion, device)result.append(test_acc)print(f'Epoch {epoch+1}/{num_epochs}, Loss: {train_loss:.4f}, Test Accuracy: {test_acc:.4f}')
results.append(result)
results    

LeNet-MNIST

class LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()self.features = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.ReLU(),  # (1, 6, 28, 28)nn.AvgPool2d(kernel_size=2, stride=2),  # (1, 6, 14, 14)nn.Conv2d(6, 16, kernel_size=5), nn.ReLU(),  # (1, 16, 10, 10)nn.AvgPool2d(kernel_size=2, stride=2),  # (1, 16, 5, 5)nn.Flatten(),  # (1, 400)nn.Linear(16 * 5 * 5, 120), nn.ReLU(),  # (1, 120)nn.Linear(120, 84), nn.ReLU(),  # (1, 84)nn.Linear(84, 10)  # (1, 10))def forward(self, x):x = self.features(x)return x
# 展示LeNet网络内部结构
X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)
for layer in LeNet().features:X = layer(X)print(layer.__class__.__name__,'output shape: \t',X.shape)
# 网络结构:
Conv2d output shape: 	 torch.Size([1, 6, 28, 28])
ReLU output shape: 	 torch.Size([1, 6, 28, 28])
AvgPool2d output shape: 	 torch.Size([1, 6, 14, 14])
Conv2d output shape: 	 torch.Size([1, 16, 10, 10])
ReLU output shape: 	 torch.Size([1, 16, 10, 10])
AvgPool2d output shape: 	 torch.Size([1, 16, 5, 5])
Flatten output shape: 	 torch.Size([1, 400])
Linear output shape: 	 torch.Size([1, 120])
ReLU output shape: 	 torch.Size([1, 120])
Linear output shape: 	 torch.Size([1, 84])
ReLU output shape: 	 torch.Size([1, 84])
Linear output shape: 	 torch.Size([1, 10])
# 初始化CNN,优化器,损失函数
model = LeNet().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
result = []
# 训练模型
num_epochs = 5
for epoch in tqdm(range(num_epochs), desc="training", unit="epoch"):train_loss = train(model, train_loader, criterion, optimizer, device)test_acc = test(model, test_loader, criterion, device)result.append(test_acc)print(f'Epoch {epoch+1}/{num_epochs}, Loss: {train_loss:.4f}, Test Accuracy: {test_acc:.4f}')
results.append(result)
results

AlexNet-MNIST

# 定义AlexNet
class AlexNet(nn.Module):def __init__(self, num_classes=10):super(AlexNet, self).__init__()self.features = nn.Sequential(nn.Conv2d(1, 64, kernel_size=11, stride=4, padding=2),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),nn.Conv2d(64, 192, kernel_size=5, padding=2),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),nn.Conv2d(192, 384, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(384, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(256, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),)self.avgpool = nn.AdaptiveAvgPool2d((6, 6))self.classifier = nn.Sequential(nn.Dropout(),nn.Linear(256 * 6 * 6, 4096),nn.ReLU(inplace=True),nn.Dropout(),nn.Linear(4096, 4096),nn.ReLU(inplace=True),nn.Linear(4096, num_classes),)def forward(self, x):x = self.features(x)x = self.avgpool(x)x = torch.flatten(x, 1)x = self.classifier(x)return x
# 重新加载数据
transform = transforms.Compose([transforms.Resize((227, 227)),transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# 初始化AlexNet、优化器、损失函数
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
alexnet = AlexNet(num_classes=10).to(device)
optimizer = optim.Adam(alexnet.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
result = []# 训练
num_epochs = 5
for epoch in tqdm(range(num_epochs), desc="training", unit="epoch"):alexnet.train()for inputs, labels in train_loader:inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = alexnet(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()#  测试alexnet.eval()correct = 0total = 0with torch.no_grad():for inputs, labels in test_loader:inputs, labels = inputs.to(device), labels.to(device)outputs = alexnet(inputs)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = correct / totalresult.append(accuracy)print(f"Accuracy on test set: {accuracy * 100:.2f}%")
results.append(result)

结果分析

# Set data
df = pd.DataFrame(results)
columns = ['epoch1', 'epoch2', 'epoch3', 'epoch4', 'epoch5']
df.columns = columns
df['Network'] = ['CNN','LeNet', 'AlexNet']
print(df)
# ------- PART 1: Create background# number of variable
categories=list(df)[:-1]
N = len(categories)# What will be the angle of each axis in the plot? (we divide the plot / number of variable)
angles = [n / float(N) * 2 * pi for n in range(N)]
angles += angles[:1]# Initialise the spider plot
ax = plt.subplot(111, polar=True)# If you want the first axis to be on top:
ax.set_theta_offset(pi / 2)
ax.set_theta_direction(-1)# Draw one axe per variable + add labels
plt.xticks(angles[:-1], categories)# Draw ylabels
ax.set_rlabel_position(0)
plt.yticks([0.925,0.95,0.975], ["0.925","0.95","0.975"], color="grey", size=7)
plt.ylim(0.9,1)# ------- PART 2: Add plots# Plot each individual = each line of the data# Ind1
values=df.loc[0].drop('Network').values.flatten().tolist()
values += values[:1]
ax.plot(angles, values, linewidth=1, linestyle='solid', label="CNN")
ax.fill(angles, values, 'b', alpha=0.1)# Ind2
values=df.loc[1].drop('Network').values.flatten().tolist()
values += values[:1]
ax.plot(angles, values, linewidth=1, linestyle='solid', label="LeNet")
ax.fill(angles, values, 'r', alpha=0.1)# Ind3
values=df.loc[2].drop('Network').values.flatten().tolist()
values += values[:1]
ax.plot(angles, values, linewidth=1, linestyle='solid', label="AlexNet")
ax.fill(angles, values, 'g', alpha=0.1)# Add legend
plt.legend(loc='upper right', bbox_to_anchor=(0.1, 0.1))# Show the graph
plt.show()
   epoch1  epoch2  epoch3  epoch4  epoch5  Network
0  0.9629  0.9764  0.9818  0.9823  0.9826      CNN
1  0.9461  0.9706  0.9781  0.9810  0.9869    LeNet
2  0.9844  0.9865  0.9887  0.9855  0.9900  AlexNet

在这里插入图片描述

总体而言:
AlexNet效果更好,但Alex网络更复杂,计算开销更大;
CNN网络最简单,计算开销最小,效果也较好;
LeNet效果不如预期,按理来说LeNet网络更复杂,相较于CNN拟合效果应更好,但实际效果有偏差,怀疑是epoch较少,5个epoch不足以收敛

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/217369.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

高通SDX12:nand flash适配

一、SBL阶段 代码流程如下: boot_images\core\storage\flash\src\dal\flash_nand_init.c nand_probe ->nand_intialize_primary_hal_device ->>nand_get_device_list_supportedboot_images\core\storage\flash\src\dal\flash_nand_config.c ->>>flash_n…

ue4 解决角度万向锁的问题 蓝图节点

问题:当角度值从359-1变化的时候,数值会经历358、357… 解决方法:勾上Shortest Path,角度值的会从359-1

Kubernetes实战(十三)-使用kube-bench检测Kubernetes集群安全

1 概述 在当今云原生应用的开发中,Kubernetes已经成为标准,然而,随着其使用的普及,也带来了安全问题的挑战。本文将介绍如何使用kube-bench工具来评估和增强Kubernetes集群的安全性。 2 CIS (Center for Internet Security)简介…

OneCode低代码引擎 V2.0源码结构详解

前言 OneCode今天(12月10日)正式更新了其V2.0版本。从OneCode的季度版本生命中,可以看到2.0版本还是一个重量级的版本,笔者在收到2.0更新后第一时间下拉了最新的代码。在参考了OneCode 的技术说明后,根据包结构来分析…

MySQL笔记-第16章_变量、流程控制与游标

视频链接:【MySQL数据库入门到大牛,mysql安装到优化,百科全书级,全网天花板】 文章目录 第16章_变量、流程控制与游标1. 变量1.1 系统变量1.1.1 系统变量分类1.1.2 查看系统变量 1.2 用户变量1.2.1 用户变量分类1.2.2 会话用户变量…

GoLang EASY 微服务游戏框架 01

1 Overview EASY 是一个go语言编写的框架,兼容性支持go版本1.19,go mod 方式构建管理。它是一个轻型,灵活,自定义适配强的微服务框架。 它支持多种网络协议TCP,websocket,UDP(待完成&#xf…

数据可视化:解析跨行业普及之道

数据可视化作为一种强大的工具,在众多行业中得到了广泛的应用,其价值和优势不断被发掘和利用。今天就让我以这些年来可视化设计的经验,讨论一下数据可视化在各个行业中备受青睐的原因吧。 无论是商业、科学、医疗保健、金融还是教育领域&…

蚂蚁SEO的百度蜘蛛池有哪些优势

一、介绍 SEO是搜索引擎优化(Search Engine Optimization)的缩写,是一种通过优化网站结构、内容和链接等元素,提高网站在搜索引擎中的排名,从而增加网站流量和吸引更多潜在客户的方法。SEO已成为现代网站管理的重要策…

聚观早报 |一加12首销;华为智能手表释放科技温暖

【聚观365】12月12日消息 一加12首销 华为智能手表释放科技温暖 卡尔动力获地平线战略投资 英伟达希望在越南建立基地 努比亚Z60 Ultra影像规格揭晓 一加12首销 现在有最新消息,近日一加12该机已于昨日开售,售价4299元起。 外观方面,全…

IDC报告:国内游戏云市场,腾讯云用量规模位列第一

12月12日消息,IDC公布最新的《中国游戏云市场跟踪研究,2022H2》报告(以下简称“《报告》”)显示,腾讯云凭借全球化节点布局以及国际领先的游戏技术积累,在整体规模、云游戏流路数、CDN流量峰值带宽等多维度…

“未来医疗揭秘:机器学习+多组学数据,开启生物医学新纪元“

在当今的数字化时代,科技正在不断地改变着我们的生活,同时也为医疗领域带来了巨大的变革。随着机器学习的快速发展,以及多组学数据在生物医学中的应用,我们正开启一个全新的医疗纪元。这个纪元以精准诊断、个性化治疗和高效康复为…

Docker容器:Centos7搭建Docker镜像私服harbor

目录 1、安装docker 1.1、前置条件 1.2、查看当前操作系统的内核版本 1.3、卸载旧版本(可选) 1.4、安装需要的软件包 1.5、设置yum安装源 1.6、查看docker可用版本 1.7、安装docker 1.8、开启docker服务 1.9、安装阿里云镜像加速器 1.10、设置docker开机自启 2、安…

K8S(一)—安装部署

目录 安装部署前提以下的操作指导(在master)之前都是三台机器都需要执行 安装docker服务下面的操作仅在k8smaster执行 安装部署 前提 以下的操作指导(在master)之前都是三台机器都需要执行 关闭防火墙 [rootk8smaster ~]# vim /etc/selinux/config [rootk8smaster ~]# swa…

指针浅谈(三)

在指针浅谈(二)http://t.csdnimg.cn/SKAkD中我们讲到了const修饰指针、指针运算、野指针、assert断言和传址调用的内容,今天我们继续学习有关数组名、指针访问数组、一维数组传参的本质相关的内容,内容比较深入,如果觉得哪里讲解的不行&#…

Docker部署Nacos集群并用nginx反向代理负载均衡

首先找到Nacos官网给的Github仓库,里面有docker compose可以快速启动Nacos集群。 文章目录 一. 脚本概况二. 自定义修改1. example/cluster-hostname.yaml2. example/.env3. env/mysql.env4. env/nacos-hostname.env 三、运行四、nginx反向代理,负载均衡…

【JavaEE学习】初识进程概念

个人主页:兜里有颗棉花糖 欢迎 点赞👍 收藏✨ 留言✉ 加关注💓本文由 兜里有颗棉花糖 原创 收录于专栏【Java系列】【JaveEE学习专栏】 本专栏旨在分享学习JavaEE的一点学习心得,欢迎大家在评论区交流讨论💌 目录 一、…

python:五种算法(GA、OOA、DBO、SSA、PSO)求解23个测试函数(python代码)

一、五种算法简介 1、遗传算法GA 2、鱼鹰优化算法OOA 3、蜣螂优化算法DBO 4、麻雀搜索算法SSA 5、粒子群优化算法PSO 二、5种算法求解23个函数 (1)23个函数简介 参考文献: [1] Yao X, Liu Y, Lin G M. Evolutionary programming made…

百度文库下载要用券?Kotlin爬虫几步解决

百度作为国内知名的网站,尤其是文库里面有各种丰富的内容,对我们学习生活都有很大的帮助,就因为其内容丰富,如果看见好用有意思的文章还用复制粘贴等方式就显得有点落后了,今天我将用我所学的爬虫知识给你们好好上一课…

基于51单片机的语音识别控制系统

0-演示视频 1-功能说明 (1)使用DHT11检测温湿度,然后用LCD12864显示,语音播放,使用STC11l08xe控制LD3320做语音识别, (2)上电时语音提示:欢迎使用声音识别系统&#xf…

【vue实战项目】通用管理系统:信息列表,信息的编辑和删除

本文为博主的vue实战小项目系列中的第七篇,很适合后端或者才入门的小伙伴看,一个前端项目从0到1的保姆级教学。前面的内容: 【vue实战项目】通用管理系统:登录页-CSDN博客 【vue实战项目】通用管理系统:封装token操作…