Keras实战之图像分类识别

文章目录

    • 整体流程
      • 数据加载与预处理
      • 搭建网络模型
      • 优化网络模型
        • 学习率
        • Drop-out操作
        • 权重初始化方法对比
        • 正则化
        • 加载模型进行测试

实战:利用Keras框架搭建神经网络模型实现基本图像分类识别,使用自己的数据集进行训练测试。

问:为什么选择Keras?
答:使用Keras便捷快速。用起来简单,入门容易,上手快。没有tensorflow那么复杂的规范。

整体流程

  1. 读取数据
  2. 数据预处理
  3. 切分数据集(分为训练集和测试集)
  4. 搭建网络模型(初始化参数)
  5. 训练网络模型
  6. 评估测试模型(通过对比不同参数下损失函数不断优化模型)
  7. 保存模型到本地

(1)手动配置参数,设置数据存储路径、模型保存路径、图片保存路径

# 输入参数,手动设置数据存储路径、模型保存路径、图片保存路径等
ap = argparse.ArgumentParser()
ap.add_argument("-d", "--dataset", required=True,help="path to input dataset of images")
ap.add_argument("-m", "--model", required=True,help="path to output trained model")
ap.add_argument("-l", "--label-bin", required=True,help="path to output label binarizer")
ap.add_argument("-p", "--plot", required=True,help="path to output accuracy/loss plot")
args = vars(ap.parse_args())

在这里插入图片描述

数据加载与预处理

# 拿到图像数据路径,方便后续读取
imagePaths = sorted(list(utils_paths.list_images(args["dataset"])))
random.seed(42)
random.shuffle(imagePaths)
# 数据洗牌前设置随机种子确保后面调参过程中训练数据集一样# 遍历读取数据
for imagePath in imagePaths:# 读取图像数据,由于使用神经网络,需要输入数据给定成一维image = cv2.imread(imagePath)# 而最初获取的图像数据是三维的,则需要将三维数据进行拉长image = cv2.resize(image, (32, 32)).flatten()data.append(image)# 读取标签,通过读取数据存储位置文件夹来判断图片标签label = imagePath.split(os.path.sep)[-2]labels.append(label)# scale图像数据,归一化
data = np.array(data, dtype="float") / 255.0
labels = np.array(labels)# 转换标签,one-hot格式
lb = LabelBinarizer()
trainY = lb.fit_transform(trainY)
testY = lb.transform(testY)

数据预处理:①通过数据除以255进行数据归一化;②对数据标签进行格式转换。

搭建网络模型

  1. 创建序列结构
model = Sequential()
  1. 添加全连接层
  • 第一层全连接层Dense设计512个神经元,当前输入特征个数(输入神经元个数)为3072,设置激活函数为"relu";
  • 第二层设计256个神经元;
  • 第三层设计类别数个神经元(即3个),并作softmax操作得到最终分类类别。
# 第一层
model.add(Dense(512, input_shape=(3072,),activation="relu"))
# 第二层
model.add(Dense(256, activation="relu",))
# 第三层
model.add(Dense(len(lb.classes_), activation="softmax",))
  1. 初始化参数
# 学习率
INIT_LR = 0.01
# 迭代次数
EPOCHS = 200
  1. 训练网络模型
# 给定损失函数和评估方法
opt = SGD(lr=INIT_LR) # 指定优化器为梯度下降的优化器
model.compile(loss="categorical_crossentropy", optimizer=opt,metrics=["accuracy"])# 训练网络模型
H = model.fit(trainX, trainY, validation_data=(testX, testY),epochs=EPOCHS, batch_size=32)
  1. 测试网络模型

使用上面训练所得网络模型对测试集进行预测,并对比预测解国和数据集真实结果打印结果报告(包括准确率、recall、f1-score),并将损失函数以折线图的效果直观展示出来

predictions = model.predict(testX, batch_size=32)
print(classification_report(testY.argmax(axis=1),predictions.argmax(axis=1), target_names=lb.classes_))
  1. 评估结果
    在这里插入图片描述
    在这里插入图片描述
    从损失函数图像中可看出,模型出现明显过拟合现象,故而该初始参数所构建的模型效果较差,需要通过调参优化模型。

