深度学习实战基础案例——卷积神经网络(CNN)基于SqueezeNet的眼疾识别|第1例

文章目录

  • 前言
  • 一、数据准备
      • 1.1 数据集介绍
      • 1.2 数据集文件结构
  • 二、项目实战
      • 2.1 数据标签划分
      • 2.2 数据预处理
      • 2.3 构建模型
      • 2.4 开始训练
      • 2.5 结果可视化
  • 三、数据集个体预测

前言

SqueezeNet是一种轻量且高效的CNN模型,它参数比AlexNet少50倍,但模型性能(accuracy)与AlexNet接近。顾名思义,Squeeze的中文意思是压缩和挤压的意思,所以我们通过算法的名字就可以猜想到,该算法一定是通过压缩模型来降低模型参数量的。当然任何算法的改进都是在原先的基础上提升精度或者降低模型参数,因此该算法的主要目的就是在于降低模型参数量的同时保持模型精度。


我的环境:

  • 基础环境:python3.7
  • 编译器:pycharm
  • 深度学习框架:pytorch
  • 数据集代码获取:链接(提取码:2357 )

一、数据准备

本案例使用的数据集是眼疾识别数据集iChallenge-PM。

1.1 数据集介绍

iChallenge-PM是百度大脑和中山大学中山眼科中心联合举办的iChallenge比赛中,提供的关于病理性近视(Pathologic Myopia,PM)的医疗类数据集,包含1200个受试者的眼底视网膜图片,训练、验证和测试数据集各400张。

  • training.zip:包含训练中的图片和标签
  • validation.zip:包含验证集的图片
  • valid_gt.zip:包含验证集的标签

该数据集是从AI Studio平台中下载的,具体信息如下:
在这里插入图片描述

1.2 数据集文件结构

数据集中共有三个压缩文件,分别是:

  • training.zip
├── PALM-Training400
│   ├── PALM-Training400.zip
│   │   ├── H0002.jpg
│   │   └── ...
│   ├── PALM-Training400-Annotation-D&F.zip
│   │   └── ...
│   └── PALM-Training400-Annotation-Lession.zip└── ...
  • valid_gt.zip:标记的位置 里面的PM_Lable_and_Fovea_Location.xlsx就是标记文件
├── PALM-Validation-GT
│   ├── Lession_Masks
│   │   └── ...
│   ├── Disc_Masks
│   │   └── ...
│   └── PM_Lable_and_Fovea_Location.xlsx
  • validation.zip:测试数据集
├── PALM-Validation
│   ├── V0001.jpg
│   ├── V0002.jpg
│   └── ...

二、项目实战

项目结构如下:
在这里插入图片描述

2.1 数据标签划分

该眼疾数据集格式有点复杂,这里我对数据集进行了自己的处理,将训练集和验证集写入txt文本里面,分别对应它的图片路径和标签。

import os
import pandas as pd
# 将训练集划分标签
train_dataset = r"F:\SqueezeNet\data\PALM-Training400\PALM-Training400"
train_list = []
label_list = []train_filenames = os.listdir(train_dataset)for name in train_filenames:filepath = os.path.join(train_dataset, name)train_list.append(filepath)if name[0] == 'N' or name[0] == 'H':label = 0label_list.append(label)elif name[0] == 'P':label = 1label_list.append(label)else:raise('Error dataset!')with open('F:/SqueezeNet/train.txt', 'w', encoding='UTF-8') as f:i = 0for train_img in train_list:f.write(str(train_img) + ' ' +str(label_list[i]))i += 1f.write('\n')
# 将验证集划分标签
valid_dataset = r"F:\SqueezeNet\data\PALM-Validation400"
valid_filenames = os.listdir(valid_dataset)
valid_label = r"F:\SqueezeNet\data\PALM-Validation-GT\PM_Label_and_Fovea_Location.xlsx"
data = pd.read_excel(valid_label)
valid_data = data[['imgName', 'Label']].values.tolist()with open('F:/SqueezeNet/valid.txt', 'w', encoding='UTF-8') as f:for valid_img in valid_data:f.write(str(valid_dataset) + '/' + valid_img[0] + ' ' + str(valid_img[1]))f.write('\n')

