深度学习——图像分类(CNN)—训练模型

训练模型

    • 1.导入必要的库
    • 2.定义超参数
    • 3.读取训练和测试标签CSV文件
    • 4.确保标签是字符串类型
    • 5.显示两个数据框的前几行以了解它们的结构
    • 6.定义图像处理参数
    • 7.创建图像数据生成器
    • 8.设置目录路径
    • 9.创建训练和验证数据生成器
    • 10.构建模型
    • 11.编译模型
    • 12.训练模型并收集历史
    • 13.绘制损失和准确率曲线
    • 14.保存图表
    • 15.保存模型到本地

1.导入必要的库

pandas as pd: Pandas是一个强大的数据分析和处理库,它提供了数据结构(如DataFrame)和工具,用于数据操作和分析。
tensorflow.keras.preprocessing.image import ImageDataGenerator: ImageDataGenerator是Keras的一部分,它用于图像数据的预处理和增强,例如,随机裁剪、旋转、缩放等。
tensorflow.keras.models import Sequential: Sequential模型是Keras中的一种模型,它允许您顺序地堆叠层。
tensorflow.keras.layers: 包含了Keras中所有的层类型,如Conv2D、MaxPooling2D、Flatten、Dense等。
tensorflow.keras.optimizers: 包含了Keras中所有的优化器类型,如Adam、SGD等。
sklearn.model_selection import train_test_split: train_test_split是Scikit-Learn的一部分,它用于将数据集分割为训练集和测试集。
numpy as np: NumPy是一个用于科学计算的库,它提供了高效的数组处理能力,对于图像处理等任务非常有用。
sklearn.preprocessing import LabelBinarizer: LabelBinarizer是Scikit-Learn的一部分,它用于将类别标签转换为二进制数组。
matplotlib.pyplot as plt: Matplotlib是一个绘图库,pyplot是其中的一个模块,它提供了一个类似于MATLAB的绘图框架。
import pickle: pickle是Python的标准库,它用于序列化Python对象,以便将它们保存到文件或从文件中加载。

import pandas as pd
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import train_test_split
import numpy as np
from sklearn.preprocessing import LabelBinarizer
import matplotlib.pyplot as plt
import pickle

2.定义超参数

INIT_LR = 0.01
EPOCHS = 30
BS = 32

3.读取训练和测试标签CSV文件

train_labels.csv和test_labels.csv在资源中。

# 读取训练标签CSV文件
train_labels_filename = 'train_labels.csv'
train_labels_df = pd.read_csv(train_labels_filename)# 读取测试标签CSV文件
test_labels_filename = 'test_labels.csv'
test_labels_df = pd.read_csv(test_labels_filename)

4.确保标签是字符串类型

train_labels_df[‘label’] = train_labels_df[‘label’].astype(str):

train_labels_df['label']:这是train_labels_df DataFrame中名为label的列。
.astype(str):这是Pandas中的一个方法,用于将列的数据类型转换为字符串类型。

test_labels_df[‘label’] = test_labels_df[‘label’].astype(str):

test_labels_df['label']:这是test_labels_df DataFrame中名为label的列。
.astype(str):这是Pandas中的一个方法,用于将列的数据类型转换为字符串类型。

train_labels_df['label'] = train_labels_df['label'].astype(str)
test_labels_df['label'] = test_labels_df['label'].astype(str)

5.显示两个数据框的前几行以了解它们的结构

print(train_labels_df.head())
print(test_labels_df.head())

6.定义图像处理参数

img_width:这是一个变量,用于存储图像的宽度。
img_height:这是一个变量,用于存储图像的高度。
= 150, 150:这行代码将img_width和img_height变量分别设置为150。

img_width, img_height = 150, 150

7.创建图像数据生成器

ImageDataGenerator:这是Keras中的一个类,用于创建一个数据生成器,用于图像数据的增强和预处理。
rescale=1./255:这是一个参数,用于将图像的像素值从0到255的范围转换为0到1的范围,这是常见的图像预处理步骤。
validation_split=0.2:这是一个参数,用于指定训练数据中用于验证的比例。在这里,20%的数据将用于验证,80%的数据将用于训练。
data_gen:这是生成的ImageDataGenerator对象,它将在后续的训练过程中用于生成增强的图像数据。

data_gen = ImageDataGenerator(rescale=1./255, validation_split=0.2)

8.设置目录路径

train和test压缩文件在资源中

# 并且数据集应该存储在环境可访问的路径中
train_dir = 'D:/rgzn/face/DATASET/train'  # 包含子文件夹的父目录
test_dir = 'D:/rgzn/face/DATASET/test'    # 包含子文件夹的父目录

9.创建训练和验证数据生成器

