loss低但精确度低_低光照图像增强网络-RetinexNet(model.py解析【2】)

51b2c2d4-d92e-eb11-8da9-e4434bdf6706.png

53b2c2d4-d92e-eb11-8da9-e4434bdf6706.png

论文地址:https://arxiv.org/pdf/1808.04560.pdf

代码地址:https://github.com/weichen582/RetinexNet

解析目录:https://zhuanlan.zhihu.com/p/88761829


整个模型架构被实现为一个类:

class lowlight_enhance(object):

其构造函数实现了网络结构的搭建、损失函数的定义、训练的配置和参数的初始化,具体如下。

网络结构的搭建(该部分包括低/正常光照图像输入的定义以及Decom-Net、Enhance-Net和重建这三部分的对接,注意这里并没有对Rlow进行去噪的部分):

# build the model
self.input_low = tf.placeholder(tf.float32, [None, None, None, 3], name='input_low')
self.input_high = tf.placeholder(tf.float32, [None, None, None, 3], name='input_high')[R_low, I_low] = DecomNet(self.input_low, layer_num=self.DecomNet_layer_num)
[R_high, I_high] = DecomNet(self.input_high, layer_num=self.DecomNet_layer_num)I_delta = RelightNet(I_low, R_low)I_low_3 = concat([I_low, I_low, I_low])
I_high_3 = concat([I_high, I_high, I_high])
I_delta_3 = concat([I_delta, I_delta, I_delta])self.output_R_low = R_low
self.output_I_low = I_low_3
self.output_I_delta = I_delta_3
self.output_S = R_low * I_delta_3

损失函数的定义(该部分包括低/正常光照图像的重建损失、反射分量一致性损失、光照分量平滑损失以及最后分别计算的Decom-Net和Enhance-Net的总损失):

# loss
self.recon_loss_low = tf.reduce_mean(tf.abs(R_low * I_low_3 - self.input_low))
self.recon_loss_high = tf.reduce_mean(tf.abs(R_high * I_high_3 - self.input_high))
self.recon_loss_mutal_low = tf.reduce_mean(tf.abs(R_high * I_low_3 - self.input_low))
self.recon_loss_mutal_high = tf.reduce_mean(tf.abs(R_low * I_high_3 - self.input_high))
self.equal_R_loss = tf.reduce_mean(tf.abs(R_low - R_high))
self.relight_loss = tf.reduce_mean(tf.abs(R_low * I_delta_3 - self.input_high))self.Ismooth_loss_low = self.smooth(I_low, R_low)
self.Ismooth_loss_high = self.smooth(I_high, R_high)
self.Ismooth_loss_delta = self.smooth(I_delta, R_low)self.loss_Decom = self.recon_loss_low + self.recon_loss_high + 0.001 * self.recon_loss_mutal_low + 0.001 * self.recon_loss_mutal_high + 0.1 * self.Ismooth_loss_low + 0.1 * self.Ismooth_loss_high + 0.01 * self.equal_R_loss
self.loss_Relight = self.relight_loss + 3 * self.Ismooth_loss_delta

训练的配置(该部分包括学习率以及Decom-Net和Enhance-Net的优化器设置):

self.lr = tf.placeholder(tf.float32, name='learning_rate')
optimizer = tf.train.AdamOptimizer(self.lr, name='AdamOptimizer')self.var_Decom = [var for var in tf.trainable_variables() if 'DecomNet' in var.name]
self.var_Relight = [var for var in tf.trainable_variables() if 'RelightNet' in var.name]self.train_op_Decom = optimizer.minimize(self.loss_Decom, var_list = self.var_Decom)
self.train_op_Relight = optimizer.minimize(self.loss_Relight, var_list = self.var_Relight)

训练参数的初始化:

self.sess.run(tf.global_variables_initializer())self.saver_Decom = tf.train.Saver(var_list = self.var_Decom)
self.saver_Relight = tf.train.Saver(var_list = self.var_Relight)print("[*] Initialize model successfully...")

