Keras】基于SegNet和U-Net的遥感图像语义分割

from:【Keras】基于SegNet和U-Net的遥感图像语义分割

上两个月参加了个比赛,做的是对遥感高清图像做语义分割,美其名曰“天空之眼”。这两周数据挖掘课期末project我们组选的课题也是遥感图像的语义分割,所以刚好又把前段时间做的成果重新整理和加强了一下,故写了这篇文章,记录一下用深度学习做遥感图像语义分割的完整流程以及一些好的思路和技巧。

 

数据集

首先介绍一下数据,我们这次采用的数据集是CCF大数据比赛提供的数据(2015年中国南方某城市的高清遥感图像),这是一个小数据集,里面包含了5张带标注的大尺寸RGB遥感图像(尺寸范围从3000×3000到6000×6000),里面一共标注了4类物体,植被(标记1)、建筑(标记2)、水体(标记3)、道路(标记4)以及其他(标记0)。其中,耕地、林地、草地均归为植被类,为了更好地观察标注情况,我们将其中三幅训练图片可视化如下:蓝色-水体,黄色-房屋,绿色-植被,棕色-马路。更多数据介绍可以参看这里。

现在说一说我们的数据处理的步骤。我们现在拥有的是5张大尺寸的遥感图像,我们不能直接把这些图像送入网络进行训练,因为内存承受不了而且他们的尺寸也各不相同。因此,我们首先将他们做随机切割,即随机生成x,y坐标,然后抠出该坐标下256*256的小图,并做以下数据增强操作:

  1. 原图和label图都需要旋转:90度,180度,270度
  2. 原图和label图都需要做沿y轴的镜像操作
  3. 原图做模糊操作
  4. 原图做光照调整操作
  5. 原图做增加噪声操作(高斯噪声,椒盐噪声)

这里我没有采用Keras自带的数据增广函数,而是自己使用opencv编写了相应的增强函数。

 
  1. img_w = 256

  2. img_h = 256

  3.  
  4. image_sets = ['1.png','2.png','3.png','4.png','5.png']

  5.  
  6. def gamma_transform(img, gamma):

  7. gamma_table = [np.power(x / 255.0, gamma) * 255.0 for x in range(256)]

  8. gamma_table = np.round(np.array(gamma_table)).astype(np.uint8)

  9. return cv2.LUT(img, gamma_table)

  10.  
  11. def random_gamma_transform(img, gamma_vari):

  12. log_gamma_vari = np.log(gamma_vari)

  13. alpha = np.random.uniform(-log_gamma_vari, log_gamma_vari)

  14. gamma = np.exp(alpha)

  15. return gamma_transform(img, gamma)

  16.  
  17.  
  18. def rotate(xb,yb,angle):

  19. M_rotate = cv2.getRotationMatrix2D((img_w/2, img_h/2), angle, 1)

  20. xb = cv2.warpAffine(xb, M_rotate, (img_w, img_h))

  21. yb = cv2.warpAffine(yb, M_rotate, (img_w, img_h))

  22. return xb,yb

  23.  
  24. def blur(img):

  25. img = cv2.blur(img, (3, 3));

  26. return img

  27.  
  28. def add_noise(img):

  29. for i in range(200): #添加点噪声

  30. temp_x = np.random.randint(0,img.shape[0])

  31. temp_y = np.random.randint(0,img.shape[1])

  32. img[temp_x][temp_y] = 255

  33. return img

  34.  
  35.  
  36. def data_augment(xb,yb):

  37. if np.random.random() < 0.25:

  38. xb,yb = rotate(xb,yb,90)

  39. if np.random.random() < 0.25:

  40. xb,yb = rotate(xb,yb,180)

  41. if np.random.random() < 0.25:

  42. xb,yb = rotate(xb,yb,270)

  43. if np.random.random() < 0.25:

  44. xb = cv2.flip(xb, 1) # flipcode > 0:沿y轴翻转

  45. yb = cv2.flip(yb, 1)

  46.  
  47. if np.random.random() < 0.25:

  48. xb = random_gamma_transform(xb,1.0)

  49.  
  50. if np.random.random() < 0.25:

  51. xb = blur(xb)

  52.  
  53. if np.random.random() < 0.2:

  54. xb = add_noise(xb)

  55.  
  56. return xb,yb

  57.  
  58. def creat_dataset(image_num = 100000, mode = 'original'):

  59. print('creating dataset...')

  60. image_each = image_num / len(image_sets)

  61. g_count = 0

  62. for i in tqdm(range(len(image_sets))):

  63. count = 0

  64. src_img = cv2.imread('./data/src/' + image_sets[i]) # 3 channels

  65. label_img = cv2.imread('./data/label/' + image_sets[i],cv2.IMREAD_GRAYSCALE) # single channel

  66. X_height,X_width,_ = src_img.shape

  67. while count < image_each:

  68. random_width = random.randint(0, X_width - img_w - 1)

  69. random_height = random.randint(0, X_height - img_h - 1)

  70. src_roi = src_img[random_height: random_height + img_h, random_width: random_width + img_w,:]

  71. label_roi = label_img[random_height: random_height + img_h, random_width: random_width + img_w]

  72. if mode == 'augment':

  73. src_roi,label_roi = data_augment(src_roi,label_roi)

  74.  
  75. visualize = np.zeros((256,256)).astype(np.uint8)

  76. visualize = label_roi *50

  77.  
  78. cv2.imwrite(('./aug/train/visualize/%d.png' % g_count),visualize)

  79. cv2.imwrite(('./aug/train/src/%d.png' % g_count),src_roi)

  80. cv2.imwrite(('./aug/train/label/%d.png' % g_count),label_roi)

  81. count += 1

  82. g_count += 1

