目录
一.项目介绍
二.项目流程详解
2.1.数据加载与配置
2.2.构建生成网络
2.3.构建判别网络
2.4.VGG特征提取网络
2.5.损失函数
三.完整代码
四.数据集
五.测试网络
一.项目介绍
超分辨率(Super-Resolution),简称超分(SR)。是指利用光学及其相关光学知识,根据已知图像信息恢复图像细节和其他数据信息的过程,简单来说就是增大图像的分辨率,防止其图像质量下降。
GAN的全称是Generative Adversarial Networks,即生成对抗网络。生成对抗网络一般由一个生成器(生成网络),和一个判别器(判别网络)组成
SRGAN使用了生成对抗的方式来进行图像的超分辨率重建,同时提出了一个由Adversarial Loss和Content Loss组成的损失函数。
论文地址:Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network | IEEE Conference Publication | IEEE Xplorehttps://ieeexplore.ieee.org/document/8099502
网络架构:
分为两个网络:生成网络和判别网络
生成网络的作用是,通过学习训练集数据的特征,在判别器的指导下,将随机噪声分布尽量拟合为训练数据的真实分布,从而生成具有训练集特征的相似数据。
判别网络则负责区分输入的数据是真实的还是生成器生成的假数据,并反馈给生成器。
两个网络交替训练,能力同步提高,直到生成网络生成的数据能够以假乱真,并与与判别网络的能力达到一定均衡。
二.项目流程详解
2.1.数据加载与配置
参数配置:
from easydict import EasyDict as edict
import jsonconfig = edict()
config.TRAIN = edict()## Adam
config.TRAIN.batch_size = 4
config.TRAIN.lr_init = 1e-4
config.TRAIN.beta1 = 0.9## initialize G
config.TRAIN.n_epoch_init = 100# config.TRAIN.lr_decay_init = 0.1# config.TRAIN.decay_every_init = int(config.TRAIN.n_epoch_init / 2)## adversarial learning (SRGAN)
config.TRAIN.n_epoch = 2000
config.TRAIN.lr_decay = 0.1
config.TRAIN.decay_every = int(config.TRAIN.n_epoch / 2)## train set location
config.TRAIN.hr_img_path = './srdata/DIV2K_train_HR'
config.TRAIN.lr_img_path = './srdata/DIV2K_train_LR_bicubic/X4'config.VALID = edict()
## test set location
config.VALID.hr_img_path = './srdata/DIV2K_valid_HR'
config.VALID.lr_img_path = './srdata/DIV2K_valid_LR_bicubic/X4'def log_config(filename, cfg):with open(filename, 'w') as f:f.write("================================================\n")f.write(json.dumps(cfg, indent=4))f.write("\n================================================\n")
数据加载:
# 通过tl.files.load_file_list获取图片名字
# 第一个参数是图片所在的文件夹的路径,第二个参数为图片类型
# sorted[:x]表示读取x个图像。(读取图像过多可能造成memory error问题)
train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False))[:100]
train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False))[:100]
valid_hr_img_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False))[:50]
valid_lr_img_list = sorted(tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False))[:50]# If your machine have enough memory, please pre-load the whole train set.
# 通过tl.vis.read_images读取图片
# 第一个参数是先前取出的图片,第二个参数是图片所在的文件夹地址,第三个参数是一次性读取多少图片
train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path,n_threads=8)
2.2.构建生成网络
tf.compat.v1.disable_eager_execution()
t_image = tf.compat.v1.placeholder('float32', [batch_size, 96, 96, 3], name='t_image_input_to_SRGAN_generator')# 构建生成网络
# reuse=False表示不重复构建网络
net_g = SRGAN_g(t_image, is_train=True, reuse=False)
SRGAN_g:
def SRGAN_g(t_image, is_train=False, reuse=False):""" Generator in Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Networkfeature maps (n) and stride (s) feature maps (n) and stride (s)"""# 权重初始化w_init = tf.random_normal_initializer(stddev=0.02)b_init = None # tf.constant_initializer(value=0.0)# gamma值初始化(BatchNormalization中的参数)g_init = tf.random_normal_initializer(1., 0.02)with tf.variable_scope("SRGAN_g", reuse=reuse) as vs:# tl.layers.set_name_reuse(reuse) # remove for TL 1.8.0+# 输入层构造n = InputLayer(t_image, name='in')# 卷积层构造n = Conv2d(n, 64, (3, 3), (1, 1), act=tf.nn.relu, padding='SAME', W_init=w_init, name='n64s1/c')temp = n# B residual blocks(增加16层残差模块)for i in range(16):nn = Conv2d(n, 64, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='n64s1/c1/%s' % i)nn = BatchNormLayer(nn, act=tf.nn.relu, is_train=is_train, gamma_init=g_init, name='n64s1/b1/%s' % i)nn = Conv2d(nn, 64, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='n64s1/c2/%s' % i)nn = BatchNormLayer(nn, is_train=is_train, gamma_init=g_init, name='n64s1/b2/%s' % i)# 两个网络相融合形成残差网络:nn = n + nn# 其中n是最初未经处理的网络,nn是处理后的网络(此处是经过两次卷积和两次BatchNormalization)nn = ElementwiseLayer([n, nn], tf.add, name='b_residual_add/%s' % i)n = nnn = Conv2d(n, 64, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='n64s1/c/m')n = BatchNormLayer(n, is_train=is_train, gamma_init=g_init, name='n64s1/b/m')# 最开始的网络和执行了一系列处理后的网络再进行一次融合形成新的网络n = ElementwiseLayer([n, temp], tf.add, name='add3')# B residual blacks end# 开始对照片进行重构操作,由低分辨率重构成高分辨率n = Conv2d(n, 256, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, name='n256s1/1')n = SubpixelConv2d(n, scale=2, n_out_channel=None, act=tf.nn.relu, name='pixelshufflerx2/1')n = Conv2d(n, 256, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, name='n256s1/2')n = SubpixelConv2d(n, scale=2, n_out_channel=None, act=tf.nn.relu, name='pixelshufflerx2/2')# 重构后进行一次卷积得到最终的结果n = Conv2d(n, 3, (1, 1), (1, 1), act=tf.nn.tanh, padding='SAME', W_init=w_init, name='out')return n
2.3.构建判别网络
tf.compat.v1.disable_eager_execution()
t_target_image = tf.compat.v1.placeholder('float32', [batch_size, 384, 384, 3], name='t_target_image')# 构建判别网络
# 让判别网络判断什么是真的,传入的数据参数是真实的图像数据
# reuse=False表示不共用网络
net_d, logits_real = SRGAN_d(t_target_image, is_train=True, reuse=False)
# 让判别网络判断什么是假的,传入的数据参数是生成网络生成的图像数据
# reuse=True表示要共用网络
_, logits_fake = SRGAN_d(net_g.outputs, is_train=True, reuse=True)
SRGAN_d:
def SRGAN_d(input_images, is_train=True, reuse=False):w_init = tf.random_normal_initializer(stddev=0.02)b_init = None # tf.constant_initializer(value=0.0)gamma_init = tf.random_normal_initializer(1., 0.02)df_dim = 64lrelu = lambda x: tl.act.lrelu(x, 0.2)# 开始进行网络的构造with tf.variable_scope("SRGAN_d", reuse=reuse):tl.layers.set_name_reuse(reuse)net_in = InputLayer(input_images, name='input/images')net_h0 = Conv2d(net_in, df_dim, (4, 4), (2, 2), act=lrelu, padding='SAME', W_init=w_init, name='h0/c')net_h1 = Conv2d(net_h0, df_dim * 2, (4, 4), (2, 2), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='h1/c')net_h1 = BatchNormLayer(net_h1, act=lrelu, is_train=is_train, gamma_init=gamma_init, name='h1/bn')net_h2 = Conv2d(net_h1, df_dim * 4, (4, 4), (2, 2), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='h2/c')net_h2 = BatchNormLayer(net_h2, act=lrelu, is_train=is_train, gamma_init=gamma_init, name='h2/bn')net_h3 = Conv2d(net_h2, df_dim * 8, (4, 4), (2, 2), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='h3/c')net_h3 = BatchNormLayer(net_h3, act=lrelu, is_train=is_train, gamma_init=gamma_init, name='h3/bn')net_h4 = Conv2d(net_h3, df_dim * 16, (4, 4), (2, 2), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='h4/c')net_h4 = BatchNormLayer(net_h4, act=lrelu, is_train=is_train, gamma_init=gamma_init, name='h4/bn')net_h5 = Conv2d(net_h4, df_dim * 32, (4, 4), (2, 2), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='h5/c')net_h5 = BatchNormLayer(net_h5, act=lrelu, is_train=is_train, gamma_init=gamma_init, name='h5/bn')net_h6 = Conv2d(net_h5, df_dim * 16, (1, 1), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='h6/c')net_h6 = BatchNormLayer(net_h6, act=lrelu, is_train=is_train, gamma_init=gamma_init, name='h6/bn')net_h7 = Conv2d(net_h6, df_dim * 8, (1, 1), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='h7/c')net_h7 = BatchNormLayer(net_h7, is_train=is_train, gamma_init=gamma_init, name='h7/bn')net = Conv2d(net_h7, df_dim * 2, (1, 1), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='res/c')net = BatchNormLayer(net, act=lrelu, is_train=is_train, gamma_init=gamma_init, name='res/bn')net = Conv2d(net, df_dim * 2, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='res/c2')net = BatchNormLayer(net, act=lrelu, is_train=is_train, gamma_init=gamma_init, name='res/bn2')net = Conv2d(net, df_dim * 8, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='res/c3')net = BatchNormLayer(net, is_train=is_train, gamma_init=gamma_init, name='res/bn3')net_h8 = ElementwiseLayer([net_h7, net], combine_fn=tf.add, name='res/add')net_h8.outputs = tl.act.lrelu(net_h8.outputs, 0.2)# 拉长卷积结果,通过全连接层net_ho = FlattenLayer(net_h8, name='ho/flatten')net_ho = DenseLayer(net_ho, n_units=1, act=tf.identity, W_init=w_init, name='ho/dense')logits = net_ho.outputs# 经过sigmoid函数得到最终的结果值,判断是真还是假net_ho.outputs = tf.nn.sigmoid(net_ho.outputs)return net_ho, logits
2.4.VGG特征提取网络
## vgg inference. 0, 1, 2, 3 BILINEAR NEAREST BICUBIC AREA# 修改数据的尺寸大小,以满足VGG网络的要求# 对原始图像进行resizet_target_image_224 = tf.image.resize_images(t_target_image, size=[224, 224], method=0,align_corners=False) # resize_target_image_for_vgg # http://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/layers.html#UpSampling2dLayer# 对生成图像进行resizet_predict_image_224 = tf.image.resize_images(net_g.outputs, size=[224, 224], method=0, align_corners=False) # resize_generate_image_for_vggnet_vgg, vgg_target_emb = Vgg19_simple_api((t_target_image_224 + 1) / 2, reuse=False)_, vgg_predict_emb = Vgg19_simple_api((t_predict_image_224 + 1) / 2, reuse=True)
Vgg19_simple_api:
def Vgg19_simple_api(rgb, reuse):"""Build the VGG 19 ModelParameters-----------rgb : rgb image placeholder [batch, height, width, 3] values scaled [0, 1]"""VGG_MEAN = [103.939, 116.779, 123.68]with tf.variable_scope("VGG19", reuse=reuse) as vs:start_time = time.time()print("build model started")rgb_scaled = rgb * 255.0# Convert RGB to BGRred, green, blue = tf.split(rgb_scaled, 3, 3)assert red.get_shape().as_list()[1:] == [224, 224, 1]assert green.get_shape().as_list()[1:] == [224, 224, 1]assert blue.get_shape().as_list()[1:] == [224, 224, 1]# 减均值操作:各自的颜色通道减去各自的均值bgr = tf.concat([blue - VGG_MEAN[0],green - VGG_MEAN[1],red - VGG_MEAN[2],], axis=3)assert bgr.get_shape().as_list()[1:] == [224, 224, 3]""" input layer """net_in = InputLayer(bgr, name='input')""" conv1 """network = Conv2d(net_in, n_filter=64, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv1_1')network = Conv2d(network, n_filter=64, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv1_2')network = MaxPool2d(network, filter_size=(2, 2), strides=(2, 2), padding='SAME', name='pool1')""" conv2 """network = Conv2d(network, n_filter=128, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv2_1')network = Conv2d(network, n_filter=128, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv2_2')network = MaxPool2d(network, filter_size=(2, 2), strides=(2, 2), padding='SAME', name='pool2')""" conv3 """network = Conv2d(network, n_filter=256, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv3_1')network = Conv2d(network, n_filter=256, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv3_2')network = Conv2d(network, n_filter=256, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv3_3')network = Conv2d(network, n_filter=256, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv3_4')network = MaxPool2d(network, filter_size=(2, 2), strides=(2, 2), padding='SAME', name='pool3')""" conv4 """network = Conv2d(network, n_filter=512, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv4_1')network = Conv2d(network, n_filter=512, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv4_2')network = Conv2d(network, n_filter=512, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv4_3')network = Conv2d(network, n_filter=512, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv4_4')network = MaxPool2d(network, filter_size=(2, 2), strides=(2, 2), padding='SAME', name='pool4') # (batch_size, 14, 14, 512)conv = network""" conv5 """network = Conv2d(network, n_filter=512, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv5_1')network = Conv2d(network, n_filter=512, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv5_2')network = Conv2d(network, n_filter=512, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv5_3')network = Conv2d(network, n_filter=512, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv5_4')network = MaxPool2d(network, filter_size=(2, 2), strides=(2, 2), padding='SAME', name='pool5') # (batch_size, 7, 7, 512)""" fc 6~8 """# 拉长数据经过全连接层network = FlattenLayer(network, name='flatten')network = DenseLayer(network, n_units=4096, act=tf.nn.relu, name='fc6')network = DenseLayer(network, n_units=4096, act=tf.nn.relu, name='fc7')network = DenseLayer(network, n_units=1000, act=tf.identity, name='fc8')print("build model finished: %fs" % (time.time() - start_time))return network, conv
2.5.损失函数
# ###========================== DEFINE TRAIN OPS ==========================#### 判别器的loss设置:# 如果是真实图像,设置ones_liked_loss1 = tl.cost.sigmoid_cross_entropy(logits_real, tf.ones_like(logits_real), name='d1')# 如果是假图像,设置zeros_liked_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake, tf.zeros_like(logits_fake), name='d2')d_loss = d_loss1 + d_loss2# 希望生成网络生成的图片是真的,设置ones_likeg_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(logits_fake, tf.ones_like(logits_fake), name='g')# 生成结果和真实图片进行比较mse_loss = tl.cost.mean_squared_error(net_g.outputs, t_target_image, is_mean=True)# 生成结果和真是图片经过VGG网络提取特征后的比较vgg_loss = 2e-6 * tl.cost.mean_squared_error(vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True)# 生成对抗网络的最终lossg_loss = mse_loss + vgg_loss + g_gan_loss
三.完整代码
main.py
#! /usr/bin/python
# -*- coding: utf8 -*-
#http://tensorlayercn.readthedocs.io/zh/latest/user/installation.html
import os
import time
import pickle, random
#from datetime import datetime
import numpy as np
import logging, scipyimport tensorflow as tf
import tensorlayer as tl
from model import SRGAN_g, SRGAN_d, Vgg19_simple_api
from utils import *
from config import config, log_config###====================== HYPER-PARAMETERS ===========================###
## Adam
batch_size = config.TRAIN.batch_size
lr_init = config.TRAIN.lr_init
beta1 = config.TRAIN.beta1
## initialize G
n_epoch_init = config.TRAIN.n_epoch_init
## adversarial learning (SRGAN)
n_epoch = config.TRAIN.n_epoch
lr_decay = config.TRAIN.lr_decay
decay_every = config.TRAIN.decay_everyni = int(np.sqrt(batch_size))def train():## create folders to save result images and trained modelsave_dir_ginit = "samples/{}_ginit".format(tl.global_flag['mode'])save_dir_gan = "samples/{}_gan".format(tl.global_flag['mode'])tl.files.exists_or_mkdir(save_dir_ginit)tl.files.exists_or_mkdir(save_dir_gan)checkpoint_dir = "checkpoint" # checkpoint_resize_convtl.files.exists_or_mkdir(checkpoint_dir)###====================== PRE-LOAD DATA ===========================#### 通过tl.files.load_file_list获取图片名字# 第一个参数是图片所在的文件夹的路径,第二个参数为图片类型# sorted[:x]表示读取x个图像(读取图像过多可能造成memory error问题)train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False))[:100]train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False))[:100]valid_hr_img_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False))[:50]valid_lr_img_list = sorted(tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False))[:50]# If your machine have enough memory, please pre-load the whole train set.# 通过tl.vis.read_images读取图片# 第一个参数是先前取出的图片,第二个参数是图片所在的文件夹地址,第三个参数是一次性读取多少图片train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=8)# for im in train_hr_imgs:# print(im.shape)# valid_lr_imgs = tl.vis.read_images(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=32)# for im in valid_lr_imgs:# print(im.shape)# valid_hr_imgs = tl.vis.read_images(valid_hr_img_list, path=config.VALID.hr_img_path, n_threads=32)# for im in valid_hr_imgs:# print(im.shape)# exit()###========================== DEFINE MODEL ============================##### train inferencetf.compat.v1.disable_eager_execution()t_image = tf.compat.v1.placeholder('float32', [batch_size, 96, 96, 3], name='t_image_input_to_SRGAN_generator')t_target_image = tf.compat.v1.placeholder('float32', [batch_size, 384, 384, 3], name='t_target_image')# 构建生成网络# reuse=False表示不共用网络net_g = SRGAN_g(t_image, is_train=True, reuse=False)# 构建判别网络# 让判别网络判断什么是真的,传入的数据参数是真实的图像数据# reuse=False表示不共用网络net_d, logits_real = SRGAN_d(t_target_image, is_train=True, reuse=False)# 让判别网络判断什么是假的,传入的数据参数是生成网络生成的图像数据# reuse=True表示要共用网络_, logits_fake = SRGAN_d(net_g.outputs, is_train=True, reuse=True)net_g.print_params(False)net_g.print_layers()net_d.print_params(False)net_d.print_layers()## vgg inference. 0, 1, 2, 3 BILINEAR NEAREST BICUBIC AREA# 修改数据的尺寸大小,以满足VGG网络的要求# 对原始图像进行resizet_target_image_224 = tf.image.resize_images(t_target_image, size=[224, 224], method=0,align_corners=False) # resize_target_image_for_vgg # http://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/layers.html#UpSampling2dLayer# 对生成图像进行resizet_predict_image_224 = tf.image.resize_images(net_g.outputs, size=[224, 224], method=0, align_corners=False) # resize_generate_image_for_vggnet_vgg, vgg_target_emb = Vgg19_simple_api((t_target_image_224 + 1) / 2, reuse=False)_, vgg_predict_emb = Vgg19_simple_api((t_predict_image_224 + 1) / 2, reuse=True)## test inferencenet_g_test = SRGAN_g(t_image, is_train=False, reuse=True)# ###========================== DEFINE TRAIN OPS ==========================#### 判别器的loss设置:# 如果是真实图像,设置ones_liked_loss1 = tl.cost.sigmoid_cross_entropy(logits_real, tf.ones_like(logits_real), name='d1')# 如果是假图像,设置zeros_liked_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake, tf.zeros_like(logits_fake), name='d2')d_loss = d_loss1 + d_loss2# 希望生成网络生成的图片是真的,设置ones_likeg_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(logits_fake, tf.ones_like(logits_fake), name='g')# 生成结果和真实图片进行比较mse_loss = tl.cost.mean_squared_error(net_g.outputs, t_target_image, is_mean=True)# 生成结果和真是图片经过VGG网络提取特征后的比较vgg_loss = 2e-6 * tl.cost.mean_squared_error(vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True)# 生成对抗网络的最终lossg_loss = mse_loss + vgg_loss + g_gan_loss# 获取参数g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True)d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True)with tf.variable_scope('learning_rate'):lr_v = tf.Variable(lr_init, trainable=False)## Pretraing_optim_init = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(mse_loss, var_list=g_vars)## SRGANg_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(g_loss, var_list=g_vars)d_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(d_loss, var_list=d_vars)###========================== RESTORE MODEL =============================###sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False))tl.layers.initialize_global_variables(sess)if tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_{}.npz'.format(tl.global_flag['mode']), network=net_g) is False:tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_{}_init.npz'.format(tl.global_flag['mode']), network=net_g)tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/d_{}.npz'.format(tl.global_flag['mode']), network=net_d)###============================= LOAD VGG ===============================###vgg19_npy_path = "vgg19.npy"if not os.path.isfile(vgg19_npy_path):print("Please download vgg19.npz from : https://github.com/machrisaa/tensorflow-vgg")exit()npz = np.load(vgg19_npy_path, encoding='latin1').item()params = []for val in sorted(npz.items()):W = np.asarray(val[1][0])b = np.asarray(val[1][1])print(" Loading %s: %s, %s" % (val[0], W.shape, b.shape))params.extend([W, b])tl.files.assign_params(sess, params, net_vgg)# net_vgg.print_params(False)# net_vgg.print_layers()print ('ok')###============================= TRAINING ===============================##### use first `batch_size` of train set to have a quick test during trainingsample_imgs = train_hr_imgs[0:batch_size]# sample_imgs = tl.vis.read_images(train_hr_img_list[0:batch_size], path=config.TRAIN.hr_img_path, n_threads=32) # if no pre-load train setsample_imgs_384 = tl.prepro.threading_data(sample_imgs, fn=crop_sub_imgs_fn, is_random=False)print('sample HR sub-image:', sample_imgs_384.shape, sample_imgs_384.min(), sample_imgs_384.max())sample_imgs_96 = tl.prepro.threading_data(sample_imgs_384, fn=downsample_fn)print('sample LR sub-image:', sample_imgs_96.shape, sample_imgs_96.min(), sample_imgs_96.max())tl.vis.save_images(sample_imgs_96, [ni, ni], save_dir_ginit + '/_train_sample_96.png')tl.vis.save_images(sample_imgs_384, [ni, ni], save_dir_ginit + '/_train_sample_384.png')tl.vis.save_images(sample_imgs_96, [ni, ni], save_dir_gan + '/_train_sample_96.png')tl.vis.save_images(sample_imgs_384, [ni, ni], save_dir_gan + '/_train_sample_384.png')###========================= initialize G ====================##### fixed learning ratesess.run(tf.assign(lr_v, lr_init))print(" ** fixed learning rate: %f (for init G)" % lr_init)for epoch in range(0, n_epoch_init + 1):epoch_time = time.time()total_mse_loss, n_iter = 0, 0## If your machine cannot load all images into memory, you should use## this one to load batch of images while training.# random.shuffle(train_hr_img_list)# for idx in range(0, len(train_hr_img_list), batch_size):# step_time = time.time()# b_imgs_list = train_hr_img_list[idx : idx + batch_size]# b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path)# b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True)# b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)## If your machine have enough memory, please pre-load the whole train set.for idx in range(0, len(train_hr_imgs), batch_size):step_time = time.time()b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn, is_random=True)b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)## update GerrM, _ = sess.run([mse_loss, g_optim_init], {t_image: b_imgs_96, t_target_image: b_imgs_384})print("Epoch [%2d/%2d] %4d time: %4.4fs, mse: %.8f " % (epoch, n_epoch_init, n_iter, time.time() - step_time, errM))total_mse_loss += errMn_iter += 1log = "[*] Epoch: [%2d/%2d] time: %4.4fs, mse: %.8f" % (epoch, n_epoch_init, time.time() - epoch_time, total_mse_loss / n_iter)print(log)## quick evaluation on train setif (epoch != 0) and (epoch % 10 == 0):out = sess.run(net_g_test.outputs, {t_image: sample_imgs_96}) #; print('gen sub-image:', out.shape, out.min(), out.max())print("[*] save images")tl.vis.save_images(out, [ni, ni], save_dir_ginit + '/train_%d.png' % epoch)## save modelif (epoch != 0) and (epoch % 10 == 0):tl.files.save_npz(net_g.all_params, name=checkpoint_dir + '/g_{}_init.npz'.format(tl.global_flag['mode']), sess=sess)###========================= train GAN (SRGAN) =========================###for epoch in range(0, n_epoch + 1):## update learning rateif epoch != 0 and (epoch % decay_every == 0):new_lr_decay = lr_decay**(epoch // decay_every)sess.run(tf.assign(lr_v, lr_init * new_lr_decay))log = " ** new learning rate: %f (for GAN)" % (lr_init * new_lr_decay)print(log)elif epoch == 0:sess.run(tf.assign(lr_v, lr_init))log = " ** init lr: %f decay_every_init: %d, lr_decay: %f (for GAN)" % (lr_init, decay_every, lr_decay)print(log)epoch_time = time.time()total_d_loss, total_g_loss, n_iter = 0, 0, 0## If your machine cannot load all images into memory, you should use## this one to load batch of images while training.# random.shuffle(train_hr_img_list)# for idx in range(0, len(train_hr_img_list), batch_size):# step_time = time.time()# b_imgs_list = train_hr_img_list[idx : idx + batch_size]# b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path)# b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True)# b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)## If your machine have enough memory, please pre-load the whole train set.for idx in range(0, len(train_hr_imgs), batch_size):step_time = time.time()b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn, is_random=True)b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)## update DerrD, _ = sess.run([d_loss, d_optim], {t_image: b_imgs_96, t_target_image: b_imgs_384})## update GerrG, errM, errV, errA, _ = sess.run([g_loss, mse_loss, vgg_loss, g_gan_loss, g_optim], {t_image: b_imgs_96, t_target_image: b_imgs_384})print("Epoch [%2d/%2d] %4d time: %4.4fs, d_loss: %.8f g_loss: %.8f (mse: %.6f vgg: %.6f adv: %.6f)" %(epoch, n_epoch, n_iter, time.time() - step_time, errD, errG, errM, errV, errA))total_d_loss += errDtotal_g_loss += errGn_iter += 1log = "[*] Epoch: [%2d/%2d] time: %4.4fs, d_loss: %.8f g_loss: %.8f" % (epoch, n_epoch, time.time() - epoch_time, total_d_loss / n_iter,total_g_loss / n_iter)print(log)## quick evaluation on train setif (epoch != 0) and (epoch % 10 == 0):out = sess.run(net_g_test.outputs, {t_image: sample_imgs_96}) #; print('gen sub-image:', out.shape, out.min(), out.max())print("[*] save images")tl.vis.save_images(out, [ni, ni], save_dir_gan + '/train_%d.png' % epoch)## save modelif (epoch != 0) and (epoch % 10 == 0):tl.files.save_npz(net_g.all_params, name=checkpoint_dir + '/g_{}.npz'.format(tl.global_flag['mode']), sess=sess)tl.files.save_npz(net_d.all_params, name=checkpoint_dir + '/d_{}.npz'.format(tl.global_flag['mode']), sess=sess)def evaluate():## create folders to save result imagessave_dir = "samples/{}".format(tl.global_flag['mode'])tl.files.exists_or_mkdir(save_dir)checkpoint_dir = "checkpoint"###====================== PRE-LOAD DATA ===========================#### train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False))# train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False))valid_hr_img_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False))valid_lr_img_list = sorted(tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False))## If your machine have enough memory, please pre-load the whole train set.# train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32)# for im in train_hr_imgs:# print(im.shape)valid_lr_imgs = tl.vis.read_images(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=8)# for im in valid_lr_imgs:# print(im.shape)valid_hr_imgs = tl.vis.read_images(valid_hr_img_list, path=config.VALID.hr_img_path, n_threads=8)# for im in valid_hr_imgs:# print(im.shape)# exit()###========================== DEFINE MODEL ============================###imid = 64 # 0: 企鹅 81: 蝴蝶 53: 鸟 64: 古堡valid_lr_img = valid_lr_imgs[imid]valid_hr_img = valid_hr_imgs[imid]# valid_lr_img = get_imgs_fn('test.png', 'data2017/') # if you want to test your own imagevalid_lr_img = (valid_lr_img / 127.5) - 1 # rescale to [-1, 1]# print(valid_lr_img.min(), valid_lr_img.max())size = valid_lr_img.shape# t_image = tf.placeholder('float32', [None, size[0], size[1], size[2]], name='input_image') # the old version of TL need to specify the image sizet_image = tf.placeholder('float32', [1, None, None, 3], name='input_image')net_g = SRGAN_g(t_image, is_train=False, reuse=False)###========================== RESTORE G =============================###sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False))tl.layers.initialize_global_variables(sess)tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_srgan.npz', network=net_g)###======================= EVALUATION =============================###start_time = time.time()out = sess.run(net_g.outputs, {t_image: [valid_lr_img]})print("took: %4.4fs" % (time.time() - start_time))print("LR size: %s / generated HR size: %s" % (size, out.shape)) # LR size: (339, 510, 3) / gen HR size: (1, 1356, 2040, 3)print("[*] save images")tl.vis.save_image(out[0], save_dir + '/valid_gen.png')tl.vis.save_image(valid_lr_img, save_dir + '/valid_lr.png')tl.vis.save_image(valid_hr_img, save_dir + '/valid_hr.png')out_bicu = scipy.misc.imresize(valid_lr_img, [size[0] * 4, size[1] * 4], interp='bicubic', mode=None)tl.vis.save_image(out_bicu, save_dir + '/valid_bicubic.png')if __name__ == '__main__':import argparseparser = argparse.ArgumentParser()parser.add_argument('--mode', type=str, default='srgan', help='srgan, evaluate')args = parser.parse_args()tl.global_flag['mode'] = args.modeif tl.global_flag['mode'] == 'srgan':train()elif tl.global_flag['mode'] == 'evaluate':evaluate()else:raise Exception("Unknow --mode")
config.py
from easydict import EasyDict as edict
import jsonconfig = edict()
config.TRAIN = edict()## Adam
config.TRAIN.batch_size = 4
config.TRAIN.lr_init = 1e-4
config.TRAIN.beta1 = 0.9## initialize G
config.TRAIN.n_epoch_init = 100# config.TRAIN.lr_decay_init = 0.1# config.TRAIN.decay_every_init = int(config.TRAIN.n_epoch_init / 2)## adversarial learning (SRGAN)
config.TRAIN.n_epoch = 2000
config.TRAIN.lr_decay = 0.1
config.TRAIN.decay_every = int(config.TRAIN.n_epoch / 2)## train set location
config.TRAIN.hr_img_path = './srdata/DIV2K_train_HR'
config.TRAIN.lr_img_path = './srdata/DIV2K_train_LR_bicubic/X4'config.VALID = edict()
## test set location
config.VALID.hr_img_path = './srdata/DIV2K_valid_HR'
config.VALID.lr_img_path = './srdata/DIV2K_valid_LR_bicubic/X4'def log_config(filename, cfg):with open(filename, 'w') as f:f.write("================================================\n")f.write(json.dumps(cfg, indent=4))f.write("\n================================================\n")
dowmload_imagenet.py
import argparse
import socket
import os
import urllib
import numpy as np
from PIL import Imagefrom joblib import Parallel, delayeddef download_image(download_str, save_dir):img_name, img_url = download_str.strip().split('\t')save_img = os.path.join(save_dir, "{}.jpg".format(img_name))downloaded = Falsetry:if not os.path.isfile(save_img):print("Downloading {} to {}.jpg".format(img_url, img_name))urllib.urlretrieve(img_url, save_img)# Check size of the imagesdownloaded = Truewith Image.open(save_img) as img:width, height = img.sizeimg_size_bytes = os.path.getsize(save_img)img_size_KB = img_size_bytes / 1024if width < 500 or height < 500 or img_size_KB < 200:os.remove(save_img)print("Remove downloaded images (w:{}, h:{}, s:{}KB)".format(width, height, img_size_KB))else:print("Already downloaded {}".format(save_img))except Exception:if not downloaded:print("Cannot download.")else:print("Remove failed, downloaded images.")if os.path.isfile(save_img):os.remove(save_img)def main():parser = argparse.ArgumentParser()parser.add_argument("--img_url_file", type=str, required=True,help="File that contains list of image IDs and urls.")parser.add_argument("--output_dir", type=str, required=True,help="Directory where to save outputs.")parser.add_argument("--n_download_urls", type=int, default=20000,help="Directory where to save outputs.")args = parser.parse_args()# np.random.seed(123456)socket.setdefaulttimeout(10)with open(args.img_url_file) as f:lines = f.readlines()lines = np.random.choice(lines, size=args.n_download_urls, replace=False)Parallel(n_jobs=12)(delayed(download_image)(line, args.output_dir) for line in lines)if __name__ == "__main__":main()
model.py
#! /usr/bin/python
# -*- coding: utf8 -*-import tensorflow as tf
import tensorlayer as tl
from tensorlayer.layers import *
import time
import os# from tensorflow.python.ops import variable_scope as vs
# from tensorflow.python.ops import math_ops, init_ops, array_ops, nn
# from tensorflow.python.util import nest
# from tensorflow.contrib.rnn.python.ops import core_rnn_cell# https://github.com/david-gpu/srez/blob/master/srez_model.pydef SRGAN_g(t_image, is_train=False, reuse=False):""" Generator in Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Networkfeature maps (n) and stride (s) feature maps (n) and stride (s)"""# 权重初始化w_init = tf.random_normal_initializer(stddev=0.02)b_init = None # tf.constant_initializer(value=0.0)# gamma值初始化(BatchNormalization中的参数)g_init = tf.random_normal_initializer(1., 0.02)# tf.compat.v1.disable_v2_behavior()with tf.compat.v1.variable_scope("SRGAN_g", reuse=reuse) as vs:# tl.layers.set_name_reuse(reuse) # remove for TL 1.8.0+# 输入层构造n = InputLayer(t_image, name='in')# 卷积层构造n = Conv2d(n, 64, (3, 3), (1, 1), act=tf.nn.relu, padding='SAME', W_init=w_init, name='n64s1/c')temp = n# B residual blocks(增加16层残差模块)for i in range(16):nn = Conv2d(n, 64, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='n64s1/c1/%s' % i)nn = BatchNormLayer(nn, act=tf.nn.relu, is_train=is_train, gamma_init=g_init, name='n64s1/b1/%s' % i)nn = Conv2d(nn, 64, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='n64s1/c2/%s' % i)nn = BatchNormLayer(nn, is_train=is_train, gamma_init=g_init, name='n64s1/b2/%s' % i)# 两个网络相融合形成残差网络:nn = n + nn# 其中n是最初未经处理的网络,nn是处理后的网络(此处是经过两次卷积和两次BatchNormalization)nn = ElementwiseLayer([n, nn], tf.add, name='b_residual_add/%s' % i)n = nnn = Conv2d(n, 64, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='n64s1/c/m')n = BatchNormLayer(n, is_train=is_train, gamma_init=g_init, name='n64s1/b/m')# 最开始的网络和执行了一系列处理后的网络再进行一次融合形成新的网络n = ElementwiseLayer([n, temp], tf.add, name='add3')# B residual blacks end# 开始对照片进行重构操作,由低分辨率重构成高分辨率n = Conv2d(n, 256, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, name='n256s1/1')n = SubpixelConv2d(n, scale=2, n_out_channel=None, act=tf.nn.relu, name='pixelshufflerx2/1')n = Conv2d(n, 256, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, name='n256s1/2')n = SubpixelConv2d(n, scale=2, n_out_channel=None, act=tf.nn.relu, name='pixelshufflerx2/2')# 重构后进行一次卷积得到最终的结果n = Conv2d(n, 3, (1, 1), (1, 1), act=tf.nn.tanh, padding='SAME', W_init=w_init, name='out')return ndef SRGAN_g2(t_image, is_train=False, reuse=False):""" Generator in Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Networkfeature maps (n) and stride (s) feature maps (n) and stride (s)96x96 --> 384x384Use Resize Conv"""w_init = tf.random_normal_initializer(stddev=0.02)b_init = None # tf.constant_initializer(value=0.0)g_init = tf.random_normal_initializer(1., 0.02)size = t_image.get_shape().as_list()with tf.variable_scope("SRGAN_g", reuse=reuse) as vs:# tl.layers.set_name_reuse(reuse) # remove for TL 1.8.0+n = InputLayer(t_image, name='in')n = Conv2d(n, 64, (3, 3), (1, 1), act=tf.nn.relu, padding='SAME', W_init=w_init, name='n64s1/c')temp = n# B residual blocksfor i in range(16):nn = Conv2d(n, 64, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='n64s1/c1/%s' % i)nn = BatchNormLayer(nn, act=tf.nn.relu, is_train=is_train, gamma_init=g_init, name='n64s1/b1/%s' % i)nn = Conv2d(nn, 64, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='n64s1/c2/%s' % i)nn = BatchNormLayer(nn, is_train=is_train, gamma_init=g_init, name='n64s1/b2/%s' % i)nn = ElementwiseLayer([n, nn], tf.add, name='b_residual_add/%s' % i)n = nnn = Conv2d(n, 64, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='n64s1/c/m')n = BatchNormLayer(n, is_train=is_train, gamma_init=g_init, name='n64s1/b/m')n = ElementwiseLayer([n, temp], tf.add, name='add3')# B residual blacks end# n = Conv2d(n, 256, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, name='n256s1/1')# n = SubpixelConv2d(n, scale=2, n_out_channel=None, act=tf.nn.relu, name='pixelshufflerx2/1')## n = Conv2d(n, 256, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, name='n256s1/2')# n = SubpixelConv2d(n, scale=2, n_out_channel=None, act=tf.nn.relu, name='pixelshufflerx2/2')## 0, 1, 2, 3 BILINEAR NEAREST BICUBIC AREAn = UpSampling2dLayer(n, size=[size[1] * 2, size[2] * 2], is_scale=False, method=1, align_corners=False, name='up1/upsample2d')n = Conv2d(n, 64, (3, 3), (1, 1), padding='SAME', W_init=w_init, b_init=b_init, name='up1/conv2d') # <-- may need to increase n_filtern = BatchNormLayer(n, act=tf.nn.relu, is_train=is_train, gamma_init=g_init, name='up1/batch_norm')n = UpSampling2dLayer(n, size=[size[1] * 4, size[2] * 4], is_scale=False, method=1, align_corners=False, name='up2/upsample2d')n = Conv2d(n, 32, (3, 3), (1, 1), padding='SAME', W_init=w_init, b_init=b_init, name='up2/conv2d') # <-- may need to increase n_filtern = BatchNormLayer(n, act=tf.nn.relu, is_train=is_train, gamma_init=g_init, name='up2/batch_norm')n = Conv2d(n, 3, (1, 1), (1, 1), act=tf.nn.tanh, padding='SAME', W_init=w_init, name='out')return ndef SRGAN_d2(t_image, is_train=False, reuse=False):""" Discriminator in Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Networkfeature maps (n) and stride (s) feature maps (n) and stride (s)"""w_init = tf.random_normal_initializer(stddev=0.02)b_init = Noneg_init = tf.random_normal_initializer(1., 0.02)lrelu = lambda x: tl.act.lrelu(x, 0.2)with tf.variable_scope("SRGAN_d", reuse=reuse) as vs:# tl.layers.set_name_reuse(reuse) # remove for TL 1.8.0+n = InputLayer(t_image, name='in')n = Conv2d(n, 64, (3, 3), (1, 1), act=lrelu, padding='SAME', W_init=w_init, name='n64s1/c')n = Conv2d(n, 64, (3, 3), (2, 2), act=lrelu, padding='SAME', W_init=w_init, b_init=b_init, name='n64s2/c')n = BatchNormLayer(n, is_train=is_train, gamma_init=g_init, name='n64s2/b')n = Conv2d(n, 128, (3, 3), (1, 1), act=lrelu, padding='SAME', W_init=w_init, b_init=b_init, name='n128s1/c')n = BatchNormLayer(n, is_train=is_train, gamma_init=g_init, name='n128s1/b')n = Conv2d(n, 128, (3, 3), (2, 2), act=lrelu, padding='SAME', W_init=w_init, b_init=b_init, name='n128s2/c')n = BatchNormLayer(n, is_train=is_train, gamma_init=g_init, name='n128s2/b')n = Conv2d(n, 256, (3, 3), (1, 1), act=lrelu, padding='SAME', W_init=w_init, b_init=b_init, name='n256s1/c')n = BatchNormLayer(n, is_train=is_train, gamma_init=g_init, name='n256s1/b')n = Conv2d(n, 256, (3, 3), (2, 2), act=lrelu, padding='SAME', W_init=w_init, b_init=b_init, name='n256s2/c')n = BatchNormLayer(n, is_train=is_train, gamma_init=g_init, name='n256s2/b')n = Conv2d(n, 512, (3, 3), (1, 1), act=lrelu, padding='SAME', W_init=w_init, b_init=b_init, name='n512s1/c')n = BatchNormLayer(n, is_train=is_train, gamma_init=g_init, name='n512s1/b')n = Conv2d(n, 512, (3, 3), (2, 2), act=lrelu, padding='SAME', W_init=w_init, b_init=b_init, name='n512s2/c')n = BatchNormLayer(n, is_train=is_train, gamma_init=g_init, name='n512s2/b')n = FlattenLayer(n, name='f')n = DenseLayer(n, n_units=1024, act=lrelu, name='d1024')n = DenseLayer(n, n_units=1, name='out')logits = n.outputsn.outputs = tf.nn.sigmoid(n.outputs)return n, logitsdef SRGAN_d(input_images, is_train=True, reuse=False):w_init = tf.random_normal_initializer(stddev=0.02)b_init = None # tf.constant_initializer(value=0.0)gamma_init = tf.random_normal_initializer(1., 0.02)df_dim = 64lrelu = lambda x: tl.act.lrelu(x, 0.2)# 开始进行网络的构造with tf.variable_scope("SRGAN_d", reuse=reuse):tl.layers.set_name_reuse(reuse)net_in = InputLayer(input_images, name='input/images')net_h0 = Conv2d(net_in, df_dim, (4, 4), (2, 2), act=lrelu, padding='SAME', W_init=w_init, name='h0/c')net_h1 = Conv2d(net_h0, df_dim * 2, (4, 4), (2, 2), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='h1/c')net_h1 = BatchNormLayer(net_h1, act=lrelu, is_train=is_train, gamma_init=gamma_init, name='h1/bn')net_h2 = Conv2d(net_h1, df_dim * 4, (4, 4), (2, 2), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='h2/c')net_h2 = BatchNormLayer(net_h2, act=lrelu, is_train=is_train, gamma_init=gamma_init, name='h2/bn')net_h3 = Conv2d(net_h2, df_dim * 8, (4, 4), (2, 2), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='h3/c')net_h3 = BatchNormLayer(net_h3, act=lrelu, is_train=is_train, gamma_init=gamma_init, name='h3/bn')net_h4 = Conv2d(net_h3, df_dim * 16, (4, 4), (2, 2), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='h4/c')net_h4 = BatchNormLayer(net_h4, act=lrelu, is_train=is_train, gamma_init=gamma_init, name='h4/bn')net_h5 = Conv2d(net_h4, df_dim * 32, (4, 4), (2, 2), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='h5/c')net_h5 = BatchNormLayer(net_h5, act=lrelu, is_train=is_train, gamma_init=gamma_init, name='h5/bn')net_h6 = Conv2d(net_h5, df_dim * 16, (1, 1), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='h6/c')net_h6 = BatchNormLayer(net_h6, act=lrelu, is_train=is_train, gamma_init=gamma_init, name='h6/bn')net_h7 = Conv2d(net_h6, df_dim * 8, (1, 1), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='h7/c')net_h7 = BatchNormLayer(net_h7, is_train=is_train, gamma_init=gamma_init, name='h7/bn')net = Conv2d(net_h7, df_dim * 2, (1, 1), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='res/c')net = BatchNormLayer(net, act=lrelu, is_train=is_train, gamma_init=gamma_init, name='res/bn')net = Conv2d(net, df_dim * 2, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='res/c2')net = BatchNormLayer(net, act=lrelu, is_train=is_train, gamma_init=gamma_init, name='res/bn2')net = Conv2d(net, df_dim * 8, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='res/c3')net = BatchNormLayer(net, is_train=is_train, gamma_init=gamma_init, name='res/bn3')net_h8 = ElementwiseLayer([net_h7, net], combine_fn=tf.add, name='res/add')net_h8.outputs = tl.act.lrelu(net_h8.outputs, 0.2)# 拉长卷积结果,通过全连接层net_ho = FlattenLayer(net_h8, name='ho/flatten')net_ho = DenseLayer(net_ho, n_units=1, act=tf.identity, W_init=w_init, name='ho/dense')logits = net_ho.outputs# 经过sigmoid函数得到最终的结果值,判断是真还是假net_ho.outputs = tf.nn.sigmoid(net_ho.outputs)return net_ho, logitsdef Vgg19_simple_api(rgb, reuse):"""Build the VGG 19 ModelParameters-----------rgb : rgb image placeholder [batch, height, width, 3] values scaled [0, 1]"""VGG_MEAN = [103.939, 116.779, 123.68]with tf.variable_scope("VGG19", reuse=reuse) as vs:start_time = time.time()print("build model started")rgb_scaled = rgb * 255.0# Convert RGB to BGRred, green, blue = tf.split(rgb_scaled, 3, 3)assert red.get_shape().as_list()[1:] == [224, 224, 1]assert green.get_shape().as_list()[1:] == [224, 224, 1]assert blue.get_shape().as_list()[1:] == [224, 224, 1]# 减均值操作:各自的颜色通道减去各自的均值bgr = tf.concat([blue - VGG_MEAN[0],green - VGG_MEAN[1],red - VGG_MEAN[2],], axis=3)assert bgr.get_shape().as_list()[1:] == [224, 224, 3]""" input layer """net_in = InputLayer(bgr, name='input')""" conv1 """network = Conv2d(net_in, n_filter=64, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv1_1')network = Conv2d(network, n_filter=64, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv1_2')network = MaxPool2d(network, filter_size=(2, 2), strides=(2, 2), padding='SAME', name='pool1')""" conv2 """network = Conv2d(network, n_filter=128, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv2_1')network = Conv2d(network, n_filter=128, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv2_2')network = MaxPool2d(network, filter_size=(2, 2), strides=(2, 2), padding='SAME', name='pool2')""" conv3 """network = Conv2d(network, n_filter=256, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv3_1')network = Conv2d(network, n_filter=256, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv3_2')network = Conv2d(network, n_filter=256, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv3_3')network = Conv2d(network, n_filter=256, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv3_4')network = MaxPool2d(network, filter_size=(2, 2), strides=(2, 2), padding='SAME', name='pool3')""" conv4 """network = Conv2d(network, n_filter=512, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv4_1')network = Conv2d(network, n_filter=512, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv4_2')network = Conv2d(network, n_filter=512, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv4_3')network = Conv2d(network, n_filter=512, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv4_4')network = MaxPool2d(network, filter_size=(2, 2), strides=(2, 2), padding='SAME', name='pool4') # (batch_size, 14, 14, 512)conv = network""" conv5 """network = Conv2d(network, n_filter=512, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv5_1')network = Conv2d(network, n_filter=512, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv5_2')network = Conv2d(network, n_filter=512, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv5_3')network = Conv2d(network, n_filter=512, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='conv5_4')network = MaxPool2d(network, filter_size=(2, 2), strides=(2, 2), padding='SAME', name='pool5') # (batch_size, 7, 7, 512)""" fc 6~8 """# 拉长数据经过全连接层network = FlattenLayer(network, name='flatten')network = DenseLayer(network, n_units=4096, act=tf.nn.relu, name='fc6')network = DenseLayer(network, n_units=4096, act=tf.nn.relu, name='fc7')network = DenseLayer(network, n_units=1000, act=tf.identity, name='fc8')print("build model finished: %fs" % (time.time() - start_time))return network, conv# def vgg16_cnn_emb(t_image, reuse=False):
# """ t_image = 244x244 [0~255] """
# with tf.variable_scope("vgg16_cnn", reuse=reuse) as vs:
# tl.layers.set_name_reuse(reuse)
#
# mean = tf.constant([123.68, 116.779, 103.939], dtype=tf.float32, shape=[1, 1, 1, 3], name='img_mean')
# net_in = InputLayer(t_image - mean, name='vgg_input_im')
# """ conv1 """
# network = tl.layers.Conv2dLayer(net_in,
# act = tf.nn.relu,
# shape = [3, 3, 3, 64], # 64 features for each 3x3 patch
# strides = [1, 1, 1, 1],
# padding='SAME',
# name ='vgg_conv1_1')
# network = tl.layers.Conv2dLayer(network,
# act = tf.nn.relu,
# shape = [3, 3, 64, 64], # 64 features for each 3x3 patch
# strides = [1, 1, 1, 1],
# padding='SAME',
# name ='vgg_conv1_2')
# network = tl.layers.PoolLayer(network,
# ksize=[1, 2, 2, 1],
# strides=[1, 2, 2, 1],
# padding='SAME',
# pool = tf.nn.max_pool,
# name ='vgg_pool1')
# """ conv2 """
# network = tl.layers.Conv2dLayer(network,
# act = tf.nn.relu,
# shape = [3, 3, 64, 128], # 128 features for each 3x3 patch
# strides = [1, 1, 1, 1],
# padding='SAME',
# name ='vgg_conv2_1')
# network = tl.layers.Conv2dLayer(network,
# act = tf.nn.relu,
# shape = [3, 3, 128, 128], # 128 features for each 3x3 patch
# strides = [1, 1, 1, 1],
# padding='SAME',
# name ='vgg_conv2_2')
# network = tl.layers.PoolLayer(network,
# ksize=[1, 2, 2, 1],
# strides=[1, 2, 2, 1],
# padding='SAME',
# pool = tf.nn.max_pool,
# name ='vgg_pool2')
# """ conv3 """
# network = tl.layers.Conv2dLayer(network,
# act = tf.nn.relu,
# shape = [3, 3, 128, 256], # 256 features for each 3x3 patch
# strides = [1, 1, 1, 1],
# padding='SAME',
# name ='vgg_conv3_1')
# network = tl.layers.Conv2dLayer(network,
# act = tf.nn.relu,
# shape = [3, 3, 256, 256], # 256 features for each 3x3 patch
# strides = [1, 1, 1, 1],
# padding='SAME',
# name ='vgg_conv3_2')
# network = tl.layers.Conv2dLayer(network,
# act = tf.nn.relu,
# shape = [3, 3, 256, 256], # 256 features for each 3x3 patch
# strides = [1, 1, 1, 1],
# padding='SAME',
# name ='vgg_conv3_3')
# network = tl.layers.PoolLayer(network,
# ksize=[1, 2, 2, 1],
# strides=[1, 2, 2, 1],
# padding='SAME',
# pool = tf.nn.max_pool,
# name ='vgg_pool3')
# """ conv4 """
# network = tl.layers.Conv2dLayer(network,
# act = tf.nn.relu,
# shape = [3, 3, 256, 512], # 512 features for each 3x3 patch
# strides = [1, 1, 1, 1],
# padding='SAME',
# name ='vgg_conv4_1')
# network = tl.layers.Conv2dLayer(network,
# act = tf.nn.relu,
# shape = [3, 3, 512, 512], # 512 features for each 3x3 patch
# strides = [1, 1, 1, 1],
# padding='SAME',
# name ='vgg_conv4_2')
# network = tl.layers.Conv2dLayer(network,
# act = tf.nn.relu,
# shape = [3, 3, 512, 512], # 512 features for each 3x3 patch
# strides = [1, 1, 1, 1],
# padding='SAME',
# name ='vgg_conv4_3')
#
# network = tl.layers.PoolLayer(network,
# ksize=[1, 2, 2, 1],
# strides=[1, 2, 2, 1],
# padding='SAME',
# pool = tf.nn.max_pool,
# name ='vgg_pool4')
# conv4 = network
#
# """ conv5 """
# network = tl.layers.Conv2dLayer(network,
# act = tf.nn.relu,
# shape = [3, 3, 512, 512], # 512 features for each 3x3 patch
# strides = [1, 1, 1, 1],
# padding='SAME',
# name ='vgg_conv5_1')
# network = tl.layers.Conv2dLayer(network,
# act = tf.nn.relu,
# shape = [3, 3, 512, 512], # 512 features for each 3x3 patch
# strides = [1, 1, 1, 1],
# padding='SAME',
# name ='vgg_conv5_2')
# network = tl.layers.Conv2dLayer(network,
# act = tf.nn.relu,
# shape = [3, 3, 512, 512], # 512 features for each 3x3 patch
# strides = [1, 1, 1, 1],
# padding='SAME',
# name ='vgg_conv5_3')
# network = tl.layers.PoolLayer(network,
# ksize=[1, 2, 2, 1],
# strides=[1, 2, 2, 1],
# padding='SAME',
# pool = tf.nn.max_pool,
# name ='vgg_pool5')
#
# network = FlattenLayer(network, name='vgg_flatten')
#
# # # network = DropoutLayer(network, keep=0.6, is_fix=True, is_train=is_train, name='vgg_out/drop1')
# # new_network = tl.layers.DenseLayer(network, n_units=4096,
# # act = tf.nn.relu,
# # name = 'vgg_out/dense')
# #
# # # new_network = DropoutLayer(new_network, keep=0.8, is_fix=True, is_train=is_train, name='vgg_out/drop2')
# # new_network = DenseLayer(new_network, z_dim, #num_lstm_units,
# # b_init=None, name='vgg_out/out')
# return conv4, network
utils.py
import tensorflow as tf
import tensorlayer as tl
from tensorlayer.prepro import *
# from config import config, log_config
#
# img_path = config.TRAIN.img_pathimport scipy
import numpy as np
import osdef get_imgs_fn(file_name, path):""" Input an image path and name, return an image array """# return scipy.misc.imread(path + file_name).astype(np.float)return scipy.misc.imread(path + file_name, mode='RGB')def crop_sub_imgs_fn(x, is_random=True):x = crop(x, wrg=384, hrg=384, is_random=is_random)x = x / (255. / 2.)x = x - 1.return xdef downsample_fn(x):# We obtained the LR images by downsampling the HR images using bicubic kernel with downsampling factor r = 4.x = imresize(x, size=[96, 96], interp='bicubic', mode=None)x = x / (255. / 2.)x = x - 1.return x
四.数据集
下载地址:
DIV2K Datasethttps://data.vision.ee.ethz.ch/cvl/DIV2K/
五.测试网络
def evaluate():## create folders to save result imagessave_dir = "samples/{}".format(tl.global_flag['mode'])tl.files.exists_or_mkdir(save_dir)checkpoint_dir = "checkpoint"###====================== PRE-LOAD DATA ===========================#### train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False))# train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False))valid_hr_img_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False))valid_lr_img_list = sorted(tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False))## If your machine have enough memory, please pre-load the whole train set.# train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32)# for im in train_hr_imgs:# print(im.shape)valid_lr_imgs = tl.vis.read_images(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=8)# for im in valid_lr_imgs:# print(im.shape)valid_hr_imgs = tl.vis.read_images(valid_hr_img_list, path=config.VALID.hr_img_path, n_threads=8)# for im in valid_hr_imgs:# print(im.shape)# exit()###========================== DEFINE MODEL ============================###imid = 64 # 0: 企鹅 81: 蝴蝶 53: 鸟 64: 古堡valid_lr_img = valid_lr_imgs[imid]valid_hr_img = valid_hr_imgs[imid]# valid_lr_img = get_imgs_fn('test.png', 'data2017/') # if you want to test your own imagevalid_lr_img = (valid_lr_img / 127.5) - 1 # rescale to [-1, 1]# print(valid_lr_img.min(), valid_lr_img.max())size = valid_lr_img.shape# t_image = tf.placeholder('float32', [None, size[0], size[1], size[2]], name='input_image') # the old version of TL need to specify the image sizet_image = tf.placeholder('float32', [1, None, None, 3], name='input_image')net_g = SRGAN_g(t_image, is_train=False, reuse=False)###========================== RESTORE G =============================###sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False))tl.layers.initialize_global_variables(sess)tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_srgan.npz', network=net_g)###======================= EVALUATION =============================###start_time = time.time()out = sess.run(net_g.outputs, {t_image: [valid_lr_img]})print("took: %4.4fs" % (time.time() - start_time))print("LR size: %s / generated HR size: %s" % (size, out.shape)) # LR size: (339, 510, 3) / gen HR size: (1, 1356, 2040, 3)print("[*] save images")tl.vis.save_image(out[0], save_dir + '/valid_gen.png')tl.vis.save_image(valid_lr_img, save_dir + '/valid_lr.png')tl.vis.save_image(valid_hr_img, save_dir + '/valid_hr.png')out_bicu = scipy.misc.imresize(valid_lr_img, [size[0] * 4, size[1] * 4], interp='bicubic', mode=None)tl.vis.save_image(out_bicu, save_dir + '/valid_bicubic.png')
低分辨率图像:
resize后的图像:
生成网络生成出的图像:
高分辨率图像: