使用 CNN 训练自己的数据集

CNN(练习数据集)

  • 1.导包:
  • 2.导入数据集:
  • 3. 使用image_dataset_from_directory()将数据加载tf.data.Dataset中:
  • 4. 查看数据集中的一部分图像,以及它们对应的标签:
  • 5.迭代数据集 train_ds,以便查看第一批图像和标签的形状:
  • 6.使用TensorFlow的ImageDataGenerator类来创建一个数据增强的对象:
  • 7.将数据集缓存到内存中,加快速度:
  • 8. 通过卷积层和池化层提取特征,再通过全连接层进行分类:
  • 9.打印网络结构:
  • 10.设置优化器,定义了训练轮次和批量大小:
  • 11.训练数据集:
  • 12.画出图像:
  • 13.评估您的模型在验证数据集的性能:
  • 14.输出在验证集上的预测结果和真实值的对比:
  • 15.输出可视化报表:

  • 在网上寻找一个新的数据集,自己进行训练

1.导包:

import pandas as pd
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import train_test_split
import numpy as np
from sklearn.preprocessing import LabelBinarizer
import matplotlib.pyplot as plt
import pickle
import pathlib
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models

输出结果:
在这里插入图片描述

2.导入数据集:

# 定义超参数
data_dir = "D:\JUANJI"
data_dir = pathlib.Path(data_dir)
image_count = len(list(data_dir.glob('*/*.jpg')))
print("图片总数为:", image_count)
batch_size = 30
img_height = 180
img_width = 180

输出结果:
在这里插入图片描述

3. 使用image_dataset_from_directory()将数据加载tf.data.Dataset中:

#  使用image_dataset_from_directory()将数据加载到tf.data.Dataset中
train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,  # 验证集0.2subset="training",seed=123,image_size=(img_height, img_width),batch_size=batch_size)val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="validation",seed=123,image_size=(img_height, img_width),batch_size=batch_size)

输出结果:
在这里插入图片描述

4. 查看数据集中的一部分图像,以及它们对应的标签:

class_names = train_ds.class_names
print(class_names)
# 可视化
plt.figure(figsize=(16, 8))
for images, labels in train_ds.take(1):for i in range(16):ax = plt.subplot(4, 4, i + 1)# plt.imshow(images[i], cmap=plt.cm.binary)plt.imshow(images[i].numpy().astype("uint8"))plt.title(class_names[labels[i]])plt.axis("off")
plt.show()

输出结果:
在这里插入图片描述
在这里插入图片描述

5.迭代数据集 train_ds,以便查看第一批图像和标签的形状:

for image_batch, labels_batch in train_ds:print(image_batch.shape)print(labels_batch.shape)break

输出结果:
在这里插入图片描述

6.使用TensorFlow的ImageDataGenerator类来创建一个数据增强的对象:

aug = ImageDataGenerator(rotation_range=30, width_shift_range=0.1,height_shift_range=0.1, shear_range=0.2, zoom_range=0.2,horizontal_flip=True, fill_mode="nearest")
x = aug.flow(image_batch, labels_batch)
AUTOTUNE = tf.data.AUTOTUNE

输出结果:
在这里插入图片描述

7.将数据集缓存到内存中,加快速度:

train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

输出结果:
在这里插入图片描述

8. 通过卷积层和池化层提取特征,再通过全连接层进行分类:

# 为了增加模型的泛化能力,增加了Dropout层,并将最大池化层更新为平均池化层
num_classes = 3
model = models.Sequential([layers.experimental.preprocessing.Rescaling(1./255,input_shape=(img_height,img_width, 3)),layers.Conv2D(32, (3, 3), activation='relu'),layers.MaxPooling2D((2, 2)),layers.Conv2D(64, (3, 3), activation='relu'),layers.MaxPooling2D((2, 2)),layers.Conv2D(128, (3, 3), activation='relu'),layers.MaxPooling2D((2, 2)),layers.Conv2D(256, (3, 3), activation='relu'),layers.MaxPooling2D((2, 2)),layers.Flatten(),layers.Dense(512, activation='relu'),layers.Dense(num_classes)
])

输出结果:
在这里插入图片描述

9.打印网络结构:

model.summary()

输出结果:
在这里插入图片描述

10.设置优化器,定义了训练轮次和批量大小:

# 设置优化器
opt = tf.keras.optimizers.Adam(learning_rate=0.001)model.compile(optimizer=opt,loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])EPOCHS = 100
BS = 5

输出结果:
在这里插入图片描述

11.训练数据集:

# 训练网络
# model.fit 可同时处理训练和即时扩充的增强数据。
# 我们必须将训练数据作为第一个参数传递给生成器。生成器将根据我们先前进行的设置生成批量的增强训练数据。
for images_train, labels_train in train_ds:continue
for images_test, labels_test in val_ds:continue
history = model.fit(x=aug.flow(images_train,labels_train, batch_size=BS),validation_data=(images_test,labels_test),
steps_per_epoch=1,epochs=EPOCHS)