#flow_from_dataframe:这是Keras中的一个方法,用于创建一个数据生成器,它可以从DataFrame中加载图像和标签。
train_data_gen = data_gen.flow_from_dataframe(#要加载的数据源
dataframe=train_labels_df,
#包含图像文件的目录
directory=train_dir,  
#DataFrame中包含图像路径的列名。
x_col='image',
#DataFrame中包含标签的列名。
y_col='label',
#目标图像的大小
target_size=(img_width, img_height),
#每次迭代中从数据生成器中获取的样本数量。
batch_size=32,
#随机种子,用于确保每次运行时生成相同的数据增强
seed=42,
#数据集的子集,用于训练。subset='training',
)
validation_data_gen = data_gen.flow_from_dataframe(dataframe=test_labels_df,directory=test_dir,  # 包含子文件夹的父目录x_col='image',y_col='label',target_size=(img_width, img_height),batch_size=32,
seed=42,
#数据集的子集,用于验证。subset='validation',
)

10.构建模型

# 构建模型
model = Sequential()
model.add(Conv2D(32, (3, 3), activation='relu', input_shape=(150, 150, 3)))
model.add(MaxPooling2D(pool_size=(2, 2)))# 新增的卷积层
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))model.add(Conv2D(128, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))# 展平层
model.add(Flatten())# 全连接层
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))# 输出层
model.add(Dense(7, activation='softmax'))

11.编译模型

model.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['accuracy'])

model:这是之前创建和配置的Keras模型。
compile:这是Keras中的一个方法,用于编译模型,指定训练过程中使用的损失函数、优化器和评估指标。
loss='categorical_crossentropy':这是模型使用的损失函数,适用于多类分类问题。
optimizer='adam':这是模型使用的优化器,用于调整模型的权重以最小化损失函数。
metrics=['accuracy']:这是模型使用的评估指标,用于评估模型在训练数据上的性能。

12.训练模型并收集历史

history = model.fit(train_data_gen, epochs=EPOCHS, validation_data=validation_data_gen, batch_size=BS)

fit:这是Keras中的一个方法,用于训练模型。
train_data_gen:这是之前创建的训练数据生成器。
epochs=EPOCHS:这是训练过程中重复训练数据的次数。
validation_data=validation_data_gen:这是用于验证模型的数据。
batch_size=BS:这是每次迭代中从数据生成器中获取的样本数量。
history:这是训练过程中记录的性能指标,如损失和准确率。

13.绘制损失和准确率曲线

N = np.arange(0, EPOCHS)
#设置图表的样式
plt.style.use('ggplot')
plt.figure()plt.plot(N, history.history['loss'], label='train_loss')
plt.plot(N, history.history['val_loss'], label='val_loss')
plt.plot(N, history.history['accuracy'], label='train_acc')
plt.plot(N, history.history['val_accuracy'], label='val_acc')plt.title("Training Loss And Accuracy (CNN)")
plt.xlabel('Epoch #')
plt.ylabel('Loss/Accuracy')
plt.legend()
plt.axis([0, EPOCHS, 0, 2])

14.保存图表

plt.savefig('plot.png')

15.保存模型到本地

print('[INFO] 正在保存模型')
model.save('model.h5')

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/13927.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Text2SQL 经典模型】SQLNet

论文:SQLNet: Generating Structured Queries From Natural Language Without Reinforcement Learning ⭐⭐⭐⭐ Code: SQLNet | paperwithcodeSQLNet| GitHub 一、论文速读 这篇论文强调了一个问题:order-matters problem —— 意思是说,对…

【C语言】8.C语言操作符详解(2)

文章目录 6.单⽬操作符7.逗号表达式8.下标访问[]、函数调⽤()8.1 [ ] 下标引⽤操作符8.2 函数调⽤操作符 9.结构成员访问操作符9.1 结构体9.1.1 结构的声明9.1.2 结构体变量的定义和初始化 9.2 结构成员访问操作符9.2.1 结构体成员的直接访问9.2.2 结构体成员的间接访问 6.单⽬…

2024.5组队学习——MetaGPT(0.8.1)智能体理论与实战(中):订阅智能体OSS实现

传送门: 《2024.5组队学习——MetaGPT(0.8.1)智能体理论与实战(上):MetaGPT安装、单智能体开发》《2024.5组队学习——MetaGPT(0.8.1)智能体理论与实战(下)&…

【线段图案】

描述 KiKi学习了循环,BoBo老师给他出了一系列打印图案的练习,该任务是打印用“*”组成的线段图案。 输入描述: 多组输入,一个整数(1~100),表示线段长度,即“*”的数量。 输出描述…

是德科技 DSOS054A MSOS054A示波器

产品 带宽 通道数 最大存储器深度 DSOS054A 高清晰度示波器 500 MHz 4 个模拟通道 800 Mpts MSOS054A 高清晰度示波器 500 MHz 4 个模拟通道和 16 个数字通道 800 Mpts Infiniium S 系列示波…

R语言使用 ggscidca包优雅的绘制支持向量机决策曲线

DCA(Decision Curve Analysis)临床决策曲线是一种用于评价诊断模型诊断准确性的方法,在2006年由AndrewVickers博士创建,我们通常判断一个疾病喜欢使用ROC曲线的AUC值来判定模型的准确性,但ROC曲线通常是通过特异度和敏感度来评价,…

