基于U-Net的视网膜血管分割(Pytorch完整版)

基于 U-Net 的视网膜血管分割是一种应用深度学习的方法,特别是 U-Net 结构,用于从眼底图像中分割出视网膜血管。U-Net 是一种全卷积神经网络(FCN),通常用于图像分割任务。以下是基于 U-Net 的视网膜血管分割的内容:
框架结构:
在这里插入图片描述
代码结构:
在这里插入图片描述

U-Net分割代码:

unet_model.py

import torch.nn.functional as F
from .unet_parts import *
class UNet(nn.Module):def __init__(self, n_channels, n_classes, bilinear=True):super(UNet, self).__init__()self.n_channels = n_channelsself.n_classes = n_classesself.bilinear = bilinearself.inc = DoubleConv(n_channels, 64)self.down1 = Down(64, 128)self.down2 = Down(128, 256)self.down3 = Down(256, 512)self.down4 = Down(512, 512)self.up1 = Up(1024, 256, bilinear)self.up2 = Up(512, 128, bilinear)self.up3 = Up(256, 64, bilinear)self.up4 = Up(128, 64, bilinear)self.outc = OutConv(64, n_classes)def forward(self, x):x1 = self.inc(x)# 在编码器下采样过程加空间注意力# x2 = self.down1(self.sp1(x1))# x3 = self.down2(self.sp2(x2))# x4 = self.down3(self.sp3(x3))# x5 = self.down4(self.sp4(x4))x2 = self.down1(x1)x3 = self.down2(x2)x4 = self.down3(x3)x5 = self.down4(x4)x = self.up1(x5, x4)x = self.up2(x, x3)x = self.up3(x, x2)x = self.up4(x, x1)logits = self.outc(x)return logitsif __name__ == '__main__':net = UNet(n_channels=3, n_classes=1)print(net)

unet_parts.py

import torch
import torch.nn as nn
import torch.nn.functional as Fclass DoubleConv(nn.Module):"""(convolution => [BN] => ReLU) * 2"""def __init__(self, in_channels, out_channels):super().__init__()self.double_conv = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight)elif isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.constant_(m.bias, 0)def forward(self, x):return self.double_conv(x)class Down(nn.Module):"""Downscaling with maxpool then double conv"""def __init__(self, in_channels, out_channels):super().__init__()self.maxpool_conv = nn.Sequential(nn.MaxPool2d(2),DoubleConv(in_channels, out_channels))def forward(self, x):return self.maxpool_conv(x)class Up(nn.Module):"""Upscaling then double conv"""def __init__(self, in_channels, out_channels, bilinear=True):super().__init__()# if bilinear, use the normal convolutions to reduce the number of channelsif bilinear:self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)else:self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)self.conv = DoubleConv(in_channels, out_channels)def forward(self, x1, x2):x1 = self.up(x1)# input is CHWdiffY = torch.tensor([x2.size()[2] - x1.size()[2]])diffX = torch.tensor([x2.size()[3] - x1.size()[3]])x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,diffY // 2, diffY - diffY // 2])x = torch.cat([x2, x1], dim=1)return self.conv(x)class OutConv(nn.Module):def __init__(self, in_channels, out_channels):super(OutConv, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight)elif isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.constant_(m.bias, 0)def forward(self, x):return self.conv(x)

trainval.py

