使用CNN或resnet,分别在flower5,flower17,flower102数据集上实现花朵识别分类-附源码-免费

前言

使用cnn和resnet实现了对flower5,flower17,flower102数据集上实现花朵识别分类。也就是6份代码,全部在Gitee仓库里,记得点个start支持谢谢。

本文给出flower17在cnn网络实现,flower102在resnet网络实现的代码。其余在Gitee仓库中,还附有学习其他博主的模型的代码。

前置准备

理论:一定的深度学习,卷积神经网络的理论知识学习,python基础语法。

环境:Anaconda3安装,python安装,pycharm安装,相应的依赖包安装,如TensorFlow,matplotlib,pillow,pandas等。

数据集

介绍

flower5

flower17

flower102

下载

https://gitee.com/karrysmile/flower_data.git

每个flower.*文件夹就是一个数据集。

每个数据集中包含train,valid文件夹,分别作训练集和数据集用。

训练集和数据集文件架构相同,包含文件夹相同,同种花归为一个文件夹,以花名为文件夹名。

运行要求

我的电脑配置是

flower5,17可以在本地运行,flower102建议用显卡跑。没有显卡的可以到腾讯云或其他平台,租一个服务器来跑,我租了一个Tesla V4显卡来跑,1.6r一小时,用钱换时间。

代码

代码思路

  1. 导入数据集
  2. 数据预处理
  3. 构建模型
  4. 训练模型
  5. 调参优化
  6. 结果可视化
  7. 模型复用

代码解释

以flower17数据集的cnn模型,flower102数据集的resnet模型作为举例,其余在文末的仓库里。

每行代码都加了注释,看注释吧。

