keras卷积处理rgb输入_CNN卷积神经网络模型搭建

前言

前段时间尝试使用深度学习来识别评测过程中的图片,以减少人力成本。目前是在深度学习框架Keras(后端使用TensorFlow)下搭建了一个CNN卷积神经网络模型,下面就如何搭建一个最简单的数字图像识别模型做下介绍。

模型的建立

(1) 卷积层(convolution layer):至于什么是卷积大家可以自己去找资料看看,这里重点讲讲Convolution2D()函数。根据keras官方文档描述,2D代表这是一个2维卷积,其功能为对2维输入进行滑窗卷积计算。我们的数字图像尺寸为28*28,拥有长、宽两维,所以在这里我们使用2维卷积函数计算卷积。所谓的滑窗计算,其实就是利用卷积核逐个像素、顺序进行计算,如下图:3d50d207708863c39573d6b2fe4ea725.png

上图选择了最简单的均值卷积核,3x3大小,我们用这个卷积核作为掩模对前面4x4大小的图像逐个像素作卷积运算。首先我们将卷积核中心对准图像第一个像素,在这里就是像素值为237的那个像素。卷积核覆盖的区域(掩模之称即由此来),其下所有像素取均值然后相加:

C(1) = 0 * 0.5 + 0 * 0.5 + 0 * 0.5 + 0 * 0.5 + 237 * 0.5 + 203 * 0.5 + 0 * 0.5 + 123 * 0.5 + 112 * 0.5

结果直接替换卷积核中心覆盖的像素值,接着是第二个像素、然后第三个,从左至右,由上到下……以此类推,卷积核逐个覆盖所有像素。整个操作过程就像一个滑动的窗口逐个滑过所有像素,最终生成一副尺寸相同但已经过卷积处理的图像。上图我们采用的是均值卷积核,实际效果就是将图像变模糊了。显然,卷积核覆盖图像边界像素时,会有部分区域越界,越界的部分我们以0填充,如上图。对于此种情况,还有一种处理方法,就是丢掉边界像素,从覆盖区域不越界的像素开始计算。像上图,如果采用丢掉边界像素的方法,3x3的卷积核就应该从第2行第2列的像素(值为112)开始,到第3行第3列结束,最终我们会得到一个2x2的图像。这种处理方式会丢掉图像的边界特征;而第一种方式则保留了图像的边界特征。在我们建立的模型中,卷积层采用哪种方式处理图像边界,卷积核尺寸有多大等参数都可以通过Convolution2D()函数来指定:

#第一个卷积层,4个卷积核,每个卷积核大小5*5。1表示输入的图片的通道,灰度图为1通道。
model.add(Conv2D(4, (5, 5), border_mode='valid',input_shape=(1, 28, 28), data_format='channels_first'))

第一个卷积层包含4个卷积核,每个卷积核大小为5x5,border_mode值为“same”意味着我们采用保留边界特征的方式滑窗,而值“valid”则指定丢掉边界像素(数字图像边缘是没用的,所以用vaild)。根据keras开发文档的说明,当我们将卷积层作为网络的第一层时,我们还应指定input_shape参数,显式地告知输入数据的形状,对我们的程序来说,input_shape的值为(1, 28, 28),代表28x28的灰度图。

PS:“channels_first”或“channels_last”之一,代表图像的通道维的位置。该参数是Keras 1.x中的image_dim_ordering,“channels_last”对应原本的“tf”,“channels_first”对应原本的“th”。以128x128的RGB图像为例,“channels_first”应将数据组织为(3,128,128),而“channels_last”应将数据组织为(128,128,3)。该参数的默认值是~/.keras/keras.json中设置的值,若从未设置过,则为“channels_last”。

(2) 激活函数层:这里讲一下最简单的relu(Rectified Linear Units,修正线性单元)函数,它的数学形式如下:

ƒ(x) = max(0, x)