from model.unet_model import UNet
from utils.dataset import FundusSeg_Loaderfrom torch import optim
import torch.nn as nn
import torch
import sys
import matplotlib.pyplot as plt
from tqdm import tqdm
import timetrain_data_path = "DRIVE/drive_train/"
valid_data_path = "DRIVE/drive_test/"
# hyperparameter-settings
N_epochs = 500
Init_lr = 0.00001def train_net(net, device, epochs=N_epochs, batch_size=1, lr=Init_lr):# 加载训练集train_dataset = FundusSeg_Loader(train_data_path, 1)valid_dataset = FundusSeg_Loader(valid_data_path, 0)train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset, batch_size=batch_size, shuffle=False)print('Traing images: %s' % len(train_loader))print('Valid  images: %s' % len(valid_loader))# 定义RMSprop算法optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)# 定义Loss算法# BCEWithLogitsLoss会对predict进行sigmoid处理# criterion 常被用来定义损失函数,方便调换损失函数criterion = nn.BCEWithLogitsLoss()# 训练epochs次# 求最小值,所以初始化为正无穷best_loss = float('inf')train_loss_list = []val_loss_list = []for epoch in range(epochs):# 训练模式net.train()train_loss = 0print(f'Epoch {epoch + 1}/{epochs}')# SGD# train_loss_list = []# val_loss_list = []with tqdm(total=train_loader.__len__()) as pbar:for i, (image, label, filename) in enumerate(train_loader):optimizer.zero_grad()# 将数据拷贝到device中image = image.to(device=device, dtype=torch.float32)label = label.to(device=device, dtype=torch.float32)# 使用网络参数,输出预测结果pred = net(image)# print(pred)# 计算lossloss = criterion(pred, label)# print(loss)train_loss = train_loss + loss.item()loss.backward()optimizer.step()pbar.set_postfix(loss=float(loss.cpu()), epoch=epoch)pbar.update(1)train_loss_list.append(train_loss / i)print('Loss/train', train_loss / i)# Validationnet.eval()val_loss = 0for i, (image, label, filename) in tqdm(enumerate(valid_loader), total=len(valid_loader)):image = image.to(device=device, dtype=torch.float32)label = label.to(device=device, dtype=torch.float32)pred = net(image)loss = criterion(pred, label)val_loss = val_loss + loss.item()# net.state_dict()就是用来保存模型参数的if val_loss < best_loss:best_loss = val_losstorch.save(net.state_dict(), 'best_model.pth')print('saving model............................................')val_loss_list.append(val_loss / i)print('Loss/valid', val_loss / i)sys.stdout.flush()return val_loss_list, train_loss_listif __name__ == "__main__":# 选择设备cudadevice = torch.device('cuda')# 加载网络,图片单通道1,分类为1。net = UNet(n_channels=3, n_classes=1)# 将网络拷贝到deivce中net.to(device=device)# 开始训练val_loss_list, train_loss_list = train_net(net, device)# 保存loss值到txt文件fileObject1 = open('train_loss.txt', 'w')for train_loss in train_loss_list:fileObject1.write(str(train_loss))fileObject1.write('\n')fileObject1.close()fileObject2 = open('val_loss.txt', 'w')for val_loss in val_loss_list:fileObject2.write(str(val_loss))fileObject2.write('\n')fileObject2.close()# 我这里迭代了5次,所以x的取值范围为(0,5),然后再将每次相对应的5损失率附在x上x = range(0, N_epochs)y1 = val_loss_listy2 = train_loss_list# 两行一列第一个plt.subplot(1, 1, 1)plt.plot(x, y1, 'r.-', label=u'val_loss')plt.plot(x, y2, 'g.-', label =u'train_loss')plt.title('loss')plt.xlabel('epochs')plt.ylabel('loss')plt.savefig("accuracy_loss.jpg")plt.show()

predict.py