经过上面数据增强操作后,我们得到了较大的训练集:100000张256*256的图片。

卷积神经网络

面对这类图像语义分割的任务,我们可以选取的经典网络有很多,比如FCN,U-Net,SegNet,DeepLab,RefineNet,Mask Rcnn,Hed Net这些都是非常经典而且在很多比赛都广泛采用的网络架构。所以我们就可以从中选取一两个经典网络作为我们这个分割任务的解决方案。我们根据我们小组的情况,选取了U-Net和SegNet作为我们的主体网络进行实验。

SegNet

SegNet已经出来好几年了,这不是一个最新、效果最好的语义分割网络,但是它胜在网络结构清晰易懂,训练快速坑少,所以我们也采取它来做同样的任务。SegNet网络结构是编码器-解码器的结构,非常优雅,值得注意的是,SegNet做语义分割时通常在末端加入CRF模块做后处理,旨在进一步精修边缘的分割结果。有兴趣深究的可以看看这里

现在讲解代码部分,首先我们先定义好SegNet的网络结构。

 
  1. def SegNet():

  2. model = Sequential()

  3. #encoder

  4. model.add(Conv2D(64,(3,3),strides=(1,1),input_shape=(3,img_w,img_h),padding='same',activation='relu'))

  5. model.add(BatchNormalization())

  6. model.add(Conv2D(64,(3,3),strides=(1,1),padding='same',activation='relu'))

  7. model.add(BatchNormalization())

  8. model.add(MaxPooling2D(pool_size=(2,2)))

  9. #(128,128)

  10. model.add(Conv2D(128, (3, 3), strides=(1, 1), padding='same', activation='relu'))

  11. model.add(BatchNormalization())

  12. model.add(Conv2D(128, (3, 3), strides=(1, 1), padding='same', activation='relu'))

  13. model.add(BatchNormalization())

  14. model.add(MaxPooling2D(pool_size=(2, 2)))

  15. #(64,64)

  16. model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu'))

  17. model.add(BatchNormalization())

  18. model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu'))

  19. model.add(BatchNormalization())

  20. model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu'))

  21. model.add(BatchNormalization())

  22. model.add(MaxPooling2D(pool_size=(2, 2)))

  23. #(32,32)

  24. model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))

  25. model.add(BatchNormalization())

  26. model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))

  27. model.add(BatchNormalization())

  28. model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))

  29. model.add(BatchNormalization())

  30. model.add(MaxPooling2D(pool_size=(2, 2)))

  31. #(16,16)

  32. model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))

  33. model.add(BatchNormalization())

  34. model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))

  35. model.add(BatchNormalization())

  36. model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))

  37. model.add(BatchNormalization())

  38. model.add(MaxPooling2D(pool_size=(2, 2)))

  39. #(8,8)

  40. #decoder

  41. model.add(UpSampling2D(size=(2,2)))

  42. #(16,16)

  43. model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))

  44. model.add(BatchNormalization())

  45. model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))

  46. model.add(BatchNormalization())

  47. model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))

  48. model.add(BatchNormalization())

  49. model.add(UpSampling2D(size=(2, 2)))

  50. #(32,32)

  51. model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))

  52. model.add(BatchNormalization())

  53. model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))

  54. model.add(BatchNormalization())

  55. model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))

  56. model.add(BatchNormalization())

  57. model.add(UpSampling2D(size=(2, 2)))

  58. #(64,64)

  59. model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu'))

  60. model.add(BatchNormalization())

  61. model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu'))

  62. model.add(BatchNormalization())

  63. model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu'))

  64. model.add(BatchNormalization())

  65. model.add(UpSampling2D(size=(2, 2)))

  66. #(128,128)

  67. model.add(Conv2D(128, (3, 3), strides=(1, 1), padding='same', activation='relu'))

  68. model.add(BatchNormalization())

  69. model.add(Conv2D(128, (3, 3), strides=(1, 1), padding='same', activation='relu'))

  70. model.add(BatchNormalization())

  71. model.add(UpSampling2D(size=(2, 2)))

  72. #(256,256)

  73. model.add(Conv2D(64, (3, 3), strides=(1, 1), input_shape=(3,img_w, img_h), padding='same', activation='relu'))

  74. model.add(BatchNormalization())

  75. model.add(Conv2D(64, (3, 3), strides=(1, 1), padding='same', activation='relu'))

  76. model.add(BatchNormalization())

  77. model.add(Conv2D(n_label, (1, 1), strides=(1, 1), padding='same'))

  78. model.add(Reshape((n_label,img_w*img_h)))

  79. #axis=1和axis=2互换位置,等同于np.swapaxes(layer,1,2)

  80. model.add(Permute((2,1)))

  81. model.add(Activation('softmax'))

  82. model.compile(loss='categorical_crossentropy',optimizer='sgd',metrics=['accuracy'])

  83. model.summary()

  84. return model