接下来是该类的一些成员函数。

def gradient(self, input_tensor, direction):self.smooth_kernel_x = tf.reshape(tf.constant([[0, 0], [-1, 1]], tf.float32), [2, 2, 1, 1])self.smooth_kernel_y = tf.transpose(self.smooth_kernel_x, [1, 0, 2, 3])if direction == "x":kernel = self.smooth_kernel_xelif direction == "y":kernel = self.smooth_kernel_yreturn tf.abs(tf.nn.conv2d(input_tensor, kernel, strides=[1, 1, 1, 1], padding='SAME'))

该函数实现的是通过与指定梯度算子进行卷积的方式求图像的水平/垂直梯度图。

def ave_gradient(self, input_tensor, direction):return tf.layers.average_pooling2d(self.gradient(input_tensor, direction), pool_size=3, strides=1, padding='SAME')

该函数实现的是通过平均池化的方式来对图像的水平/垂直梯度图进行平滑。

def smooth(self, input_I, input_R):input_R = tf.image.rgb_to_grayscale(input_R)return tf.reduce_mean(self.gradient(input_I, "x") * tf.exp(-10 * self.ave_gradient(input_R, "x")) + self.gradient(input_I, "y") * tf.exp(-10 * self.ave_gradient(input_R, "y")))

该函数是对光照分量平滑损失的具体实现(可对应原论文中的公式来看)。

def evaluate(self, epoch_num, eval_low_data, sample_dir, train_phase):print("[*] Evaluating for phase %s / epoch %d..." % (train_phase, epoch_num))for idx in range(len(eval_low_data)):input_low_eval = np.expand_dims(eval_low_data[idx], axis=0)if train_phase == "Decom":result_1, result_2 = self.sess.run([self.output_R_low, self.output_I_low], feed_dict={self.input_low: input_low_eval})if train_phase == "Relight":result_1, result_2 = self.sess.run([self.output_S, self.output_I_delta], feed_dict={self.input_low: input_low_eval})save_images(os.path.join(sample_dir, 'eval_%s_%d_%d.png' % (train_phase, idx + 1, epoch_num)), result_1, result_2)

该函数是对训练epoch_num次后的Decom-Net/Enhance-Net模型进行评估,并保存评估结果图。

接下来是关于模型的训练:

def train(self, train_low_data, train_high_data, eval_low_data, batch_size, patch_size, epoch, lr, sample_dir, ckpt_dir, eval_every_epoch, train_phase):

该函数中包含了预训练模型的加载、数据的读取与处理、模型的训练、评估和保存这几个部分。

assert len(train_low_data) == len(train_high_data)
numBatch = len(train_low_data) // int(batch_size)

检查所有需要参与训练的低/正常光照样本数量是否一致,若一致则计算训练集含有的batch数量。

# load pretrained model
if train_phase == "Decom":train_op = self.train_op_Decomtrain_loss = self.loss_Decomsaver = self.saver_Decom
elif train_phase == "Relight":train_op = self.train_op_Relighttrain_loss = self.loss_Relightsaver = self.saver_Relightload_model_status, global_step = self.load(saver, ckpt_dir)
if load_model_status:iter_num = global_stepstart_epoch = global_step // numBatchstart_step = global_step % numBatchprint("[*] Model restore success!")
else:iter_num = 0start_epoch = 0start_step = 0
print("[*] Not find pretrained model!")

若存在Decom-Net/Enhance-Net对应的预训练模型,则进行加载;否则从头开始训练。

# generate data for a batch
batch_input_low = np.zeros((batch_size, patch_size, patch_size, 3), dtype="float32")
batch_input_high = np.zeros((batch_size, patch_size, patch_size, 3), dtype="float32")
for patch_id in range(batch_size):h, w, _ = train_low_data[image_id].shapex = random.randint(0, h - patch_size)y = random.randint(0, w - patch_size)rand_mode = random.randint(0, 7)batch_input_low[patch_id, :, :, :] = data_augmentation(train_low_data[image_id][x : x+patch_size, y : y+patch_size, :], rand_mode)batch_input_high[patch_id, :, :, :] = data_augmentation(train_high_data[image_id][x : x+patch_size, y : y+patch_size, :], rand_mode)image_id = (image_id + 1) % len(train_low_data)if image_id == 0:tmp = list(zip(train_low_data, train_high_data))random.shuffle(list(tmp))train_low_data, train_high_data = zip(*tmp)

