keras卷积处理rgb输入_CNN卷积神经网络模型搭建

前言

前段时间尝试使用深度学习来识别评测过程中的图片,以减少人力成本。目前是在深度学习框架Keras(后端使用TensorFlow)下搭建了一个CNN卷积神经网络模型,下面就如何搭建一个最简单的数字图像识别模型做下介绍。

模型的建立

(1) 卷积层(convolution layer):至于什么是卷积大家可以自己去找资料看看,这里重点讲讲Convolution2D()函数。根据keras官方文档描述,2D代表这是一个2维卷积,其功能为对2维输入进行滑窗卷积计算。我们的数字图像尺寸为28*28,拥有长、宽两维,所以在这里我们使用2维卷积函数计算卷积。所谓的滑窗计算,其实就是利用卷积核逐个像素、顺序进行计算,如下图:3d50d207708863c39573d6b2fe4ea725.png

上图选择了最简单的均值卷积核,3x3大小,我们用这个卷积核作为掩模对前面4x4大小的图像逐个像素作卷积运算。首先我们将卷积核中心对准图像第一个像素,在这里就是像素值为237的那个像素。卷积核覆盖的区域(掩模之称即由此来),其下所有像素取均值然后相加:

C(1) = 0 * 0.5 + 0 * 0.5 + 0 * 0.5 + 0 * 0.5 + 237 * 0.5 + 203 * 0.5 + 0 * 0.5 + 123 * 0.5 + 112 * 0.5

结果直接替换卷积核中心覆盖的像素值,接着是第二个像素、然后第三个,从左至右,由上到下……以此类推,卷积核逐个覆盖所有像素。整个操作过程就像一个滑动的窗口逐个滑过所有像素,最终生成一副尺寸相同但已经过卷积处理的图像。上图我们采用的是均值卷积核,实际效果就是将图像变模糊了。显然,卷积核覆盖图像边界像素时,会有部分区域越界,越界的部分我们以0填充,如上图。对于此种情况,还有一种处理方法,就是丢掉边界像素,从覆盖区域不越界的像素开始计算。像上图,如果采用丢掉边界像素的方法,3x3的卷积核就应该从第2行第2列的像素(值为112)开始,到第3行第3列结束,最终我们会得到一个2x2的图像。这种处理方式会丢掉图像的边界特征;而第一种方式则保留了图像的边界特征。在我们建立的模型中,卷积层采用哪种方式处理图像边界,卷积核尺寸有多大等参数都可以通过Convolution2D()函数来指定:

#第一个卷积层,4个卷积核,每个卷积核大小5*5。1表示输入的图片的通道,灰度图为1通道。
model.add(Conv2D(4, (5, 5), border_mode='valid',input_shape=(1, 28, 28), data_format='channels_first'))

第一个卷积层包含4个卷积核,每个卷积核大小为5x5,border_mode值为“same”意味着我们采用保留边界特征的方式滑窗,而值“valid”则指定丢掉边界像素(数字图像边缘是没用的,所以用vaild)。根据keras开发文档的说明,当我们将卷积层作为网络的第一层时,我们还应指定input_shape参数,显式地告知输入数据的形状,对我们的程序来说,input_shape的值为(1, 28, 28),代表28x28的灰度图。

PS:“channels_first”或“channels_last”之一,代表图像的通道维的位置。该参数是Keras 1.x中的image_dim_ordering,“channels_last”对应原本的“tf”,“channels_first”对应原本的“th”。以128x128的RGB图像为例,“channels_first”应将数据组织为(3,128,128),而“channels_last”应将数据组织为(128,128,3)。该参数的默认值是~/.keras/keras.json中设置的值,若从未设置过,则为“channels_last”。

(2) 激活函数层:这里讲一下最简单的relu(Rectified Linear Units,修正线性单元)函数,它的数学形式如下:

ƒ(x) = max(0, x)

