迁移学习实战

文章目录

    • 预处理
    • 网络模型的构建
    • 网络的训练
    • 网络模型的验证

预处理

from __future__ import absolute_import
from __future__ import division
from __future__ import print_functionimport tensorflow as tffrom tensorflow.python.ops import control_flow_opsdef apply_with_random_selector(x, func, num_cases):"""Computes func(x, sel), with sel sampled from [0...num_cases-1].Args:x: input Tensor.func: Python function to apply.num_cases: Python int32, number of cases to sample sel from.Returns:The result of func(x, sel), where func receives the value of theselector as a python integer, but sel is sampled dynamically."""sel = tf.random_uniform([], maxval=num_cases, dtype=tf.int32)# Pass the real x only to one of the func calls.return control_flow_ops.merge([func(control_flow_ops.switch(x, tf.equal(sel, case))[1], case)for case in range(num_cases)])[0]def distort_color(image, color_ordering=0, fast_mode=True, scope=None):"""Distort the color of a Tensor image.Each color distortion is non-commutative and thus ordering of the color opsmatters. Ideally we would randomly permute the ordering of the color ops.Rather then adding that level of complication, we select a distinct orderingof color ops for each preprocessing thread.Args:image: 3-D Tensor containing single image in [0, 1].color_ordering: Python int, a type of distortion (valid values: 0-3).fast_mode: Avoids slower ops (random_hue and random_contrast)scope: Optional scope for name_scope.Returns:3-D Tensor color-distorted image on range [0, 1]Raises:ValueError: if color_ordering not in [0, 3]"""with tf.name_scope(scope, 'distort_color', [image]):if fast_mode:if color_ordering == 0:image = tf.image.random_brightness(image, max_delta=32. / 255.)image = tf.image.random_saturation(image, lower=0.5, upper=1.5)else:image = tf.image.random_saturation(image, lower=0.5, upper=1.5)image = tf.image.random_brightness(image, max_delta=32. / 255.)else:if color_ordering == 0:image = tf.image.random_brightness(image, max_delta=32. / 255.)image = tf.image.random_saturation(image, lower=0.5, upper=1.5)image = tf.image.random_hue(image, max_delta=0.2)image = tf.image.random_contrast(image, lower=0.5, upper=1.5)elif color_ordering == 1:image = tf.image.random_saturation(image, lower=0.5, upper=1.5)image = tf.image.random_brightness(image, max_delta=32. / 255.)image = tf.image.random_contrast(image, lower=0.5, upper=1.5)image = tf.image.random_hue(image, max_delta=0.2)elif color_ordering == 2:image = tf.image.random_contrast(image, lower=0.5, upper=1.5)image = tf.image.random_hue(image, max_delta=0.2)image = tf.image.random_brightness(image, max_delta=32. / 255.)image = tf.image.random_saturation(image, lower=0.5, upper=1.5)elif color_ordering == 3:image = tf.image.random_hue(image, max_delta=0.2)image = tf.image.random_saturation(image, lower=0.5, upper=1.5)image = tf.image.random_contrast(image, lower=0.5, upper=1.5)image = tf.image.random_brightness(image, max_delta=32. / 255.)else:raise ValueError('color_ordering must be in [0, 3]')# The random_* ops do not necessarily clamp.return tf.clip_by_value(image, 0.0, 1.0)def distorted_bounding_box_crop(image,bbox,min_object_covered=0.1,aspect_ratio_range=(0.75, 1.33),area_range=(0.05, 1.0),max_attempts=100,scope=None):"""Generates cropped_image using a one of the bboxes randomly distorted.See `tf.image.sample_distorted_bounding_box` for more documentation.Args:image: 3-D Tensor of image (it will be converted to floats in [0, 1]).bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]where each coordinate is [0, 1) and the coordinates are arrangedas [ymin, xmin, ymax, xmax]. If num_boxes is 0 then it would use the wholeimage.min_object_covered: An optional `float`. Defaults to `0.1`. The croppedarea of the image must contain at least this fraction of any bounding boxsupplied.aspect_ratio_range: An optional list of `floats`. The cropped area of theimage must have an aspect ratio = width / height within this range.area_range: An optional list of `floats`. The cropped area of the imagemust contain a fraction of the supplied image within in this range.max_attempts: An optional `int`. Number of attempts at generating a croppedregion of the image of the specified constraints. After `max_attempts`failures, return the entire image.scope: Optional scope for name_scope.Returns:A tuple, a 3-D Tensor cropped_image and the distorted bbox"""with tf.name_scope(scope, 'distorted_bounding_box_crop', [image, bbox]):# Each bounding box has shape [1, num_boxes, box coords] and# the coordinates are ordered [ymin, xmin, ymax, xmax].# A large fraction of image datasets contain a human-annotated bounding# box delineating the region of the image containing the object of interest.# We choose to create a new bounding box for the object which is a randomly# distorted version of the human-annotated bounding box that obeys an# allowed range of aspect ratios, sizes and overlap with the human-annotated# bounding box. If no box is supplied, then we assume the bounding box is# the entire image.sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(tf.shape(image),bounding_boxes=bbox,min_object_covered=min_object_covered,aspect_ratio_range=aspect_ratio_range,area_range=area_range,max_attempts=max_attempts,use_image_if_no_bounding_boxes=True)bbox_begin, bbox_size, distort_bbox = sample_distorted_bounding_box# Crop the image to the specified bounding box.cropped_image = tf.slice(image, bbox_begin, bbox_size)return cropped_image, distort_bboxdef preprocess_for_train(image, height, width, bbox,fast_mode=True,scope=None):"""Distort one image for training a network.Distorting images provides a useful technique for augmenting the dataset during training in order to make the network invariant to aspectsof the image that do not effect the label.Additionally it would create image_summaries to display the differenttransformations applied to the image.Args:image: 3-D Tensor of image. If dtype is tf.float32 then the range should be[0, 1], otherwise it would converted to tf.float32 assuming that the rangeis [0, MAX], where MAX is largest positive representable number forint(8/16/32) data type (see `tf.image.convert_image_dtype` for details).height: integerwidth: integerbbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]where each coordinate is [0, 1) and the coordinates are arrangedas [ymin, xmin, ymax, xmax].fast_mode: Optional boolean, if True avoids slower transformations (i.e.bi-cubic resizing, random_hue or random_contrast).scope: Optional scope for name_scope.Returns:3-D float Tensor of distorted image used for training with range [-1, 1]."""with tf.name_scope(scope, 'distort_image', [image, height, width, bbox]):if bbox is None:bbox = tf.constant([0.0, 0.0, 1.0, 1.0],dtype=tf.float32,shape=[1, 1, 4])if image.dtype != tf.float32:image = tf.image.convert_image_dtype(image, dtype=tf.float32)# Each bounding box has shape [1, num_boxes, box coords] and# the coordinates are ordered [ymin, xmin, ymax, xmax].image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0),bbox)tf.summary.image('image_with_bounding_boxes', image_with_box)distorted_image, distorted_bbox = distorted_bounding_box_crop(image, bbox)# Restore the shape since the dynamic slice based upon the bbox_size loses# the third dimension.distorted_image.set_shape([None, None, 3])image_with_distorted_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0), distorted_bbox)tf.summary.image('images_with_distorted_bounding_box',image_with_distorted_box)# This resizing operation may distort the images because the aspect# ratio is not respected. We select a resize method in a round robin# fashion based on the thread number.# Note that ResizeMethod contains 4 enumerated resizing methods.# We select only 1 case for fast_mode bilinear.num_resize_cases = 1 if fast_mode else 4distorted_image = apply_with_random_selector(distorted_image,lambda x, method: tf.image.resize_images(x, [height, width], method=method),num_cases=num_resize_cases)tf.summary.image('cropped_resized_image',tf.expand_dims(distorted_image, 0))# Randomly flip the image horizontally.distorted_image = tf.image.random_flip_left_right(distorted_image)# Randomly distort the colors. There are 4 ways to do it.distorted_image = apply_with_random_selector(distorted_image,lambda x, ordering: distort_color(x, ordering, fast_mode),num_cases=4)tf.summary.image('final_distorted_image',tf.expand_dims(distorted_image, 0))distorted_image = tf.subtract(distorted_image, 0.5)distorted_image = tf.multiply(distorted_image, 2.0)return distorted_imagedef preprocess_for_eval(image, height, width,central_fraction=0.875, scope=None):"""Prepare one image for evaluation.If height and width are specified it would output an image with that size byapplying resize_bilinear.If central_fraction is specified it would cropt the central fraction of theinput image.Args:image: 3-D Tensor of image. If dtype is tf.float32 then the range should be[0, 1], otherwise it would converted to tf.float32 assuming that the rangeis [0, MAX], where MAX is largest positive representable number forint(8/16/32) data type (see `tf.image.convert_image_dtype` for details)height: integerwidth: integercentral_fraction: Optional Float, fraction of the image to crop.scope: Optional scope for name_scope.Returns:3-D float Tensor of prepared image."""with tf.name_scope(scope, 'eval_image', [image, height, width]):if image.dtype != tf.float32:image = tf.image.convert_image_dtype(image, dtype=tf.float32)# Crop the central region of the image with an area containing 87.5% of# the original image.if central_fraction:image = tf.image.central_crop(image, central_fraction=central_fraction)if height and width:# Resize the image to the specified height and width.image = tf.expand_dims(image, 0)image = tf.image.resize_bilinear(image, [height, width],align_corners=False)image = tf.squeeze(image, [0])image = tf.subtract(image, 0.5)image = tf.multiply(image, 2.0)return imagedef preprocess_image(image, height, width,is_training=False,bbox=None,fast_mode=True):"""Pre-process one image for training or evaluation.Args:image: 3-D Tensor [height, width, channels] with the image.height: integer, image expected height.width: integer, image expected width.is_training: Boolean. If true it would transform an image for train,otherwise it would transform it for evaluation.bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]where each coordinate is [0, 1) and the coordinates are arranged as[ymin, xmin, ymax, xmax].fast_mode: Optional boolean, if True avoids slower transformations.Returns:3-D float Tensor containing an appropriately scaled imageRaises:ValueError: if user does not provide bounding box"""if is_training:return preprocess_for_train(image, height, width, bbox, fast_mode)else:return preprocess_for_eval(image, height, width)

