LeNet实验 四分类 与 四分类变为多个二分类

 

目录

1. 划分二分类

2. 训练独立的二分类模型

3. 二分类预测结果代码

4. 二分类预测结果

5 改进训练模型

6 优化后 预测结果代码

 7 优化后预测结果

8 训练四分类模型 

9 预测结果代码

10 四分类结果识别


1. 划分二分类

可以根据不同的类别进行多个划分,以实现NonDemented为例,划分为NonDemented和Demented两类,不属于NonDemented的全都属于Demented

2. 训练独立的二分类模型

import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.image import ImageDataGeneratorfrom 文件准备 import data_dir# 数据生成器
train_datagen = ImageDataGenerator(rescale=1./255,shear_range=0.2,zoom_range=0.2,horizontal_flip=True,validation_split=0.2  # 20%用于验证
)train_generator = train_datagen.flow_from_directory(data_dir,target_size=(28, 28),batch_size=32,class_mode='binary',subset='training'
)validation_generator = train_datagen.flow_from_directory(data_dir,target_size=(28, 28),batch_size=32,class_mode='binary',subset='validation'
)# 构建LeNet-5模型
model = models.Sequential()
model.add(layers.Conv2D(6, (5, 5), activation='relu', input_shape=(28, 28, 3), padding='same'))
model.add(layers.AveragePooling2D((2, 2)))
model.add(layers.Conv2D(16, (5, 5), activation='relu', padding='same'))
model.add(layers.AveragePooling2D((2, 2)))
model.add(layers.Conv2D(120, (5, 5), activation='relu', padding='same'))
model.add(layers.Flatten())
model.add(layers.Dense(84, activation='relu'))
model.add(layers.Dense(1, activation='sigmoid'))# 编译模型
model.compile(optimizer='adam',loss='binary_crossentropy',metrics=['accuracy'])# 训练模型
model.fit(train_generator,steps_per_epoch=train_generator.samples // train_generator.batch_size,epochs=10,validation_data=validation_generator,validation_steps=validation_generator.samples // validation_generator.batch_size
)# 保存模型
model.save('lenet_binary_classification_model.h5')

3. 预测结果代码

import tensorflow as tf
from tensorflow.keras.preprocessing import image
import numpy as np
import matplotlib.pyplot as plt# 加载模型
model = tf.keras.models.load_model('lenet_binary_classification_model.h5')# 预处理图像
def preprocess_image(img_path):img = image.load_img(img_path, target_size=(28, 28))img_array = image.img_to_array(img) / 255.0img_array = np.expand_dims(img_array, axis=0)return img_array# 预测图像
img_path = 'D:\Pycharm_workspace\LeNet实验_二分类\Demented\moderateDem24.jpg'  # 测试图像路径
img_array = preprocess_image(img_path)
prediction = model.predict(img_array)
predicted_class = 'Demented' if prediction[0][0] > 0.5 else 'NonDemented'print(f'The predicted class is: {predicted_class}')# 显示图像
img = image.load_img(img_path, target_size=(28, 28))
plt.imshow(img)
plt.title(f'Predicted: {predicted_class}')
plt.show()

4. 预测结果

Demented结果

 NonDemented结果没有。。。。。。

竟然全都没有。。。。因为预测的全部都是Demented

疯狂找原因中

猜测是像素太低使得训练的模型准确率太低

于是重新训练

5 改进训练模型

