使用 CNN 训练自己的数据集

CNN(练习数据集)

  • 1.导包:
  • 2.导入数据集:
  • 3. 使用image_dataset_from_directory()将数据加载tf.data.Dataset中:
  • 4. 查看数据集中的一部分图像,以及它们对应的标签:
  • 5.迭代数据集 train_ds,以便查看第一批图像和标签的形状:
  • 6.使用TensorFlow的ImageDataGenerator类来创建一个数据增强的对象:
  • 7.将数据集缓存到内存中,加快速度:
  • 8. 通过卷积层和池化层提取特征,再通过全连接层进行分类:
  • 9.打印网络结构:
  • 10.设置优化器,定义了训练轮次和批量大小:
  • 11.训练数据集:
  • 12.画出图像:
  • 13.评估您的模型在验证数据集的性能:
  • 14.输出在验证集上的预测结果和真实值的对比:
  • 15.输出可视化报表:

  • 在网上寻找一个新的数据集,自己进行训练

1.导包:

import pandas as pd
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import train_test_split
import numpy as np
from sklearn.preprocessing import LabelBinarizer
import matplotlib.pyplot as plt
import pickle
import pathlib
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models

输出结果:
在这里插入图片描述

2.导入数据集:

# 定义超参数
data_dir = "D:\JUANJI"
data_dir = pathlib.Path(data_dir)
image_count = len(list(data_dir.glob('*/*.jpg')))
print("图片总数为:", image_count)
batch_size = 30
img_height = 180
img_width = 180

输出结果:
在这里插入图片描述

3. 使用image_dataset_from_directory()将数据加载tf.data.Dataset中:

#  使用image_dataset_from_directory()将数据加载到tf.data.Dataset中
train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,  # 验证集0.2subset="training",seed=123,image_size=(img_height, img_width),batch_size=batch_size)val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="validation",seed=123,image_size=(img_height, img_width),batch_size=batch_size)

输出结果:
在这里插入图片描述

4. 查看数据集中的一部分图像,以及它们对应的标签:

class_names = train_ds.class_names
print(class_names)
# 可视化
plt.figure(figsize=(16, 8))
for images, labels in train_ds.take(1):for i in range(16):ax = plt.subplot(4, 4, i + 1)# plt.imshow(images[i], cmap=plt.cm.binary)plt.imshow(images[i].numpy().astype("uint8"))plt.title(class_names[labels[i]])plt.axis("off")
plt.show()

输出结果:
在这里插入图片描述
在这里插入图片描述

5.迭代数据集 train_ds,以便查看第一批图像和标签的形状:

for image_batch, labels_batch in train_ds:print(image_batch.shape)print(labels_batch.shape)break

输出结果:
在这里插入图片描述

6.使用TensorFlow的ImageDataGenerator类来创建一个数据增强的对象:

aug = ImageDataGenerator(rotation_range=30, width_shift_range=0.1,height_shift_range=0.1, shear_range=0.2, zoom_range=0.2,horizontal_flip=True, fill_mode="nearest")
x = aug.flow(image_batch, labels_batch)
AUTOTUNE = tf.data.AUTOTUNE

输出结果:
在这里插入图片描述

7.将数据集缓存到内存中,加快速度:

train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

输出结果:
在这里插入图片描述

8. 通过卷积层和池化层提取特征,再通过全连接层进行分类:

# 为了增加模型的泛化能力,增加了Dropout层,并将最大池化层更新为平均池化层
num_classes = 3
model = models.Sequential([layers.experimental.preprocessing.Rescaling(1./255,input_shape=(img_height,img_width, 3)),layers.Conv2D(32, (3, 3), activation='relu'),layers.MaxPooling2D((2, 2)),layers.Conv2D(64, (3, 3), activation='relu'),layers.MaxPooling2D((2, 2)),layers.Conv2D(128, (3, 3), activation='relu'),layers.MaxPooling2D((2, 2)),layers.Conv2D(256, (3, 3), activation='relu'),layers.MaxPooling2D((2, 2)),layers.Flatten(),layers.Dense(512, activation='relu'),layers.Dense(num_classes)
])

输出结果:
在这里插入图片描述

9.打印网络结构:

model.summary()

输出结果:
在这里插入图片描述

10.设置优化器,定义了训练轮次和批量大小:

# 设置优化器
opt = tf.keras.optimizers.Adam(learning_rate=0.001)model.compile(optimizer=opt,loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])EPOCHS = 100
BS = 5

输出结果:
在这里插入图片描述

11.训练数据集:

# 训练网络
# model.fit 可同时处理训练和即时扩充的增强数据。
# 我们必须将训练数据作为第一个参数传递给生成器。生成器将根据我们先前进行的设置生成批量的增强训练数据。
for images_train, labels_train in train_ds:continue
for images_test, labels_test in val_ds:continue
history = model.fit(x=aug.flow(images_train,labels_train, batch_size=BS),validation_data=(images_test,labels_test),
steps_per_epoch=1,epochs=EPOCHS)

