第T7周:Tensorflow实现咖啡豆识别

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

目标

具体实现

(一)环境

语言环境:Python 3.10
编 译 器: PyCharm
框 架:

(二)具体步骤
1. 使用GPU
--------------------------utils.py-------------------
import tensorflow as tf  
import PIL  
import matplotlib.pyplot as plt  def GPU_ON():  # 查询tensorflow版本  print("Tensorflow Version:", tf.__version__)  # 设置使用GPU  gpus = tf.config.list_physical_devices("GPU")  print(gpus)  if gpus:  gpu0 = gpus[0]  # 如果有多个GPU,仅使用第0个GPU  tf.config.experimental.set_memory_growth(gpu0, True)  # 设置GPU显存按需使用  tf.config.set_visible_devices([gpu0], "GPU")

使用GPU并查看数据

import tensorflow as tf  
from tensorflow import keras  
import numpy as np  
import matplotlib.pyplot as plt  
import os, PIL, pathlib  
from utils import GPU_ONGPU_ON()  data_dir = "./datasets/coffee/"  
data_dir = pathlib.Path(data_dir)  image_count = len(list(data_dir.glob("*/*.png")))  
print("图片总数量为:", image_count)
------------------
图片总数量为: 1200
2. 加载数据
# 加载数据  
batch_size = 32  
img_height, img_width = 224, 224  train_ds = tf.keras.preprocessing.image_dataset_from_directory(  data_dir,  validation_split=0.2,  subset="training",  seed=123,  image_size=(img_height, img_width),  batch_size=batch_size,  
)  val_ds = tf.keras.preprocessing.image_dataset_from_directory(  data_dir,  validation_split=0.2,  subset="validation",  seed=123,  image_size=(img_height, img_width),  batch_size=batch_size,  
)
--------------------
Found 1200 files belonging to 4 classes.
Using 960 files for training.
Found 1200 files belonging to 4 classes.
Using 240 files for validation.

获取标签:

# 获取标签  
class_names = train_ds.class_names  
print(class_names)
------------------
['Dark', 'Green', 'Light', 'Medium']

可视化数据:

# 可视化数据  
plt.figure(figsize=(10, 10))  
for images, labels in train_ds.take(2):  for i in range(30):  ax = plt.subplot(5, 6, i+1)  plt.imshow(images[i].numpy().astype("uint8"))  plt.title(class_names[labels[i]])  plt.axis("off")  
plt.show()


检查一下数据:

# 检查一下数据  
for image_batch, labels_batch in train_ds:  print(image_batch.shape)  print(labels_batch.shape)  break
----------------------------
(32, 224, 224, 3)
(32,)
**3.**配置数据集
# 配置数据集  
AUTOTUNE = tf.data.AUTOTUNE  
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)  
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)  normalization_layer = layers.experimental.preprocessing.Rescaling(1./255)  
train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))  
val_ds = val_ds.map(lambda x, y: (normalization_layer(x), y))  image_batch, labels_batch = next(iter(train_ds))  
first_image = image_batch[0]  # 查看归一化后的数据  
print(np.min(first_image), np.max(first_image))
--------------------
0.0 1.0
4.搭建VGG-16网络

本次准备直接调用官方模型