这个函数非常简单,其输出一目了然,小于0的输入,输出全部为0,大于0的则输入与输出相等。该函数的优点是收敛速度快,除了它,keras库还支持其它几种激活函数,如下:

  • softplus

  • softsign

  • tanh

  • sigmoid

  • hard_sigmoid

  • linear

它们的函数式、优缺点网络上有很多资料,大家自己去查。对于不同的需求,我们可以选择不同的激活函数,这也是模型训练可调整的一部分,运用之妙,存乎一心,请自忖之。另外再交代一句,其实激活函数层按照我们前文所讲,其属于人工神经元的一部分,所以我们亦可以在构造层对象时通过传递activation参数设置,如下:

model.add(Convolution2D(4, (5, 5), border_mode='valid',input_shape=(1, 28, 28), data_format='channels_first'))
model.add(Activation('tanh'))

#通过传递activation参数设置,与上两行代码的作用相同
model.add(Convolution2D(4, (5, 5), border_mode='valid',input_shape=(1, 28, 28), data_format='channels_first'), activation='tanh')

(3) 池化层(pooling layer):池化层存在的目的是缩小输入的特征图,简化网络计算复杂度;同时进行特征压缩,突出主要特征。我们通过调用MaxPooling2D()函数建立了池化层,这个函数采用了最大值池化法,这个方法选取覆盖区域的最大值作为区域主要特征组成新的缩小后的特征图:830a08912887dfdc1e8c2e160de188f6.png

显然,池化层与卷积层覆盖区域的方法不同,前者按照池化尺寸逐块覆盖特征图,卷积层则是逐个像素滑动覆盖。对于我们输入的28x28特征图来说,经过2x2池化后,图像变为14x14大小。

model.add(MaxPooling2D(pool_size=(2, 2)))

(4)Dropout层:随机断开一定百分比的输入神经元链接,以防止过拟合。那么什么是过拟合呢?一句话解释就是训练数据预测准确率很高,测试数据预测准确率很低,用图形表示就是拟合曲线较尖,不平滑。导致这种现象的原因是模型的参数很多,但训练样本太少,导致模型拟合过度。为了解决这个问题,Dropout层将有意识的随机减少模型参数,让模型变得简单,而越简单的模型越不容易产生过拟合。代码中Dropout()函数只有一个输入参数——指定抛弃比率,范围为0~1之间的浮点数,其实就是百分比。这个参数亦是一个可调参数,我们可以根据训练结果调整它以达到更好的模型成熟度。

#本样例没有使用到,详见官方文档
keras.layers.core.Dropout(rate, noise_shape=None, seed=None)

(5)Flatten层:截止到Flatten层之前,在网络中流动的数据还是多维的(对于我们的程序就是2维的),经过多次的卷积、池化、Dropout之后,到了这里就可以进入全连接层做最后的处理了。全连接层要求输入的数据必须是一维的,因此,我们必须把输入数据“压扁”成一维后才能进入全连接层,Flatten层的作用即在于此。该层的作用如此纯粹,因此反映到代码上我们看到它不需要任何输入参数。