进行重新训练

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt# 定义LeNet模型
def create_lenet_model(input_shape):model = Sequential([Conv2D(6, (5, 5), activation='relu', input_shape=input_shape, padding='same'),MaxPooling2D((2, 2), strides=2),Conv2D(16, (5, 5), activation='relu'),MaxPooling2D((2, 2), strides=2),Flatten(),Dense(120, activation='relu'),Dense(84, activation='relu'),Dense(1, activation='sigmoid')])model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])return model# 数据增强和数据生成器
train_datagen = ImageDataGenerator(rescale=1./255, validation_split=0.2)# 训练数据生成器
train_generator = train_datagen.flow_from_directory('D:\Pycharm_workspace\LeNet实验_二分类\image',target_size=(176, 208),batch_size=32,class_mode='binary',subset='training'
)# 验证数据生成器
validation_generator = train_datagen.flow_from_directory('D:\Pycharm_workspace\LeNet实验_二分类\image',target_size=(176, 208),batch_size=32,class_mode='binary',subset='validation'
)# 创建并训练模型
input_shape = (176, 208, 3)
model = create_lenet_model(input_shape)
history = model.fit(train_generator, epochs=10, validation_data=validation_generator)# 保存模型
model.save('dementia_classification_model.h5')# 绘制训练和验证损失
plt.plot(history.history['loss'], label='训练损失')
plt.plot(history.history['val_loss'], label='验证损失')
plt.title('训练和验证损失')
plt.xlabel('时期')
plt.ylabel('损失')
plt.legend()
plt.show()# 绘制训练和验证准确率
plt.plot(history.history['accuracy'], label='训练准确率')
plt.plot(history.history['val_accuracy'], label='验证准确率')
plt.title('训练和验证准确率')
plt.xlabel('时期')
plt.ylabel('准确率')
plt.legend()
plt.show()

 这里还有图形画loss与准确率但是我忘记保存了,就用控制台的输出

 可以看到loss值非常小而且准确率是100

6 优化后 预测结果代码

import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
import numpy as np
import matplotlib.pyplot as plt
import os# 加载模型
model = load_model('dementia_classification_model.h5')# 定义类别标签
class_labels = ['Demented', 'NonDemented']# 预测函数
def predict_image(img_path):img = image.load_img(img_path, target_size=(176, 208))img_array = image.img_to_array(img)img_array = np.expand_dims(img_array, axis=0)img_array /= 255.0prediction = model.predict(img_array)predicted_class = class_labels[int(prediction[0] > 0.5)]# 显示图像和预测结果plt.imshow(image.load_img(img_path))plt.title(f'Predicted: {predicted_class}')plt.axis('off')plt.show()# 预测并展示结果
img_path = r'D:\Pycharm_workspace\LeNet实验_二分类\image\NonDemented\nonDem1.jpg'  # 替换为你的图片路径
predict_image(img_path)

 7 优化后预测结果

 图片与预测结果对应上了(右侧是图片链接可以看到是Dem的类型)

 NonDem的也是对应上了

 就此训练完成

 

8 训练四分类模型 

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt# 定义LeNet模型
def create_lenet_model(input_shape):model = Sequential([Conv2D(6, (5, 5), activation='relu', input_shape=input_shape, padding='same'),MaxPooling2D((2, 2), strides=2),Conv2D(16, (5, 5), activation='relu'),MaxPooling2D((2, 2), strides=2),Flatten(),Dense(120, activation='relu'),Dense(84, activation='relu'),Dense(4, activation='softmax')])model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])return model# 数据增强和数据生成器
train_datagen = ImageDataGenerator(rescale=1./255, validation_split=0.2)# 训练数据生成器
train_generator = train_datagen.flow_from_directory('D:\Pycharm_workspace\LeNet实验_四分类\image',target_size=(176, 208),batch_size=32,class_mode='categorical',subset='training'
)# 验证数据生成器
validation_generator = train_datagen.flow_from_directory('D:\Pycharm_workspace\LeNet实验_四分类\image',target_size=(176, 208),batch_size=32,class_mode='categorical',subset='validation'
)# 创建并训练模型
input_shape = (176, 208, 3)
model = create_lenet_model(input_shape)
history = model.fit(train_generator, epochs=10, validation_data=validation_generator)# 保存模型
model.save('dementia_classification_model.h5')# 绘制训练和验证损失
plt.plot(history.history['loss'], label='训练损失')
plt.plot(history.history['val_loss'], label='验证损失')
plt.title('训练和验证损失')
plt.xlabel('时期')
plt.ylabel('损失')
plt.legend()
plt.show()# 绘制训练和验证准确率
plt.plot(history.history['accuracy'], label='训练准确率')
plt.plot(history.history['val_accuracy'], label='验证准确率')
plt.title('训练和验证准确率')
plt.xlabel('时期')
plt.ylabel('准确率')
plt.legend()
plt.show()

 

