基于卷积神经网络的图像二分类检测模型训练与推理实现教程 | 幽络源

前言

对于本教程,说白了,就是期望能通过一个程序判断一张图片是否为某个物体,或者说判断一张图片是否为某个缺陷。因为本教程是针对二分类问题,因此主要处理 是 与 不是 的问题,比如我的模型是判断一张图片是否为苹果,那么拿一张图片给模型去推理,他会得出这张图是苹果的概率,如果概率大于0.5(这个概率在0~1之间),那么就判断为是苹果。

教程内容

使用了Python的 TensorFlow 和 Keras 库 构建卷积神经网络来完成二分类模型训练,以及使用模型完成对一张图片的推理。原文链接:基于卷积神经网络的图像二分类检测模型训练与推理实现教程 | 幽络源

大致步骤

1.确定环境与库

2.准备数据集并且划分

3.数据集的命名问题注意事项

4.编写训练代码完成模型训练

5.编写推理代码

6.测试二分类检测结果

7.根据结果优化数据集

步骤1.确定环境与库

Python环境是必备的,我这里所使用的Python版本为3.12.3

其次还需要以下库,依次执行如下命令即可

pip install tensorflow
pip install pillow
pip install scipy

如图

1

2

步骤2.准备数据集并且划分

我这里以判断图片是否为冲沟缺陷 来准备数据集,首先创建数据集的目录结构,结构如下

data/train/true_sample/ false_sample/  val/true_sample/false_sample/

QQ_1734065732662

目录解释:

data:作为数据集的根目录

train和val分别为训练集、验证集目录

true_sample:正类样本,也就是我这里需要把含有冲沟缺陷的图放到这个目录

false_sample:负类样本,也就是这里需要将不含有冲沟缺陷的图片放进这个目录

如图,我向train和val的true_sample目录加入了一些含有冲沟缺陷的图片

3

对于负类样本,也不是无脑的只要不是冲沟就往里面放,而是放置你认为训练出的模型可能会将什么识别为正类样本。比如滑坡和冲沟其实是有联系的,但不完全等同于,所以我需要将滑坡相关的,但是没有冲沟情况的图片放入false_sample中,期望模型不要误判。再比如一个苹果,你可能需要把红色气球作为父类样本,防止模型将红气球判断为是苹果,如图是我的负类样本

4

步骤3.数据集的命名问题注意事项

关于数据集的命名,这里其实有一个坑,但是先说避免坑的做法:就像步骤2一样,你的正类样本所放置的目录命名为true_sample、负类样本所放置的目录命名为false_sample就行了。(如果看不懂下面的解释,按照这里做法做就是了)

然后我来解释下是什么坑,对于这个二分类模型训练,训练出来的模型,无非是识别 是 与 不是 的问题,但是模型怎么区分我的哪个目录放置的为是,哪个目录放置的为不是呢,步骤4会给出训练代码,训练代码中的加载数据集时有一行如下代码

class_mode='binary'  # 二分类(冲沟缺陷 vs. 非冲沟缺陷)

这表示我们要做二分类模型训练,加上这行代码,在加载数据集时,Keras 会自动将这些文件夹的名称作为标签,分别命名为1 和 0,如果被命名为标签1 的目录,则在推理时,概率越接近于1,则越表示是标为1的目录的样本,反之概率越接近于0,则越表示是标为0的目录的样本。而keras自动命名标签1和0时是根据目录名首字母的顺序来的字,字母靠前的标为0,后者为1,true_sample的首字母为t,false_sample的首字母为f,因此false_sample标为0,true_sample标为1,这是符合我们的正常预期的。

反面例子:

如果我把正类样本放置于名为defect的目录,负类样本放置于no_defect目录会怎样呢,按照如上解释,defect目录会被标为0,no_defect目录会被标为1,这就和我们预期相反了,什么意思呢。我把正类样本放置defect目录中,其推理结果将会是越接近0,则越表示为正类了,因此这里特别需要注意(如果你要自定义目录名的话)。