优化网络模型

学习率

对比学习率为0.01和0.001的损失函数图像。

在这里插入图片描述

train_loss与val_loss之间差异仍然存在,但是可看出学习率越大,过拟合现象越明显。

Drop-out操作

Dropout操作:在搭建网络模型中,通过设置一0到1范围内的参数从而防止过拟合。
在这里插入图片描述

权重初始化方法对比

(1)RandomNormal随机高斯初始化

kernel_initializer =initializers.random_normal(mean=0.0,stddev=0.05)
model.add(Dense(512, input_shape=(3072,),activation="relu",kernel_initializer =initializers.random_normal(mean=0.0,stddev=0.05)))
model.add(Dense(256, activation="relu",kernel_initializer =initializers.random_normal(mean=0.0,stddev=0.05)))
model.add(Dense(len(lb.classes_), activation="softmax",kernel_initializer =initializers.random_normal(mean=0.0,stddev=0.05)))

在这里插入图片描述
图中可看出,添加RandomNormal初始化后,过拟合现象减弱了一丢丢。

(2)TruncatedNormal截断

kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05, seed=None)

相比于正常高斯分布截断了两边,只取小于2倍stddev的值

model.add(Dense(512, input_shape=(3072,), activation="relu" ,kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05)))
model.add(Dense(256, activation="relu",kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05)))
model.add(Dense(len(lb.classes_), activation="softmax",kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05)))

在这里插入图片描述

对比stddev取不同值时的loss函数图可得,TruncatedNormal中stddev值越小,过拟合风险越低,模型效果越好。TruncatedNormal消除过拟合的效果RandomNormal好。

正则化
kernel_regularizer=regularizers.l2(0.01)

正则化后,损失函数loss = 初始loss + aR(W)。正则化惩罚W,让稳定的W减少过拟合。

model.add(Dense(512, input_shape=(3072,), activation="relu" ,kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05, seed=None),kernel_regularizer=regularizers.l2(0.01)))
model.add(Dense(256, activation="relu",kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05, seed=None),kernel_regularizer=regularizers.l2(0.01)))
model.add(Dense(len(lb.classes_), activation="softmax",kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05, seed=None),kernel_regularizer=regularizers.l2(0.01)))

对比正则化前后取迭代150到200的loss波动图,可发现正则化后虽然开始时loss值较大,但后期过拟合现象有明显减弱
在这里插入图片描述
再对比正则化参数l2 = 0.01和0.05的结果可得,l2越大,W的惩罚力度越大,过拟合风险越小
在这里插入图片描述

加载模型进行测试
# 导入所需工具包
from keras.models import load_model
import argparse
import pickle
import cv2# 设置输入参数
ap = argparse.ArgumentParser()
ap.add_argument("-i", "--image", required=True,help="path to input image we are going to classify")
ap.add_argument("-m", "--model", required=True,help="path to trained Keras model")
ap.add_argument("-l", "--label-bin", required=True,help="path to label binarizer")
ap.add_argument("-w", "--width", type=int, default=28,help="target spatial dimension width")
ap.add_argument("-e", "--height", type=int, default=28,help="target spatial dimension height")
ap.add_argument("-f", "--flatten", type=int, default=-1,help="whether or not we should flatten the image")
args = vars(ap.parse_args())# 加载测试数据并进行相同预处理操作
image = cv2.imread(args["image"])
output = image.copy()
image = cv2.resize(image, (args["width"], args["height"]))# scale the pixel values to [0, 1]
image = image.astype("float") / 255.0# 对图像进行拉平操作
image = image.flatten()
image = image.reshape((1, image.shape[0]))# 读取模型和标签
print("[INFO] loading network and label binarizer...")
model = load_model(args["model"])
lb = pickle.loads(open(args["label_bin"], "rb").read())# 预测
preds = model.predict(image)# 得到预测结果以及其对应的标签
i = preds.argmax(axis=1)[0]
label = lb.classes_[i]# 在图像中把结果画出来
text = "{}: {:.2f}%".format(label, preds[0][i] * 100)
cv2.putText(output, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7,(0, 0, 255), 2)# 绘图
cv2.imshow("Image", output)
cv2.waitKey(0)