loss值与准确率的变化图

可以看到才第四轮准确率就已经很高了 

9 预测结果代码

import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
import numpy as np
import matplotlib.pyplot as plt# 加载模型
model = load_model('dementia_classification_model.h5')# 定义类别标签
class_labels = ['MildDemented', 'ModerateDemented', 'NonDemented', 'VeryMildDemented']# 预测函数
def predict_image(img_path):img = image.load_img(img_path, target_size=(176, 208))img_array = image.img_to_array(img)img_array = np.expand_dims(img_array, axis=0)img_array /= 255.0prediction = model.predict(img_array)predicted_class = class_labels[np.argmax(prediction)]# 显示图像和预测结果plt.imshow(image.load_img(img_path))plt.title(f'Predicted: {predicted_class}')plt.axis('off')plt.show()# 预测并展示结果
img_path = r'D:\Pycharm_workspace\LeNet实验_四分类\image\VeryMildDemented\verymildDem0.jpg'  # 你的图片路径
predict_image(img_path)

 

10 四分类结果识别

1 MildDem成功识别(右侧有图片名称)

2 ModerateDem 成功识别

3 NonDem成功识别

 4 VeryMildDem成功识别

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/47228.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Unity运行时节点编辑器——互动电影案例

Unity运行时节点编辑器——互动电影案例 引子 最近需要做一个互动电影的小项目,需求很简单,就是有一堆的视频,然后在某视频播放完的时候,让观众做一个选择题,然后根据观众做出的选择,继续播放不同的视频&…

科研绘图系列:R语言分割小提琴图(Split-violin)

介绍 分割小提琴图(Split-violin plot)是一种数据可视化工具,它结合了小提琴图(violin plot)和箱线图(box plot)的特点。小提琴图是一种展示数据分布的图形,它通过在箱线图的两侧添加曲线来表示数据的密度分布,曲线的宽度表示数据点的密度。而分割小提琴图则是将小提…

【题解】2014年408计网真题

15.【2014统考真题】使用浏览器访问某大学的Web网站主页时,不可能使用到的协议是()。A. PPP B. ARP C. UDP D. SMTP 我们逐一分析给定的选项: A. PPP(Point-to-Point Protocol,点对点协议)&…

Python 模块导入方式

在Python 中,导入外部模块有2种方式 以 Pyhton 自带的 time 模块 为例: 使用 import time 导入方式 import time print(time.ctime()) 注意事项: time 模块导入后,使用以下格式来调用模块中的函数: 模块名.函数名 如果导入的模…

绿色算力|暴雨服务器用芯片筑起“十四五”转型新篇章

面对全球气候变化、技术革新以及能源转型的新形势,发展低碳、高效的绿色算力不仅是顺应时代的要求,更是我国建设数字基础设施和展现节能减碳大国担当的重要命题,在此背景下也要求在提升算力规模和性能的同时,积极探索推动算力基础…

day2加餐 Go 接口型函数的使用场景

文章目录 问题价值使用场景其他语言类似特性 问题 在 动手写分布式缓存 - GeeCache day2 单机并发缓存 这篇文章中,有一个接口型函数的实现: // A Getter loads data for a key. type Getter interface {Get(key string) ([]byte, error) }// A Getter…

【iOS】APP仿写——网易云音乐

网易云音乐 启动页发现定时器控制轮播图UIButtonConfiguration 发现换头像 我的总结 启动页 这里我的启动页是使用Xcode自带的启动功能,将图片放置在LaunchScreen中即可。这里也可以通过定时器控制,来实现启动的效果 效果图: 这里放一篇大…

31_MobileViT网络讲解

VIT:https://blog.csdn.net/qq_51605551/article/details/140445491?spm1001.2014.3001.5501 1.1 简介 MobileVIT是“Mobile Vision Transformer”的简称,是一种专门为移动设备设计的高效视觉模型。它结合了Transformer架构的优点与移动优先的设计原则&#xff0…

在eclipse中导入本地的jar包配置Junit环境步骤(包含Junit中的方法一直标红的解决方法)

