loss低但精确度低_低光照图像增强网络-RetinexNet(model.py解析【2】)

51b2c2d4-d92e-eb11-8da9-e4434bdf6706.png

53b2c2d4-d92e-eb11-8da9-e4434bdf6706.png

论文地址:https://arxiv.org/pdf/1808.04560.pdf

代码地址:https://github.com/weichen582/RetinexNet

解析目录:https://zhuanlan.zhihu.com/p/88761829


整个模型架构被实现为一个类:

class lowlight_enhance(object):

其构造函数实现了网络结构的搭建、损失函数的定义、训练的配置和参数的初始化,具体如下。

网络结构的搭建(该部分包括低/正常光照图像输入的定义以及Decom-Net、Enhance-Net和重建这三部分的对接,注意这里并没有对Rlow进行去噪的部分):

# build the model
self.input_low = tf.placeholder(tf.float32, [None, None, None, 3], name='input_low')
self.input_high = tf.placeholder(tf.float32, [None, None, None, 3], name='input_high')[R_low, I_low] = DecomNet(self.input_low, layer_num=self.DecomNet_layer_num)
[R_high, I_high] = DecomNet(self.input_high, layer_num=self.DecomNet_layer_num)I_delta = RelightNet(I_low, R_low)I_low_3 = concat([I_low, I_low, I_low])
I_high_3 = concat([I_high, I_high, I_high])
I_delta_3 = concat([I_delta, I_delta, I_delta])self.output_R_low = R_low
self.output_I_low = I_low_3
self.output_I_delta = I_delta_3
self.output_S = R_low * I_delta_3

损失函数的定义(该部分包括低/正常光照图像的重建损失、反射分量一致性损失、光照分量平滑损失以及最后分别计算的Decom-Net和Enhance-Net的总损失):

# loss
self.recon_loss_low = tf.reduce_mean(tf.abs(R_low * I_low_3 - self.input_low))
self.recon_loss_high = tf.reduce_mean(tf.abs(R_high * I_high_3 - self.input_high))
self.recon_loss_mutal_low = tf.reduce_mean(tf.abs(R_high * I_low_3 - self.input_low))
self.recon_loss_mutal_high = tf.reduce_mean(tf.abs(R_low * I_high_3 - self.input_high))
self.equal_R_loss = tf.reduce_mean(tf.abs(R_low - R_high))
self.relight_loss = tf.reduce_mean(tf.abs(R_low * I_delta_3 - self.input_high))self.Ismooth_loss_low = self.smooth(I_low, R_low)
self.Ismooth_loss_high = self.smooth(I_high, R_high)
self.Ismooth_loss_delta = self.smooth(I_delta, R_low)self.loss_Decom = self.recon_loss_low + self.recon_loss_high + 0.001 * self.recon_loss_mutal_low + 0.001 * self.recon_loss_mutal_high + 0.1 * self.Ismooth_loss_low + 0.1 * self.Ismooth_loss_high + 0.01 * self.equal_R_loss
self.loss_Relight = self.relight_loss + 3 * self.Ismooth_loss_delta

训练的配置(该部分包括学习率以及Decom-Net和Enhance-Net的优化器设置):

self.lr = tf.placeholder(tf.float32, name='learning_rate')
optimizer = tf.train.AdamOptimizer(self.lr, name='AdamOptimizer')self.var_Decom = [var for var in tf.trainable_variables() if 'DecomNet' in var.name]
self.var_Relight = [var for var in tf.trainable_variables() if 'RelightNet' in var.name]self.train_op_Decom = optimizer.minimize(self.loss_Decom, var_list = self.var_Decom)
self.train_op_Relight = optimizer.minimize(self.loss_Relight, var_list = self.var_Relight)

训练参数的初始化:

self.sess.run(tf.global_variables_initializer())self.saver_Decom = tf.train.Saver(var_list = self.var_Decom)
self.saver_Relight = tf.train.Saver(var_list = self.var_Relight)print("[*] Initialize model successfully...")

