5 Tensorflow图像识别(下)模型构建

上一篇:4 Tensorflow图像识别模型——数据预处理-CSDN博客

1、数据集标签

上一篇介绍了图像识别的数据预处理,下面是完整的代码:

import os
import tensorflow as tf# 获取训练集和验证集目录
train_dir = os.path.join('cats_and_dogs_filtered/train')
validation_dir = os.path.join('cats_and_dogs_filtered/validation')# 模型参数设置
BATCH_SIZE = 100# 图片尺寸统一为150*150
IMG_SHAPE = 150# 处理图像尺寸
img_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1. / 255, horizontal_flip=True, )train_data_gen = img_generator.flow_from_directory(directory=train_dir,shuffle=True,batch_size=BATCH_SIZE,target_size=(IMG_SHAPE, IMG_SHAPE),class_mode='binary')
val_data_gen = img_generator.flow_from_directory(directory=validation_dir,shuffle=True,batch_size=BATCH_SIZE,target_size=(IMG_SHAPE, IMG_SHAPE),class_mode='binary')

上一篇提到系统的输入是“特征-标签”对,特征是输入的图片,标签就是标记该图片是猫还是狗。上面的代码如何知道输入的照片是猫还是狗?

这里用到了keras的一个函数flow_from_directory(),从目录中生成数据流,子目录会自动帮你生成标签。先看看train训练集的这两个子目录生成的标签是什么:

使用下面代码查看

print(train_data_gen.class_indices)

运行结果:

Found 2000 images belonging to 2 classes.
Found 1000 images belonging to 2 classes.
{'cats': 0, 'dogs': 1}

从运行结果可以看到,猫的照片系统自动打上了0的标签,狗的标签是1。

2、Relu激活函数

构建模型的完整代码如下:

import os
import tensorflow as tf
import numpy as np# 获取训练集和验证集目录
train_dir = os.path.join('cats_and_dogs_filtered/train')
validation_dir = os.path.join('cats_and_dogs_filtered/validation')# 模型参数设置
BATCH_SIZE = 100# 图片尺寸统一为150*150
IMG_SHAPE = 150# 处理图像尺寸
img_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1. / 255, horizontal_flip=True, )# 训练数据
train_data_gen = img_generator.flow_from_directory(directory=train_dir,shuffle=True,batch_size=BATCH_SIZE,target_size=(IMG_SHAPE, IMG_SHAPE),class_mode='binary')# 验证数据
val_data_gen = img_generator.flow_from_directory(directory=validation_dir,shuffle=True,batch_size=BATCH_SIZE,target_size=(IMG_SHAPE, IMG_SHAPE),class_mode='binary')model = tf.keras.Sequential([tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(150, 150, 3)),tf.keras.layers.MaxPooling2D(2, 2),tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),tf.keras.layers.MaxPooling2D(2, 2),tf.keras.layers.Conv2D(100, (3, 3), activation='relu'),tf.keras.layers.MaxPooling2D(2, 2),tf.keras.layers.Flatten(),tf.keras.layers.Dense(512, activation='relu'),tf.keras.layers.Dense(2)])model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])EPOCHS = 20
history = model.fit_generator(train_data_gen,steps_per_epoch=int(np.ceil(2000 / float(BATCH_SIZE))),epochs=EPOCHS,validation_data=val_data_gen,validation_steps=int(np.ceil(1000 / float(BATCH_SIZE)))
)

model中加入了和之前不一样的代码:

tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(150, 150, 3)),

这里使用了卷积神经,主要是为了突出区分不同对象的特征。一张图片的信息很多的,但往往我们只需要一些特征进行训练就可以了,后续会详细介绍。

现在先介绍 activation='relu',激活函数Relu。

ReLU,全称是线性整流函数(Rectified Linear Unit),是人工神经网络中常用的激活函数。它的图像如下:

当x<=0时,f(x)=0;

当x>0时,f(x)=x;

可以运行代码看看:

例1:

import tensorflow as tfx = -19
print(tf.nn.relu(x))

运行结果:
tf.Tensor(0, shape=(), dtype=int32)

输入-19,使用relu激活函数后的结果为0

例2:

import tensorflow as tfx = 8
print(tf.nn.relu(x))

运行结果:

tf.Tensor(8, shape=(), dtype=int32)

3、损失函数

代码:


model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy']
              )
 

其中损失函数为SparseCategoricalCrossentropy,它是用于计算多分类问题的交叉熵,如果是两个或两个以上的分类问题可以始终这样设置。对其原理及计算过程的读者可以自行百度,此处不详细介绍。

