卷积神经网络(CNN)注意力检测

文章目录

  • 一、前言
  • 二、前期工作
    • 1. 设置GPU(如果使用的是CPU可以忽略这步)
    • 2. 导入数据
    • 3. 查看数据
  • 二、数据预处理
    • 1.加载数据
    • 2. 可视化数据
    • 4. 配置数据集
  • 三、调用官方网络模型
  • 四、设置动态学习率
  • 五、编译
  • 六、训练模型
  • 七、模型评估
    • 1. Accuracy与Loss图
    • 2. 混淆矩阵
  • 八、保存and加载模型
  • 九、预测

一、前言

我的环境:

  • 语言环境:Python3.6.5
  • 编译器:jupyter notebook
  • 深度学习环境:TensorFlow2.4.1

往期精彩内容:

  • 卷积神经网络(CNN)实现mnist手写数字识别
  • 卷积神经网络(CNN)多种图片分类的实现
  • 卷积神经网络(CNN)衣服图像分类的实现
  • 卷积神经网络(CNN)鲜花识别
  • 卷积神经网络(CNN)天气识别
  • 卷积神经网络(VGG-16)识别海贼王草帽一伙
  • 卷积神经网络(ResNet-50)鸟类识别
  • 卷积神经网络(AlexNet)鸟类识别
  • 卷积神经网络(CNN)识别验证码
  • 卷积神经网络(CNN)车牌识别

来自专栏:机器学习与深度学习算法推荐

二、前期工作

1. 设置GPU(如果使用的是CPU可以忽略这步)

import tensorflow as tfgpus = tf.config.list_physical_devices("GPU")if gpus:tf.config.experimental.set_memory_growth(gpus[0], True)  #设置GPU显存用量按需使用tf.config.set_visible_devices([gpus[0]],"GPU")# 打印显卡信息,确认GPU可用
print(gpus)

2. 导入数据

import matplotlib.pyplot as plt
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号import os,PIL# 设置随机种子尽可能使结果可以重现
import numpy as np
np.random.seed(1)# 设置随机种子尽可能使结果可以重现
import tensorflow as tf
tf.random.set_seed(1)import pathlib
data_dir = "Eye_dataset"data_dir = pathlib.Path(data_dir)

3. 查看数据

image_count = len(list(data_dir.glob('*/*')))print("图片总数为:",image_count)
图片总数为: 4307

二、数据预处理

1.加载数据

batch_size = 64
img_height = 224
img_width = 224

使用image_dataset_from_directory方法将磁盘中的数据加载到tf.data.Dataset

train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="training",seed=12,image_size=(img_height, img_width),batch_size=batch_size)
Found 4307 files belonging to 4 classes.
Using 3446 files for training.
val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="validation",seed=12,image_size=(img_height, img_width),batch_size=batch_size)
Found 4307 files belonging to 4 classes.
Using 861 files for validation.

我们可以通过class_names输出数据集的标签。标签将按字母顺序对应于目录名称。

class_names = train_ds.class_names
print(class_names)
['close_look', 'forward_look', 'left_look', 'right_look']

2. 可视化数据

plt.figure(figsize=(10, 5))  # 图形的宽为10高为5
plt.suptitle("数据展示")for images, labels in train_ds.take(1):for i in range(8):ax = plt.subplot(2, 4, i + 1)  ax.patch.set_facecolor('yellow')plt.imshow(images[i].numpy().astype("uint8"))plt.title(class_names[labels[i]])plt.axis("off")

在这里插入图片描述

  1. 再次检查数据
for image_batch, labels_batch in train_ds:print(image_batch.shape)print(labels_batch.shape)break
(64, 224, 224, 3)
(64,)
  • Image_batch是形状的张量(8, 224, 224, 3)。这是一批形状240x240x3的8张图片(最后一维指的是彩色通道RGB)。
  • Label_batch是形状(8,)的张量,这些标签对应8张图片

4. 配置数据集

