LeNet对MNIST 数据集中的图像进行分类--keras实现

我们将训练一个卷积神经网络来对 MNIST 数据库中的图像进行分类,可以与前面所提到的CNN实现对比CNN对 MNIST 数据库中的图像进行分类-CSDN博客

加载 MNIST 数据库

MNIST 是机器学习领域最著名的数据集之一。

  • 它有 70,000 张手写数字图像 - 下载非常简单 - 图像尺寸为 28x28 - 灰度图像
from keras.datasets import mnist# 使用 Keras 导入预洗牌 MNIST 数据库
(X_train, y_train), (X_test, y_test) = mnist.load_data()print("The MNIST database has a training set of %d examples." % len(X_train))
print("The MNIST database has a test set of %d examples." % len(X_test))

将前六个训练图像可视化 

import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib.cm as cm
import numpy as np# 绘制前六幅训练图像
fig = plt.figure(figsize=(20,20))
for i in range(6):ax = fig.add_subplot(1, 6, i+1, xticks=[], yticks=[])ax.imshow(X_train[i], cmap='gray')ax.set_title(str(y_train[i]))

 查看图像的更多细节

def visualize_input(img, ax):ax.imshow(img, cmap='gray')width, height = img.shapethresh = img.max()/2.5for x in range(width):for y in range(height):ax.annotate(str(round(img[x][y],2)), xy=(y,x),horizontalalignment='center',verticalalignment='center',color='white' if img[x][y]<thresh else 'black')fig = plt.figure(figsize = (12,12)) 
ax = fig.add_subplot(111)
visualize_input(X_train[0], ax)

预处理输入图像:通过将每幅图像中的每个像素除以 255 来调整图像比例

# normalize the data to accelerate learning
mean = np.mean(X_train)
std = np.std(X_train)
X_train = (X_train-mean)/(std+1e-7)
X_test = (X_test-mean)/(std+1e-7)print('X_train shape:', X_train.shape)
print(X_train.shape[0], 'train samples')
print(X_test.shape[0], 'test samples')

对标签进行预处理:使用单热方案对分类整数标签进行编码

from keras.utils import np_utilsnum_classes = 10 
# print first ten (integer-valued) training labels
print('Integer-valued labels:')
print(y_train[:10])# one-hot encode the labels
# convert class vectors to binary class matrices
y_train = np_utils.to_categorical(y_train, num_classes)
y_test = np_utils.to_categorical(y_test, num_classes)# print first ten (one-hot) training labels
print('One-hot labels:')
print(y_train[:10])

重塑数据以适应我们的 CNN(和 input_shape)

# input image dimensions 28x28 pixel images. 
img_rows, img_cols = 28, 28X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1)
X_test = X_test.reshape(X_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)print('image input shape: ', input_shape)
print('x_train shape:', X_train.shape)

定义模型架构

论文地址:lecun-01a.pdf

要在 Keras 中实现 LeNet-5,请阅读原始论文并从第 6、7 和 8 页中提取架构信息。以下是构建 LeNet-5 网络的主要启示:

  • 每个卷积层的滤波器数量:从图中(以及论文中的定义)可以看出,每个卷积层的深度(滤波器数量)如下:C1 = 6、C3 = 16、C5 = 120 层。
  • 每个 CONV 层的内核大小:根据论文,内核大小 = 5 x 5
  • 每个卷积层之后都会添加一个子采样层(POOL)。每个单元的感受野是一个 2 x 2 的区域(即 pool_size = 2)。请注意,LeNet-5 创建者使用的是平均池化,它计算的是输入的平均值,而不是我们在早期项目中使用的最大池化层,后者传递的是输入的最大值。如果您有兴趣了解两者的区别,可以同时尝试。在本实验中,我们将采用论文架构。
  • 激活函数:LeNet-5 的创建者为隐藏层使用了 tanh 激活函数,因为对称函数被认为比 sigmoid 函数收敛更快。一般来说,我们强烈建议您为网络中的每个卷积层添加 ReLU 激活函数。

