基于tensorflow的咖啡豆识别

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

一、前期工作

1. 设置GPU

import tensorflow as tfgpus = tf.config.list_physical_devices("GPU")if gpus:tf.config.experimental.set_memory_growth(gpus[0], True)  #设置GPU显存用量按需使用tf.config.set_visible_devices([gpus[0]],"GPU")print("GPU is available")

2. 导入数据

from tensorflow       import keras
from tensorflow.keras import layers,models
import numpy             as np
import matplotlib.pyplot as plt
import os,PIL,pathlibdata_dir = "F:/host/Data/咖啡豆识别数据/"
data_dir = pathlib.Path(data_dir)
image_count = len(list(data_dir.glob('*/*.png')))print("图片总数为:",image_count)

在这里插入图片描述

二、数据预处理

1. 加载数据

使用image_dataset_from_directory方法将磁盘中的数据加载到tf.data.Dataset

batch_size = 8
img_height = 224
img_width = 224
train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="training",seed=123,image_size=(img_height, img_width),batch_size=batch_size)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.1,subset="validation",seed=123,image_size=(img_height, img_width),batch_size=batch_size)
class_names = train_ds.class_names
print(class_names)

在这里插入图片描述

2. 可视化数据

plt.figure(figsize=(10, 4))  # 图形的宽为10高为5for images, labels in train_ds.take(1):for i in range(8):ax = plt.subplot(2, 4, i + 1)  plt.imshow(images[i].numpy().astype("uint8"))plt.title(class_names[labels[i]])plt.axis("off")

在这里插入图片描述

for image_batch, labels_batch in train_ds:print(image_batch.shape)print(labels_batch.shape)break

在这里插入图片描述

3. 配置数据集

  • shuffle() :打乱数据,关于此函数的详细介绍可以参考:https://zhuanlan.zhihu.com/p/42417456
  • prefetch() :预取数据,加速运行,其详细介绍可以参考我前两篇文章,里面都有讲解。
  • cache() :将数据集缓存到内存当中,加速运行
AUTOTUNE = tf.data.AUTOTUNEtrain_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds   = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
normalization_layer = layers.experimental.preprocessing.Rescaling(1./255)train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))
val_ds   = val_ds.map(lambda x, y: (normalization_layer(x), y)) 
image_batch, labels_batch = next(iter(val_ds))
first_image = image_batch[0]# 查看归一化后的数据
print(np.min(first_image), np.max(first_image))

三、构建VGG-16网络

from tensorflow.keras import layers, models, Input
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, Flatten, Dropoutdef VGG16(nb_classes, input_shape):# 输入层input_tensor = Input(shape=input_shape)# 卷积层1x = Conv2D(64, (3,3), activation='relu', padding='same',name='block1_conv1')(input_tensor)x = Conv2D(64, (3,3), activation='relu', padding='same',name='block1_conv2')(x)x = MaxPooling2D((2,2), strides=(2,2), name = 'block1_pool')(x)# 卷积层2x = Conv2D(128, (3,3), activation='relu', padding='same',name='block2_conv1')(x)x = Conv2D(128, (3,3), activation='relu', padding='same',name='block2_conv2')(x)x = MaxPooling2D((2,2), strides=(2,2), name = 'block2_pool')(x)# 卷积层3x = Conv2D(256, (3,3), activation='relu', padding='same',name='block3_conv1')(x)x = Conv2D(256, (3,3), activation='relu', padding='same',name='block3_conv2')(x)x = Conv2D(256, (3,3), activation='relu', padding='same',name='block3_conv3')(x)x = MaxPooling2D((2,2), strides=(2,2), name = 'block3_pool')(x)# 卷积层4x = Conv2D(512, (3,3), activation='relu', padding='same',name='block4_conv1')(x)x = Conv2D(512, (3,3), activation='relu', padding='same',name='block4_conv2')(x)x = Conv2D(512, (3,3), activation='relu', padding='same',name='block4_conv3')(x)x = MaxPooling2D((2,2), strides=(2,2), name = 'block4_pool')(x)# 卷积层5x = Conv2D(512, (3,3), activation='relu', padding='same',name='block5_conv1')(x)x = Conv2D(512, (3,3), activation='relu', padding='same',name='block5_conv2')(x)x = Conv2D(512, (3,3), activation='relu', padding='same',name='block5_conv3')(x)x = MaxPooling2D((2,2), strides=(2,2), name = 'block5_pool')(x)# 展平层x = Flatten()(x)# 全连接层1x = Dense(4096, activation='relu',name='fc1')(x)# 全连接层2x = Dense(4096, activation='relu',name='fc2')(x)# 输出层output_tensor = Dense(nb_classes, activation='softmax',name='predictions')(x)# 创建模型model = Model(input_tensor, output_tensor)return model# 创建模型
model=VGG16(len(class_names), (img_width, img_height, 3))# 打印模型结构
model.summary()

