图像分割中的编码解码结构(Encoder-Decoder Model)是一种广泛应用的网络架构,它有效地结合了特征提取(编码)和分割结果生成(解码)两个过程。以下是对图像分割中编码解码结构的详细解析:
一、编码器(Encoder)
功能:
编码器负责从输入图像中提取有用的特征信息。这些特征信息通常是图像中不同对象或区域的抽象表示,有助于后续的分割任务。
工作原理:
- 特征提取:编码器通过一系列卷积层(Convolutional Layers)、池化层(Pooling Layers)等网络结构,逐步从输入图像中提取高级语义特征。这些特征图(Feature Maps)的分辨率逐渐降低,但包含的信息量逐渐增加,能够捕捉到图像中的关键信息。
- 下采样:在编码过程中,通常会使用池化层或步长大于1的卷积层进行下采样,以降低特征图的分辨率,减少计算量,并增加感受野(Receptive Field),使模型能够捕捉到更全局的信息。
二、解码器(Decoder)
功能:
解码器负责将编码器提取的特征信息转换为最终的分割结果。它逐步恢复特征图的分辨率,并生成与输入图像相同尺寸的分割掩码(Segmentation Mask),其中每个像素都被分配了一个类别标签。
工作原理:
- 上采样:解码器通过上采样操作(如转置卷积、双线性插值等)逐步恢复特征图的分辨率。上采样过程与编码器的下采样过程相反,旨在将特征图的尺寸恢复到与输入图像相同或接近。
- 特征融合:在解码过程中,通常会将编码器中的某些特征图与解码器中对应尺度的特征图进行融合(如拼接、相加等),以结合不同尺度的信息,提高分割结果的准确性。这种特征融合方式有助于模型捕捉到更精细的细节信息。
- 输出层:解码器的最后一层通常是一个卷积层,用于将特征图转换为分割掩码。该卷积层的输出通道数与类别数相同,每个通道对应一个类别的预测概率图。通过应用softmax函数或argmax操作,可以将这些概率图转换为最终的分割掩码。
三、代码实现
采用卷积-BN-池化的结构,设计下采样模块,每经过一次下采样模块,通道数翻倍,特征图长宽缩小一倍。
采用转置卷积实现上采样模块,每经过一次上采样,通道数减半,特征图长宽放大一倍
import keras
from tensorflow.keras import Model, layers
from tensorflow.keras.utils import plot_model
def DownSample(filters):'''卷积-BN-池化结构,每经过一次下采样,通道数翻倍,特征图长宽缩小一半:param filters:卷积核参数 ,输出通道数:return:'''Layer = keras.Sequential(layers=[layers.Conv2D(filters=filters, strides=1, kernel_size=3, padding="same"),layers.BatchNormalization(),layers.MaxPooling2D(strides=2, pool_size=2, padding="same"),layers.ReLU()], name="DownSample"+str(filters))return Layer
def UpSample(filters):'''转置卷积,每经过一次转置卷积,通道数减半,特征图长宽翻倍:param filters: 输出通道数:return:'''Layer = keras.Sequential(layers=[layers.Convolution2DTranspose(filters=filters, kernel_size=3, strides=2, padding="same",activation="relu")], name="UpSample"+str(filters))return Layerdef En_De():x = layers.Input(shape=(256,256,3))Down = [DownSample(64),DownSample(128),DownSample(256),DownSample(512)]y = Down[0](x)for Layer in Down[1:]:y=Layer(y)Up=[UpSample(256),UpSample(128),UpSample(64),UpSample(32)]for Layer in Up:y = Layer(y)y = layers.Conv2D(name="Result", filters=1, padding="same", kernel_size=3, strides=1, activation="sigmoid")(y)model = Model(x, y)return modelEn_De = En_De()
plot_model(En_De, "Model.png", show_shapes=True)
En_De.summary()
结果:
模型图:
四、网络的训练
网络的训练