数据集链接:https://pan.baidu.com/s/1IBe_P0AyHgZC39NqzOxZhA?pwd=nztc
提取码:nztc
- UNet模型
import torch
import torch.nn as nnclass conv_block(nn.Module):def __init__(self, ch_in, ch_out):super(conv_block, self).__init__()self.conv = nn.Sequential(nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),nn.BatchNorm2d(ch_out),nn.ReLU(inplace=True),nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True),nn.BatchNorm2d(ch_out),nn.ReLU(inplace=True))def forward(self, x):x = self.conv(x)return xclass up_conv(nn.Module):def __init__(self, ch_in, ch_out):super(up_conv, self).__init__()self.up = nn.Sequential(nn.Upsample(scale_factor=2),nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),nn.BatchNorm2d(ch_out),nn.ReLU(inplace=True))def forward(self, x):x = self.up(x)return x
class UNet(nn.Module):def __init__(self, img_ch=3, output_ch=1):super(UNet, self).__init__()self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)self.Conv1 = conv_block(ch_in=img_ch, ch_out=64)self.Conv2 = conv_block(ch_in=64, ch_out=128)self.Conv3 = conv_block(ch_in=128, ch_out=256)self.Conv4 = conv_block(ch_in=256, ch_out=512)self.Conv5 = conv_block(ch_in=512, ch_out=1024)self.Up5 = up_conv(ch_in=1024, ch_out=512)self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)self.Up4 = up_conv(ch_in=512, ch_out=256)self.Up_conv4 = conv_block(ch_in=512, ch_out=256)self.Up3 = up_conv(ch_in=256, ch_out=128)self.Up_conv3 = conv_block(ch_in=