https://github.com/zonghaofan/pig-seg/tree/master/disk_segmentation
网络架构:
# coding:utf-8
import tensorflow as tf
import cv2
import numpy as np
import matplotlib.pyplot as pltimg = cv2.imread('./data/test.png')# cv2.imshow('1.jpg',img)
# cv2.waitKey(0)
img = cv2.resize(img, (1024, 1024))
img = np.array(img).astype(np.float32)
img = img[np.newaxis, ...]
print(img.shape)x_input = tf.placeholder(shape=[None, 1024, 1024, 3], dtype=tf.float32)# x=tf.random_normal(shape=[1,1024,1024,3],dtype=tf.float32)
n_filters = [8, 8]# name=1
def conv2d(x, n_filters, training, name, pool=True, activation=tf.nn.relu):with tf.variable_scope('layer{}'.format(name)):for index, filter in enumerate(n_filters):conv = tf.layers.conv2d(x, filter, (3, 3), strides=1, padding='same', activation=None,name='conv_{}'.format(index + 1))conv = tf.layers.batch_normalization(conv, training=training, name='bn_{}'.format(index + 1))conv = activation(conv, name='relu{}_{}'.format(name, index + 1))if pool is False:return convpool = tf.layers.max_pooling2d(conv, pool_size=(2, 2), strides=2, name='pool_{}'.format(name))return conv, pooldef upsampling_2d(tensor, name, size=(2, 2)):h_, w_, c_ = tensor.get_shape().as_list()[1:]h_multi, w_multi = sizeh = h_multi * h_w = w_multi * w_target = tf.image.resize_nearest_neighbor(tensor, size=(h, w), name='upsample_{}'.format(name))return targetdef upsampling_concat(input_A, input_B, name):upsampling = upsampling_2d(input_A, name=name, size=(2, 2))up_concat = tf.concat([upsampling, input_B], axis=-1, name='up_concat_{}'.format(name))return up_concatdef unet(input):#归一化 -1~1input=(input-127.5)/127.5conv1, pool1 = conv2d(input, [8, 8], training=True, name=1)print(conv1.shape)print(pool1.shape)conv2, pool2 = conv2d(pool1, [16, 16], training=True, name=2)print(conv2.shape)print(pool2.shape)conv3, pool3 = conv2d(pool2, [32, 32], training=True, name=3)print(conv3.shape)print(pool3.shape)conv4, pool4 = conv2d(pool3, [64, 64], training=True, name=4)print(conv4.shape)print(pool4.shape)conv5 = conv2d(pool4, [128, 128], training=True, pool=False, name=5)print(conv5.shape)up6 = upsampling_concat(conv5, conv4, name=6)print('up6', up6.shape)conv6 = conv2d(up6, [64, 64], training=True, pool=False, name=6)print(conv6.shape)up7 = upsampling_concat(conv6, conv3, name=7)print('up7', up7.shape)conv7 = conv2d(up7, [32, 32], training=True, pool=False, name=7)print(conv7.shape)up8 = upsampling_concat(conv7, conv2, name=8)print('up8', up8.shape)conv8 = conv2d(up8, [16, 16], training=True, pool=False, name=8)print(conv8.shape)up9 = upsampling_concat(conv8, conv1, name=9)print('up9', up9.shape)conv9 = conv2d(up9, [8, 8], training=True, pool=False, name=9)print(conv9.shape)final = tf.layers.conv2d(conv9, 1, (1, 1), name='final', activation=tf.nn.sigmoid, padding='same')print('final', final.shape)return finalif __name__ == '__main__':final=unet(x_input)with tf.Session() as sess:sess.run(tf.global_variables_initializer())y_final = sess.run(final, feed_dict={x_input: img})result = y_final[0, ...]print(result.shape)print(result[...,:10])# result=cv2.imread('./2.jpg')# result=cv2.resize(result,(640,640))# print(result)cv2.imshow('1.jpg', result)cv2.waitKey(0)
打印结果:这里打印值有小数,故直接imshow就是输出图,而如果imwrite,查看图片的值全是0,1,轮廓也能看清,只不过不是很清晰。
输入:
输出:截图没有完全