第63步 深度学习图像识别:多分类建模误判病例分析(Tensorflow)

基于WIN10的64位系统演示

一、写在前面

上两期我们基于TensorFlow和Pytorch环境做了图像识别的多分类任务建模。这一期我们做误判病例分析,分两节介绍,分别基于TensorFlow和Pytorch环境的建模和分析。

本期以健康组、肺结核组、COVID-19组、细菌性(病毒性)肺炎组为数据集,基于TensorFlow环境,构建mobilenet_v2多分类模型,因为它建模速度快。

同样,基于GPT-4辅助编程,这次改写过程会简单展示。

二、误判病例分析实战

使用胸片的数据集:肺结核病人和健康人的胸片的识别。其中,健康人900张,肺结核病人700张,COVID-19病人549张、细菌性(病毒性)肺炎组900张,分别存入单独的文件夹中。

直接分享代码:

######################################导入包###################################
from tensorflow import keras
import tensorflow as tf
from tensorflow.python.keras.layers import Dense, Flatten, Conv2D, MaxPool2D, Dropout, Activation, Reshape, Softmax, GlobalAveragePooling2D, BatchNormalization
from tensorflow.python.keras.layers.convolutional import Convolution2D, MaxPooling2D
from tensorflow.python.keras import Sequential
from tensorflow.python.keras import Model
from tensorflow.python.keras.optimizers import adam_v2
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator, image_dataset_from_directory
from tensorflow.python.keras.layers.preprocessing.image_preprocessing import RandomFlip, RandomRotation, RandomContrast, RandomZoom, RandomTranslation
import os,PIL,pathlib
import warnings#设置GPU
gpus = tf.config.list_physical_devices("GPU")warnings.filterwarnings("ignore")             #忽略警告信息
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False    # 用来正常显示负号################################导入数据集#####################################
data_dir = "./MTB-1"
data_dir = pathlib.Path(data_dir)
image_count = len(list(data_dir.glob('*/*')))
print("图片总数为:", image_count)batch_size = 32
img_height = 100
img_width  = 100# 创建一个数据集,其中包含所有图像的路径。
list_ds = tf.data.Dataset.list_files(str(data_dir/'*/*'), shuffle=True)
# 切分为训练集和验证集
val_size = int(image_count * 0.2)
train_ds = list_ds.skip(val_size)
val_ds = list_ds.take(val_size)class_names = np.array(sorted([item.name for item in data_dir.glob('*') if item.name != "LICENSE.txt"]))
print(class_names)def get_label(file_path):parts = tf.strings.split(file_path, os.path.sep)one_hot = parts[-2] == class_namesreturn tf.argmax(one_hot)def decode_img(img):img = tf.image.decode_image(img, channels=3, expand_animations=False)  # 指定 channels 参数img = tf.image.resize(img, [img_height, img_width])img = img / 255.0  # normalize to [0,1] rangereturn img# 在创建数据集时,添加一个新的元素:数据集类型
def process_path_with_filename_and_dataset_type(file_path, dataset_type):label = get_label(file_path)img = tf.io.read_file(file_path)img = decode_img(img)return img, label, file_path, dataset_typeAUTOTUNE = tf.data.AUTOTUNE# 在此处对train_ds和val_ds进行图像处理,包括添加文件名信息和数据集类型信息
train_ds_with_filenames_and_type = train_ds.map(lambda x: process_path_with_filename_and_dataset_type(x, 'Train'), num_parallel_calls=AUTOTUNE)
val_ds_with_filenames_and_type = val_ds.map(lambda x: process_path_with_filename_and_dataset_type(x, 'Val'), num_parallel_calls=AUTOTUNE)# 合并训练集和验证集
all_ds_with_filenames_and_type = train_ds_with_filenames_and_type.concatenate(val_ds_with_filenames_and_type)# 对训练数据集进行批处理和预加载
train_ds_with_filenames_and_type = train_ds_with_filenames_and_type.batch(batch_size)
train_ds_with_filenames_and_type = train_ds_with_filenames_and_type.prefetch(buffer_size=AUTOTUNE)# 对验证数据集进行批处理和预加载
val_ds_with_filenames_and_type = val_ds_with_filenames_and_type.batch(batch_size)
val_ds_with_filenames_and_type = val_ds_with_filenames_and_type.prefetch(buffer_size=AUTOTUNE)# 在进行模型训练时,不需要文件名和数据集类型信息,所以在此处移除
train_ds = train_ds_with_filenames_and_type.map(lambda x, y, z, t: (x, y))
val_ds = val_ds_with_filenames_and_type.map(lambda x, y, z, t: (x, y))for image, label, path, dataset_type in train_ds_with_filenames_and_type.take(1):print("Image shape: ", image.numpy().shape)print("Label: ", label.numpy())print("Path: ", path.numpy())print("Dataset type: ", dataset_type.numpy())train_ds = train_ds.prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.prefetch(buffer_size=AUTOTUNE)plt.figure(figsize=(10, 8))  # 图形的宽为10高为5
plt.suptitle("数据展示")for images, labels, paths, dataset_types in train_ds_with_filenames_and_type.take(1):for i in range(15):plt.subplot(4, 5, i + 1)plt.xticks([])plt.yticks([])plt.grid(False)plt.imshow(images[i].numpy())plt.xlabel(class_names[labels[i]])
plt.show()######################################数据增强函数################################data_augmentation = Sequential([RandomFlip("horizontal_and_vertical"),RandomRotation(0.2),RandomContrast(1.0),RandomZoom(0.5, 0.2),RandomTranslation(0.3, 0.5),
])def prepare(ds, augment=False):ds = ds.map(lambda x, y, z, t: (data_augmentation(x, training=True), y, z, t) if augment else (x, y, z, t), num_parallel_calls=AUTOTUNE)return ds# 注意这里变量名的更改
train_ds_with_filenames_and_type = prepare(train_ds_with_filenames_and_type, augment=True)# 在进行模型训练时,不需要文件名和数据集类型信息,所以在此处移除
train_ds = train_ds_with_filenames_and_type.map(lambda x, y, z, t: (x, y))
val_ds = val_ds_with_filenames_and_type.map(lambda x, y, z, t: (x, y))train_ds = train_ds.prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.prefetch(buffer_size=AUTOTUNE)###############################导入mobilenet_v2################################
#获取预训练模型对输入的预处理方法
from tensorflow.python.keras.applications import mobilenet_v2
from tensorflow.python.keras import Input, regularizers
IMG_SIZE = (img_height, img_width, 3)base_model = mobilenet_v2.MobileNetV2(input_shape=IMG_SIZE, include_top=False, #是否包含顶层的全连接层weights='imagenet')inputs = Input(shape=IMG_SIZE)
#模型
x = base_model(inputs, training=False) #参数不变化
#全局池化
x = GlobalAveragePooling2D()(x)
#BatchNormalization
x = BatchNormalization()(x)
#Dropout
x = Dropout(0.8)(x)
#Dense
x = Dense(128, kernel_regularizer=regularizers.l2(0.1))(x)  # 全连接层减少到128,添加 L2 正则化
#BatchNormalization
x = BatchNormalization()(x)
#激活函数
x = Activation('relu')(x)
#输出层
outputs = Dense(4, kernel_regularizer=regularizers.l2(0.1))(x)  # 添加 L2 正则化,改变输出层的神经元数量为4
#BatchNormalization
outputs = BatchNormalization()(outputs)
#激活函数
outputs = Activation('softmax')(outputs)  # 使用softmax激活函数,因为是多分类问题
#整体封装
model = Model(inputs, outputs)
#打印模型结构
print(model.summary())
#############################编译模型#########################################
#定义优化器
from tensorflow.python.keras.optimizers import adam_v2, rmsprop_v2optimizer = adam_v2.Adam()#编译模型
model.compile(optimizer=optimizer,loss='sparse_categorical_crossentropy',  # 因为是多分类问题,所以损失函数选择sparse_categorical_crossentropymetrics=['accuracy'])#训练模型
from tensorflow.python.keras.callbacks import ModelCheckpoint, Callback, EarlyStopping, ReduceLROnPlateau, LearningRateSchedulerNO_EPOCHS = 50
PATIENCE  = 10
VERBOSE   = 1# 设置动态学习率
annealer = LearningRateScheduler(lambda x: 1e-5 * 0.99 ** (x+NO_EPOCHS))# 设置早停
earlystopper = EarlyStopping(monitor='loss', patience=PATIENCE, verbose=VERBOSE)# 
checkpointer = ModelCheckpoint('mtb_jet_best_model_mobilenetv3samll-1.h5',monitor='val_accuracy',verbose=VERBOSE,save_best_only=True,save_weights_only=True)train_model  = model.fit(train_ds,epochs=NO_EPOCHS,verbose=1,validation_data=val_ds,callbacks=[earlystopper, checkpointer, annealer])#保存模型
#model.save('mtb_jet_best_model_mobilenet-1.h5')
#print("The trained model has been saved.")###########################误判病例分析#################################
import pandas as pd# 提取图片的信息并预测
data_list = []
for image, label, path, dataset_type in all_ds_with_filenames_and_type:# 获取图片名称、类别信息path_parts = path.numpy().decode('utf-8').split('/')dataset_type = dataset_type.numpy().decode('utf-8')true_class = class_names[label.numpy()]image_name = path_parts[-1]# 使用模型预测图片的类别img_array = np.expand_dims(image, axis=0)predictions = model.predict(img_array)pred_class = class_names[np.argmax(predictions)]# 根据预测结果判断所属的组别if true_class == pred_class:group = 'A'elif true_class == 'COVID-19':if pred_class == 'Normal':group = 'B'elif pred_class == 'Pneumonia':group = 'C'elif pred_class == 'Tuberculosis':group = 'D'elif true_class == 'Normal':if pred_class == 'COVID-19':group = 'E'elif pred_class == 'Pneumonia':group = 'F'elif pred_class == 'Tuberculosis':group = 'G'elif true_class == 'Pneumonia':if pred_class == 'COVID-19':group = 'H'elif pred_class == 'Normal':group = 'I'elif pred_class == 'Tuberculosis':group = 'J'elif true_class == 'Tuberculosis':if pred_class == 'COVID-19':group = 'H'elif pred_class == 'Normal':group = 'I'elif pred_class == 'Pneumonia':group = 'J'# 保存图片的信息和预测结果data_list.append([image_name, dataset_type, pred_class, group])# 将结果转化为DataFrame并保存为csv文件
result = pd.DataFrame(data_list, columns=["原始图片的名称", "属于训练集还是验证集", "预测为分组类型", "判定的组别"])
result.to_csv("result-m-t.csv", index=False)

三、改写过程

先说策略:首先,先把二分类的误判病例分析代码改成四分类的;其次,用咒语让GPT-4帮我们续写代码已达到误判病例分析。

策略的理由:之前介绍过,做误判病例分析是需要读取图片的路径信息。悲剧的是,我们之前在读取数据的时候使用的是“image_dataset_from_directory”函数,它不提供路径信息。因此,在二分类的误判病例分析的教程中,我们修改了数据读取的代码,因此,在此基础上进行修改,效率最高!

提供咒语如下:

①改写{代码1},改变成4分类的建模。代码1为:{XXX};

在{代码1}的基础上改写代码,达到下面要求:

(1)首先,提取出所有图片的“原始图片的名称”、“属于训练集还是验证集”、“预测为分组类型”;文件的路劲格式为:例如,“MTB-1\Normal\XXX.png”属于Normal,“MTB-1\COVID-19\XXX.jpg”属于COVID-19,“MTB-1\Pneumonia\XXX.jpeg”属于Pneumonia,“MTB-1\Tuberculosis\XXX.png”属于Tuberculosis;

(2)其次,根据样本预测结果,把样本分为以下若干组:(a)预测正确的图片,全部判定为A组;(b)本来就是COVID-19的图片,预测为Normal,判定为B组;(c)本来就是COVID-19的图片,预测为Pneumonia,判定为C组;(d)本来就是COVID-19的图片,预测为Tuberculosis,判定为D组;(e)本来就是Normal的图片,预测为COVID-19,判定为E组;(f)本来就是Normal的图片,预测为Pneumonia,判定为F组;(g)本来就是Normal的图片,预测为Tuberculosis,判定为G组;(h)本来就是Pneumonia的图片,预测为COVID-19,判定为H组;(i)本来就是Pneumonia的图片,预测为Normal,判定为I组;(j)本来就是Pneumonia的图片,预测为Tuberculosis,判定为J组;(k)本来就是Tuberculosis的图片,预测为COVID-19,判定为H组;(l)本来就是Tuberculosis的图片,预测为Normal,判定为I组;(m)本来就是Tuberculosis的图片,预测为Pneumonia,判定为J组;

(3)居于以上计算的结果,生成一个名为result-m.csv表格文件。列名分别为:“原始图片的名称”、“属于训练集还是验证集”、“预测为分组类型”、“判定的组别”。其中,“原始图片的名称”为所有图片的图片名称;“属于训练集还是验证集”为这个图片属于训练集还是验证集;“预测为分组类型”为模型预测该样本是哪一个分组;“判定的组别”为根据步骤(2)判定的组别,从A到J一共十组选择一个。

(4)需要把所有的图片都进行上面操作,注意是所有图片,而不只是一个批次的图片。

代码1为:{XXX}

③还需要根据报错做一些调整即可,自行调整。

最后,看看结果:

四、数据

链接:https://pan.baidu.com/s/1rqu15KAUxjNBaWYfEmPwgQ?pwd=xfyn

提取码:xfyn

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/66947.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python安装与Pycharm配置

Python与Pycharm安装 用了一年的Python最近被一个问题难倒了,pip安装一直不能用,报错说被另一个程序使用。被逼到只能重新安装python了,正好记录一下这个过程,写这篇笔记。(突然想到可能是配Arcgis的python接口&#…

金融风控数据分析-信用评分卡建模(附数据集下载地址)

本文引用自: 金融风控:信用评分卡建模流程 - 知乎 (zhihu.com) 在原文的基础上加上了一部分自己的理解,转载在CSDN上作为保留记录。 本文涉及到的数据集可直接从天池上面下载: Give Me Some Credit给我一些荣誉_数据集-阿里云…

2023年新风口,Kwai快手海外公会入驻详细指南!

Kwai快手海外公会发展前景作为中国短视频市场的领头羊,Kwai快手在海外市场也取得了一定的成绩。但是,海外市场的发展前景还存在一些不确定因素,需要快手进一步探索和努力。首先,海外市场对于内容的申请找cmxyci要求与国内市场有所…

MonoDETR: Depth-guided Transformer for Monocular 3D Object Detection 论文解读

文章目录 1.Abstract2. Introduction3. Related workDETR base methods 4. Method4.1Feature ExtractionVisual Featuresdepth featuresforeground depth map 4. 2 Depth guided transformerVisual and depth encodersDepth-guided-decoderDepth positional encoding 4. 3 Dete…

【vue2第十一章】v-model的原理详解 与 如何使用v-model对父子组件的value绑定 和修饰符.sync

v-model的原理详解 v-model的本质就是一个语法糖,实际上就是 :value"msg" 与 input"msg $event.target.value" 的简写。 :value"msg" 从数据单向绑定到input框,当data数据中的msg内容一旦改变,而input框数据…

OpenCV(十八):图像直方图

目录 1.直方图统计 2.直方图均衡化 3.直方图匹配 1.直方图统计 直方图统计是一种用于分析图像或数据的统计方法,它通过统计每个数值或像素值的频率分布来了解数据的分布情况。 在OpenCV中,可以使用函数cv::calcHist()来计算图像的直方图。 calcHist(…

uni-app之android离线打包

一 AndroidStudio创建项目 1.1,上一节演示了uni-app云打包,下面演示怎样androidStudio离线打包。在AndroidStudio里面新建空项目 1.2,下载uni-app离线SDK,离线SDK主要用于App本地离线打包及扩展原生能力,SDK下载链接h…

电商平台api对接货源

如今,电商平台已经成为了人们购物的主要途径之一。 然而,对于电商平台来说,货源对接一直是一个比较棘手的问题。为了解决这个问题,越来越多的电商平台开始使用API来对接货源。 API,即应用程序接口,是一种允…

分支创建查看切换

1、初始化git目录,创建文件并将其推送到本地库 git init echo "123" > hello.txt git add hello.txt git commit -m "first commit" hello.txt$ git init Initialized empty Git repository in D:/Git/git-demo/.git/ AdministratorDESKT…

django中配置使用websocket终极解决方案

django ASGI/Channels 启动和 ASGI/daphne的区别 Django ASGI/Channels 是 Django 框架的一个扩展,它提供了异步服务器网关接口(ASGI)协议的支持,以便处理实时应用程序的并发连接。ASGI 是一个用于构建异步 Web 服务器和应用程序…

ipad手写笔什么牌子好?适合开学买的电容笔推荐

随着社会和经济水平的不断提高,对电容笔的使用日益增加。国产平替的电容笔和原装苹果的电容笔没有太大的区别,不管是功能还是手感都很像,书写上的笔锋也很流畅,这让我很惊讶,因为其的价格很便宜,但其的体验…

【实训项目】精点考研

1.设计摘要 如果说高考是一次能够改变命运的考试,那么考研应该是另外一次。为什么那么多人都要考研呢?从中国教育在线官方公布是考研动机调查来看,大家扎堆考研的原因大概集中在这6个方面:本科就业压力大,提升竞争力、…

开源照片管理服务LibrePhotos

本文是为了解决网友 赵云遇到的问题,顺便折腾的。虽然软件跑起来了,但是他遇到的问题,超出了老苏的认知。当然最终问题还是得到了解决,不过与 LibrePhotos 无关; 什么是 LibrePhotos ? LibrePhotos 是一个自托管的开源…

学习高级数据结构:探索平衡树与图的高级算法

文章目录 1. 平衡树:维护数据的平衡与高效性1.1 AVL 树:严格的平衡1.2 红黑树:近似平衡 2. 图的高级算法:建模复杂关系与优化2.1 最小生成树:寻找最优连接方式2.2 拓扑排序:解决依赖关系 拓展思考 &#x1…

支付参考文档

支付宝官方提供的样例代码与支付宝通信: 小程序文档 - 支付宝文档中心支付宝文档中心https://docs.open.alipay.com/api_1/alipay.trade.query 参考sdk示例: 小程序文档 - 支付宝文档中心支付宝文档中心https://docs.open.alipay.com/api_1/alipay.tra…

unity界面上Global 与Local xyz- right up forward

gloabal 如果要沿这个方向移动就比较困难 local下就不一样了

docker高级(DockerFile解析)

1、构建三步骤 编写Dockerfile文件 docker build命令构建镜像 docker run依镜像运行容器实例 2、DockerFile构建过程解析 Dockerfile内容基础知识 1:每条保留字指令都必须为大写字母且后面要跟随至少一个参数 2:指令按照从上到下,顺序执行…

【2023集创赛】加速科技杯二等奖作品:基于ATE的电源芯片测试设计与性能分析

本文为2023年第七届全国大学生集成电路创新创业大赛(“集创赛”)加速科技杯二等奖作品分享,参加极术社区的【有奖征集】分享你的2023集创赛作品,秀出作品风采,分享2023集创赛作品扩大影响力,更有丰富电子礼…

c++11 标准模板(STL)(std::basic_stringstream)(四)

定义于头文件 <sstream> template< class CharT, class Traits std::char_traits<CharT> > class basic_stringstream;(C11 前)template< class CharT, class Traits std::char_traits<CharT>, class Allocator std::alloc…

C++算法 —— 动态规划(1)斐波那契数列模型

文章目录 1、动规思路简介2、第N个泰波那契数列3、三步问题4、使用最小花费爬楼梯5、解码方法6、动规分析总结 1、动规思路简介 动规的思路有五个步骤&#xff0c;且最好画图来理解细节&#xff0c;不要怕麻烦。当你开始画图&#xff0c;仔细阅读题时&#xff0c;学习中的沉浸…