【TensorFlow2 之014】在 TF 2.0 中实现 LeNet-5

一、说明

         在这篇文章中,我们将展示如何在 TensorFlow 中实现像 \(LeNet-5\) 这样的基础卷积神经网络。LeNet-5 架构由 Yann LeCun 于 1998 年发明,是第一个卷积神经网络。

 数据黑客变种rs    深度学习 机器学习 TensorFlow    2020 年 2 月 29 日  |  0

1.1 教程概述:

  1. 理论重述
  2. 在 TensorFlow 中的实现

1. 理论重述

\(LeNet-5 \) 的目标是识别手写数字。因此,它作为输入 \(32\times32\times1 \) 图像。它是灰度图像,因此通道数为 \(1 \)。下面我们可以看到该网络的架构。

LeNet5架构

LeNet-5架构

        I. 最初的 MNIST 现在被认为太简单了。在过去的二十年里,许多研究人员针对原始 MNIST 提出了成功的解决方案。您可以在此处查看直接比较结果。

        由于该网络主要是为 MNIST 数据集设计的,因此它的性能明显更好。通过微小的改变,它就可以在 Fashion MNIST 数据集上达到这种准确性。然而,在这篇文章中,我们将坚持网络的原始架构。

关于 LeNet-5 架构,您可以在此处阅读详细的理论文章。

        让我们总结一下 LeNet-5 架构的各层。

图层类型特征图尺寸内核大小跨步激活
图像132×32
卷积628×285×51正值
平均池化614×142×22
卷积1610×105×51正值
平均池化165×52×22
完全连接120正值
完全连接84正值
完全连接10软最大

二、TensorFlow中的实现

        交互式 Colab 笔记本可在以下链接找到  
        在 Google Colab 中运行

为了练习,您可以尝试通过将Fashion_mnist替换为mnist、cifar10或其他来更改数据集。


        让我们从导入所有必需的库开始。导入后,我们可以使用导入的模块来加载数据。load_data  () 函数将自动下载数据并将其拆分为训练集和测试集。

import datetime
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as pltfrom tensorflow.keras import Model
from tensorflow.keras.models import Sequential
from tensorflow.keras.losses import categorical_crossentropy
from tensorflow.keras.layers import Dense, Flatten, Conv2D, AveragePooling2Dfrom tensorflow.keras import datasets
from tensorflow.keras.utils import to_categorical#from __future__ import absolute_import, division, print_function, unicode_literals
# The data, split between train and test sets:
(x_train, y_train), (x_test, y_test) = datasets.fashion_mnist.load_data()
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
32768/29515 [=================================] - 0s 2us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
26427392/26421880 [==============================] - 7s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
8192/5148 [===============================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
4423680/4422102 [==============================] - 1s 0us/step

        我们可以检查新数据的形状,发现我们的图像是 28×28 像素,因此我们需要添加一个新轴,它将代表多个通道。此外,对标签进行 one-hot 编码和对输入图像进行归一化也很重要。

print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
print(x_train[0].shape, 'image shape')

x_train shape: (60000, 28, 28)
60000 train samples
10000 test samples
(28, 28) image shape
 

# Add a new axis
x_train = x_train[:, :, :, np.newaxis]
x_test = x_test[:, :, :, np.newaxis]print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
print(x_train[0].shape, 'image shape')
x_train shape: (60000, 28, 28, 1)
60000 train samples
10000 test samples
(28, 28, 1) image shape
# Convert class vectors to binary class matrices.num_classes = 10
y_train = to_categorical(y_train, num_classes)
y_test = to_categorical(y_test, num_classes)
# Data normalization
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255


        现在,是时候开始使用 TensorFlow 2.0 来构建我们的卷积神经网络了。最简单的方法是使用 Sequential API。我们将其包装在一个名为LeNet的类中。输入是图像,输出是类概率向量。

在 tf.keras 中,顺序模型表示层的线性堆栈,在本例中它遵循网络架构。

# LeNet-5 model
class LeNet(Sequential):def __init__(self, input_shape, nb_classes):super().__init__()self.add(Conv2D(6, kernel_size=(5, 5), strides=(1, 1), activation='tanh', input_shape=input_shape, padding="same"))self.add(AveragePooling2D(pool_size=(2, 2), strides=(2, 2), padding='valid'))self.add(Conv2D(16, kernel_size=(5, 5), strides=(1, 1), activation='tanh', padding='valid'))self.add(AveragePooling2D(pool_size=(2, 2), strides=(2, 2), padding='valid'))self.add(Flatten())self.add(Dense(120, activation='tanh'))self.add(Dense(84, activation='tanh'))self.add(Dense(nb_classes, activation='softmax'))self.compile(optimizer='adam',loss=categorical_crossentropy,metrics=['accuracy'])
model = LeNet(x_train[0].shape, num_classes)
model.summary()
Model: "le_net"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              (None, 28, 28, 6)         156       
_________________________________________________________________
average_pooling2d (AveragePo (None, 14, 14, 6)         0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 10, 10, 16)        2416      
_________________________________________________________________
average_pooling2d_1 (Average (None, 5, 5, 16)          0         
_________________________________________________________________
flatten (Flatten)            (None, 400)               0         
_________________________________________________________________
dense (Dense)                (None, 120)               48120     
_________________________________________________________________
dense_1 (Dense)              (None, 84)                10164     
_________________________________________________________________
dense_2 (Dense)              (None, 10)                850       
=================================================================
Total params: 61,706
Trainable params: 61,706
Non-trainable params: 0

        创建模型后,我们需要训练它的参数以使其强大。让我们训练给定数量的 epoch 模型。我们可以在TensorBoard中看到训练的进度。

# Place the logs in a timestamped subdirectory
# This allows to easy select different training runs
# In order not to overwrite some data, it is useful to have a name with a timestamp
log_dir="logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
# Specify the callback object
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)# tf.keras.callback.TensorBoard ensures that logs are created and stored
# We need to pass callback object to the fit method
# The way to do this is by passing the list of callback objects, which is in our case just one
model.fit(x_train, y=y_train, epochs=20, validation_data=(x_test, y_test), callbacks=[tensorboard_callback],verbose=0)%tensorboard --logdir logs/fit

        仅 20 个 epoch 就已经不错了。

        之后,我们可以做出一些预测并将其可视化。使用predict_classes()函数,我们可以从网络中获取准确的类别而不是概率值。它与使用numpy.argmax()相同。我们将用红色显示错误的预测,用蓝色表示正确的预测。

class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat','Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']prediction_values = model.predict_classes(x_test)# set up the figure
fig = plt.figure(figsize=(15, 7))
fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0.05, wspace=0.05)# plot the images: each image is 28x28 pixels
for i in range(50):ax = fig.add_subplot(5, 10, i + 1, xticks=[], yticks=[])ax.imshow(x_test[i,:].reshape((28,28)),cmap=plt.cm.gray_r, interpolation='nearest')if prediction_values[i] == np.argmax(y_test[i]):# label the image with the blue textax.text(0, 7, class_names[prediction_values[i]], color='blue')else:# label the image with the red textax.text(0, 7, class_names[prediction_values[i]], color='red')
时尚迷斯特