这个函数非常简单,其输出一目了然,小于0的输入,输出全部为0,大于0的则输入与输出相等。该函数的优点是收敛速度快,除了它,keras库还支持其它几种激活函数,如下:

  • softplus

  • softsign

  • tanh

  • sigmoid

  • hard_sigmoid

  • linear

它们的函数式、优缺点网络上有很多资料,大家自己去查。对于不同的需求,我们可以选择不同的激活函数,这也是模型训练可调整的一部分,运用之妙,存乎一心,请自忖之。另外再交代一句,其实激活函数层按照我们前文所讲,其属于人工神经元的一部分,所以我们亦可以在构造层对象时通过传递activation参数设置,如下:

model.add(Convolution2D(4, (5, 5), border_mode='valid',input_shape=(1, 28, 28), data_format='channels_first'))
model.add(Activation('tanh'))

#通过传递activation参数设置,与上两行代码的作用相同
model.add(Convolution2D(4, (5, 5), border_mode='valid',input_shape=(1, 28, 28), data_format='channels_first'), activation='tanh')

(3) 池化层(pooling layer):池化层存在的目的是缩小输入的特征图,简化网络计算复杂度;同时进行特征压缩,突出主要特征。我们通过调用MaxPooling2D()函数建立了池化层,这个函数采用了最大值池化法,这个方法选取覆盖区域的最大值作为区域主要特征组成新的缩小后的特征图:830a08912887dfdc1e8c2e160de188f6.png

显然,池化层与卷积层覆盖区域的方法不同,前者按照池化尺寸逐块覆盖特征图,卷积层则是逐个像素滑动覆盖。对于我们输入的28x28特征图来说,经过2x2池化后,图像变为14x14大小。

model.add(MaxPooling2D(pool_size=(2, 2)))

(4)Dropout层:随机断开一定百分比的输入神经元链接,以防止过拟合。那么什么是过拟合呢?一句话解释就是训练数据预测准确率很高,测试数据预测准确率很低,用图形表示就是拟合曲线较尖,不平滑。导致这种现象的原因是模型的参数很多,但训练样本太少,导致模型拟合过度。为了解决这个问题,Dropout层将有意识的随机减少模型参数,让模型变得简单,而越简单的模型越不容易产生过拟合。代码中Dropout()函数只有一个输入参数——指定抛弃比率,范围为0~1之间的浮点数,其实就是百分比。这个参数亦是一个可调参数,我们可以根据训练结果调整它以达到更好的模型成熟度。

#本样例没有使用到,详见官方文档
keras.layers.core.Dropout(rate, noise_shape=None, seed=None)

(5)Flatten层:截止到Flatten层之前,在网络中流动的数据还是多维的(对于我们的程序就是2维的),经过多次的卷积、池化、Dropout之后,到了这里就可以进入全连接层做最后的处理了。全连接层要求输入的数据必须是一维的,因此,我们必须把输入数据“压扁”成一维后才能进入全连接层,Flatten层的作用即在于此。该层的作用如此纯粹,因此反映到代码上我们看到它不需要任何输入参数。