接下来是该类的一些成员函数。

def gradient(self, input_tensor, direction):self.smooth_kernel_x = tf.reshape(tf.constant([[0, 0], [-1, 1]], tf.float32), [2, 2, 1, 1])self.smooth_kernel_y = tf.transpose(self.smooth_kernel_x, [1, 0, 2, 3])if direction == "x":kernel = self.smooth_kernel_xelif direction == "y":kernel = self.smooth_kernel_yreturn tf.abs(tf.nn.conv2d(input_tensor, kernel, strides=[1, 1, 1, 1], padding='SAME'))

该函数实现的是通过与指定梯度算子进行卷积的方式求图像的水平/垂直梯度图。

def ave_gradient(self, input_tensor, direction):return tf.layers.average_pooling2d(self.gradient(input_tensor, direction), pool_size=3, strides=1, padding='SAME')

该函数实现的是通过平均池化的方式来对图像的水平/垂直梯度图进行平滑。

def smooth(self, input_I, input_R):input_R = tf.image.rgb_to_grayscale(input_R)return tf.reduce_mean(self.gradient(input_I, "x") * tf.exp(-10 * self.ave_gradient(input_R, "x")) + self.gradient(input_I, "y") * tf.exp(-10 * self.ave_gradient(input_R, "y")))

该函数是对光照分量平滑损失的具体实现(可对应原论文中的公式来看)。

def evaluate(self, epoch_num, eval_low_data, sample_dir, train_phase):print("[*] Evaluating for phase %s / epoch %d..." % (train_phase, epoch_num))for idx in range(len(eval_low_data)):input_low_eval = np.expand_dims(eval_low_data[idx], axis=0)if train_phase == "Decom":result_1, result_2 = self.sess.run([self.output_R_low, self.output_I_low], feed_dict={self.input_low: input_low_eval})if train_phase == "Relight":result_1, result_2 = self.sess.run([self.output_S, self.output_I_delta], feed_dict={self.input_low: input_low_eval})save_images(os.path.join(sample_dir, 'eval_%s_%d_%d.png' % (train_phase, idx + 1, epoch_num)), result_1, result_2)

该函数是对训练epoch_num次后的Decom-Net/Enhance-Net模型进行评估,并保存评估结果图。

接下来是关于模型的训练:

def train(self, train_low_data, train_high_data, eval_low_data, batch_size, patch_size, epoch, lr, sample_dir, ckpt_dir, eval_every_epoch, train_phase):

该函数中包含了预训练模型的加载、数据的读取与处理、模型的训练、评估和保存这几个部分。

assert len(train_low_data) == len(train_high_data)
numBatch = len(train_low_data) // int(batch_size)

检查所有需要参与训练的低/正常光照样本数量是否一致,若一致则计算训练集含有的batch数量。

# load pretrained model
if train_phase == "Decom":train_op = self.train_op_Decomtrain_loss = self.loss_Decomsaver = self.saver_Decom
elif train_phase == "Relight":train_op = self.train_op_Relighttrain_loss = self.loss_Relightsaver = self.saver_Relightload_model_status, global_step = self.load(saver, ckpt_dir)
if load_model_status:iter_num = global_stepstart_epoch = global_step // numBatchstart_step = global_step % numBatchprint("[*] Model restore success!")
else:iter_num = 0start_epoch = 0start_step = 0
print("[*] Not find pretrained model!")

若存在Decom-Net/Enhance-Net对应的预训练模型,则进行加载;否则从头开始训练。