2.2 数据预处理

这里采用到的数据预处理,主要有调整图像大小、随机翻转、归一化等。

import os.path
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import transformstransform_BZ = transforms.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5]
)class LoadData(Dataset):def __init__(self, txt_path, train_flag=True):self.imgs_info = self.get_images(txt_path)self.train_flag = train_flagself.train_tf = transforms.Compose([transforms.Resize(224),  # 调整图像大小为224x224transforms.RandomHorizontalFlip(),  # 随机左右翻转图像transforms.RandomVerticalFlip(),  # 随机上下翻转图像transforms.ToTensor(),  # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间transform_BZ  # 执行某些复杂变换操作])self.val_tf = transforms.Compose([transforms.Resize(224),  # 调整图像大小为224x224transforms.ToTensor(),  # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间transform_BZ  # 执行某些复杂变换操作])def get_images(self, txt_path):with open(txt_path, 'r', encoding='utf-8') as f:imgs_info = f.readlines()imgs_info = list(map(lambda x: x.strip().split(' '), imgs_info))return imgs_infodef padding_black(self, img):w, h = img.sizescale = 224. / max(w, h)img_fg = img.resize([int(x) for x in [w * scale, h * scale]])size_fg = img_fg.sizesize_bg = 224img_bg = Image.new("RGB", (size_bg, size_bg))img_bg.paste(img_fg, ((size_bg - size_fg[0]) // 2,(size_bg - size_fg[1]) // 2))img = img_bgreturn imgdef __getitem__(self, index):img_path, label = self.imgs_info[index]img_path = os.path.join('', img_path)img = Image.open(img_path)img = img.convert("RGB")img = self.padding_black(img)if self.train_flag:img = self.train_tf(img)else:img = self.val_tf(img)label = int(label)return img, labeldef __len__(self):return len(self.imgs_info)

2.3 构建模型

import torch
import torch.nn as nn
import torch.nn.init as initclass Fire(nn.Module):def __init__(self, inplanes, squeeze_planes,expand1x1_planes, expand3x3_planes):super(Fire, self).__init__()self.inplanes = inplanesself.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)self.squeeze_activation = nn.ReLU(inplace=True)self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes,kernel_size=1)self.expand1x1_activation = nn.ReLU(inplace=True)self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes,kernel_size=3, padding=1)self.expand3x3_activation = nn.ReLU(inplace=True)def forward(self, x):x = self.squeeze_activation(self.squeeze(x))return torch.cat([self.expand1x1_activation(self.expand1x1(x)),self.expand3x3_activation(self.expand3x3(x))], 1)class SqueezeNet(nn.Module):def __init__(self, version='1_0', num_classes=1000):super(SqueezeNet, self).__init__()self.num_classes = num_classesif version == '1_0':self.features = nn.Sequential(nn.Conv2d(3, 96, kernel_size=7, stride=2),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),Fire(96, 16, 64, 64),Fire(128, 16, 64, 64),Fire(128, 32, 128, 128),nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),Fire(256, 32, 128, 128),Fire(256, 48, 192, 192),Fire(384, 48, 192, 192),Fire(384, 64, 256, 256),nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),Fire(512, 64, 256, 256),)elif version == '1_1':self.features = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, stride=2),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),Fire(64, 16, 64, 64),Fire(128, 16, 64, 64),nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),Fire(128, 32, 128, 128),Fire(256, 32, 128, 128),nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),Fire(256, 48, 192, 192),Fire(384, 48, 192, 192),Fire(384, 64, 256, 256),Fire(512, 64, 256, 256),)else:# FIXME: Is this needed? SqueezeNet should only be called from the# FIXME: squeezenet1_x() functions# FIXME: This checking is not done for the other modelsraise ValueError("Unsupported SqueezeNet version {version}:""1_0 or 1_1 expected".format(version=version))# Final convolution is initialized differently from the restfinal_conv = nn.Conv2d(512, self.num_classes, kernel_size=1)self.classifier = nn.Sequential(nn.Dropout(p=0.5),final_conv,nn.ReLU(inplace=True),nn.AdaptiveAvgPool2d((1, 1)))for m in self.modules():if isinstance(m, nn.Conv2d):if m is final_conv:init.normal_(m.weight, mean=0.0, std=0.01)else:init.kaiming_uniform_(m.weight)if m.bias is not None:init.constant_(m.bias, 0)def forward(self, x):x = self.features(x)x = self.classifier(x)return torch.flatten(x, 1)