4、训练过程详解
(1)训练准确率

运行上面的完整代码:

可以看到训练集和验证集的loss值在慢慢下降,准确率在提升。

划线部分是最后一个epoch的训练结果:

accuracy:0.8175,也就是说你的神经网络在分类训练数据方面的准确率约为82%;

val_accuracy:0.7160,在验证集的准确率约为72%

(2)batch_size批次大小

代码中batche_size设置的大小为100,意思是每批次生成的样本数量为100。

例如上述代码的train训练集一共有2000张图片,一个周期(epoch)分20个批次(2000/100=20)样本数据进行训练,每个批次训练完后利用优化器更新模型参数。

所以一个周期(epoch)的模型参数更新次数就是20:2000/batch_size=20

截图中红色部分,就是一个epoch分了20个批次用来更新模型参数。

训练结果会因为模型的参数的设置、训练集图片的数量等等原因结果大不相同,学习的时候可以自己动手去调整模型参数来看看训练结果。



 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/135262.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

AI+BI行业数字化转型研讨会 - 总结精华回顾

带您一起观看研讨会精彩内容回顾&#xff01; || 导语 AIBI行业数字化转型研讨会—引领未来&#xff0c;智慧转型 德昂信息技术(北京)有限公司于2023年10月26日成功举办了AIBI行业数字化转型研讨会。此次盛会汇聚了产业精英、企业领袖以及技术专家&#xff0c;共同探讨在快速…

Python的编码规范:PEP 8介绍及基本遵循原则

文章目录 PEP 8简介基本遵循原则1. 缩进2. 行宽3. 空行4. 导入5. 空格6. 命名约定7. 表达式和语句中的空格8. 注释9. 编码声明10. 文档字符串PEP 8简介 PEP 8,或Python Enhancement Proposal 8,是一个官方文档,发布于2001年。它由Guido van Rossum,Python语言的创始人,以…

前端框架Vue学习 ——(二)Vue常用指令

文章目录 常用指令 常用指令 指令: HTML 标签上带有 “v-” 前缀的特殊属性&#xff0c;不同指令具有不同含义。例如: v-if, v-for… 常用指令&#xff1a; v-bind&#xff1a;为 HTML 标签绑定属性值&#xff0c;如设置 href&#xff0c;css 样式等 <a v-bind:href"…

Spark 新特性+核心回顾

Spark 新特性核心 本文来自 B站 黑马程序员 - Spark教程 &#xff1a;原地址 1. 掌握Spark的Shuffle流程 1.1 Spark Shuffle Map和Reduce 在Shuffle过程中&#xff0c;提供数据的称之为Map端&#xff08;Shuffle Write&#xff09;接收数据的称之为Reduce端&#xff08;Sh…

MybatisPlus之新增操作并返回主键ID

在应用mybatisplus持久层框架的项目中&#xff0c;经常遇到执行新增操作后需要获取主键ID的场景&#xff0c;下面将分析及测试过程记录分享出来。 1、MybatisPlus新增方法 持久层新增方法源码如下&#xff1a; public interface BaseMapper<T> extends Mapper<T> …

js处理赎金信

给你两个字符串&#xff1a;ransomNote 和 magazine &#xff0c;判断 ransomNote 能不能由 magazine 里面的字符构成。 如果可以&#xff0c;返回 true &#xff1b;否则返回 false 。 magazine 中的每个字符只能在 ransomNote 中使用一次。 示例 1&#xff1a; 输入&…

自动控制原理--面试问答题

以下文中的&#xff0c;例如 s_1 为 s下角标1。面试加油&#xff01; 控制系统的三要素&#xff1a;稳准快。稳&#xff0c;系统最后不能震荡、发散&#xff0c;一定要收敛于某一个值&#xff1b;快&#xff0c;能够迅速达到系统的预设值&#xff1b;准&#xff0c;最后稳态值…

一台电脑生成两个ssh,绑定两个GitHub账号

背景 一般一台电脑账号生成一个ssh绑定一个GitHub&#xff0c;即一一对应的关系&#xff01;我之前有一个账号也配置了ssh&#xff0c;但是我想经营两个GitHub账号&#xff0c;当我用https url clone新账号的仓库时&#xff0c;直接超时。所以想起了配置ssh。于是有了今天这篇…

【自然语言处理】利用python创建简单的聊天系统