# flower17_cnn
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
import sys
import datetime
from keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint
# 打印环境版本信息 作者信息
print("@Author karrysmile")
print("@Date "+datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
print("Python version:"+sys.version)
print("TensorFlow version:", tf.__version__)# 设置GPU设备 有的话动态增长
physical_devices = tf.config.list_physical_devices('GPU')
if physical_devices:tf.config.experimental.set_memory_growth(physical_devices[0], True)# 该数据集总共1360个文件,其中1190用于训练集,170用于验证集# 数据准备和增强
# 文件目录
train_data_dir = 'flower_data/train'
valid_data_dir = 'flower_data/valid'
# 批处理大小
batch_size = 32
# 每张图片的重塑大小
image_size = (128, 128)# 使用 ImageDataGenerator 对图像进行数据增强
train_datagen = ImageDataGenerator(# 设定数据增强的模式参数# 将图像的像素值缩放到 [0, 1] 范围内。rescale=1./255,# 随机旋转图像30度rotation_range=30,# 随机水平平移20%width_shift_range=0.2,# 随机垂直平移20%height_shift_range=0.2,# 随机应用错切变换20度shear_range=0.2,# 随机缩放图像尺寸20%zoom_range=0.2,# 随机进行水平翻转horizontal_flip=True,# 随机亮度变化20%brightness_range=[0.8, 1.2],  # 亮度范围
)
# 验证集只把图像的像素值缩放到 [0, 1] 范围内。
test_datagen = ImageDataGenerator(rescale=1./255)# 应用数据增强模型,设定训练数据,从文件目录读取图像
train_generator = train_datagen.flow_from_directory(# 训练集目录train_data_dir,# 图片重塑大小target_size=image_size,# 批处理张数batch_size=batch_size,# 分类模型 - 多分类class_mode='categorical',
)
# 应用数据增强模型,设定验证集数据,从文件目录读取图像
valid_generator = test_datagen.flow_from_directory(# 验证集目录valid_data_dir,# 重塑图像大小target_size=image_size,# 批处理数batch_size=batch_size,# 设定分类模型class_mode='categorical',
)# 搭建CNN模型
model = tf.keras.models.Sequential([# 卷积层,32个filter,卷积核大小为3x3,激活函数为relu,输入形状为(128, 128, 3),长x宽x3通道(RGB)tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(128, 128, 3)),# 最大池化层 提取主要特征,减少计算量tf.keras.layers.MaxPooling2D(2, 2),tf.keras.layers.Conv2D(32, (3, 3), activation='relu'),tf.keras.layers.MaxPooling2D(2, 2),# 卷积层,64个filter,卷积核大小3x3,激活函数为relutf.keras.layers.Conv2D(64, (3, 3), activation='relu'),tf.keras.layers.MaxPooling2D(2, 2),# 最大池化层 提取主要特征,减少计算量tf.keras.layers.BatchNormalization(),# 卷积层,128个filter,卷积核大小为3x3,激活函数为relutf.keras.layers.Conv2D(128, (3, 3), activation='relu'),tf.keras.layers.BatchNormalization(),tf.keras.layers.MaxPooling2D(2, 2),# 将多维输入数据展平为一维向量,以便连接到全连接层tf.keras.layers.Flatten(),# 全连接层,512维,激活函数为relutf.keras.layers.Dense(512, activation='relu'),# dropout 30%的数据 避免过拟合tf.keras.layers.Dropout(0.3),# 全连接层,输出,17个维度对应17种花,激活函数为softmax,用于多分类tf.keras.layers.Dense(17, activation='softmax')
])# 设定优化器 Adam 初始学习率为0。001
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
# 编译模型,优化器选择Adam,损失函数为交叉熵损失函数,适用于多类别分类问题,准确率作为评估模型性能的指标
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
# 打印出模型的摘要信息,包括每一层的名称、输出形状和参数数量等
model.summary()# 训练模型
# 检查点,根据验证准确率,每个epoch判断要不要保存最好的模型  保存整个模型
checkpoint = ModelCheckpoint("model", monitor='val_accuracy', verbose=1,save_best_only=True, save_weights_only=False, mode='auto', save_freq='epoch')
# 早退,当设定的n个epoch发生,验证准确率都没有发生提升,就退出
early = EarlyStopping(monitor='val_accuracy', min_delta=0, patience=50, verbose=1, mode='auto')
# 减少学习率 检测val_loss 如果5个epoch没有发生更好的变化,就变为原来的二分之一,避免过拟合
reduce_lr = ReduceLROnPlateau(monitor='val_loss', patience=5, mode='auto',factor=0.5)
# 模型训练,结果保存到history
history = model.fit(# 训练数据放进来train_generator,# 计算每个epoch的数量(总长度 除以 批处理大小)steps_per_epoch=train_generator.samples // batch_size,# 要跑的轮数epochs=1000,# 批处理大小batch_size=batch_size,validation_data=valid_generator,validation_steps=valid_generator.samples // batch_size,# 回调函数,用于监测和调整超参数callbacks=[reduce_lr,checkpoint,early]
)
# 保存模型
model.save('flower17_cnn.h5')
model.save('flower17_cnn')
# 用全部测试数据评估模型
test_loss, test_acc = model.evaluate(valid_generator, verbose=2)
print('\nTest accuracy:', test_acc)# 绘制训练和测试损失
plt.plot(history.history['loss'], label='train_loss')
plt.plot(history.history['val_loss'], label='valid_loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()# 绘制训练和测试准确率
plt.plot(history.history['accuracy'], label='train_acc')
plt.plot(history.history['val_accuracy'], label='valid_acc')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

# flower102_resnet18
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
import sys
import datetime
from keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint
# 打印环境版本信息 作者信息
print("@Author karrysmile")
print("@Date "+datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
print("Python version:"+sys.version)
print("TensorFlow version:", tf.__version__)# 设置GPU设备
physical_devices = tf.config.list_physical_devices('GPU')
if physical_devices:tf.config.experimental.set_memory_growth(physical_devices[0], True)# 数据准备和增强
train_data_dir = 'flower_data/train'
valid_data_dir = 'flower_data/valid'
batch_size = 32
image_size = (128, 128)train_datagen = ImageDataGenerator(rescale=1./255,rotation_range=30,width_shift_range=0.2,height_shift_range=0.2,shear_range=0.2,zoom_range=0.2,horizontal_flip=True,brightness_range=[0.8, 1.2],  # 亮度范围
)valid_datagen = ImageDataGenerator(rescale=1./255)train_generator = train_datagen.flow_from_directory(train_data_dir,target_size=image_size,batch_size=batch_size,class_mode='categorical',
)valid_generator = valid_datagen.flow_from_directory(valid_data_dir,target_size=image_size,batch_size=batch_size,class_mode='categorical',
)def ConvLayer(x,filters,kernel_size,stride):x = tf.keras.layers.Conv2D(filters=filters, kernel_size=kernel_size, strides=stride,padding='same')(x)x = tf.keras.layers.BatchNormalization(epsilon=1e-5,momentum=0.1)(x)return xdef ResNetBlock(input,filters,kernel_size,strides):x = ConvLayer(input,filters,kernel_size,strides)x = tf.keras.layers.Activation('relu')(x)x = ConvLayer(x,filters,kernel_size,(1,1))if strides != (1,1):residual = ConvLayer(input,filters,(1,1),strides)else:residual = inputx = x+residualx = tf.keras.layers.Activation('relu')(x)return xdef ResNet(input_size):# headx = ConvLayer(input_size,64,(7,7),(2,2))x = tf.keras.layers.Activation('relu')(x)x = tf.keras.layers.MaxPooling2D(3,strides=2,padding='same')(x)# layer1-------------------x = ResNetBlock(x,64,(3,3),(1,1))x = ResNetBlock(x,64,(3,3),(1,1))# layer2-------------------x = ResNetBlock(x,128,(3,3),(2,2))x = ResNetBlock(x,128,(3,3),(1,1))# layer3-------------------x = ResNetBlock(x,256,(3,3),(2,2))x = ResNetBlock(x,256,(3,3),(1,1))# layer4-------------------x = ResNetBlock(x,512,(3,3),(2,2))x = ResNetBlock(x,512,(3,3),(1,1))# tailx = tf.keras.layers.AvgPool2D(1)(x)x = tf.keras.layers.Flatten()(x)x = tf.keras.layers.Dense(512, activation='relu')(x)output = tf.keras.layers.Dense(102, activation='softmax')(x)return outputinputs = tf.keras.Input((128,128,3))
outputs = ResNet(inputs)
model = tf.keras.Model(inputs,outputs)optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()
# 训练模型
checkpoint = ModelCheckpoint("model", monitor='val_accuracy', verbose=1,save_best_only=True, save_weights_only=False, mode='auto', save_freq='epoch')
early = EarlyStopping(monitor='val_accuracy', min_delta=0, patience=10, verbose=1, mode='auto')
reduce_lr = ReduceLROnPlateau(monitor='val_loss', patience=3, mode='auto',factor=0.2)
history = model.fit(train_generator,steps_per_epoch=train_generator.samples // batch_size,epochs=50,batch_size=batch_size,validation_data=valid_generator,validation_steps=valid_generator.samples // batch_size,callbacks=[reduce_lr,checkpoint,early]
)# 保存为h5文件
model.save('flower102_resnet.h5')
# 保存为文件夹形式,可以注释掉
model.save('flower102_resnet')test_loss, test_acc = model.evaluate(valid_generator, verbose=2)
print('\nTest accuracy:', test_acc)# 绘制训练和测试损失
plt.plot(history.history['loss'], label='train_loss')
plt.plot(history.history['val_loss'], label='valid_loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()# 绘制训练和测试准确率
plt.plot(history.history['accuracy'], label='train_acc')
plt.plot(history.history['val_accuracy'], label='valid_acc')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

运行结果

flower17_cnn

flower102_resnet18

随机加载一张图片来验证

在根目录下放置一张test.jpg,加载这张图片并输出验证结果。

import keras.models
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing import image# 加载SavedModel格式的模型
loaded_model = keras.models.load_model('flower17_cnn')# 进行预测等操作# 读取测试图片
img_path = 'test.jpg'  # 测试图片的路径
img = image.load_img(img_path, target_size=(128, 128))
img_array = image.img_to_array(img)
img_array = np.expand_dims(img_array, axis=0)
img_array /= 255.# 进行预测
prediction = loaded_model.predict(img_array)
predicted_class_index = np.argmax(prediction)
class_labels = ['bluebell', 'buttercup', 'colts_foot', 'cowslip', 'crocus', 'daffodil', 'daisy', 'dandelion', 'fritillary', 'iris', 'lily_valley', 'pansy', 'snowdrop', 'sunflower', 'tigerlily', 'tulip', 'windflower']
predicted_class = class_labels[predicted_class_index]print("当前图片预测的类型是:--->>>", predicted_class)# 显示预测结果
plt.imshow(img)
plt.title('Predicted: {}'.format(predicted_class))
plt.axis('off')
plt.show()

运行结果

总结

  1. 真的需要算力,不然很多时间都留在等待上面,但又恰恰因为等待,可以有更深的思考(所以需要一定时间的等待,但不能过长)
  2. 不要随意更新或者卸载依赖包,会容易影响整个环境的包之间的版本不匹配
  3. 越深层的网络,需要考虑的东西越多,如果不考虑,仅仅是堆深度,可能根本学不到东西,甚至比原来更差。
  4. 图片进行垂直翻转,会出现验证率下降的问题。待验证和解决。
  5. 最好是自动监控与停止,多参考别人的代码。

参考文章

ResNet18详细原理(含tensorflow版源码)_resnet18网络结构-CSDN博客

(四)pytorch图像识别实战之用resnet18实现花朵分类(代码+详细注解)_pytorch中调用resnet18进行分类-CSDN博客

TensorFlow指定GPU使用及监控GPU占用情况_taskflow gpu-CSDN博客

Gitee仓库

包含两种模型(cnn,resnet)在三个数据集(flower5,17,102)上的六个实现,用ipynb存储。

resnet附上了其他作者的迁移预训练结果的代码。文件名包含example的代码不是本人写的。

https://gitee.com/karrysmile/flowers.git

有用请点个star,按赞收藏关注。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/834189.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

docker私有仓库的registry

简介 Docker私有仓库的Registry是一个服务,主要用于存储、管理和分发Docker镜像。具体来说,Registry的功能包括: 存储镜像:Registry提供一个集中的地方来存储Docker镜像,包括镜像的层次结构和元数据。 版本控制&…

基于HSI模型的水下图像增强算法,Matlab实现

博主简介: 专注、专一于Matlab图像处理学习、交流,matlab图像代码代做/项目合作可以联系(QQ:3249726188) 个人主页:Matlab_ImagePro-CSDN博客 原则:代码均由本人编写完成,非中介,提供…

【数据结构】-- 链表专题

链表的分类 前面我们实现了单链表,单链表只是链表的一种。可以根据以下几个标准来判断链表的类型: 1.单向或者双向 如图所示,单向链表中一个节点的指针域只储存了下一个节点的指针,能通过前一个节点访问后一个节点,无…

【4089】基于小程序实现的互动打卡系统

作者主页:Java码库 主营内容:SpringBoot、Vue、SSM、HLMT、Jsp、PHP、Nodejs、Python、爬虫、数据可视化、小程序、安卓app等设计与开发。 收藏点赞不迷路 关注作者有好处 文末获取源码 技术选型 【后端】:Java 【框架】:spring…

怎么写毕业论文的? 推荐4个AI工具

写作这件事一直让我们从小学时期就开始头痛,初高中时期800字的作文让我们焦头烂额,一篇作文里用尽了口水话,拼拼凑凑才勉强完成。 大学时期以为可以轻松顺利毕业,结果毕业前的最后一道坎拦住我们的是毕业论文,这玩意不…

短视频矩阵系统源码/saas--总后台端、商户端、代理端、源头开发

短视频矩阵系统源码/saas--总后台端、商户端、代理端、源头开发 搭建短视频矩阵系统源码的交付步骤可以概括为以下几个关键环节: 1. **系统需求分析**:明确系统需要支持的功能,如短视频的上传、存储、播放、分享、评论、点赞等。 2. **技术选…

Python深度学习基于Tensorflow(5)机器学习基础

文章目录 监督学习线性回归逻辑回归决策树支持向量机朴素贝叶斯 集成学习BaggingBoosting 无监督学习主成分分析KMeans聚类 缺失值和分类数据处理处理缺失数据分类数据转化为OneHot编码 葡萄酒数据集示例 机器学习的流程如下所示: 具体又可以分为以下五个步骤&#…

Python开源工具库使用之运动姿势追踪库mediapipe

文章目录 前言一、姿势估计1.1 姿态关键点1.2 旧版 solution API1.3 新版 solution API1.4 俯卧撑计数 二、手部追踪2.1 手部姿态2.2 API 使用2.3 识别手势含义 参考 前言 Mediapipe 是谷歌出品的一种开源框架,旨在为开发者提供一种简单而强大的工具,用…

[C++核心编程-04]----C++类和对象之封装

目录 引言 正文 01-类和对象简介 02-封装简介 03-封装的意义 04-封装案例之设计学生类 05-封装的权限控制 06-struct和class的区别 07-成员属性设置为私有 08-封装案例1-设计立方体 09-封装案例2-判断点和圆的关系 总结 引言 在C中,…

Failed to build flash-attn:ERROR: Could not build wheels for flash-attn

安装 FlashAttention 的时候遇到报错: Failed to build flash-attn ERROR: Could not build wheels for flash-attn, which is required to install pyproject.toml-based projects可能是安装的版本与环境存在冲突吧,我的环境是: python 3.1…

堆的应用2——TOPK问题

TOPK问题 TOP-K问题:即求数据结合中前K个最大的元素或者最小的元素,一般情况下数据量都比较大。 比如:专业前10名、世界500强、富豪榜、游戏中前100的活跃玩家等。 情况1——数据量小 对于Top-K问题,能想到的最简单直接的方式就…

嵌入式C语言高级教程:实现基于STM32的自适应交通信号控制系统

自适应交通信号控制系统能够基于实时交通流数据调整信号灯的时长,提高路口的通行效率。本教程将指导您如何在STM32微控制器上实现一个基本的自适应交通信号控制系统。 一、开发环境准备 硬件要求 微控制器:STM32F103C8,具备足够的处理能力…

Eclipse下载安装教程(包含JDK安装)【保姆级教学】【2023.10月最新版】

目录 文章最后附下载链接 第一步:下载Eclipse,并安装 第二步:下载JDK,并安装 第三步:Java运行环境配置 安装Eclipse必须同时安装JDK !!! 文章最后附下载链接 第一步&#xf…

[法规规划|数据概念]金融行业数据资产和安全管理系列文件解析(3)

“ 金融行业在自身数据治理和资产化建设方面一直走在前列。” 一直以来,金融行业由于其自身需要,都是国内开展信息化建设最早,信息化程度最高的行业。 在当今数据要素资产化的浪潮下,除了行业自身自身数据治理和资产化建设方面&am…

EditReady for Mac激活版:专业视频转码工具

对于视频专业人员来说,一款高效的视频转码工具是不可或缺的。EditReady for Mac正是这样一款强大的工具,它拥有简洁直观的操作界面和强大的功能,让您的视频处理工作事半功倍。 EditReady for Mac支持多种视频格式的转码,并且支持常…

【Java】初识网络编程

文章目录 前言✍一、互联网的发展1.独立模式2.网络的出现局域网LAN广域网WAN ✍二、网络编程概述✍三、网络编程中的术语介绍IP地址端口号协议OSI七层模型TCP\IP四层模型 ✍四、协议的层级之间是如何配合工作的 前言 在本文中,会对网络编程的一些术语进行解释&#…

动态规划——路径问题:931.下降路径最小和

文章目录 题目描述算法原理1.状态表示(经验题目)2.状态转移方程3.初始化4.填表顺序5.返回值 代码实现CJava 题目描述 题目链接:931.下降路径最小和 关于这⼀类题,看过我之前的博客的朋友对于状态表示以及状态转移是⽐较容易分析…

5分钟了解下HDFS

随着大数据时代的到来,传统的数据存储和管理方式已经无法满足日益增长的数据处理需求。HDFS(Hadoop Distributed File System)作为Apache Hadoop项目的一部分,以其高度的容错性、可扩展性和高吞吐量,成为了处理大规模数…

抖音APP运用的AI技术拆解

1.推荐系统(RS) 用户画像:根据用户的信息(如地区、性别、年龄、收藏、关注......)进行分析,构建用户画像,对用户进行分类; 行为分析:将用户的显形行为数据(如…

搜维尔科技:OptiTrack是基于LED墙虚拟制作舞台的最佳选择

OptiTrack因其绝对精度、易用性、可靠性以及与现场工具的完美集成而被选中&#xff0c;仍然是全球首屈一指的基于 LED 墙的虚拟制作舞台的选择。 当今虚拟制作阶段的低延迟、超精确摄像机跟踪标准 /- 0.2 毫米 位置精度1 < 10 毫秒 系统延迟 /- 0.1 度 旋转精度2 电影…