需要记住的事项

  • 始终为 CNN 中的 Conv2D 层添加 ReLU 激活函数。除了网络中的最后一层,密集层也应具有 ReLU 激活函数。
  • 在构建分类网络时,网络的最后一层应该是具有软最大激活函数的密集(FC)层。最终层的节点数应等于数据集中的类别总数。
from keras.models import Sequential
from keras.layers import Conv2D, AveragePooling2D, Flatten, Dense
#Instantiate an empty model
model = Sequential()# C1 Convolutional Layer
model.add(Conv2D(6, kernel_size=(5, 5), strides=(1, 1), activation='tanh', input_shape=input_shape, padding='same'))# S2 Pooling Layer
model.add(AveragePooling2D(pool_size=(2, 2), strides=2, padding='valid'))# C3 Convolutional Layer
model.add(Conv2D(16, kernel_size=(5, 5), strides=(1, 1), activation='tanh', padding='valid'))# S4 Pooling Layer
model.add(AveragePooling2D(pool_size=(2, 2), strides=2, padding='valid'))# C5 Fully Connected Convolutional Layer
model.add(Conv2D(120, kernel_size=(5, 5), strides=(1, 1), activation='tanh', padding='valid'))#Flatten the CNN output so that we can connect it with fully connected layers
model.add(Flatten())# FC6 Fully Connected Layer
model.add(Dense(84, activation='tanh'))# Output Layer with softmax activation
model.add(Dense(10, activation='softmax'))# print the model summary
model.summary()

编译模型

我们将使用亚当优化器

# the loss function is categorical cross entropy since we have multiple classes (10) # compile the model by defining the loss function, optimizer, and performance metric
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

训练模型

LeCun 和他的团队采用了计划衰减学习法,学习率的值按照以下时间表递减:前两个历元为 0.0005,接下来的三个历元为 0.0002,接下来的四个历元为 0.00005,之后为 0.00001。在论文中,作者对其网络进行了 20 个历元的训练。

from keras.callbacks import ModelCheckpoint, LearningRateScheduler# set the learning rate schedule as created in the original paper
def lr_schedule(epoch):if epoch <= 2:     lr = 5e-4elif epoch > 2 and epoch <= 5:lr = 2e-4elif epoch > 5 and epoch <= 9:lr = 5e-5else: lr = 1e-5return lrlr_scheduler = LearningRateScheduler(lr_schedule)# set the checkpointer
checkpointer = ModelCheckpoint(filepath='model.weights.best.hdf5', verbose=1, save_best_only=True)# train the model
hist = model.fit(X_train, y_train, batch_size=32, epochs=20,validation_data=(X_test, y_test), callbacks=[checkpointer, lr_scheduler], verbose=2, shuffle=True)

在验证集上加载分类准确率最高的模型

# load the weights that yielded the best validation accuracy
model.load_weights('model.weights.best.hdf5')

计算测试集的分类准确率

# evaluate test accuracy
score = model.evaluate(X_test, y_test, verbose=0)
accuracy = 100*score[1]# print test accuracy
print('Test accuracy: %.4f%%' % accuracy)

评估模型

import matplotlib.pyplot as pltf, ax = plt.subplots()
ax.plot([None] + hist.history['accuracy'], 'o-')
ax.plot([None] + hist.history['val_accuracy'], 'x-')
# 绘制图例并自动使用最佳位置: loc = 0。
ax.legend(['Train acc', 'Validation acc'], loc = 0)
ax.set_title('Training/Validation acc per Epoch')
ax.set_xlabel('Epoch')
ax.set_ylabel('acc')
plt.show()

import matplotlib.pyplot as pltf, ax = plt.subplots()
ax.plot([None] + hist.history['loss'], 'o-')
ax.plot([None] + hist.history['val_loss'], 'x-')# Plot legend and use the best location automatically: loc = 0.
ax.legend(['Train loss', "Val loss"], loc = 0)
ax.set_title('Training/Validation Loss per Epoch')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
plt.show()

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/189945.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MMdetection3.0 问题:DETR验证集上AP值全为0.0000

