第64步 深度学习图像识别:多分类建模误判病例分析(Pytorch)

基于WIN10的64位系统演示

一、写在前面

上期我们基于TensorFlow环境介绍了多分类建模的误判病例分析。

本期以健康组、肺结核组、COVID-19组、细菌性(病毒性)肺炎组为数据集,基于Pytorch环境,构建SqueezeNet多分类模型,分析误判病例,因为它建模速度快。

同样,基于GPT-4辅助编程。

二、误判病例分析实战

使用胸片的数据集:肺结核病人和健康人的胸片的识别。其中,健康人900张,肺结核病人700张,COVID-19病人549张、细菌性(病毒性)肺炎组900张,分别存入单独的文件夹中。

直接分享代码:

######################################导入包###################################
import copy
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision import models
from torch.utils.data import DataLoader
from torch import optim, nn
from torch.optim import lr_scheduler
import os
import matplotlib.pyplot as plt
import warnings
import numpy as npwarnings.filterwarnings("ignore")
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False# 设置GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")################################导入数据集#####################################
from torchvision import datasets, transforms
from torch.nn.functional import softmax
from PIL import Image
import pandas as pd
import torch.nn as nn
import timm
from torch.optim import lr_scheduler# 自定义的数据集类
class ImageFolderWithPaths(datasets.ImageFolder):def __getitem__(self, index):original_tuple = super(ImageFolderWithPaths, self).__getitem__(index)path = self.imgs[index][0]tuple_with_path = (original_tuple + (path,))return tuple_with_path# 数据集路径
data_dir = "./MTB-1"# 图像的大小
img_height = 256
img_width = 256# 数据预处理
data_transforms = {'train': transforms.Compose([transforms.RandomResizedCrop(img_height),transforms.RandomHorizontalFlip(),transforms.RandomVerticalFlip(),transforms.RandomRotation(0.2),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'val': transforms.Compose([transforms.Resize((img_height, img_width)),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}# 加载数据集
full_dataset = ImageFolderWithPaths(data_dir, transform=data_transforms['train'])# 获取数据集的大小
full_size = len(full_dataset)
train_size = int(0.8 * full_size)  # 假设训练集占70%
val_size = full_size - train_size  # 验证集的大小# 随机分割数据集
torch.manual_seed(0)  # 设置随机种子以确保结果可重复
train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])# 应用数据增强到训练集和验证集
train_dataset.dataset.transform = data_transforms['train']
val_dataset.dataset.transform = data_transforms['val']# 创建数据加载器
batch_size = 8
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=0)dataloaders = {'train': train_dataloader, 'val': val_dataloader}
dataset_sizes = {'train': len(train_dataset), 'val': len(val_dataset)}
class_names = full_dataset.classes# 获取数据集的类别
class_names = full_dataset.classes# 保存预测结果的列表
results = []###############################定义SqueezeNet模型################################
# 定义SqueezeNet模型
model = models.squeezenet1_1(pretrained=True)  # 这里以SqueezeNet 1.1版本为例
num_ftrs = model.classifier[1].in_channels# 根据分类任务修改最后一层
# 这里我们改变模型的输出层为4,因为我们做的是四分类
model.classifier[1] = nn.Conv2d(num_ftrs, 4, kernel_size=(1,1))# 修改模型最后的输出层为我们需要的类别数
model.num_classes = 4model = model.to(device)# 打印模型摘要
print(model)#############################编译模型#########################################
# 定义损失函数
criterion = nn.CrossEntropyLoss()# 定义优化器
optimizer = torch.optim.Adam(model.parameters())# 定义学习率调度器
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)# 开始训练模型
num_epochs = 20# 初始化记录器
train_loss_history = []
train_acc_history = []
val_loss_history = []
val_acc_history = []for epoch in range(num_epochs):print('Epoch {}/{}'.format(epoch, num_epochs - 1))print('-' * 10)# 每个epoch都有一个训练和验证阶段for phase in ['train', 'val']:if phase == 'train':model.train()  # 设置模型为训练模式else:model.eval()   # 设置模型为评估模式running_loss = 0.0running_corrects = 0# 遍历数据for inputs, labels, paths in dataloaders[phase]:inputs = inputs.to(device)labels = labels.to(device)# 零参数梯度optimizer.zero_grad()# 前向with torch.set_grad_enabled(phase == 'train'):outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)# 只在训练模式下进行反向和优化if phase == 'train':loss.backward()optimizer.step()# 统计running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)epoch_loss = running_loss / dataset_sizes[phase]epoch_acc = (running_corrects.double() / dataset_sizes[phase]).item()# 记录每个epoch的loss和accuracyif phase == 'train':train_loss_history.append(epoch_loss)train_acc_history.append(epoch_acc)else:val_loss_history.append(epoch_loss)val_acc_history.append(epoch_acc)print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))print()# 保存模型
torch.save(model.state_dict(), 'SqueezeNet_model-m-s.pth')# 加载最佳模型权重
#model.load_state_dict(best_model_wts)
#torch.save(model, 'shufflenet_best_model.pth')
#print("The trained model has been saved.")
###########################误判病例分析#################################
import os
import pandas as pd
from collections import defaultdict# 判定组别的字典
group_dict = {("COVID-19", "Normal"): "B",("COVID-19", "Pneumonia"): "C",("COVID-19", "Tuberculosis"): "D",("Normal", "COVID-19"): "E",("Normal", "Pneumonia"): "F",("Normal", "Tuberculosis"): "G",("Pneumonia", "COVID-19"): "H",("Pneumonia", "Normal"): "I",("Pneumonia", "Tuberculosis"): "J",("Tuberculosis", "COVID-19"): "K",("Tuberculosis", "Normal"): "L",("Tuberculosis", "Pneumonia"): "M",
}# 创建一个字典来保存所有的图片信息
image_predictions = {}# 循环遍历所有数据集(训练集和验证集)
for phase in ['train', 'val']:# 设置模型的状态model.eval()# 遍历数据for inputs, labels, paths in dataloaders[phase]:inputs = inputs.to(device)labels = labels.to(device)# 计算模型的输出with torch.no_grad():outputs = model(inputs)_, preds = torch.max(outputs, 1)# 循环遍历每一个批次的结果for path, pred in zip(paths, preds):# 提取图片的类别actual_class = os.path.split(os.path.dirname(path))[-1] # 提取图片的名称image_name = os.path.basename(path)# 获取预测的类别predicted_class = class_names[pred]# 判断预测的分组类型if actual_class == predicted_class:group_type = 'A'elif (actual_class, predicted_class) in group_dict:group_type = group_dict[(actual_class, predicted_class)]else:group_type = 'Other'  # 如果没有匹配的条件,可以归类为其他# 保存到字典中image_predictions[image_name] = [phase, actual_class, predicted_class, group_type]# 将字典转换为DataFrame
df = pd.DataFrame.from_dict(image_predictions, orient='index', columns=['Dataset Type', 'Actual Class', 'Predicted Class', 'Group Type'])# 保存到CSV文件中
df.to_csv('result-m-s.csv')

