LeNet实验 四分类 与 四分类变为多个二分类

 

目录

1. 划分二分类

2. 训练独立的二分类模型

3. 二分类预测结果代码

4. 二分类预测结果

5 改进训练模型

6 优化后 预测结果代码

 7 优化后预测结果

8 训练四分类模型 

9 预测结果代码

10 四分类结果识别


1. 划分二分类

可以根据不同的类别进行多个划分,以实现NonDemented为例,划分为NonDemented和Demented两类,不属于NonDemented的全都属于Demented

2. 训练独立的二分类模型

import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.image import ImageDataGeneratorfrom 文件准备 import data_dir# 数据生成器
train_datagen = ImageDataGenerator(rescale=1./255,shear_range=0.2,zoom_range=0.2,horizontal_flip=True,validation_split=0.2  # 20%用于验证
)train_generator = train_datagen.flow_from_directory(data_dir,target_size=(28, 28),batch_size=32,class_mode='binary',subset='training'
)validation_generator = train_datagen.flow_from_directory(data_dir,target_size=(28, 28),batch_size=32,class_mode='binary',subset='validation'
)# 构建LeNet-5模型
model = models.Sequential()
model.add(layers.Conv2D(6, (5, 5), activation='relu', input_shape=(28, 28, 3), padding='same'))
model.add(layers.AveragePooling2D((2, 2)))
model.add(layers.Conv2D(16, (5, 5), activation='relu', padding='same'))
model.add(layers.AveragePooling2D((2, 2)))
model.add(layers.Conv2D(120, (5, 5), activation='relu', padding='same'))
model.add(layers.Flatten())
model.add(layers.Dense(84, activation='relu'))
model.add(layers.Dense(1, activation='sigmoid'))# 编译模型
model.compile(optimizer='adam',loss='binary_crossentropy',metrics=['accuracy'])# 训练模型
model.fit(train_generator,steps_per_epoch=train_generator.samples // train_generator.batch_size,epochs=10,validation_data=validation_generator,validation_steps=validation_generator.samples // validation_generator.batch_size
)# 保存模型
model.save('lenet_binary_classification_model.h5')

3. 预测结果代码

import tensorflow as tf
from tensorflow.keras.preprocessing import image
import numpy as np
import matplotlib.pyplot as plt# 加载模型
model = tf.keras.models.load_model('lenet_binary_classification_model.h5')# 预处理图像
def preprocess_image(img_path):img = image.load_img(img_path, target_size=(28, 28))img_array = image.img_to_array(img) / 255.0img_array = np.expand_dims(img_array, axis=0)return img_array# 预测图像
img_path = 'D:\Pycharm_workspace\LeNet实验_二分类\Demented\moderateDem24.jpg'  # 测试图像路径
img_array = preprocess_image(img_path)
prediction = model.predict(img_array)
predicted_class = 'Demented' if prediction[0][0] > 0.5 else 'NonDemented'print(f'The predicted class is: {predicted_class}')# 显示图像
img = image.load_img(img_path, target_size=(28, 28))
plt.imshow(img)
plt.title(f'Predicted: {predicted_class}')
plt.show()

4. 预测结果

Demented结果

 NonDemented结果没有。。。。。。

竟然全都没有。。。。因为预测的全部都是Demented

疯狂找原因中

猜测是像素太低使得训练的模型准确率太低

于是重新训练

5 改进训练模型

