基于Keras3.x使用CNN实现简单的猫狗分类

使用CNN实现简单的猫狗分类
完整代码见:基于Keras3.x使用CNN实现简单的猫狗分类,置信度约为:85%

文章目录

  • 概述
    • 项目整体目录
    • 环境版本
    • 注意
  • 环境准备
    • 下载miniconda
    • 新建虚拟环境
    • 基于conda虚拟环境新建Pycharm项目
    • 下载分类需要用到的依赖
  • 数据准备
    • 数据目录结构
    • 挪动图片可以采用下列代码
  • config
  • 准备训练、测试数据集
  • 构建模型
  • 训练模型
    • 训练过程
    • 损失和准确度曲线
  • 测试模型
    • 使用带标签的测试图片评估整体准确率
    • 定义模型类
    • 预测单张图片
    • 将不带标签的测试图片分类并存到对应目录中
    • 源码
  • 错误记录
    • 双重归一化问题

概述

项目整体目录

在这里插入图片描述

  • /data 存放数据集
  • /model 存放训练好的模型
  • /config.py 存储一些关键模型参数和路径信息等
  • /dataset.py 返回数据增强后的数据集,用于模型训练
  • /model.py 定义模型
  • /train.py 训练模型并绘制训练损失和准确度曲线
  • /test.py 测试模型精准度

环境版本

  • python 3.11
  • keras 3.9.2
  • tensorflow 2.19.0

注意

本项目使用Keras3.x实现,代码与keras 2.x有部分不同,请仔细甄别

环境准备

下载miniconda

如果没有下载conda,参照上一篇文章进行下载配置

新建虚拟环境

创建一个用于猫狗识别的虚拟环境,可以指定py版本

conda create --name catanddog python=3.11

基于conda虚拟环境新建Pycharm项目

依次选择菜单路径:
File-NewProject-Pure Python
在弹出的窗口选择:

  • custom environment
  • Select existing
  • Type:conda
  • 选择刚刚新建的catanddog虚拟环境
    如图:
    在这里插入图片描述

下载分类需要用到的依赖

主要用到:keras3.9.2和tensorflow2.19.0

pip install keras

数据准备

从Kaggle上下载常用的猫狗分类数据集,下载下来后,有训练数据(共25000张,猫狗各一半,带标签,命名示例:dog.0.jpg)和测试数据(共12500张,不带标签,命名示例:1.jgp)。
将训练数据分为两份,前20000张用于训练数据,后5000张带标签的数据用于预测模型整体准确度。12500张测试数据可以用于单张图片的模型预测,以及将猫狗分类后放入对应的目录中,方便查看。
如图组织数据:

  • test为12500张不带标签的测试数据;
  • test2为5000张带标签的测试数据;
  • train为20000张训练数据

数据目录结构

注意:必须把训练数据、test2图片放在新建好的cats和dogs目录下,模型才能自动推断标签
在这里插入图片描述

挪动图片可以采用下列代码

import os, shutil
# 将train_dir_tag_cat后2500张猫图像移动到test2_dir_tag_cat
cats = ['cat.{}.jpg'.format(i) for i in range(1000)]
for cat in cats:src = os.path.join(train_dir_tag_cat, cat)dst = os.path.join(test2_dir_tag_cat, cat)shutil.move(src, dst)

config

config.py:用于存储一些关键参数和路径信息等。
训练的batch为:32
训练15个EPOCH

"""
@Author      :Ayaki Shi
@Date        :2025/4/18 11:03 
@Description : 配置信息
"""
import os, shutildata_dir = './data'# 训练集、测试集所在路径
test_dir = os.path.join(data_dir, 'test')
test2_dir = os.path.join(data_dir, 'test2')
train_dir = os.path.join(data_dir, 'train')# 划分标签后的数据路径
train_dir_tag_cat = os.path.join(train_dir, 'cats')
test_dir_tag_cat = os.path.join(test_dir, 'cats')
test2_dir_tag_cat = os.path.join(test2_dir, 'cats')train_dir_tag_dog = os.path.join(train_dir, 'dogs')
test_dir_tag_dog = os.path.join(test_dir, 'dogs')
test2_dir_tag_dog = os.path.join(test2_dir, 'dogs')# 训练参数
IMG_SIZE = (256, 256)
BATCH_SIZE = 32
EPOCHS = 15# 模型路径
MODEL_PATH = './model/CatAndDogClassifier.keras'