(6)全连接层(dense layer):全连接层的作用就是用于分类或回归,对于我们来说就是分类。keras将全连接层定义为Dense层,其含义就是这里的神经元连接非常“稠密”。我们通过Dense()函数定义全连接层。这个函数的一个必填参数就是神经元个数,其实就是指定该层有多少个输出。在我们的代码中,第一个全连接层(#14 Dense层)指定了512个神经元,也就是保留了512个特征输出到下一层。这个参数可以根据实际训练情况进行调整,依然是没有可参考的调整标准,自调之。

#全连接层,先将前一层输出的二维特征图flatten为一维的。
#Dense就是隐藏层。16就是上一层输出的特征图个数。4是根据每个卷积层计算出来的:(28-5+1)得到24,(24-3+1)/2得到11,(11-3+1)/2得到4
#全连接有128个神经元节点,初始化方式为normal
model.add(Flatten())
model.add(Dense(128, init='normal'))
model.add(Activation('tanh'))

(7)分类层:全连接层最终的目的就是完成我们的分类要求:0到9,模型构建代码的最后两行完成此项工作:

#Softmax分类,输出是10类别
model.add(Dense(10, init='normal'))
model.add(Activation('softmax'))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/502946.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

python中try命令_Python 异常处理 Python 基础教程 try..except

异常处理在之前的学习中我们一直没有接触过。 哦对,我们甚至还不知道怎么向程序输入一段字符串。那么我们在这里提供一个小例子。 在命令行中,我们输入 s raw_input(Enter something --> )好了,我们已经知道如何输入一个字符串了&#xf…

python分词_Python 结巴分词实现关键词抽取分析

1 简介 关键词抽取就是从文本里面把跟这篇文档意义最相关的一些词抽取出来。这个可以追溯到文献检索初期,当时还不支持全文搜索的时候,关键词就可以作为搜索这篇论文的词语。因此,目前依然可以在论文中看到关键词这一项。 除了这些&#xff0…

redis 如何 mysql_Redis 如何保持和 MySQL 数据一致

一、需求起因在高并发的业务场景下,数据库大多数情况都是用户并发访问最薄弱的环节。所以,就需要使用redis做一个缓冲操作,让请求先访问到redis,而不是直接访问MySQL等数据库。这个业务场景,主要是解决读数据从Redis缓…

iframe 跨域_【梯云纵】搞定前端跨域

韦陀掌法,难陀时间善恶;梯云纵,难纵过乱世纷扰。现在开始写代码o(╯□╰)o什么是跨域1.跨域的定义广义的跨域是指一个域下对的文档或者脚本试图去请求另外一个域下的资源。a链接、重定向、表单提交、、、等标签background:url()、font-face()ajax 跨域请求……狭义的…

java中exception_Java中的异常 Exceptions

1. 概念exception是“exceptional event”的缩写,是指执行程序中发生的事件,破坏了程序的正常执行流程。Java 异常处理机制使程序更加健壮易于调试,它可以告诉程序员三个问题:错误的类型、位置、原因,帮助程序员解决错…

python异步asy_Python 异步编程之asyncio【转载】

一、协程的认识 协程(Coroutine),也可以被称为微线程,是一种用户态内的上下文切换技术。 简而言之,其实就是通过一个线程实现代码块相互切换执行。例如:deffunc1():print(1) ...print(2)deffunc2():print(3…

bitcount java_Java源码解释之Integer.bitCount

Java中的Integer.bitCount(i)的返回值是i的二进制表示中1的个数。源码如下:public static int bitCount(int i) {// HD, Figure 5-2i i - ((i >>> 1) & 0x55555555);i (i & 0x33333333) ((i >>> 2) & 0x33333333);i (i (i >&…

git merge 冲突_卧槽!小姐姐用动画图解 Git 命令,这也太秀了吧?!

公众号关注 “GitHubDaily”设为 “星标”,每天带你逛 GitHub!大家好,我是小 G。在座的各位应该都知道,Git 作为居家必备、团队协作之利器,自从 Linus Torvalds 发布这款工具后,便一直受到各路开发者的喜爱…

android 删除文件 代码_代码审计之某系统后台存在任意删除文件

本文作者:霾团队交流群:673441920-----------------------------------------------------------前言POC镇楼!!!POST 漏洞演示过程:首先我们利用D盾监听下我们的项目以外的目录。这里刚刚我们创建了这个文件…

ubuntu java8 java9_在Ubuntu/Debian系统上安装Java 9的方法

本文介绍在Ubuntu/Debian系统上安装Oracle Java 9的方法:使用webupd8team/java PPA,相同的PPA提供了Java 8和Java 7等旧版Java的软件包,如果你的应用程序需要这个,可以随意安装它们。要安装新版本可参考在Ubuntu 18.04系统上安装J…

websocket 压力测试_打造最强移动测试平台

笔者今年换掉了服役N年的旧手机,新手机12G的RAM,比自用的本子内存都大,如果只是玩游戏感觉不能完全发挥出全部机能,但又因为怕影响日常使用没有进行root,经过一番折腾,发现即使不root也不影响把它变成一款测…

python银行系统模拟演练_python多线程实现代码(模拟银行服务操作流程)

1.模拟银行服务完成程序代码目前,在以银行营业大厅为代表的窗口行业中大量使用排队(叫号)系统,该系统完全模拟了人群排队全过程,通过取票进队、排队等待、叫号服务等功能,代替了人们站队的辛苦。排队叫号软件的具体操作流程为&…

字符串左侧补0_(48)C++面试之最长不含重复字符的子字符串(动态规划)

// 面试题48&#xff1a;最长不含重复字符的子字符串// 题目&#xff1a;请从字符串中找出一个最长的不包含重复字符的子字符串&#xff0c;计算该最长子// 字符串的长度。假设字符串中只包含从a到z的字符。#include <vector> #include <string> #include <iost…

java udp 同一个端口实现收发_Java网络编程之UDP协议

伙伴们注意了&#xff01;小编在这里给大家送上关注福利&#xff1a;搜索微信公众号“速学Java”关注即可领取小编精心准备的资料一份&#xff01;今天我们来聊聊网络编程这部分的内容网络编程1)计算机网络是指将地理位置不同的具有独立功能的多台计算机及其外部设备&#xff0…

java 多态_Java面向对象 —— 多态

前两天已经相继介绍了Java面向对象的三大特性之中的封装、继承&#xff0c;所以今天就介绍Java面向对象的三大特性的最后一项&#xff0c;多态~首先讲一下什么是多态&#xff0c;以及多态需要注意的细节 什么是多态&#xff1a;一个对象具备多种形态&#xff0c;也可以理解为事…

vb6 方法‘ ’作用于对象 失败_JS基础入门-对象的使用

今日背诵小纸条对象是一组属性方法的组合&#xff0c;其中可包含基本值、对象和函数对象的定义1 对象字面量var hero{name: ‘产品小姐姐’&#xff0c;age: 16&#xff0c;weapon: [ ‘头盔’, ‘靴子’, ‘盔甲 ]&#xff0c;sayHi: function ( ) {console.log( this.name ’…

无法从套接字读取更多的数据 oracle_小伙面试时被追问数据库优化,面试前如何埋点反杀?

前言周五的早高峰, 各地软件园地铁站里中出现了不少穿着长袖加绒格子衫, 背双肩电脑包的年轻码农, 现在节气正值 [ 小雪 ] , 11月的全国性突然降温 , 让经历过996摧残的猿们一出地铁站就冻的打了个激灵 , 很庆幸的告诉大家距离放年假还剩不到 37 个工作日, 要买火车票的赶紧预约…

qlineedit限制输入数字_Excel单元格限制录入,实用小技巧

在填写资料表格的时候&#xff0c;为了不防止出错&#xff0c;会在单元格中设置一些技巧&#xff0c;限制对方输入内容&#xff0c;这样可以更好的预防输入错误。那么单元格限制输入技巧是如何实现的呢&#xff1f;1、限制只能录入数字比如单元格是我们要用来填写年龄数据等数字…

java二维数组 内存分配_java中二维数组内存分配

区分三种初始化方式&#xff1a;格式一&#xff1a;数据类型[][] 数组名 new 数据类型[m][n];m:表示这个二维数组有多少个一维数组。n:表示每一个一维数组的元素有多少个。//例&#xff1a;int arr[][]new int[3][2];如下图格式二&#xff1a;数据类型[][] 数组名 new 数据类…

word公式插件_如何快速输入复杂的数学公式?这里有 3 个实用技巧

不管你是不是科研狗&#xff0c;都可能遇到过在文章中插入公式。而我们最常用的就是使用 Word 自带的公式编辑器输入&#xff0c;Word 公式可以很好地匹配文章的格式&#xff0c;自然地插入文中。有时候处理一个公式简单&#xff0c;但如果你要输入大量公式&#xff0c;键盘、鼠…