进行重新训练

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt# 定义LeNet模型
def create_lenet_model(input_shape):model = Sequential([Conv2D(6, (5, 5), activation='relu', input_shape=input_shape, padding='same'),MaxPooling2D((2, 2), strides=2),Conv2D(16, (5, 5), activation='relu'),MaxPooling2D((2, 2), strides=2),Flatten(),Dense(120, activation='relu'),Dense(84, activation='relu'),Dense(1, activation='sigmoid')])model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])return model# 数据增强和数据生成器
train_datagen = ImageDataGenerator(rescale=1./255, validation_split=0.2)# 训练数据生成器
train_generator = train_datagen.flow_from_directory('D:\Pycharm_workspace\LeNet实验_二分类\image',target_size=(176, 208),batch_size=32,class_mode='binary',subset='training'
)# 验证数据生成器
validation_generator = train_datagen.flow_from_directory('D:\Pycharm_workspace\LeNet实验_二分类\image',target_size=(176, 208),batch_size=32,class_mode='binary',subset='validation'
)# 创建并训练模型
input_shape = (176, 208, 3)
model = create_lenet_model(input_shape)
history = model.fit(train_generator, epochs=10, validation_data=validation_generator)# 保存模型
model.save('dementia_classification_model.h5')# 绘制训练和验证损失
plt.plot(history.history['loss'], label='训练损失')
plt.plot(history.history['val_loss'], label='验证损失')
plt.title('训练和验证损失')
plt.xlabel('时期')
plt.ylabel('损失')
plt.legend()
plt.show()# 绘制训练和验证准确率
plt.plot(history.history['accuracy'], label='训练准确率')
plt.plot(history.history['val_accuracy'], label='验证准确率')
plt.title('训练和验证准确率')
plt.xlabel('时期')
plt.ylabel('准确率')
plt.legend()
plt.show()

 这里还有图形画loss与准确率但是我忘记保存了,就用控制台的输出

 可以看到loss值非常小而且准确率是100

6 优化后 预测结果代码

import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
import numpy as np
import matplotlib.pyplot as plt
import os# 加载模型
model = load_model('dementia_classification_model.h5')# 定义类别标签
class_labels = ['Demented', 'NonDemented']# 预测函数
def predict_image(img_path):img = image.load_img(img_path, target_size=(176, 208))img_array = image.img_to_array(img)img_array = np.expand_dims(img_array, axis=0)img_array /= 255.0prediction = model.predict(img_array)predicted_class = class_labels[int(prediction[0] > 0.5)]# 显示图像和预测结果plt.imshow(image.load_img(img_path))plt.title(f'Predicted: {predicted_class}')plt.axis('off')plt.show()# 预测并展示结果
img_path = r'D:\Pycharm_workspace\LeNet实验_二分类\image\NonDemented\nonDem1.jpg'  # 替换为你的图片路径
predict_image(img_path)

 7 优化后预测结果

 图片与预测结果对应上了(右侧是图片链接可以看到是Dem的类型)

 NonDem的也是对应上了

 就此训练完成

 

8 训练四分类模型 

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt# 定义LeNet模型
def create_lenet_model(input_shape):model = Sequential([Conv2D(6, (5, 5), activation='relu', input_shape=input_shape, padding='same'),MaxPooling2D((2, 2), strides=2),Conv2D(16, (5, 5), activation='relu'),MaxPooling2D((2, 2), strides=2),Flatten(),Dense(120, activation='relu'),Dense(84, activation='relu'),Dense(4, activation='softmax')])model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])return model# 数据增强和数据生成器
train_datagen = ImageDataGenerator(rescale=1./255, validation_split=0.2)# 训练数据生成器
train_generator = train_datagen.flow_from_directory('D:\Pycharm_workspace\LeNet实验_四分类\image',target_size=(176, 208),batch_size=32,class_mode='categorical',subset='training'
)# 验证数据生成器
validation_generator = train_datagen.flow_from_directory('D:\Pycharm_workspace\LeNet实验_四分类\image',target_size=(176, 208),batch_size=32,class_mode='categorical',subset='validation'
)# 创建并训练模型
input_shape = (176, 208, 3)
model = create_lenet_model(input_shape)
history = model.fit(train_generator, epochs=10, validation_data=validation_generator)# 保存模型
model.save('dementia_classification_model.h5')# 绘制训练和验证损失
plt.plot(history.history['loss'], label='训练损失')
plt.plot(history.history['val_loss'], label='验证损失')
plt.title('训练和验证损失')
plt.xlabel('时期')
plt.ylabel('损失')
plt.legend()
plt.show()# 绘制训练和验证准确率
plt.plot(history.history['accuracy'], label='训练准确率')
plt.plot(history.history['val_accuracy'], label='验证准确率')
plt.title('训练和验证准确率')
plt.xlabel('时期')
plt.ylabel('准确率')
plt.legend()
plt.show()

 

loss值与准确率的变化图

可以看到才第四轮准确率就已经很高了 

