TensorFlow2实战-系列教程3:猫狗识别1

🧡💛💚TensorFlow2实战-系列教程 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Jupyter Notebook中进行
本篇文章配套的代码资源已经上传

1、项目介绍

基本流程:

  • 数据预处理:图像数据处理,准备训练和验证数据集
  • 卷积网络模型:构建网络架构
  • 过拟合问题:观察训练和验证效果,针对过拟合问题提出解决方法
  • 数据增强:图像数据增强方法与效果
  • 迁移学习:深度学习必备训练策略

在我们的数据中,有训练和验证,训练集中分别有猫狗两个类别,都有1000张图像,验证集则有500张

2、数据读取

import os
import warnings
warnings.filterwarnings("ignore")
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator# 数据所在文件夹
base_dir = './data/cats_and_dogs'
train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')# 训练集
train_cats_dir = os.path.join(train_dir, 'cats')
train_dogs_dir = os.path.join(train_dir, 'dogs')# 验证集
validation_cats_dir = os.path.join(validation_dir, 'cats')
validation_dogs_dir = os.path.join(validation_dir, 'dogs')
  1. 导包
  2. 指定数据路径
  3. 训练数据路径
  4. 验证数据路径
  5. 训练数据猫类别路径
  6. 训练数据狗类别路径
  7. 验证数据猫类别路径
  8. 训练数据狗类别路径

3、构建卷积神经网络

model = tf.keras.models.Sequential([#如果训练慢,可以把数据设置的更小一些tf.keras.layers.Conv2D(32, (3,3), activation='relu', input_shape=(64, 64, 3)),tf.keras.layers.MaxPooling2D(2, 2),tf.keras.layers.Conv2D(64, (3,3), activation='relu'),tf.keras.layers.MaxPooling2D(2,2),tf.keras.layers.Conv2D(128, (3,3), activation='relu'),tf.keras.layers.MaxPooling2D(2,2),#为全连接层准备tf.keras.layers.Flatten(),tf.keras.layers.Dense(512, activation='relu'),# 二分类sigmoid就够了tf.keras.layers.Dense(1, activation='sigmoid')
])

3个3x3卷积,穿插3个2x2池化,拉平操作,两个全连接层

model.summary()

打印一下模型架构:

配置训练器:

model.compile(loss='binary_crossentropy', optimizer=Adam(lr=1e-4), metrics=['acc'])

4、数据预处理

  • 读进来的数据会被自动转换成tensor(float32)格式,分别准备训练和验证
  • 图像数据归一化(0-1)区间
