卷积神经网络(CNN)注意力检测

文章目录

  • 一、前言
  • 二、前期工作
    • 1. 设置GPU(如果使用的是CPU可以忽略这步)
    • 2. 导入数据
    • 3. 查看数据
  • 二、数据预处理
    • 1.加载数据
    • 2. 可视化数据
    • 4. 配置数据集
  • 三、调用官方网络模型
  • 四、设置动态学习率
  • 五、编译
  • 六、训练模型
  • 七、模型评估
    • 1. Accuracy与Loss图
    • 2. 混淆矩阵
  • 八、保存and加载模型
  • 九、预测

一、前言

我的环境:

  • 语言环境:Python3.6.5
  • 编译器:jupyter notebook
  • 深度学习环境:TensorFlow2.4.1

往期精彩内容:

  • 卷积神经网络(CNN)实现mnist手写数字识别
  • 卷积神经网络(CNN)多种图片分类的实现
  • 卷积神经网络(CNN)衣服图像分类的实现
  • 卷积神经网络(CNN)鲜花识别
  • 卷积神经网络(CNN)天气识别
  • 卷积神经网络(VGG-16)识别海贼王草帽一伙
  • 卷积神经网络(ResNet-50)鸟类识别
  • 卷积神经网络(AlexNet)鸟类识别
  • 卷积神经网络(CNN)识别验证码
  • 卷积神经网络(CNN)车牌识别

来自专栏:机器学习与深度学习算法推荐

二、前期工作

1. 设置GPU(如果使用的是CPU可以忽略这步)

import tensorflow as tfgpus = tf.config.list_physical_devices("GPU")if gpus:tf.config.experimental.set_memory_growth(gpus[0], True)  #设置GPU显存用量按需使用tf.config.set_visible_devices([gpus[0]],"GPU")# 打印显卡信息,确认GPU可用
print(gpus)

2. 导入数据

import matplotlib.pyplot as plt
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号import os,PIL# 设置随机种子尽可能使结果可以重现
import numpy as np
np.random.seed(1)# 设置随机种子尽可能使结果可以重现
import tensorflow as tf
tf.random.set_seed(1)import pathlib
data_dir = "Eye_dataset"data_dir = pathlib.Path(data_dir)

3. 查看数据

image_count = len(list(data_dir.glob('*/*')))print("图片总数为:",image_count)
图片总数为: 4307

二、数据预处理

1.加载数据

batch_size = 64
img_height = 224
img_width = 224

使用image_dataset_from_directory方法将磁盘中的数据加载到tf.data.Dataset

train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="training",seed=12,image_size=(img_height, img_width),batch_size=batch_size)
Found 4307 files belonging to 4 classes.
Using 3446 files for training.
val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="validation",seed=12,image_size=(img_height, img_width),batch_size=batch_size)
Found 4307 files belonging to 4 classes.
Using 861 files for validation.

我们可以通过class_names输出数据集的标签。标签将按字母顺序对应于目录名称。

class_names = train_ds.class_names
print(class_names)
['close_look', 'forward_look', 'left_look', 'right_look']

2. 可视化数据

plt.figure(figsize=(10, 5))  # 图形的宽为10高为5
plt.suptitle("数据展示")for images, labels in train_ds.take(1):for i in range(8):ax = plt.subplot(2, 4, i + 1)  ax.patch.set_facecolor('yellow')plt.imshow(images[i].numpy().astype("uint8"))plt.title(class_names[labels[i]])plt.axis("off")

在这里插入图片描述

  1. 再次检查数据
for image_batch, labels_batch in train_ds:print(image_batch.shape)print(labels_batch.shape)break
(64, 224, 224, 3)
(64,)
  • Image_batch是形状的张量(8, 224, 224, 3)。这是一批形状240x240x3的8张图片(最后一维指的是彩色通道RGB)。
  • Label_batch是形状(8,)的张量,这些标签对应8张图片

4. 配置数据集

AUTOTUNE = tf.data.AUTOTUNEtrain_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds   = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

三、调用官方网络模型

model = tf.keras.applications.VGG16()
# 打印模型信息
model.summary()

四、设置动态学习率

这里先罗列一下学习率大与学习率小的优缺点。

  • 学习率大
    • 优点: 1、加快学习速率。 2、有助于跳出局部最优值。
    • 缺点: 1、导致模型训练不收敛。 2、单单使用大学习率容易导致模型不精确。
  • 学习率小
    • 优点: 1、有助于模型收敛、模型细化。 2、提高模型精度。
    • 缺点: 1、很难跳出局部最优值。 2、收敛缓慢。

注意:这里设置的动态学习率为:指数衰减型(ExponentialDecay)。在每一个epoch开始前,学习率(learning_rate)都将会重置为初始学习率(initial_learning_rate),然后再重新开始衰减。计算公式如下:

learning_rate = initial_learning_rate * decay_rate ^ (step / decay_steps)

# 设置初始学习率
initial_learning_rate = 1e-4lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate, decay_steps=5,      # 敲黑板!!!这里是指 steps,不是指epochsdecay_rate=0.96,     # lr经过一次衰减就会变成 decay_rate*lrstaircase=True)# 将指数衰减学习率送入优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)