9 预测结果代码

import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
import numpy as np
import matplotlib.pyplot as plt# 加载模型
model = load_model('dementia_classification_model.h5')# 定义类别标签
class_labels = ['MildDemented', 'ModerateDemented', 'NonDemented', 'VeryMildDemented']# 预测函数
def predict_image(img_path):img = image.load_img(img_path, target_size=(176, 208))img_array = image.img_to_array(img)img_array = np.expand_dims(img_array, axis=0)img_array /= 255.0prediction = model.predict(img_array)predicted_class = class_labels[np.argmax(prediction)]# 显示图像和预测结果plt.imshow(image.load_img(img_path))plt.title(f'Predicted: {predicted_class}')plt.axis('off')plt.show()# 预测并展示结果
img_path = r'D:\Pycharm_workspace\LeNet实验_四分类\image\VeryMildDemented\verymildDem0.jpg'  # 你的图片路径
predict_image(img_path)

 

10 四分类结果识别

1 MildDem成功识别(右侧有图片名称)

2 ModerateDem 成功识别

3 NonDem成功识别

 4 VeryMildDem成功识别

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/47228.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

科研绘图系列:R语言分割小提琴图(Split-violin)

介绍 分割小提琴图(Split-violin plot)是一种数据可视化工具,它结合了小提琴图(violin plot)和箱线图(box plot)的特点。小提琴图是一种展示数据分布的图形,它通过在箱线图的两侧添加曲线来表示数据的密度分布,曲线的宽度表示数据点的密度。而分割小提琴图则是将小提…

绿色算力|暴雨服务器用芯片筑起“十四五”转型新篇章

面对全球气候变化、技术革新以及能源转型的新形势,发展低碳、高效的绿色算力不仅是顺应时代的要求,更是我国建设数字基础设施和展现节能减碳大国担当的重要命题,在此背景下也要求在提升算力规模和性能的同时,积极探索推动算力基础…

【iOS】APP仿写——网易云音乐

网易云音乐 启动页发现定时器控制轮播图UIButtonConfiguration 发现换头像 我的总结 启动页 这里我的启动页是使用Xcode自带的启动功能,将图片放置在LaunchScreen中即可。这里也可以通过定时器控制,来实现启动的效果 效果图: 这里放一篇大…

31_MobileViT网络讲解

VIT:https://blog.csdn.net/qq_51605551/article/details/140445491?spm1001.2014.3001.5501 1.1 简介 MobileVIT是“Mobile Vision Transformer”的简称,是一种专门为移动设备设计的高效视觉模型。它结合了Transformer架构的优点与移动优先的设计原则&#xff0…

在eclipse中导入本地的jar包配置Junit环境步骤(包含Junit中的方法一直标红的解决方法)