网络模型的构建

模型结构是将Inception结构与Resnet结构相结合。

from __future__ import absolute_import
from __future__ import division
from __future__ import print_functionimport tensorflow as tfslim = tf.contrib.slimdef block35(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None):"""Builds the 35x35 resnet block."""with tf.variable_scope(scope, 'Block35', [net], reuse=reuse):with tf.variable_scope('Branch_0'):tower_conv = slim.conv2d(net, 32, 1, scope='Conv2d_1x1')with tf.variable_scope('Branch_1'):tower_conv1_0 = slim.conv2d(net, 32, 1, scope='Conv2d_0a_1x1')tower_conv1_1 = slim.conv2d(tower_conv1_0, 32, 3, scope='Conv2d_0b_3x3')with tf.variable_scope('Branch_2'):tower_conv2_0 = slim.conv2d(net, 32, 1, scope='Conv2d_0a_1x1')tower_conv2_1 = slim.conv2d(tower_conv2_0, 48, 3, scope='Conv2d_0b_3x3')tower_conv2_2 = slim.conv2d(tower_conv2_1, 64, 3, scope='Conv2d_0c_3x3')mixed = tf.concat(axis=3, values=[tower_conv, tower_conv1_1, tower_conv2_2])up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None,activation_fn=None, scope='Conv2d_1x1')net += scale * upif activation_fn:net = activation_fn(net)return netdef block17(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None):"""Builds the 17x17 resnet block."""with tf.variable_scope(scope, 'Block17', [net], reuse=reuse):with tf.variable_scope('Branch_0'):tower_conv = slim.conv2d(net, 192, 1, scope='Conv2d_1x1')with tf.variable_scope('Branch_1'):tower_conv1_0 = slim.conv2d(net, 128, 1, scope='Conv2d_0a_1x1')tower_conv1_1 = slim.conv2d(tower_conv1_0, 160, [1, 7],scope='Conv2d_0b_1x7')tower_conv1_2 = slim.conv2d(tower_conv1_1, 192, [7, 1],scope='Conv2d_0c_7x1')mixed = tf.concat(axis=3, values=[tower_conv, tower_conv1_2])up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None,activation_fn=None, scope='Conv2d_1x1')net += scale * upif activation_fn:net = activation_fn(net)return netdef block8(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None):"""Builds the 8x8 resnet block."""with tf.variable_scope(scope, 'Block8', [net], reuse=reuse):with tf.variable_scope('Branch_0'):tower_conv = slim.conv2d(net, 192, 1, scope='Conv2d_1x1')with tf.variable_scope('Branch_1'):tower_conv1_0 = slim.conv2d(net, 192, 1, scope='Conv2d_0a_1x1')tower_conv1_1 = slim.conv2d(tower_conv1_0, 224, [1, 3],scope='Conv2d_0b_1x3')tower_conv1_2 = slim.conv2d(tower_conv1_1, 256, [3, 1],scope='Conv2d_0c_3x1')mixed = tf.concat(axis=3, values=[tower_conv, tower_conv1_2])up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None,activation_fn=None, scope='Conv2d_1x1')net += scale * upif activation_fn:net = activation_fn(net)return netdef inception_resnet_v2(inputs, num_classes=1001, is_training=True,dropout_keep_prob=0.8,reuse=None,scope='InceptionResnetV2'):"""Creates the Inception Resnet V2 model.Args:inputs: a 4-D tensor of size [batch_size, height, width, 3].num_classes: number of predicted classes.is_training: whether is training or not.dropout_keep_prob: float, the fraction to keep before final layer.reuse: whether or not the network and its variables should be reused. To beable to reuse 'scope' must be given.scope: Optional variable_scope.Returns:logits: the logits outputs of the model.end_points: the set of end_points from the inception model."""end_points = {}with tf.variable_scope(scope, 'InceptionResnetV2', [inputs], reuse=reuse):with slim.arg_scope([slim.batch_norm, slim.dropout],is_training=is_training):with slim.arg_scope([slim.conv2d, slim.max_pool2d, slim.avg_pool2d],stride=1, padding='SAME'):# 149 x 149 x 32net = slim.conv2d(inputs, 32, 3, stride=2, padding='VALID',scope='Conv2d_1a_3x3')end_points['Conv2d_1a_3x3'] = net# 147 x 147 x 32net = slim.conv2d(net, 32, 3, padding='VALID',scope='Conv2d_2a_3x3')end_points['Conv2d_2a_3x3'] = net# 147 x 147 x 64net = slim.conv2d(net, 64, 3, scope='Conv2d_2b_3x3')end_points['Conv2d_2b_3x3'] = net# 73 x 73 x 64net = slim.max_pool2d(net, 3, stride=2, padding='VALID',scope='MaxPool_3a_3x3')end_points['MaxPool_3a_3x3'] = net# 73 x 73 x 80net = slim.conv2d(net, 80, 1, padding='VALID',scope='Conv2d_3b_1x1')end_points['Conv2d_3b_1x1'] = net# 71 x 71 x 192net = slim.conv2d(net, 192, 3, padding='VALID',scope='Conv2d_4a_3x3')end_points['Conv2d_4a_3x3'] = net# 35 x 35 x 192net = slim.max_pool2d(net, 3, stride=2, padding='VALID',scope='MaxPool_5a_3x3')end_points['MaxPool_5a_3x3'] = net# 35 x 35 x 320with tf.variable_scope('Mixed_5b'):with tf.variable_scope('Branch_0'):tower_conv = slim.conv2d(net, 96, 1, scope='Conv2d_1x1')with tf.variable_scope('Branch_1'):tower_conv1_0 = slim.conv2d(net, 48, 1, scope='Conv2d_0a_1x1')tower_conv1_1 = slim.conv2d(tower_conv1_0, 64, 5,scope='Conv2d_0b_5x5')with tf.variable_scope('Branch_2'):tower_conv2_0 = slim.conv2d(net, 64, 1, scope='Conv2d_0a_1x1')tower_conv2_1 = slim.conv2d(tower_conv2_0, 96, 3,scope='Conv2d_0b_3x3')tower_conv2_2 = slim.conv2d(tower_conv2_1, 96, 3,scope='Conv2d_0c_3x3')with tf.variable_scope('Branch_3'):tower_pool = slim.avg_pool2d(net, 3, stride=1, padding='SAME',scope='AvgPool_0a_3x3')tower_pool_1 = slim.conv2d(tower_pool, 64, 1,scope='Conv2d_0b_1x1')net = tf.concat(axis=3, values=[tower_conv, tower_conv1_1,tower_conv2_2, tower_pool_1])end_points['Mixed_5b'] = netnet = slim.repeat(net, 10, block35, scale=0.17)# 17 x 17 x 1024with tf.variable_scope('Mixed_6a'):with tf.variable_scope('Branch_0'):tower_conv = slim.conv2d(net, 384, 3, stride=2, padding='VALID',scope='Conv2d_1a_3x3')with tf.variable_scope('Branch_1'):tower_conv1_0 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1')tower_conv1_1 = slim.conv2d(tower_conv1_0, 256, 3,scope='Conv2d_0b_3x3')tower_conv1_2 = slim.conv2d(tower_conv1_1, 384, 3,stride=2, padding='VALID',scope='Conv2d_1a_3x3')with tf.variable_scope('Branch_2'):tower_pool = slim.max_pool2d(net, 3, stride=2, padding='VALID',scope='MaxPool_1a_3x3')net = tf.concat(axis=3, values=[tower_conv, tower_conv1_2, tower_pool])end_points['Mixed_6a'] = netnet = slim.repeat(net, 20, block17, scale=0.10)# Auxillary towerwith tf.variable_scope('AuxLogits'):aux = slim.avg_pool2d(net, 5, stride=3, padding='VALID',scope='Conv2d_1a_3x3')aux = slim.conv2d(aux, 128, 1, scope='Conv2d_1b_1x1')aux = slim.conv2d(aux, 768, aux.get_shape()[1:3],padding='VALID', scope='Conv2d_2a_5x5')aux = slim.flatten(aux)aux = slim.fully_connected(aux, num_classes, activation_fn=None,scope='Logits')end_points['AuxLogits'] = auxwith tf.variable_scope('Mixed_7a'):with tf.variable_scope('Branch_0'):tower_conv = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1')tower_conv_1 = slim.conv2d(tower_conv, 384, 3, stride=2,padding='VALID', scope='Conv2d_1a_3x3')with tf.variable_scope('Branch_1'):tower_conv1 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1')tower_conv1_1 = slim.conv2d(tower_conv1, 288, 3, stride=2,padding='VALID', scope='Conv2d_1a_3x3')with tf.variable_scope('Branch_2'):tower_conv2 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1')tower_conv2_1 = slim.conv2d(tower_conv2, 288, 3,scope='Conv2d_0b_3x3')tower_conv2_2 = slim.conv2d(tower_conv2_1, 320, 3, stride=2,padding='VALID', scope='Conv2d_1a_3x3')with tf.variable_scope('Branch_3'):tower_pool = slim.max_pool2d(net, 3, stride=2, padding='VALID',scope='MaxPool_1a_3x3')net = tf.concat(axis=3, values=[tower_conv_1, tower_conv1_1,tower_conv2_2, tower_pool])end_points['Mixed_7a'] = netnet = slim.repeat(net, 9, block8, scale=0.20)net = block8(net, activation_fn=None)net = slim.conv2d(net, 1536, 1, scope='Conv2d_7b_1x1')end_points['Conv2d_7b_1x1'] = netwith tf.variable_scope('Logits'):end_points['PrePool'] = netnet = slim.avg_pool2d(net, net.get_shape()[1:3], padding='VALID',scope='AvgPool_1a_8x8')net = slim.flatten(net)net = slim.dropout(net, dropout_keep_prob, is_training=is_training,scope='Dropout')end_points['PreLogitsFlatten'] = netlogits = slim.fully_connected(net, num_classes, activation_fn=None,scope='Logits')end_points['Logits'] = logitsend_points['Predictions'] = tf.nn.softmax(logits, name='Predictions')return logits, end_points
inception_resnet_v2.default_image_size = 299def inception_resnet_v2_arg_scope(weight_decay=0.00004,batch_norm_decay=0.9997,batch_norm_epsilon=0.001):"""Yields the scope with the default parameters for inception_resnet_v2.Args:weight_decay: the weight decay for weights variables.batch_norm_decay: decay for the moving average of batch_norm momentums.batch_norm_epsilon: small float added to variance to avoid dividing by zero.Returns:a arg_scope with the parameters needed for inception_resnet_v2."""# Set weight_decay for weights in conv2d and fully_connected layers.with slim.arg_scope([slim.conv2d, slim.fully_connected],weights_regularizer=slim.l2_regularizer(weight_decay),biases_regularizer=slim.l2_regularizer(weight_decay)):batch_norm_params = {'decay': batch_norm_decay,'epsilon': batch_norm_epsilon,}# Set activation_fn and parameters for batch_norm.with slim.arg_scope([slim.conv2d], activation_fn=tf.nn.relu,normalizer_fn=slim.batch_norm,normalizer_params=batch_norm_params) as scope:return scope