四、改写过程

先说策略:首先,先把二分类的误判病例分析代码改成四分类的;其次,用咒语让GPT-4帮我们续写代码已达到误判病例分析。

提供咒语如下:

①改写{代码1},改变成4分类的建模。代码1为:{XXX};

在{代码1}的基础上改写代码,达到下面要求:

(1)首先,提取出所有图片的“原始图片的名称”、“属于训练集还是验证集”、“预测为分组类型”;文件的路劲格式为:例如,“MTB-1\Normal\XXX.png”属于Normal,“MTB-1\COVID-19\XXX.jpg”属于COVID-19,“MTB-1\Pneumonia\XXX.jpeg”属于Pneumonia,“MTB-1\Tuberculosis\XXX.png”属于Tuberculosis;

(2)其次,根据样本预测结果,把样本分为以下若干组:(a)预测正确的图片,全部判定为A组;(b)本来就是COVID-19的图片,预测为Normal,判定为B组;(c)本来就是COVID-19的图片,预测为Pneumonia,判定为C组;(d)本来就是COVID-19的图片,预测为Tuberculosis,判定为D组;(e)本来就是Normal的图片,预测为COVID-19,判定为E组;(f)本来就是Normal的图片,预测为Pneumonia,判定为F组;(g)本来就是Normal的图片,预测为Tuberculosis,判定为G组;(h)本来就是Pneumonia的图片,预测为COVID-19,判定为H组;(i)本来就是Pneumonia的图片,预测为Normal,判定为I组;(j)本来就是Pneumonia的图片,预测为Tuberculosis,判定为J组;(k)本来就是Tuberculosis的图片,预测为COVID-19,判定为H组;(l)本来就是Tuberculosis的图片,预测为Normal,判定为I组;(m)本来就是Tuberculosis的图片,预测为Pneumonia,判定为J组;