三、总结

        所以,在这里我们学习了如何在 Tensorflow 2.0 中开发和训练 LeNet-5。在下 一篇文章中 ,我们将继续实现流行的卷积神经网络,并学习如何在 TensorFlow 2.0 中实现AlexNet 。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/107321.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

AUTOSAR组织发布20周年纪念册,东软睿驰NeuSAR列入成功案例

近日,AUTOSAR组织在成立20周年之际发布20周年官方纪念册(20th Anniversary Brochure),记录了AUTOSAR组织从成立到今天的故事、汽车行业当前和未来的发展以及AUTOSAR 伙伴关系和合作在重塑汽车方面的作用。东软睿驰提报的基于AUTOS…

行情分析——加密货币市场大盘走势(10.16)

目前大饼再次止稳,并开始向上攀升,目前MACD来看也是进入了多头趋势。重新调整了蓝色上涨趋势线,目前来看这次的低点并没有跌破上一个低点,可以认为是上涨的中继。注意白天的下跌回调。 以太目前也是走了四连阳线,而MAC…

强化学习案例复现(1)--- MountainCar基于Q-learning

1 搭建环境 1.1 gym自带 import gym# Create environment env gym.make("MountainCar-v0")eposides 10 for eq in range(eposides):obs env.reset()done Falserewards 0while not done:action env.action_space.sample()obs, reward, done, action, info env.…

关于Skywalking Agent customize-enhance-trace对应用复杂参数类型取值

对于Skywalking Agent customize-enhance-trace 大家应该不陌生了,主要支持以非入侵的方式按用户自定义的Span跟踪对应的应用方法,并获取数据。 参考https://skywalking.apache.org/docs/skywalking-java/v9.0.0/en/setup/service-agent/java-agent/cust…

STM32 ---- 再次学习STM32F103C8T6/STM32F409IGT6

目录 一、环境搭建及介绍 关于STM32基础介绍 新建工程 外设案例 LED流水灯 蜂鸣器 上拉电阻和下拉电阻知识 电压比较器 c语言基础知识 类型、结构体、枚举 类型int8_t int16_t int32_t 宏替换 #define 和typedef用法 结构体两种填充方法 和 命名规则 枚举用法 常用…

uniapp中全局页面挂载组件(H5)

前言 我们已经学习了 uniapp中全局页面挂载组件(小程序) 有些小伙伴问在H5怎么做那让我们试一试 直接上代码 //引用组件 import dialog from ./index.vue; //我这里要把小程序的方法和h5方法写一起所以用了混入 import mixins from ./mixins.js //使用…

HTTPS双向认证及密钥总结

公钥私钥: 1)公钥加密,私钥解密:加解密 为什么不能私钥加密公钥解密? 私钥加密后,公钥是公开的都能解密,没有意义。 2)私钥签名,公钥验签:属于身份验证,防串改&#x…