网络的训练

import tensorflow as tf
from tensorflow.contrib.framework.python.ops.variables import get_or_create_global_step
from tensorflow.python.platform import tf_logging as logging
import inception_preprocessing
from inception_resnet_v2 import inception_resnet_v2, inception_resnet_v2_arg_scope
import os
import time
slim = tf.contrib.slim#================ DATASET INFORMATION ======================
#State dataset directory where the tfrecord files are located
dataset_dir = '.'#State where your log file is at. If it doesn't exist, create it.
log_dir = './log'#State where your checkpoint file is
checkpoint_file = './inception_resnet_v2_2016_08_30.ckpt'#State the image size you're resizing your images to. We will use the default inception size of 299.
image_size = 299#State the number of classes to predict:
num_classes = 5#State the labels file and read it
labels_file = './labels.txt'
labels = open(labels_file, 'r')#Create a dictionary to refer each label to their string name
labels_to_name = {}
for line in labels:label, string_name = line.split(':')string_name = string_name[:-1] #Remove newlinelabels_to_name[int(label)] = string_name#Create the file pattern of your TFRecord files so that it could be recognized later on
file_pattern = 'flowers_%s_*.tfrecord'#Create a dictionary that will help people understand your dataset better. This is required by the Dataset class later.
items_to_descriptions = {'image': 'A 3-channel RGB coloured flower image that is either tulips, sunflowers, roses, dandelion, or daisy.','label': 'A label that is as such -- 0:daisy, 1:dandelion, 2:roses, 3:sunflowers, 4:tulips'
}#================= TRAINING INFORMATION ==================
#State the number of epochs to train
num_epochs = 1#State your batch size
batch_size = 8#Learning rate information and configuration (Up to you to experiment)
initial_learning_rate = 0.0002
learning_rate_decay_factor = 0.7
num_epochs_before_decay = 2#============== DATASET LOADING ======================
#We now create a function that creates a Dataset class which will give us many TFRecord files to feed in the examples into a queue in parallel.
def get_split(split_name, dataset_dir, file_pattern=file_pattern, file_pattern_for_counting='flowers'):'''Obtains the split - training or validation - to create a Dataset class for feeding the examples into a queue later on. This function willset up the decoder and dataset information all into one Dataset class so that you can avoid the brute work later on.Your file_pattern is very important in locating the files later. INPUTS:- split_name(str): 'train' or 'validation'. Used to get the correct data split of tfrecord files- dataset_dir(str): the dataset directory where the tfrecord files are located- file_pattern(str): the file name structure of the tfrecord files in order to get the correct data- file_pattern_for_counting(str): the string name to identify your tfrecord files for countingOUTPUTS:- dataset (Dataset): A Dataset class object where we can read its various components for easier batch creation later.'''#First check whether the split_name is train or validationif split_name not in ['train', 'validation']:raise ValueError('The split_name %s is not recognized. Please input either train or validation as the split_name' % (split_name))#Create the full path for a general file_pattern to locate the tfrecord_filesfile_pattern_path = os.path.join(dataset_dir, file_pattern % (split_name))#Count the total number of examples in all of these shardnum_samples = 0file_pattern_for_counting = file_pattern_for_counting + '_' + split_nametfrecords_to_count = [os.path.join(dataset_dir, file) for file in os.listdir(dataset_dir) if file.startswith(file_pattern_for_counting)]for tfrecord_file in tfrecords_to_count:for record in tf.python_io.tf_record_iterator(tfrecord_file):num_samples += 1#Create a reader, which must be a TFRecord reader in this casereader = tf.TFRecordReader#Create the keys_to_features dictionary for the decoderkeys_to_features = {'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),'image/format': tf.FixedLenFeature((), tf.string, default_value='jpg'),'image/class/label': tf.FixedLenFeature([], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),}#Create the items_to_handlers dictionary for the decoder.items_to_handlers = {'image': slim.tfexample_decoder.Image(),'label': slim.tfexample_decoder.Tensor('image/class/label'),}#Start to create the decoderdecoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)#Create the labels_to_name filelabels_to_name_dict = labels_to_name#Actually create the datasetdataset = slim.dataset.Dataset(data_sources = file_pattern_path,decoder = decoder,reader = reader,num_readers = 4,num_samples = num_samples,num_classes = num_classes,labels_to_name = labels_to_name_dict,items_to_descriptions = items_to_descriptions)return datasetdef load_batch(dataset, batch_size, height=image_size, width=image_size, is_training=True):'''Loads a batch for training.INPUTS:- dataset(Dataset): a Dataset class object that is created from the get_split function- batch_size(int): determines how big of a batch to train- height(int): the height of the image to resize to during preprocessing- width(int): the width of the image to resize to during preprocessing- is_training(bool): to determine whether to perform a training or evaluation preprocessingOUTPUTS:- images(Tensor): a Tensor of the shape (batch_size, height, width, channels) that contain one batch of images- labels(Tensor): the batch's labels with the shape (batch_size,) (requires one_hot_encoding).'''#First create the data_provider objectdata_provider = slim.dataset_data_provider.DatasetDataProvider(dataset,common_queue_capacity = 24 + 3 * batch_size,common_queue_min = 24)#Obtain the raw image using the get methodraw_image, label = data_provider.get(['image', 'label'])#Perform the correct preprocessing for this image depending if it is training or evaluatingimage = inception_preprocessing.preprocess_image(raw_image, height, width, is_training)#As for the raw images, we just do a simple reshape to batch it upraw_image = tf.expand_dims(raw_image, 0)raw_image = tf.image.resize_nearest_neighbor(raw_image, [height, width])raw_image = tf.squeeze(raw_image)#Batch up the image by enqueing the tensors internally in a FIFO queue and dequeueing many elements with tf.train.batch.images, raw_images, labels = tf.train.batch([image, raw_image, label],batch_size = batch_size,num_threads = 4,capacity = 4 * batch_size,allow_smaller_final_batch = True)return images, raw_images, labelsdef run():#Create the log directory here. Must be done here otherwise import will activate this unneededly.if not os.path.exists(log_dir):os.mkdir(log_dir)#======================= TRAINING PROCESS =========================#Now we start to construct the graph and build our modelwith tf.Graph().as_default() as graph:tf.logging.set_verbosity(tf.logging.INFO) #Set the verbosity to INFO level#First create the dataset and load one batchdataset = get_split('train', dataset_dir, file_pattern=file_pattern)images, _, labels = load_batch(dataset, batch_size=batch_size)#Know the number steps to take before decaying the learning rate and batches per epochnum_batches_per_epoch = int(dataset.num_samples / batch_size)num_steps_per_epoch = num_batches_per_epoch #Because one step is one batch processeddecay_steps = int(num_epochs_before_decay * num_steps_per_epoch)#Create the model inferencewith slim.arg_scope(inception_resnet_v2_arg_scope()):logits, end_points = inception_resnet_v2(images, num_classes = dataset.num_classes, is_training = True)#Define the scopes that you want to exclude for restorationexclude = ['InceptionResnetV2/Logits', 'InceptionResnetV2/AuxLogits']variables_to_restore = slim.get_variables_to_restore(exclude = exclude)#Perform one-hot-encoding of the labels (Try one-hot-encoding within the load_batch function!)one_hot_labels = slim.one_hot_encoding(labels, dataset.num_classes)#Performs the equivalent to tf.nn.sparse_softmax_cross_entropy_with_logits but enhanced with checksloss = tf.losses.softmax_cross_entropy(onehot_labels = one_hot_labels, logits = logits)total_loss = tf.losses.get_total_loss()    #obtain the regularization losses as well#Create the global step for monitoring the learning_rate and training.global_step = get_or_create_global_step()#Define your exponentially decaying learning ratelr = tf.train.exponential_decay(learning_rate = initial_learning_rate,global_step = global_step,decay_steps = decay_steps,decay_rate = learning_rate_decay_factor,staircase = True)#Now we can define the optimizer that takes on the learning rateoptimizer = tf.train.AdamOptimizer(learning_rate = lr)#Create the train_op.train_op = slim.learning.create_train_op(total_loss, optimizer)#State the metrics that you want to predict. We get a predictions that is not one_hot_encoded.predictions = tf.argmax(end_points['Predictions'], 1)probabilities = end_points['Predictions']accuracy, accuracy_update = tf.contrib.metrics.streaming_accuracy(predictions, labels)metrics_op = tf.group(accuracy_update, probabilities)#Now finally create all the summaries you need to monitor and group them into one summary op.tf.summary.scalar('losses/Total_Loss', total_loss)tf.summary.scalar('accuracy', accuracy)tf.summary.scalar('learning_rate', lr)my_summary_op = tf.summary.merge_all()#Now we need to create a training step function that runs both the train_op, metrics_op and updates the global_step concurrently.def train_step(sess, train_op, global_step):'''Simply runs a session for the three arguments provided and gives a logging on the time elapsed for each global step'''#Check the time for each sess runstart_time = time.time()total_loss, global_step_count, _ = sess.run([train_op, global_step, metrics_op])time_elapsed = time.time() - start_time#Run the logging to print some resultslogging.info('global step %s: loss: %.4f (%.2f sec/step)', global_step_count, total_loss, time_elapsed)return total_loss, global_step_count#Now we create a saver function that actually restores the variables from a checkpoint file in a sesssaver = tf.train.Saver(variables_to_restore)def restore_fn(sess):return saver.restore(sess, checkpoint_file)#Define your supervisor for running a managed session. Do not run the summary_op automatically or else it will consume too much memorysv = tf.train.Supervisor(logdir = log_dir, summary_op = None, init_fn = restore_fn)#Run the managed sessionwith sv.managed_session() as sess:for step in xrange(num_steps_per_epoch * num_epochs):#At the start of every epoch, show the vital information:if step % num_batches_per_epoch == 0:logging.info('Epoch %s/%s', step/num_batches_per_epoch + 1, num_epochs)learning_rate_value, accuracy_value = sess.run([lr, accuracy])logging.info('Current Learning Rate: %s', learning_rate_value)logging.info('Current Streaming Accuracy: %s', accuracy_value)# optionally, print your logits and predictions for a sanity check that things are going fine.logits_value, probabilities_value, predictions_value, labels_value = sess.run([logits, probabilities, predictions, labels])print 'logits: \n', logits_valueprint 'Probabilities: \n', probabilities_valueprint 'predictions: \n', predictions_valueprint 'Labels:\n:', labels_value#Log the summaries every 10 step.if step % 10 == 0:loss, _ = train_step(sess, train_op, sv.global_step)summaries = sess.run(my_summary_op)sv.summary_computed(sess, summaries)#If not, simply run the training stepelse:loss, _ = train_step(sess, train_op, sv.global_step)#We log the final training loss and accuracylogging.info('Final Loss: %s', loss)logging.info('Final Accuracy: %s', sess.run(accuracy))#Once all the training has been done, save the log files and checkpoint modellogging.info('Finished training! Saving model to disk now.')# saver.save(sess, "./flowers_model.ckpt")sv.saver.save(sess, sv.save_path, global_step = sv.global_step)if __name__ == '__main__':run()

