一文简单弄懂tensorflow_在tensorflow中设置梯度衰减

c0a5c80b93744488040cbfc92da8c379.png

我是从keras入门深度学习的,第一个用的demo是keras实现的yolov3,代码很好懂(其实也不是很好懂,第一次也搞了很久才弄懂)

然后是做的车牌识别,用了tiny-yolo来检测车牌位置,当时训练有4w张图片,用了一天来训练,当时觉得时间可能就是这么长,也不懂GPU训练的时候GPU利用率,所以不怎么在意,后来随着项目图片片的增多,训练时间越来越大,受不了了,看了一片文章才注意到GPU利用率的问题.想到要用tensorflow原生的api去训练,比如用tf.data.dataset

就找到了这个tensorflow原生实现yolo的项目,在训练的时候发现他没加梯度衰减,训练了一段时间total loss下不去了,所以加了一个梯度衰减。想写一下文章,小白的第一篇文章哈哈哈,大神别喷我的内容太简单

YunYang1994/tensorflow-yolov3​github.com
dd9643eca9c399274ebf4bd1c4306859.png

他好像改了train.py

原来是这样的

import tensorflow as tf
from core import utils, yolov3
from core.dataset import dataset, Parser
sess = tf.Session()IMAGE_H, IMAGE_W = 416, 416
BATCH_SIZE       = 8
EPOCHS           = 2000*1000
LR               = 0.0001
SHUFFLE_SIZE     = 1000
CLASSES          = utils.read_coco_names('./data/voc.names')
ANCHORS          = utils.get_anchors('./data/voc_anchors.txt')
NUM_CLASSES      = len(CLASSES)train_tfrecord   = "../VOC/train/voc_train*.tfrecords"
test_tfrecord    = "../VOC/test/voc_test*.tfrecords"parser   = Parser(IMAGE_H, IMAGE_W, ANCHORS, NUM_CLASSES)
trainset = dataset(parser, train_tfrecord, BATCH_SIZE, shuffle=SHUFFLE_SIZE)
testset  = dataset(parser, test_tfrecord , BATCH_SIZE, shuffle=None)is_training = tf.placeholder(tf.bool)
example = tf.cond(is_training, lambda: trainset.get_next(), lambda: testset.get_next())images, *y_true = example
model = yolov3.yolov3(NUM_CLASSES, ANCHORS)with tf.variable_scope('yolov3'):y_pred = model.forward(images, is_training=is_training)loss = model.compute_loss(y_pred, y_true)optimizer = tf.train.AdamOptimizer(LR)
saver = tf.train.Saver(max_to_keep=2)tf.summary.scalar("loss/coord_loss",   loss[1])
tf.summary.scalar("loss/sizes_loss",   loss[2])
tf.summary.scalar("loss/confs_loss",   loss[3])
tf.summary.scalar("loss/class_loss",   loss[4])write_op = tf.summary.merge_all()
writer_train = tf.summary.FileWriter("./data/train")
writer_test  = tf.summary.FileWriter("./data/test")update_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="yolov3/yolo-v3")
with tf.control_dependencies(update_var):train_op = optimizer.minimize(loss[0], var_list=update_var,global_step=global_step) # only update yolo layersess.run(tf.global_variables_initializer())
pretrained_weights = tf.global_variables(scope="yolov3/darknet-53")
load_op = utils.load_weights(var_list=pretrained_weights,weights_file="./darknet53.conv.74")
sess.run(load_op)for epoch in range(EPOCHS):run_items = sess.run([train_op, write_op] + loss, feed_dict={is_training:True})writer_train.add_summary(run_items[1], global_step=epoch)writer_train.flush() # Flushes the event file to diskif (epoch+1)%1000 == 0: saver.save(sess, save_path="./checkpoint/yolov3.ckpt", global_step=epoch)run_items = sess.run([write_op] + loss, feed_dict={is_training:False})writer_test.add_summary(run_items[0], global_step=epoch)writer_test.flush() # Flushes the event file to diskprint("EPOCH:%7d tloss_xy:%7.4f tloss_wh:%7.4f tloss_conf:%7.4f tloss_class:%7.4f"%(epoch, run_items[2], run_items[3], run_items[4], run_items[5]))