# generate data for a batch
batch_input_low = np.zeros((batch_size, patch_size, patch_size, 3), dtype="float32")
batch_input_high = np.zeros((batch_size, patch_size, patch_size, 3), dtype="float32")
for patch_id in range(batch_size):h, w, _ = train_low_data[image_id].shapex = random.randint(0, h - patch_size)y = random.randint(0, w - patch_size)rand_mode = random.randint(0, 7)batch_input_low[patch_id, :, :, :] = data_augmentation(train_low_data[image_id][x : x+patch_size, y : y+patch_size, :], rand_mode)batch_input_high[patch_id, :, :, :] = data_augmentation(train_high_data[image_id][x : x+patch_size, y : y+patch_size, :], rand_mode)image_id = (image_id + 1) % len(train_low_data)if image_id == 0:tmp = list(zip(train_low_data, train_high_data))random.shuffle(list(tmp))train_low_data, train_high_data = zip(*tmp)

顺序读取训练图像,在每次读取的低/正常光照图像对上随机取patch,并进行数据扩增(具体见 中对函数data_augmentation的描述)。这里,应当注意的是,训练数据每满一个batch时将会重新打乱整个训练集。

# train
_, loss = self.sess.run([train_op, train_loss], feed_dict={self.input_low: batch_input_low, self.input_high: batch_input_high, self.lr: lr[epoch]})print("%s Epoch: [%2d] [%4d/%4d] time: %4.4f, loss: %.6f" % (train_phase, epoch + 1, batch_id + 1, numBatch, time.time() - start_time, loss))
iter_num += 1

训练一个iter并打印相关信息。

# evalutate the model and save a checkpoint file for it
if (epoch + 1) % eval_every_epoch == 0:self.evaluate(epoch + 1, eval_low_data, sample_dir=sample_dir, train_phase=train_phase)self.save(saver, iter_num, ckpt_dir, "RetinexNet-%s" % train_phase)

每训练eval_every_epoch次评估并保存一次模型。

保存指定iter的模型:

def save(self, saver, iter_num, ckpt_dir, model_name):if not os.path.exists(ckpt_dir):os.makedirs(ckpt_dir)print("[*] Saving model %s" % model_name)saver.save(self.sess, os.path.join(ckpt_dir, model_name), global_step=iter_num)

加载最新的模型:

def load(self, saver, ckpt_dir):ckpt = tf.train.get_checkpoint_state(ckpt_dir)if ckpt and ckpt.model_checkpoint_path:full_path = tf.train.latest_checkpoint(ckpt_dir)try:global_step = int(full_path.split('/')[-1].split('-')[-1])except ValueError:global_step = Nonesaver.restore(self.sess, full_path)return True, global_stepelse:print("[*] Failed to load model from %s" % ckpt_dir)return False, 0

最后是关于模型的测试(其中test_high_data并没有用到):

def test(self, test_low_data, test_high_data, test_low_data_names, save_dir, decom_flag):

该函数中包含了模型的加载、模型的测试和结果图的保存这几个部分。

tf.global_variables_initializer().run()print("[*] Reading checkpoint...")
load_model_status_Decom, _ = self.load(self.saver_Decom, './model/Decom')
load_model_status_Relight, _ = self.load(self.saver_Relight, './model/Relight')
if load_model_status_Decom and load_model_status_Relight:print("[*] Load weights successfully...")

初始化所有参数并加载最新的Decom-Net和Enhance-Net模型。

print("[*] Testing...")
for idx in range(len(test_low_data)):print(test_low_data_names[idx])[_, name] = os.path.split(test_low_data_names[idx])suffix = name[name.find('.') + 1:]name = name[:name.find('.')]input_low_test = np.expand_dims(test_low_data[idx], axis=0)[R_low, I_low, I_delta, S] = self.sess.run([self.output_R_low, self.output_I_low, self.output_I_delta, self.output_S], feed_dict = {self.input_low: input_low_test})if decom_flag == 1:save_images(os.path.join(save_dir, name + "_R_low." + suffix), R_low)save_images(os.path.join(save_dir, name + "_I_low." + suffix), I_low)save_images(os.path.join(save_dir, name + "_I_delta." + suffix), I_delta)save_images(os.path.join(save_dir, name + "_S." + suffix), S)

