keras优化算法_目标检测算法 - CenterNet - 代码分析

0df77a933b56185144189e715471fbbc.png
代码出处

吃水不忘打井人,分析github上的基于keras的实现:

xuannianz/keras-CenterNet​github.com
923504f8ffdaf0474179b62a2ba97384.png
代码主体结构

模型训练的主函数流程如下所示,该流程也是使用keras的较为标准的流程。其中代码篇幅较大的是数据准备的部分,通常的代码也亦如此。下面按照不同的部分分别进行说明。

b58c041b08884c27d9fac8387b9673fe.png
create_generators 数据集准备

该代码支持Pascal VOC格式、COCO格式以及CSV格式。keras中有三个函数可以用来进行模型的训练:分别是fit,fit_generator和train_on_batch。

fit(train_x, train_y, batchsize, epochs)

在使用fit进行模型训练时,通常假设整个训练集都可以放入RAM,并且没有数据增强(即不需要keras生成器)。常用于简单小型的数据集训练。

fit_generator;常常使用的模型训练函数

fit_generator适用于大数据集无法直接全部放入内存中,以及标注数据较少需要使用数据增强来增加训练模型的泛化能力。fit_generator需要传入一个数据生成器,数据生成器可以每次动态的生成一个batchsize的训练数据,通常我们也将数据增强放入数据生成器中,这样便可以动态的生成增强后的数据。在使用fit_generator时,需要传入steps_per_epoch的值,而fit函数则不需要,这是因为fit函数的steps_per_epoch默认等于总的训练数据/batchsize,而对于fit_generator来说,如果采用了数据增强,则可以产生无限的batchsize训练数据,因此需要指定该参数。

By the way,数据生成器可以使用keras的API或者直接自己手码python的代码,因为其本质上也就是python的函数。

train_on_batch(batchX, batchY)

train_on_batch用于需要对训练迭代进行精细控制,给其传入一批数据即可(数据大小任意),不需要提供batchsize的大小。通常很少使用该函数进行模型训练。

  • 本算法的实现过程就是采用的fit_generator进行的模型训练。因此需要为其构建数据生成器。common.py文件:class Generator(keras.utils.Sequence)构建数据生成器的基类,咱们先说道说道keras.utils.Sequence这个类。
keras.utils.Sequence:这个基类通常应用于数据集生成一个数据序列。使用时需构建一个python类继承自该
基类,并必须实现__len__和__getitem__两个函数,如果要在每个epoch间修改数据集则需要实现on_epoch_end
方法。
NOTE:特别注意,__getitem__要返回一个完整的batchsize数据,__len__统计的也是有多少个batch

Generator类可以当成一个抽象基类,其中主要实现的是batch的划分、数据增强的处理、以及标注数据的转换(将bounding box的标注形式转换成高斯分布的标注)。而真正使用的数据集的生成器如下所示。主要按照不同的数据集生成的类,并均都继承于Generator抽象类,这里区分不同的数据集主要为了能方便区分其不同的数据标注格式,使用起来更为方便。主要是load_annotations()和load_image()函数的实现。至此数据生成器便构建完成了。

class PascalVocGenerator(Generator)
class CocoGenerator(Generator)
centernet网络构建

算法实现采用的Resnet50作为网络的backbone,采用下述引用网络。网络构建这里相对就比较简单了,取出Resnet的C5,先添加了一层dropout,然后进行了上采样,然后分别构建网络head,主要有三支:中心点预测、中心点偏移值预测以及bouding box的size预测。

from keras.applications.resnet50 import ResNet50

最后构建model,使用keras的Lambda层构建loss,作为model的output

loss_ = Lambda(loss, name='centernet_loss')([y1, y2, y3, hm_input, wh_input, reg_input, reg_mask_input, index_input])
model = Model(inputs=[image_input, hm_input, wh_input, reg_input, reg_mask_input, index_input], outputs=[loss_])
预训练模型权重加载

keras的模型加载可以使用load_weights来实现,其模型加载可以按照模型结构加载,此时by_name需设置为False。否则将按照网络层的名字来加载,此时通常将skip_mismatch也设置成True,即仅加载名字相同的层,其他名字不同的层直接跳过。因此可以利用这个特性,对已训练好的网络局部进行修改,然后再加载之前训练好的模型,方便进行模型的调优。

model.load_weights(args.snapshot, by_name=True, skip_mismatch=True)
模型配置

其中loss参数的传递有几种形式。

  • 目标函数/损失函数的字符串,比如keras内置的一些损失函数
  • 目标函数/损失函数,通常为自定义的损失函数
  • 将目标函数/损失函数定义成model的一个层,类似本代码的实现。本代码实现时,因为直接把loss作为model的输出,因此输入y_true和y_pred,实际使用y_pred即输出loss,对其进行优化。