AUTOTUNE = tf.data.AUTOTUNEtrain_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds   = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

三、调用官方网络模型

model = tf.keras.applications.VGG16()
# 打印模型信息
model.summary()

四、设置动态学习率

这里先罗列一下学习率大与学习率小的优缺点。

  • 学习率大
    • 优点: 1、加快学习速率。 2、有助于跳出局部最优值。
    • 缺点: 1、导致模型训练不收敛。 2、单单使用大学习率容易导致模型不精确。
  • 学习率小
    • 优点: 1、有助于模型收敛、模型细化。 2、提高模型精度。
    • 缺点: 1、很难跳出局部最优值。 2、收敛缓慢。

注意:这里设置的动态学习率为:指数衰减型(ExponentialDecay)。在每一个epoch开始前,学习率(learning_rate)都将会重置为初始学习率(initial_learning_rate),然后再重新开始衰减。计算公式如下:

learning_rate = initial_learning_rate * decay_rate ^ (step / decay_steps)

# 设置初始学习率
initial_learning_rate = 1e-4lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate, decay_steps=5,      # 敲黑板!!!这里是指 steps,不是指epochsdecay_rate=0.96,     # lr经过一次衰减就会变成 decay_rate*lrstaircase=True)# 将指数衰减学习率送入优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)

五、编译

在准备对模型进行训练之前,还需要再对其进行一些设置。以下内容是在模型的编译步骤中添加的:

  • 损失函数(loss):用于衡量模型在训练期间的准确率。
  • 优化器(optimizer):决定模型如何根据其看到的数据和自身的损失函数进行更新。
  • 指标(metrics):用于监控训练和测试步骤。以下示例使用了准确率,即被正确分类的图像的比率。
model.compile(optimizer=optimizer,loss     ='sparse_categorical_crossentropy',metrics  =['accuracy'])

六、训练模型

epochs = 20history = model.fit(train_ds,validation_data=val_ds,epochs=epochs
)

七、模型评估

1. Accuracy与Loss图

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']loss = history.history['loss']
val_loss = history.history['val_loss']epochs_range = range(epochs)plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

2. 混淆矩阵

Seaborn 是一个画图库,它基于 Matplotlib 核心库进行了更高阶的 API 封装,可以让你轻松地画出更漂亮的图形。Seaborn 的漂亮主要体现在配色更加舒服、以及图形元素的样式更加细腻。

from sklearn.metrics import confusion_matrix
import seaborn as sns
import pandas as pd# 定义一个绘制混淆矩阵图的函数
def plot_cm(labels, predictions):# 生成混淆矩阵conf_numpy = confusion_matrix(labels, predictions)# 将矩阵转化为 DataFrameconf_df = pd.DataFrame(conf_numpy, index=class_names ,columns=class_names)  plt.figure(figsize=(8,7))sns.heatmap(conf_df, annot=True, fmt="d", cmap="BuPu")plt.title('混淆矩阵',fontsize=15)plt.ylabel('真实值',fontsize=14)plt.xlabel('预测值',fontsize=14)
val_pre   = []
val_label = []for images, labels in val_ds:#这里可以取部分验证数据(.take(1))生成混淆矩阵for image, label in zip(images, labels):# 需要给图片增加一个维度img_array = tf.expand_dims(image, 0) # 使用模型预测图片中的人物prediction = model.predict(img_array)val_pre.append(class_names[np.argmax(prediction)])val_label.append(class_names[label])
plot_cm(val_label, val_pre)

在这里插入图片描述

八、保存and加载模型

这是最简单的模型保存与加载方法哈

# 保存模型
model.save('model/16_model.h5')
# 加载模型
new_model = tf.keras.models.load_model('model/16_model.h5')

九、预测

