深度学习——图像分类(CNN)—训练模型

训练模型

    • 1.导入必要的库
    • 2.定义超参数
    • 3.读取训练和测试标签CSV文件
    • 4.确保标签是字符串类型
    • 5.显示两个数据框的前几行以了解它们的结构
    • 6.定义图像处理参数
    • 7.创建图像数据生成器
    • 8.设置目录路径
    • 9.创建训练和验证数据生成器
    • 10.构建模型
    • 11.编译模型
    • 12.训练模型并收集历史
    • 13.绘制损失和准确率曲线
    • 14.保存图表
    • 15.保存模型到本地

1.导入必要的库

pandas as pd: Pandas是一个强大的数据分析和处理库,它提供了数据结构(如DataFrame)和工具,用于数据操作和分析。
tensorflow.keras.preprocessing.image import ImageDataGenerator: ImageDataGenerator是Keras的一部分,它用于图像数据的预处理和增强,例如,随机裁剪、旋转、缩放等。
tensorflow.keras.models import Sequential: Sequential模型是Keras中的一种模型,它允许您顺序地堆叠层。
tensorflow.keras.layers: 包含了Keras中所有的层类型,如Conv2D、MaxPooling2D、Flatten、Dense等。
tensorflow.keras.optimizers: 包含了Keras中所有的优化器类型,如Adam、SGD等。
sklearn.model_selection import train_test_split: train_test_split是Scikit-Learn的一部分,它用于将数据集分割为训练集和测试集。
numpy as np: NumPy是一个用于科学计算的库,它提供了高效的数组处理能力,对于图像处理等任务非常有用。
sklearn.preprocessing import LabelBinarizer: LabelBinarizer是Scikit-Learn的一部分,它用于将类别标签转换为二进制数组。
matplotlib.pyplot as plt: Matplotlib是一个绘图库,pyplot是其中的一个模块,它提供了一个类似于MATLAB的绘图框架。
import pickle: pickle是Python的标准库,它用于序列化Python对象,以便将它们保存到文件或从文件中加载。

import pandas as pd
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import train_test_split
import numpy as np
from sklearn.preprocessing import LabelBinarizer
import matplotlib.pyplot as plt
import pickle

2.定义超参数

INIT_LR = 0.01
EPOCHS = 30
BS = 32

3.读取训练和测试标签CSV文件

train_labels.csv和test_labels.csv在资源中。

# 读取训练标签CSV文件
train_labels_filename = 'train_labels.csv'
train_labels_df = pd.read_csv(train_labels_filename)# 读取测试标签CSV文件
test_labels_filename = 'test_labels.csv'
test_labels_df = pd.read_csv(test_labels_filename)

4.确保标签是字符串类型

train_labels_df[‘label’] = train_labels_df[‘label’].astype(str):

train_labels_df['label']:这是train_labels_df DataFrame中名为label的列。
.astype(str):这是Pandas中的一个方法,用于将列的数据类型转换为字符串类型。

test_labels_df[‘label’] = test_labels_df[‘label’].astype(str):

test_labels_df['label']:这是test_labels_df DataFrame中名为label的列。
.astype(str):这是Pandas中的一个方法,用于将列的数据类型转换为字符串类型。

train_labels_df['label'] = train_labels_df['label'].astype(str)
test_labels_df['label'] = test_labels_df['label'].astype(str)

5.显示两个数据框的前几行以了解它们的结构

print(train_labels_df.head())
print(test_labels_df.head())

6.定义图像处理参数

img_width:这是一个变量,用于存储图像的宽度。
img_height:这是一个变量,用于存储图像的高度。
= 150, 150:这行代码将img_width和img_height变量分别设置为150。

img_width, img_height = 150, 150

7.创建图像数据生成器

ImageDataGenerator:这是Keras中的一个类,用于创建一个数据生成器,用于图像数据的增强和预处理。
rescale=1./255:这是一个参数,用于将图像的像素值从0到255的范围转换为0到1的范围,这是常见的图像预处理步骤。
validation_split=0.2:这是一个参数,用于指定训练数据中用于验证的比例。在这里,20%的数据将用于验证,80%的数据将用于训练。
data_gen:这是生成的ImageDataGenerator对象,它将在后续的训练过程中用于生成增强的图像数据。

data_gen = ImageDataGenerator(rescale=1./255, validation_split=0.2)

8.设置目录路径

train和test压缩文件在资源中

# 并且数据集应该存储在环境可访问的路径中
train_dir = 'D:/rgzn/face/DATASET/train'  # 包含子文件夹的父目录
test_dir = 'D:/rgzn/face/DATASET/test'    # 包含子文件夹的父目录

9.创建训练和验证数据生成器

