第T7周:Tensorflow实现咖啡豆识别

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

目标

具体实现

(一)环境

语言环境:Python 3.10
编 译 器: PyCharm
框 架:

(二)具体步骤
1. 使用GPU
--------------------------utils.py-------------------
import tensorflow as tf  
import PIL  
import matplotlib.pyplot as plt  def GPU_ON():  # 查询tensorflow版本  print("Tensorflow Version:", tf.__version__)  # 设置使用GPU  gpus = tf.config.list_physical_devices("GPU")  print(gpus)  if gpus:  gpu0 = gpus[0]  # 如果有多个GPU,仅使用第0个GPU  tf.config.experimental.set_memory_growth(gpu0, True)  # 设置GPU显存按需使用  tf.config.set_visible_devices([gpu0], "GPU")

使用GPU并查看数据

import tensorflow as tf  
from tensorflow import keras  
import numpy as np  
import matplotlib.pyplot as plt  
import os, PIL, pathlib  
from utils import GPU_ONGPU_ON()  data_dir = "./datasets/coffee/"  
data_dir = pathlib.Path(data_dir)  image_count = len(list(data_dir.glob("*/*.png")))  
print("图片总数量为:", image_count)
------------------
图片总数量为: 1200
2. 加载数据
# 加载数据  
batch_size = 32  
img_height, img_width = 224, 224  train_ds = tf.keras.preprocessing.image_dataset_from_directory(  data_dir,  validation_split=0.2,  subset="training",  seed=123,  image_size=(img_height, img_width),  batch_size=batch_size,  
)  val_ds = tf.keras.preprocessing.image_dataset_from_directory(  data_dir,  validation_split=0.2,  subset="validation",  seed=123,  image_size=(img_height, img_width),  batch_size=batch_size,  
)
--------------------
Found 1200 files belonging to 4 classes.
Using 960 files for training.
Found 1200 files belonging to 4 classes.
Using 240 files for validation.

获取标签:

# 获取标签  
class_names = train_ds.class_names  
print(class_names)
------------------
['Dark', 'Green', 'Light', 'Medium']

可视化数据:

# 可视化数据  
plt.figure(figsize=(10, 10))  
for images, labels in train_ds.take(2):  for i in range(30):  ax = plt.subplot(5, 6, i+1)  plt.imshow(images[i].numpy().astype("uint8"))  plt.title(class_names[labels[i]])  plt.axis("off")  
plt.show()


检查一下数据:

# 检查一下数据  
for image_batch, labels_batch in train_ds:  print(image_batch.shape)  print(labels_batch.shape)  break
----------------------------
(32, 224, 224, 3)
(32,)
**3.**配置数据集
# 配置数据集  
AUTOTUNE = tf.data.AUTOTUNE  
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)  
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)  normalization_layer = layers.experimental.preprocessing.Rescaling(1./255)  
train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))  
val_ds = val_ds.map(lambda x, y: (normalization_layer(x), y))  image_batch, labels_batch = next(iter(train_ds))  
first_image = image_batch[0]  # 查看归一化后的数据  
print(np.min(first_image), np.max(first_image))
--------------------
0.0 1.0
4.搭建VGG-16网络

本次准备直接调用官方模型