网络模型的验证

import tensorflow as tf
from tensorflow.python.platform import tf_logging as logging
from tensorflow.contrib.framework.python.ops.variables import get_or_create_global_step
import inception_preprocessing
from inception_resnet_v2 import inception_resnet_v2, inception_resnet_v2_arg_scope
import time
import os
from train_flowers import get_split, load_batch
import matplotlib.pyplot as plt
plt.style.use('ggplot')
slim = tf.contrib.slim#State your log directory where you can retrieve your model
log_dir = './log'#Create a new evaluation log directory to visualize the validation process
log_eval = './log_eval_test'#State the dataset directory where the validation set is found
dataset_dir = '.'#State the batch_size to evaluate each time, which can be a lot more than the training batch
batch_size = 36#State the number of epochs to evaluate
num_epochs = 3#Get the latest checkpoint file
checkpoint_file = tf.train.latest_checkpoint(log_dir)def run():#Create log_dir for evaluation informationif not os.path.exists(log_eval):os.mkdir(log_eval)#Just construct the graph from scratch againwith tf.Graph().as_default() as graph:tf.logging.set_verbosity(tf.logging.INFO)#Get the dataset first and load one batch of validation images and labels tensors. Set is_training as False so as to use the evaluation preprocessingdataset = get_split('validation', dataset_dir)images, raw_images, labels = load_batch(dataset, batch_size = batch_size, is_training = False)#Create some information about the training stepsnum_batches_per_epoch = dataset.num_samples / batch_sizenum_steps_per_epoch = num_batches_per_epoch#Now create the inference model but set is_training=Falsewith slim.arg_scope(inception_resnet_v2_arg_scope()):logits, end_points = inception_resnet_v2(images, num_classes = dataset.num_classes, is_training = False)# #get all the variables to restore from the checkpoint file and create the saver function to restorevariables_to_restore = slim.get_variables_to_restore()saver = tf.train.Saver(variables_to_restore)def restore_fn(sess):return saver.restore(sess, checkpoint_file)#Just define the metrics to track without the loss or whatsoeverpredictions = tf.argmax(end_points['Predictions'], 1)accuracy, accuracy_update = tf.contrib.metrics.streaming_accuracy(predictions, labels)metrics_op = tf.group(accuracy_update)#Create the global step and an increment op for monitoringglobal_step = get_or_create_global_step()global_step_op = tf.assign(global_step, global_step + 1) #no apply_gradient method so manually increasing the global_step#Create a evaluation step functiondef eval_step(sess, metrics_op, global_step):'''Simply takes in a session, runs the metrics op and some logging information.'''start_time = time.time()_, global_step_count, accuracy_value = sess.run([metrics_op, global_step_op, accuracy])time_elapsed = time.time() - start_time#Log some informationlogging.info('Global Step %s: Streaming Accuracy: %.4f (%.2f sec/step)', global_step_count, accuracy_value, time_elapsed)return accuracy_value#Define some scalar quantities to monitortf.summary.scalar('Validation_Accuracy', accuracy)my_summary_op = tf.summary.merge_all()#Get your supervisorsv = tf.train.Supervisor(logdir = log_eval, summary_op = None, saver = None, init_fn = restore_fn)#Now we are ready to run in one sessionwith sv.managed_session() as sess:for step in xrange(num_steps_per_epoch * num_epochs):sess.run(sv.global_step)#print vital information every start of the epoch as alwaysif step % num_batches_per_epoch == 0:logging.info('Epoch: %s/%s', step / num_batches_per_epoch + 1, num_epochs)logging.info('Current Streaming Accuracy: %.4f', sess.run(accuracy))#Compute summaries every 10 steps and continue evaluatingif step % 10 == 0:eval_step(sess, metrics_op = metrics_op, global_step = sv.global_step)summaries = sess.run(my_summary_op)sv.summary_computed(sess, summaries)#Otherwise just run as per normalelse:eval_step(sess, metrics_op = metrics_op, global_step = sv.global_step)#At the end of all the evaluation, show the final accuracylogging.info('Final Streaming Accuracy: %.4f', sess.run(accuracy))#Now we want to visualize the last batch's images just to see what our model has predictedraw_images, labels, predictions = sess.run([raw_images, labels, predictions])for i in range(10):image, label, prediction = raw_images[i], labels[i], predictions[i]prediction_name, label_name = dataset.labels_to_name[prediction], dataset.labels_to_name[label]text = 'Prediction: %s \n Ground Truth: %s' %(prediction_name, label_name)img_plot = plt.imshow(image)#Set up the plot and hide axesplt.title(text)img_plot.axes.get_yaxis().set_ticks([])img_plot.axes.get_xaxis().set_ticks([])plt.show()logging.info('Model evaluation has completed! Visit TensorBoard for more information regarding your evaluation.')if __name__ == '__main__':run()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/467694.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Spring MVC-视图解析器(View Resolverr)-内部资源视图解析器(Internal Resource View Resolver)示例(转载实践)...