model.compile(optimizer=Adam(lr=1e-3), loss={'centernet_loss': lambda y_true, y_pred: y_pred})def compile(self, optimizer,loss=None,metrics=None,loss_weights=None,sample_weight_mode=None,weighted_metrics=None,target_tensors=None,**kwargs):"""Configures the model for training.# Argumentsoptimizer: String (name of optimizer) or optimizer instance.See [optimizers](/optimizers).loss: String (name of objective function) or objective function.See [losses](/losses).If the model has multiple outputs, you can use a different losson each output by passing a dictionary or a list of losses.The loss value that will be minimized by the modelwill then be the sum of all individual losses.metrics: List of metrics to be evaluated by the modelduring training and testing.Typically you will use `metrics=['accuracy']`.To specify different metrics for different outputs of amulti-output model, you could also pass a dictionary,such as `metrics={'output_a': 'accuracy'}`.loss_weights: Optional list or dictionary specifying scalarcoefficients (Python floats) to weight the loss contributionsof different model outputs.The loss value that will be minimized by the modelwill then be the *weighted sum* of all individual losses,weighted by the `loss_weights` coefficients.If a list, it is expected to have a 1:1 mappingto the model's outputs. If a dict, it is expected to mapoutput names (strings) to scalar coefficients.sample_weight_mode: If you need to do timestep-wisesample weighting (2D weights), set this to `"temporal"`.`None` defaults to sample-wise weights (1D).If the model has multiple outputs, you can use a different`sample_weight_mode` on each output by passing adictionary or a list of modes.weighted_metrics: List of metrics to be evaluated and weightedby sample_weight or class_weight during training and testing.target_tensors: By default, Keras will create placeholders for themodel's target, which will be fed with the target data duringtraining. If instead you would like to use your owntarget tensors (in turn, Keras will not expect externalNumpy data for these targets at training time), youcan specify them via the `target_tensors` argument. It can bea single tensor (for a single-output model), a list of tensors,or a dict mapping output names to target tensors.**kwargs: When using the Theano/CNTK backends, these argumentsare passed into `K.function`.When using the TensorFlow backend,these arguments are passed into `tf.Session.run`.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/528631.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

html checkbox 禁用,js禁用checkbox

两种禁用checkbox的方法:代码示例:-//W3C//DTD HTML 4.01 Transitional//EN” “http://www.w3.org/TR/html4/loose.dtd”>function x(){// document.all.cb1.disabled true;// 方法一document.getElementsByName(“cb1”)[0].disabled true;//方法二}JavaScri…

html5 拍照 清晰度,html5强大的功能(一)

html5强大的功能(一)发布时间:2020-03-26 16:03html5得出现被传的神乎其神的,做前端的总是要跟随着潮流发展,不过在跟风之前还是想要了解一下html5真正的魅力所在。html5创建的目的是以一种标准和直观的UI标记语言来把web设计和开发变得容易起…

徽柏工业机器人_新松机器人股票(中国机器人公司排名是怎样的?)

他们有:新时达、万丰科技、沃德福、徽柏等。5、深圳市汇川技术股份有限公司 。椅子孟安波扔过去&孤覃白曼走出去$安徽埃夫特--奇瑞工业机器人沈阳新松zhidao---新松机器人其实国产机器人主体未有一家达到规模,道4、宁波均胜电子回股份答有限公司 &am…

python初学者用什么开发环境搭建_2019-04-11 python入门学习——配置机器及搭建开发环境...

#  在windows操作系统中搭建python 3.x版本的开发环境,开发工具为 Anaconda 3.#1.1  下载及安装Anaconda 3Anaconda的特点:集成性高,包含很多常用的开发软件包,省去下载和安装软件包的时间。下载地址:https://www.…

html5 视频 showtime,利用function showTime显示不出时间是为什么?

js-01.htmlvar todaynew Date();var yeartoday.getYear();var monthtoday.getMonth();var hourtoday.getHours();var minutetoday.getMinutes();function showTime(){document.getElementById("content").innerHTML"现在为您报时:";document.ge…

c语言中创建一个整数数组_VBA中动态数组的创建及利用方法

大家好,后疫情时代一定会到来,各行各业,都将是一场战胜萧条的无声的战役。无论怎样,我们一定要坚信,疫情终将会过去,曙光一定会到来。后疫情时代将会是一个全新的世界,很多理念都将被打破&#…

用计算机求函数公式,计算机常用的函数公式有哪些?

01计算机常用的函数公式包括RANK函数、COUNTIF函数、IF函数、ABS函数、AND函数、AVERAGE函数、COLUMN 函数等。RANK函数是Excel计算序数的主要工具,它的语法为:RANK(number,ref,order),其中number为参与计算的数字或含…

js判断ipad还是安卓_JS判断客户端是否是iOS或者Android

每个客户端都带有自身的UA标识,通过JavaScript,可以获取客户端标识,我们可以获取浏览器的userAgent,用正则来判断手机是ios(苹果)还是Android(安卓)客户端。代码如下:var u navigator.userAgent;var isAndroid u.ind…

.net 移除html标签,.net去除html标签代码

.net去除html标签代码public string NoHTML(string Htmlstring){//删除脚本Htmlstring Regex.Replace(Htmlstring, "", "", RegexOptions.IgnoreCase);//删除HTMLHtmlstring Regex.Replace(Htmlstring, "", "$br$", RegexOptions.I…

golang 读取文件最后一行_python3从零学习-5.4.3、文件输入流fileinput

源代码: Lib/fileinput.py此模块实现了一个辅助类和一些函数用来快速编写访问标准输入或文件列表的循环。 如果你只想要读写一个文件请参阅 open().典型用法为:import fileinputfor line in fileinput.input(): process(line)这将遍历sys中列出的所有文件的行。argv[1:]如果…

云计算机具体应用场景,云计算的定义、类型及应用场景

云计算是20年来IT行业出现的最激动人心且最具颠覆性的技术,甚至比大型主机向客户端/服务器架构的迁移还更具颠覆性。无论是IT服务的交付方式,还是企业消费这些IT服务的方式,都因云计算而改变。用户也正在快速应对新架构带来的变革&#xff0c…

dataframe for循环 筛选_Python循环12种超强写法,又快又省内存

0 前言说到处理循环,我们习惯使用for, while等,比如依次打印每个列表中的字符:在打印内容字节数较小时,全部载入内存后,再打印,没有问题。可是,如果现在有成千上百万条车辆行驶轨迹,…

html5文字飞入插件,jquery使用CSS3实现文字动画效果插件Textillate.js

jquery使用CSS3实现文字动画效果插件Textillate.jsTextillate是一款基于CSS3动画效果的 JavaScript 库,您可非常轻轻松地把这些动画效果应该于网页中的任何文字。使用方法引入核心文件构建html标签My Title写入JS,初始化$(function () {$(.tlt).textilla…

工业机器人导轨 百度文库_工业机器人或许开创一个全新的PLC时代

自机器人诞生之日起人们就不断地尝试着说明到底什么是机器人。但随着机器人技术的飞速发展和信息时代的到来,机器人所涵盖的内容越来越丰富,机器人的定义也不断充实和创新。机器人技术作为20世纪人类最伟大的发明之一,自20世纪60年代初问世以…

计算机和互联网基础知识作业,计算机作业1基础知识含答案.doc

计算机作业1基础知识含答案.doc跳到主要内容网络课程学院主页 E-Learning 实验室 联系我们 窗体顶端窗体底端页面路径 首页/ 我的课程/ 计算机应用基础2299/ 主题 2/ 第一次作业 计算机基础知识开始时间 2015 年 10 月 1 日 星期四 1301完成于 2015 年 10 月 1 日 星期四 1435耗…

天津计算机的专科学校,天津市电子计算机职业中等专业学校

天津市电子计算机职业中等专业学校创建于980年,是首批国家级重点职业中专,是国家级中等职业教育改革发展示范学校建设单位天津市职业教育先进单位。办学30多年来,学校本着“以人为本传承发展”的原则,培养面向现代化的复合型人才,取得了很好的办学效益和社会效益,学校实训设备先…

内存超频trfc_内存超频教学

一、前言先说说内存超频的作用,在很多应用里,内存超频能带来显著提升,就比如PUBG、CSGO等FPS游戏,超频后的帧数表现和超频前的帧数表现相差很多。也有很多人觉得超频很麻烦,觉得超频会损害硬件的使用寿命,其…

前端html预览,HTML5 上传前预览

下面是前端之家 f2er.com 通过网络收集整理的代码片段。前端之家小编现在分享给大家,也给大家做个参考。HTML5上传图片预览请选择图片文件:JPG/GIF$("#file0").change(function(){var objUrl getObjectURL(this.files[0]) ;console.log("…

银联分账与银联代付_第三方分账系统到底有哪些作用?

随着监管越来越严,业务越来越复杂,所有平台电商企业都需要通过第三方分账系统解决支付清算及二清等问题。作为第三方分账系统行业从业者,整理了部分关于系统的相关问题及解答,希望对大家有所帮助。问题一:第三方分账系…

计算机更改桌面,2010年职称计算机考试:更改桌面背景和颜色

Windows XP提供了各种桌面的颜色和背景方案,用户可以根据自己的爱好进行选择。颜 色充当桌面的最底层,背景覆盖于颜色之上。(l)桌面背景的更改在"显示属性"对话框中,选择"桌面"选项卡。在"桌面"选项卡上有一个"背景"列表框,选择列表框…