(3)居于以上计算的结果,生成一个名为result-m.csv表格文件。列名分别为:“原始图片的名称”、“属于训练集还是验证集”、“预测为分组类型”、“判定的组别”。其中,“原始图片的名称”为所有图片的图片名称;“属于训练集还是验证集”为这个图片属于训练集还是验证集;“预测为分组类型”为模型预测该样本是哪一个分组;“判定的组别”为根据步骤(2)判定的组别,从A到J一共十组选择一个。

(4)需要把所有的图片都进行上面操作,注意是所有图片,而不只是一个批次的图片。

代码1为:{XXX}

③还需要根据报错做一些调整即可,自行调整。

最后,看看结果:

模型只运行了2次,所以效果很差哈,全部是预测成了COVID-19。

四、数据

链接:https://pan.baidu.com/s/1rqu15KAUxjNBaWYfEmPwgQ?pwd=xfyn

提取码:xfyn

五、结语

深度学习图像分类的教程到此结束,洋洋洒洒29篇,涉及到的算法和技巧也够发一篇SCI了。当然,图像识别还有图像分割和目标识别两块内容,就放到最后再说了。下一趴,我们来介绍时间序列建模!!!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/68822.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

业务需要咨询?开发遇到 bug 想反馈?开发者在线提单功能上线!

大家是否遇到过下列问题—— 在开发的时候,遇到 bug 需要反馈… 有合作意向的时候,想更多了解业务和相关产品… 在接入的时候,需要得到专业技术支持… 别急,荣耀开发者服务平台在线提单功能上线了~ 处理问题分类说明&#xff1…

MIT6.S081实验环境搭建

MIT6.S081 lab 环境搭建 本文参考了MIT的官方指南和知乎文章环境搭建 step1 首先需要一个ubuntu20.04的系统,我使用的是vscode的WSL2连接的ubuntu20.04,使用virtual box建一个ubuntu20.04的虚拟机应该也可以。 可以用 lsb_release -a 查看一下自己ub…