五、编译

在准备对模型进行训练之前,还需要再对其进行一些设置。以下内容是在模型的编译步骤中添加的:

  • 损失函数(loss):用于衡量模型在训练期间的准确率。
  • 优化器(optimizer):决定模型如何根据其看到的数据和自身的损失函数进行更新。
  • 指标(metrics):用于监控训练和测试步骤。以下示例使用了准确率,即被正确分类的图像的比率。
model.compile(optimizer=optimizer,loss     ='sparse_categorical_crossentropy',metrics  =['accuracy'])

六、训练模型

epochs = 20history = model.fit(train_ds,validation_data=val_ds,epochs=epochs
)

七、模型评估

1. Accuracy与Loss图

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']loss = history.history['loss']
val_loss = history.history['val_loss']epochs_range = range(epochs)plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

2. 混淆矩阵

Seaborn 是一个画图库,它基于 Matplotlib 核心库进行了更高阶的 API 封装,可以让你轻松地画出更漂亮的图形。Seaborn 的漂亮主要体现在配色更加舒服、以及图形元素的样式更加细腻。

from sklearn.metrics import confusion_matrix
import seaborn as sns
import pandas as pd# 定义一个绘制混淆矩阵图的函数
def plot_cm(labels, predictions):# 生成混淆矩阵conf_numpy = confusion_matrix(labels, predictions)# 将矩阵转化为 DataFrameconf_df = pd.DataFrame(conf_numpy, index=class_names ,columns=class_names)  plt.figure(figsize=(8,7))sns.heatmap(conf_df, annot=True, fmt="d", cmap="BuPu")plt.title('混淆矩阵',fontsize=15)plt.ylabel('真实值',fontsize=14)plt.xlabel('预测值',fontsize=14)
val_pre   = []
val_label = []for images, labels in val_ds:#这里可以取部分验证数据(.take(1))生成混淆矩阵for image, label in zip(images, labels):# 需要给图片增加一个维度img_array = tf.expand_dims(image, 0) # 使用模型预测图片中的人物prediction = model.predict(img_array)val_pre.append(class_names[np.argmax(prediction)])val_label.append(class_names[label])
plot_cm(val_label, val_pre)

在这里插入图片描述

八、保存and加载模型

这是最简单的模型保存与加载方法哈

# 保存模型
model.save('model/16_model.h5')
# 加载模型
new_model = tf.keras.models.load_model('model/16_model.h5')

九、预测

九、预测
# 采用加载的模型(new_model)来看预测结果plt.figure(figsize=(10, 5))  # 图形的宽为10高为5
plt.suptitle("预测结果展示")for images, labels in val_ds.take(1):for i in range(8):ax = plt.subplot(2, 4, i + 1)  # 显示图片plt.imshow(images[i].numpy().astype("uint8"))# 需要给图片增加一个维度img_array = tf.expand_dims(images[i], 0) # 使用模型预测图片中的人物predictions = new_model.predict(img_array)plt.title(class_names[np.argmax(predictions)])plt.axis("off")

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/184050.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

外汇天眼:外汇市场中的“双向交易”是什么意思?

说到外汇市场,总免不了提到它双向交易的优势,很多新手会对这一点有所疑问,今天我们就帮大家解决这一个疑问。 何谓双向交易? 金融市场上,交易者最常接触到的股票,多属于单向交易。 单向交易的模式便是「先…

1145. 北极通讯网络(Kruskal,并查集维护)

北极的某区域共有 n 座村庄,每座村庄的坐标用一对整数 (x,y) 表示。 为了加强联系,决定在村庄之间建立通讯网络,使每两座村庄之间都可以直接或间接通讯。 通讯工具可以是无线电收发机,也可以是卫星设备。 无线电收发机有多种不…

MySQL之redo log

聊聊REDO LOG 为什么需要redolog? 那redolog主要是为了保证数据的持久化,我们知道innodb存储引擎中数据是以页为单位进行存储,每一个页中有很多行记录来存储数据,我们的数据最终是要持久化到硬盘中,那如果我们每进行…

MySQL修改已存在数据的字符集

在实际应用中,如果一开始没有正确的设置字符集,在运行一段时间以后,才发现当前字符集不能满足要求,需要进行调整,但又不想丢弃这段时间的数据,这个时候就需要修改字符集。 在MySQL设置默认字符集和校对规则…

【探索Linux】—— 强大的命令行工具 P.18(进程信号 —— 信号捕捉 | 信号处理 | sigaction() )

阅读导航 引言一、信号捕捉1. 内核实现信号捕捉过程2. sigaction() 函数(1)函数原型(2)参数说明(3)返回值(4)函数使用 二、可重入函数与不可重入函数1. 可重入函数条件2. 不可重入函…

Pytorch模型编译报错 UserWarning: (Resize(), RandomResizedCrop(), etc.)——解决办法