然后我发现没有梯度下降,所以就找了怎么实现

实现如下
optimizer = tf.train.AdamOptimizer(LR)
改为
global_step = tf.Variable(0, trainable=False)
learning_rate = tf.train.exponential_decay(LR,100,0.93,staircase=True,global_step=global_step)
optimizer = tf.train.AdamOptimizer(learning_rate)

learningrate 是梯度的类,LR是初始梯度,100是每一百次初始梯度乘以衰减度,这里是第三个参数0.93代表了衰减度,globalstep_step = global_step是一定要加的,不然梯度一直保持了初始梯度。

最后加个打印

tf.summary.scalar('learning_rate',learning_rate)

就可以爽快的去训练了

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/514268.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

打破“单点防护”缺陷,山石网科发布“云网端”XDR解决方案

编辑 | 宋 慧 供稿 | 山石网科 出品 | CSDN云计算 近年来,CISO面临的安全形势可谓“内忧外患”,对内面临多样化的网络接入途径、庞大且繁杂的IT资产;对外面临攻防关系、攻防手段、网络攻击的数量等呈指数级增长等问题,给组织的…

Serverless 场景下 Pod 创建效率优化

简介: 众所周知,Kubernetes 是云原生领域的基石,作为容器编排的基础设施,被广泛应用在 Serverless 领域。弹性能力是 Serverless 领域的核心竞争力,本次分享将重点介绍基于 Kubernetes 的 Serverless 服务中&#xff0…

安装wordcloud_COVID19数据分析实战:WordCloud 词云分析

↑↑点击上方蓝字,回复资料,N个G的惊喜前言上一篇文章(链接)我们对COVID19_line_list数据集进行了清洗以及初步分析。本文中我们将分析如何用词云来展示文本信息的概要。比如我们从词云百度百科截取文字,制作词云。简单来说,词云就…

到达率99.9%:闲鱼消息在高速上换引擎(集大成)