顺序读取训练图像,在每次读取的低/正常光照图像对上随机取patch,并进行数据扩增(具体见 中对函数data_augmentation的描述)。这里,应当注意的是,训练数据每满一个batch时将会重新打乱整个训练集。

# train
_, loss = self.sess.run([train_op, train_loss], feed_dict={self.input_low: batch_input_low, self.input_high: batch_input_high, self.lr: lr[epoch]})print("%s Epoch: [%2d] [%4d/%4d] time: %4.4f, loss: %.6f" % (train_phase, epoch + 1, batch_id + 1, numBatch, time.time() - start_time, loss))
iter_num += 1

训练一个iter并打印相关信息。

# evalutate the model and save a checkpoint file for it
if (epoch + 1) % eval_every_epoch == 0:self.evaluate(epoch + 1, eval_low_data, sample_dir=sample_dir, train_phase=train_phase)self.save(saver, iter_num, ckpt_dir, "RetinexNet-%s" % train_phase)

每训练eval_every_epoch次评估并保存一次模型。

保存指定iter的模型:

def save(self, saver, iter_num, ckpt_dir, model_name):if not os.path.exists(ckpt_dir):os.makedirs(ckpt_dir)print("[*] Saving model %s" % model_name)saver.save(self.sess, os.path.join(ckpt_dir, model_name), global_step=iter_num)

加载最新的模型:

def load(self, saver, ckpt_dir):ckpt = tf.train.get_checkpoint_state(ckpt_dir)if ckpt and ckpt.model_checkpoint_path:full_path = tf.train.latest_checkpoint(ckpt_dir)try:global_step = int(full_path.split('/')[-1].split('-')[-1])except ValueError:global_step = Nonesaver.restore(self.sess, full_path)return True, global_stepelse:print("[*] Failed to load model from %s" % ckpt_dir)return False, 0

最后是关于模型的测试(其中test_high_data并没有用到):

def test(self, test_low_data, test_high_data, test_low_data_names, save_dir, decom_flag):

该函数中包含了模型的加载、模型的测试和结果图的保存这几个部分。

tf.global_variables_initializer().run()print("[*] Reading checkpoint...")
load_model_status_Decom, _ = self.load(self.saver_Decom, './model/Decom')
load_model_status_Relight, _ = self.load(self.saver_Relight, './model/Relight')
if load_model_status_Decom and load_model_status_Relight:print("[*] Load weights successfully...")

初始化所有参数并加载最新的Decom-Net和Enhance-Net模型。

print("[*] Testing...")
for idx in range(len(test_low_data)):print(test_low_data_names[idx])[_, name] = os.path.split(test_low_data_names[idx])suffix = name[name.find('.') + 1:]name = name[:name.find('.')]input_low_test = np.expand_dims(test_low_data[idx], axis=0)[R_low, I_low, I_delta, S] = self.sess.run([self.output_R_low, self.output_I_low, self.output_I_delta, self.output_S], feed_dict = {self.input_low: input_low_test})if decom_flag == 1:save_images(os.path.join(save_dir, name + "_R_low." + suffix), R_low)save_images(os.path.join(save_dir, name + "_I_low." + suffix), I_low)save_images(os.path.join(save_dir, name + "_I_delta." + suffix), I_delta)save_images(os.path.join(save_dir, name + "_S." + suffix), S)

遍历测试样本进行测试,并保存最终结果图(可自行指定是否保存Decom-Net的分解结果)。

欢迎关注公众号:huangxiaobai880

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/397330.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