以下内容翻译自:https://www.tutorialspoint.com/springmvc/springmvc_internalresourceviewresolver.htm 说明:示例基于Spring MVC 4.1.6。 InternalResourceViewResolver用于将提供的URI解析为实际的URI。以下示例显示如何使用Spring Web MVC框架使用I…

手机是怎么确定位置信息的?

#手机是怎么定位的?定位是一个老生常谈的话题,最近几年还在讨论一个热点技术话题「室内定位」,从我知道这个技术到现在已经过了好几年了,也出现了一些室内定位的方案,而我们的手机是如何进行定位的,又有哪些…

tablestore列式存储原理_10分钟搞透:技术人必会的MySQL体系结构与存储引擎!

MySQL是目前使用最广的开源数据库,不管从装机量、使用人群、专职人员、社区发展,还是基于MySQL的其他分支,都是当之无愧的No.1。 本文将从以下4个方面,带你搞透MySQL体系结构与存储引擎。主要包括:1、MySQL数据库的体系…

使用Adobe Audition生成基本音频

#首先打开软件 #新建一个音频 #在效果菜单栏插入基本音频

viewpager 跳转到指定页面

viewPager.setCurrentItem(getUserIndex(userId)); seCurrentItem是定位到指定页面参数是position转载于:https://www.cnblogs.com/wmxl/p/7500647.html