# 搭建VGG-16网络模型  
model = tf.keras.applications.VGG16(weights="imagenet")  
print(model.summary())
-------------------------------
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg16/vgg16_weights_tf_dim_ordering_tf_kernels.h5
553467096/553467096 [==============================] - 14s 0us/step
Model: "vgg16"
_________________________________________________________________Layer (type)                Output Shape              Param #   
=================================================================input_1 (InputLayer)        [(None, 224, 224, 3)]     0         block1_conv1 (Conv2D)       (None, 224, 224, 64)      1792      block1_conv2 (Conv2D)       (None, 224, 224, 64)      36928     block1_pool (MaxPooling2D)  (None, 112, 112, 64)      0         block2_conv1 (Conv2D)       (None, 112, 112, 128)     73856     block2_conv2 (Conv2D)       (None, 112, 112, 128)     147584    block2_pool (MaxPooling2D)  (None, 56, 56, 128)       0         block3_conv1 (Conv2D)       (None, 56, 56, 256)       295168    block3_conv2 (Conv2D)       (None, 56, 56, 256)       590080    block3_conv3 (Conv2D)       (None, 56, 56, 256)       590080    block3_pool (MaxPooling2D)  (None, 28, 28, 256)       0         block4_conv1 (Conv2D)       (None, 28, 28, 512)       1180160   block4_conv2 (Conv2D)       (None, 28, 28, 512)       2359808   block4_conv3 (Conv2D)       (None, 28, 28, 512)       2359808   block4_pool (MaxPooling2D)  (None, 14, 14, 512)       0         block5_conv1 (Conv2D)       (None, 14, 14, 512)       2359808   block5_conv2 (Conv2D)       (None, 14, 14, 512)       2359808   block5_conv3 (Conv2D)       (None, 14, 14, 512)       2359808   block5_pool (MaxPooling2D)  (None, 7, 7, 512)         0         flatten (Flatten)           (None, 25088)             0         fc1 (Dense)                 (None, 4096)              102764544 fc2 (Dense)                 (None, 4096)              16781312  predictions (Dense)         (None, 1000)              4097000   =================================================================
Total params: 138,357,544
Trainable params: 138,357,544
Non-trainable params: 0
_________________________________________________________________

简简单单1亿的参数的模型。哈哈。

编译一下:

# 编译模型  
# 设置初始学习率  
initial_learning_rate = 1e-4  
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(  initial_learning_rate=initial_learning_rate,  decay_steps=30,  decay_rate=0.92,  staircase=True  
)  # 设置优化器  
opt = tf.keras.optimizers.Adam(learning_rate=initial_learning_rate)  model.compile(  optimizer=opt,  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),  metrics=['accuracy']  
)

训练模型:

# 训练模型  
epochs = 20  
history = model.fit(  train_ds,  validation_data=val_ds,  epochs=epochs,  
)

image.png
训练效果不错,可视化看看:

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']loss = history.history['loss']
val_loss = history.history['val_loss']epochs_range = range(epochs)plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

image.png
果然超赞。
改成动态学习率的结果:

opt = tf.keras.optimizers.Adam(learning_rate=lr_schedule)

image.png

5. 手动搭建VGG-16模型

image.png
image.png
image.png
VGG-16的网络 有13个卷积层(被5个max-pooling层分割)和3个全连接层(FC),所有卷积层过滤器的大小都是3X3,步长为1,进行padding。5个max-pooling层分别在第2、4、7、10,13卷积层后面。每次进行池化(max-pooling)后,特征图的长宽都缩小一半,但是channel都翻倍了,一直到512。最后三个全连接层大小分别是4096,4096, 1000,我们使用的是咖啡豆识别,根据数据集的类别数量修改最后的分类数量(即从1000改成len(class_names))


-----------------------------
Model: "model"
_________________________________________________________________Layer (type)                Output Shape              Param #   
=================================================================input_1 (InputLayer)        [(None, 224, 224, 3)]     0         block1_conv1 (Conv2D)       (None, 224, 224, 64)      1792      block1_conv2 (Conv2D)       (None, 224, 224, 64)      36928     block1_pool (MaxPooling2D)  (None, 112, 112, 64)      0         block2_conv1 (Conv2D)       (None, 112, 112, 128)     73856     block2_conv2 (Conv2D)       (None, 112, 112, 128)     147584    block2_pool (MaxPooling2D)  (None, 56, 56, 128)       0         block3_conv1 (Conv2D)       (None, 56, 56, 256)       295168    block3_conv2 (Conv2D)       (None, 56, 56, 256)       590080    block3_conv3 (Conv2D)       (None, 56, 56, 256)       590080    block3_pool (MaxPooling2D)  (None, 28, 28, 256)       0         block4_conv1 (Conv2D)       (None, 28, 28, 512)       1180160   block4_conv2 (Conv2D)       (None, 28, 28, 512)       2359808   block4_conv3 (Conv2D)       (None, 28, 28, 512)       2359808   block4_pool (MaxPooling2D)  (None, 14, 14, 512)       0         block5_conv1 (Conv2D)       (None, 14, 14, 512)       2359808   block5_conv2 (Conv2D)       (None, 14, 14, 512)       2359808   block5_conv3 (Conv2D)       (None, 14, 14, 512)       2359808   block5_pool (MaxPooling2D)  (None, 7, 7, 512)         0         flatten (Flatten)           (None, 25088)             0         fc1 (Dense)                 (None, 4096)              102764544 fc2 (Dense)                 (None, 4096)              16781312  predictions (Dense)         (None, 4)                 16388     =================================================================
Total params: 134,276,932
Trainable params: 134,276,932
Non-trainable params: 0
_________________________________________________________________