步骤4.编写训练代码完成模型训练

先直接上训练代码

from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import tensorflow as tftrain_dir='data/train'
val_dir='data/val'# 设置图像的尺寸和批量大小,不用改,保持150是最平衡的
IMG_HEIGHT = 150
IMG_WIDTH = 150
BATCH_SIZE = 12# 数据预处理与增强
train_datagen = ImageDataGenerator(rescale=1./255,  # 将像素值归一化到 [0, 1] 区间shear_range=0.2,zoom_range=0.2,horizontal_flip=True
)validation_datagen = ImageDataGenerator(rescale=1./255)# 加载训练和验证数据
train_generator = train_datagen.flow_from_directory(train_dir,  # 训练数据目录target_size=(IMG_HEIGHT, IMG_WIDTH),  # 图像尺寸batch_size=BATCH_SIZE,class_mode='binary'  # 二分类(冲沟缺陷 vs. 非冲沟缺陷)
)train_class_labels = train_generator.class_indices
print("训练集自动标签映射关系为:"+str(train_class_labels))validation_generator = validation_datagen.flow_from_directory(val_dir,  # 验证数据目录target_size=(IMG_HEIGHT, IMG_WIDTH),batch_size=BATCH_SIZE,class_mode='binary'
)val_class_labels = validation_generator.class_indices
print("测试集自动标签映射关系为:"+str(val_class_labels))# 将数据生成器转换为 tf.data.Dataset 并应用 repeat() 方法
train_dataset = tf.data.Dataset.from_generator(lambda: train_generator,output_signature=(tf.TensorSpec(shape=(None, IMG_HEIGHT, IMG_WIDTH, 3), dtype=tf.float32),tf.TensorSpec(shape=(None,), dtype=tf.int32))
)
train_dataset = train_dataset.repeat()  # 确保数据重复validation_dataset = tf.data.Dataset.from_generator(lambda: validation_generator,output_signature=(tf.TensorSpec(shape=(None, IMG_HEIGHT, IMG_WIDTH, 3), dtype=tf.float32),tf.TensorSpec(shape=(None,), dtype=tf.int32))
)
validation_dataset = validation_dataset.repeat()  # 确保数据重复# 构建模型
model = models.Sequential([layers.InputLayer(shape=(IMG_HEIGHT, IMG_WIDTH, 3)),  # 添加 Input 层layers.Conv2D(32, (3, 3), activation='relu'),layers.MaxPooling2D((2, 2)),layers.Conv2D(64, (3, 3), activation='relu'),layers.MaxPooling2D((2, 2)),layers.Conv2D(128, (3, 3), activation='relu'),layers.MaxPooling2D((2, 2)),layers.Flatten(),layers.Dense(128, activation='relu'),layers.Dense(1, activation='sigmoid')  # 输出层,二分类问题
])# 编译模型
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])# 训练模型
model.fit(train_dataset,steps_per_epoch=train_generator.samples // BATCH_SIZE,epochs=30,validation_data=validation_dataset,validation_steps=validation_generator.samples // BATCH_SIZE
)# 保存模型
model.save('defect_detector_model.keras')  # 使用 .keras 格式保存模型

使用这段代码训练数据集你唯一需要注意的是保持代码文件于数据集文件在同一目录,或者使用绝对路径,如图

QQ_1734070943655

我们启动训练代码,可以看到控制台在按照规定的轮次30在训练中,而且可以看到我在训练代码中加入了输出标签映射关系来确保正类与负类的映射关系正确,如图

QQ_1734071390702

训练后,你会得到一个名为defect_detector_nodel.keras的文件,推理时会使用该模型进行推理

步骤5.编写推理代码

代码如下:

import os
from tensorflow.keras.models import load_model
import numpy as np
from tensorflow.keras.preprocessing import image# 加载训练好的模型
model = load_model('defect_detector_model.keras')  # 注意加载的是 .keras 格式# 设置输入图像的目标尺寸(与训练时相同)
IMG_HEIGHT = 150
IMG_WIDTH = 150# 定义函数来加载并预测图像
def predict_image(img_path):# 加载图像并进行预处理img = image.load_img(img_path, target_size=(IMG_HEIGHT, IMG_WIDTH))img_array = image.img_to_array(img)  # 将图像转换为数组img_array = np.expand_dims(img_array, axis=0)  # 扩展维度,成为一个 batchimg_array = img_array / 255.0  # 归一化处理(与训练时一致)# 预测图像类别prediction = model.predict(img_array)  # 返回的是一个包含概率的数组return prediction[0][0]  # 提取预测的概率值picPath=r"测试图.jpg"
confidence = predict_image(picPath)
print("有冲沟缺陷的概率为:"+str(confidence))

这段推理代码中,我们加载了刚才训练出的模型,然后使用了一张名为测试图.jpg的图片来进行推理,然后输出他有缺陷的概率

步骤6.测试二分类检测结果

我这里就不用一张图片来测试了,我这里指定一个目录,进行整个目录来测试里面的图片,还是附上我这个推理代码吧

import os
from tensorflow.keras.models import load_model
import numpy as np
from tensorflow.keras.preprocessing import image# 加载训练好的模型
model = load_model('defect_detector_model.keras')  # 注意加载的是 .keras 格式# 设置输入图像的目标尺寸(与训练时相同)
IMG_HEIGHT = 150
IMG_WIDTH = 150# 定义函数来加载并预测图像
def predict_image(img_path):# 加载图像并进行预处理img = image.load_img(img_path, target_size=(IMG_HEIGHT, IMG_WIDTH))img_array = image.img_to_array(img)  # 将图像转换为数组img_array = np.expand_dims(img_array, axis=0)  # 扩展维度,成为一个 batchimg_array = img_array / 255.0  # 归一化处理(与训练时一致)# 预测图像类别prediction = model.predict(img_array)  # 返回的是一个包含概率的数组return prediction[0][0]  # 提取预测的概率值# 测试目录,包含要进行推理的图像
testDir = r"D:\virtualTemp\pythonProject\CNN分类检测\data\train\true_sample"
pics = os.listdir(testDir)
# 遍历目录中的所有图片并进行预测
for pic in pics:picPath = os.path.join(testDir, pic)  # 获取图片的完整路径# 获取预测结果的置信度confidence = predict_image(picPath)# 输出图像的置信度和类别print(f"{pic} 置信度: {confidence:.4f}, 预测结果: {'有缺陷' if confidence >= 0.5 else '无缺陷'}")

我先使用正类样本来测试,先看看拿训练的数据如何,然后再用另外的图片来测试

结果如下图,正类样本中只有一张图判定为了无冲沟,但是我正类样本中其实都应当是冲沟,而我有101张图,因此这里正确率为99.009%

QQ_1734071615033

拿训练的数据来说话可能没有说服力,现在我使用爬图器来批量的爬取一些图片,需要的可以这里拿=> 幽络源爬图器

如图我爬取了3轮桥梁破损图,2轮冲沟地貌图,对于冲沟图,最好是手动删一些莫名奇妙的图,便于验证

QQ_1734072068792

QQ_1734072170259

ok,然后先测试桥梁破损,如果足够符合预期,足够表示模型很好,那么推理出的有缺陷数量应该没有或者很少才对,结果如下

QQ_1734072462049

看起来结果并不好,90张图中,居然有44张判定为了有冲沟缺陷,正确率只有46/90=51.11%,再测试下正类检测呢,如图48张图中只有11张判定为了无,还是不错的。

步骤7.根据结果优化数据集

在步骤6的测试中可知,所训练的模型对正类比较适应,对负类的学习还有所欠缺,处理方法有如下

1.调整判定指标confidence,一般为0.5,可以调大以提高正确率,但是不推荐这么做

2.加大训练轮次

3.训练时的父类样本图片多加一些

ok,方法1我不是很推荐,现在首先加大训练次数到100,然后多爬取一些非冲沟图加入到负类样本之中,当然,桥梁破损的图也放进去一些,然后重新训练获取模型。