输出结果:
在这里插入图片描述

12.画出图像:

# 画出训练精确度和损失图
N = np.arange(0, EPOCHS)
plt.style.use("ggplot")
plt.figure()
plt.plot(N, history.history["loss"], label="train_loss")
plt.plot(N, history.history["val_loss"], label="val_loss")
plt.plot(N, history.history["accuracy"], label="train_acc")
plt.plot(N, history.history["val_accuracy"], label="val_acc")
plt.title("Aug Training Loss and Accuracy")
plt.xlabel("Epoch #")
plt.ylabel("Loss/Accuracy")
plt.legend(loc='upper right')  # legend显示位置
plt.show()

输出结果:
在这里插入图片描述

13.评估您的模型在验证数据集的性能:

test_loss, test_acc = model.evaluate(val_ds, verbose=2)
print(test_loss, test_acc)

输出结果:
在这里插入图片描述

14.输出在验证集上的预测结果和真实值的对比:

#  优化2 输出在验证集上的预测结果和真实值的对比
pre = model.predict(val_ds)
for images, labels in val_ds.take(1):for i in range(4):ax = plt.subplot(1, 4, i + 1)plt.imshow(images[i].numpy().astype("uint8"))plt.title(class_names[labels[i]])plt.xticks([])plt.yticks([])# plt.xlabel('pre: ' + class_names[np.argmax(pre[i])] + ' real: ' + class_names[labels[i]])plt.xlabel('pre: ' + class_names[np.argmax(pre[i])])print('pre: ' + str(class_names[np.argmax(pre[i])]) + ' real: ' + class_names[labels[i]])
plt.show()

输出结果:
在这里插入图片描述

15.输出可视化报表:

print(labels_test)
print(labels)
print(pre)
print(class_names)
from sklearn.metrics import classification_report
# 优化1 输出可视化报表
print(classification_report(labels_test,pre.argmax(axis=1),
target_names=class_names))

输出结果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/19424.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【漏洞复现】DT-高清车牌识别摄像机 任意文件读取漏洞

0x01 产品简介 DT-高清 车牌识别摄像机是一款先进的安防设备,采用高清图像传感器和先进的识别算法,能够精准、快速地识别车牌信息。其高清晰该摄像机结合了智能识别技术,支持实时监宴图像质量确保在各种光照和天气条件下都能准确捕捉车牌信息…

【面试八股总结】MySQL事务:事务特性、事务并行、事务的隔离级别

参考资料:小林coding 一、事务的特性ACID 原子性(Atomicity) 一个事务是一个不可分割的工作单位,事务中的所有操作,要么全部完成,要么全部不完成,不会结束在中间某个环节。原子性是通过 undo …

C#根据数据量自动排版标签的样例

这是一个C#根据数据量自动排版标签的样例 using System; using System.Collections.Generic; using System.Data.SqlClient; using System.Drawing; using System.Text; using System.Threading; using System.Threading.Tasks; using System.Windows.Forms; using HslCommuni…

代码随想录算法训练营第四十五天 | 1049. 最后一块石头的重量 II、494. 目标和、474.一和零

1049. 最后一块石头的重量 II 视频讲解: 动态规划之背包问题,这个背包最多能装多少?LeetCode:1049.最后一块石头的重量II_哔哩哔哩_bilibili 代码随想录 解题思路 直接将这一些石头,分为两堆,让他们尽可能…

C语言 | Leetcode C语言题解之第120题三角形最小路径和

题目&#xff1a; 题解&#xff1a; int minimumTotal(int** triangle, int triangleSize, int* triangleColSize) {int f[triangleSize];memset(f, 0, sizeof(f));f[0] triangle[0][0];for (int i 1; i < triangleSize; i) {f[i] f[i - 1] triangle[i][i];for (int j …

【excel】设置二级联动菜单

文章目录 【需求】在一级菜单选定后&#xff0c;二级菜单联动显示一级菜单下的可选项【步骤】step1 制作辅助列1.列转行2.在辅助列中匹配班级成员 之前做完了 【excel】设置可变下拉菜单&#xff08;一级联动下拉菜单&#xff09;&#xff0c;开始做二级联动菜单。 【需求】在…

python实现——综合类型数据挖掘任务(无监督的分类任务)

综合类型数据挖掘任务 航空公司客户价值分析。航空公司客户价值分析。航空公司客户价值分析。航空公司已积累了大量的会员档案信息和其乘坐航班记录&#xff08;air_data.csv&#xff09;&#xff0c;以2014年3月31日为结束时间抽取两年内有乘机记录的所有客户的详细数据。利用…

万界星空科技MES系统功能介绍