2.4 开始训练

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from model import SqueezeNet
import torchsummary
from dataloader import LoadData
import copydevice = "cuda:0" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))model = SqueezeNet(num_classes=2).to(device)
# print(model)
#print(torchsummary.summary(model, (3, 224, 224), 1))# 加载训练集和验证集
train_data = LoadData(r"F:\SqueezeNet\train.txt", True)
train_dl = torch.utils.data.DataLoader(train_data, batch_size=16, pin_memory=True,shuffle=True, num_workers=0)
test_data = LoadData(r"F:\SqueezeNet\valid.txt", True)
test_dl = torch.utils.data.DataLoader(test_data, batch_size=16, pin_memory=True,shuffle=True, num_workers=0)# 编写训练函数
def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)  # 训练集的大小num_batches = len(dataloader)  # 批次数目, (size/batch_size,向上取整)print('num_batches:', num_batches)train_loss, train_acc = 0, 0  # 初始化训练损失和正确率for X, y in dataloader:  # 获取图片及其标签X, y = X.to(device), y.to(device)# 计算预测误差pred = model(X)  # 网络输出loss = loss_fn(pred, y)  # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失# 反向传播optimizer.zero_grad()  # grad属性归零loss.backward()  # 反向传播optimizer.step()  # 每一步自动更新# 记录acc与losstrain_acc += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc /= sizetrain_loss /= num_batchesreturn train_acc, train_loss# 编写验证函数
def test(dataloader, model, loss_fn):size = len(dataloader.dataset)  # 测试集的大小num_batches = len(dataloader)  # 批次数目, (size/batch_size,向上取整)test_loss, test_acc = 0, 0# 当不进行训练时,停止梯度更新,节省计算内存消耗with torch.no_grad():for imgs, target in dataloader:imgs, target = imgs.to(device), target.to(device)# 计算losstarget_pred = model(imgs)loss = loss_fn(target_pred, target)test_loss += loss.item()test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()test_acc /= sizetest_loss /= num_batchesreturn test_acc, test_loss# 开始训练epochs = 20train_loss = []
train_acc = []
test_loss = []
test_acc = []best_acc = 0  # 设置一个最佳准确率,作为最佳模型的判别指标loss_function = nn.CrossEntropyLoss()  # 定义损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # 定义Adam优化器for epoch in range(epochs):model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_function, optimizer)model.eval()epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_function)# 保存最佳模型到 best_modelif epoch_test_acc > best_acc:best_acc = epoch_test_accbest_model = copy.deepcopy(model)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)# 获取当前的学习率lr = optimizer.state_dict()['param_groups'][0]['lr']template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')print(template.format(epoch + 1, epoch_train_acc * 100, epoch_train_loss,epoch_test_acc * 100, epoch_test_loss, lr))# 保存最佳模型到文件中
PATH = './best_model.pth'  # 保存的参数文件名
torch.save(best_model.state_dict(), PATH)print('Done')

在这里插入图片描述

2.5 结果可视化

import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")               #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 100        #分辨率epochs_range = range(epochs)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Test Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Test Loss')
plt.show()

可视化结果如下:
在这里插入图片描述
可以自行调整学习率以及batch_size,这里我的超参数并没有调整。

三、数据集个体预测