产品研发流程(方法)

产品研发流程(方法) IPD(Integrated Product Development,集成产品开发)和敏捷开发是现代软件开发中常见的两种方法论。 IPD IPD研发流程 IPD是一种系统化的产品开发方法,旨在通过跨职能团队的协作和集成…

vue项目报错:internal/modules/cjs/loader.js:892 throw err;

前言: vue项目中无法正常使用git,并报错情况。 报错信息: internal/modules/cjs/loader.js:892throw err;^ Error: Cannot find module D:\project\sd_wh_yth_front\node_modules\yorkie\src\runner.js 报错处理: npm install y…

GitHub的原理及应用详解(二)

本系列文章简介: GitHub是一个基于Git版本控制系统的代码托管平台,为开发者提供了一个方便的协作和版本管理的工具。它广泛应用于软件开发项目中,包括但不限于代码托管、协作开发、版本控制、错误追踪、持续集成等方面。 GitHub的原理可以简单…

夏天晚上热,早上凉怎么办?

温差太大容易引起感冒 1.定个大概3点的闹钟,起来盖被子。有些土豪可以开空调,我这个咸鱼没有空调。 2.空调调到合适的温度,比如20几度。

Redisson的setConnectionMinimumIdleSize配置,设置小一点有什么影响?

设置最小空闲连接数 (connectionMinimumIdleSize) 是影响 RedissonClient 性能和资源利用的一个关键配置。让我们详细了解一下这个参数的作用及其影响。 什么是 connectionMinimumIdleSize? connectionMinimumIdleSize 参数用于定义连接池中保持的最小空闲连接数。…

【HarmonyOS4学习笔记】《HarmonyOS4+NEXT星河版入门到企业级实战教程》课程学习笔记(十一)

课程地址: 黑马程序员HarmonyOS4NEXT星河版入门到企业级实战教程,一套精通鸿蒙应用开发 (本篇笔记对应课程第 18 节) P18《17.ArkUI-状态管理Observed 和 ObjectLink》 第一件事:嵌套对象的类型上加上 Observed 装饰器…

基于网络爬虫技术的网络新闻分析(四)

目录 4.2 系统异常处理 4.2.1 爬虫异常总体概况 4.2.2 爬虫访问网页被拒绝 5 软件测试 5.1 白盒测试 5.1.1 爬虫系统测试结果 5.1.2 中文分词系统测试结果 5.1.3 中文文章相似度匹配系统测试结果 5.1.4 相似新闻趋势展示系统测试结果 5.2 黑盒测试 5.2.1 爬虫系统测…

shell编程-结构化命令

1、if-then语句 如果command命令执行成功(退出状态码为0),then部分的命令就会执行 if comand then commands fi 或 if comand: then commands fi 例如: #!/bin/bash if date who then echo this is test …

2024电工杯数学建模 - 案例:最短时间生产计划安排

# 前言 2024电工杯(中国电机工程学会杯)数学建模思路解析 最新思路更新(看最新发布的文章即可): https://blog.csdn.net/dc_sinor/article/details/138726153 最短时间生产计划模型 该模型出现在好几个竞赛赛题上,预测2022今年国赛也会与该模型相关。 1 模型描…

CoShNet:使用复数改进神经网络

使用复数改进神经网络 文章目录 一、说明二、了解卷积神经网络三、进入混合神经网络四、令人惊叹的 CoSh 网络五、复杂函数的神奇性质六、相位一致性七、结论 一、说明 本文题为“CoShNet:使用Shearlets的混合复杂值神经网络”,提出了在混合神经网络中使…

深入理解SVM和浅层机器学习算法的训练机制

深入理解SVM和浅层机器学习算法的训练机制支持向量机(SVM)的训练过程SVM的基本概念SVM的损失函数训练方法 浅层机器学习算法的训练机制决策树K-最近邻(K-NN)朴素贝叶斯 结论 深入理解SVM和浅层机器学习算法的训练机制 在探讨浅层…

nrf52832自定义蓝牙名字过长,广播显示不全

自定义蓝牙名字过长,广播显示不全 原因:nrf52832默认使用蓝牙4.x的广播,它的广播包数据只有32byte数据,当广播已经包含足够多的数据的时候,广播每次过长就会显示部分名称,即便你选择"BLE_ADVDATA_FULL_NAME"…

Python学习---基于多任务协程的并发下载器案例

目标:使用协程实现网络图片的下载(适合网络io) 多进程: 密集CPU任务,需要充分使用多核CPU资源(服务器,大量的并行计算)的时候,用多进程。 缺陷:多个进程之间通…

Python基础学习笔记(六)——列表

目录 一维列表1. 索引的查询与返回2. 切片3. 添加元素4. 删除元素5. 更改元素6. 排序7. 生成式 一维列表 列表,也称数组,是一种有序、可变、允许重复元素的组合数据结构,属于可变序列,由方括号[]内、用逗号分隔的一组元素组成。 列…