搭建JUnit环境 一、配置环境 跟上一篇的那种方法不一样,直接Add to Build Path 是先将jar包复制到项目的lib目录下,然后直接添加 选定项目>>>右键>>>Bulid Path>>>Add Libraries>>>Configure Build Path(配置构建路…

Kotlin单例、数据类、静态

文章目录 1. 数据类2. 单例3. 静态 1. 数据类 data class Cellphone(val brand: String, val price: Double)自动生成了get set hashcode equals toString等方法。通过反编译(打开AS,在Tools-Kotlin-ShowBytecode-Decompile)可以得到如下结果 public final class …

python—爬虫爬取电影页面实例

下面是一个简单的爬虫实例,使用Python的requests库来发送HTTP请求,并使用lxml库来解析HTML页面内容。这个爬虫的目标是抓取一个电影网站,并提取每部电影的主义部分。 首先,确保你已经安装了requests和lxml库。如果没有安装&#x…

Fast Planner规划算法(一)—— Fast Planner前端

本系列文章用于回顾学习记录Fast-Planner规划算法的相关内容,【本系列博客写于2023年9月,共包含四篇文章,现在进行补发第一篇,其余几篇文章将在近期补发】 一、Fast Planner前端 Fast Planner的轨迹规划部分一共分为三个模块&…

4.基础知识-数据库技术基础

基础知识 一、数据库基本概念1、数据库系统基础知识2、三级模式-两级映像3、数据库设计4、数据模型:4.1 E-R模型★4.2 关系模型★ 5、关系代数 二、规范化和并发控制1、函数依赖2、键与约束3、范式★3.1 第一范式1NF实例3.2 第二范式2NF3.3 第三范式3NF3.4 BC范式BC…

rockchip的yolov5 rknn python推理分析

rockchip的yolov5 rknn推理分析 对于rockchip给出的这个yolov5后处理代码的分析,本人能力十分有限,可能有的地方描述的很不好,欢迎大家和我一起讨论,指出我的错误!!! RKNN模型输出 将官方的Y…

直方图的最大长方形面积

前提知识:单调栈基础题-CSDN博客 子数组的最大值-CSDN博客 题目描述: 给定一个非负数(0和正数),代表直方图,返回直方图的最大长方形面积,比如,arr {3, 2, 4, 2, 5}&#xff0c…

景区导航导览系统:基于AR技术+VR技术的功能效益全面解析

在数字化时代背景下,游客对旅游体验的期望不断提升。游客们更倾向于使用手机作为旅行的贴身助手,不仅因为它能提供实时、精准的导航服务,更在于其融合AR(增强现实)、VR(虚拟现实)等前沿技术&…

C++编程:实现一个跨平台安全的定时器Timer模块

文章目录 0. 概要1. 设计目标2. SafeTimer 类的实现2.1 头文件 safe_timer.h源文件 safe_timer.cpp 3. 工作流程图4. 单元测试 0. 概要 对于C应用编程,定时器模块是一个至关重要的组件。为了确保系统的可靠性和功能安全,我们需要设计一个高效、稳定的定…

十三、网络编程正则表达式设计模式(模块23)

网络编程&正则表达式&设计模式 模块23_网络编程&正则表达式&设计模式第一章.网络编程1.软件结构2.服务器概念3.通信三要素4.UDP协议编程4.1.客户端(发送端)4.2.服务端(接收端) 5.TCP协议编程4.1.编写客户端4.2.编写服务端 6.文件上传6.1.文件上传客户端以及服务…

AI学习指南机器学习篇-SOM算法原理

AI学习指南机器学习篇-SOM算法原理 自组织映射(Self-Organizing Map, SOM)算法是一种常用的无监督学习算法,被广泛应用于数据聚类、可视化和模式识别等领域。SOM算法可以帮助我们发现数据中的隐藏结构,并且能够在高维空间中有效地…

【开发踩坑】 MySQL不支持特殊字符(表情)插入问题

背景 线上功能报错: Cause:java.sql.SQLException:Incorrect string value:xFO\x9F\x9FxBO for column commentat row 1 uncategorized SQLException; SQL state [HY000]:error code [1366]排查 初步觉得是编码问题(utf8 — utf8mb4) 参考上…