import matplotlib.pyplot as plt
from PIL import Image
from torchvision.transforms import transforms
from model import SqueezeNet
import torchdata_transform = transforms.Compose([transforms.ToTensor(),transforms.Resize((224, 224)),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])img = Image.open("F:\SqueezeNet\data\PALM-Validation400\V0008.jpg")
plt.imshow(img)
img = data_transform(img)
img = torch.unsqueeze(img, dim=0)
name = ['非病理性近视', '病理性近视']
model_weight_path = r"F:\SqueezeNet\best_model.pth"
model = SqueezeNet(num_classes=2)
model.load_state_dict(torch.load(model_weight_path))
model.eval()
with torch.no_grad():output = torch.squeeze(model(img))predict = torch.softmax(output, dim=0)# 获得最大可能性索引predict_cla = torch.argmax(predict).numpy()print('索引为', predict_cla)
print('预测结果为:{},置信度为: {}'.format(name[predict_cla], predict[predict_cla].item()))
plt.show()
索引为 1
预测结果为:病理性近视,置信度为: 0.9768268465995789

在这里插入图片描述

更详细的请看paddle版本的实现:深度学习实战基础案例——卷积神经网络(CNN)基于SqueezeNet的眼疾识别

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/39288.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linkedin为什么要退出中国市场?

在迅速发展的时代,职场也在不断变换,只有不断地提升专业技能和进行培训,才能在职场中获得成功。Linkedin作为一家专注于职业发展的平台,专业的学习体验以及热门技能赢得了人们青睐。然而遗憾的是这个曾经让人备受青睐的平台,如今却在中国市场中黯然落幕,究竟是何种原因让曾经风…

大数据Flink(六十一):Flink流处理程序流程和项目准备

文章目录 Flink流处理程序流程和项目准备 一、Flink流处理程序的一般流程

Spark SQL优化:NOT IN子查询优化解决

背景 有如下的数据查询场景。 SELECT a,b,c,d,e,f FROM xxx.BBBB WHERE dt ${zdt.addDay(0).format(yyyy-MM-dd)} AND predict_type not IN ( SELECT distinct a FROM xxx.AAAAAWHERE dt ${zdt.addDay(0).format(yyyy-MM-dd)} ) 分析 通过查看SQL语句的执行计划基本…

Dubbo基础学习(笔记一)

目录 第一章、概念介绍1.1)什么是RPC框架1.2)什么是分布式系统1.3)Dubbo概述1.3)Dubbo基本架构 第二章、服务提供者2.1)目录结构和依赖2.2)model层2.3)service层2.4)resources配置文…

ARTS 挑战打卡的第8天 ---volatile 关键字在MCU中的作用,四个实例讲解(Tips)

前言 (1)volatile 关键字作为嵌入式面试的常考点,很多人都不是很了解,或者说一知半解。 (2)可能有些人会说了,volatile 关键字不就是防止编译器优化的吗?有啥好详细讲解的&#xff1…

澎峰科技|邀您关注2023 RISC-V中国峰会!

峰会概览 2023 RISC-V中国峰会(RISC-V Summit China 2023)将于8月23日至25日在北京香格里拉饭店举行。本届峰会将以“RISC-V生态共建”为主题,结合当下全球新形势,把握全球新时机,呈现RISC-V全球新观点、新趋势。 本…

《3D 数学基础》12 几何图元

目录 1 表达图元的方法 1.1 隐式表示法 1.2 参数表示 1.3 直接表示 2. 直线和射线 2.1 射线的不同表示法 2.1.1 两点表示 2.1.2 参数表示 2.1.3 相互转换 2.2 直线的不同表示法 2.2.1 隐式表示法 2.2.2 斜截式 2.2.3 相互转换 3. 球 3.1 隐式表示 1 表达图元的方…

C语言的使用技巧--在IO操作中的移位和快速配置

在WB32F103(ARM cortex m3内核,96Mhz)的gpio初始化中有一段代码,充分的结合了硬件特征并使用C语言的技巧来快速的配置对应的GPIO的功能,堪称经典和楷模,代码异常简洁,执行速度快,配置…

Python pycparser(c文件解析)模块使用教程