vue里实现同步执行方法_vue中的watch方法 实时同步存储数据

RFID基础知识 BS:BinarySearch. TSA:TimeSlottedAloha. BSA:基本二进制搜索算法. DBSA:动态二进制搜索算法. RBSA:后退式二进制搜索算法. FSA:Frame Slo ... WampServer修改MySQL密码 WampServer安装后密码是空的,需要设置一下 一般有两种方式: 一是通过phpMyAdmin直接修改: 二…

更多网络类型

文章目录丰富网络类型CPPN孪生网络Triplet Network应用Variational Auto-encoder强化学习Markov decision processesBellman公式丰富网络类型 深度学习除了经典卷积神经网络,循环神经网络还有广泛的网络类型 CPPN 网络输入是像素坐标值(x,…

Tyvj1091

题目链接 分析: 最近做的区间dp挺多 最简单的:n^3枚举,显然TLE 其实有一个很显然的dp状态: f[i][j]表示结尾是i,j的等差数列的数量: f[i][j]Σ(f[k][j]1) (a[i]-a[j]a[j]-a[k]) 但是这样的复杂度也是…

更多框架

文章目录关于框架Caffe基于层的设计思路Protocol Buffer 技术prototxt .caffemodel文件caffe的训练主要特点TorchLUA语言主要特点Tensorflowcomputation graphsMXNET关于框架 Caffe 依赖大量第三方库 为了读取图像,以及简单的图像处理,连接很重的Openc…

.net的label的背景如何设置成为透明_css如何设置背景图片?background属性添加背景图片...

在前端开发过程中,为了页面的美观,往往都会给html页面添加背景图片。那么如何利用css设置html中用图片做背景?本章就给大家介绍css怎样设置背景图片。有一定的参考价值,有需要的朋友可以参考一下,希望对你有所帮助。程…

苹果uwb定位技术

昨天的文章简单说明了手机定位的技术,文章写的比较简单,但是阅读量却还可以,这篇文章转一个uwb定位技术的文章,让更多的人了解这项技术。相关阅读:手机是怎么确定位置信息的?载波和LoRa#前言关于昨天的文章…

斗地主游戏小结

文章目录程序简介程序的实现程序中使用到的类Card:一副扑克牌。CardGroup:Player:玩家Game:游戏主程序Scene:游戏界面人工智能部分实现分析选牌牌型分析手牌并拆分成有效牌型出牌/跟牌分析拆牌原则:程序简介…

201671010144 2016-2017 《java程序设计》--对象与类!

随着学习的逐渐深入,java语言的独特之处与其他语言的区别慢慢显现了出来。我认为java的独特之处就在于它处理事件的思维,即面向对象,而非面向过程。初次接触到类的概念,比较难懂,后面练习过几次后,有了初步…

springboot导包显示不存在_(一)SpringBoot搭建基本后端应用

boot在计算机术语中是启动的意思,SpringBoot也就是Spring的启动器。稍有经验的JavaWeb程序员对于传统SSM结构的MVC应用,大多数最不好的体验就是搭建一个项目需要进行大量的配置。稍有不慎就可能采坑。更关键的是有些配置基本就不会去进行定制化修改。为了…

为什么需要超出48K的音频采样率,以及PCM到DSD的演进

网上很多观点说,根据采样定理,48K的音频采样率即可无损的表示音频模拟信号(人耳最多可以听到20K的音频),为何还需要96K, 192K等更高的采样率呢?最先我也有这样的疑问,毕竟采样定理是…

171. Excel Sheet Column Number (Easy)

Given a column title as appear in an Excel sheet, return its corresponding column number. For example: A -> 1B -> 2C -> 3...Z -> 26AA -> 27AB -> 28 思路:1.26进制转化为十进制,ord()函数;2.逐个读入字符串中的每…

f12获取网页文本_8招教你快速搞定网页内容禁止复制粘贴,想怎么复制就怎么复制...

大家平时在搜索资料、浏览网页时,经常会复制一些内容。尤其是文字比较多时,比起一个个字手打,复制能省下不少功夫。可有时候好不容易找到资料了,却发现有些网站上的内容文本复制不了?甚至右键菜单都打不开!…

为什么需要超过48k的采样音频?

最近在看音频的事情,随便拿点东西出来聊一下,如果说的不对,请用棒槌来打我,这样我晚上睡觉就不用数绵羊了。我播放一个20HZ~20KHZ的音频,如下图我使用16K的采样率来采集它是声音信号,获取音频如下图我使用4…

HDU 2859 Phalanx(二维DP)

题目链接:http://acm.hdu.edu.cn/showproblem.php?pid2859 题目大意:对称矩阵是这样的矩阵,它由“左下到右”线对称。 相应位置的元素应该相同。 例如,这里是3 * 3对称矩阵: cbx     cpb     zcc 给出任意的…

Linux io内存存在的意义~

今天是母亲节,首先祝各位读者的母亲节日快乐,祝你们的母亲年轻健康。母亲节是一个亘古的话题,我本来想写个文章,但是想起来这周就一个周末,要花点时间陪下家人,昨天我们老大开会,特别说了&#…