【神经网络】基于CNN(卷积神经网络)构建猫狗分类模型

文章目录

    • 解决问题
    • 数据集
    • 探索性数据分析
    • 数据预处理
      • 数据集分割
      • 数据预处理
    • 构建模型并训练
      • 构建模型
      • 训练模型
    • 结果分析与评估
    • 模型保存
    • 结果预测
    • 经验总结

解决问题

针对经典猫狗数据集,基于卷积神经网络,构建猫狗二元分类模型,使用数据集进行参数训练,模型评估,然后使用模型进行分类预测,最后对模型进行保存,供后续使用。

数据集

数据集来源

猫狗数据集

探索性数据分析

查看待训练识别图片

from matplotlib import pyplot as plt
import os
import random# 获取文件名
_,_,cat_images = next(os.walk('../../dataset/kagglecatsanddogs_5340/PetImages/Cat'))# 准备3*3 图表
fig, ax = plt.subplots(3, 3, figsize=(20, 10))
# 随机选择一幅图像并绘制
for idx, img in enumerate(random.sample(cat_images, 9)):img_read = plt.imread('../../dataset/kagglecatsanddogs_5340/PetImages/Cat/' + img)ax[int(idx / 3), idx % 3].imshow(img_read)ax[int(idx / 3), idx % 3].set_title('cat/' + img)ax[int(idx / 3), idx % 3].axis('off')
plt.show()

查看狗图片类似,将Cat目录换成Dog即可

image-20240614224011275

数据预处理

数据集分割

由于下载的图片猫和狗各在一个文件夹内,如下:

image-20240613223710122

需要将数据按80%:20%进行分割,分为训练集和测试集。目录结构如下:

image-20240613223517479

下面进行数据拆分,核心代码(以猫图片为例)如下:

# 训练数据集80% 测试数据集20%
train_size = 0.8
# 获取猫图像数量
_, _, cat_images = next(os.walk(src_folder+'Cat/'))
num_cat_images = len(cat_images)
num_cat_images_train = int(train_size * num_cat_images)
num_cat_images_test = num_cat_images - num_cat_images_train
# 分割猫图像
cat_train_images = random.sample(cat_images, num_cat_images_train)
for img in cat_train_images:shutil.copy(src=src_folder+'Cat/'+img, dst=src_folder+'Train/Cat/')
cat_test_images  = [img for img in cat_images if img not in cat_train_images]
for img in cat_test_images:shutil.copy(src=src_folder+'Cat/'+img, dst=src_folder+'Test/Cat/')

数据预处理

这一步要将分割后的数据集转成和模型结构匹配的数据类型。使用keras提供的ImageDataGenerator类和flow_from_directory()方法

ImageDataGenerator类:图像增强类,可以进行图像旋转、图像平移、水平翻转、图像缩放等操作;

flow_from_directory()方法:ImageDataGenerator类的方法,支持以图像路径为输入,按批次加载图像到内存,防止训练数据量过大,机器内存不足问题;还支持对图像进行预处理操作,例如尺寸缩放和图像增强

# 训练数据预处理
training_data_generator = ImageDataGenerator(rescale=1./255)
training_set = training_data_generator.flow_from_directory('../../dataset/kagglecatsanddogs_5340/PetImages/train/',target_size=(32, 32),batch_size=16,class_mode='binary')# 测试数据预处理
testing_data_generator = ImageDataGenerator(rescale= 1./255)
testing_set = testing_data_generator.flow_from_directory('../../dataset/kagglecatsanddogs_5340/PetImages/test/',target_size=(32, 32),batch_size=16, class_mode='binary')

构建模型并训练

构建模型