九、预测
# 采用加载的模型(new_model)来看预测结果plt.figure(figsize=(10, 5))  # 图形的宽为10高为5
plt.suptitle("预测结果展示")for images, labels in val_ds.take(1):for i in range(8):ax = plt.subplot(2, 4, i + 1)  # 显示图片plt.imshow(images[i].numpy().astype("uint8"))# 需要给图片增加一个维度img_array = tf.expand_dims(images[i], 0) # 使用模型预测图片中的人物predictions = new_model.predict(img_array)plt.title(class_names[np.argmax(predictions)])plt.axis("off")

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/184050.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

外汇天眼:外汇市场中的“双向交易”是什么意思?

说到外汇市场,总免不了提到它双向交易的优势,很多新手会对这一点有所疑问,今天我们就帮大家解决这一个疑问。 何谓双向交易? 金融市场上,交易者最常接触到的股票,多属于单向交易。 单向交易的模式便是「先…

1145. 北极通讯网络(Kruskal,并查集维护)

北极的某区域共有 n 座村庄,每座村庄的坐标用一对整数 (x,y) 表示。 为了加强联系,决定在村庄之间建立通讯网络,使每两座村庄之间都可以直接或间接通讯。 通讯工具可以是无线电收发机,也可以是卫星设备。 无线电收发机有多种不…

java获取当天的开始时间和结束时间,并通过mybatis查询

一、一天中的开始时间 LocalDateTime date LocalDateTime.now(); // 获取一天的开始时间LocalDateTime startOfTheDay LocalDateTime.of(date.toLocalDate(), LocalTime.MIN);System.out.println("startOfTheDay " startOfTheDay);二、一天中的结束时间 // 获取…

MySQL之redo log

聊聊REDO LOG 为什么需要redolog? 那redolog主要是为了保证数据的持久化,我们知道innodb存储引擎中数据是以页为单位进行存储,每一个页中有很多行记录来存储数据,我们的数据最终是要持久化到硬盘中,那如果我们每进行…

MySQL修改已存在数据的字符集

在实际应用中,如果一开始没有正确的设置字符集,在运行一段时间以后,才发现当前字符集不能满足要求,需要进行调整,但又不想丢弃这段时间的数据,这个时候就需要修改字符集。 在MySQL设置默认字符集和校对规则…

【探索Linux】—— 强大的命令行工具 P.18(进程信号 —— 信号捕捉 | 信号处理 | sigaction() )

阅读导航 引言一、信号捕捉1. 内核实现信号捕捉过程2. sigaction() 函数(1)函数原型(2)参数说明(3)返回值(4)函数使用 二、可重入函数与不可重入函数1. 可重入函数条件2. 不可重入函…

Pytorch模型编译报错 UserWarning: (Resize(), RandomResizedCrop(), etc.)——解决办法

1、问题描述 使用Pytorch训练模型时,编译报错: UserWarning: The default value of the antialias parameter of all the resizing transforms (Resize(), RandomResizedCrop(), etc.) will change from None to True in v0.17, in order to be consis…

redis 存储一个map 怎么让map中其中一个值设置过期时间,而不是过期掉整个map?

🌷🍁 博主猫头虎(🐅🐾)带您 Go to New World✨🍁 🦄 博客首页——🐅🐾猫头虎的博客🎐 🐳 《面试题大全专栏》 🦕 文章图文并茂🦖生动形象🐅简单易学!欢迎大家来踩踩~🌺 🌊 《IDEA开发秘籍专栏》 🐾 学会IDEA常用操作,工作效率翻倍~💐 �…

linux socket套接字

文章目录 socket流socket(TCP)数据报socket(UDP) 讨论 socket 所谓套接字,就是对网络中不同主机上的应用程序之间进行双向通信的端点的抽象。一个套接字就是网络上进程通信的一端,套接字提供了应用层进程利…

【内网安全】搭建网络拓扑,CS内网横向移动实验

文章目录 搭建网络拓扑 ☁环境CS搭建,木马生成上传一句话,获取WebShellCS上线reGeorg搭建代理,访问内网域控IIS提权信息收集横向移动 实验拓扑结构如下: 搭建网络拓扑 ☁ 环境 **攻击者win10地址:**192.168.8.3 dmz win7地址&…