# 搭建VGG-16网络模型  
model = tf.keras.applications.VGG16(weights="imagenet")  
print(model.summary())
-------------------------------
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg16/vgg16_weights_tf_dim_ordering_tf_kernels.h5
553467096/553467096 [==============================] - 14s 0us/step
Model: "vgg16"
_________________________________________________________________Layer (type)                Output Shape              Param #   
=================================================================input_1 (InputLayer)        [(None, 224, 224, 3)]     0         block1_conv1 (Conv2D)       (None, 224, 224, 64)      1792      block1_conv2 (Conv2D)       (None, 224, 224, 64)      36928     block1_pool (MaxPooling2D)  (None, 112, 112, 64)      0         block2_conv1 (Conv2D)       (None, 112, 112, 128)     73856     block2_conv2 (Conv2D)       (None, 112, 112, 128)     147584    block2_pool (MaxPooling2D)  (None, 56, 56, 128)       0         block3_conv1 (Conv2D)       (None, 56, 56, 256)       295168    block3_conv2 (Conv2D)       (None, 56, 56, 256)       590080    block3_conv3 (Conv2D)       (None, 56, 56, 256)       590080    block3_pool (MaxPooling2D)  (None, 28, 28, 256)       0         block4_conv1 (Conv2D)       (None, 28, 28, 512)       1180160   block4_conv2 (Conv2D)       (None, 28, 28, 512)       2359808   block4_conv3 (Conv2D)       (None, 28, 28, 512)       2359808   block4_pool (MaxPooling2D)  (None, 14, 14, 512)       0         block5_conv1 (Conv2D)       (None, 14, 14, 512)       2359808   block5_conv2 (Conv2D)       (None, 14, 14, 512)       2359808   block5_conv3 (Conv2D)       (None, 14, 14, 512)       2359808   block5_pool (MaxPooling2D)  (None, 7, 7, 512)         0         flatten (Flatten)           (None, 25088)             0         fc1 (Dense)                 (None, 4096)              102764544 fc2 (Dense)                 (None, 4096)              16781312  predictions (Dense)         (None, 1000)              4097000   =================================================================
Total params: 138,357,544
Trainable params: 138,357,544
Non-trainable params: 0
_________________________________________________________________

简简单单1亿的参数的模型。哈哈。

编译一下:

# 编译模型  
# 设置初始学习率  
initial_learning_rate = 1e-4  
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(  initial_learning_rate=initial_learning_rate,  decay_steps=30,  decay_rate=0.92,  staircase=True  
)  # 设置优化器  
opt = tf.keras.optimizers.Adam(learning_rate=initial_learning_rate)  model.compile(  optimizer=opt,  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),  metrics=['accuracy']  
)

训练模型:

# 训练模型  
epochs = 20  
history = model.fit(  train_ds,  validation_data=val_ds,  epochs=epochs,  
)

image.png
训练效果不错,可视化看看:

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']loss = history.history['loss']
val_loss = history.history['val_loss']epochs_range = range(epochs)plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

image.png
果然超赞。
改成动态学习率的结果:

opt = tf.keras.optimizers.Adam(learning_rate=lr_schedule)

image.png

5. 手动搭建VGG-16模型

image.png
image.png
image.png
VGG-16的网络 有13个卷积层(被5个max-pooling层分割)和3个全连接层(FC),所有卷积层过滤器的大小都是3X3,步长为1,进行padding。5个max-pooling层分别在第2、4、7、10,13卷积层后面。每次进行池化(max-pooling)后,特征图的长宽都缩小一半,但是channel都翻倍了,一直到512。最后三个全连接层大小分别是4096,4096, 1000,我们使用的是咖啡豆识别,根据数据集的类别数量修改最后的分类数量(即从1000改成len(class_names))


-----------------------------
Model: "model"
_________________________________________________________________Layer (type)                Output Shape              Param #   
=================================================================input_1 (InputLayer)        [(None, 224, 224, 3)]     0         block1_conv1 (Conv2D)       (None, 224, 224, 64)      1792      block1_conv2 (Conv2D)       (None, 224, 224, 64)      36928     block1_pool (MaxPooling2D)  (None, 112, 112, 64)      0         block2_conv1 (Conv2D)       (None, 112, 112, 128)     73856     block2_conv2 (Conv2D)       (None, 112, 112, 128)     147584    block2_pool (MaxPooling2D)  (None, 56, 56, 128)       0         block3_conv1 (Conv2D)       (None, 56, 56, 256)       295168    block3_conv2 (Conv2D)       (None, 56, 56, 256)       590080    block3_conv3 (Conv2D)       (None, 56, 56, 256)       590080    block3_pool (MaxPooling2D)  (None, 28, 28, 256)       0         block4_conv1 (Conv2D)       (None, 28, 28, 512)       1180160   block4_conv2 (Conv2D)       (None, 28, 28, 512)       2359808   block4_conv3 (Conv2D)       (None, 28, 28, 512)       2359808   block4_pool (MaxPooling2D)  (None, 14, 14, 512)       0         block5_conv1 (Conv2D)       (None, 14, 14, 512)       2359808   block5_conv2 (Conv2D)       (None, 14, 14, 512)       2359808   block5_conv3 (Conv2D)       (None, 14, 14, 512)       2359808   block5_pool (MaxPooling2D)  (None, 7, 7, 512)         0         flatten (Flatten)           (None, 25088)             0         fc1 (Dense)                 (None, 4096)              102764544 fc2 (Dense)                 (None, 4096)              16781312  predictions (Dense)         (None, 4)                 16388     =================================================================
Total params: 134,276,932
Trainable params: 134,276,932
Non-trainable params: 0
_________________________________________________________________