搭建JUnit环境 一、配置环境 跟上一篇的那种方法不一样,直接Add to Build Path 是先将jar包复制到项目的lib目录下,然后直接添加 选定项目>>>右键>>>Bulid Path>>>Add Libraries>>>Configure Build Path(配置构建路…

python—爬虫爬取电影页面实例

下面是一个简单的爬虫实例,使用Python的requests库来发送HTTP请求,并使用lxml库来解析HTML页面内容。这个爬虫的目标是抓取一个电影网站,并提取每部电影的主义部分。 首先,确保你已经安装了requests和lxml库。如果没有安装&#x…

Fast Planner规划算法(一)—— Fast Planner前端

本系列文章用于回顾学习记录Fast-Planner规划算法的相关内容,【本系列博客写于2023年9月,共包含四篇文章,现在进行补发第一篇,其余几篇文章将在近期补发】 一、Fast Planner前端 Fast Planner的轨迹规划部分一共分为三个模块&…

4.基础知识-数据库技术基础

基础知识 一、数据库基本概念1、数据库系统基础知识2、三级模式-两级映像3、数据库设计4、数据模型:4.1 E-R模型★4.2 关系模型★ 5、关系代数 二、规范化和并发控制1、函数依赖2、键与约束3、范式★3.1 第一范式1NF实例3.2 第二范式2NF3.3 第三范式3NF3.4 BC范式BC…

rockchip的yolov5 rknn python推理分析

rockchip的yolov5 rknn推理分析 对于rockchip给出的这个yolov5后处理代码的分析,本人能力十分有限,可能有的地方描述的很不好,欢迎大家和我一起讨论,指出我的错误!!! RKNN模型输出 将官方的Y…

直方图的最大长方形面积

前提知识:单调栈基础题-CSDN博客 子数组的最大值-CSDN博客 题目描述: 给定一个非负数(0和正数),代表直方图,返回直方图的最大长方形面积,比如,arr {3, 2, 4, 2, 5}&#xff0c…

景区导航导览系统:基于AR技术+VR技术的功能效益全面解析

在数字化时代背景下,游客对旅游体验的期望不断提升。游客们更倾向于使用手机作为旅行的贴身助手,不仅因为它能提供实时、精准的导航服务,更在于其融合AR(增强现实)、VR(虚拟现实)等前沿技术&…

十三、网络编程正则表达式设计模式(模块23)

网络编程&正则表达式&设计模式 模块23_网络编程&正则表达式&设计模式第一章.网络编程1.软件结构2.服务器概念3.通信三要素4.UDP协议编程4.1.客户端(发送端)4.2.服务端(接收端) 5.TCP协议编程4.1.编写客户端4.2.编写服务端 6.文件上传6.1.文件上传客户端以及服务…

【开发踩坑】 MySQL不支持特殊字符(表情)插入问题

背景 线上功能报错: Cause:java.sql.SQLException:Incorrect string value:xFO\x9F\x9FxBO for column commentat row 1 uncategorized SQLException; SQL state [HY000]:error code [1366]排查 初步觉得是编码问题(utf8 — utf8mb4) 参考上…

Leetcode 2520. 统计能整除数字的位数

问题描述: 给你一个整数 num ,返回 num 中能整除 num 的数位的数目。 如果满足 nums % val 0 ,则认为整数 val 可以整除 nums 。 示例 1: 输入:num 7 输出:1 解释:7 被自己整除&#xff0…

浅谈芯片验证中的仿真运行之 timescale (五)提防陷阱

一 仿真单位 timeunit 我们知道,当我们的代码中写清楚延时语句时,若不指定时间单位,则使用此单位; 例如: `timescale 1ns/1ps 则 #15 语句表示delay15ns; 例:如下代码,module a 的timescale是1ns/1ps, module b 是1ps/1ps; module b中的clk,频率是由输入参…

HTML开发笔记:1.环境、标签和属性、CSS语法

一、环境与新建 在VSCODE里&#xff0c;加载插件&#xff1a;“open in browser” 然后新建一个文件夹&#xff0c;再在VSCODE中打开该文件夹&#xff0c;在右上角图标新建文档&#xff0c;一定要是加.html&#xff0c;不要忘了文件后缀 复制任意一个代码比如&#xff1a; <…

primeflex教学笔记20240720, FastAPI+Vue3+PrimeVue前后端分离开发

练习 先实现基本的页面结构&#xff1a; 代码如下&#xff1a; <template><div class"flex p-3 bg-gray-100 gap-3"><div class"w-20rem h-12rem bg-indigo-200 flex justify-content-center align-items-center text-white text-5xl">…

C++关系运算符重载 函数运算符重载

#include <iostream> #include <string> using namespace std; //实现关系运算符重载 仿函数 class Person { public:Person(int Age,string Name){age Age;name Name;}int age;string name;int operator()(){return age ;}bool operator(const Person& p)…

用html做python教程01

用html做python教程01 前言开肝构思实操额外修饰更换字体自适应 最后 前言 今天打开csdn的时候&#xff0c;看见csdn给我推荐了一个python技能书。 说实话&#xff0c;做得真不错。再看看我自己&#xff0c;有亿点差距&#x1f61f;。 开肝 先创建一个文件&#xff0c;后缀…

获取本地时间(Linux下,C语言)

一、函数 #include <time.h> time_t time(time_t *tloc);函数功能&#xff1a;获取本机时间&#xff08;以秒数存储&#xff0c;从1970年1月1日0:0:0开始到现在&#xff09;。返回值&#xff1a;获得的秒数&#xff0c;如果形参非空&#xff0c;返回值也可以通过传址调用…