遍历测试样本进行测试,并保存最终结果图(可自行指定是否保存Decom-Net的分解结果)。

欢迎关注公众号:huangxiaobai880

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/397330.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

计算机应用发表论文,计算机应用论文发表.docx

计算机应用论文发表1在工程项目管理中应用计算机技术存在的问题计算机软件是计算机运行的重要保障,一个好的计算机软件直接决定计算机技术在工程项目管理的高效应用。但由于市场上计算机软件种类繁多,质量好坏不一,质量好的价格高&#xff0c…

添加dubbo xsd的支持

使用dubbo时遇到问题: org.xml.sax.SAXParseException: schema_reference.4: Failed to read schema document http://code.alibabatech.com/schema/dubbo/dubbo.xsd, because 1) could not find the document; 2) the document could not be read; 3) the root ele…

byte数组穿换成pcm格式_形象地介绍DSD的编解码原理及和PCM的区别

一直有人不清楚DSD到底是啥原理,和MP3, FLAC, APE, WAV等基于PCM编码技术的音频格式又有啥区别。特意做了两张图说明一下。图一是是由很多黑点构成的蒙娜丽莎头像,点击看大图就知道是没有灰阶只有黑白两色。但是人眼是可以看到有丰富的灰阶的。这和DSD一…

最大熵对应的概率分布