训练完后还是按照步骤6中来测试桥梁破损,如图,这一次,90张图中判定为有缺陷的只有7个了,非常不错,正确率提高到了82/90=91.11%

QQ_1734073419635

结语

以上是幽络源的基于卷积神经网络的图像二分类检测模型训练与推理实现教程,对Python、Java感兴趣的小伙伴可加群交流

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/62811.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

RabbitMQ个人理解与基本使用

目录 一. 作用: 二. RabbitMQ的5中队列模式: 1. 简单模式 2. Work模式 3. 发布/订阅模式 4. 路由模式 5. 主题模式 三. 消息持久化: 消息过期时间 ACK应答 四. 同步接收和异步接收: 应用场景 五. 基本使用 &#xff…

前端怎么预览pdf

1.背景 后台返回了一个在线的pdf地址,需要我这边去做一个pdf的预览(需求1),并且支持配置是否可以下载(需求2),需要在当前页就能预览(需求3)。之前我写过一篇预览pdf的文…

滑动窗口算法专题

滑动窗口简介 滑动窗口就是利用单调性,配合同向双指针来优化暴力枚举的一种算法。 该算法主要有四个步骤 1. 先进进窗口 2. 判断条件,后续根据条件来判断是出窗口还是进窗口 3. 出窗口 4.更新结果,更新结果这个步骤是不确定的&#xff0c…

C# 中的Task

文章目录 前言一、Task 的基本概念二、创建 Task使用异步方法使用 Task.Run 方法 三、等待 Task 完成使用 await 关键字使用 Task.Wait 方法 四、处理 Task 的异常使用 try-catch 块使用 Task.Exception 属性 五、Task 的延续使用 ContinueWith 方法使用 await 关键字和异步方法…

【AIGC】如何高效使用ChatGPT挖掘AI最大潜能?26个Prompt提问秘诀帮你提升300%效率的!

还记得第一次使用ChatGPT时,那种既兴奋又困惑的心情吗?我是从一个对AI一知半解的普通用户,逐步成长为现在的“ChatGPT大神”。这一过程并非一蹴而就,而是通过不断的探索和实践,掌握了一系列高效使用的技巧。今天&#…

浩辰CAD教程004:柱梁板

文章目录 柱梁板标准柱角柱构造柱柱齐墙边绘制梁绘制楼板 柱梁板 标准柱 绘制标准柱: ①:点选插入柱子②:沿着一根轴线布置柱子③:指定的矩形区域内的轴线交点插入柱子 替换现有柱子:选择替换之后的柱子形状&#x…

UNIX数据恢复—UNIX系统常见故障问题和数据恢复方案

UNIX系统常见故障表现: 1、存储结构出错; 2、数据删除; 3、文件系统格式化; 4、其他原因数据丢失。 UNIX系统常见故障解决方案: 1、检测UNIX系统故障涉及的设备是否存在硬件故障,如果存在硬件故障&#xf…

桥接模式的理解和实践

桥接模式(Bridge Pattern),又称桥梁模式,是一种结构型设计模式。它的核心思想是将抽象部分与实现部分分离,使它们可以独立地进行变化,从而提高系统的灵活性和可扩展性。本文将详细介绍桥接模式的概念、原理…

HTML综合

一.HTML的初始结构 <!DOCTYPE html> <html lang"en"><head><!-- 设置文本字符 --><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><!-- 设置网页…

二维码数据集,使用yolov,voc,coco标注,3044张各种二维码原始图片(未图像增强)

二维码数据集&#xff0c;使用yolov&#xff0c;voc&#xff0c;coco标注&#xff0c;3044张各种二维码原始图片&#xff08;未图像增强&#xff09; 数据集分割 训练组70&#xff05; 2132图片 有效集20&#xff05; 607图片 测试集10&#xff05; 305图…

用豆包MarsCode IDE,从0到1画出精美数据大屏!