#flow_from_dataframe:这是Keras中的一个方法,用于创建一个数据生成器,它可以从DataFrame中加载图像和标签。
train_data_gen = data_gen.flow_from_dataframe(#要加载的数据源
dataframe=train_labels_df,
#包含图像文件的目录
directory=train_dir,  
#DataFrame中包含图像路径的列名。
x_col='image',
#DataFrame中包含标签的列名。
y_col='label',
#目标图像的大小
target_size=(img_width, img_height),
#每次迭代中从数据生成器中获取的样本数量。
batch_size=32,
#随机种子,用于确保每次运行时生成相同的数据增强
seed=42,
#数据集的子集,用于训练。subset='training',
)
validation_data_gen = data_gen.flow_from_dataframe(dataframe=test_labels_df,directory=test_dir,  # 包含子文件夹的父目录x_col='image',y_col='label',target_size=(img_width, img_height),batch_size=32,
seed=42,
#数据集的子集,用于验证。subset='validation',
)

10.构建模型

# 构建模型
model = Sequential()
model.add(Conv2D(32, (3, 3), activation='relu', input_shape=(150, 150, 3)))
model.add(MaxPooling2D(pool_size=(2, 2)))# 新增的卷积层
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))model.add(Conv2D(128, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))# 展平层
model.add(Flatten())# 全连接层
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))# 输出层
model.add(Dense(7, activation='softmax'))

11.编译模型

model.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['accuracy'])

model:这是之前创建和配置的Keras模型。
compile:这是Keras中的一个方法,用于编译模型,指定训练过程中使用的损失函数、优化器和评估指标。
loss='categorical_crossentropy':这是模型使用的损失函数,适用于多类分类问题。
optimizer='adam':这是模型使用的优化器,用于调整模型的权重以最小化损失函数。
metrics=['accuracy']:这是模型使用的评估指标,用于评估模型在训练数据上的性能。

12.训练模型并收集历史

history = model.fit(train_data_gen, epochs=EPOCHS, validation_data=validation_data_gen, batch_size=BS)

fit:这是Keras中的一个方法,用于训练模型。
train_data_gen:这是之前创建的训练数据生成器。
epochs=EPOCHS:这是训练过程中重复训练数据的次数。
validation_data=validation_data_gen:这是用于验证模型的数据。
batch_size=BS:这是每次迭代中从数据生成器中获取的样本数量。
history:这是训练过程中记录的性能指标,如损失和准确率。

13.绘制损失和准确率曲线

N = np.arange(0, EPOCHS)
#设置图表的样式
plt.style.use('ggplot')
plt.figure()plt.plot(N, history.history['loss'], label='train_loss')
plt.plot(N, history.history['val_loss'], label='val_loss')
plt.plot(N, history.history['accuracy'], label='train_acc')
plt.plot(N, history.history['val_accuracy'], label='val_acc')plt.title("Training Loss And Accuracy (CNN)")
plt.xlabel('Epoch #')
plt.ylabel('Loss/Accuracy')
plt.legend()
plt.axis([0, EPOCHS, 0, 2])

14.保存图表

plt.savefig('plot.png')

15.保存模型到本地

print('[INFO] 正在保存模型')
model.save('model.h5')

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/13927.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Text2SQL 经典模型】SQLNet

论文:SQLNet: Generating Structured Queries From Natural Language Without Reinforcement Learning ⭐⭐⭐⭐ Code: SQLNet | paperwithcodeSQLNet| GitHub 一、论文速读 这篇论文强调了一个问题:order-matters problem —— 意思是说,对…

2024.5组队学习——MetaGPT(0.8.1)智能体理论与实战(中):订阅智能体OSS实现

传送门: 《2024.5组队学习——MetaGPT(0.8.1)智能体理论与实战(上):MetaGPT安装、单智能体开发》《2024.5组队学习——MetaGPT(0.8.1)智能体理论与实战(下)&…

【线段图案】

描述 KiKi学习了循环,BoBo老师给他出了一系列打印图案的练习,该任务是打印用“*”组成的线段图案。 输入描述: 多组输入,一个整数(1~100),表示线段长度,即“*”的数量。 输出描述…

是德科技 DSOS054A MSOS054A示波器

产品 带宽 通道数 最大存储器深度 DSOS054A 高清晰度示波器 500 MHz 4 个模拟通道 800 Mpts MSOS054A 高清晰度示波器 500 MHz 4 个模拟通道和 16 个数字通道 800 Mpts Infiniium S 系列示波…

R语言使用 ggscidca包优雅的绘制支持向量机决策曲线

DCA(Decision Curve Analysis)临床决策曲线是一种用于评价诊断模型诊断准确性的方法,在2006年由AndrewVickers博士创建,我们通常判断一个疾病喜欢使用ROC曲线的AUC值来判定模型的准确性,但ROC曲线通常是通过特异度和敏感度来评价,…

vue项目报错:internal/modules/cjs/loader.js:892 throw err;

前言: vue项目中无法正常使用git,并报错情况。 报错信息: internal/modules/cjs/loader.js:892throw err;^ Error: Cannot find module D:\project\sd_wh_yth_front\node_modules\yorkie\src\runner.js 报错处理: npm install y…

夏天晚上热,早上凉怎么办?