image.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/59507.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vue2项目中在线预览csv文件

简介 希望在项目中,在线预览.csv文件,本以为插件很多,结果都只是支持excel(.xls、.xlsx)一到.csv就歇菜。。。 关于文件预览 vue-office:文档、 查看在线演示demo,支持docx、.xlsx、pdf、ppt…

【Excel】身份证号最后一位“X”怎么计算

大多数人身份证号最后一位都是数字,但有个别号码最后一位却是“X"。 如果你查百度,会得到如下答案: 当最后一位编码是10的时候,因为多出一位,所以就用X替换。 可大多数人不知道的是,这个10是怎么来的…

【HAProxy09】企业级反向代理HAProxy高级功能之压缩功能与后端服务器健康性监测

HAProxy 高级功能 介绍 HAProxy 高级配置及实用案例 压缩功能 对响应给客户端的报文进行压缩,以节省网络带宽,但是会占用部分CPU性能 建议在后端服务器开启压缩功能,而非在HAProxy上开启压缩 注意:默认Ubuntu的包安装nginx开…

【Java Web】JSON 以及 JSON 转换

JSON(JavaScript Object Notation)一种灵活、高效、轻量级的数据交换格式,广泛应用于各种数据交换和存储场景。 基本特点 1、简单易用:JSON格式非常简单,易于理解和使用。 2、轻量级:相比XML等其他数据格…

第四十一章 Vue之初识VueX

目录 一、引言 1.1. vuex的概念 1.2. vuex使用场景 1.3. 优势 二、创建演示项目 2.1. 构建项目步骤 2.2. 项目最终生成结构 2.3. 创建项目文件 2.3.1. App.vue 2.3.2. Son1.vue 2.3.3. Son2.vue 三、创建一个空仓库 3.1. 安装vuex 3.2. 新建仓库 3.3. 挂载仓库…

gitlab-development-kit部署gitlab《二》

gitlab-development-kit部署gitlab《一》 环境 mac 12.7.4 xcode 14.2 gdk 0.2.16 gitlab-foss 13.7 QA xcode源码安装 # https://crifan.github.io/xcode_dev_summary/website/xcode_dev/install_xcode/ # https://xcodereleases.comopenssl1.1 源码安装 # https://open…

编程之路,从0开始:内存函数

Hello大家好!很高兴我们又见面了。 给生活添点passion,开始今天的编程之路! 今天我们来讲C语言中的内存函数。 目录 1、memcpy内存复制 2、memmove可重叠内存拷贝 3、memset设置字符 4、memcmp比较 1、memcpy内存复制 memcpy就是内存复制…

【C语言】值传递和地址传递

值传递 引用传递(传地址,传引用)的区别 传值,是把实参的值赋值给行参 ,那么对行参的修改,不会影响实参的值。 传地址,是传值的一种特殊方式,只是他传递的是地址,不是普通…

摘要与登记

10.15:mysql 10.16:redis, 10.17:k8s,netty,dubbo,设计模式 10.18:juc、 10.21:rabbitMQ、ElasticSearch 10.22:docker 10.23:k8s 10.24:springsecurity 10.30:spring事务 11.01:mysql 11.05:redis 11.06:k8s 11.07:netty、docker 11.08:设计模式 11.09:juc 11.11:rabbitMQ、sp…

Springboot采用jasypt加密配置