分类结果:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
通过预测结果可得:该模型在预测猫上存在较大误差,在预测熊猫上较为准确。或许改进增加迭代次数可进一步优化模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/42088.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

这门HCIE认证正式发布!

华为认证AI解决方案架构专家HCIE-AI Solution Architect V1.0(中文版)自2024年6月28日起,正式在中国区发布。 基于“平台生态”战略,围绕“云-管-端”协同的新ICT技术架构,华为公司打造了覆盖ICT领域的认证体系&#x…

C++ | Leetcode C++题解之第217题存在重复元素

题目&#xff1a; 题解&#xff1a; class Solution { public:bool containsDuplicate(vector<int>& nums) {unordered_set<int> s;for (int x: nums) {if (s.find(x) ! s.end()) {return true;}s.insert(x);}return false;} };

2024年江苏省研究生数学建模科研创新实践大赛C题气象数据高精度融合技术研究论文和代码分析

经过不懈的努力&#xff0c; 2024年江苏省研究生数学建模科研创新实践大赛C题气象数据高精度融合技术研究论文和代码已完成&#xff0c;代码为C题全部问题的代码&#xff0c;论文包括摘要、问题重述、问题分析、模型假设、符号说明、模型的建立和求解&#xff08;问题1模型的建…

服务器BMC基础知识总结

前言 因为对硬件方面不太理解&#xff0c;所以打算先从服务器开始学习&#xff0c;也想和大家一起分享一下&#xff0c;有什么不对的地方可以纠正一下哦&#xff01;谢谢啦&#xff01;互相学习共同成长~ 1.BMC是什么&#xff1f; 官方解释&#xff1a;BMC全名Baseboard Mana…

【深度学习】-WASB-调试说明

要改这么几个地方&#xff1a; 代码仓库&#xff1a;/Desktop/code/python_project/WASB-SBDT-main/ 篮球数据集xx_xx_11.xml只保留最后一个11.xml 并把11下直接放置11 video&#xff1a; 这里的东西被我改了&#xff0c;要以仓库为准

一气之下,关闭成都400多人的游戏公司

关注卢松松&#xff0c;会经常给你分享一些我的经验和观点。 最近&#xff0c;多益网络宣布关闭成都公司&#xff0c;在未来三年内&#xff0c;关闭成都所有的相关公司。原因竟然是输掉了劳动仲裁&#xff0c;赔偿员工38万多&#xff0c;然后一气之下要退出成都&#xff0c;…

【2024最新华为OD-C/D卷试题汇总】[支持在线评测] LYA的生日聚会(100分) - 三语言AC题解(Python/Java/Cpp)

&#x1f36d; 大家好这里是清隆学长 &#xff0c;一枚热爱算法的程序员 ✨ 本系列打算持续跟新华为OD-C/D卷的三语言AC题解 &#x1f4bb; ACM银牌&#x1f948;| 多次AK大厂笔试 &#xff5c; 编程一对一辅导 &#x1f44f; 感谢大家的订阅➕ 和 喜欢&#x1f497; &#x1f…

使用antd的<Form/>组件获取富文本编辑器输入的数据

前端开发中&#xff0c;嵌入富文本编辑器时&#xff0c;可以通过富文本编辑器自身的事件处理函数将数据传输给后端。有时候&#xff0c;场景稍微复杂点&#xff0c;比如一个输入页面除了要保存富文本编辑器的内容到后端&#xff0c;可能还有一些其他输入组件获取到的数据也一并…

Mac搭建anaconda环境并安装深度学习库

1. 下载anaconda安装包 根据自己的操作系统不同&#xff0c;选择不同的安装包Anaconda3-2024.06-1-MacOSX-x86_64.pkg&#xff0c;我用的还是旧的intel所以下载这个&#xff0c;https://mirrors.tuna.tsinghua.edu.cn/anaconda/archive/&#xff0c;如果mac用的是M1&#xff0…

GuLi商城-商品服务-API-品牌管理-云存储开通与使用

