铁轨语义分割(Unet结合resnet系列)

数据介绍

一类是图片,一类是图像标签。
在这里插入图片描述
在这里插入图片描述

引入库,处理数据

import torch.nn as nn
import torch
import torch.nn.functional as F
import os
from PIL import Image
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from torch.utils.data import random_split
import cv2
import numpy as np
# 读取数据class SemanticSegmentationDataset(Dataset):def __init__(self, data_dir, transform=None):self.data_dir = data_dirself.transform = transformself.images_dir = os.path.join(data_dir, 'Railsurfaceimages')self.labels_dir = os.path.join(data_dir, 'GroundTruth')self.filenames = sorted(os.listdir(self.images_dir))def __len__(self):return len(self.filenames)def __getitem__(self, idx):img_name = self.filenames[idx]img_path = os.path.join(self.images_dir, img_name)label_path = os.path.join(self.labels_dir, img_name)image = Image.open(img_path)label = Image.open(label_path)image = np.array(image)label = np.array(label)image = image.reshape(1, image.shape[0], image.shape[1])label = label.reshape(1, label.shape[0], label.shape[1])# 标签操作label[label<=122] = 0label[label>122] = 1return image, label# 数据预处理
transform = transforms.Compose([transforms.ToTensor(),])

读取图像

# 读取图像
data_dir = 'C:/Users/jiaoyang/Desktop/数据集/RSDDs 数据集/RSDDs 数据集/Type-II RSDDs dataset'
dataset = SemanticSegmentationDataset(data_dir=data_dir,transform=transform)
for i,j in dataset:print(i.shape)print(j.shape)break

在这里插入图片描述
数据集划分

# 数据集的划分
val_size = int(len(dataset) * 0.1)
test_size = int(len(dataset)*0.1)
train_size = len(dataset) - val_size - test_sizetrain_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size],generator=torch.Generator().manual_seed(42))

读取数据

# 读取数据
batch_size=2
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
test_loader =  torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

查看数据

for i,j in train_loader:print(i.shape)print(j.shape)values, counts = torch.unique(j, return_counts=True)for value, count in zip(values, counts):print(f"{value}: {count}")break

在这里插入图片描述

# 查看数据尺寸
for i,j in train_loader:print(i.shape)print(j.shape)break

在这里插入图片描述

搭建网络

# 搭建网络class DoubleConv(nn.Module):"""(convolution => [BN] => ReLU) * 2"""def __init__(self, in_channels, out_channels):super().__init__()self.double_conv = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))def forward(self, x):return self.double_conv(x)class Down(nn.Module):"""Downscaling with maxpool then double conv"""def __init__(self, in_channels, out_channels):super().__init__()self.maxpool_conv = nn.Sequential(nn.MaxPool2d(2),DoubleConv(in_channels, out_channels))def forward(self, x):return self.maxpool_conv(x)class Up(nn.Module):"""Upscaling then double conv"""def __init__(self, in_channels, out_channels, bilinear=True):super().__init__()# if bilinear, use the normal convolutions to reduce the number of channelsif bilinear:self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)else:self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)self.conv = DoubleConv(in_channels, out_channels)def forward(self, x1, x2):x1 = self.up(x1)# input is CHWdiffY = torch.tensor([x2.size()[2] - x1.size()[2]])diffX = torch.tensor([x2.size()[3] - x1.size()[3]])x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,diffY // 2, diffY - diffY // 2])# if you have padding issues, see# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bdx = torch.cat([x2, x1], dim=1)return self.conv(x)class OutConv(nn.Module):def __init__(self, in_channels, out_channels):super(OutConv, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)def forward(self, x):return self.conv(x)
class UNet(nn.Module):def __init__(self, n_channels, n_classes, bilinear=False):super(UNet, self).__init__()self.n_channels = n_channelsself.n_classes = n_classesself.bilinear = bilinearself.inc = DoubleConv(n_channels, 64)self.down1 = Down(64, 128)self.down2 = Down(128, 256)self.down3 = Down(256, 512)self.down4 = Down(512, 1024)self.up1 = Up(1024, 512, bilinear)self.up2 = Up(512, 256, bilinear)self.up3 = Up(256, 128, bilinear)self.up4 = Up(128, 64, bilinear)self.outc = OutConv(64, n_classes)def forward(self, x):x1 = self.inc(x)x2 = self.down1(x1)x3 = self.down2(x2)x4 = self.down3(x3)x5 = self.down4(x4)x = self.up1(x5, x4)x = self.up2(x, x3)x = self.up3(x, x2)x = self.up4(x, x1)logits = self.outc(x)return logits

