深度学习笔记(八)——构建网络的常用辅助增强方法:数据增强扩充、断点续训、可视化和部署预测

文中程序以Tensorflow-2.6.0为例
部分概念包含笔者个人理解,如有遗漏或错误,欢迎评论或私信指正。
截图和程序部分引用自北京大学机器学习公开课

要构建一个完善可用的神经网络,除了设计网络结构以外,还需要添加一些辅助代码来增强网络运行的稳定性,鲁棒性。可以用来增强的方向主要有 个,首先是数据输入前的预处理环节,其次是数据在训练过程中的优化,最后的数据在训练结束后的导出和可视化,同时能够及时保存结果和继续上一次训练在实际工作中是十分有效的。

训练的前奏曲——数据集

在前面的代码中往往是直接加载现有的数据集,然后送入网络进行学习,实际工作研究中,数据往往需要重新花费不少时间去采集、标准化和标注。对于不同的数据集一般可以有不同的处理方式。首先是最常见的图像数据,在进行图像学习之前先要制作图像数据集,一个好的数据集可以帮助我们达到事半功倍的效果。制作图像数据集时,有几个基础要求:

  • 有比较明显的特征,图像中背景信息和语义信息有比较明显的区分,语义信息就是我们关注的对象;
  • 尺寸合适,图像尺寸能够满足网络输入的基本要求,现在大部分主流手机拍照得到的照片尺寸太大,不合适在验证阶段使用;
  • 数据格式能够加载,图像数据格式最好是RGB或者灰度图;
  • 数据覆盖全面,数据源尽可能多的覆盖研究对象可能出现或者存在的场景;
  • 标注不出错,标注时不能只图快,要保证一定的精确度

当然,研究的数据不一定都是图像,还有可能是时间序列数据,是多组传感器的采集值,是一段文本等等,但是都避不开 数据格式、数据规模、标签准确度这几个关键因素。
下面还是以基础的图像处理来说明数据集处理中常用的几个方法。

数据增强,扩充数据集

当被用来的训练的数据由于成本、时间等原因受到限制的时候可以通过数据的扩充适当的将原始的数据增加出一定比例。但这不意味着可以无限制的增加数据规模,只有当已经采集到的数据达到一定规模之后,数据的扩充作为锦上添花才能起到比较好的作用。
在数据扩充之前要先解决一个问题,就是如何把数据从原始的文件夹中读取到程序中。通常可以使用python自带的os库遍历某个路径下所有符合要求格式的数据,然后以此加载数据,随后通过PIL或Pands对读取的数据进行格式化操作,最后转化为numpy数组,再导入为tensor格式就可以用来训练了。在加载数据和格式化数据时,要注意做好特征数据和标签数据的对应。
下面我们尝试从指定的本文文件中读取数据和对应的标签,首先假设我们有这样的一组数据:
在这里插入图片描述
右键,在所在位置打开终端输入:

dir/b>file_name.csv

这段代码会将当前目录下的所有文件名以此写到一个csv表格中。我们可以使用excel打开这个表格,修改其中数据的分类,或者根据自己的需要进行修改。当然不一定要使用csv,也可以使用txt后缀,这样文件就直接输出到txt文本中。
针对csv我们可以逐行或者逐列读取:

import pandas as pd
# 读取 CSV 文件
df = pd.read_csv('../Data/MNIST/file_name.csv')
# 获取列数据
train_data_file = df['train']
train_data_label = df['train_label']
test_data_file = df['test']
test_data_label = df['test_label']
# 打印列数据
print(train_data_file )
print(train_data_label)

表格数据:
在这里插入图片描述
程序输出:

0    0.png
1    1.png
2    2.png
3    3.png
4    4.png
5    5.png
Name: train, dtype: object
0    0
1    1
2    2
3    3
4    4
5    5
Name: train_label, dtype: int64

那么我们直接构建一个函数读取指定的数据内容的函数

import pandas as pd
from PIL import Image
import numpy as np
# 读取 CSV 文件
def readImage(image_path, file_path):csv_file = pd.read_csv(file_path)# 获取列数据train_data_file = csv_file['train']train_data_label = csv_file['train_label']test_data_file = csv_file['test']test_data_label = csv_file['test_label']x, y_ , t, yt_= [], [], [], []for _index in np.arange(0, train_data_file.shape[0], 1):if pd.notna(train_data_file.iloc[_index]):  # 判断如果数据非空img_ = Image.open(image_path + train_data_file[_index])img_ = np.array(img_.convert('L'))img_ = img_ / 255.  # 数据标准归一化x.append(img_)y_.append(train_data_label[_index])for _index in np.arange(0, test_data_file.shape[0], 1):if pd.notna(test_data_file.iloc[_index]):img_ = Image.open(image_path + test_data_file[_index])img_ = np.array(img_.convert('L'))img_ = img_ / 255.t.append(img_)yt_.append(test_data_label[_index])return (x, y_), (t, yt_)(train_img, train_lab), (test_img, test_lab)  = readImage(image_path='../Data/MNIST/', file_path='../Data/MNIST/file_name.csv')