输出结果:
在这里插入图片描述

12.画出图像:

# 画出训练精确度和损失图
N = np.arange(0, EPOCHS)
plt.style.use("ggplot")
plt.figure()
plt.plot(N, history.history["loss"], label="train_loss")
plt.plot(N, history.history["val_loss"], label="val_loss")
plt.plot(N, history.history["accuracy"], label="train_acc")
plt.plot(N, history.history["val_accuracy"], label="val_acc")
plt.title("Aug Training Loss and Accuracy")
plt.xlabel("Epoch #")
plt.ylabel("Loss/Accuracy")
plt.legend(loc='upper right')  # legend显示位置
plt.show()

输出结果:
在这里插入图片描述

13.评估您的模型在验证数据集的性能:

test_loss, test_acc = model.evaluate(val_ds, verbose=2)
print(test_loss, test_acc)

输出结果:
在这里插入图片描述

14.输出在验证集上的预测结果和真实值的对比:

#  优化2 输出在验证集上的预测结果和真实值的对比
pre = model.predict(val_ds)
for images, labels in val_ds.take(1):for i in range(4):ax = plt.subplot(1, 4, i + 1)plt.imshow(images[i].numpy().astype("uint8"))plt.title(class_names[labels[i]])plt.xticks([])plt.yticks([])# plt.xlabel('pre: ' + class_names[np.argmax(pre[i])] + ' real: ' + class_names[labels[i]])plt.xlabel('pre: ' + class_names[np.argmax(pre[i])])print('pre: ' + str(class_names[np.argmax(pre[i])]) + ' real: ' + class_names[labels[i]])
plt.show()

输出结果:
在这里插入图片描述

15.输出可视化报表:

print(labels_test)
print(labels)
print(pre)
print(class_names)
from sklearn.metrics import classification_report
# 优化1 输出可视化报表
print(classification_report(labels_test,pre.argmax(axis=1),
target_names=class_names))

输出结果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/19424.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Shell 编程之免交互

一:Here Document 免交互 1.1:概述 Here Document 是一个特殊用途的代码块。它在 Linux Shell 中使用 I/O 重定向的方式将命令列表提供给交互式程序或命令,比如 ftp 、 cat 或 read 命令。 Here Document 是标准输入的一种替代品&a…

MySQL学习——在批处理模式下使用mysql

除了交互式地使用mysql来输入语句并查看结果。也可以以批处理模式运行mysql。为此&#xff0c;将你想要运行的语句放入一个文件中&#xff0c;然后告诉mysql从该文件读取输入&#xff1a; $> mysql < batch-file 如果你在Windows下运行mysql&#xff0c;并且文件中包含…

【前端每日基础】day31——uni-app

uni-app 开发详细介绍 基本概念 uni-app&#xff1a;uni-app 是一个使用 Vue.js 开发多端应用的框架&#xff0c;可以编译到微信小程序、支付宝小程序、百度小程序、字节跳动小程序、H5、App等多个平台。 跨平台&#xff1a;一次开发&#xff0c;多端部署。通过条件编译实现多…

【漏洞复现】DT-高清车牌识别摄像机 任意文件读取漏洞

0x01 产品简介 DT-高清 车牌识别摄像机是一款先进的安防设备&#xff0c;采用高清图像传感器和先进的识别算法&#xff0c;能够精准、快速地识别车牌信息。其高清晰该摄像机结合了智能识别技术&#xff0c;支持实时监宴图像质量确保在各种光照和天气条件下都能准确捕捉车牌信息…

【面试八股总结】MySQL事务:事务特性、事务并行、事务的隔离级别

参考资料&#xff1a;小林coding 一、事务的特性ACID 原子性&#xff08;Atomicity&#xff09; 一个事务是一个不可分割的工作单位&#xff0c;事务中的所有操作&#xff0c;要么全部完成&#xff0c;要么全部不完成&#xff0c;不会结束在中间某个环节。原子性是通过 undo …

CSS-in-JS学习

CSS-in-JS CSS-in-JS 是一种将样式直接写入JavaScript代码中的方法,它通常与React、Vue等现代前端框架结合使用。 1. 什么是CSS-in-JS? CSS-in-JS 是一种编写样式的方法,它允许开发者在JavaScript组件内部定义样式,通常使用类似于CSS的语法。这种方式提高了代码的可复用…

C#根据数据量自动排版标签的样例

这是一个C#根据数据量自动排版标签的样例 using System; using System.Collections.Generic; using System.Data.SqlClient; using System.Drawing; using System.Text; using System.Threading; using System.Threading.Tasks; using System.Windows.Forms; using HslCommuni…

go mongo 唯一索引创建

1. 登录mongo&#xff0c;创建数据库 mongosh -u $username -p $password use test 2. 查看集合索引 db.$collection_name.getIndexes() 为不存在的集合创建字段唯一索引 package mainimport ("context""fmt""log""time""go…