简单测试模型

# 简单测试模型
model = UNet(n_channels=1, n_classes=1)
X = torch.randn(1,1,1250,55)
out = model(X)
out.shape

在这里插入图片描述

训练函数及训练

设置训练参数

# 参数设置
lr=0.0001
#model = UNet(n_channels=1, n_classes=1).to(device='cuda', dtype=torch.float32)
optimizer = torch.optim.RMSprop(model.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
criterion = nn.BCEWithLogitsLoss()
num_epochs = 50

训练函数

def train(model, criterion, optimizer, train_loader, val_loader, num_epochs,device='cuda'):for epoch in range(num_epochs):# 训练模式model.train()train_loss = 0.0for images, masks in train_loader:# 将数据移动到计算设备上images = images.to(device,dtype=torch.float32)masks = masks.to(device,dtype=torch.float32)# 前向传播outputs = model(images)loss = criterion(outputs, masks)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()train_loss += loss.item() * images.size(0)# 验证模式model.eval()val_loss = 0.0num_correct = 0num_pixels = 0with torch.no_grad():for images, masks in val_loader:# 将数据移动到计算设备上images = images.to(device,dtype=torch.float32)masks = masks.to(device,dtype=torch.float32)# 前向传播outputs = model(images)loss = criterion(outputs, masks)# 计算指标val_loss += loss.item() * images.size(0)outputs[outputs >= 0] = 255outputs[outputs < 0] = 0outputs[outputs==255] = 1preds = outputsnum_correct += torch.sum(preds == masks).item()num_pixels += torch.numel(preds)train_loss /= len(train_dataset)val_loss /= len(val_dataset)accuracy = num_correct / num_pixels# 打印训练过程中的相关指标print('Epoch: {}, Train Loss: {:.4f}, Val Loss: {:.4f}, Accuracy: {:.4f}'.format(epoch+1, train_loss, val_loss, accuracy))

开始训练

train(model, criterion, optimizer, train_loader, val_loader, num_epochs)

在这里插入图片描述

保存及预测,各项评价指标

保存模型

# 保存模型
# 保存模型参数
PATH = "./data/resnet+unet++.pt"
torch.save(model.state_dict(), PATH)

加载模型参数

# 加载模型参数# 创建一个新的模型
model = NestedUResnet(block=BasicBlock,layers=[3,4,6,3],num_classes=1).to(device='cuda', dtype=torch.float32)# 加载之前保存的模型参数
PATH = "./data/resnet+unet++.pt"
model.load_state_dict(torch.load(PATH))

预测并保存图片

# 保存图片for data,label in test_loader:data = data.to(device='cuda',dtype=torch.float32)out = model(data)out[out >= 0] = 255out[out < 0] = 0out = out[0][0].cpu().detach().numpy()#print(out)label[label==1] = 255label = label[0][0].cpu()label = np.array(label)cv2.imwrite('./data/label.png', label)cv2.imwrite('./data/out.png', out)breakfor data,label in test_loader:data = data.to(device='cuda',dtype=torch.float32)out = model(data)out[out >= 0] = 255out[out < 0] = 0out = out[1][0].cpu().detach().numpy()#print(out)label[label==1] = 255label = label[1][0].cpu()label = np.array(label)cv2.imwrite('./data/label2.png', label)cv2.imwrite('./data/out2.png', out)break

计算混淆矩阵

# 计算混淆矩阵,0表示白色像素,表示正例
from sklearn.metrics import confusion_matrix
TP = []
FN = []
FP = []
TN = []
for data,label in test_loader:data = data.to(device='cuda',dtype=torch.float32)out = model(data)out[out >= 0] = 255out[out < 0] = 0# 转换以便求混淆矩阵out[out == 0] = 1out[out == 255] = 0label[label == 0] = 255label[label == 1] = 0label[label == 255] = 1out = out.view(-1).cpu().detach().numpy()label = label.view(-1).cpu().detach().numpy()confusion = confusion_matrix(label, out)TP.append(confusion[0][0])FN.append(confusion[0][1])FP.append(confusion[1][0])TN.append(confusion[1][1])TP = np.sum(np.array(TP))
FN = np.sum(np.array(FN))
FP = np.sum(np.array(FP))
TN = np.sum(np.array(TN))

计算各项评价指标

# 计算各评价指标
# 计算F1的值
Precision = TP / (TP + FP)
Recall = TP / (TP + FN)
F1 = 2 * (Precision * Recall) / (Precision + Recall)
print('F1:{:.4f}'.format(F1))# 类别像素准确率1
cpa1 = TP/(TP+FP)
print('cpa1:{:.4f}'.format(cpa1))# 类别像素准确率2
cpa2 = TN / (TN + FN)
print('cpa2:{:.4f}'.format(cpa2))# MPA
mpa = (cpa2+cpa1)/2
print('MPA:{:.4f}'.format(mpa))# PA(像素准确率)
pa = (TP + TN) / (TP + TN + FP + FN)
print('PA:{:.4f}'.format(pa))# 交并比1
Iou1 = TP/(TP+FP+FN)
print('Iou1:{:.4f}'.format(Iou1))# 交并比2
Iou2 = TN / (TN + FN + FP)
print('Iou2:{:.4f}'.format(Iou2))# MIou
MIou = (Iou1+Iou2)/2
print('MIou:{:.4f}'.format(MIou))

在这里插入图片描述

Unet++网络的搭建

class VGGBlock(nn.Module):def __init__(self, in_channels, middle_channels, out_channels):super().__init__()self.first = nn.Sequential(nn.Conv2d(in_channels, middle_channels, 3, padding=1,bias = False),nn.BatchNorm2d(middle_channels),nn.ReLU(inplace = True))self.second = nn.Sequential(nn.Conv2d(middle_channels, out_channels, 3, padding=1,bias = False),       nn.BatchNorm2d(out_channels),nn.ReLU(inplace = True))def forward(self, x):out = self.first(x)out = self.second(out)return outclass Up(nn.Module):  # 将x1上采样,然后调整为x2的大小"""Upscaling then double conv"""def __init__(self):super().__init__()self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)def forward(self, x1, x2):x1 = self.up(x1) # 将传入数据上采样,diffY = torch.tensor([x2.size()[2] - x1.size()[2]])diffX = torch.tensor([x2.size()[3] - x1.size()[3]])x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,diffY // 2, diffY - diffY // 2])  # 填充为x2相同的大小return x1  class UNetplusplus(nn.Module):def __init__(self, num_classes, input_channels=1, deep_supervision=False, **kwargs):super().__init__()nb_filter = [64, 128, 256, 512,1024]self.deep_supervision = deep_supervisionself.Up = Up()self.pool = nn.MaxPool2d(2, 2)self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])self.conv0_1 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])self.conv1_1 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])self.conv2_1 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])self.conv0_2 = VGGBlock(nb_filter[0]*2+nb_filter[1], nb_filter[0], nb_filter[0])self.conv1_2 = VGGBlock(nb_filter[1]*2+nb_filter[2], nb_filter[1], nb_filter[1])self.conv2_2 = VGGBlock(nb_filter[2]*2+nb_filter[3], nb_filter[2], nb_filter[2])self.conv0_3 = VGGBlock(nb_filter[0]*3+nb_filter[1], nb_filter[0], nb_filter[0])self.conv1_3 = VGGBlock(nb_filter[1]*3+nb_filter[2], nb_filter[1], nb_filter[1])self.conv0_4 = VGGBlock(nb_filter[0]*4+nb_filter[1], nb_filter[0], nb_filter[0])if self.deep_supervision:self.final1 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)self.final2 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)self.final3 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)self.final4 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)else:self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)def forward(self, input):x0_0 = self.conv0_0(input)x1_0 = self.conv1_0(self.pool(x0_0))x0_1 = self.conv0_1(torch.cat([x0_0, self.Up(x1_0,x0_0)], 1))x2_0 = self.conv2_0(self.pool(x1_0))x1_1 = self.conv1_1(torch.cat([x1_0, self.Up(x2_0,x1_0)], 1))x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.Up(x1_1,x0_0)], 1))x3_0 = self.conv3_0(self.pool(x2_0))x2_1 = self.conv2_1(torch.cat([x2_0, self.Up(x3_0,x2_0)], 1))x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.Up(x2_1,x1_0)], 1))x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.Up(x1_2,x0_0)], 1))x4_0 = self.conv4_0(self.pool(x3_0))x3_1 = self.conv3_1(torch.cat([x3_0, self.Up(x4_0,x3_0)], 1))x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.Up(x3_1,x2_0)], 1))x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.Up(x2_2,x1_0)], 1))x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.Up(x1_3,x0_0)], 1))if self.deep_supervision:               #多个输出output1 = self.final1(x0_1)output2 = self.final2(x0_2)output3 = self.final3(x0_3)output4 = self.final4(x0_4)return [output1, output2, output3, output4]else:output = self.final(x0_4)return output

简单测试模型

# 简单测试模型
model = UNetplusplus(1)
x = torch.rand(1,1,1250,55)
out = model(x)
print(out.shape)

在这里插入图片描述

resnet+unet网络的搭建

class Up(nn.Module):  # 将x1上采样,然后调整为x2的大小"""Upscaling then double conv"""def __init__(self):super().__init__()self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)def forward(self, x1, x2):x1 = self.up(x1) # 将传入数据上采样,diffY = torch.tensor([x2.size()[2] - x1.size()[2]])diffX = torch.tensor([x2.size()[3] - x1.size()[3]])x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,diffY // 2, diffY - diffY // 2])  # 填充为x2相同的大小return x1 class BasicBlock(nn.Module):          expansion = 1def __init__(self, in_channels, out_channels, stride=1):super().__init__()self.residual_function = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(out_channels * BasicBlock.expansion))self.shortcut = nn.Sequential()if stride != 1 or in_channels != BasicBlock.expansion * out_channels:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels * BasicBlock.expansion))def forward(self, x):return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))class BottleNeck(nn.Module):expansion = 4'''espansion是通道扩充的比例注意实际输出channel = middle_channels * BottleNeck.expansion'''def __init__(self, in_channels, middle_channels, stride=1):super().__init__()self.residual_function = nn.Sequential(nn.Conv2d(in_channels, middle_channels, kernel_size=1, bias=False),nn.BatchNorm2d(middle_channels),nn.ReLU(inplace=True),nn.Conv2d(middle_channels, middle_channels, stride=stride, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(middle_channels),nn.ReLU(inplace=True),nn.Conv2d(middle_channels, middle_channels * BottleNeck.expansion, kernel_size=1, bias=False),nn.BatchNorm2d(middle_channels * BottleNeck.expansion),)self.shortcut = nn.Sequential()if stride != 1 or in_channels != middle_channels * BottleNeck.expansion:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, middle_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False),nn.BatchNorm2d(middle_channels * BottleNeck.expansion))def forward(self, x):return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))class VGGBlock(nn.Module):def __init__(self, in_channels, middle_channels, out_channels):super().__init__()self.first = nn.Sequential(nn.Conv2d(in_channels, middle_channels, 3, padding=1),nn.BatchNorm2d(middle_channels),nn.ReLU())self.second = nn.Sequential(nn.Conv2d(middle_channels, out_channels, 3, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU())def forward(self, x):out = self.first(x)out = self.second(out)return outclass UResnet(nn.Module):def __init__(self, block, layers, num_classes, input_channels=1):super().__init__()nb_filter = [64, 128, 256, 512, 1024]self.Up = Up()self.in_channel = nb_filter[0]self.pool = nn.MaxPool2d(2, 2)self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])self.conv1_0 = self._make_layer(block,nb_filter[1], layers[0], 1)self.conv2_0 = self._make_layer(block,nb_filter[2], layers[1], 1)self.conv3_0 = self._make_layer(block,nb_filter[3], layers[2], 1)self.conv4_0 = self._make_layer(block,nb_filter[4], layers[3], 1)self.conv3_1 = VGGBlock((nb_filter[3] + nb_filter[4]) * block.expansion, nb_filter[3],nb_filter[3] * block.expansion)self.conv2_2 = VGGBlock((nb_filter[2] + nb_filter[3]) * block.expansion, nb_filter[2],nb_filter[2] * block.expansion)self.conv1_3 = VGGBlock((nb_filter[1] + nb_filter[2]) * block.expansion, nb_filter[1],nb_filter[1] * block.expansion)self.conv0_4 = VGGBlock(nb_filter[0] + nb_filter[1] * block.expansion, nb_filter[0], nb_filter[0])self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)def _make_layer(self, block,middle_channel, num_blocks, stride):'''middle_channels中间维度,实际输出channels = middle_channels * block.expansionnum_blocks,一个Layer包含block的个数'''strides = [stride] + [1] * (num_blocks - 1)layers = []for stride in strides:layers.append(block(self.in_channel, middle_channel, stride))self.in_channel = middle_channel * block.expansionreturn nn.Sequential(*layers)def forward(self, input):x0_0 = self.conv0_0(input)x1_0 = self.conv1_0(self.pool(x0_0))x2_0 = self.conv2_0(self.pool(x1_0))x3_0 = self.conv3_0(self.pool(x2_0))x4_0 = self.conv4_0(self.pool(x3_0))x3_1 = self.conv3_1(torch.cat([x3_0, self.Up(x4_0,x3_0)], 1))x2_2 = self.conv2_2(torch.cat([x2_0, self.Up(x3_1,x2_0)], 1))x1_3 = self.conv1_3(torch.cat([x1_0, self.Up(x2_2,x1_0)], 1))x0_4 = self.conv0_4(torch.cat([x0_0, self.Up(x1_3,x0_0)], 1))output = self.final(x0_4)return output

简单测试模型

UResnet34 = UResnet(block=BasicBlock,layers=[3,4,6,3],num_classes=1) 
x = torch.rand(1,1,1250,55)
out = UResnet34(x)
print(out.shape)

在这里插入图片描述

resnet+unet++网络的搭建

class VGGBlock(nn.Module):def __init__(self, in_channels, middle_channels, out_channels):super().__init__()self.first = nn.Sequential(nn.Conv2d(in_channels, middle_channels, 3, padding=1,bias=False),nn.BatchNorm2d(middle_channels),nn.ReLU(inplace = True))self.second = nn.Sequential(nn.Conv2d(middle_channels, out_channels, 3, padding=1,bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace = True))def forward(self, x):out = self.first(x)out = self.second(out)return outclass Up(nn.Module):  # 将x1上采样,然后调整为x2的大小"""Upscaling then double conv"""def __init__(self):super().__init__()self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)def forward(self, x1, x2):x1 = self.up(x1) # 将传入数据上采样,diffY = torch.tensor([x2.size()[2] - x1.size()[2]])diffX = torch.tensor([x2.size()[3] - x1.size()[3]])x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,diffY // 2, diffY - diffY // 2])  # 填充为x2相同的大小return x1   class BasicBlock(nn.Module):expansion = 1def __init__(self, in_channels, out_channels, stride=1):super().__init__()self.residual_function = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(out_channels * BasicBlock.expansion))self.shortcut = nn.Sequential()if stride != 1 or in_channels != BasicBlock.expansion * out_channels:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels * BasicBlock.expansion))def forward(self, x):return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))class BottleNeck(nn.Module):expansion = 4def __init__(self, in_channels, out_channels, stride=1):super().__init__()self.residual_function = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False),nn.BatchNorm2d(out_channels * BottleNeck.expansion),)self.shortcut = nn.Sequential()if stride != 1 or in_channels != out_channels * BottleNeck.expansion:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False),nn.BatchNorm2d(out_channels * BottleNeck.expansion))def forward(self, x):return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))class NestedUResnet(nn.Module):def __init__(self,block,layers,num_classes, input_channels=1, deep_supervision=False):super().__init__()nb_filter = [64, 128, 256, 512, 1024]self.in_channels = nb_filter[0]self.relu = nn.ReLU()self.deep_supervision = deep_supervisionself.pool = nn.MaxPool2d(2, 2)self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)self.Up = Up()self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])self.conv1_0 = self._make_layer(block,nb_filter[1],layers[0],1)self.conv2_0 = self._make_layer(block,nb_filter[2],layers[1],1)self.conv3_0 = self._make_layer(block,nb_filter[3],layers[2],1)self.conv4_0 = self._make_layer(block,nb_filter[4],layers[3],1)self.conv0_1 = VGGBlock(nb_filter[0] + nb_filter[1] * block.expansion, nb_filter[0], nb_filter[0])self.conv1_1 = VGGBlock((nb_filter[1] +nb_filter[2]) * block.expansion, nb_filter[1], nb_filter[1] * block.expansion)self.conv2_1 = VGGBlock((nb_filter[2] +nb_filter[3]) * block.expansion, nb_filter[2], nb_filter[2] * block.expansion)self.conv3_1 = VGGBlock((nb_filter[3] +nb_filter[4]) * block.expansion, nb_filter[3], nb_filter[3] * block.expansion)self.conv0_2 = VGGBlock(nb_filter[0]*2+nb_filter[1] * block.expansion, nb_filter[0], nb_filter[0])self.conv1_2 = VGGBlock((nb_filter[1]*2+nb_filter[2]) * block.expansion, nb_filter[1], nb_filter[1] * block.expansion)self.conv2_2 = VGGBlock((nb_filter[2]*2+nb_filter[3]) * block.expansion, nb_filter[2], nb_filter[2] * block.expansion)self.conv0_3 = VGGBlock(nb_filter[0]*3+nb_filter[1] * block.expansion, nb_filter[0], nb_filter[0])self.conv1_3 = VGGBlock((nb_filter[1]*3+nb_filter[2]) * block.expansion, nb_filter[1], nb_filter[1] * block.expansion)self.conv0_4 = VGGBlock(nb_filter[0]*4+nb_filter[1] * block.expansion, nb_filter[0], nb_filter[0])if self.deep_supervision:self.final1 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)self.final2 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)self.final3 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)self.final4 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)else:self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)def _make_layer(self,block, middle_channels, num_blocks, stride):strides = [stride] + [1] * (num_blocks - 1)layers = []for stride in strides:layers.append(block(self.in_channels, middle_channels, stride))self.in_channels = middle_channels * block.expansionreturn nn.Sequential(*layers)def forward(self, input):x0_0 = self.conv0_0(input)x1_0 = self.conv1_0(self.pool(x0_0))x0_1 = self.conv0_1(torch.cat([x0_0, self.Up(x1_0,x0_0)], 1))x2_0 = self.conv2_0(self.pool(x1_0))x1_1 = self.conv1_1(torch.cat([x1_0, self.Up(x2_0,x1_0)], 1))x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.Up(x1_1,x0_0)], 1))x3_0 = self.conv3_0(self.pool(x2_0))x2_1 = self.conv2_1(torch.cat([x2_0, self.Up(x3_0,x2_0)], 1))x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.Up(x2_1,x1_0)], 1))x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.Up(x1_2,x0_0)], 1))x4_0 = self.conv4_0(self.pool(x3_0))x3_1 = self.conv3_1(torch.cat([x3_0, self.Up(x4_0,x3_0)], 1))x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.Up(x3_1,x2_0)], 1))x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.Up(x2_2,x1_0)], 1))x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.Up(x1_3,x0_0)], 1))if self.deep_supervision:output1 = self.final1(x0_1)output2 = self.final2(x0_2)output3 = self.final3(x0_3)output4 = self.final4(x0_4)return [output1, output2, output3, output4]else:output = self.final(x0_4)return output

简单测试模型

model = NestedUResnet(block=BottleNeck,layers=[3,4,6,3],num_classes=1)
x = torch.rand(1,1,1250,55)
out = model(x)
print(out.shape)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/661480.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Unity SRP 管线【第七讲:URP LOD实现以及Reflections反射探针】

目录 一、URP LOD 组件1、LOD Group的使用2、LOD切换原理Cross Fade(淡入淡出)模式Animated Cross-Fading如果未设置Clip&#xff0c;并且Fade Transition Width不为0LOD物体烘培 SpeedTree 模式 二、反射探针1. 获取反射探针数据2. 环境光照明 IBL3. 反射探针&#xff08;Refl…

部署前后端分离项目详细教程

部署前后端分离项目详细教程 1、准备工作 首先你需要一台服务器&#xff0c;然后在服务器上安装好你所需要的环境&#xff0c;我这里用的宝塔界面来安装环境。 如果有人不知道怎么安装宝塔界面&#xff0c;可参考这篇文章&#xff0c;如果不知道怎么买服务器&#xff0c;可以参…

菜鸡后端的前端学习记录-2

前言 记录一下看视频学习前端的的一些笔记&#xff0c;以前对Html、Js、CSS有一定的基础&#xff08;都认得&#xff0c;没用过&#xff09;&#xff0c;现在不想从头再来了&#xff0c;学学Vue框架&#xff0c;不定时更新&#xff0c;指不定什么时候就鸽了。。。。 忘了记一下…

前端工程化基础(四):Git代码版本控制工具详解

Git版本控制工具详解 认识版本控制&#xff08;版本控制&#xff09; 是维护 工程蓝图的标准做法&#xff0c;能追踪工程蓝图从诞生一直到定案的过程版本控制也是 一种软件工程技巧&#xff0c;借此能在软件开发的过程中&#xff0c;确保不同的人所编辑的同一程序都能得到同步…

tcp/ip模型中,帧是第几层的数据单元?

在网络通信的世界中&#xff0c;TCP/IP模型以其高效和可靠性而著称。这个模型是现代互联网通信的基石&#xff0c;它定义了数据在网络中如何被传输和接收。其中&#xff0c;一个核心的概念是数据单元的层级&#xff0c;特别是“帧”在这个模型中的位置。今天&#xff0c;我们就…

有趣的css - 简约的动态关注按钮

页面效果 此效果主要使用 css 伪选择器配合 css content 属性&#xff0c;以及 transition(过渡)属性来实现一个简约的动态按钮效果。 此效果可适用于关注按钮、详情按钮等&#xff0c;增强用户交互体验。 核心代码部分&#xff0c;简要说明了写法思路&#xff0c;看 css 部分的…

一文详解docker swarm

文章目录 1、简介1.1、涉及到哪些概念&#xff1f;1.2、需要注意什么&#xff1f; 2、集群管理2.1、创建集群2.2、将节点加入集群2.3、查看集群状态。2.4、将节点从集群中移除2.5、更新集群2.6、锁定/解锁集群 3、节点管理4、服务部署4.1、准备4.2、服务管理4.2.1、常用命令4.2…

[C++]继承(续)

一、基类和派生类对象赋值转换 在public继承时&#xff0c;父类和子类是一个“is - a”的关系。 子类对象赋值给父类对象/父类指针/父类引用&#xff0c;我们认为是天然的&#xff0c;中间不产生临时对象&#xff0c;也叫作父子类赋值兼容规则&#xff08;切割/切片&#xff…

idea查看日志的辅助插件 --- Grep Console (高亮、取消高亮)

&#x1f680; 分享一款很有用的插件&#xff1a;Grep Console &#x1f680; 我们在查看日志的时候可能会有遗漏&#xff0c;使用这款插件可以让特定的关键词高亮&#xff0c;可以达到不遗漏的效果&#xff01; 如果你是一个开发者或者对日志文件分析感兴趣&#xff0c;不要…

记录一次使用ant design 中 ConfigProvider来修改样式导致样式改变的问题(Tabs嵌套Tabs)

一 说明 继之前的一篇文章&#xff1a;antd5 Tabs 标签头的文本颜色和背景颜色修改 后&#xff0c;发现在被修改后的Tab中继续嵌套Tabs组件&#xff0c;这个新的Tabs组件样式跟外层Tabs样式也是一致的&#xff0c;如下图所示&#xff1a; 二 原因 在修改外层tabs样式时&…

又涨又跌 近期现货黄金价格波动怎么看?

踏入2024年一月的下旬&#xff0c;现货黄金价格可以说没了之前火热的状态&#xff0c;盘面上是又涨又跌。面对这样的行情&#xff0c;很多投资者不知道如何看了。下面我们就来讨论一下怎么把握近期的行情。 先区分走势类型。在现货黄金市场中有两种主要的走势类型&#xff0c;一…

【SpringCloud】使用OpenFeign进行微服务化改造

目录 一、需求与背景二、OpenFeign 远程调用技术原理三、项目代码演示3.1 引入依赖3.2 实现OpenFeign注解修饰接口3.3 指定 OpenFeign 远程调用接口的扫描路径 四、OpenFeign 在日志中打印Request和Response五、OpenFeign 客户端超时配置六、使用 OpenFeign 实现服务降级6.1 实…

《区块链简易速速上手小册》第10章:区块链的未来与趋势(2024 最新版)

文章目录 10.1 区块链的未来展望10.1.1 基础知识10.1.2 主要案例&#xff1a;区块链在金融领域的发展10.1.3 拓展案例 1&#xff1a;区块链在供应链管理中的应用10.1.4 拓展案例 2&#xff1a;区块链在身份管理和隐私保护中的应用 10.2 新兴技术与区块链的融合10.2.1 基础知识1…

智能家居的网关新形态:Aqara 方舟智慧中枢 M3 体验

如果说在刚刚结束的 2023 年有哪些备受期待的智能家居产品&#xff0c;Aqara 方舟智慧中枢 M3 一定榜上有名&#xff0c;我的多位朋友也曾在装修过程中多次向我询问是否有这款产品的相关资讯&#xff1b;谁能想到自从在 2022 年 11 月首次亮相之后&#xff0c;这款产品一直等了…

vulhub靶机activemq环境下的CVE-2015-5254(ActiveMQ 反序列化漏洞)

影响范围 Apache ActiveMQ 5.x ~ Apache ActiveMQ 5.13.0 远程攻击者可以制作一个特殊的序列化 Java 消息服务 (JMS) ObjectMessage 对象&#xff0c;利用该漏洞执行任意代码。 漏洞搭建 没有特殊要求&#xff0c;请看 (3条消息) vulhub搭建方法_himobrinehacken的博客-CSD…

iOS图像处理----探索图片解压缩到渲染的全过程以及屏幕卡顿

一&#xff1a;图像成像过程 ①、将需要显示的图像&#xff0c;由CPU和GPU通过总线连接起来&#xff0c;在CPU中输出的位图经总线在合适的时机上传给GPU &#xff0c;GPU拿到位图做相应位图的图层渲染、纹理合成。 ②、将渲染后的结果&#xff0c;存储到帧缓存区&#xff0c;帧…

【脑电信号处理与特征提取】P7-贾会宾:基于EEG/MEG信号的大尺度脑功能网络分析

基于EEG/MEG信号的大尺度脑功能网络分析 Q: 什么是基于EEG/MEG信号的大尺度脑功能网络分析&#xff1f; A: 基于脑电图&#xff08;EEG&#xff09;或脑磁图&#xff08;MEG&#xff09;信号的大尺度脑功能网络分析是一种研究大脑活动的方法&#xff0c;旨在探索脑区之间的功能…

【JavaSE篇】——继承

目录 &#x1f393;继承 ✅为什么需要继承 ✅继承概念 ✅继承的语法 ✅父类成员访问 &#x1f6a9;子类中访问父类的成员变量 1. 子类和父类不存在同名成员变量的情况 2. 子类和父类成员变量同名 &#x1f6a9;子类中访问父类的成员方法 1. 成员方法名字不同 2. 成员…

SAM:基于 prompt 的通用图像分割模型

Paper: Kirillov A, Mintun E, Ravi N, et al. Segment anything[J]. arXiv preprint arXiv:2304.02643, 2023. Introduction: https://segment-anything.com/ Code: https://github.com/facebookresearch/segment-anything SAM 是 Meta AI 开发的一款基于 prompt 的通用视觉大…

100%涨点!2024最新卷积块创新方案盘点(附模块和代码)

在写论文时&#xff0c;设计高效、创新的卷积块可以显著提升模型的性能&#xff0c;保障工作的有效性和先进性。另外&#xff0c;合理利用卷积块还可以帮助我们提升实验结果、拓展研究的视野和应用场景&#xff0c;增加论文的创新点。因此&#xff0c;对于论文er来说&#xff0…