MMdetection3.0 问题&#xff1a;DETR验证集上AP值全为0.0000 条件&#xff1a; 1、选择的是NWPU-VHR-10数据集&#xff0c;其中训练集455张&#xff0c;验证、测试相同均为195张。 2、在Faster-rcnn、YOLOv3模型上均能训练成功&#xff0c;但是在DETR训练时&#xff0c;出现验…

《地理信息系统原理》笔记/期末复习资料(7. 空间分析)

目录 7. 空间分析 7.1 空间分析的内容与步骤 7.2 数据检索及表格分析 7.2.1 属性统计分析 7.2.2 布尔逻辑查询 7.2.3 空间数据库查询语言 7.2.4 重分类&#xff0c;边界消除与合并 7.3 叠置分析 7.3.1 栅格系统的叠加分析 7.3.2 矢量系统的叠加分析&#xff08;拓扑叠…

视频后期特效处理软件 Motion 5 mac中文版

Motion mac是一款运动图形和视频合成软件&#xff0c;适用于Mac OS平台。 Motion mac软件特点 - 精美的效果&#xff1a;Motion提供了多种高质量的运动图形和视频效果&#xff0c;例如3D效果、烟雾效果、粒子效果等&#xff0c;方便用户制作出丰富多彩的视频和动画。 - 高效的工…

设计模式-结构型模式之桥接设计模式

文章目录 三、桥接模式 三、桥接模式 桥接模式&#xff08;Bridge&#xff09;是用于把抽象化与实现化解耦&#xff0c;使得二者可以独立变化。它通过提供抽象化和实现化之间的桥接结构&#xff0c;来实现二者的解耦。 这种模式涉及到一个作为桥接的接口&#xff0c;使得实体类…

Windows本地搭建Emby媒体库服务器并实现远程访问「内网穿透」

文章目录 1.前言2. Emby网站搭建2.1. Emby下载和安装2.2 Emby网页测试 3. 本地网页发布3.1 注册并安装cpolar内网穿透3.2 Cpolar云端设置3.3 Cpolar内网穿透本地设置 4.公网访问测试5.结语 1.前言 在现代五花八门的网络应用场景中&#xff0c;观看视频绝对是主力应用场景之一&…

jQuery选择器、操作DOM、事件处理机制、动画、ADJX操作知识点梳理

jQuery 核心理念就是写的更少&#xff0c;做的更多实现的代码更加简洁有效的提高开发效率 jQuery跟JavaScript的用法是不一样的 跟jQuery相继诞生的JavaScript库还有很多&#xff0c;不包括node.js 关于代码$("li").get(0)&#xff0c;获取DOM对象 jQuery对象声…

前缀和 LeetCode1094 拼车

1094. 拼车 车上最初有 capacity 个空座位。车 只能 向一个方向行驶&#xff08;也就是说&#xff0c;不允许掉头或改变方向&#xff09; 给定整数 capacity 和一个数组 trips , trip[i] [numPassengersi, fromi, toi] 表示第 i 次旅行有 numPassengersi 乘客&#xff0c;接…

Tomcat 漏洞修复

1、去掉请求响应中Server信息 修复方法&#xff1a; 在Tomcat的配置文件的Connector中增加 server" " &#xff0c;server 的值可以改成你任意想返回的值。

Elasticsearch:什么是自然语言处理(NLP)?

自然语言处理定义 自然语言处理 (natural language processing - NLP) 是人工智能 (AI) 的一种形式&#xff0c;专注于计算机和人们使用人类语言进行交互的方式。 NLP 技术帮助计算机使用我们的自然交流模式&#xff08;语音和书面文本&#xff09;来分析、理解和响应我们。 自…

一进三出宿舍限电模块的改造升级

一进三出宿舍限电模块改造升级石家庄光大远通电气有限公司智能模块功能特点&#xff1a; 电能控制功能&#xff1a;可实施剩余电量管理&#xff0c;电量用完时将自动断电&#xff1b; 剩余电量可视报警提示功能&#xff1a;剩余电量可视&#xff0c;并当电量剩余5度时&#xff…