import numpy as np
import torch
import cv2
from model.unet_model import UNetfrom utils.dataset import FundusSeg_Loader
import copy
from sklearn.metrics import roc_auc_scoremodel_path='./best_model.pth'
test_data_path = "DRIVE/drive_test/"
save_path='./results/'if __name__ == "__main__":test_dataset = FundusSeg_Loader(test_data_path,0)test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)print('Testing images: %s' %len(test_loader))# 选择设备CUDAdevice = torch.device('cuda')# 加载网络,图片单通道,分类为1。net = UNet(n_channels=3, n_classes=1)# 将网络拷贝到deivce中net.to(device=device)# 加载模型参数print(f'Loading model {model_path}')net.load_state_dict(torch.load(model_path, map_location=device))# 测试模式net.eval()tp = 0tn = 0fp = 0fn = 0pred_list = []label_list = []for image, label, filename in test_loader:image = image.to(device=device, dtype=torch.float32)pred = net(image)# Normalize to [0, 1]pred = torch.sigmoid(pred)pred = np.array(pred.data.cpu()[0])[0]pred_list.append(pred)# ConfusionMAtrixpred_bin = copy.deepcopy(pred)label = np.array(label.data.cpu()[0])[0]label_list.append(label)pred_bin[pred_bin >= 0.5] = 1pred_bin[pred_bin < 0.5] = 0tp += ((pred_bin == 1) & (label == 1)).sum()tn += ((pred_bin == 0) & (label == 0)).sum()fn += ((pred_bin == 0) & (label == 1)).sum()fp += ((pred_bin == 1) & (label == 0)).sum()# 保存图片pred = pred * 255save_filename = save_path + filename[0] + '.png'cv2.imwrite(save_filename, pred)print(f'{save_filename} done!')# Evaluaiton Indicatorsprecision = tp / (tp + fp)   # 预测为真并且正确/预测正确样本总和sen = tp / (tp + fn)    # 预测为真并且正确/正样本总和spe = tn / (tn + fp)acc = (tp + tn) / (tp + tn + fp + fn)f1score = 2 * precision * sen / (precision + sen)# auc computingpred_auc = np.stack(pred_list, axis=0)label_auc = np.stack(label_list, axis=0)auc = roc_auc_score(label_auc.reshape(-1), pred_auc.reshape(-1))print(f'Precision: {precision} Sen: {sen} Spe:{spe} F1-score: {f1score} Acc: {acc} AUC: {auc}')

dataset.py

import torch
import cv2
import os
import glob
from torch.utils.data import Dataset# import random
# from PIL import Image
# import numpy as npclass FundusSeg_Loader(Dataset):def __init__(self, data_path, is_train):# 初始化函数,读取所有data_path下的图片self.data_path = data_pathself.imgs_path = glob.glob(os.path.join(data_path, 'image/*.tif'))self.labels_path = glob.glob(os.path.join(data_path, 'label/*.tif'))self.is_train = is_trainprint(self.imgs_path)print(self.labels_path)def __getitem__(self, index):# 根据index读取图片image_path = self.imgs_path[index]if self.is_train == 1:label_path = image_path.replace('image', 'label')label_path = label_path.replace('training', 'manual1')else:label_path = image_path.replace('image', 'label')label_path = label_path.replace('test.tif', 'manual1.tif')# 读取训练图片和标签图片image = cv2.imread(image_path)label = cv2.imread(label_path)# image = np.array(image)# label = np.array(label)# label = cv2.imread(label_path)# image = cv2.resize(image, (600,400))# label = cv2.resize(label, (600,400))# 转为单通道的图片# image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)label = cv2.cvtColor(label, cv2.COLOR_BGR2GRAY)# label = Image.fromarray(label)# label = label.convert("1")# reshape()函数可以改变数组的形状,并且原始数据不发生变化。image = image.transpose(2, 0, 1)# image = image.reshape(1, label.shape[0], label.shape[1])label = label.reshape(1, label.shape[0], label.shape[1])# 处理标签,将像素值为255的改为1if label.max() > 1:label[label > 1] = 1return image, label, image_path[len(image_path)-12:len(image_path)-4]def __len__(self):# 返回训练集大小return len(self.imgs_path)

visual.py

import numpy as np
import matplotlib.pyplot as plt
import pylab as plfrom mpl_toolkits.axes_grid1.inset_locator import inset_axes
data1_loss =np.loadtxt("E:\\code\\UNet_lr00001\\train_loss.txt",dtype=str )
data2_loss = np.loadtxt("E:\\code\\UNet_lr00001\\val_loss.txt",dtype=str)
x = range(0,10)
y = data1_loss[:, 0]
x1 = range(0,10)
y1 = data2_loss[:, 0]
fig = plt.figure(figsize = (7,5))    #figsize是图片的大小`
ax1 = fig.add_subplot(1, 1, 1) # ax1是子图的名字`
pl.plot(x,y,'g-',label=u'Dense_Unet(block layer=5)')
# ‘'g‘'代表“green”,表示画出的曲线是绿色,“-”代表画的曲线是实线,可自行选择,label代表的是图例的名称,一般要在名称前面加一个u,如果名称是中文,会显示不出来,目前还不知道怎么解决。
p2 = pl.plot(x, y,'r-', label = u'train_loss')
pl.legend()
#显示图例
p3 = pl.plot(x1,y1, 'b-', label = u'val_loss')
pl.legend()
pl.xlabel(u'epoch')
pl.ylabel(u'loss')
plt.title('Compare loss for different models in training')