温差太大容易引起感冒 1.定个大概3点的闹钟,起来盖被子。有些土豪可以开空调,我这个咸鱼没有空调。 2.空调调到合适的温度,比如20几度。

【HarmonyOS4学习笔记】《HarmonyOS4+NEXT星河版入门到企业级实战教程》课程学习笔记(十一)

课程地址: 黑马程序员HarmonyOS4NEXT星河版入门到企业级实战教程,一套精通鸿蒙应用开发 (本篇笔记对应课程第 18 节) P18《17.ArkUI-状态管理Observed 和 ObjectLink》 第一件事:嵌套对象的类型上加上 Observed 装饰器…

基于网络爬虫技术的网络新闻分析(四)

目录 4.2 系统异常处理 4.2.1 爬虫异常总体概况 4.2.2 爬虫访问网页被拒绝 5 软件测试 5.1 白盒测试 5.1.1 爬虫系统测试结果 5.1.2 中文分词系统测试结果 5.1.3 中文文章相似度匹配系统测试结果 5.1.4 相似新闻趋势展示系统测试结果 5.2 黑盒测试 5.2.1 爬虫系统测…

2024电工杯数学建模 - 案例:最短时间生产计划安排

# 前言 2024电工杯(中国电机工程学会杯)数学建模思路解析 最新思路更新(看最新发布的文章即可): https://blog.csdn.net/dc_sinor/article/details/138726153 最短时间生产计划模型 该模型出现在好几个竞赛赛题上,预测2022今年国赛也会与该模型相关。 1 模型描…

CoShNet:使用复数改进神经网络

使用复数改进神经网络 文章目录 一、说明二、了解卷积神经网络三、进入混合神经网络四、令人惊叹的 CoSh 网络五、复杂函数的神奇性质六、相位一致性七、结论 一、说明 本文题为“CoShNet:使用Shearlets的混合复杂值神经网络”,提出了在混合神经网络中使…

深入理解SVM和浅层机器学习算法的训练机制

深入理解SVM和浅层机器学习算法的训练机制支持向量机(SVM)的训练过程SVM的基本概念SVM的损失函数训练方法 浅层机器学习算法的训练机制决策树K-最近邻(K-NN)朴素贝叶斯 结论 深入理解SVM和浅层机器学习算法的训练机制 在探讨浅层…

展现金融科技前沿力量,ATFX于哥伦比亚金融博览会绽放光彩

不到半个月的时间里,高光时刻再度降临ATFX。而这一次,是ATFX不曾拥有的桂冠—“全球最佳在线经纪商”(Best Global Online Broker)。2024年5月15日至16日,拉丁美洲首屈一指的金融盛会—2024年哥伦比亚金融博览会(Money Expo Colombia 2024) 于…

AI智能体|使用扣子Coze基于IDE创建自定义插件

大家好,我是无界生长。 在使用Coze的过程中,有些个性化场景无法通过插件商店已有的插件满足,这个时候就需要通过自定义插件的方式来实现业务需求。下面将通过一个实际案例来简单介绍下如何使用Coze基于IDE创建自定义插件,完成在Co…

2024最新流媒体在线音乐系统网站源码| 音乐社区 | 多语言 | 开心版

简介: 2024最新流媒体在线音乐系统网站源码| 音乐社区 | 多语言 | 开心版 下载地址 https://www.kuaiyuanya.com/product/article/index/id/33.html 图片:

使用 Django Rest Framework 构建强大的 Web API

文章目录 安装 Django Rest Framework创建序列化器创建视图和 URL 路由配置认证和权限测试 API Django Rest Framework(DRF)是一个强大的工具,用于在 Django Web 框架中构建灵活且功能丰富的 Web API。它提供了许多功能,包括序列化…

(六)DockerCompose安装与配置

DockerCompose简介 Compose 项目是 Docker 官方的开源项目,负责实现对 Docker 容器集群的快速编排。使用前面介绍的Dockerfile我们很容易定义一个单独的应用容器。然而在日常开发工作中,经常会碰到需要多个容器相互配合来完成某项任务的情况。例如要实现…

protobuf学习

学习了下protobuf这个工具,可以用来序列化数据结构,而且效率很高,数据可以压缩的更小。 记录下,我这里主要在C#里使用,从NuGet程序包安装以下两个 安装好后可以在该程序目录找到 packages\Google.Protobuf.Tools.3.26.…

在windows中使用wsl下的unbuntu环境

1 unbuntu下载编译环境 编译环境安装命令: sudo apt install gdb sudo apt install gcc sudo apt install g 2 使用vscode正常打开项目,在window中打开的项目(官方推荐将项目放在linux中的home目录) 但在windows中也可以使用&a…

汐鹤Key码查询,网站授权系统源码

汐鹤Key码查询和网站授权系统源码主要用于特殊虚拟物品销售商家。 下 载 地 址 : runruncode.com/php/19770.html 附带插件功能(网站授权),但目前开发内容较少,请谅解!同时,代码优化空间很大…