image.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/59507.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vue2项目中在线预览csv文件

简介 希望在项目中,在线预览.csv文件,本以为插件很多,结果都只是支持excel(.xls、.xlsx)一到.csv就歇菜。。。 关于文件预览 vue-office:文档、 查看在线演示demo,支持docx、.xlsx、pdf、ppt…

【Excel】身份证号最后一位“X”怎么计算

大多数人身份证号最后一位都是数字,但有个别号码最后一位却是“X"。 如果你查百度,会得到如下答案: 当最后一位编码是10的时候,因为多出一位,所以就用X替换。 可大多数人不知道的是,这个10是怎么来的…

【HAProxy09】企业级反向代理HAProxy高级功能之压缩功能与后端服务器健康性监测

HAProxy 高级功能 介绍 HAProxy 高级配置及实用案例 压缩功能 对响应给客户端的报文进行压缩,以节省网络带宽,但是会占用部分CPU性能 建议在后端服务器开启压缩功能,而非在HAProxy上开启压缩 注意:默认Ubuntu的包安装nginx开…

【Java Web】JSON 以及 JSON 转换

JSON(JavaScript Object Notation)一种灵活、高效、轻量级的数据交换格式,广泛应用于各种数据交换和存储场景。 基本特点 1、简单易用:JSON格式非常简单,易于理解和使用。 2、轻量级:相比XML等其他数据格…

第四十一章 Vue之初识VueX

目录 一、引言 1.1. vuex的概念 1.2. vuex使用场景 1.3. 优势 二、创建演示项目 2.1. 构建项目步骤 2.2. 项目最终生成结构 2.3. 创建项目文件 2.3.1. App.vue 2.3.2. Son1.vue 2.3.3. Son2.vue 三、创建一个空仓库 3.1. 安装vuex 3.2. 新建仓库 3.3. 挂载仓库…

编程之路,从0开始:内存函数

Hello大家好!很高兴我们又见面了。 给生活添点passion,开始今天的编程之路! 今天我们来讲C语言中的内存函数。 目录 1、memcpy内存复制 2、memmove可重叠内存拷贝 3、memset设置字符 4、memcmp比较 1、memcpy内存复制 memcpy就是内存复制…

【C语言】值传递和地址传递

值传递 引用传递(传地址,传引用)的区别 传值,是把实参的值赋值给行参 ,那么对行参的修改,不会影响实参的值。 传地址,是传值的一种特殊方式,只是他传递的是地址,不是普通…

Springboot采用jasypt加密配置

目录 前言 一、Jasypt简介 二、运用场景 三、整合Jasypt 2.1.环境配置 2.2.添加依赖 2.3.添加Jasypt配置 2.4.编写加/解密工具类 2.5.自定义加密属性前缀和后缀 2.6.防止密码泄露措施 2.61.自定义加密器 2.6.2通过环境变量指定加密盐值 总结 前言 在以往的多数项目中&#xff0…

axios平替!用浏览器自带的fetch处理AJAX(兼容表单/JSON/文件上传)