目录 前言 一、Jasypt简介 二、运用场景 三、整合Jasypt 2.1.环境配置 2.2.添加依赖 2.3.添加Jasypt配置 2.4.编写加/解密工具类 2.5.自定义加密属性前缀和后缀 2.6.防止密码泄露措施 2.61.自定义加密器 2.6.2通过环境变量指定加密盐值 总结 前言 在以往的多数项目中&#xff0…

axios平替!用浏览器自带的fetch处理AJAX(兼容表单/JSON/文件上传)

fetch 是啥? fetch 函数是 JavaScript 中用于发送网络请求的内置 API,可以替代传统的 XMLHttpRequest。它可以发送 HTTP 请求(如 GET、POST 等),并返回一个 Promise,从而简化异步操作 基本用法 /* 下面是…

贪吃蛇小游戏设计

贪吃蛇小游戏 1.引言1.1 背景1.2 目的1.3 意义1.4 任务1.5 技术可行性分析1.5.1执行平台1.5.2 语言特性与功能方面 2.需求分析2.1 环境需求2.2开发环境分析2.3游戏功能分析2.4 游戏性能分析2.5 数据流图2.6 数据字典 3.概要设计3.1 设计思路3.2 游戏界面设计3.3 总设计模块的划…

go T 泛型

目录 1、类型约束 2、泛型函数 3、泛型结构体 4、泛型接口 5、以接口作为类型约束 关键词:泛型、类型参数、类型约束 Go 语言在 1.18 版本引入了泛型(Generics)特性,可以编写更通用、可复用的代码,泛型可以用于&a…

如何处理 iOS 客户端内 Webview H5 中后台播放的音视频问题

目录 问题描述Page Visibility API 的应用什么是 Page Visibility API?使用 Page Visibility API 暂停音视频完整解决方案1. 监听媒体的播放和暂停事件2. 防止自动播放3. 结合 Intersection Observer 进行媒体控制4. 手动处理应用生命周期中的事件 问题描述 在 iOS…

Matplotlib库中show()函数的用法

在Matplotlib库中使用show()函数是用于显示绘制的图形的函数。它将图形显示在屏幕上或保存到文件中。show()函数通常在绘制完图形后调用。 Matplotlib是一个用于绘制2D图形的Python库,它提供了丰富的绘图工具和函数,可以用于创建各种类型的图表&#xf…

DNS面临的4大类共计11小类安全风险及防御措施

DNS在设计之初,并未考虑网络安全限制,导致了许多问题。DNS安全扩展(DNSSEC)协议的开发旨在解决DNS的安全漏洞,但其部署并不广泛,DNS仍面临各种攻击。接下来我们一起看下DNS都存在哪些安全攻击及缓解措施,旨在对DNS安全…

蓝队知识浅谈(中)

声明:学习视频来自b站up主 泷羽sec,如涉及侵权马上删除文章 感谢泷羽sec 团队的教学 视频地址:蓝队基础之网络七层杀伤链_哔哩哔哩_bilibili 本文主要分享一些蓝队相关的知识。 一、网络杀伤链 网络杀伤链(Cyber Kill Chain&…

vue2在el-dialog打开的时候使该el-dialog中的某个输入框获得焦点方法总结

在 Vue 2 中,如果你想通过 ref 调用一个方法(如 inputFocus)来聚焦一个输入框,确保以下几点: 确保 ref 的设置正确:你需要确保在模板中正确设置了 ref,并且它指向了你想要操作的组件或 DOM 元素…

【大数据学习 | flume】flume的概述与组件的介绍

1. flume概述 Flume是cloudera(CDH版本的hadoop) 开发的一个分布式、可靠、高可用的海量日志收集系统。它将各个服务器中的数据收集起来并送到指定的地方去,比如说送到HDFS、Hbase,简单来说flume就是收集日志的。 Flume两个版本区别: ​ 1&…

Jmeter中的定时器(一)

定时器 1--固定定时器 功能特点 固定延迟:在每个请求之间添加固定的延迟时间。精确控制:可以精确控制请求的发送频率。简单易用:配置简单,易于理解和使用。 配置步骤 添加固定定时器 右键点击需要添加定时器的请求或线程组。选…