代码随想录算法训练营第四十五天 | 1049. 最后一块石头的重量 II、494. 目标和、474.一和零

1049. 最后一块石头的重量 II 视频讲解&#xff1a; 动态规划之背包问题&#xff0c;这个背包最多能装多少&#xff1f;LeetCode&#xff1a;1049.最后一块石头的重量II_哔哩哔哩_bilibili 代码随想录 解题思路 直接将这一些石头&#xff0c;分为两堆&#xff0c;让他们尽可能…

假如Redis⾥面有1亿个key,其中有10w个key是以某个固定的已知的前缀开头的,如何将它们全部找出来?

使⽤用 keys 指令可以扫出指定模式的 key 列列表。但是要注意 keys 指令会导致线程阻塞⼀一段时间&#xff0c;线上服务会停 顿&#xff0c;直到指令执⾏行行完毕&#xff0c;服务才能恢复。这个时候可以使⽤用 scan 指令&#xff0c; scan 指令可以⽆无阻塞的提取出指定模式 的…

C语言 | Leetcode C语言题解之第120题三角形最小路径和

题目&#xff1a; 题解&#xff1a; int minimumTotal(int** triangle, int triangleSize, int* triangleColSize) {int f[triangleSize];memset(f, 0, sizeof(f));f[0] triangle[0][0];for (int i 1; i < triangleSize; i) {f[i] f[i - 1] triangle[i][i];for (int j …

SQL语句来实现不使用子查询的方式,直接通过JOIN和MAX函数来筛选出每个主表关联的最新子表记

除了使用JOIN和子查询的方式外&#xff0c;还可以使用窗口函数来实现不带子查询的方式来筛选出每个主表关联的最新子表记录。 以下是使用窗口函数的SQL语句示例&#xff1a; sql SELECT r.*, t.* FROM (SELECT r.*, t.*,ROW_NUMBER() OVER (PARTITION BY r.id ORDER BY t.creat…

latex中对目录的处理

文章目录 设置目录的章节编号宽度和章节标题的缩进设置条目的间距设置章节标题与页码之间的连接线 设置目录的章节编号宽度和章节标题的缩进 \usepackage{tocloft} \setlength{\cftsubsecnumwidth}{4cm} % 设置子章节编号的宽度为4cm \setlength{\cftsubsecindent}{1cm} % 设置…

【excel】设置二级联动菜单

文章目录 【需求】在一级菜单选定后&#xff0c;二级菜单联动显示一级菜单下的可选项【步骤】step1 制作辅助列1.列转行2.在辅助列中匹配班级成员 之前做完了 【excel】设置可变下拉菜单&#xff08;一级联动下拉菜单&#xff09;&#xff0c;开始做二级联动菜单。 【需求】在…

python实现——综合类型数据挖掘任务(无监督的分类任务)

综合类型数据挖掘任务 航空公司客户价值分析。航空公司客户价值分析。航空公司客户价值分析。航空公司已积累了大量的会员档案信息和其乘坐航班记录&#xff08;air_data.csv&#xff09;&#xff0c;以2014年3月31日为结束时间抽取两年内有乘机记录的所有客户的详细数据。利用…

万界星空科技MES系统功能介绍

制造执行系统或MES 是一个全面的动态软件系统&#xff0c;用于监视、跟踪、记录和控制从原材料到成品的制造过程。MES在企业资源规划(ERP) 和过程控制系统之间提供了一个功能层&#xff0c;为决策者提供了提高车间效率和优化生产所需的数据。 万界星空科技MES 系统基础功能&am…

Spark基础:Scala变量与数据类型

在Scala中&#xff0c;变量和数据类型是编程的基础。Scala作为一种强大的静态类型语言&#xff0c;支持多种数据类型&#xff0c;并提供了可变&#xff08;var&#xff09;和不可变&#xff08;val&#xff09;两种类型的变量声明方式。以下是在Scala中变量和数据类型的基础知识…

【全开源】Java短剧系统微信小程序+H5+微信公众号+APP 源码

打造属于你的精彩短视频平台 一、引言&#xff1a;为何选择短剧系统小程序&#xff1f; 在当今数字化时代&#xff0c;短视频已经成为人们日常生活中不可或缺的一部分。而短剧系统小程序源码&#xff0c;作为构建短视频平台的强大工具&#xff0c;为广大开发者提供了快速搭建…

03-树1 树的同构(浙大数据结构PTA习题)

03-树1 树的同构 分数 25 作者 陈越 单位 浙江大学 给定两棵树 T1​ 和 T2​。如果 T1​ 可以通过若干次左右孩子互换就变成 T2​&#xff0c;则我们称两棵树是“同构”的。例如图1给出的两棵树就是同构的&#xff0c;因为我们把其中一棵树的结点A、B、G…

CSPM.pdf

PDF转图片 归档&#xff1a;