准备训练、测试数据集

dataset.py
训练数据经过数据增强后返回,用于测试模型整体准确度的test2无需数据增强直接返回。
注意: 这个方法ImageDataGenerator已经不推荐使用了,因此使用image_dataset_from_directory这个方法,可以根据目录自动推断标签,只是数据增强稍微复杂了点

"""
@Author      :Ayaki Shi
@Date        :2025/4/18 11:02
@Description : 返回dataset
"""from keras.api.utils import image_dataset_from_directory
from config import train_dir,test2_dir, BATCH_SIZE,IMG_SIZE
from keras import layers, models
import tensorflow as tf# 数据增强
def create_augmentation_model():return models.Sequential([layers.RandomFlip("horizontal", seed=42),layers.RandomRotation(0.2, fill_mode='nearest', seed=42),layers.RandomZoom(0.2, fill_mode='nearest', seed=42),layers.RandomContrast(0.3, seed=42),layers.RandomTranslation(0.1, 0.1, fill_mode='nearest', seed=42),], name="data_augmentation")def create_train_dataset():train_dataset = image_dataset_from_directory(train_dir,label_mode = 'binary',batch_size = BATCH_SIZE,image_size = IMG_SIZE,shuffle=True,  # 必须启用 shuffleseed=42)# 创建预处理模型augmentation_model = create_augmentation_model()# 定义预处理函数def preprocess_train(image, label):image = augmentation_model(image, training=True)  # 训练模式激活增强return image, labeltrain_dataset = train_dataset.map(preprocess_train,num_parallel_calls= tf.data.AUTOTUNE)print('--------------返回增强后的训练数据集--------------')return train_dataset.prefetch(buffer_size=tf.data.AUTOTUNE)def create_test2_dataset():test2_dataset = image_dataset_from_directory(test2_dir,label_mode = 'binary',batch_size = BATCH_SIZE,image_size = IMG_SIZE,shuffle=False)print('--------------返回测试数据集[带标签]--------------')return test2_dataset

构建模型

model.py 模型结构:

  • 输入层:指定输入数据形状
  • 数据归一化
  • 四层卷积层和四层池化层交替
  • 展平层:将输出的多维特征图展平为一维向量
  • Dropout防止过拟合
  • 两个全连接层,用于特征提取和最终分类
"""
@Author      :Ayaki Shi
@Date        :2025/4/18 11:02 
@Description : 创建模型
"""
from keras import layers, models, optimizersfrom config import IMG_SIZEdef create_model():model = models.Sequential([# 输入层:指定输入数据形状layers.Input(shape=(*IMG_SIZE, 3)),layers.Rescaling(1./255),  # 归一化到 [0,1]# 四层卷积层和四层池化层layers.Conv2D(32, (3, 3), activation='relu'),layers.MaxPooling2D(2, 2),layers.Conv2D(64, (3, 3), activation='relu'),layers.MaxPooling2D(2, 2),layers.Conv2D(128, (3, 3), activation='relu'),layers.MaxPooling2D(2, 2),layers.Conv2D(128, (3, 3), activation='relu'),layers.MaxPooling2D(2, 2),# 展平层:将输出的多维特征图展平为一维向量layers.Flatten(),# 防止过拟合layers.Dropout(0.5),# 两个全连接层,用于特征提取和最终分类layers.Dense(512, activation='relu'),layers.Dense(1, activation='sigmoid')])# 编译模型model.compile(loss='binary_crossentropy',  # 损失函数optimizer= optimizers.Adam(learning_rate=1e-4), # 优化器metrics=['accuracy']) # 评估标准:准确率print('--------------构建模型成功--------------')return model

训练模型

train.py
获取数据集-创建模型-训练模型-保存模型-绘制损失和准确度曲线

"""
@Author      :Ayaki Shi
@Date        :2025/4/18 16:08 
@Description : 训练模型
"""from dataset import create_train_dataset
from model import create_model
from config import EPOCHS, BATCH_SIZE, MODEL_PATH
import matplotlib.pyplot as pltdef train_model():# 获取datasettrain_dataset = create_train_dataset()# 生成模型model = create_model()# 训练模型print('--------------开始训练模型--------------')history = model.fit(train_dataset,epochs = EPOCHS,batch_size = BATCH_SIZE)# 保存模型print('--------------开始保存模型--------------')model.save(MODEL_PATH)print('--------------开始绘制损失和准确性曲线--------------')# 绘制训练损失曲线plt.figure(figsize=(10, 4))plt.plot(history.history['loss'], label='Training Loss', color='blue', marker='o')plt.title('Training Loss Over Epochs')plt.xlabel('Epochs')plt.ylabel('Loss')plt.legend()plt.grid(True)plt.show()# 绘制训练准确率曲线plt.figure(figsize=(10, 4))plt.plot(history.history['accuracy'], label='Training Accuracy', color='green', marker='s')plt.title('Training Accuracy Over Epochs')plt.xlabel('Epochs')plt.ylabel('Accuracy')plt.legend()plt.grid(True)plt.show()if __name__ == '__main__':train_model()

训练过程

在这里插入图片描述

损失和准确度曲线

在这里插入图片描述
在这里插入图片描述

测试模型

test.py

使用带标签的测试图片评估整体准确率

代码见后面,可以看到整体准确度为85%左右
在这里插入图片描述

定义模型类

代码见后面

预测单张图片

代码见后面,预测为狗的概率是73%。
在这里插入图片描述

将不带标签的测试图片分类并存到对应目录中

代码见后面,可以看到大部分测试图片都被放到了正确的目录里,但是也有少数错的

在这里插入图片描述
在这里插入图片描述

源码

from keras import models
import numpy as np
import os, shutil
from keras_preprocessing import image
from config import MODEL_PATH,IMG_SIZE,test_dir,test_dir_tag_cat,test_dir_tag_dogDOG_TAG_STR = 'dog'
CAT_TAG_STR = 'cat'
NUM_IMAGES = 12500             # 测试图片数class CatAndDogClassifier:def __init__(self):self.model = models.load_model(MODEL_PATH)print("模型加载成功!")def predict_single_image(self, img_path):img = image.load_img(img_path, target_size=IMG_SIZE)img_array = image.img_to_array(img)# 错误代码:双重归一化# img_array = np.expand_dims(img_array, axis=0) / 255.0img_array = np.expand_dims(img_array, axis=0)prediction = self.model.predict(img_array)[0][0]print(prediction)return DOG_TAG_STR if prediction > 0.5 else CAT_TAG_STR, predictiondef classify_all_images(self):# 遍历所有图片# for i in range(1, NUM_IMAGES + 1):filename = ''for i in range(1, NUM_IMAGES + 1):try:#(文件名为1.jpg到12500.jpg)filename = f"{i}.jpg"src_path = os.path.join(test_dir, filename)# 跳过不存在的文件if not os.path.exists(src_path):print(f"Warning: {filename} 不存在,已跳过")continue# 进行预测label, confidence = self.predict_single_image(src_path)# 确定目标目录dest_dir = test_dir_tag_dog if label == DOG_TAG_STR else test_dir_tag_catdest_path = os.path.join(dest_dir, filename)# 移动文件shutil.move(src_path, dest_path)if i%500 == 0: # 打印12500行太多了,每500行打印一次print(f"[{i}/12500] {filename} -> {dest_dir} (置信度: {confidence:.2%})")except Exception as e:print(f"处理 {filename} 时发生错误: {str(e)}")continuedef evaluate_model():from dataset import create_test2_datasettest2_dataset = create_test2_dataset()model = models.load_model(MODEL_PATH)loss, acc = model.evaluate(test2_dataset)print(f'\nTest accuracy: {acc:.2%}')if __name__ == '__main__':# 初始化分类器classifier = CatAndDogClassifier()# # 评估整体准确率# evaluate_model()# # 单张图片预测# img_path = os.path.join('./data/train/dogs/dog.100.jpg')# label, prob = classifier.predict_single_image(img_path)# print(f'预测为: {label} (置信度: {prob if label == DOG_TAG_STR else 1 - prob:.2%})')# 将不带标签的测试图片分类放入不同的文件夹classifier.classify_all_images()

错误记录

双重归一化问题

在预测单张图片过程中,出现了不管什么图片,预测度总是特别低,只有7%左右
在这里插入图片描述
首先预测结果不对第一时间考虑到是不是模型欠拟合或者过拟合的问题。
但是基于以下两个原因:

  • 首先训练过程中记录的准确度和测试整体准确率都是85%,说明模型大概率是没有问题的
  • 其次这个置信度已经低的离谱了
    所以考虑是在测试单张图片对图片处理出现了问题,经过排查,发现问题出在了,我在对单张图片进行了归一化,然后模型中又进行了一次归一化,导致预测置信度极低。
    test.py
    在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/902762.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

中介者模式:解耦对象间复杂交互的设计模式

中介者模式:解耦对象间复杂交互的设计模式 一、模式核心:用中介者统一管理对象交互,避免两两直接依赖 当系统中多个对象之间存在复杂的网状交互时(如 GUI 界面中按钮、文本框、下拉框的联动),对象间直接调…

豆包桌面版 1.47.4 可做浏览器,免安装绿色版

自己动手升级更新办法: 下载新版本后安装,把 C:\Users\用户名\AppData\Local\Doubao\Application 文件夹的文件,拷贝替换 DoubaoPortable\App\Doubao 文件夹的文件,就升级成功了。 再把安装的豆包彻底卸载就可以。 桌面版比网页版…

Android PackageManagerService(PMS)框架深度解析

目录 一、概念与核心作用 二、技术架构与模块组成 1. 分层架构 1.1 应用层架构细节 1.2 Binder接口层实现 1.3 PMS核心服务层 1.4 底层支持层实现 2. 核心模块技术要点与工作流程 2.1 PackageParser 2.2 Settings 2.3 PermissionManager 2.4 Installer 2.5 ComponentM…

TensorFlow深度学习实战(14)——循环神经网络详解

TensorFlow深度学习实战(14)——循环神经网络详解 0. 前言1. 基本循环神经网络单元1.1 循环神经网络工作原理1.2 时间反向传播1.3 梯度消失和梯度爆炸问题2. RNN 单元变体2.1 长短期记忆2.2 门控循环单元2.3 Peephole LSTM3. RNN 变体3.1 双向 RNN3.2 状态 RNN4. RNN 拓扑结构…

PySide6 GUI 学习笔记——常用类及控件使用方法(常用类矩阵QRectF)

文章目录 类描述构造方法主要方法1. 基础属性2. 边界操作3. 几何运算4. 坐标调整5. 转换方法6. 状态判断 类特点总结1. 浮点精度:2. 坐标系统:3. 有效性判断:4. 几何运算:5. 类型转换:6. 特殊处理: 典型应用…

Electron主进程渲染进程间通信的方式

在 Electron 中,主进程和渲染进程之间的通信主要通过 IPC(进程间通信)机制实现。以下是几种常见的通信方式: 1. 渲染进程向主进程发送消息(单向) 渲染进程可以通过 ipcRenderer.send 向主进程发送消息&am…

【C++基础知识】C++类型特征组合:`disjunction_v` 和 `conjunction_v` 深度解析

这两个模板是C17引入的类型特征组合工具,用于构建更复杂的类型判断逻辑。下面我将从技术实现到实际应用进行全面剖析: 一、基本概念与C引入版本 1. std::disjunction_v (逻辑OR) 引入版本:C17功能:对多个类型特征进行逻辑或运算…

私有知识库 Coco AI 实战(二):摄入 MongoDB 数据

在之前的文章中,我们介绍过如何使用《 Logstash 迁移 MongoDB 数据到 Easyseach》,既然 Coco AI 后台数据存储也使用 Easysearch,我们能否直接把 MongoDB 的数据迁移到 Coco AI 的 Easysearch,使用 Coco AI 对数据进行检索呢&…

sql server 与navicat测试后,连接qt

先用Navicat测试和sql的连通性,Navicat和sql连通之后,qt也能和sql连通了。 Navicat和Sqlserver Management 能连上,项目无法连接本地 Navicat 连接SQLServer 数据库 QT国内镜像网站 Navicat连接SqlServer的问题点 Sql Server的基本配置以及使…

2025年3月电子学会青少年机器人技术(六级)等级考试试卷-理论综合

青少年机器人技术等级考试理论综合试卷(六级) 分数:100 题数:30 一、单选题(共20题,共80分) 1. 2025年初,中国科技初创公司深度求索在大模型领域迅速崛起,其开源的大模型成为全球AI领域的焦…

spark local模式搭建运行示例

Apache Spark 是一个强大的分布式计算框架,但在本地模式下,它也可以作为一个单机程序运行,非常适合开发和测试阶段。以下是一个简单的示例,展示如何在本地模式下搭建和运行 Spark 程序。 一、环境准备 安装 Java Spark 需要 Java…

【人工智能】解锁 AI 潜能:DeepSeek 大模型迁移学习与特定领域微调的实践

《Python OpenCV从菜鸟到高手》带你进入图像处理与计算机视觉的大门! 解锁Python编程的无限可能:《奇妙的Python》带你漫游代码世界 随着大型语言模型(LLMs)的快速发展,迁移学习与特定领域微调成为提升模型性能的关键技术。本文深入探讨了 DeepSeek 大模型在迁移学习中的…

视频智能分析平台EasyCVR无线监控:全流程安装指南与功能应用解析

在当今数字化安防时代,无线监控系统的安装与调试对于保障各类场所的安全至关重要。本文将结合EasyCVR视频监控的强大功能,为您详细阐述监控系统安装过程中的关键步骤和注意事项,帮助您打造一个高效、可靠的监控解决方案。 一、调试物资准备与…

【k8s系列7-更新中】kubeadm搭建Kubernetes高可用集群-三主两从

主机准备 结合前面的章节,这里需要5台机器,可以先创建一台虚拟机作为基础虚拟机。优先把5台机器的公共部分优先在一台机器上配置好 1、配置好静态IP地址 2、主机名宇IP地址解析 [root@localhost ~]# cat /etc/hosts 127.0.0.1 localhost localhost.localdomain localhost…

【Java后端】MyBatis 与 MyBatis-Plus 如何防止 SQL 注入?从原理到实战

在日常开发中,SQL 注入是一种常见但危害巨大的安全漏洞。如果你正在使用 MyBatis 或 MyBatis-Plus 进行数据库操作,这篇文章将带你系统了解:这两个框架是如何防止 SQL 注入的,我们又该如何写出安全的代码。 什么是 SQL 注入&#…

数据分析案例:医疗健康数据分析

目录 数据分析案例:医疗健康数据分析1. 项目背景2. 数据加载与预处理2.1 加载数据2.2 数据清洗3. 探索性数据分析(EDA)3.1 再入院率概览3.2 按年龄分组的再入院率3.3 住院时长与再入院4. 特征工程与可视化5. 模型构建与评估5.1 数据划分5.2 训练逻辑回归5.3 模型评估6. 业务…

3台CentOS虚拟机部署 StarRocks 1 FE+ 3 BE集群

背景:公司最近业务数据量上去了,需要做一个漏斗分析功能,实时性要求较高,mysql已经已经不在适用,做了个大数据技术栈选型调研后,决定使用StarRocks StarRocks官网:StarRocks | A High-Performa…

软件设计师/系统架构师---计算机网络

概要 什么是计算机网络? 计算机网络是指将多台计算机和其他设备通过通信线路互联,以便共享资源和信息的系统。计算机网络可以有不同的规模,从家庭网络到全球互联网。它们可以通过有线(如以太网)或无线(如W…

1.5软考系统架构设计师:架构师的角色与能力要求 - 超简记忆要点、知识体系全解、考点深度解析、真题训练附答案及解析

超简记忆要点 角色职责 需求规划→架构设计→质量保障 能力要求 技术(架构模式/性能优化) 业务(模型抽象→技术方案) 管理(团队协作/风险控制) 知识体系 基础:CAP/设计模式/网络协议案例&am…

基于STM32的汽车主门电动窗开关系统设计方案

芯片和功能模块选型 主控芯片 STM32F103C8T6:基于 ARM Cortex - M3 内核,有丰富的 GPIO 接口用于连接各类外设,具备 ADC 模块可用于电流检测,还有 CAN 控制器方便实现 CAN 总线通信。它资源丰富、成本低,适合学生进行 DIY 项目开发。按键模块 轻触按键:用于控制车窗的自…