制造执行系统或MES 是一个全面的动态软件系统&#xff0c;用于监视、跟踪、记录和控制从原材料到成品的制造过程。MES在企业资源规划(ERP) 和过程控制系统之间提供了一个功能层&#xff0c;为决策者提供了提高车间效率和优化生产所需的数据。 万界星空科技MES 系统基础功能&am…

【全开源】Java短剧系统微信小程序+H5+微信公众号+APP 源码

打造属于你的精彩短视频平台 一、引言&#xff1a;为何选择短剧系统小程序&#xff1f; 在当今数字化时代&#xff0c;短视频已经成为人们日常生活中不可或缺的一部分。而短剧系统小程序源码&#xff0c;作为构建短视频平台的强大工具&#xff0c;为广大开发者提供了快速搭建…

03-树1 树的同构(浙大数据结构PTA习题)

03-树1 树的同构 分数 25 作者 陈越 单位 浙江大学 给定两棵树 T1​ 和 T2​。如果 T1​ 可以通过若干次左右孩子互换就变成 T2​&#xff0c;则我们称两棵树是“同构”的。例如图1给出的两棵树就是同构的&#xff0c;因为我们把其中一棵树的结点A、B、G…

CSPM.pdf

PDF转图片 归档&#xff1a;

跨境电商多店铺:怎么管理?风险如何规避?

跨境电商的市场辽阔&#xff0c;有非常多的商业机会。你可能已经在Amazon、eBay、Etsy等在线平台向潜在客户销售产品了。为了赚更多的钱&#xff0c;你可能还在经营多个店铺和品牌。 但是&#xff0c;像Amazon、eBay、Etsy等知名平台会有自己的规则&#xff0c;他们开发了很多…

手拉手springboot整合kafka发送消息

环境介绍技术栈springbootmybatis-plusmysqlrocketmq软件版本mysql8IDEAIntelliJ IDEA 2022.2.1JDK17Spring Boot3.1.7kafka2.13-3.7.0 创建topic时&#xff0c;若不指定topic的分区(Partition主题分区数)数量使&#xff0c;则默认为1个分区(partition) springboot加入依赖kafk…

探索无限可能性——微软 Visio 2021 改变您的思维方式

在当今信息化时代&#xff0c;信息流动和数据处理已经成为各行各业的关键。微软 Visio 2021 作为领先的流程图和图表软件&#xff0c;帮助用户以直观、动态的方式呈现信息和数据&#xff0c;从而提高工作效率&#xff0c;优化业务流程。本文将介绍 Visio 2021 的特色功能及其在…

华为OD机试 - 游戏分组 - 递归(Java 2024 C卷 100分)

华为OD机试 2024C卷题库疯狂收录中&#xff0c;刷题点这里 专栏导读 本专栏收录于《华为OD机试&#xff08;JAVA&#xff09;真题&#xff08;A卷B卷C卷&#xff09;》。 刷的越多&#xff0c;抽中的概率越大&#xff0c;每一题都有详细的答题思路、详细的代码注释、样例测试…

精准检测,安全无忧:安全阀检测实践指南

安全阀作为一种重要的安全装置&#xff0c;在各类工业系统和设备中发挥着举足轻重的作用。 它通过自动控制内部压力&#xff0c;有效防止因压力过高而引发的设备损坏和事故风险&#xff0c;因此&#xff0c;对安全阀进行定期检测&#xff0c;确保其性能完好、工作可靠&#xf…

使用pytorch构建ResNet50模型训练猫狗数据集

数据集 1.导包 import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torchvision import datasets, transforms, models import numpy as np import matplotlib.pyplot as plt import os from tqdm.auto import t…

流媒体服务器SMS-语音对讲(一)

1.简介 在国标语音对讲对接中&#xff0c;会发现不同的厂商或不同型号的设备&#xff0c;对讲流程都不一样&#xff0c;本文主要介绍流媒体与设备之间的交互情况。 SMS流媒体服务代码库地址&#xff1a;https://gitee.com/inyeme/simple-media-server 2.流媒体与设备交互的可能…

max6675热电偶温度采集

思路来源 参考价格 概述 MAX6675具有冷端补偿和将来自K型热电偶的信号数字化。数据以12位分辨率输出&#xff0c;SPI™兼容&#xff0c; 只读格式。该转换器将温度分解为0.25C&#xff0c;允许读数高达1024C&#xff0c;并显示热电偶8LSB在0C至 700C 引脚连接 温度采样电路 …

中间件复习之-消息队列

消息队列在分布式架构的作用 消息队列&#xff1a;在消息的传输过程中保存消息的容器&#xff0c;生产者和消费者不直接通讯&#xff0c;依靠队列保证消息的可靠性&#xff0c;避免了系统间的相互影响。 主要作用&#xff1a; 业务解耦异步调用流量削峰 业务解耦 将模块间的…