VGG16模型实现新冠肺炎图片多分类

1. 项目简介

本项目的目标是通过深度学习模型VGG16,实现对新冠肺炎图像的多分类任务,以帮助医疗人员对患者的影像进行快速、准确的诊断。新冠肺炎自爆发以来,利用医学影像如X光和CT扫描进行疾病诊断已成为重要手段之一。随着数据量的增加,基于人工智能的图像分析方法逐渐显现出其优势,能够有效提高检测效率并减少误诊率。该项目基于预训练的VGG16模型,通过对肺部CT或X光影像进行分类,实现对不同类型的肺部病变的分类识别。VGG16模型是深度卷积神经网络中的经典网络,具有16层网络结构,能够捕捉图像中的细微特征,适用于医学图像分析。本项目通过迁移学习,将VGG16的卷积层权重应用于新冠肺炎图片分类任务,并通过微调模型,使其适应于具体的医学影像数据集。最终目标是构建一个高效且稳定的深度学习模型,帮助医疗人员对肺炎患者进行辅助诊断,提高诊断的准确性和效率,同时减轻医疗系统的负担。

2.技术创新点摘要

迁移学习的应用:该项目利用VGG16模型进行迁移学习,这是该项目的重要创新之一。VGG16是一个预训练模型,已经在大规模图像数据集ImageNet上进行训练,具有强大的特征提取能力。通过冻结预训练模型的卷积层权重,模型可以专注于当前新冠肺炎图像的分类任务,避免从头开始训练,有效缩短了模型的训练时间,并提升了训练的稳定性和准确性。

3. 数据集与预处理

本项目使用的新冠肺炎医学图像数据集主要由CT或X光图像组成,数据集包含了正常、轻度感染及重度感染的肺部影像。这些医学图像具有高分辨率,能够反映患者肺部的病变情况。数据集中的标签对应不同的病理分类,这些标签用于训练模型进行多分类任务。医学影像的特征在于其复杂的结构和细节,因此需要经过严格的预处理,以确保模型能够从中学习到有效的特征。

在数据预处理阶段,首先对原始图像进行统一的尺寸调整。所有图像被缩放到224x224像素,以匹配VGG16模型的输入尺寸。此外,图像通过 transforms.ToTensor() 函数转换为张量,并将像素值从0-255的范围标准化为0-1之间。接着,使用预训练模型ImageNet的均值和标准差对图像进行归一化处理,将像素值调整到(-1,1)的区间。这一步能够确保输入数据的分布与预训练模型的输入分布相一致,进而提高模型的性能。

在数据增强方面,项目引入了多种增强策略,以增强模型的泛化能力。这些增强操作包括随机裁剪、翻转等,这能够有效增加训练数据的多样性,从而防止模型过拟合。同时,这些增强手段能够模拟不同条件下的医学图像变化,使模型更加稳健。

4. 模型架构

本项目使用了VGG16模型,这是一种深度卷积神经网络,具有16个权重层。其模型结构包括卷积层、池化层、全连接层等,具体如下:

  • 卷积层:VGG16由多个卷积层构成,每一层卷积操作的公式为:

y i , j , k = ∑ m , n w m , n , k ⋅ x i + m , j + n + b k y_{i,j,k} = \sum_{m,n} w_{m,n,k} \cdot x_{i+m,j+n} + b_k yi,j,k=m,nwm,n,kxi+m,j+n+bk

其中,x 是输入图像,w 是卷积核权重,b 是偏置项,y 是输出的特征图。这些卷积操作主要用于提取图像中的局部特征,尤其适合复杂的医学图像。

池化层:卷积后会经过最大池化层(Max Pooling),其公式为:

y i , j , k = max ⁡ { x i + m , j + n , k } , ( m , n ) ∈ S y_{i,j,k} = \max \{ x_{i+m,j+n,k} \}, \, (m,n) \in S yi,j,k=max{xi+m,j+n,k},(m,n)S

池化操作减少了特征图的大小,从而降低了模型的计算复杂度,同时保留了重要的特征信息。

全连接层:卷积和池化层的输出最终会通过全连接层,该层将多维的特征映射为一维向量,公式为:

y = W ⋅ x + b y = W \cdot x + b y=Wx+b

其中 W是权重矩阵,x 是输入向量,b是偏置项,y 是输出。这一层用于完成分类任务,将卷积提取到的特征映射到具体的分类结果上。

Softmax层:在最后的输出层,使用Softmax激活函数生成每个类别的概率分布,公式为:

P ( y = k ∣ x ) = e z k ∑ j e z j P(y = k | x) = \frac{e^{z_k}}{\sum_{j} e^{z_j}} P(y=kx)=jezjezk

其中 zk是类别 k的输出,Softmax函数保证输出结果为概率分布,并用于多分类任务。

模型层次概览

输入:224x224x3的医学影像

卷积层1-2(64个3x3卷积核)

最大池化层1

卷积层3-4(128个3x3卷积核)

最大池化层2

卷积层5-7(256个3x3卷积核)

最大池化层3

卷积层8-10(512个3x3卷积核)

最大池化层4

卷积层11-13(512个3x3卷积核)

最大池化层5

3个全连接层

最后通过Softmax层输出分类概率

模型的整体训练流程

训练流程如下:

数据加载:数据集通过自定义 MyDataset 类加载,并应用了标准化和数据增强等预处理步骤。

模型初始化:加载预训练的VGG16模型,并冻结部分卷积层的权重以保留其在ImageNet上的特征提取能力,只对最后几层进行微调。

前向传播:将图像输入到模型中,经过卷积、池化、全连接等层,生成最终的分类结果。

损失计算:使用交叉熵损失函数(CrossEntropyLoss)计算预测结果与真实标签之间的差异: L = − ∑ i y i log ⁡ ( y ^ i ) L = -\sum_{i} y_i \log(\hat{y}_i) L=iyilog(y^i)

其中 yi是真实标签的概率分布,y^i是预测概率分布。

反向传播:通过计算损失函数的梯度,更新可训练的参数,优化目标是最小化损失函数。

优化器:使用Adam优化器进行梯度更新,该优化器结合了动量与自适应学习率的优点: θ t + 1 = θ t − η ⋅ m t v t + ϵ \theta_{t+1} = \theta_t - \eta \cdot \frac{m_t}{\sqrt{v_t} + \epsilon} θt+1=θtηvt +ϵmt其中 mt和 vt分别是一阶和二阶矩估计,η 是学习率。

训练轮次:设定训练轮次(例如20轮),在每一轮中通过前向传播、损失计算、反向传播进行权重更新。

模型评估:在测试集上进行评估,主要使用准确率(Accuracy)、召回率(Recall)等指标:

  1. 准确率 A c c u r a c y = T P + T N T P + T N + F P + F N Accuracy = \frac{TP + TN}{TP + TN + FP + FN} Accuracy=TP+TN+FP+FNTP+TN
  1. 召回率 R e c a l l = T P T P + F N Recall = \frac{TP}{TP + FN} Recall=TP+FNTP

通过这些步骤,模型能够高效完成新冠肺炎图像的多分类任务,并在实际数据集上进行评估与优化。

5. 核心代码详细讲解