简介: 记录这一年闲鱼消息的优化之路 1. 背景 在2020年年初的时候接手了闲鱼的消息,当时的消息存在各种问题,网上的舆情也是接连不断:“闲鱼消息经常丢失”、“消息用户头像乱了”、“订单状态不对”(相信现在看文章的…

1小时打造HaaS版小小蛮驴智能车

1、认识一下小小蛮驴真面目 1.1、组件部分 HaaS100核心板 HaaS100是一款物联网场景中的标准硬件,并配套嵌入到硬件中的软件驱动及功能模块,为用户提供物联网设备高效开发服务。 HaaS100核心板有着丰富的外设接口,如下所示: 智…

Spring Boot Admin 集成诊断利器 Arthas 实践

简介: Arthas 是 Alibaba 开源的 Java 诊断工具,具有实时查看系统的运行状况;查看函数调用参数、返回值和异常;在线热更新代码;秒解决类冲突问题;定位类加载路径;生成热点;通过网页诊…

设计方案,拿来吧你!

作者:零一来源:前端印象前言大家好,我是零一,今天要跟大家聊聊开发流程中不起眼的环节——设计方案。你们可能没听过,也可能只是简单得走过过场,别划走,这非常重要!在字节&#xff0…

借力阿里云存储产品 延锋彼欧加速数字化重塑

简介: 延锋彼欧作为汽车外饰件生产的领航企业,通过基于业务和数据驱动的数字化管理,释放工业设备数据潜能提升产能。依托阿里云“稳定、安全、可靠、易用”的存储服务,延锋彼欧的发展步伐将更为稳健。 “一日骋千里,无…

英雄帖!移动云首批最有价值专家(MVP)招募开始了!

这是开发者的时代,这是价值重塑的时代。站在科技的潮头,我们期待去引领、去挖掘、去创造……移动云已迎来飞速发展的黄金期,移动云开发者社区将成为业界优秀开发者的聚集地。今天,移动云开发者社区正式开启移动云MVP首批招募&…

无责任畅想:云原生中间件的下一站

简介: 本文源自 2020 年 12 月 20 日作者在云原生社区 meetup 第二期北京站演讲 《Apache Dubbo-go 在云原生时代的实践与探索》的部分内容 自从以 2013 年开源的 docker 为代表的的容器技术和以 2014 年开源的 K8s 为代表的容器编排技术登上舞台之后,相…

深度剖析:Redis 分布式锁到底安全吗?看完这篇文章彻底懂了!

作者 | Kaito 来源 | 水滴与银弹阅读本文大约需要 20 分钟。大家好,我是 Kaito。这篇文章我想和你聊一聊,关于 Redis 分布式锁的「安全性」问题。Redis 分布式锁的话题,很多文章已经写烂了,我为什么还要写这篇文章呢?因…

Spring Boot 微服务性能下降九成!使用 Arthas 定位根因

简介: 接收到公司业务部门的开发反馈,应用在升级公司内部框架后,UAT(预生产)环境接口性能压测不达标。 背景 接收到公司业务部门的开发反馈,应用在升级公司内部框架后,UAT(预生产&a…

阿里研究员:线下环境为何不稳定?怎么破

简介: 为什么线下环境的不稳定是必然的?我们怎么办?怎么让它尽量稳定一点? 这篇文章想讲两件事: 为什么线下环境[1]的不稳定是必然的?我们怎么办?怎么让它尽量稳定一点? 此外&#…

谁说技术男不浪漫!90后程序员2天做出猫咪情绪识别软件

整理 | 王晓曼出品 | CSDN(ID:CSDNnews)9月1日,一则关于#程序员2天做出猫咪情绪识别软件#的话题登上微博热搜,参与阅读的人数达到了8218.1万,讨论次数1.3万,引发网友们的热议。高手在民间&#…

闲鱼如何一招保证推荐流稳如泰山

简介: 风雨不动安如山 背景 近几年互联网的快速发展中,互联网业务发展越来越复杂,业务也被拆分得越来越细,阿里内部业务也发生着翻天覆地的变化,从最初的单体应用,到后面的分布式集群,再到最近…

电商直播平台如何借助容器与中间件实现研发效率提升100%?

简介: 经过实际场景验证及用户的综合评估,电商直播平台借助全面的云原生容器化能力和中间件产品能力,大幅提升开发部署运维效率达50%~100%,极大地提升了用户体验,为业务持续发展打下了坚实的基础。 前言 直播带货是近…

在游戏运营行业,Serverless 如何解决数据采集分析痛点?

简介: 众所周知,游戏行业在当今的互联网行业中算是一棵常青树。在疫情之前的 2019 年,中国游戏市场营收规模约 2884.8 亿元,同比增长 17.1%。2020 年因为疫情,游戏行业更是突飞猛进。玩游戏本就是中国网民最普遍的娱乐…

字节大战腾讯元宇宙;Docker 自己定制镜像;VMware 云桌面助力秦皇岛市第一医院;微软开源 Cloud Katana;...

NEWS本周新闻回顾字节大战腾讯元宇宙:布局社交产品Pixsoul,上线游戏“重启世界”字节投资的代码乾坤,已于近日正式上线了元宇宙游戏《重启世界》。就在两个月前,被称为“元宇宙第一股”的Roblox登陆国内,由腾讯改名为《…

从 RxJS 到 Flink:如何处理数据流?

简介: 前端开发的本质是什么?响应式编程相对于 MVVM 或者 Redux 有什么优点?响应式编程的思想是否可以应用到后端开发中?本文以一个新闻网站为例,阐述在前端开发中如何使用响应式编程思想;再以计算电商平台…

Spring RSocket:基于服务注册发现的 RSocket 负载均衡

简介: RSocket 作为通讯协议的后起之秀,核心是二进制异步化消息通讯,是否也能和 Spring Cloud 技术栈结合,实现服务注册发现、客户端负载均衡,从而更高效地实现面向服务的架构?这篇文章我们就讨论一下 Spri…