添加dubbo xsd的支持

使用dubbo时遇到问题: org.xml.sax.SAXParseException: schema_reference.4: Failed to read schema document http://code.alibabatech.com/schema/dubbo/dubbo.xsd, because 1) could not find the document; 2) the document could not be read; 3) the root ele…

byte数组穿换成pcm格式_形象地介绍DSD的编解码原理及和PCM的区别

一直有人不清楚DSD到底是啥原理,和MP3, FLAC, APE, WAV等基于PCM编码技术的音频格式又有啥区别。特意做了两张图说明一下。图一是是由很多黑点构成的蒙娜丽莎头像,点击看大图就知道是没有灰阶只有黑白两色。但是人眼是可以看到有丰富的灰阶的。这和DSD一…

UBUNTU : Destination Host Unreachable

介绍我的系统的搭建的方式: WIN7 64 VMWARE STATION,方式是进行桥接的方式。最近突然出现了问题,Ubuntu ping 外网或者 PING WIN 7 的时候,出现 Destination Host Unreachable 的错误;想着去修改网卡的链接形式: 编辑…

怎么把分开的pdf放在一起_糖和盐混在一起了要怎么分开?| 趣问万物

趣 问 万 物来源:把科学带回家撰文:Mirror如何分离糖和盐?图源:Pixabay小手一抖,不小心把糖(蔗糖)和盐(氯化钠)混在一块儿了该怎么办?趁着光棍节,就让我们吃饱了撑着研究研究把糖和盐拆散的N种方…

《JavaScript DOM编程艺术》笔记

1. 把<script>标签放到HTML文档的最后&#xff0c;<body>标签之前能使浏览器更快地加载页面。 2. nodeType的常见取值 元素节点(1) 属性节点(2) 文本节点(3) 3. <a href"http://www.baidu.com" οnclick"popUp(this.href);return false;"&g…

常熟理工学院计算机考研,2018江苏专转本考生必看-常熟理工学院介绍

原标题&#xff1a;2018江苏专转本考生必看-常熟理工学院介绍这次轮到默默学介绍常熟理工学院啦&#xff01;今年常熟理工学院有个专转本的学生&#xff0c;也是默默学专转本视频课程考上常熟理工的一个学生&#xff0c;叫黄群超&#xff0c;当年专转本计算机也考了八九十分吧&…

.net中调用esb_大型ESB服务总线平台服务运行分析和监控预警实践

今天准备谈下ESB总线平台建设项目中的服务运行统计分析&#xff0c;服务心跳监测&#xff0c;服务监控预警方面的设计和实现。可以看到&#xff0c;在一个ESB服务总线平台上线后&#xff0c;SOA治理管控就变得相当重要&#xff0c;而这些运行监控分析本身也是提升ESB总线平台高…

计算机操作系统实验银行家算法,实验六 银行家算法(下)

实验六 银行家算法(下)一、实验说明实验说明&#xff1a;本次实验主要是对银行家算法进行进一步的实践学习&#xff0c;掌握银行家算法的整体流程&#xff0c;理解程序测试时每一步的当前状态&#xff0c;能对当前的资源分配进行预判断。二、实验要求1、获取源代码2、看懂大致框…

什么原因导致芯片短路_华为为什么突然大量用起了联发科芯片,或是这三个产品策略原因...

经常关注数码圈的都知道&#xff0c;近几年来&#xff0c;随着华为自研能力的提升&#xff0c;华为几乎很少采购第三方芯片&#xff0c;近几年来的绝大多数华为手机&#xff0c;几乎都是用的自研芯片麒麟系列。并没有像其它国产品牌那样用联发科或者高通的芯片。不过今年却大不…

如何运行vue项目(维护他人的项目)

假如你是个小白&#xff0c;在公司接手他人的项目&#xff0c;这个时候&#xff0c;该怎么将这个项目跑通&#xff1f; 前提&#xff1a; 首先&#xff0c;这个教程主要针对vue小白&#xff0c;并且不知道安装node.js环境的。言归正传&#xff0c;下面开始教程&#xff1a;在维…