在这里插入图片描述
在这里插入图片描述

这种基于 U-Net 的方法已在医学图像分割领域取得了一些成功,特别是在视网膜图像处理中。通过深度学习的方法,这种技术能够更准确地提取视网膜血管,为眼科医生提供辅助诊断和治疗的信息。
如有疑问,请评论。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/172044.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

mysql高级知识点

一、mysql架构 连接层&#xff1a;负责接收客户端的连接请求&#xff0c;可以进行授权、认证(验证账号密码)。服务层&#xff1a;负责调用sql接口&#xff0c;对sql语法进行解析&#xff0c;对查询进行优化&#xff0c;缓存。引擎层&#xff1a;是真正进行执行sql的地方&#x…

Linux面试题(二)

目录 17、怎么使一个命令在后台运行? 18、利用 ps 怎么显示所有的进程? 怎么利用 ps 查看指定进程的信息&#xff1f; 19、哪个命令专门用来查看后台任务? 20、把后台任务调到前台执行使用什么命令?把停下的后台任务在后台执行起来用什么命令? 21、终止进程用什么命令…

Vue框架学习笔记——事件修饰符

文章目录 前文提要事件修饰符prevent&#xff08;常用&#xff09;stop&#xff08;不常用&#xff09;事件冒泡stop使用方法三层嵌套下的stop三层嵌套看出的stop&#xff1a; once&#xff08;常用&#xff09;capture&#xff08;不常用&#xff09;self&#xff08;不常用&a…

Vue轻松入门,附带学习笔记和相关案例

目录 一Vue基础 什么是Vue&#xff1f; 补充&#xff1a;mvvm框架 mvvm的组成 详解 Vue的使用方法 1.直接下载并引入 2.通过 CDN 使用 Vue 3.通过npm安装 4.使用Vue CLI创建项目 二插值表达式 什么是插值表达式&#xff1f; 插值表达式的缺点 解决方法 相关代…

【数据结构】树与二叉树(廿五):树搜索指定数据域的结点(算法FindTarget)

文章目录 5.3.1 树的存储结构5. 左儿子右兄弟链接结构 5.3.2 获取结点的算法1. 获取大儿子、大兄弟结点2. 搜索给定结点的父亲3. 搜索指定数据域的结点a. 算法FindTargetb. 算法解析c. 代码实现a. 使用指向指针的指针b. 直接返回找到的节点 4. 代码整合 5.3.1 树的存储结构 5.…

VUE限制文件上传大小和上传格式

<el-form-item label"图片&#xff1a;" prop"tempImagePath"><el-uploadclass"upload"accept"image/jpeg":show-file-list"false"list-type"picture-card":headers"{ token: token}":action&…

linux的netstat命令和ss命令

1. 网络状态 State状态LISTENING监听中&#xff0c;服务端需要打开一个socket进行监听&#xff0c;侦听来自远方TCP端口的连接请求ESTABLISHED已连接&#xff0c;代表一个打开的连接&#xff0c;双方可以进行或已经在数据交互了SYN_SENT客户端通过应用程序调用connect发送一个…

人力资源管理后台 === 基础环境+登陆

目录 1.人力资源项目介绍 1.1 项目架构和解决方案 1.2 课程安排 1.3 课程具备能力 1.4 课程地址 2. 拉取项目基础代码 3.项目目录和入口文件介绍 4.App.vue根组件解析 5.基础设置settings.js和导航守卫permission.js 6.Vuex的结构 7.使用模板中的Icon图标 8.扩展…

最新世界银行WDI面板数据(1960-2022年)

The World Development Indicators 是由世界银行编制和发布的全面数据集&#xff0c;旨在提供全球发展的详尽统计信息。这份数据集收录了1960-2022年间&#xff0c;世界266个国家共计1477个指标&#xff0c;涵盖经济、社会、环境、教育、公共卫生等20个领域 一、数据介绍 数据…

chromium通信系统-mojo系统(一)-ipcz系统代码实现-同Node通信