对可再生能源和微电网集成研究的新控制技术和保护算法进行基线和测试及静态、时域和频率分析研究(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

基于React实现:弹窗组件与Promise的有机结合

背景 弹窗在现代应用中是最为常见的一种展示信息的形式,二次确认弹窗是其中最为经典的一种。当我们在React,Vue这种数据驱动视图的前端框架中渲染弹窗基本是固定的使用形式。 使用方式:创建新的弹窗组件,在需要弹窗的地方引用并…

【系统设计系列】 回顾可扩展性

系统设计系列初衷 System Design Primer: 英文文档 GitHub - donnemartin/system-design-primer: Learn how to design large-scale systems. Prep for the system design interview. Includes Anki flashcards. 中文版: https://github.com/donnemart…

【React】React学习:从初级到高级(三)

3 状态管理 随着应用不断变大,应该更有意识的去关注应用状态如何组织,以及数据如何在组件之间流动。冗余或重复的状态往往是缺陷的根源。 3.1 用State响应输入 3.1.1 声明式地考虑UI 总体步骤如下: 定位组件中不同的视图状态 确定是什么…

C++零碎记录(四)

6. 深拷贝与浅拷贝 ① 浅拷贝:简单的赋值拷贝操作。 ② 深拷贝:在堆区重新申请空间,进行拷贝操作。 ③ 浅拷贝,如下图所示,带来的问题就是堆区的内存重复释放。 ④ 深拷贝,如下图所示,在堆区…

深度学习-4-二维目标检测-YOLOv5源码测试与训练

本文采用的YOLOv5源码是ultralytics发行版3.1 YOLOv5源码测试与训练 1.Anaconda环境配置 1.1安装Anaconda Anaconda 是一个用于科学计算的 Python 发行版,支持 Linux, Mac, Windows, 包含了众多流行的科学计算、数据分析的 Python 包。 官方网址下载安装包&…

pdf怎么编辑文字?了解一下这几种编辑方法

pdf怎么编辑文字?PDF文件的普及使得它成为了一个重要的文件格式。然而,由于PDF文件的特性,它们不可直接编辑,这就使得PDF文件的修改变得比较麻烦。但是,不用担心,接下来这篇文章就给大家介绍几种编辑pdf文字…

go语言基本操作---三

变量的内存和变量的地址 指针是一个代表着某个内存地址的值。这个内存地址往往是在内存中存储的另一个变量的值的起始位置。Go语言对指针的支持介于java语言和C/C语言之间,它即没有想Java语言那样取消了代码对指针的直接操作的能力,也避免了C/C语言中由…

Redis面试题(笔记)

目录 1.缓存穿透 2.缓存击穿 3.缓存雪崩 小结 4.缓存-双写一致性 5.缓存-持久性 6.缓存-数据过期策略 7.缓存-数据淘汰策略 数据淘汰策略-使用建议 数据淘汰策略总结 8.redis分布式锁 setnx redission 主从一致性 9.主从复制、主从同步 10.哨兵模式 服务状态监…

最快1个月录用!9月SCI/SSCI/EI刊源表已更新!

2023年9月SCI/SSCI/EI期刊目录更新 2023年9月份刊源表已更新!计算机、医学、工程、环境、SSCI均有新增期刊,1区(TOP),最快1个月录用,好刊版面紧俏,切莫错失机会! 01 计算机领域 02 医学与制药领域 03 工…

桌面应用小程序,一种创新的跨端开发方案

Qt Group在提及2023年有桌面端应用程序开发热门趋势时,曾经提及三点: 关注用户体验:无论您是为桌面端、移动端,还是为两者一起开发应用程序,有一点是可以确定的:随着市场竞争日益激烈,对产品的期…

Python爬取天气数据并进行分析与预测

随着全球气候的不断变化,对于天气数据的获取、分析和预测显得越来越重要。本文将介绍如何使用Python编写一个简单而强大的天气数据爬虫,并结合相关库实现对历史和当前天气数据进行分析以及未来趋势预测。 1 、数据源选择 选择可靠丰富的公开API或网站作…

电子科大软件系统架构设计——面向对象建模基础

文章目录 面向对象建模基础UML建模语言UML模型图用例图活动图类图顺序图通信图状态机图构件图部署图包图对象图组合结构图扩展图交互概览图时间图 BPMN建模语言业务建模定义模型元素流对象活动事件网关 流数据人工制品泳池和泳道 建模案例订单采购流程建模电商系统订货业务流程…

搭建最简单的SpringBoot项目

1、创建maven项目 2、引入父pom <parent><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-parent</artifactId><version>2.7.15</version> </parent> 3、引入springboot-web依赖 <dependency…

React【React是什么?、创建项目 、React组件化、 JSX语法、条件渲染、列表渲染、事件处理】(一)

文章目录 React是什么&#xff1f; 为什么要学习React React开发前准备 创建React项目 React项目结构简介 React组件化 初识JSX 渲染JSX描述的页面 JSX语法 JSX的Class与Style属性 JSX生成的React元素 条件渲染&#xff08;一&#xff09; 条件渲染 &#xff0…

系统架构技能之设计模式-工厂模式

一、开篇 本文主要是讲述设计模式中最经典的创建型模式-工厂模式&#xff0c;本文将会从以下几点对工厂模式进行阐述。 本文将会从上面的四个方面进行详细的讲解和说明&#xff0c;当然会的朋友可以之处我的不足之处&#xff0c;不会的朋友也请我们能够相互学习讨论。 二、摘…

Java后端开发面试题——企业场景篇

单点登录这块怎么实现的 单点登录的英文名叫做&#xff1a;Single Sign On&#xff08;简称SSO&#xff09;,只需要登录一次&#xff0c;就可以访问所有信任的应用系统 JWT解决单点登录 用户访问其他系统&#xff0c;会在网关判断token是否有效 如果token无效则会返回401&am…

C#循环定时上传数据,失败重传解决方案,数据库标识

有些时候我们需要定时的上传一些数据库的数据&#xff0c;在数据不完整的情况下可能上传失败&#xff0c;上传失败后我们需要定时在重新上传失败的数据&#xff0c;该怎么合理的制定解决方案呢&#xff1f;下面一起看一下&#xff1a; 当然本篇文章只是提供一个思路&#xff0…