在这里插入图片描述

3. 网络结构图

关于卷积的相关知识可以参考文章:https://mtyjkh.blog.csdn.net/article/details/114278995

结构说明:

  • 13个卷积层(Convolutional Layer),分别用blockX_convX表示
  • 3个全连接层(Fully connected Layer),分别用fcX与predictions表示
  • 5个池化层(Pool layer),分别用blockX_pool表示

VGG-16包含了16个隐藏层(13个卷积层和3个全连接层),故称为VGG-16

在这里插入图片描述

四、编译

在准备对模型进行训练之前,还需要再对其进行一些设置。以下内容是在模型的编译步骤中添加的:

  • 损失函数(loss):用于衡量模型在训练期间的准确率。
  • 优化器(optimizer):决定模型如何根据其看到的数据和自身的损失函数进行更新。
  • 指标(metrics):用于监控训练和测试步骤。以下示例使用了准确率,即被正确分类的图像的比率。
# 设置初始学习率
initial_learning_rate = 1e-4lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate, decay_steps=30,      # 敲黑板!!!这里是指 steps,不是指epochsdecay_rate=0.92,     # lr经过一次衰减就会变成 decay_rate*lrstaircase=True)# 设置优化器
opt = tf.keras.optimizers.Adam(learning_rate=initial_learning_rate)model.compile(optimizer=opt,loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])

五、训练模型

epochs = 20history = model.fit(train_ds,validation_data=val_ds,epochs=epochs
)

在这里插入图片描述

六、可视化结果

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']loss = history.history['loss']
val_loss = history.history['val_loss']epochs_range = range(epochs)plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

在这里插入图片描述

七、个人小结

在本次咖啡豆识别项目中,我们通过设置GPU、导入并预处理数据、构建深度学习模型,以及对模型进行训练和评估,实现了对咖啡豆图像的自动识别。整个过程涵盖了数据加载与可视化、数据集配置、模型构建与优化等关键步骤,最终显著提升了图像分类的准确性,同时也加深了我们对深度学习技术的实践理解。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/18594.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

全球伦敦银收盘时间一致吗

跟伦敦金市场相似,伦敦银市场也是一个全球化的无形市场,无论来自世界上什么地方的投资者参与其中,都可以得到全天接近24个小时的连贯行情,只要精力足够,根本不用担心没有交易获利的机会。但由于交易平台始终有维护的需…

轻松实现PDF文件的在线浏览

福昕软件最近发布了一款名为Cloud API的产品,通过几行代码即可轻松实现PDF文件的在线浏览。先一睹为快吧。 简介 先看看产品官网:福昕 Cloud API Cloud API包括两个形态产品,一个是在线的PDF查看工具,叫PDF Embed API,另外一个…

git 学习(一)

一、版本控制 (一)介绍 版本迭代 每一次更新代码 都会出现新的版本如果我们需要之前的版本的文件 我们就得需要版本控制的文件 每一次更新的结果我们都保存下来 多人开发必须要用版本控制器 否则代价会很大 (二)主流的版本控制…

uniapp开发微信小程序:用户手机号授权获取全流程详解与实战示例

随着多端小程序研发工具的日益普及,诸如uniapp、Taro、Flutter等跨平台解决方案使得开发者能够高效地构建同时适配多个主流小程序平台(如微信、支付宝、百度、字节跳动等)的应用。尽管各平台间存在一定的差异性,但在获取用户手机号…

防火墙技术基础篇:基于Ensp配置防火墙NAT server(服务器映射)

配置防火墙NAT server(服务器映射) 什么是NAT Server (服务器映射) NAT(Network Address Translation,网络地址转换)是一种允许多个设备共享一个公共IP地址的技术。NAT Server,也称为服务器映射,是NAT技术中的一种应…

7 Series FPGAs Integrated Block for PCI Express IP核设计中的物理层控制核状态接口