在chromium通信系统-mojo系统(一)-ipcz系统基本概念一文中我们介绍了ipcz的基本概念。 本章我们来通过代码分析它的实现。 handle系统 为了不对上层api暴露太多细节&#xff0c;实现解耦&#xff0c;也方便于传输&#xff0c;ipcz系统使用handle表示一个对象&#xff0c;hand…

MySQL基本SQL语句(下)

MySQL基本SQL语句&#xff08;下&#xff09; 一、扩展常见的数据类型 1、回顾数据表的创建语法 基本语法&#xff1a; mysql> create table 数据表名称(字段名称1 字段类型 字段约束,字段名称2 字段类型 字段约束,...primary key(主键字段 > 不能为空、必须唯一) ) …

WebSocket协议测试实战

当涉及到WebSocket协议测试时&#xff0c;有几个关键方面需要考虑。在本文中&#xff0c;我们将探讨如何使用Python编写WebSocket测试&#xff0c;并使用一些常见的工具和库来简化测试过程。 1、什么是WebSocket协议&#xff1f; WebSocket是一种在客户端和服务器之间提供双向…

KubeVela核心控制器原理浅析

前言 在学习 KubeVela 的核心控制器之前&#xff0c;我们先简单了解一下 KubeVela 的相关知识。 KubeVela 本身是一个应用交付与管理控制平面&#xff0c;它架在 Kubernetes 集群、云平台等基础设施之上&#xff0c;通过开放应用模型来对组件、云服务、运维能力、交付工作流进…

4G模块(EC600N)通过MQTT连接华为云

目录 一、前言 二、EC600N模块使用 1&#xff0e;透传模式 2&#xff0e;非透传模式 3、华为云的MQTT使用教程&#xff1a; 三、具体连接步骤 1、初始化检测 2、打开MQTT客户端网络 3、创建产品 4、创建模型 5、注册设备 6、连接客户端到MQTT服务器 7、发布主题消…

Redis面试题:Redis的数据过期策略有哪些?

目录 面试官&#xff1a;Redis的数据过期策略有哪些 ? 惰性删除 定期删除 面试官&#xff1a;Redis的数据过期策略有哪些 ? 候选人&#xff1a; 嗯~&#xff0c;在redis中提供了两种数据过期删除策略 第一种是惰性删除&#xff0c;在设置该key过期时间后&#xff0c;我们…

Stm32CubeMx生成代码提示缺少“core_cm3.h“

Stm32CubeMx生成代码提示缺少"core_cm3.h" 1.原因分析 1.1问题根源 在我们使用本地解压的方法去安装固件包,但是找错了要下载的固件包&#x1f60a;.在你点击进入下载页面之后,能看到一共有两个下载链接,其中上面的是补丁包,而第二个才是我们应该要下载的固件包 当…

【Web-Note】 JavaScript概述

JavaSript基本语法 JavaSript程序不能独立运行&#xff0c;必须依赖于HTML文件。 <script type "text/javascript" [src "外部文件"]> JS语句块; </script> script标记是成对标记。 type属性&#xff1a;说明脚本的类型。 "text/jav…

王者农药小游戏

游戏运行如下&#xff1a; sxt Background package sxt;import java.awt.*; //背景类 public class Background extends GameObject{public Background(GameFrame gameFrame) {super(gameFrame);}Image bg Toolkit.getDefaultToolkit().getImage("C:\\Users\\24465\\D…

【数据分享】我国12.5米分辨率的坡向数据(免费获取)

地形数据&#xff0c;也叫DEM数据&#xff0c;是我们在各项研究中最常使用的数据之一。之前我们分享过源于NASA地球科学数据网站发布的12.5米分辨率DEM地形数据&#xff01;基于该数据我们处理得到12.5米分辨率的坡度数据、12.5米分辨率的山体阴影数据&#xff08;均可查看之前…

【Hadoop】分布式文件系统 HDFS

目录 一、介绍二、HDFS设计原理2.1 HDFS 架构2.2 数据复制复制的实现原理 三、HDFS的特点四、图解HDFS存储原理1. 写过程2. 读过程3. HDFS故障类型和其检测方法故障类型和其检测方法读写故障的处理DataNode 故障处理副本布局策略 一、介绍 HDFS &#xff08;Hadoop Distribute…