fetch 是啥? fetch 函数是 JavaScript 中用于发送网络请求的内置 API,可以替代传统的 XMLHttpRequest。它可以发送 HTTP 请求(如 GET、POST 等),并返回一个 Promise,从而简化异步操作 基本用法 /* 下面是…

贪吃蛇小游戏设计

贪吃蛇小游戏 1.引言1.1 背景1.2 目的1.3 意义1.4 任务1.5 技术可行性分析1.5.1执行平台1.5.2 语言特性与功能方面 2.需求分析2.1 环境需求2.2开发环境分析2.3游戏功能分析2.4 游戏性能分析2.5 数据流图2.6 数据字典 3.概要设计3.1 设计思路3.2 游戏界面设计3.3 总设计模块的划…

DNS面临的4大类共计11小类安全风险及防御措施

DNS在设计之初,并未考虑网络安全限制,导致了许多问题。DNS安全扩展(DNSSEC)协议的开发旨在解决DNS的安全漏洞,但其部署并不广泛,DNS仍面临各种攻击。接下来我们一起看下DNS都存在哪些安全攻击及缓解措施,旨在对DNS安全…

【大数据学习 | flume】flume的概述与组件的介绍

1. flume概述 Flume是cloudera(CDH版本的hadoop) 开发的一个分布式、可靠、高可用的海量日志收集系统。它将各个服务器中的数据收集起来并送到指定的地方去,比如说送到HDFS、Hbase,简单来说flume就是收集日志的。 Flume两个版本区别: ​ 1&…

Jmeter中的定时器(一)

定时器 1--固定定时器 功能特点 固定延迟:在每个请求之间添加固定的延迟时间。精确控制:可以精确控制请求的发送频率。简单易用:配置简单,易于理解和使用。 配置步骤 添加固定定时器 右键点击需要添加定时器的请求或线程组。选…

区块链技术在慈善捐赠中的应用

💓 博客主页:瑕疵的CSDN主页 📝 Gitee主页:瑕疵的gitee主页 ⏩ 文章专栏:《热点资讯》 区块链技术在慈善捐赠中的应用 区块链技术在慈善捐赠中的应用 区块链技术在慈善捐赠中的应用 引言 区块链技术概述 定义与原理 发…

[数组二分查找] 0074. 搜索二维矩阵

文章目录 1. 题目链接2. 题目大意3. 示例4. 解题思路5. 参考代码 1. 题目链接 74. 搜索二维矩阵 - 力扣(LeetCode) 2. 题目大意 描述:给定一个 mn 大小的有序二维矩阵 matrix。矩阵中每行元素从左到右升序排列,每列元素从上到下…

使用 Python 脚本在 Ansys Mechanical 中创建用于后处理的螺栓工具

介绍 由螺栓连接定义的接头在工业应用中非常普遍。在 Ansys Mechanical FEA 中分析它们是一种非常常见的做法。通过Object Generator或Bolt Tools Add-on,使用线体、梁连接甚至3D实体中的梁单元,在Ansys Mechanical中生成螺栓连接非常容易。定义螺栓联接…

【AI声音克隆整合包及教程】第二代GPT-SoVITS V2:创新与应用

一、引言 随着科技的迅猛发展,声音克隆技术已经成为一个炙手可热的研究领域。SoVITS(Sound Voice Intelligent Transfer System),作为该领域的先锋,凭借其卓越的性能和广泛的适用性,正在为多个行业带来前所…

python调用MySql详细步骤

一、下载MySql MySQL :: Download MySQL Installerhttps://dev.mysql.com/downloads/windows/installer/8.0.html点击上面链接,进入MySQL8.0的下载页面,选择离线安装包下载。 不需要登陆,直接点击下方的 No thanks,just start my download. …

《InsCode AI IDE:编程新时代的引领者》

《InsCode AI IDE:编程新时代的引领者》 一、InsCode AI IDE 的诞生与亮相二、独特功能与优势(一)智能编程体验(二)多语言支持与功能迭代 三、实际应用与案例(一)游戏开发案例(二&am…

华为路由策略配置

一、AS_Path过滤 要求: AR1与AR2、AR2与AR3之间建立EBGP连接 AS10的设备和AS30的设备无法相互通信 1.启动设备 2.配置IP地址 3.配置路由器的EBGP对等体连接,引入直连路由 [AR1]bgp 10 [AR1-bgp]router-id 1.1.1.1 [AR1-bgp]peer 200.1.2.2 as-nu…