MyBatis教程之搭建(二)

1、开发环境 IDE&#xff1a;idea 2019.2 构建工具&#xff1a;maven 3.5.4 MySQL版本&#xff1a;MySQL 5.7 MyBatis版本&#xff1a;MyBatis 3.5.7 2、创建maven工程 <dependencies><!-- Mybatis核心 --><dependency><groupId>org.mybatis</…

VSCode 代码调试

断点调试&#xff08;debug&#xff09;&#xff1a; 指在程序的某一行设置一个断点&#xff0c;调试时&#xff0c;程序运行到这一行就会停住&#xff0c;然后你可以一步一步往下调试&#xff0c;调试过程中可以看各个变量当前的值&#xff0c;出错的话&#xff0c;调试到出错…

PostgreSQL-SQL联表查询LEFT JOIN 数据去重复

我们在使用left join联表查询时&#xff0c;如果table1中的一条记录对应了table2的多条记录&#xff0c;则会重复查出id相同的多条记录。 1、解决方法一 SELECT t1.* FROM table1 t1 LEFT JOIN table2 t2 ON t1.id t2.tid 第一种方法我们发现还是有重复数据 2、解决方法二…

Spark_spark hints 详细介绍

spark 中hints 的优先级高于&#xff0c;代码中的config, 高于spark_submit 中的commit Hints - Spark 3.5.0 Documentation

无限移动的风景 css3 动画

<style>*{margin:0;padding:0;/* box-sizing: border-box; */}ul{list-style: none;}#nav{width:900px;height:100px;border:2px solid rgb(70, 69, 69);margin:100px auto; overflow: hidden;}#nav ul{animation:moving 5s linear infinite;width:200%; /*怎么模拟动画…

【数据挖掘】国科大刘莹老师数据挖掘课程作业 —— 第二次作业

Written Part 1. 给定包含属性&#xff5b;Height, Hair, Eye&#xff5d;和两个类别&#xff5b;C1, C2&#xff5d;的数据集。构建基于信息增益&#xff08;info gain&#xff09;的决策树。 HeightHairEyeClass1TallBlondBrownC12TallDarkBlueC13TallDarkBrownC14ShortDark…

selenium写一个自动化程序查看作业的未提交名单

selenium写一个自动化程序 7.17 结合selenium写一个自动化程序&#xff0c;完成的功能&#xff08;看哪些人提交了&#xff0c;哪些人没提交&#xff09; 遇到的难点1&#xff1a;定位某一个用户名 解决方案1&#xff1a;元素定位并优化 经常使用的jupyter快捷键命令&#x…

Java实现简单的王者荣耀游戏

一、创建新项目 首先创建一个新的项目&#xff0c;并命名为wangzherongyao。 其次在飞翔的鸟项目下创建一个名为img的文件夹用来存放游戏相关图片。详细如下图&#xff1a; 二、游戏代码 1、创建怪物类 1.bear&#xff1a; package beast;import wangzherogyao.GameFrame;…

Java给钉钉机器人发送消息

示例代码 import org.slf4j.Logger; import org.slf4j.LoggerFactory;import java.io.OutputStream; import java.net.HttpURLConnection; import java.net.URL; import java.nio.charset.StandardCharsets;/*** author : yx-0176* description* date : 2023/11/27*/ public a…

数据结构与算法之递归: LeetCode 46. 全排列 (Typescript版)

全排列 https://leetcode.cn/problems/permutations/ 描述 给定一个不含重复数字的数组 nums &#xff0c;返回其 所有可能的全排列 。你可以 按任意顺序 返回答案。 示例 1 输入&#xff1a;nums [1,2,3] 输出&#xff1a;[[1,2,3],[1,3,2],[2,1,3],[2,3,1],[3,1,2],[3,…