物理层控制和状态允许用户应用程序根据数据吞吐量和电源需求来更改链路的宽度和速度。 1 Design Considerations for a Directed Link Change 在Directed Link Change(定向链接更改)期间需要注意的事项有: 链接更改操作(Link c…

C++STL---模拟实现string

我们这篇文章进行string的模拟实现。 为了防止标准库和我们自己写的string类发生命名冲突,我们将我们自己写的string类放在我们自己的命名空间中: 我们先来搭一个class string的框架: namespace CYF{ public://各种成员函数 priva…

基于单片机智能防触电装置的研究与设计

摘 要 : 针对潮湿天气下配电线路附近易发生触电事故等问题 , 对单片机的控制算法进行了研究 , 设 计 了 一 种 基 于 单片机的野外智能防触电装置。 首先建立了该装置的整体结构框架 , 再分别进行硬件设计和软件流程分析 &#xf…

IDEA升级web项目为maven项目乱码

今天将一个java web项目改造为maven项目。 首先&#xff0c;创建一个新的maven项目&#xff0c;将文件拷贝到新项目中。 其次&#xff0c;将旧项目的jar包&#xff0c;在maven的pom.xml做成依赖 接着&#xff0c;把没有maven坐标的jar包在编译的时候也包含进来 <build>…

实战教程:使用Go的net/http/fcgi包打造高性能Web应用

实战教程&#xff1a;使用Go的net/http/fcgi包打造高性能Web应用 简介理解FCGI协议FastCGI工作原理CGI与FastCGI对比其他接口对比 初步使用net/http/fcgi包设置和配置基础环境一个简单的FastCGI应用示例本地测试与调试 深入net/http/fcgi包的高级用法探索net/http/fcgi的主要功…

气膜建筑的运行保障:应对停电的解决方案—轻空间

气膜建筑作为一种现代化的建筑形式&#xff0c;以其独特的结构和多样的应用赢得了广泛关注。这种建筑依靠风机不断往内部吹气来维持其结构形态&#xff0c;那么如果遇到停电的情况&#xff0c;该如何确保其正常运行呢&#xff1f; 气膜建筑的供风系统 气膜建筑内部的气压维持依…

信创崛起:从安可到国产化,中国信息技术创新之路

在信息技术迅速演进的时代背景下&#xff0c;几个核心概念日益凸显其重要性&#xff1a;安全可靠&#xff08;简称安可&#xff09;、信息技术创新&#xff08;简称信创&#xff09;&#xff0c;以及国产化。这些概念紧密关联&#xff0c;共同服务于一个宏伟目标——构建一个独…

MFC 发起 HTTP Post 请求 发送MES消息

文章目录 获取Token将获取的Token写入JSON文件 将测试参数发送到http首先将测试参数写入到TestData.JSON文件rapidjson 库需要将CString 进行类型转换才能使用&#xff0c;将CString 转换为const char* 发送JSON 参数到http中&#xff0c;并且获取返回结果写入TestFinish.JSON文…

SpringSecurity6从入门到实战之SpringSecurity快速入门

SpringSecurity6从入门到实战之SpringSecurity快速入门 环境准备 依赖版本号springsecurity6.0.8springboot3.0.12JDK17 这里尽量与我依赖一致,免得在学习过程中出现位置的bug等 创建工程 这里直接选择springboot初始化快速搭建工程,导入对应的jdk17进行创建 直接勾选一个web…

Redhat9 LAMP安全配置方案及测试

目录 数据库主机 安装Mariadb数据库服务 设置mariadb开机自动启动 Php主机 部署Apache服务器 设置apache服务开机自启 安装php 安装 phpMyAdmin 打开测试机 更新软件包列表&#xff1a; 首先&#xff0c;确保你的软件包列表是最新的。打开终端并输入以下命令&#xf…

Linux查看设备信息命令

dmidecode | grep Product Name 查看grub版本号&#xff1a;rpm -qa | grep -i "grub" 客户端操作系统版本&#xff1a; cat /etc/issue cat /etc/redhat-release 处理器品牌及型号&#xff1a; less /proc/cpuinfo |grep model

【TCP协议中104解析】wireshark抓取流量包工具,群殴协议解析基础

Tcp ,104 ,wireshark工具进行解析 IEC104 是用于监控和诊断工业控制网络的一种标准&#xff0c;而 Wireshark则是一款常用的网络协议分析工具&#xff0c;可以用干解析TEC104 报文。本文将介绍如何使用 Wireshark解析 IEC104报文&#xff0c;以及解析过 程中的注意事项。 一、安…

AI图书推荐:用ChatGPT和Python搭建AI应用来变现

《用ChatGPT和Python搭建AI应用来变现》&#xff08;Building AI Applications with ChatGPT API&#xff09;将ChatGPT API与Python结合使用&#xff0c;可以开启构建非凡AI应用的大门。通过利用这些API&#xff0c;你可以专注于应用逻辑和用户体验&#xff0c;而ChatGPT强大的…

Axios的使用简单说明

axios 请求方式和参数 axios 可以发送 ajax 请求&#xff0c;不同的方法可以发送不同的请求: axios.get&#xff1a;发送get请求 axios.post&#xff1a;发送post请求 axios.put&#xff1a;发送put请求 axios.delete&#xff1a;发送delete请求 无论哪种方法&#xff0c;第一…

【2】:向量与矩阵

向量 既有大小又有方向的量叫做向量 向量的模 向量的长度 单位向量 (只表示方向不表示长度) 向量的加减运算 向量求和 行向量与列向量的置换 图形学中竖着写 向量的长度计算 点乘&#xff08;计算向量间夹角&#xff09; 点乘满足的运算规律 交换律、结合律、分配…