然后需要读入数据集。这里我们选择的验证集大小是训练集的0.25。

 
  1. def get_train_val(val_rate = 0.25):

  2. train_url = []

  3. train_set = []

  4. val_set = []

  5. for pic in os.listdir(filepath + 'src'):

  6. train_url.append(pic)

  7. random.shuffle(train_url)

  8. total_num = len(train_url)

  9. val_num = int(val_rate * total_num)

  10. for i in range(len(train_url)):

  11. if i < val_num:

  12. val_set.append(train_url[i])

  13. else:

  14. train_set.append(train_url[i])

  15. return train_set,val_set

  16.  
  17. # data for training

  18. def generateData(batch_size,data=[]):

  19. #print 'generateData...'

  20. while True:

  21. train_data = []

  22. train_label = []

  23. batch = 0

  24. for i in (range(len(data))):

  25. url = data[i]

  26. batch += 1

  27. #print (filepath + 'src/' + url)

  28. #img = load_img(filepath + 'src/' + url, target_size=(img_w, img_h))

  29. img = load_img(filepath + 'src/' + url)

  30. img = img_to_array(img)

  31. # print img

  32. # print img.shape

  33. train_data.append(img)

  34. #label = load_img(filepath + 'label/' + url, target_size=(img_w, img_h),grayscale=True)

  35. label = load_img(filepath + 'label/' + url, grayscale=True)

  36. label = img_to_array(label).reshape((img_w * img_h,))

  37. # print label.shape

  38. train_label.append(label)

  39. if batch % batch_size==0:

  40. #print 'get enough bacth!\n'

  41. train_data = np.array(train_data)

  42. train_label = np.array(train_label).flatten()

  43. train_label = labelencoder.transform(train_label)

  44. train_label = to_categorical(train_label, num_classes=n_label)

  45. train_label = train_label.reshape((batch_size,img_w * img_h,n_label))

  46. yield (train_data,train_label)

  47. train_data = []

  48. train_label = []

  49. batch = 0

  50.  
  51. # data for validation

  52. def generateValidData(batch_size,data=[]):

  53. #print 'generateValidData...'

  54. while True:

  55. valid_data = []

  56. valid_label = []

  57. batch = 0

  58. for i in (range(len(data))):

  59. url = data[i]

  60. batch += 1

  61. #img = load_img(filepath + 'src/' + url, target_size=(img_w, img_h))

  62. img = load_img(filepath + 'src/' + url)

  63. #print img

  64. #print (filepath + 'src/' + url)

  65. img = img_to_array(img)

  66. # print img.shape

  67. valid_data.append(img)

  68. #label = load_img(filepath + 'label/' + url, target_size=(img_w, img_h),grayscale=True)

  69. label = load_img(filepath + 'label/' + url, grayscale=True)

  70. label = img_to_array(label).reshape((img_w * img_h,))

  71. # print label.shape

  72. valid_label.append(label)

  73. if batch % batch_size==0:

  74. valid_data = np.array(valid_data)

  75. valid_label = np.array(valid_label).flatten()

  76. valid_label = labelencoder.transform(valid_label)

  77. valid_label = to_categorical(valid_label, num_classes=n_label)

  78. valid_label = valid_label.reshape((batch_size,img_w * img_h,n_label))

  79. yield (valid_data,valid_label)

  80. valid_data = []

  81. valid_label = []

  82. batch = 0