这里学习下阿里云对象存储 地址&#xff1a;对象存储 OSS_云存储服务_企业数据管理_存储-阿里云 登录支付宝账号&#xff0c;找到了我以前开通的阿里云对象存储 熟悉下API 文档中心 简介_对象存储(OSS)-阿里云帮助中心 我们将用这种方式上传阿里云OSS

SprongBoot3整合Knife4j实现在线接口文档

大家好&#xff0c;我是晓凡。 写在前面 在上一篇文章&#xff0c;我们详细介绍了SpringBoot3 怎么整合SpringDoc实现在线接口文档。但是&#xff0c;有不少小伙伴 都觉得接口界面太丑了。有没有什么更美观一点的UI界面呢&#xff1f; 当然是有的了&#xff0c;毕竟这是一个…

基于Android Studio电影购票系统

目录 项目介绍 图片展示 运行环境 获取方式 项目介绍 主要实为了方便用户随时随地进行电影购票。在配色方面选择了一些富有电影元素的颜色。主要能够实现的功能与流程为&#xff1a; 1.用户首先需要注册用户名填写密码。 2.用户可以用之前注册的用户名和密码进行登录。 3.登…

【密码学】密码学体系

密码学体系是信息安全领域的基石&#xff0c;它主要分为两大类&#xff1a;对称密码体制和非对称密码体制。 一、对称密码体制&#xff08;Symmetric Cryptography&#xff09; 在对称密码体制中&#xff0c;加密和解密使用相同的密钥。这意味着发送方和接收方都必须事先拥有这…

1-3 NLP为什么这么难做

1-3 NLP为什么这么难做 主目录点这里 字词结构的复杂性 中文以汉字为基础单位&#xff0c;一个词通常由一个或多个汉字组成&#xff0c;而不像英语词汇单元由字母构成。这使得中文分词&#xff08;切分句子为词语&#xff09;成为一个具有挑战性的任务。语言歧义性 中文中常…

网络安全设备——蜜罐

网络安全设备蜜罐&#xff08;Honeypot&#xff09;是一种主动防御技术&#xff0c;它通过模拟真实网络环境中的易受攻击的目标&#xff0c;以吸引和监测攻击者的活动。具体来说&#xff0c;蜜罐是一种虚拟或实体的计算机系统&#xff0c;它模拟了一个真实的网络系统或应用程序…

Shell编程类-网站检测

Shell编程类-网站检测 面试题参考答法 a(1 2 3 4) echo ${a[0]} echo ${a[*]}这里声明一个数值&#xff0c;并选择逐个调用输出还是全部输出 curl -w %{http_code} urL/IPADDR常用-w选项去判断网站的状态&#xff0c;因为不加选择访问到的网站可能出现乱码无法判断是否网站down…

Xilinx FPGA:vivado关于fifo的一些零碎知识

一、FIFO概念 先进先出&#xff0c;是一种组织和操作数据结构的方法。在硬件应用中&#xff0c;FIFO一般由一些读写指针&#xff0c;存储和控制的逻辑组成。 二、xilinx中生成的FIFO的存储类型 &#xff08;1&#xff09;shift register FIFO : 移位寄存器FIFO&#xff0c;这…

自动化设备上位机设计 三

目录 一 设计原型 二 后台源码 一 设计原型 二 后台源码 using SqlSugar;namespace 自动化上位机设计 {public partial class Form1 : Form{SqlHelper sqlHelper new SqlHelper();SqlSugarClient dbContent null;bool IsRun false;int Count 0;public Form1(){Initializ…

【论文笔记】BEVCar: Camera-Radar Fusion for BEV Map and Object Segmentation

原文链接&#xff1a;https://arxiv.org/abs/2403.11761 0. 概述 本文的BEVCar模型是基于环视图像和雷达融合的BEV目标检测和地图分割模型&#xff0c;如图所示。模型的图像分支利用可变形注意力&#xff0c;将图像特征提升到BEV空间中&#xff0c;其中雷达数据用于初始化查询…

Tkinter布局助手

免费的功能基本可以满足快速开发布局&#xff0c; https://pytk.net/ iamxcd/tkinter-helper: 为tkinter打造的可视化拖拽布局界面设计小工具 (github.com) 作者也把项目开源了&#xff0c;有兴趣可以玩玩