1、问题描述 使用Pytorch训练模型时,编译报错: UserWarning: The default value of the antialias parameter of all the resizing transforms (Resize(), RandomResizedCrop(), etc.) will change from None to True in v0.17, in order to be consis…

linux socket套接字

文章目录 socket流socket(TCP)数据报socket(UDP) 讨论 socket 所谓套接字,就是对网络中不同主机上的应用程序之间进行双向通信的端点的抽象。一个套接字就是网络上进程通信的一端,套接字提供了应用层进程利…

【内网安全】搭建网络拓扑,CS内网横向移动实验

文章目录 搭建网络拓扑 ☁环境CS搭建,木马生成上传一句话,获取WebShellCS上线reGeorg搭建代理,访问内网域控IIS提权信息收集横向移动 实验拓扑结构如下: 搭建网络拓扑 ☁ 环境 **攻击者win10地址:**192.168.8.3 dmz win7地址&…

VSCode 代码调试

断点调试(debug): 指在程序的某一行设置一个断点,调试时,程序运行到这一行就会停住,然后你可以一步一步往下调试,调试过程中可以看各个变量当前的值,出错的话,调试到出错…

PostgreSQL-SQL联表查询LEFT JOIN 数据去重复

我们在使用left join联表查询时,如果table1中的一条记录对应了table2的多条记录,则会重复查出id相同的多条记录。 1、解决方法一 SELECT t1.* FROM table1 t1 LEFT JOIN table2 t2 ON t1.id t2.tid 第一种方法我们发现还是有重复数据 2、解决方法二…

无限移动的风景 css3 动画

<style>*{margin:0;padding:0;/* box-sizing: border-box; */}ul{list-style: none;}#nav{width:900px;height:100px;border:2px solid rgb(70, 69, 69);margin:100px auto; overflow: hidden;}#nav ul{animation:moving 5s linear infinite;width:200%; /*怎么模拟动画…

【数据挖掘】国科大刘莹老师数据挖掘课程作业 —— 第二次作业

Written Part 1. 给定包含属性&#xff5b;Height, Hair, Eye&#xff5d;和两个类别&#xff5b;C1, C2&#xff5d;的数据集。构建基于信息增益&#xff08;info gain&#xff09;的决策树。 HeightHairEyeClass1TallBlondBrownC12TallDarkBlueC13TallDarkBrownC14ShortDark…

Java实现简单的王者荣耀游戏

一、创建新项目 首先创建一个新的项目&#xff0c;并命名为wangzherongyao。 其次在飞翔的鸟项目下创建一个名为img的文件夹用来存放游戏相关图片。详细如下图&#xff1a; 二、游戏代码 1、创建怪物类 1.bear&#xff1a; package beast;import wangzherogyao.GameFrame;…

a-table:表格组件常用功能记录——基础积累2

antdvue是我目前项目的主流&#xff0c;在工作过程中&#xff0c;经常用到table组件。下面就记录一下工作中经常用到的部分知识点。 a-table&#xff1a;表格组件常用功能记录——基础积累2 效果图1.table 点击行触发点击事件1.1 实现单选 点击事件1.2 实现多选 点击事件1.3 实…

知识社区问答平台源码系统 开源的知识问答平台 附带完整的搭建教程

互联网的快速发展&#xff0c;人们对于知识的需求越来越高。知识社区问答平台源码系统是一款基于开源框架搭建的知识问答平台&#xff0c;旨在帮助人们快速、准确地获取所需知识&#xff0c;提高学习效率。 以下是部分代码示例&#xff1a; 系统特色功能一览&#xff1a; 1.知…

什么是消息队列

什么是消息队列 MQ(message queue)&#xff0c;从字面意思上看&#xff0c;本质是个队列&#xff0c;FIFO 先入先出队列&#xff0c;只不过队列中存放的内容是 message 而已&#xff0c;还是一种跨进程的通信机制&#xff0c;用于上下游传递消息。在互联网架构中&#xff0c;M…

二叉树leetcode(求二叉树深度问题)

today我们来练习三道leetcode上的有关于二叉树的题目&#xff0c;都是一些基础的二叉树题目&#xff0c;那让我们一起来学习一下吧。 https://leetcode.cn/problems/maximum-depth-of-binary-tree/submissions/ 看题目描述是让我们来求出二叉树的深度&#xff0c;我们以第一个父…

HT for Web (Hightopo) 使用心得(5)- 动画的实现

其实&#xff0c;在 HT for Web 中&#xff0c;有多种手段可以用来实现动画。我们这里仍然用直升机为例&#xff0c;只是更换了场景。增加了巡游过程。 使用 HT 开发的一个简单网页直升机巡逻动画&#xff08;Hightopo 使用心得&#xff08;5&#xff09;&#xff09; 这里主…

UWB高精度定位系统项目源码

在现代社会中&#xff0c;精准定位技术对于各行各业都至关重要。为了满足对高精度定位的需求&#xff0c;超宽带&#xff08;Ultra-Wideband, UWB&#xff09;技术应运而生。UWB高精度定位系统以其出色的定位精度和多样化的应用领域而备受关注。本文将深入探讨UWB高精度定位系统…