# 定义超参数
# 特征滤波器尺寸
FILTER_SIZE = 3
# 特征滤波器数量
FILTER_NUM = 32
# 图片输入尺寸
INPUT_SIZE = 32
# 最大池化尺寸
MAXPOOL_SIZE = 2
# 批量处理图片的大小
BATCH_SIZE = 16
STEPS_PER_EPOCH = 20000 // BATCH_SIZE
# 训练轮次
EPOCHS = 10
# 定义模型
model = Sequential()
# 添加卷积、池化层 提取特征
model.add(Conv2D(FILTER_NUM, (FILTER_SIZE, FILTER_SIZE), input_shape=(INPUT_SIZE, INPUT_SIZE, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(MAXPOOL_SIZE,MAXPOOL_SIZE)))
# 再添加卷积、池化层 提取特征
model.add(Conv2D(FILTER_NUM, (FILTER_SIZE, FILTER_SIZE), input_shape=(INPUT_SIZE, INPUT_SIZE, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(MAXPOOL_SIZE,MAXPOOL_SIZE)))
# 对输出结果进行降维处理,转成一维张量
model.add(Flatten())
# 添加全链接层,根据特征进行分类预测
model.add(Dense(units=128, activation='relu'))
# 添加dropout层,随机将一部分输入设置为0,防止模型复杂,出现过拟合现象
model.add(Dropout(0.5))
# 添加输出层,一个节点
model.add(Dense(units=1, activation='sigmoid'))

该模型结构分为,卷积池化层,卷积池化层,Flatten层,全链接层1,全链接层2(输出层)如下:

image-20240614221747896

其中,第一列是神经网络的层,第二列是每层的输出形状,第三层是每层训练的参数

可以看到,该模型图像输入尺寸是(32,32),经过一层卷积(32个特征过滤器)输出为(30,30,32),经过一层最大池化层,输出为(15,15,32);其中特征滤波器尺寸为3*3,所以滤波后的尺寸会是32-(3-1)=30,经过最大池化(2x2)尺寸减半,为15。

训练模型

# 模型训练
model.fit(training_set, steps_per_epoch=STEPS_PER_EPOCH, epochs=EPOCHS, verbose=1)
image-20240614123217943

结果分析与评估

model.evaluate(testing_set,steps=len(testing_set),verbose=1)
image-20240614222619817

准确度达到了0.7856

模型保存

from joblib import dump, load# 模型持久化 到磁盘
dump(model, './猫狗分类.onnx')

结果预测

引入保存模型,随机选取一张图片进行预测分类

from matplotlib import pyplot as plt
fig, ax = plt.subplots()
img = plt.imread('../../dataset/kagglecatsanddogs_5340/PetImages/Dog/6.jpg')
ax.imshow(img)
plt.show()
image-20240614223543720
from joblib import dump, load
model = load('./猫狗分类.onnx')from tensorflow.keras.preprocessing.image import img_to_array,load_imgimg = load_img('../../dataset/kagglecatsanddogs_5340/PetImages/Dog/6.jpg',target_size=(32,32))
img = img_to_array(img)
img /= 255
import numpy as np
img_array = np.expand_dims(img, axis=0)
print(img_array.shape)
model.predict(img_array)

在这里插入图片描述

由于是二元分类,0和1分别表示猫狗,输出概率接近表示是狗,接近0表示是猫狗。但具体为啥0表示猫1表示狗而不是反过来表示,还待研究。

经验总结

1 在使用next()加载图像时,要确保路径正确,否则会报StopIteration错误,原因是路径错误,找不到可迭代的数据。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/29469.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

怎么提取视频中的音频?别错过这6个音频提取方法了!(全新)

您是否曾经发现过一个音乐很棒的视频,并想从视频中提取音频?如今,关于提取mp4视频中的音频需求越来越常见。例如,您可能想从mp4格式的电影中提取音频,将音乐用作手机铃声,或在自己的视频项目中使用视频中的…

[Qt] Qt Creator 编码警告:warning:C4819

Qt项目使用VC(2019 64bit)编译器出现此错误。 warning:C4819:该文件包含不能在当前代码页(936)中表示的字符。请将该文件保存为Unicode格式以防止数据丢失。(可能这个警告内容也会在Qt Creator 中乱码) 如…

Matlab只选取自己需要的数据画图

在Matlab作图的时候,经常会在同一个坐标系中作很多数据的图,如下图所示: 这就会导致不同数据所作的线会重叠在一起,不利于数据分析。如果只想对比几个数据的趋势,直接修改代码太过麻烦,可通过Matlab的绘图…

【C语言】数组参数和指针参数详解

在写代码的时候难免要把【数组】或者【指针】传给函数&#xff0c;那函数的参数该如何设计呢&#xff1f; 1 一维数组传参 #include <stdio.h> void test(int arr[])//ok? {} void test(int arr[10])//ok? {} void test(int* arr)//ok? {} void test2(int* arr[20])…

Java毕业设计 基于SSM助学贷款管理系统

Java毕业设计 基于SSM助学贷款管理系统 SSM 助学贷款管理系统 功能介绍 学生&#xff1a;登录 修改密码 学生信息 贷款项目信息 申请贷款 留言信息 公告 学校负责人&#xff1a;登录 修改密码 学生管理 学校负责人信息 贷款项目 贷款申请审批 留言信息 公告 银行负责人&…

Linux中nginx.conf如何配置【搬代码】

Nginx 是一个独立的软件。 它是一款高性能的 Web 服务器、反向代理服务器和负载均衡器等&#xff0c;具有强大的功能和广泛的应用场景。它通常需要单独进行安装和配置来发挥其作用。 下载网址&#xff1a;http://nginx.org/en/download.html nginx.conf写法&#xff1a; #配置…

鸿蒙实现金刚区效果

前言&#xff1a; DevEco Studio版本&#xff1a;4.0.0.600 所谓“金刚区"是位于APP功能入口的导航区域&#xff0c;通常以“图标文字”的宫格导航的形式出现。之所以叫“金刚区”&#xff0c;是因为该区域会随着业务目标的改变&#xff0c;展示不同的功能图标&#xff…

C++ 70 之 类模版中的成员函数,在类外实现

#include <iostream> #include <string> using namespace std;template<class T1, class T2> class Students10{ public:T1 m_name;T2 m_age;Students10(T1 name, T2 age); // 类内声明 类外实现// {// this->m_name name;// this->m_age …

CCAA质量管理【学习笔记】​​ 备考知识点笔记(六)质量改进系统方法与工具

第七节 质量改进系统方法与工具 1 质 量 改 进 方 法 概 述 可以说几乎每种质量管理领域的方法与工具都可以用于质量改进&#xff0c;但是一个组织在改进的整体推进中&#xff0c;往往不是采用单一的方法&#xff0c;会涉及多种改进的工具和手段&#xff0c;并依据一定的模式…

鸿蒙实现自定义Tabbar样式,显示数字红点提示

前言&#xff1a; DevEco Studio版本&#xff1a;4.0.0.600 Tabs的链接参考&#xff1a;OpenHarmony Tabs TabContent的链接参考&#xff1a;OpenHarmony TabContent 通过查看链接参考我们知道可以通过TabContent的tabBar来实现自定义TabBar样式&#xff08;CustomBuilder&…

CloudTopExam考试系统

前言 整个项目的都是自己从0到1完成的&#xff08;包括数据库设计&#xff09;。 这个项目耗费了自己的很多心血&#xff0c;尤其是数据库的设计&#xff08;中途推翻重做了好几次&#x1f494;&#xff09;。在做这个之前也看过很多类似的开源项目&#xff0c;相较于商用的产品…

第六节 未登录与登录分支设立

经常我们在设计中,经常会遇到多条件分支打开相关界面,下面重点基于一个控件判断对未登录与已登录分支跳转案例进行说明。 一、设置元件 注意:动态面板默认设置 二、设置隐藏面板 三、关联条件情形 1、设置触发事件的元件 2、启用情形 3、添加情形,增加面板中“未登录”为…

文件操作(2)(C语言版)

文件的随机读写&#xff1a; fseek函数&#xff1a; 前面讲解了顺序读写的相关函数&#xff0c;这里介绍一些可以“指哪写哪的函数” 有三个参数&#xff1a; 1、文件的地址 2、相对于第三个参数origin偏移的位置 3、起始位置&#xff08;有三种&#xff09; 第一种&#xff…

参数搜索流形学习

目录 一、网格搜索1、介绍2、代码示例 二、HalvingGridSearch1、介绍2、代码示例 三、随机搜索1、介绍2、代码示例 三、贝叶斯搜索1、介绍2、代码示例 四、参数搜索总结五、流形学习1、LLE1、介绍2、官方代码示例 2、t-SNE1、介绍2、官方代码示例 一、网格搜索 1、介绍 网格搜…

Matlab复数相关

文章目录 MATLAB复数相关知识相关函数 MATLAB复数相关知识 相关函数 假定存在复数zabi 函数说明real(z)返回复数z的实部&#xff08;a&#xff09;imag(z)返回复数z的虚部&#xff08;b&#xff09;abs(z)返回复数的模即|z| &#xff08; ( a 2 ) ( b 2 ) \sqrt{(a^2)(b^2)…

重生奇迹MU召唤术师简介

出生地&#xff1a;幻术园 性 别&#xff1a;女 擅 长&#xff1a;召唤幻兽、辅助魔法&攻击魔法 转 职&#xff1a;召唤巫师&#xff08;3转&#xff09; 介 绍&#xff1a;从古代开始流传下来的高贵的血缘&#xff0c;为了种族纯正血缘的延续及特殊使用咒术的天赋&…

【前端】 nvm安装管理多版本node、 npm install失败解决方式

【问题】If you believe this might be a permissions issue, please double-check the npm ERR! permissio或者Error: EPERM: operation not permitted, VScode中npm install或cnpm install报错 简单总结&#xff0c;我们运行npm install 无法安装吧包&#xff0c;提示权限问题…

友思特应用 | 模型链接一应俱全:IC多类别视觉检测一站式解决方案

导读 高精度IC制造工艺需要对产品进行全方位检测以保证工艺质量过关。友思特 Neuro-T 通过调用平台的流程图功能&#xff0c;搭建多类深度学习模型&#xff0c;形成了一站式的视觉检测解决方案。本文将为您详述方案搭建过程与实际应用效果。 在当今集成电路&#xff08;IC&…

SuiNS更新命名标准,增强用户体验

SuiNS将其面向用户的命名标准从 xxx.sui 更新为 xxx&#xff0c;让用户能够以一种适用于Web2和Web3世界的方式来代表自己。通过此更新&#xff0c;用户可以在其选择的名称前使用 &#xff0c;而不是在名称后添加 .sui。 Sui命名服务于去年推出&#xff0c;旨在使Sui上的地址更…

TypeScript写好了,怎么运行啊!!!

环境搭建 Vs code Ctrlshiftp打开首选项—》打开工作区设置—》搜索Typescript 推荐开启的配置项主要是这几个&#xff1a; Function Like Return Types&#xff0c;显示推导得到的函数返回值类型&#xff1b;Parameter Names&#xff0c;显示函数入参的名称&#xff1b;Par…