C#和MySQL技巧分享:日期的模糊查询

文章目录 前言一、EF Core 模糊查询二、MySql 日期模糊查询分析和优化2.1 测试环境准备2.1.1 创建数据库2.1.2 查看测试数据 2.2 查询日期的运行效率对比2.3 运行效率优化 三、EF Core 模糊查询优化3.1 字符串转日期3.2 使用日期格式查询 四、优化建议总结 前言 在处理数据库查…

go开发之个人微信号机器人开发

简要描述&#xff1a; 下载消息中的文件 请求URL&#xff1a; http://域名地址/getMsgFile 请求方式&#xff1a; POST 请求头Headers&#xff1a; Content-Type&#xff1a;application/jsonAuthorization&#xff1a;login接口返回 参数&#xff1a; 参数名必选类型…

Linux Spug自动化运维平台本地部署与公网远程访问

文章目录 前言1. Docker安装Spug2 . 本地访问测试3. Linux 安装cpolar4. 配置Spug公网访问地址5. 公网远程访问Spug管理界面6. 固定Spug公网地址 前言 Spug 面向中小型企业设计的轻量级无 Agent 的自动化运维平台&#xff0c;整合了主机管理、主机批量执行、主机在线终端、文件…

版本控制系统Git学习笔记-Git分支操作

文章目录 概述一、Git分支简介1.1 基本概念1.2 创建分支1.3 分支切换1.4 删除分支 二、新建和合并分支2.1 工作流程示意图2.2 新建分支2.3 合并分支2.4 分支示例2.4.1 当前除了主分支&#xff0c;再次创建了两个分支2.4.2 先合并test1分支2.4.3 合并testbranch分支 2.5 解决合并…

算法基础三

电话号码的字母组合 给定一个仅包含数字 2-9 的字符串&#xff0c;返回所有它能表示的字母组合。答案可以按 任意顺序 返回。给出数字到字母的映射如下&#xff08;与电话按键相同&#xff09;。注意 1 不对应任何字母。 示例 1&#xff1a; 输入&#xff1a;digits "…

【集合篇】Java集合概述

Java 集合概述 集合与容器 容器&#xff08;Container&#xff09;是一个更广泛的术语&#xff0c;用于表示可以容纳、组织和管理其他对象的对象。它是一个更高层次的概念&#xff0c;包括集合&#xff08;Collection&#xff09;在内。集合&#xff08;Collection&#xff0…

深入理解Servlet(中)

作者简介&#xff1a;大家好&#xff0c;我是smart哥&#xff0c;前中兴通讯、美团架构师&#xff0c;现某互联网公司CTO 联系qq&#xff1a;184480602&#xff0c;加我进群&#xff0c;大家一起学习&#xff0c;一起进步&#xff0c;一起对抗互联网寒冬 上篇有一张图&#xff…

Redis的安装

本文采用原生的方式安装Redis&#xff0c;Redis的版本为5.0.5 安装 下载 下载网站&#xff1a;https://download.redis.io/releases/ wget http://download.redis.io/releases/redis-5.0.5.tar.gz解压 tar -zxvf redis-5.0.5.tar.gz进入redis目录 cd redis-5.0.5执行编译…

u盘一插上就提示格式化解决办法,帮助重新使用,避免数据丢失

在我们使用U盘的过程中&#xff0c;有时会遇到一插上就提示格式化的问题。这个问题可能会给我们带来很多麻烦&#xff0c;因为格式化操作会导致数据的丢失。为了解决这一问题&#xff0c;本文将介绍一些解决办法&#xff0c;帮助读者重新使用U盘&#xff0c;并避免数据丢失的风…

【开源】基于Vue和SpringBoot的校园二手交易系统

项目编号&#xff1a; S 009 &#xff0c;文末获取源码。 \color{red}{项目编号&#xff1a;S009&#xff0c;文末获取源码。} 项目编号&#xff1a;S009&#xff0c;文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 数据中心模块2.2 二手商品档案管理模…