文章目录 安装 pycparser 模块模块开发者网址获取抽象语法树1. 需要导入的模块2. 获取 不关注预处理相关 c语言文件的抽象语法树ast3. 获取 预处理后的c语言文件的抽象语法树ast 语法树组成1. 数据类型定义 Typedef2. 类型声明 TypeDecl3. 标识符类型 IdentifierType4. 变量声明…

语聚AI公测发布,大语言模型时代下新的生产力工具

语聚AI 公测发布 距离语聚AI内测上线已经过去近1个月。 这期间,我们共邀请了近百位资深用户与行业专家加入语聚AI产品体验。通过大家的热情参与积极反馈,我们不断优化并完善了语聚AI的功能与使用体验。 经过研发团队不懈的努力,今天语聚AI终…

梅赛德斯-奔驰将成为首家集成ChatGPT的汽车制造商

ChatGPT的受欢迎程度毋庸置疑。OpenAI这个基于人工智能的工具,每天能够吸引无数用户使用,已成为当下很受欢迎的技术热点。因此,有许多公司都在想方设法利用ChatGPT来提高产品吸引力,卖点以及性能。在汽车领域,梅赛德斯…

代码随想录算法训练营第59天|动态规划part16|583. 两个字符串的删除操作、72. 编辑距离、编辑距离总结篇

代码随想录算法训练营第59天|动态规划part16|583. 两个字符串的删除操作、72. 编辑距离、编辑距离总结篇 583. 两个字符串的删除操作 583. 两个字符串的删除操作 思路: 思路见代码 代码: python class Solution(object):de…

2023年国赛数学建模思路 - 案例:FPTree-频繁模式树算法

文章目录 算法介绍FP树表示法构建FP树实现代码 建模资料 ## 赛题思路 (赛题出来以后第一时间在CSDN分享) https://blog.csdn.net/dc_sinor?typeblog 算法介绍 FP-Tree算法全称是FrequentPattern Tree算法,就是频繁模式树算法&#xff0c…

QT-Mysql数据库图形化接口

QT sql mysqloper.h qsqlrelationaltablemodelview.h /************************************************************************* 接口描述:Mysql数据库图形化接口 拟制: 接口版本:V1.0 时间:20230727 说明:支…

基于VUE3+Layui从头搭建通用后台管理系统(前端篇)九:自定义组件封装下

一、本章内容 续上一张,本章实现一些自定义组件的封装,包括文件上传组件封装、级联选择组件封装、富文本组件封装等。 1. 详细课程地址: 待发布 2. 源码下载地址: 待发布 二、界面预览 三、开发视频 基于VUE3+Layui从头搭建通用后台管

mov转mp4格式怎么转?

mov转mp4格式怎么转?众所周知,MOV视频格式是由苹果公司推出的常用的视频格式,能够在苹果软件及设备上使用。但是,如果将其应用于其他软件和设备上的话,可能会遇到文件无法正常播放的情况。在这个时候,我们需…

Linux MQTT智能家居项目(LED界面的布局设置)

文章目录 前言一、LED界面布局准备工作二、LED界面布局三、逻辑实现总结 前言 上篇文章我们完成了主界面的布局设置那么这篇文章我们就来完成各个界面的布局设置吧。 一、LED界面布局准备工作 首先添加LED灯光控制的图标。 将选择好的LED图标添加进来: 图标可以…

drawio导出矢量图

1.选中要导出的图 2.导出为pdf 3.用adobe打开pdf,另存为eps

华为认证含金量如何

华为认证是指通过华为技术有限公司官方认证考试所获得的认证资格。华为认证主要分为三个级别:华为认证工程师(HCIE)、华为认证专家(HCNP)和华为认证技术专家(HCNA),每个级别都有不同…

你真的了解数据结构与算法吗?

数据结构与算法,是理论和实践必须紧密结合的一门学科,有关数据结构和算法同类的课程或书籍,有些只是名为“数据结构”,而非“数据结构与算法”,它们在内容上并无很大区别。 实际上,数据结构和算法&#xf…