把数据读取为numpy之后就可以进一步载入tf中,利用tf函数进行数据的扩充。扩充的方法主要有:随机平移、缩放、0填充、随机旋转。

img_prossess_Gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale='所有数据乘以这个数(倍乘)',horizontal_flip='是否随机水平旋转 Boolean',rotation_rang='随机旋转的角度范围 Int',width_shift_range='随机宽度偏移量',height_shift_range='随机高度便宜量',zoom_range='随机缩放的范围 Float or [lower, upper].'
)
img_prossess_Gen.fit(train_img)

所以结合前面分类博客中的代码,我们可以得到一个简单的例子:

cifar10 = tf.keras.datasets.cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
datagen = ImageDataGenerator(	# 定义数据扩充项featurewise_center=True,featurewise_std_normalization=True,rotation_range=20,width_shift_range=0.2,height_shift_range=0.2,horizontal_flip=True,validation_split=0.2)
datagen.fit(x_train)	# 扩充训练数据
model.fit(datagen.flow(x_train, y_train, batch_size=32,	# 根据扩充数据并辅助分组后进行训练subset='training'),validation_data=datagen.flow(x_train, y_train,batch_size=8, subset='validation'),steps_per_epoch=len(x_train) / 32, epochs=epochs)

到此我们实现了加载数据并且扩充数据,利用扩充数据实现网络训练的操作。

训练的曲谱——过程优化

当数据量较大的时候并且网络结构比较复杂的情况下,我们比较希望能够在训练的过程中按照一定阶段保存训练模型。并且在未来某个时候重新加载数据继续训练。当然,我们也可以在一个足够大的通用数据集先进行预训练,让网络学习到数据的通用特征。然后将得到的模型导出,并重新加载细节上符合工作研究要求的数据重新开始训练。
同时在训练的过程中,也希望网络能够记住每个迭代计算的结果,并且计时的把训练过程中最好的模型保存下来。这个操作在网络训练的后期会显得十分重要。

断点续训,模型保存和读取

首先应该指定一个模型的保存路径,如果在路径中已经有需要的历史训练数据,就直接加载历史模型。值得注意,保存的动态模型格式是 ckpt

# 记录模型保存路径
checkpoint_save_path = "./checkpoint/Baseline.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):print('-------------load the model-----------------')model.load_weights(checkpoint_save_path)     # 如果有模型则加载后使用

在加载之后还需要能够保存模型,这里使用keras提供的训练过程记录器,通过提供一个训练过程的回调函数来检测训练和保存模型。

# 定义保存和记录数据的回调器
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,save_weights_only=True,	# 保存模型save_best_only=True)	# 只保存最好的模型
# 使用数据扩充,并添加回调控制器,用来记录模型
history = model.fit(image_gen_train.flow(x_train, y_train, batch_size=32),epochs=5, validation_data=(x_test, y_test), validation_freq=1,callbacks=[cp_callback])
# 训练过程的历史参数可以通过 history 查看

参数查看

在上面的程序中已经实现了模型的保存。但是跟具体的还可以把网络中每个层中每个连接的权重参数偏置项,和卷积计算的结果,卷积核的参数,偏置保存下来。

# 保存网络权重参数
# print(model.trainable_variables)
file = open('./weights.txt', 'w')
for v in model.trainable_variables:		# 逐行的将网络中的所有参数写入到 weights.txt 文本文件中file.write(str(v.name) + '\n')file.write(str(v.shape) + '\n')file.write(str(v.numpy()) + '\n')
file.close()

训练过程可视化、tensorboard

如何观察参数的效果,除了在终端打印训练过程中的Loss,准确度以外,可以将这些关键的数据保存下来,这样调整参数后不同的效果就可以通过曲线图像的形式保存出来,便于观察变化趋势,指导设计者调节参数。在绘制曲线之前首先要在训练位置的函数输出 history 参数,后续通过调用这个参数中的数据在加上matplotlib来画出曲线。

# 显示训练集和验证集的acc和loss曲线
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

在这里插入图片描述

根据图像不难看出随着迭代次数的增加,训练集和测试集的测试准确率不断上升,损失值不断下降。此时可以增加训练的迭代次数,并微调相关参数,查看网络可能出现的不同的效果。
除了上面提及的讲训练结果导出画图的方法,还可以安装tensorboard,一般安装tensorflow2时会配套自动安装。只需要在模型训练结果保存的位置打开终端,启动对应的虚拟环境,然后输入 tensorboard,就可以在给出的网页中查看到实时的训练参数。在训练的函数中需要加入关于tensorboard的回调函数。

# 设置TensorBoard输出的回调函数
tf_callback = tf.keras.callbacks.TensorBoard(log_dir="./logs")	# 设置log文件存放地址,这里是相对路径,也可以使用绝对路径
# 使用数据扩充
history = model.fit(image_gen_train.flow(x_train, y_train, batch_size=64),epochs=8, validation_data=(x_test, y_test), validation_freq=1,callbacks=[cp_callback, tf_callback])	# 这里的回调包含保存模型和输出tensorboard训练参数

运行后再终端启动虚拟环境,输入指令加上你设置的log文件存放地址

tensorboard --logdir YOU_LOG_PATH

在这里插入图片描述
在tensorboard中可以观察模型的结构和训练中的参数数据。

训练的终章——结果可视化与数据评价

训练得到一个新的模型了,那如何使用它。这就需要再继续添加新的代码,用来单独加载模型并实现模型的前向推理,最后推理结果给出。
根据上面的例子,我们以手写数字的代码为例,可以很容易得到如下代码:

import cv2
# opencv-python==4.5.1.48
# 1. 加载图像
model_path = './checkpoint/mnist/mnist.ckpt'
image_path = '../Data/MNIST/'new_model = mnisModel()
new_model.load_weights(model_path)preNum = int(input("place input how many jpg file while be test:"))
for i in range(preNum):imgNum = int(input("place input png name:"))img_path = image_path+str(imgNum)+'.png'print("read image:{}".format(img_path))img_ = cv2.imread(img_path)resized_img = cv2.resize(img_, (28, 28), interpolation=cv2.INTER_AREA)gray_img = cv2.cvtColor(resized_img, cv2.COLOR_BGR2GRAY)# 4. 准备图像数据,进行归一化和添加批次维度cv2.imshow("input num", img_)img_for_prediction = gray_img.astype(np.float32) / 255.0  # 归一化到 [0, 1]img_for_prediction = np.expand_dims(img_for_prediction, axis=0)  # 添加批次维度result = model.predict(img_for_prediction)predNum = tf.argmax(result, axis=1)print("predice num is: ")tf.print(predNum)

通过上面的程序最终我们实现了加载已经有的模型,然后继续开始前向推理,并输出推理的结果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/630039.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Macos系统】安装VOSviewer及使用VOSviewer教程!!以ESN网络的研究进行案例分析

【Macos系统】安装VOSviewer及使用VOSviewer教程 以ESN网络的研究进行案例分析 本文介绍如何安装和使用VOSviewer软件,并以ESN(Echo State Network)网络的研究为案例进行分析。利用VOSviewer对相关文献进行可视化分析,并深入了解…

Linux之引导和服务篇

系统引导是操作系统运行的开始,在用户能够正常登录之前,Linux的引导过程完成了一系列的初始化任务,并加载必要的程序和命令终端,为用户登录做好准备。 一. 引导过程 开机自检--->MBR引导--->GRUB菜单--->加载Linux内核-…

商汤书生大模型一次可读 30 万汉字;2023 年 Shopee Live 超100万马来人注册;2023年中国出生人口902万人;

今日精选 • 商汤“书生・浦语”2.0 大语言模型开源:200K 上下文,一次可读 30 万汉字• 2023年中国出生人口902万人• 2023 年 Shopee Live 有超 100 万马来人注册并观看直播 投融资 • Airbnb 2 亿美元收购人工智能初创公司 Gameplanner.AI• 哥伦比…

【JavaEEj进阶】 Spring实现留言板

文章目录 🎍预期结果🍀前端代码🎄约定前后端交互接⼝🚩需求分析🚩接⼝定义 🌳实现服务器端代码🚩lombok 🌲服务器代码实现🌴运⾏测试 🎍预期结果 可以发布并…

Vcast工程创建

Vcast工程创建 1.新建项目,创建工程名称 2.创建该工程下的项目组 3.设置项目组环境的名字 4.选择需要测试的源代码文件 5.选择被测文件,点击build 6.出现报错,点击报错窗口的按钮 进入报错详细页,查看详细信息 报错内容如下 Unstubbed Enti…

使用原生input模拟器样式正常,但是真机上 input框溢出

目录 一、问题 二、解决方法 三、总结 tiips:如嫌繁琐,直接移步总结即可! 一、问题 1.使用原生input写了一个搜索框,在模拟器和pc上一切正常。但是打包放到手机上,样式就有问题:这个搜索框的布局是正常的&#xf…

广州银行IPO再添堵:原董事长被查,资产质量承压,罚单频现

撰稿|行星 来源|贝多财经 广州银行的上市之路,或因前高管涉嫌违纪再添一层阴云。 前不久,广州市纪委监委披露的信息显示,广州银行原党委书记、董事长姚建军涉嫌严重违纪违法,正在接受纪律审查和监察调查。据贝多财经了解&#…

代码随想录-刷题第五十七天

42. 接雨水 题目链接:42. 接雨水 思路:本题十分经典,使用单调栈需要理解的几个问题: 首先单调栈是按照行方向来计算雨水,如图: 使用单调栈内元素的顺序 从大到小还是从小到大呢? 从栈头&…

【AI】RTX2060 6G Ubuntu 22.04.1 LTS (Jammy Jellyfish) 部署Chinese-LLaMA-Alpaca-2

下载源码 cd ~/Downloads/ai git clone --depth1 https://gitee.com/ymcui/Chinese-LLaMA-Alpaca-2 创建venv python3 -m venv venv source venv/bin/activate安装依赖 pip install -r requirements.txt 已安装依赖列表 (venv) yeqiangyeqiang-MS-7B23:~/Downloads/ai/Chi…

软件测试【测试用例设计】面试题详解

前言 今天笔者想和大家来聊聊测试用例,这篇文章主要是想要写给测试小伙伴们的,因为我发现还是有很多小伙伴在遇到写测试用例的时候无从下手,我就想和大家简单的聊聊,这篇文章主要是针对功能测试的。 一、微信功能测试 1.点击点…

QGroundControl Qt安卓环境搭建及编译出现的问题

记录Qt 5.15.2搭建安卓环境出现的各种问题。 zipalign tool not found: D:/JavaAndroid/Android/sdk/build-tools//zipalign.exe? 答:需要将DANDROID_PLATFORM升级到已下载的版本. bin/llvm-readobj.exe: error: unknown argument ‘–libs’ 答&…

07-微服务getaway网关详解

一、初识网关 在微服务架构中,一个系统会被拆分为很多个微服务。那么作为客户端要如何去调用这么多的微服务呢?如果没有网关的存在,我们只能在客户端记录每个微服务的地址,然后分别去调用。这样的话会产生很多问题,例…

链表的中间节点

链表的中间节点 力扣(LeetCode)官网 - 全球极客挚爱的技术成长平台备战技术面试?力扣提供海量技术面试资源,帮助你高效提升编程技能,轻松拿下世界 IT 名企 Dream Offer。https://leetcode.cn/problems/middle-of-the-…

【1】SM4 CBC-MAC 机制

0x01 题目 MSG1: e55e3e24a3ae7797808fdca05a16ac15eb5fa2e6185c23a814a35ba32b4637c2 MAC1: 0712c867aa6ec7c1bb2b66312367b2c8 ----------------------------------------------------- MSG2: d8d94f33797e1f41cab9217793b2d0f02b93d46c2ead104dce4bfec453767719 MAC2: 4366…

Vue3的使用

一 Vue3的变化 1.性能的提升 打包大小减少41% 初次渲染快55%, 更新渲染快133% 内存减少54% 2.源码的升级 使用Proxy代替defineProperty实现响应式 重写虚拟DOM的实现和Tree-Shaking 3.拥抱TypeScript Vue3可以更好的支持TypeScript 4.新的特性 Composition API&#…

【UE 材质】简单的纹理失真、溶解效果

目录 1. 失真效果 2. 溶解效果 3. 失真溶解 我们一开始有这样一个纹理 1. 失真效果 其中纹理节点“DistortTexture”的纹理为引擎自带的纹理“T_Noise01”,我们可以通过控制参数“失真度”来控制纹理的失真程度 2. 溶解效果 3. 失真溶解

kafka简单介绍和代码示例

“这是一篇理论文章,给大家讲一讲kafka” 简介 在大数据领域开发者常常会听到MQ这个术语,该术语便是消息队列的意思, Kafka是分布式的发布—订阅消息系统。它最初由LinkedIn(领英)公司发布,使用Scala语言编写,与2010年…

HTML---Jquery选择器

文章目录 目录 文章目录 本章目标 一.Jquery选择器概述 二.Jquery选择器分类 基本选择器 层次选择器 属性选择器 三.基本过滤选择器 练习 本章目标 会使用基本选择器获取元素会使用层次选择器获取元素会使用属性选择器获取元素会使用过滤选择器获取元素 …

SQL Server 数据类型

文章目录 一、文本类型(字母、符号或数字字符的组合)二、整数类型三、精确数字类型四、近似数字(浮点)类型五、日期类型六、货币类型七、位类型八、二进制类型 一、文本类型(字母、符号或数字字符的组合) 在…

【物联网】物联网设备和应用程序涉及协议的概述

物联网设备和应用程序涉及协议的概述。帮助澄清IoT层技术栈和头对头比较。 物联网涵盖了广泛的行业和用例,从单一受限制的设备扩展到大量跨平台部署嵌入式技术和实时连接的云系统。 将它们捆绑在一起是许多传统和新兴的通信协议,允许设备和服务器以新的…