数据预处理与特征工程
pic_transform = transforms.Compose([transforms.Resize([224,224]),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

解释

  1. transforms.Resize([224,224]):这行代码将输入图像的尺寸缩放到224x224像素,以确保输入图像大小一致,符合VGG16模型的输入要求。
  2. transforms.ToTensor():将PIL图像转换为PyTorch的Tensor类型,并将像素值从0-255的范围归一化为0-1。这是标准的PyTorch数据处理步骤。
  3. transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):基于ImageNet的均值和标准差进行图像归一化,将像素值调整到(-1,1)的范围。该归一化策略是ImageNet预训练模型的标准配置,有助于提高模型性能。
自定义数据集加载
class MyDataset(Dataset):def init(self, img_path, file_name ,transform=None):self.root = img_pathself.file_name = file_nameself.csv_root = self.root + '//' + self.file_namedf = pd.read_csv(self.csv_root)rows = df.shape[0]imgs = []labels = []for row in range(0,rows):imgs.append(os.path.join(self.root,df['image_path'][row]))labels.append(df['labels'][row])self.img = imgsself.label = labelsself.transform = transform
def len(self):return len(self.label)
def getitem(self, item):img = self.img[item]label = self.label[item]img = Image.open(img).convert('RGB')if self.transform is not None:img = self.transform(img)label = np.array(label).astype(np.int64)label = torch.from_numpy(label)return img, label

解释

  1. init(self, img_path, file_name ,transform=None):初始化方法,定义了数据集的路径和图像转换方法,并加载图像路径和标签。transform参数用于指定数据增强和预处理步骤。
  2. df = pd.read_csv(self.csv_root):从指定的CSV文件中读取图像路径和标签,CSV文件包含图像的文件路径及其对应的标签。
  3. self.img = imgsself.label = labels:将图像路径和标签分别存储在两个列表中,以供后续数据加载使用。
  4. len(self):返回数据集中样本的数量,这是PyTorch自定义数据集的标准实现。
  5. getitem(self, item):通过索引获取图像和标签。图像通过PIL库打开并转换为RGB格式,然后应用数据预处理(如果有),最终返回Tensor格式的图像和标签。
模型构建与训练
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = models.vgg16(pretrained=True)
for param in model.parameters():param.requires_grad = False
model.classifier[6] = nn.Linear(4096, 3)
model = model.to(device)

解释

  1. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu'):这行代码用于检测当前系统是否有可用的GPU,如果有则将计算设备设为GPU(CUDA),否则使用CPU。
  2. model = models.vgg16(pretrained=True):加载预训练的VGG16模型,这个模型已经在ImageNet上进行过训练,能够有效地提取图像的特征。
  3. for param in model.parameters(): param.requires_grad = False:冻结VGG16模型的所有卷积层权重,使它们在训练过程中不更新。这是典型的迁移学习策略,主要目的是利用预训练模型的特征提取能力,同时减少训练时间和计算资源。
  4. model.classifier[6] = nn.Linear(4096, 3):替换VGG16模型中的最后一层全连接层,将输出从ImageNet的1000类修改为当前任务的3类(例如:正常、轻度感染、重度感染)。
  5. model = model.to(device):将模型移动到GPU或CPU上,以加速训练过程。
模型训练与评估
def vgg_train(model, epochs, train_loader, test_loader, log_step_freq):model.train()optimizer = torch.optim.Adam(model.parameters(), lr=0.001)loss_fn = nn.CrossEntropyLoss()for epoch in range(epochs):for step, (x, y) in enumerate(train_loader):x, y = x.to(device), y.to(device)optimizer.zero_grad()pred = model(x)loss = loss_fn(pred, y)loss.backward()optimizer.step()if step % log_step_freq == 0:print(f"[{epoch+1}/{epochs}] Step: {step}, Loss: {loss.item()}")print('训练成功~')

解释

  1. model.train():将模型设置为训练模式,这会启用诸如Dropout等正则化技术。
  2. optimizer = torch.optim.Adam(model.parameters(), lr=0.001):使用Adam优化器进行模型参数的更新,学习率设置为0.001。Adam优化器结合了动量和自适应学习率,能够加快训练过程并减少震荡。
  3. loss_fn = nn.CrossEntropyLoss():定义交叉熵损失函数,用于计算预测值与真实标签之间的误差,适用于多分类任务。
  4. for epoch in range(epochs):开始模型的训练循环,每个epoch表示模型对整个数据集的完整遍历。
  5. x, y = x.to(device), y.to(device):将输入数据和标签移动到GPU(如果有)或CPU,以确保与模型在同一设备上进行计算。
  6. optimizer.zero_grad():清空优化器中的梯度缓存,避免上一次的梯度对本次计算的影响。
  7. pred = model(x):前向传播,模型对输入数据x进行预测。
  8. loss = loss_fn(pred, y):计算预测结果与真实标签之间的损失。
  9. loss.backward():反向传播计算梯度,更新模型参数。
  10. optimizer.step():根据反向传播计算得到的梯度更新模型参数。
  11. if step % log_step_freq == 0:每隔log_step_freq步打印一次训练日志,包括当前epoch、step和损失值。
  12. print('训练成功~'):训练结束后的提示信息。
评估指标
def line_plotling(df, metric):import seaborn as snsimport matplotlib.pyplot as pltsns.set_theme(style='ticks')sns.lineplot(x='epoch', y=metric, data=df, color='r')sns.lineplot(x='epoch', y='val_'+metric, data=df, color='b')plt.legend(['train_'+metric, 'val_'+metric])

解释

  1. sns.set_theme(style='ticks'):使用Seaborn库设置绘图主题,风格为ticks
  2. sns.lineplot(x='epoch', y=metric, data=df, color='r'):绘制训练集的性能指标(例如准确率或损失)的变化曲线,x轴表示epoch,y轴表示指标值,曲线颜色为红色。
  3. sns.lineplot(x='epoch', y='val_'+metric, data=df, color='b'):绘制验证集的性能指标变化曲线,颜色为蓝色。通过对比训练集和验证集的曲线变化,可以观察到模型是否过拟合或欠拟合。
  4. plt.legend(['train_'+metric, 'val_'+metric]):为图形添加图例,区分训练集和验证集的曲线。

6. 模型优缺点评价

模型优点:
  1. 迁移学习的有效应用:通过使用VGG16的预训练权重,模型在图像特征提取方面表现出色,同时减少了对大规模数据集的依赖,加速了训练过程。
  2. 深度网络的特征提取能力强:VGG16的多层卷积结构能够提取复杂的图像特征,尤其适合医学图像中微小病变的检测。
  3. 数据预处理与增强合理:项目采用了图像归一化和标准化,以及图像尺寸调整等预处理措施,有效提高了模型对不同分辨率图像的泛化能力。
  4. 准确性高:通过使用交叉熵损失和Adam优化器,模型在分类任务中的表现稳定,能够较好地完成多分类任务。
模型缺点:
  1. 计算资源需求大:VGG16网络较深,参数较多,尽管特征提取效果好,但其计算复杂度较高,在推理时可能对计算资源要求较高,不适合实时应用场景。
  2. 适应性有限:模型结构未针对医学图像中的特殊结构(如肺部CT的形态学特征)进行专门优化,可能导致在处理非典型病变时表现不佳。
  3. 超参数未优化:项目中未对学习率、批量大小等超参数进行深入优化,可能存在进一步提高模型表现的空间。
改进方向:
  1. 模型结构优化:可以尝试引入轻量化模型(如MobileNet、EfficientNet)或加入注意力机制,增强对医学图像中病变区域的聚焦能力,降低计算成本并提升效率。
  2. 超参数调整:对学习率、优化器、批量大小等超参数进行调优,通过网格搜索或随机搜索等方法找到最佳配置,提高模型性能。
  3. 更多数据增强:可以引入更加丰富的数据增强方法,如随机裁剪、旋转、色彩调整等,以增加训练数据的多样性,提升模型的泛化能力。
  4. 查看更多项目案例/数据集/代码/视频教程:点击进入>>

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/53330.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

华为---以太网静态路由配置使用下一跳通信正常,而使用出接口无法通信

目录 1. 实验环境 2. 结果测试 3. 分析验证 3.1 以太网静态路由配置使用下一跳跨网段通信抓包分析 3.2 以太网静态路由配置使用出接口跨网段通信抓包分析 3.3 以太网静态路由配置使用出接口无法跨网段通信问题解决办法 1. 实验环境 以太网静态路由配置使用下一跳跨网段通…

网络丢包定位记录(二)

网卡驱动丢包 查看:ifconfig eth1/eth0 等接口 1.RX errors: 表示总的收包的错误数量,还包括too-long-frames错误,Ring Buffer 溢出错误,crc 校验错误,帧同步错误,fifo overruns 以及 missed pkg 等等。 …

Maven的详细解读和配置

目录 一、Maven 1.1 引言 1.2 介绍 1.3 下载安装 1.3.1 解压 1.3.2 配置环境变量 1.3.3 测试 1.4 仓库[了解] 1.5 Maven配置 1.5.1 修改仓库位置 1.5.2 设置镜像 二、IDEA - MAVEN 2.1 idea关联maven 2.2 为新项目设置 2.2 创建java项目[重点] 2.3 java项目结构…

Go-知识-定时器

Go-知识-定时器 1. 介绍2. Timer使用场景2.1 设定超时时间2.2 延迟执行某个方法 3. Timer 对外接口3.1 创建定时器3.2 停止定时器3.3 重置定时器3.4 After3.5 AfterFunc 4. Timer 的实现原理4.1 Timer数据结构4.1.1 Timer4.1.2 runtimeTimer 4.2 Timer 实现原理4.2.1 创建Timer…

特征工程与交叉验证在机器学习中的应用

数据入口:学生考试表现影响因素数据集 - Heywhale.com 本数据集提供了关于影响学生考试成绩的多种因素的全面概述。数据集包含了有关学习习惯、出勤率、家长参与、资源获取等方面的信息。 数据说明 字段名说明Hours_Studied每周学习的小时数Attendance出勤率&…

(笔记自用)位运算总结+LeetCode例题:颠倒二进制位+位1的个数

一.位运算总结: 在解题之前理解一下为什么需要位运算?它的本质是什么? 力扣上不少位运算相关的题,并且很多题也会用到位运算的技巧。这又是为什么? 位运算的由来 在计算机里面,任何数据最终都是用数字来表示的&…

[Linux]:信号(下)

✨✨ 欢迎大家来到贝蒂大讲堂✨✨ 🎈🎈养成好习惯,先赞后看哦~🎈🎈 所属专栏:Linux学习 贝蒂的主页:Betty’s blog 1. 信号的阻塞 1.1 基本概念 信号被操作系统发送给进程之后,进程…

【Linux学习】基本指令其一

命令行界面 命令行终端是一个用户界面,允许用户通过输入文本命令与计算机系统进行交互。 比如Windows下, 键入winR,然后输入cmd,就可以输入文本指令与操作系统交互了。 Windows有另一个命令行界面Powershell,它的功能比cmd更强大…

电商ISV 电商SaaS 是什么

Independent Software Vendors的英文缩写,意为“独立软件开发商” 软件即服务(SaaS) 指一种基于云技术的软件交付模式 订阅收费 这些公司叫做ISV软件供应商,通过SaaS服务交付收费 为什么会有电商ISV 从商家角度划分:有独立品牌商家、大商…

微信支付的委托代扣功能服务如何申请开通?

扣款服务(原委托代扣服务,以下均用委托代扣)是微信支付旗下的重要产品 1、委托代扣是指商户取得用户的扣款授权后,向微信支付发起从用户账户扣款至商户账户的扣款指令,微信支付无需验证用户的支付密码,即可…

记录一下,Vcenter清理/storage/archive空间

一、根因 vpostgres:这个目录可能包含与 vCenter Server 使用的 PostgreSQL 数据库相关的归档文件过多,导致空间被占用。 二、处理过程 1、SSH登陆到Vcenter. 2、df -Th **图中可以看到 /storage/archive 使用占比很高。 /storage/archive 目录通常用…

fiddler抓包06_抓取https请求(chrome)

课程大纲 首次安装Fiddler,抓https请求,除打开抓包功能(F12)还需要: ① Fiddler开启https抓包 ② Fiddler导出证书; ③ 浏览器导入证书。 否则,无法访问https网站(如下图&#xff0…

Qt优秀开源项目之二十三:QSimpleUpdater

QSimpleUpdater是开源的自动升级模块,用于检测、下载和安装更新。 github地址:https://github.com/alex-spataru/QSimpleUpdater QSimpleUpdater目前Star不多(911个),但已在很多开源项目看到其身影,比如Not…

web网站的任意文件上传下载漏洞解析

免责申明 本文仅是用于学习检测自己搭建的任意文件上传下载漏洞相关原理,请勿用在非法途径上,若将其用于非法目的,所造成的一切后果由您自行承担,产生的一切风险和后果与笔者无关;本文开始前请认真详细学习《‌中华人民共和国网络安全法》‌及其所在国家地区相关法规内容【…

【D3.js in Action 3 精译_023】3.3 使用 D3 将数据绑定到 DOM 元素

当前内容所在位置: 第一部分 D3.js 基础知识 第一章 D3.js 简介(已完结) 1.1 何为 D3.js?1.2 D3 生态系统——入门须知1.3 数据可视化最佳实践(上)1.3 数据可视化最佳实践(下)1.4 本…

Three.js 3D人物漫游项目(中)

本文目录 前言最终效果展示1、人物添加阴影1.1 添加地板1.1.1 效果 1.2 模型castShadow1.2.1 效果 1.3 轨道控制器1.3.1 效果 2、创建建筑物2.1 代码2.2 效果 前言 在数字技术的浪潮中,三维图形渲染技术以其独特的魅力,正逐步渗透到我们生活的方方面面&a…

手机、平板电脑编程———未来之窗行业应用跨平台架构

一、平板编程优点 1. 便携性强 - 可以随时随地携带平板进行编程,不受地点限制,方便在旅行、出差或休息时间进行学习和开发。 2. 直观的触摸操作 - 利用触摸屏幕进行代码编辑、缩放、拖动等操作,提供了一种直观和自然的交互方式。 …

联想(lenovo) 小新Pro13锐龙版(新机整理、查看硬件配置和系统版本、无线网络问题、windows可选功能)

新机整理 小新pro13win10新机整理 查看硬件配置和系统版本 设置-》系统-》系统信息 无线网络问题 部分热点可以,部分不可以 问题:是因为自己修改了WLAN的IP分配方式为手动分配,导致只能在连接家里无线网的时候可以,连接其他…

Unity 高亮插件HighlightPlus介绍

主要是对官方文档进行了翻译(我做了一些补充和一些小的调整) 但是如果你只是想快速入门: Unity 高亮插件Highlight Plus快速入门-CSDN博客 注意:官方文档本身就落后实际,但对入门仍很有帮助,核心并没有较大改变,有的功能有差异,以实际为准.(目前我已校正了大部分差异,后续我…

vue3 自定义el-tree树形结构样式

这里样式设置主要用到了 windcss 实现效果 模拟数据 这里也可以用模拟的数据,下面用的是后端请求的真实数据 [{"id": 5,"rule_id": 0,"status": 1,"create_time": "2019-08-11 13:36:09","update_time": "…