一&#xff0c;实现原理 代码设计了一个简单的客户端-服务器聊天应用程序&#xff0c;建立了两个脚本文件&#xff08;.py文件)&#xff0c;其中有一个客户端和一个服务器端。客户端和服务器之间通过网络连接进行通信&#xff0c;客户端发送消息&#xff0c;服务器端接收消息并…

django+drf+vue 简单系统搭建 (2) - drf 应用

按照本系统设置目的&#xff0c;是为了建立一些工具用来处理简单的文件。 1. 准备djangorestframework 关于drf的说明请参见&#xff1a;Django REST Framework教程 | 大江狗的博客 本系列直接使用drf的序列化等其他功能。 安装 conda install djangorestframework conda i…

VSCode使用插件Github Copilot进行AI编程

演示示例 函数封装 根据上下文 根据注释 详情请看GitHub Copilot 安装插件 在VS Code中安装插件 GitHub Copilot 登录账号 点击VS code左下角账户图标&#xff0c;点击【Sign in】&#xff0c;会自动在浏览器打开Github登录页&#xff0c;登录具有 Github Copilot 服务的…

网工内推 | 上市公司,云平台运维,IP认证优先,13薪

01 上海新炬网络信息技术股份有限公司 招聘岗位&#xff1a;云平台运维工程师 职责描述&#xff1a; 1、负责云平台运维&#xff0c;包括例行巡检、版本发布、问题及故障处理、平台重保等&#xff0c;保障平台全年稳定运行&#xff1b; 2、参与制定运维标准规范与流程&#x…

混沌系统在图像加密中的应用(基于哈密顿能量函数的混沌系统构造1.1)

混沌系统在图像加密中的应用&#xff08;基于哈密顿能量函数的混沌系统构造1.1&#xff09; 前言一、基于广义哈密顿系统的一类混沌系统构造1.基本动力学特性分析2.数值分析 待续 前言 本文的主题是“基于哈密顿能量函数的混沌系统构造”&#xff0c;哈密顿能量函数是是全文研…

案例 - 拖拽上传文件,生成缩略图

直接看效果 实现代码 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>拖拽上传文件</title>&l…

PTA_乙级_1002

思路&#xff1a;不仅超出int还超出Longlong,直接用string类型定义n&#xff0c;for循环来遍历每一位字符然后转换成数字进行累加&#xff0c;再用to_string把数字和转换成字符串&#xff0c;再用for循环把数字和的每一位定位到pinyin字符串数组上输出 #include <iostream&…

人工智能(AI)是一种快速发展的技术,其未来发展前景非常广阔。

人工智能&#xff08;AI&#xff09;是一种快速发展的技术&#xff0c;其未来发展前景非常广阔。以下是一些关于AI未来的可能发展方向和就业前景的详细说明&#xff1a; 1.机器学习工程师&#xff1a;机器学习是AI的核心技术之一&#xff0c;它涉及到从数据中自动学习模式并进…

使用Python爬虫被封ip的解决方案

在使用 Python 程序进行网络爬虫开发时&#xff0c;可能会因为下面原因导致被封IP或封禁爬虫程序&#xff1a; 1、频繁访问网站 爬虫程序可能会在很短的时间内访问网站很多次&#xff0c;从而对目标网站造成较大的负担和压力&#xff0c;这种行为容易引起目标网站的注意并被封…

C语言【趣编程】我们怎样便捷输出空心的金字塔

目录 1问题&#xff1a; 2解题思路&#xff1a; 3代码如下&#xff1a; 4代码运行结果如下图所示&#xff1a; 5总结&#xff1a; r如若后续有不会的问题&#xff0c;可以和我私聊&#xff1b; 1问题&#xff1a; 2解题思路&#xff1a; 方法&#xff1a;找规律&#xff0…

Flask(Jinja2) 服务端模板注入漏洞(SSTI)

Flask&#xff08;Jinja2&#xff09; 服务端模板注入漏洞(SSTI) 参考 https://www.freebuf.com/articles/web/260504.html 验证漏洞存在 ?name{{7*7}} 回显49说明漏洞存在 vulhub给出的payload: {% for c in [].__class__.__base__.__subclasses__() %} {% if c.__name__…

【uniapp+vue3/vue2】ksp-cropper高性能图片裁剪工具,详解

效果图&#xff1a; 1、ksp-cropper是hbuilder插件市场中的一款插件&#xff0c;兼容vue2和vue3 ksp-cropper插件安装地址&#xff0c;直接点击跳转 2、插件用法相对简单 &#xff08;1&#xff09;只要url有值就会显示插件&#xff0c;为空就会隐藏插件 &#xff08;2&#…