然后定义一下我们训练的过程,在这个任务上,我们把batch size定为16,epoch定为30,每次都存储最佳model(save_best_only=True),并且在训练结束时绘制loss/acc曲线,并存储起来。

 
  1. def train(args):

  2. EPOCHS = 30

  3. BS = 16

  4. model = SegNet()

  5. modelcheck = ModelCheckpoint(args['model'],monitor='val_acc',save_best_only=True,mode='max')

  6. callable = [modelcheck]

  7. train_set,val_set = get_train_val()

  8. train_numb = len(train_set)

  9. valid_numb = len(val_set)

  10. print ("the number of train data is",train_numb)

  11. print ("the number of val data is",valid_numb)

  12. H = model.fit_generator(generator=generateData(BS,train_set),steps_per_epoch=train_numb//BS,epochs=EPOCHS,verbose=1,

  13. validation_data=generateValidData(BS,val_set),validation_steps=valid_numb//BS,callbacks=callable,max_q_size=1)

  14.  
  15. # plot the training loss and accuracy

  16. plt.style.use("ggplot")

  17. plt.figure()

  18. N = EPOCHS

  19. plt.plot(np.arange(0, N), H.history["loss"], label="train_loss")

  20. plt.plot(np.arange(0, N), H.history["val_loss"], label="val_loss")

  21. plt.plot(np.arange(0, N), H.history["acc"], label="train_acc")

  22. plt.plot(np.arange(0, N), H.history["val_acc"], label="val_acc")

  23. plt.title("Training Loss and Accuracy on SegNet Satellite Seg")

  24. plt.xlabel("Epoch #")

  25. plt.ylabel("Loss/Accuracy")

  26. plt.legend(loc="lower left")

  27. plt.savefig(args["plot"])

然后开始漫长的训练,训练时间接近3天,绘制出的loss/acc图如下:

训练loss降到0.1左右,acc可以去到0.9,但是验证集的loss和acc都没那么好,貌似存在点问题。

先不管了,先看看预测结果吧。

这里需要思考一下怎么预测整张遥感图像。我们知道,我们训练模型时选择的图片输入是256×256,所以我们预测时也要采用256×256的图片尺寸送进模型预测。现在我们要考虑一个问题,我们该怎么将这些预测好的小图重新拼接成一个大图呢?这里给出一个最基础的方案:先给大图做padding 0操作,得到一副padding过的大图,同时我们也生成一个与该图一样大的全0图A,把图像的尺寸补齐为256的倍数,然后以256为步长切割大图,依次将小图送进模型预测,预测好的小图则放在A的相应位置上,依次进行,最终得到预测好的整张大图(即A),再做图像切割,切割成原先图片的尺寸,完成整个预测流程。

 
  1. def predict(args):

  2. # load the trained convolutional neural network

  3. print("[INFO] loading network...")

  4. model = load_model(args["model"])

  5. stride = args['stride']

  6. for n in range(len(TEST_SET)):

  7. path = TEST_SET[n]

  8. #load the image

  9. image = cv2.imread('./test/' + path)

  10. # pre-process the image for classification

  11. #image = image.astype("float") / 255.0

  12. #image = img_to_array(image)

  13. h,w,_ = image.shape

  14. padding_h = (h//stride + 1) * stride

  15. padding_w = (w//stride + 1) * stride

  16. padding_img = np.zeros((padding_h,padding_w,3),dtype=np.uint8)

  17. padding_img[0:h,0:w,:] = image[:,:,:]

  18. padding_img = padding_img.astype("float") / 255.0

  19. padding_img = img_to_array(padding_img)

  20. print 'src:',padding_img.shape

  21. mask_whole = np.zeros((padding_h,padding_w),dtype=np.uint8)

  22. for i in range(padding_h//stride):

  23. for j in range(padding_w//stride):

  24. crop = padding_img[:3,i*stride:i*stride+image_size,j*stride:j*stride+image_size]

  25. _,ch,cw = crop.shape

  26. if ch != 256 or cw != 256:

  27. print 'invalid size!'

  28. continue

  29.  
  30. crop = np.expand_dims(crop, axis=0)

  31. #print 'crop:',crop.shape

  32. pred = model.predict_classes(crop,verbose=2)

  33. pred = labelencoder.inverse_transform(pred[0])

  34. #print (np.unique(pred))

  35. pred = pred.reshape((256,256)).astype(np.uint8)

  36. #print 'pred:',pred.shape

  37. mask_whole[i*stride:i*stride+image_size,j*stride:j*stride+image_size] = pred[:,:]

  38.  
  39.  
  40. cv2.imwrite('./predict/pre'+str(n+1)+'.png',mask_whole[0:h,0:w])

预测的效果图如下:

一眼看去,效果真的不错,但是仔细看一下,就会发现有个很大的问题:拼接痕迹过于明显了!那怎么解决这类边缘问题呢?很直接的想法就是缩小切割时的滑动步伐,比如我们把切割步伐改为128,那么拼接时就会有一般的图像发生重叠,这样做可以尽可能地减少拼接痕迹。

U-Net

对于这个语义分割任务,我们毫不犹豫地选择了U-Net作为我们的方案,原因很简单,我们参考很多类似的遥感图像分割比赛的资料,绝大多数获奖的选手使用的都是U-Net模型。在这么多的好评下,我们选择U-Net也就毫无疑问了。

U-Net有很多优点,最大卖点就是它可以在小数据集上也能train出一个好的模型,这个优点对于我们这个任务来说真的非常适合。而且,U-Net在训练速度上也是非常快的,这对于需要短时间就得出结果的期末project来说也是非常合适。U-Net在网络架构上还是非常优雅的,整个呈现U形,故起名U-Net。这里不打算详细介绍U-Net结构,有兴趣的深究的可以看看论文。

现在开始谈谈代码细节。首先我们定义一下U-Net的网络结构,这里用的deep learning框架还是Keras。

注意到,我们这里训练的模型是一个多分类模型,其实更好的做法是,训练一个二分类模型(使用二分类的标签),对每一类物体进行预测,得到4张预测图,再做预测图叠加,合并成一张完整的包含4类的预测图,这个策略在效果上肯定好于一个直接4分类的模型。所以,U-Net这边我们采取的思路就是对于每一类的分类都训练一个二分类模型,最后再将每一类的预测结果组合成一个四分类的结果。

定义U-Net结构,注意了,这里的loss function我们选了binary_crossentropy,因为我们要训练的是二分类模型。

 
  1. def unet():

  2. inputs = Input((3, img_w, img_h))

  3.  
  4. conv1 = Conv2D(32, (3, 3), activation="relu", padding="same")(inputs)

  5. conv1 = Conv2D(32, (3, 3), activation="relu", padding="same")(conv1)

  6. pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

  7.  
  8. conv2 = Conv2D(64, (3, 3), activation="relu", padding="same")(pool1)

  9. conv2 = Conv2D(64, (3, 3), activation="relu", padding="same")(conv2)

  10. pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

  11.  
  12. conv3 = Conv2D(128, (3, 3), activation="relu", padding="same")(pool2)

  13. conv3 = Conv2D(128, (3, 3), activation="relu", padding="same")(conv3)

  14. pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

  15.  
  16. conv4 = Conv2D(256, (3, 3), activation="relu", padding="same")(pool3)

  17. conv4 = Conv2D(256, (3, 3), activation="relu", padding="same")(conv4)

  18. pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

  19.  
  20. conv5 = Conv2D(512, (3, 3), activation="relu", padding="same")(pool4)

  21. conv5 = Conv2D(512, (3, 3), activation="relu", padding="same")(conv5)

  22.  
  23. up6 = concatenate([UpSampling2D(size=(2, 2))(conv5), conv4], axis=1)

  24. conv6 = Conv2D(256, (3, 3), activation="relu", padding="same")(up6)

  25. conv6 = Conv2D(256, (3, 3), activation="relu", padding="same")(conv6)

  26.  
  27. up7 = concatenate([UpSampling2D(size=(2, 2))(conv6), conv3], axis=1)

  28. conv7 = Conv2D(128, (3, 3), activation="relu", padding="same")(up7)

  29. conv7 = Conv2D(128, (3, 3), activation="relu", padding="same")(conv7)

  30.  
  31. up8 = concatenate([UpSampling2D(size=(2, 2))(conv7), conv2], axis=1)

  32. conv8 = Conv2D(64, (3, 3), activation="relu", padding="same")(up8)

  33. conv8 = Conv2D(64, (3, 3), activation="relu", padding="same")(conv8)

  34.  
  35. up9 = concatenate([UpSampling2D(size=(2, 2))(conv8), conv1], axis=1)

  36. conv9 = Conv2D(32, (3, 3), activation="relu", padding="same")(up9)

  37. conv9 = Conv2D(32, (3, 3), activation="relu", padding="same")(conv9)

  38.  
  39. conv10 = Conv2D(n_label, (1, 1), activation="sigmoid")(conv9)

  40. #conv10 = Conv2D(n_label, (1, 1), activation="softmax")(conv9)

  41.  
  42. model = Model(inputs=inputs, outputs=conv10)

  43. model.compile(optimizer='Adam', loss='binary_crossentropy', metrics=['accuracy'])

  44. return model

读取数据的组织方式有一些改动。

 
  1. # data for training

  2. def generateData(batch_size,data=[]):

  3. #print 'generateData...'

  4. while True:

  5. train_data = []

  6. train_label = []

  7. batch = 0

  8. for i in (range(len(data))):

  9. url = data[i]

  10. batch += 1

  11. img = load_img(filepath + 'src/' + url)

  12. img = img_to_array(img)

  13. train_data.append(img)

  14. label = load_img(filepath + 'label/' + url, grayscale=True)

  15. label = img_to_array(label)

  16. #print label.shape

  17. train_label.append(label)

  18. if batch % batch_size==0:

  19. #print 'get enough bacth!\n'

  20. train_data = np.array(train_data)

  21. train_label = np.array(train_label)

  22.  
  23. yield (train_data,train_label)

  24. train_data = []

  25. train_label = []

  26. batch = 0

  27.  
  28. # data for validation

  29. def generateValidData(batch_size,data=[]):

  30. #print 'generateValidData...'

  31. while True:

  32. valid_data = []

  33. valid_label = []

  34. batch = 0

  35. for i in (range(len(data))):

  36. url = data[i]

  37. batch += 1

  38. img = load_img(filepath + 'src/' + url)

  39. #print img

  40. img = img_to_array(img)

  41. # print img.shape

  42. valid_data.append(img)

  43. label = load_img(filepath + 'label/' + url, grayscale=True)

  44. valid_label.append(label)

  45. if batch % batch_size==0:

  46. valid_data = np.array(valid_data)

  47. valid_label = np.array(valid_label)

  48. yield (valid_data,valid_label)

  49. valid_data = []

  50. valid_label = []

  51. batch = 0

训练:指定输出model名字和训练集位置

python unet.py --model unet_buildings20.h5 --data ./unet_train/buildings/

预测单张遥感图像时我们分别使用4个模型做预测,那我们就会得到4张mask(比如下图就是我们用训练好的buildings模型预测的结果),我们现在要将这4张mask合并成1张,那么怎么合并会比较好呢?我思路是,通过观察每一类的预测结果,我们可以从直观上知道哪些类的预测比较准确,那么我们就可以给这些mask图排优先级了,比如:priority:building>water>road>vegetation,那么当遇到一个像素点,4个mask图都说是属于自己类别的标签时,我们就可以根据先前定义好的优先级,把该像素的标签定为优先级最高的标签。代码思路可以参照下面的代码:

 
  1. def combind_all_mask():

  2. for mask_num in tqdm(range(3)):

  3. if mask_num == 0:

  4. final_mask = np.zeros((5142,5664),np.uint8)#生成一个全黑全0图像,图片尺寸与原图相同

  5. elif mask_num == 1:

  6. final_mask = np.zeros((2470,4011),np.uint8)

  7. elif mask_num == 2:

  8. final_mask = np.zeros((6116,3356),np.uint8)

  9. #final_mask = cv2.imread('final_1_8bits_predict.png',0)

  10.  
  11. if mask_num == 0:

  12. mask_pool = mask1_pool

  13. elif mask_num == 1:

  14. mask_pool = mask2_pool

  15. elif mask_num == 2:

  16. mask_pool = mask3_pool

  17. final_name = img_sets[mask_num]

  18. for idx,name in enumerate(mask_pool):

  19. img = cv2.imread('./predict_mask/'+name,0)

  20. height,width = img.shape

  21. label_value = idx+1 #coressponding labels value

  22. for i in tqdm(range(height)): #priority:building>water>road>vegetation

  23. for j in range(width):

  24. if img[i,j] == 255:

  25. if label_value == 2:

  26. final_mask[i,j] = label_value

  27. elif label_value == 3 and final_mask[i,j] != 2:

  28. final_mask[i,j] = label_value

  29. elif label_value == 4 and final_mask[i,j] != 2 and final_mask[i,j] != 3:

  30. final_mask[i,j] = label_value

  31. elif label_value == 1 and final_mask[i,j] == 0:

  32. final_mask[i,j] = label_value

  33.  
  34. cv2.imwrite('./final_result/'+final_name,final_mask)

  35.  
  36.  
  37. print 'combinding mask...'

  38. combind_all_mask()

模型融合

集成学习的方法在这类比赛中经常使用,要想获得好成绩集成学习必须做得好。在这里简单谈谈思路,我们使用了两个模型,我们模型也会采取不同参数去训练和预测,那么我们就会得到很多预测MASK图,此时 我们可以采取模型融合的思路,对每张结果图的每个像素点采取投票表决的思路,对每张图相应位置的像素点的类别进行预测,票数最多的类别即为该像素点的类别。正所谓“三个臭皮匠,胜过诸葛亮”,我们这种ensemble的思路,可以很好地去掉一些明显分类错误的像素点,很大程度上改善模型的预测能力。

少数服从多数的投票表决策略代码:

 
  1. import numpy as np

  2. import cv2

  3. import argparse

  4.  
  5. RESULT_PREFIXX = ['./result1/','./result2/','./result3/']

  6.  
  7. # each mask has 5 classes: 0~4

  8.  
  9. def vote_per_image(image_id):

  10. result_list = []

  11. for j in range(len(RESULT_PREFIXX)):

  12. im = cv2.imread(RESULT_PREFIXX[j]+str(image_id)+'.png',0)

  13. result_list.append(im)

  14.  
  15. # each pixel

  16. height,width = result_list[0].shape

  17. vote_mask = np.zeros((height,width))

  18. for h in range(height):

  19. for w in range(width):

  20. record = np.zeros((1,5))

  21. for n in range(len(result_list)):

  22. mask = result_list[n]

  23. pixel = mask[h,w]

  24. #print('pix:',pixel)

  25. record[0,pixel]+=1

  26.  
  27. label = record.argmax()

  28. #print(label)

  29. vote_mask[h,w] = label

  30.  
  31. cv2.imwrite('vote_mask'+str(image_id)+'.png',vote_mask)

  32.  
  33.  
  34. vote_per_image(3)

模型融合后的预测结果:

可以看出,模型融合后的预测效果确实有较大提升,明显错误分类的像素点消失了。

额外的思路:GAN

我们对数据方面思考得更多一些,我们针对数据集小的问题,我们有个想法:使用生成对抗网络去生成虚假的卫星地图,旨在进一步扩大数据集。我们的想法就是,使用这些虚假+真实的数据集去训练网络,网络的泛化能力肯定有更大的提升。我们的想法是根据这篇论文(pix2pix)来展开的,这是一篇很有意思的论文,它主要讲的是用图像生成图像的方法。里面提到了用标注好的卫星地图生成虚假的卫星地图的想法,真的让人耳目一新,我们也想根据该思路,生成属于我们的虚假卫星地图数据集。 Map to Aerial的效果是多么的震撼。

但是我们自己实现起来的效果却不容乐观(如下图所示,右面那幅就是我们生成的假图),效果不好的原因有很多,标注的问题最大,因为生成的虚假卫星地图质量不好,所以该想法以失败告终,生成的假图也没有拿去做训练。但感觉思路还是可行的,如果给的标注合适的话,还是可以生成非常像的虚假地图。

总结

对于这类遥感图像的语义分割,思路还有很多,最容易想到的思路就是,将各种语义分割经典网络都实现以下,看看哪个效果最好,再做模型融合,只要集成学习做得好,效果一般都会很不错的。我们仅靠上面那个简单思路(数据增强,经典模型搭建,集成学习),就已经可以获得比赛的TOP 5%了,当然还有一些tricks可以使效果更进一步提升,这里就不细说了,总的建模思路掌握就行。完整的代码可以在我的github获取。

 

数据下载:

链接:https://pan.baidu.com/s/1i6oMukH

密码:yqj2

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/493481.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

实验四51单片机并口实验

一.实验目的&#xff1a; 1. 了解51单片机I/O口的电气特性和驱动能力。 2. 了解LED电路中加入限流电阻的原因。 3. 掌握定时器原理及编程。 4. 掌握并口程序的编辑、编译、调试和运行。 二.实验设备和器件 1. PC机 2. PROTEUS仿真软件 3. 实验箱 4. ISP下载器 5. 51仿真器…

5G 建设拉动光模块量价齐升

来源&#xff1a;国联证券光模块是光通信的核心部件&#xff0c;它主要完成光电转换和电光转换。行业增长稳定&#xff0c;行业内的头部企业通过不断并购完成上下游的整合&#xff0c;提供一体化的解决方案。国内企业在快速的追赶外资标杆企业&#xff0c;产品逐步往高端方向转…

tensorflow权重初始化

一&#xff0c;用10层神经网络&#xff0c;每一层的参数都是随机正态分布&#xff0c;均值为0&#xff0c;标准差为0.01 #10层神经网络 data tf.constant(np.random.randn(2000, 800).astype(float32)) layer_sizes [800 - 50 * i for i in range(0, 10)] num_layers len(l…

单片机实验报告-片内外RAM的数据转移

一、实验目的&#xff1a; 1.掌握C51编程基础&#xff0c;C51程序结构。 2.掌握C51数据类型、函数设计。 3.掌握C51程序的编辑、编译、调试和运行 二、编程提示 编程将片外8000H单元开始的10字节的内容移至8100H开始的各单元中。8000H单元开始的10字节内容用编程方式赋值。…

单片机实验-定时中断

一.实验目的 1.掌握51单片机定时器工作原理。 2.掌握51单片机中断系统工作原理。 3.掌握定时器初始化编程。 4.掌握中断程序的编写和调试。 二.实验设备和器件 1.KEIL软件 2.PROTEUS仿真软件 3.伟福实验箱 三&#xff0e;实验内容 &#xff08;1&#xff09;编程实…

2018全球最值得关注的60家半导体公司,7家中国公司新上榜 | 年度榜单

编译 | 张玺 四月来源&#xff1a;机器之能由《EE Times》每年评选全球值得关注的 60 家新创半导体公司排行榜——『Silicon 60』&#xff0c;今年已经迈向第 19 届&#xff0c;今年的关键词仍然是「机器学习」(machine learning)&#xff0c;它正以硬件支持的运算形式强势崛起…

单片机实验报告-串口实验

一.实验目的 1. 掌握 51 单片机串口工作原理。 2. 掌握 51 单片机串口初始化编程。 3. 掌握 51 单片机串口的软硬件编程。 二.实验设备和器件 1.KEIL软件 2.PROTEUS仿真软件 3.伟福实验箱 三&#xff0e;实验内容 &#xff08;1&#xff09;编程实现&#xff1a…

学习率周期性变化

学习率周期性变化&#xff0c;能后解决陷入鞍点的问题&#xff0c;更多的方式请参考https://github.com/bckenstler/CLR base_lr:最低的学习率 max_lr:最高的学习率 step_size&#xff1a;&#xff08;2-8&#xff09;倍的每个epoch的训练次数。 scale_fn(x)&#xff1a;自…

清华发布《人工智能AI芯片研究报告》,一文读懂人才技术趋势

来源&#xff1a;Future智能摘要&#xff1a;大数据产业的爆炸性增长下&#xff0c;AI 芯片作为人工智能时代的技术核心之一&#xff0c;决定了平台的基础架构和发展生态。 近日&#xff0c;清华大学推出了《 人工智能芯片研究报告 》&#xff0c;全面讲解人工智能芯片&#xf…

开发者账号申请 真机调试 应用发布

技术博客http://www.cnblogs.com/ChenYilong/ 新浪微博http://weibo.com/luohanchenyilong 开发者账号申请 真机调试 应用发布 技术博客http://www.cnblogs.com/ChenYilong/新浪微博http://weibo.com/luohanchenyilong 要解决的问题 • 开发者账号申请 • 真机调试 • 真机调…

单片机实验-DA实验

一、实验目的 1、了解 D/A 转换的基本原理。 2、了解 D/A 转换芯片 0832 的性能及编程方法。 3、了解单片机系统中扩展 D/A 转换的基本方法。 二.实验设备和器件 1.KEIL软件 2.实验箱 三&#xff0e;实验内容 利用 DAC0832&#xff0c;编制程序产生锯齿波、三角波、正弦…

进化三部曲,从互联网大脑发育看产业互联网的未来

摘要&#xff1a;从互联网的左右大脑发育看&#xff0c;产业互联网可以看做互联网的下半场&#xff0c;但从互联网大脑的长远发育看&#xff0c;互联网依然处于大脑尚未发育成熟的婴儿时期&#xff0c;未来还需要漫长的时间发育。参考互联网右大脑的发育历程&#xff0c;我们判…

pycharm远程连接服务器(docker)调试+ssh连接多次报错

一&#xff0c;登入服务器建docker nvidia-docker run -it -v ~/workspace/:/workspace -w /workspace/ --namefzh_tf --shm-size 8G -p 1111:22 -p 1112:6006 -p 1113:8888 tensorflow/tensorflow:1.4.0-devel-gpu bash 二&#xff0c;开ssh服务 apt-get update apt-get i…

Verilog HDL语言设计4个独立的非门

代码&#xff1a; module yanxu11(in,out); input wire[3:0] in; output reg[3:0] out; always (in) begin out[0]~in[0]; out[1]~in[1]; out[2]~in[2]; out[3]~in[3]; end endmodule timescale 1ns/1ns module test(); reg[3:0] in; wire[3:0] out; yanxu11 U(…

深度长文:表面繁荣之下,人工智能的发展已陷入困境

来源&#xff1a;36氪编辑&#xff1a;郝鹏程摘要&#xff1a;《连线》杂志在其最近发布的12月刊上&#xff0c;以封面故事的形式报道了人工智能的发展状况。现在&#xff0c;深度学习面临着无法进行推理的困境&#xff0c;这也就意味着&#xff0c;它无法让机器具备像人一样的…

Verilog HDL语言设计一个比较电路

设计一个比较电路&#xff0c;当输入的一位8421BCD码大于4时&#xff0c;输出为1&#xff0c;否则为0&#xff0c;进行功能仿真&#xff0c;查看仿真结果&#xff0c;将Verilog代码和仿真波形图整理入实验报告。 代码&#xff1a; module yanxu12(in,out); input wire[3:0] i…

交叉熵

1.公式 用sigmoid推导 上式做一下转换&#xff1a; y 视为类后验概率 p(y 1 | x)&#xff0c;则上式可以写为&#xff1a; 则有&#xff1a; 将上式进行简单综合&#xff0c;可写成如下形式&#xff1a; 写成对数形式就是我们熟知的交叉熵损失函数了&#xff0c;这也是交叉熵…

第5章 散列

我们在第4章讨论了查找树ADT&#xff0c;它允许对一组元素进行各种操作。本章讨论散列表(hash table)ADT&#xff0c;不过它只支持二叉查找树所允许的一部分操作。 散列表的实现常常叫作散列(hashing)。散列是一种以常数平均时间执行插入、删除和查找的技术。但是&#xff0c;那…

谷歌自动驾驶是个大坑,还好中国在构建自己的智能驾驶大系统

来源&#xff1a;张国斌中国有堪称全球最复杂的路况&#xff0c;例如上图是去年投入使用的重庆黄桷湾立交桥上下共5层&#xff0c;共20条匝道&#xff0c;堪称中国最复杂立交桥之最&#xff0c;据称走错一个路口要在这里一日游&#xff0c;这样的立交桥如果让谷歌无人驾驶车上去…

Verilog HDL语言设计计数器+加法器

完成课本例题4.12&#xff0c;进行综合和仿真&#xff08;包括功能仿真和时序仿真&#xff09;&#xff0c;查看仿真结果&#xff0c;将Verilog代码和仿真波形图整理入实验报告。 功能文件&#xff1a; module shiyan1(out,reset,clk); input reset,clk; output reg[3:0] ou…