进程操作

2019独角兽企业重金招聘Python工程师标准>>> 一、创建一个进程 进程是系统中最基本的执行单位。Linux系统允许任何一个用户进程创建一个子进程&#xff0c;创建之后&#xff0c;子进程存在于系统之中并独立于父进程。 关于父进程与子进程这两个概念&#xff0c;除了…

计算机硬件发展的特点有哪些,简述计算机的发展历程及各代计算机的特点。

满意答案Karen0491推荐于 2017.11.25采纳率&#xff1a;40% 等级&#xff1a;6已帮助&#xff1a;608人世界上第一台计算机是1946年问世的&#xff0c;根据计算机的性能和软硬件技术&#xff0c;将计算机发展划分成以下几个阶段&#xff1a;①第一阶段&#xff1a;电子管计算…

电饼锅的样式图片价格_进口珐琅铸铁锅专场,精致小厨娘们来康康!

两个月前&#xff0c;小灰兔我写了《10个高颜值居家好物&#xff0c;让你在朋友圈万众瞩目&#xff01;》一文&#xff0c;曾有小伙伴私信说这张图简直就是梦想中厨房的亚子强烈同意&#xff01;&#xff01;&#xff01;有多少女孩子&#xff0c;看到颜值炒鸡高的锅路都走不动…

在UITouch事件中画圆圈-iOS8 Swift基础教程

这篇教程主要内容展示如何利用Core Graphics Framework画圆圈,当用户点击屏幕时随机生成不同大小的圆,这篇教程在Xcode6和iOS8下编译通过。 打开Xcode,新建项目选择Single View Application,Product Name填写iOS8SwiftDrawingCirclesTutorial,Organization Name和Organization …

浏览器兼容性问题

转载于:https://www.cnblogs.com/python-machine/p/9406084.html

有人在远程使用计算机是什么意思,如何远程控制计算机,计算机远程控制有什么用途...

对于每个人来说&#xff0c;计算机都是至关重要的家用电器. 因为使用计算机可以使我们的业余生活丰富多彩. 随着Internet的普及&#xff0c;越来越多的用户开始学习自己使用计算机. 但是&#xff0c;操作中仍然存在很多问题&#xff0c;只要每个人都学会了远程控制&#xff0c;…

图学java基础篇之IO

java io体系 如图可以看出&#xff0c;java的io按照包来划分的话可以分为三大块&#xff1a;io、nio、aio&#xff0c;但是从使用角度来看&#xff0c;这三块其实揉杂在一起的&#xff0c;下边我们先来概述下这三块&#xff1a; io:主要包含字符流和字节流&#xff0c;我们常用…

boot界面上下键调节键不能动_为什么电脑一开机就自动进入BIOS界面

电脑故障的问题表现形式很多&#xff0c;比如说为什么电脑蓝屏&#xff0c;为什么电脑一开机就自动进入BIOS界面等。这些问题往往另很多网友不知所措。今天小编就针对电脑一开机就自动进入BIOS界面的问题&#xff0c;教下大家具体的解决方法。1、你的BIOS电池没有电了。解决方法…

句子相似度--余弦相似度算法的实现

1、余弦相似度余弦距离&#xff0c;也称为余弦相似度&#xff0c;是用向量空间中两个向量夹角的余弦值作为衡量两个个体间差异的大小的度量。余弦值越接近1&#xff0c;就表明夹角越接近0度&#xff0c;也就是两个向量越相似&#xff0c;这就叫"余弦相似性"。 上图两…

红帽436——HA高可用集群之概念篇

一、集群概念&#xff1a;集群&#xff1a;提高性能&#xff0c;降低成本&#xff0c;提高可扩展性&#xff0c;增强可靠性&#xff0c;任务调度室集群中的核心技术。集群作用:保证业务不断 集群三种网络&#xff1a;业务网络,集群网络,存储网络 二、集群三种类型&#xff1a;…