train_datagen = ImageDataGenerator(rescale=1./255)
test_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(train_dir,  # 文件夹路径target_size=(64, 64),  # 指定resize成的大小batch_size=20,# 如果one-hot就是categorical,二分类用binary就可以class_mode='binary')validation_generator = test_datagen.flow_from_directory(validation_dir,target_size=(64, 64),batch_size=20,class_mode='binary')

打印结果:
Found 2000 images belonging to 2 classes.
Found 1000 images belonging to 2 classes.

5、模型训练

  • 直接fit也可以,但是通常咱们不能把所有数据全部放入内存,fit_generator相当于一个生成器,动态产生所需的batch数据
  • steps_per_epoch相当给定一个停止条件,因为生成器会不断产生batch数据,说白了就是它不知道一个epoch里需要执行多少个step
history = model.fit_generator(train_generator,steps_per_epoch=100,  # 2000 images = batch_size * stepsepochs=20,validation_data=validation_generator,validation_steps=50,  # 1000 images = batch_size * stepsverbose=2)

部分打印结果:
Epoch 1/20 100/100 - 9s - loss: 0.6909 - acc: 0.5240 - val_loss: 0.6952 - val_acc: 0.5000
Epoch 2/20 100/100 - 9s - loss: 0.6645 - acc: 0.5960 - val_loss: 0.6906 - val_acc: 0.5360

Epoch 19/20 100/100 - 9s - loss: 0.1750 - acc: 0.9460 - val_loss: 0.6277 - val_acc: 0.7390
Epoch 20/20 100/100 - 9s - loss: 0.1593 - acc: 0.9505 - val_loss: 0.5901 - val_acc: 0.7490

6、预测效果展示

import matplotlib.pyplot as plt
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']epochs = range(len(acc))plt.plot(epochs, acc, 'bo', label='Training accuracy')
plt.plot(epochs, val_acc, 'b', label='Validation accuracy')
plt.title('Training and validation accuracy')plt.figure()plt.plot(epochs, loss, 'bo', label='Training Loss')
plt.plot(epochs, val_loss, 'b', label='Validation Loss')
plt.title('Training and validation loss')
plt.legend()plt.show()

在这里插入图片描述
在这里插入图片描述
将训练损失、准确率和对应的epoch分别画图展示

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/649728.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Web3:B站chainlink课程Lesson5遇到的小坑汇总

ethers代码 我用的ethers.js 6 ,和视频里一样用的是5的不用看代码部分 ethers.providers.JsonRpcProvider("server") //无了 ethers.JsonRpcProvider("server") //现在的wallet.getTransactionCount() //无了 wallet.getNonce() //现在的Big…

已解决:安卓,怎么优雅接入科大讯飞语音评测功能?

写在前面: 网上关于讯飞接入的博客都很少,按说讯飞都是业界翘楚,不知为何,很少搜索到精品,一搜就是一个要求开会员的博客,我也是醉了。讯飞提供的文档也是不清晰,我是摸着石头过河,…

java集合ArrayList和HashSet的fail-fast与fail-safe以及ConcurrentModificationException

在 java 的集合工具类中&#xff0c;例如对 ArrayList 或者 HashSet 进行删除元素后再遍历元素时&#xff0c;会抛出 ConcurrentModificationException 异常。 fail-fast ArrayList public class TestList {public static void main(String[] args) {ArrayList<Integer>…

【iOS ARKit】BlendShapes

BlendShapes 基础介绍 利用前置摄像头采集到的用户面部表情特征&#xff0c;ARKit 提供了一种更加抽象的表示面部表情的方式&#xff0c;这种表示方式叫作 BlendShapes,BlendShapes 可以翻译成形状融合&#xff0c;在3ds Max 中也叫变形器&#xff0c;这个概念原本用于描述通过…

Ubuntu18编译jdk8源码

环境 系统 ubuntu18 Linux ubuntu 5.4.0-150-generic #167~18.04.1-Ubuntu SMP Wed May 24 00:51:42 UTC 2023 x86_64 x86_64 x86_64 GNU/Linux jdk源码openjdk-8u41-src-b04-14_jan_2020.zip bootJdk jdk-8u391-linux-x64.tar.gz ps -e|grep ssh sudo apt-get install ssh…

【MATLAB第92期】基于MATLAB的集成聚合多输入单输出回归预测方法(LSBoost、Bag)含自动优化超参数和特征敏感性分析功能

【MATLAB第92期】基于MATLAB的集成聚合多输入单输出回归预测方法&#xff08;LSBoost、Bag&#xff09;含自动优化超参数和特征敏感性分析功能 本文展示多种非常用多输入单输出回归预测模型效果。 注&#xff1a;每次运行数据训练集测试集为随机&#xff0c;故对比不严谨&…

京东广告算法架构体系建设--在线模型系统分布式异构计算演变 | 京东零售广告技术团队

一、现状介绍 算法策略在广告行业中起着重要的作用&#xff0c;它可以帮助广告主和广告平台更好地理解用户行为和兴趣&#xff0c;从而优化广告投放策略&#xff0c;提高广告点击率和转化率。模型系统作为承载算法策略的载体&#xff0c;目前承载搜索、推荐、首焦、站外等众多广…

Word插入音乐视频文件快速方法 exe zip doc apk txt pdf bat等

需求&#xff1a; Word插入文件有哪些极限操作&#xff1f;如何快速插入音乐视频等文件 问题解决&#xff1a; 使用拖动进行文件快速插入&#xff08;PPT Excle 同理&#xff09; 操作 1.让文件和word界面处于同一屏幕&#xff0c;可以使用分屏 2.鼠标选中文件左键或者使用笔…

一些反序列化总结

1 反序列化漏洞原理 如果反序列化的内容就是那串字符串&#xff0c;是用户可以控制的&#xff08;即变量的值&#xff09;&#xff0c;且后台不正当的使用了PHP中的魔法函数&#xff0c;就会导致反序列化漏洞&#xff0c;可以执行任意命令。Java 序列化指 Java 对象转换为字节序…

Flink问题解决及性能调优-【Flink根据不同场景状态后端使用调优】

Flink 实时groupby聚合场景操作时&#xff0c;由于使用的是rocksdb状态后端&#xff0c;发现CPU的高负载卡在rocksdb的读写上&#xff0c;导致上游算子背压特别大。通过调优使用hashmap状态后端代替rocksdb状态后端&#xff0c;使吞吐量有了质的飞跃&#xff08;20倍的性能提升…

2024年,你是否还在迷茫?

2024年&#xff0c;你是否还在迷茫&#xff1f; 别担心&#xff01;鸿蒙来了&#xff0c;这个未来技术的制高点&#xff0c;为你提供了答案&#xff01; 诸多大厂疯抢、24年预计鸿蒙相关的岗位需求将达到百万级、就业均薪达到19K&#xff0c;全国高校开课…… 种种现象都在表…

VirtualBox安装Ubuntu22.04

目录 1、新建虚拟机 1.1、设置内存大小 1.2、创建虚拟硬盘 2、虚拟机设置 2.1、设置启动顺序​编辑 2.2、选择iso镜像文件 2.3、设置网络(桥接网卡) 3、启动 3.1、设置语言环境 3.2、系统更新安装(不更新) 3.3、选择键盘布局(默认即可) 3.4、选择安装类型 3.5、网…

硬件知识(1) 手机的长焦镜头

#灵感# 手机总是配备好几个镜头&#xff0c;研究一下 目录 手机常配备的摄像头&#xff0c;及效果举例 长焦的焦距 焦距的定义和示图&#xff1a; IPC的焦距和适用场景&#xff1a; 手机常配备的摄像头&#xff0c;及效果举例 以下是小米某个手机的摄像头介绍&#xff1a…

EXCEL VBA抓取网页JSON数据并解析

EXCEL VBA抓取网页JSON数据并解析 链接地址&#xff1a; https://api.api68.com/CQShiCai/getBaseCQShiCaiList.do?lotCode10036&date2024-01-26 Sub test() On Error Resume Next Sheet.Select Sheet1.Cells.ClearContents [a1:g1] Split("preDrawIssue|preDrawTi…

用Visual Studio Code创建JavaScript运行环境【2024版】

用Visual Studio Code创建JavaScript运行环境 JavaScript 的历史 JavaScript 最初被称为 LiveScript&#xff0c;由 Netscape&#xff08;Netscape Communications Corporation&#xff0c;网景通信公司&#xff09;公司的布兰登艾奇&#xff08;Brendan Eich&#xff09;在 …

mysql 存储过程学习

存储过程介绍 1.1 SQL指令执行过程 从SQL执行的流程中我们分析存在的问题: 1.如果我们需要重复多次执行相同的SQL&#xff0c;SQL执行都需要通过连接传递到MySQL&#xff0c;并且需要经过编译和执行的步骤; 2.如果我们需要执行多个SQL指令&#xff0c;并且第二个SQL指令需要…

Topaz Video AI:无损放大,让你的视频更清晰!

在当今的数字时代&#xff0c;视频内容的重要性越来越受到人们的关注。无论是在社交媒体上分享生活片段&#xff0c;还是在商业领域中制作宣传视频&#xff0c;人们都希望能够展现出更高质量的视频内容。 然而&#xff0c;由于各种原因&#xff0c;我们经常会面临一个问题&…

C++版QT:分割窗口

目录 mainwindow.h mainwindow.cpp main.cpp Qt的分割窗口功能允许用户将一个窗口分割成多个区域&#xff0c;每个区域可以独立地显示不同的内容。这种功能在许多应用程序中非常有用&#xff0c;例如编辑器、浏览器和IDE等。 理解Qt的分割窗口&#xff0c;需要从以下几个方面…

音频格式之AAC:(2)AAC封装格式ADIF,ADTS,LATM,extradata及AAC ES存储格式

系列文章目录 音频格式的介绍文章系列&#xff1a; 音频编解码格式介绍(1) ADPCM&#xff1a;adpcm编解码原理及其代码实现 音频编解码格式介绍(2) MP3 &#xff1a;音频格式之MP3&#xff1a;(1)MP3封装格式简介 音频编解码格式介绍(2) MP3 &#xff1a;音频格式之MP3&#x…

IDEA jdk版本切换问题

打开 IntelliJ IDEA 的 Project Structure&#xff08;快捷键通常是 Ctrl Alt Shift S&#xff09;。 转到 Project Settings > Modules。 选择相应的模块&#xff0c;然后在 Sources 标签页下&#xff0c;查看 Language level 是否设置为 自己需要的jdk版本语言。 接…