ElementUI编辑表格单元格与查看模式切换的应用

需求:有时候在填写表单的时候,想要在输入的时候是input输入框的状态,但是当鼠标移出输入框失去焦点时,希望是查看的状态,这种场景可以通过 v-if实现 vue2ElementUi里面使用如下: 1.el-table标签注册 cell-…

GitLab(1)——GitLab安装

目录 一、使用设备 二、使用rpm包安装 Gitlab国内清华源下载地址: ①下载命令如下: ②安装命令如下: ③删除rpm包 ④配置 ⑤重载 ⑥重启 ⑦配置自启动 ⑧打开8989端口并重启防火墙 三、GitLab登录 ①访问GitLab的URL ②输入用户…

如何实现 Es 全文检索、高亮文本略缩处理(封装工具接口极致解耦)

如何实现 Es 全文检索、高亮文本略缩处理 前言技术选型JAVA 常用语法说明全文检索开发高亮开发Es Map 转对象使用核心代码 Trans 接口(支持父类属性的复杂映射)Trans 接口可优化的点高亮全局配置类如下真实项目落地效果为什么不用 numOfFragments、fragm…

203、RabbitMQ 之 使用 direct 类型的 Exchange 实现 消息路由 (RoutingKey)

目录 ★ 使用direct实现消息路由代码演示这个情况二ConstantUtil 常量工具类ConnectionUtil 连接RabbitMQ的工具类Publisher 消息生产者测试消息生产者 Consumer01 消息消费者01测试消费者结果: Consumer02 消息消费者02测试消费者结果: 完整代码&#x…

ES6中flat(),flatMap()使用方法

实际应用: 1.代替filtermap的连用 例:现有一组数据,只展示days>30的数据,且work为1设置color:“#ffffff”,work:0设置color:“#ff0000”: const dataList [{days: 31,name: "占位文字31",work: 0 }, {days: 30,n…

android 13.0 添加系统字体并且设置为默认字体

1.概述 在13.0系统定制化开发中,在产品定制中,有产品需求对于系统字体风格不太满意,所以想要更换系统的默认字体,对于系统字体的修改也是常有的功能,而系统默认也支持增加字体,所以就来添加楷体字体为系统字体,并替换为系统默认字体, 接下来就来分析下替换默认字体的方…

微信小程序框架---详细教程

🎬 艳艳耶✌️:个人主页 🔥 个人专栏 :《Spring与Mybatis集成整合》《Vue.js使用》 ⛺️ 越努力 ,越幸运。 目录 1.框架 1.1响应的数据绑定 1.2.页面管理 1.3.基础组件 1.4.丰富的 API 2.视图层 View 2.1.介绍 …

Jmeter脚本参数化和正则匹配

我们在做接口测试过程中,往往会遇到以下几种情况 每次发送请求,都需要更改参数值为未使用的参数值,比如手机号注册、动态时间等 上一个接口的请求体参数用于下一个接口的请求体参数 上一个接口的响应体参数用于下一个接口的请求体参数&#…

【09】基础知识:React组件的生命周期

组件从创建到死亡它会经历一些特定的阶段。 React 组件中包含一系列勾子函数&#xff08;生命周期回调函数 <> 生命周期钩子函数 <> 生命周期函数 <> 生命周期钩子&#xff09;&#xff0c;会在特定的时刻调用。 我们在定义组件时&#xff0c;会在特定的生…

一文理解登录鉴权(Cookie、Session、Jwt、CAS、SSO)

1 前言 登录鉴权是任何一个网站都无法绕开的部分&#xff0c;当系统要正式上线前都会要求接入统一登陆系统&#xff0c;一方面能够让网站只允许合法的用户访问&#xff0c;另一方面&#xff0c;当用户在网站上进行操作时也需要识别操作的用户&#xff0c;用作后期的操作审计。…

社区团购:优势、模式与未来发展趋势

随着互联网的普及和电子商务的快速发展&#xff0c;社区团购作为一种新兴的商业模式&#xff0c;正逐渐在各大城市流行起来。本文将详细阐述社区团购的基本概念、优势、模式以及未来发展趋势。 一、社区团购的基本概念和现状 社区团购是一种基于社区网络的电子商务模式&#…

docker jenkins

mkdir jenkins_home chown -R 1000:1000 /root/jenkins_home/docker run -d --name myjenkins -v /root/jenkins_home:/var/jenkins_home -p 8080:8080 -p 50000:50000 --restarton-failure jenkins/jenkins:lts-jdk17参考 Official Jenkins Docker imageDocker 搭建 Jenkins …

灾害与环境遥感团队本科生在IEEE TGRS 发表高水平论文

2023年9月27日&#xff0c;地球科学和遥感领域顶级期刊《IEEE Transactions on Geoscience and Remote Sensing》&#xff08;IEEE TGRS&#xff09;在线预刊发了灾害与环境遥感团队的最新研究成果“A novel spectral index for rapid dust-proof net mapping based on Sentine…