最大熵对应的概率分布 最大熵定理 设 \(X \sim p(x)\) 是一个连续型随机变量,其微分熵定义为\[ h(X) - \int p(x)\log p(x) dx \]其中,\(\log\) 一般取自然对数 \(\ln\), 单位为 奈特(nats)。 考虑如下优化问题:\[ \b…

UBUNTU : Destination Host Unreachable

介绍我的系统的搭建的方式: WIN7 64 VMWARE STATION,方式是进行桥接的方式。最近突然出现了问题,Ubuntu ping 外网或者 PING WIN 7 的时候,出现 Destination Host Unreachable 的错误;想着去修改网卡的链接形式: 编辑…

焦作师范高等专科学校对口计算机分数线,焦作师范高等专科学校录取分数线2018...

焦作师范高等专科学校录取分数线20182018年 电子信息工程技术 理科 332 3602018年 物联网应用技术 文科 391 4082018年 物联网应用技术 理科 328 3692018年 学前教育 文科 388 4022018年 学前教育 理科 324 3512018年 移动应用开发 文科 02018年 移动应用开发 理科 305 3322018…

在Spring boot 配置过滤器(filter)

在spring boot 配置servlet filter 逻辑上与配置spring 是一样的。 不过相比spring 更加简化配置的难度。 这里只需要两步1 创建一个自定义顾虑器并继承spring filter 例如OncePerRequestFilterpublic class AuthenticationFilter extends OncePerRequestFilter{private final …

Flink之状态之状态存储 state backends

流计算中可能有各种方式来保存状态: 窗口操作使用 了KV操作的函数继承了CheckpointedFunction的函数当开始做checkpointing的时候,状态会被持久化到checkpoints里来规避数据丢失和状态恢复。选择的状态存储策略不同,会导致状态持久化如何和ch…

怎么把分开的pdf放在一起_糖和盐混在一起了要怎么分开?| 趣问万物

趣 问 万 物来源:把科学带回家撰文:Mirror如何分离糖和盐?图源:Pixabay小手一抖,不小心把糖(蔗糖)和盐(氯化钠)混在一块儿了该怎么办?趁着光棍节,就让我们吃饱了撑着研究研究把糖和盐拆散的N种方…

《JavaScript DOM编程艺术》笔记

1. 把<script>标签放到HTML文档的最后&#xff0c;<body>标签之前能使浏览器更快地加载页面。 2. nodeType的常见取值 元素节点(1) 属性节点(2) 文本节点(3) 3. <a href"http://www.baidu.com" οnclick"popUp(this.href);return false;"&g…

maven POM.xml内的标签大全详解

<project xmlns"http://maven.apache.org/POM/4.0.0" xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation"http://maven.apache.org/POM/4.0.0http://maven.apache.org/maven-v4_0_0.xsd"><!--父项目的坐标。如果…

常熟理工学院计算机考研,2018江苏专转本考生必看-常熟理工学院介绍

原标题&#xff1a;2018江苏专转本考生必看-常熟理工学院介绍这次轮到默默学介绍常熟理工学院啦&#xff01;今年常熟理工学院有个专转本的学生&#xff0c;也是默默学专转本视频课程考上常熟理工的一个学生&#xff0c;叫黄群超&#xff0c;当年专转本计算机也考了八九十分吧&…

.net中调用esb_大型ESB服务总线平台服务运行分析和监控预警实践

今天准备谈下ESB总线平台建设项目中的服务运行统计分析&#xff0c;服务心跳监测&#xff0c;服务监控预警方面的设计和实现。可以看到&#xff0c;在一个ESB服务总线平台上线后&#xff0c;SOA治理管控就变得相当重要&#xff0c;而这些运行监控分析本身也是提升ESB总线平台高…

使用Maven创建Web项目后,jsp引入静态文件提示报错。JSP 报错:javax.servlet.ServletException cannot be resolved to a type...

用maven创建多模块的web工程后&#xff0c;不同于直接创建普通的web工程。 1、在普通的web工程创建后&#xff0c;在项目中会有tomcat等服务器的jar包&#xff0c;这时创建JSP文件肯定是没有错的&#xff1b; 2、即使是使用maven创建的单模块的web工程&#xff0c;也会自动的在…

ES6之路第十三篇:Iterator和for...of循环

Iterator(遍历器)的概念 JavaScript 原有的表示“集合”的数据结构&#xff0c;主要是数组&#xff08;Array&#xff09;和对象&#xff08;Object&#xff09;&#xff0c;ES6 又添加了Map和Set。这样就有了四种数据集合&#xff0c;用户还可以组合使用它们&#xff0c;定义自…

MyBatis 特殊字符处理

http://blog.csdn.net/zheng0518/article/details/10449549

计算机操作系统实验银行家算法,实验六 银行家算法(下)

实验六 银行家算法(下)一、实验说明实验说明&#xff1a;本次实验主要是对银行家算法进行进一步的实践学习&#xff0c;掌握银行家算法的整体流程&#xff0c;理解程序测试时每一步的当前状态&#xff0c;能对当前的资源分配进行预判断。二、实验要求1、获取源代码2、看懂大致框…

什么原因导致芯片短路_华为为什么突然大量用起了联发科芯片,或是这三个产品策略原因...

经常关注数码圈的都知道&#xff0c;近几年来&#xff0c;随着华为自研能力的提升&#xff0c;华为几乎很少采购第三方芯片&#xff0c;近几年来的绝大多数华为手机&#xff0c;几乎都是用的自研芯片麒麟系列。并没有像其它国产品牌那样用联发科或者高通的芯片。不过今年却大不…

如何运行vue项目(维护他人的项目)

假如你是个小白&#xff0c;在公司接手他人的项目&#xff0c;这个时候&#xff0c;该怎么将这个项目跑通&#xff1f; 前提&#xff1a; 首先&#xff0c;这个教程主要针对vue小白&#xff0c;并且不知道安装node.js环境的。言归正传&#xff0c;下面开始教程&#xff1a;在维…

进程操作

2019独角兽企业重金招聘Python工程师标准>>> 一、创建一个进程 进程是系统中最基本的执行单位。Linux系统允许任何一个用户进程创建一个子进程&#xff0c;创建之后&#xff0c;子进程存在于系统之中并独立于父进程。 关于父进程与子进程这两个概念&#xff0c;除了…