豆包MarsCode IDE 是一个云端 AI IDE 平台&#xff0c;通过内置的 AI 编程助手&#xff0c;开箱即用的开发环境&#xff0c;可以帮助开发者更专注于各类项目的开发。 作为一名前端开发工程师&#xff0c;今天想尝试利用豆包MarsCode IDE&#xff0c;选择 Vue Echarts 创建一个…

游戏引擎学习第42天

仓库: https://gitee.com/mrxiao_com/2d_game 简介 目前我们正在研究的内容是如何构建一个基本的游戏引擎。我们将深入了解游戏开发的每一个环节&#xff0c;从最基础的技术实现到高级的游戏编程。 角色移动代码 我们主要讨论的是角色的移动代码。我一直希望能够使用一些基…

Redis是什么?Redis和MongoDB的区别在那里?

Redis介绍 Redis&#xff08;Remote Dictionary Server&#xff09;是一个开源的、基于内存的数据结构存储系统&#xff0c;它可以用作数据库、缓存和消息中间件。以下是关于Redis的详细介绍&#xff1a; 一、数据结构支持 字符串&#xff08;String&#xff09; 这是Redis最…

Bug 解决 无法正常登录或获取不到用户信息

目录 1、跨域问题 2、后端代码问题 3、前端代码问题 我相信登录这个功能是很多人做项目时候遇到第一个槛&#xff01; **看起来好像很简单的登录功能&#xff0c;实际上还是有点坑的&#xff0c;比如明明账号密码都填写正确了&#xff0c;**为什么登录后请求接口又说我没登…

论文翻译 | ChunkRAG: Novel LLM-Chunk Filtering Method for RAG Systems

摘要 使用大型语言模型&#xff08;LLM&#xff09;的检索-增强生成&#xff08;RAG&#xff09;系统经常由于检索不相关或松散相关的信息而生成不准确的响应。现有的在文档级别操作的方法无法有效地过滤掉此类内容。我们提出了LLM驱动的块过滤&#xff0c;ChunkRAG&#xff0…

Maven(生命周期、POM、模块化、聚合、依赖管理)详解

Maven构建项目的生命周期 在Maven出现之前&#xff0c;项目构建的生命周期就已经存在&#xff0c;软件开发人员每天都在对项目进行清理&#xff0c;编译&#xff0c;测试&#xff0c;部署等工作&#xff0c;这个过程就是项目构建的生命周期。虽然大家都在不停的做构建工作&…

jenkins harbor安装

Harbor是一个企业级Docker镜像仓库‌。 文章目录 1. 什么是Docker私有仓库2. Docker有哪些私有仓库3. Harbor简介4. Harbor安装 1. 什么是Docker私有仓库 Docker私有仓库是用于存储和管理Docker镜像的私有存储库。Docker默认会有一个公共的仓库Docker Hub&#xff0c;而与Dock…

【Python网络爬虫笔记】10- os库存储爬取数据

os库的作用 操作系统交互&#xff1a;os库提供了一种使用Python与操作系统进行交互的方式。使用os库来创建用于存储爬取数据的文件夹&#xff0c;或者获取当前工作目录的路径&#xff0c;以便将爬取的数据存储在合适的位置。环境变量操作&#xff1a;可以读取和设置环境变量。在…

微信小程序从后端获取的图片,展示的时候上下没有完全拼接,有缝隙【已解决】

文章目录 1、index.wxml2、index.js3、detail.detail为什么 .rich-text-style 样式可以生效&#xff1f;1. <rich-text> 组件的特殊性2. 类选择器的作用范围3. 样式优先级4. line-height: 0 的作用5. 为什么直接使用 rich-text 选择器无效&#xff1f; 总结 上下两张图片…

Linux-apache虚拟主机配置笔记

一、 安装apache 有需要的话&#xff0c;可以去查看具体的apache的安装apache安装https://blog.csdn.net/m0_68472908/article/details/139348739?spm1001.2014.3001.5501 都可以使用本地yum源搭建本地yum源搭建https://blog.csdn.net/m0_68472908/article/details/14385692…