(6)全连接层(dense layer):全连接层的作用就是用于分类或回归,对于我们来说就是分类。keras将全连接层定义为Dense层,其含义就是这里的神经元连接非常“稠密”。我们通过Dense()函数定义全连接层。这个函数的一个必填参数就是神经元个数,其实就是指定该层有多少个输出。在我们的代码中,第一个全连接层(#14 Dense层)指定了512个神经元,也就是保留了512个特征输出到下一层。这个参数可以根据实际训练情况进行调整,依然是没有可参考的调整标准,自调之。

#全连接层,先将前一层输出的二维特征图flatten为一维的。
#Dense就是隐藏层。16就是上一层输出的特征图个数。4是根据每个卷积层计算出来的:(28-5+1)得到24,(24-3+1)/2得到11,(11-3+1)/2得到4
#全连接有128个神经元节点,初始化方式为normal
model.add(Flatten())
model.add(Dense(128, init='normal'))
model.add(Activation('tanh'))

(7)分类层:全连接层最终的目的就是完成我们的分类要求:0到9,模型构建代码的最后两行完成此项工作:

#Softmax分类,输出是10类别
model.add(Dense(10, init='normal'))
model.add(Activation('softmax'))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/502946.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

python装饰器应用论文_Python装饰器的应用场景代码总结

装饰器的应用场景 附加功能 数据的清理或添加: 函数参数类型验证 require_ints 类似请求前拦截 数据格式转换 将函数返回字典改为 json/YAML 类似响应后篡改 为函数提供额外的数据 mock.patch 函数注册 在任务中心注册一个任务 注册一个带信号处理器的函数不同应用场景下装饰器…

python中try命令_Python 异常处理 Python 基础教程 try..except

异常处理在之前的学习中我们一直没有接触过。 哦对,我们甚至还不知道怎么向程序输入一段字符串。那么我们在这里提供一个小例子。 在命令行中,我们输入 s raw_input(Enter something --> )好了,我们已经知道如何输入一个字符串了&#xf…

python读取大文件性能_强悍的Python读取大文件的解决方案

Python 环境下文件的读取问题,请参见拙文 Python基础之文件读取的讲解这是一道著名的 Python 面试题,考察的问题是,Python 读取大文件和一般规模的文件时的区别,也即哪些接口不适合读取大文件。1. read() 接口的问题f open(filen…

python mysql 保存csv_使用Python将csv文件快速转存到Mysql

因为一些工作需要,我们经常会做一些数据持久化的事情,例如将临时数据存到文件里,又或者是存到数据库里。对于一个规范的表文件(例如csv),我们如何才能快速将数据存到Mysql里面呢?这个时候,我们可以使用pyth…

python分词_Python 结巴分词实现关键词抽取分析

1 简介 关键词抽取就是从文本里面把跟这篇文档意义最相关的一些词抽取出来。这个可以追溯到文献检索初期,当时还不支持全文搜索的时候,关键词就可以作为搜索这篇论文的词语。因此,目前依然可以在论文中看到关键词这一项。 除了这些&#xff0…

redis 如何 mysql_Redis 如何保持和 MySQL 数据一致

一、需求起因在高并发的业务场景下,数据库大多数情况都是用户并发访问最薄弱的环节。所以,就需要使用redis做一个缓冲操作,让请求先访问到redis,而不是直接访问MySQL等数据库。这个业务场景,主要是解决读数据从Redis缓…

truncate python是删除文件内容吗_在Python中操作文件之truncate()方法的使用教程

truncate()方法截断该文件的大小。如果可选的尺寸参数存在,该文件被截断(最多)的大小。 大小默认为当前位置。当前文件位置不改变。注意,如果一个指定的大小超过了文件的当前大小,其结果是依赖于平台。 注意:此方法不会在当文件工…

sqlserver mysql时间格式化_SqlServer时间格式化

最近用的SqlServer比较多, 时间 格式化 老是忘记,现整理如下:(来源于网上,具体来源地址忘记了,归根到底MSDN吧) SELECT CONVERT(varchar(50), GETDATE(), 0): 05 16 2006 10:57AM SELECT CONVERT(varchar(50), GETDATE…

iframe 跨域_【梯云纵】搞定前端跨域

韦陀掌法,难陀时间善恶;梯云纵,难纵过乱世纷扰。现在开始写代码o(╯□╰)o什么是跨域1.跨域的定义广义的跨域是指一个域下对的文档或者脚本试图去请求另外一个域下的资源。a链接、重定向、表单提交、、、等标签background:url()、font-face()ajax 跨域请求……狭义的…

java中exception_Java中的异常 Exceptions

1. 概念exception是“exceptional event”的缩写,是指执行程序中发生的事件,破坏了程序的正常执行流程。Java 异常处理机制使程序更加健壮易于调试,它可以告诉程序员三个问题:错误的类型、位置、原因,帮助程序员解决错…

python异步asy_Python 异步编程之asyncio【转载】

一、协程的认识 协程(Coroutine),也可以被称为微线程,是一种用户态内的上下文切换技术。 简而言之,其实就是通过一个线程实现代码块相互切换执行。例如:deffunc1():print(1) ...print(2)deffunc2():print(3…

bitcount java_Java源码解释之Integer.bitCount

Java中的Integer.bitCount(i)的返回值是i的二进制表示中1的个数。源码如下:public static int bitCount(int i) {// HD, Figure 5-2i i - ((i >>> 1) & 0x55555555);i (i & 0x33333333) ((i >>> 2) & 0x33333333);i (i (i >&…

python自定义全局异常_如何在python中进行全局异常捕获

使用sys.excepthook函数进行全局异常的获取。 首先定义异常处理函数, 并使用该函数接收系统异常信息。 import wx import sys class TestFrame(wx.Frame): def __init__(self): wx.Frame.__init__(self, None, -1, test) btn wx.Button(self, -1, test) btn.Bind(w…

git merge 冲突_卧槽!小姐姐用动画图解 Git 命令,这也太秀了吧?!

公众号关注 “GitHubDaily”设为 “星标”,每天带你逛 GitHub!大家好,我是小 G。在座的各位应该都知道,Git 作为居家必备、团队协作之利器,自从 Linus Torvalds 发布这款工具后,便一直受到各路开发者的喜爱…

freebsd java 能用吗_在FreeBSD 4.9下安装JAVA环境

导读:资源下载地址:1.http://www.sun.com/softwarre/java2/download.html2.http://ftp.csie.chu.edu.tw/FreeBSD/distfiles/openmotif/3.http://ameba.sc-uni.ktu.lt/pub/FreeBSD/4.http://www.wormwang.net/mirrors/java/一、以下的包要先下载放到各自的…

python中exec是什么意思_Python中的进程分支fork和exec详解

在python中,任务并发一种方式是通过进程分支来实现的.在linux系统在,通过fork()方法来实现进程分支. 1.fork()调用后会创建一个新的子进程,这个子进程是原父进程的副本.子进程可以独立父进程外运行. 2.fork()是一个很特殊的方法,一次调用,两次返回. 3.fork()它会返回2个值,一个…

java冒泡排序原理_冒泡排序原理及其java实现

冒泡排序原理:临近的数字两两进行比较,按照从小到大或者从大到小的顺序进行交换,这样外层循环每循环一次,都会把一个数的顺序排好(从小到大的话每次都会把上回剩余的数据最大的放在剩余数的最后面,反之则是最小的放剩余…

java holder_java.sql.SQLException: connection holder is null

错误信息2017-11-15 14:53:16.931 [ ] ERROR com.hzcf.flagship.web.AssetPlanController 126 :### Error updating database. Cause: java.sql.SQLException: connection holder is null### Cause: java.sql.SQLException: connection holder is null; uncategorized SQLExcep…

java signed_如何从java中的字节读取signed int?

我有一个规范读取接下来的两个字节是signed int.要在java中读取我有以下内容当我使用以下代码在java中读取signed int时,我得到值65449计算无符号的逻辑int a (byte[1] & 0xff) <<8int b (byte[0] & 0xff) <<0int c ab我认为这是错误的,因为如果我和0xff我…

android 删除文件 代码_代码审计之某系统后台存在任意删除文件

本文作者&#xff1a;霾团队交流群&#xff1a;673441920-----------------------------------------------------------前言POC镇楼&#xff01;&#xff01;&#xff01;POST 漏洞演示过程&#xff1a;首先我们利用D盾监听下我们的项目以外的目录。这里刚刚我们创建了这个文件…