【深度学习_TensorFlow】过拟合

写在前面

过拟合与欠拟合


欠拟合: 是指在模型学习能力较弱,而数据复杂度较高的情况下,模型无法学习到数据集中的“一般规律”,因而导致泛化能力弱。此时,算法在训练集上表现一般,但在测试集上表现较差,泛化性能不佳。

过拟合: 是指模型在训练数据上表现很好,但在测试数据上表现不佳。这是由于模型过于复杂,记住了训练数据中的噪声和模式,而没有学到一般规则。本文将探讨过拟合问题以及其解决方法。

过拟合的解决方法


方法描述
增加训练数据
(More data)通过增加训练数据量,简单粗暴最有效,可以减少过拟合现象。
正则化
(regularization)在损失函数中添加一项,以惩罚模型的复杂度。常用的正则化方法包括L1正则化、L2正则化和dropout。
早停法
(Early stopping)在训练过程中,每次迭代后都会评估模型在验证集上的性能。如果性能在连续若干次迭代中没有提高,就停止训练。
数据增强
(Data augmentation)通过改变原有数据减少过拟合。例如,可以通过旋转、缩放等方式对图像数据进行增强。

写在中间

可以看到随着网络层数增加,模型变得复杂,过拟合现象变得愈发严重。接下来,我们将介绍一系列方法来帮助检测并抑制过拟合现象。

在这里插入图片描述

1. 交叉验证

增加数据集是最有效的方法,但是代价往往是昂贵的,所以要充分利用好现有的数据集。前面我们介绍了数据集需要划分为训练集和测试集,但我们为了挑选模型超参数和检测过拟合线性,一般需要将原来的训练集再次切分为新的训练集和验证集(validation set)。最终数据集被切分为 训练集、验证集、测试集。这三部分数据集的功能如下:

类别描述
训练集用于训练模型的参数,通过学习训练数据集来进行模型训练
验证集用于评估训练过程中的模型表现,调整模型的超参数
测试集用于评估最终训练好的模型在真实数据上的表现,测试验证模型的性能,评估模型的预测能力和泛化能力

验证集和测试集的区分

验证集使命:根据验证集的表现来调整模型的各种超参数的设置,提升模型的泛化能力。

测试集使命:就是检验模型的能力,其表现不能用来反馈模型的调整(就如你不能拿着期末考试原题来练习,否则期末高分就不能体现出你平时学习的真实状况),我们的办法就是从平常的练习题中抽取几道题组成验证集来检验你的能力。

这是一个将mnist手写数字识别测试集切分的例子,将6万张图像的前5万张划分为训练集,后1万张划分为验证集

(x, y), (x_test, y_test) = datasets.mnist.load_data()# 60k训练集切分为 50k训练集和 10k验证集
# (x_train, y_train), (x_val, y_val)
x_train, x_val = tf.split(x, num_or_size_splits=[50000, 10000])
y_train, y_val = tf.split(y, num_or_size_splits=[50000, 10000])# 训练集
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_db = train_db.map(preprocess).shuffle(10000).batch(128)# 验证集
val_db = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_db = val_db.map(preprocess).shuffle(10000).batch(128)# 测试集
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_db = test_db.map(preprocess).shuffle(10000).batch(128)

但是这样切分还是有局限性的,训练集只能在前5万张图像中出现,验证集只能在后1万张图像中出现,是否有方法让验证集能使用前5万张的图像呢?还真有,这就是标题所提及的**交叉验证:**我们将训练集的6万张图像划成n份,每训练取其中的 n - 1 份作为训练集,取 1 份作为验证集,而非固定前5万张为训练集,后 1 万张为验证集。

# 创建一个范围为60000的索引数组
idx = tf.range(60000)
# 随机打乱索引数组
idx = tf.random.shuffle(idx)
# 使用索引数组从x和y中获取训练集的样本
# 训练集的样本数量为50000
x_train, y_train = tf.gather(x, idx[:50000]), tf.gather(y, idx[:50000])
# 使用索引数组从x和y中获取验证集的样本
# 验证集的样本数量为10000
x_val, y_val = tf.gather(x, idx[-10000:]), tf.gather(y, idx[-10000:])

也可以不用手动划分,在训练函数中增加参数,即可自动化操作

# 使用60k训练集对网络进行训练,设置训练周期数为6
# 设置验证集占比为0.1,即数据集中10%的数据将作为验证集
# 设置每个2个训练周期进行一次验证
network.fit(train_db, epochs=6, validation_split=0.1, validation_freq=2)

2. 正则化

之所以出现过拟合的现象,是因为模型太过复杂,这里通过限制网络参数的稀疏性来约束网络的实际容量,这种约束一般通过在损失函数上添加额外的参数稀疏性惩罚项实现,常用的正则化的方式有L0、L1、L2正则化,dropout正则化。


简单介绍

关于正则化的数学原理,本人也搞不明白,也就不滥竽充数了。

有关实现简单概括就是在损失函数中引入模型权重参数的L范数,使学习到的权重参数稀疏化。

正则化的方式可以手动实现,也可以调用API实现:其中手动实现主要在计算loss值的后面,调用API主要在创建层的时候。

import tensorflow as tf
from tensorflow.keras import layers, regularizers# 网络构建
# 这会在模型损失函数中加入权重参数的L2范数作为惩罚项,力度由0.001控制。
network = Sequential([layers.Dense(256, kernel_regularizer=regularizers.l2(0.001), activation='relu'),layers.Dense(128, kernel_regularizer=regularizers.l2(0.001), activation='relu'),layers.Dense(64, kernel_regularizer=regularizers.l2(0.001), activation='relu'),layers.Dense(32, kernel_regularizer=regularizers.l2(0.001), activation='relu'),layers.Dense(10)])
# 参数构建
network.build(input_shape=(None, 28*28))
# 模型展示
network.summary()
# 截取手动前向计算的代码
for step, (x, y) in enumerate(train_db):# 创建一个 GradientTape,用于记录计算过程with tf.GradientTape() as tape:x = tf.reshape(x, (-1, 28*28))  # [b, 28, 28] => [b, 784]out = network(x)  # [b, 784] => [b, 10]y_onehot = tf.one_hot(y, depth=10)  # [b] => [b, 10]# 使用交叉熵损失函数计算 lossloss = tf.reduce_mean(tf.losses.categorical_crossentropy(y_onehot, out, from_logits=True))loss_regularization = []for p in network.trainable_variables:  # 重点在这里,遍历网络中所有的可训练参数(network.trainable_variables)。loss_regularization.append(tf.nn.l2_loss(p))  # # 对每个参数计算L2正则化项(tf.nn.l2_loss(p)),这会返回一个标量。loss_regularization = tf.reduce_sum(tf.stack(loss_regularization))  # # 将所有参数的L2正则化项求和,得到正则化损失loss_regularization。# 将损失函数定义为交叉熵损失和 L2 正则化损失的和loss = loss + 0.0001 * loss_regularization# 使用 tape 计算损失函数关于网络参数的梯度,并应用优化器进行反向传播更新参数grads = tape.gradient(loss, network.trainable_variables)optimizer.apply_gradients(zip(grads, network.trainable_variables))

正则化效果

在这里插入图片描述

Dropout


通过随机断开神经网络的连接,减少每次训练时实际参与计算的模型的参数量,但是在测试时,Dropout会恢复所有的连接,保证模型测试时获得最好的性能。

在这里插入图片描述

我们以层方式来实现以上功能

import tensorflow as tf
from tensorflow.keras import layers, regularizers# 网络构建network = Sequential([layers.Dense(256, activation='relu'),layers.Dropout(0.5),  # 有0.5的概率断开与下一层神经元的连接layers.Dense(128, activation='relu'),layers.Dropout(0.5),layers.Dense(64, activation='relu'),layers.Dense(32, activation='relu'),layers.Dense(10)])
# 参数构建
network.build(input_shape=(None, 28*28))
# 模型展示
network.summary()for step, (x, y) in enumerate(train_db):# 训练时with tf.GradientTape() as tape:out = network(x, training=True)# 测试时out = network(x, training=False)

3. Early stopping

早停法

那么如何选择合适的 Epoch 就提前停止训练(Early Stopping),避免出现过拟合现象呢?我们可以通过观察验证指标的变化,来预测最适合的 Epoch 可能的位置。具体地,对于分类问题,我们可以记录模型的验证准确率,并监控验证准确率的变化,当发现验证准确率连续𝑛个 Epoch 没有下降时,可以预测可能已经达到了最适合的 Epoch 附近,从而提前终止训练。

from tensorflow.keras.callbacks import EarlyStopping# 数据集读取···# 定义早停法回调函数
early_stopping = EarlyStopping(monitor='val_loss',  # 监视验证集losspatience=3,  # 当验证集loss在3个epoch内都没有改善则停止训练mode='min',  # 监测loss时一般设置为min,监测准确值时一般设置为maxverbose=1,  # 检测值改善时打印一条信息restore_best_weights=True  # 将权重恢复到最好的一个epoch)# 网络构建# 参数构建# 模型装配# 模型训练,添加参数network.fit(train_db, epochs=100,validation_data=val_db, validation_steps=10,callbacks=[early_stopping])

这里我们会对手写数字识别的代码再次进行修改,来使用上面提及的方法,你可以通过修改repeat的方式来复制数据集使训练数据增多,更改epochs的方式来增加训练次数,经过测试,这段代码在10 epochs 之后便达到了过拟合

import tensorflow as tf
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics, regularizers
from tensorflow.keras.callbacks import EarlyStopping
# 处理每一张图像
def preprocess(x, y):x = tf.cast(x, dtype=tf.float32) / 255.x = tf.reshape(x, [28 * 28])y = tf.cast(y, dtype=tf.int32)y = tf.one_hot(y, depth=10)return x, y# 数据集读取
(x, y), (x_test, y_test) = datasets.mnist.load_data()# # 固定切分
# # 60k训练集切分为 50k训练集和 10k验证集
# # (x_train, y_train), (x_val, y_val)
# x_train, x_val = tf.split(x, num_or_size_splits=[50000, 10000])
# y_train, y_val = tf.split(y, num_or_size_splits=[50000, 10000])# 交叉验证切分
idx = tf.range(60000)
idx = tf.random.shuffle(idx)
x_train, y_train = tf.gather(x, idx[:50000]), tf.gather(y, idx[:50000])
x_val, y_val = tf.gather(x, idx[-10000:]), tf.gather(y, idx[-10000:])# 训练集
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_db = train_db.map(preprocess).shuffle(10000).batch(128).repeat(10)# 验证集
val_db = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_db = val_db.map(preprocess).shuffle(10000).batch(128)# 测试集
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_db = test_db.map(preprocess).shuffle(10000).batch(128)# 定义早停法回调函数
early_stopping = EarlyStopping(monitor='val_loss',  # 监视验证集losspatience=3,  # 当验证集loss在2个epoch内都没有改善则停止训练mode='min',  # 监测loss时一般设置为min,监测准确值时一般设置为maxverbose=1,  # 检测值改善时打印一条信息restore_best_weights=True  # 将权重恢复到最好的一个epoch)# 网络构建
# 正则化:在模型损失函数中加入权重参数的L2范数作为惩罚项,力度由0.001控制。
# Dropout:添加dropout层来随机断开连接
network = Sequential([layers.Dense(256, kernel_regularizer=regularizers.l2(0.001), activation='relu'),layers.Dropout(0.5),layers.Dense(128, kernel_regularizer=regularizers.l2(0.001), activation='relu'),layers.Dropout(0.5),layers.Dense(64, kernel_regularizer=regularizers.l2(0.001), activation='relu'),layers.Dense(32, kernel_regularizer=regularizers.l2(0.001), activation='relu'),layers.Dense(10)])
# 参数构建
network.build(input_shape=(None, 28*28))
# 模型展示
network.summary()
# 模型装配
network.compile(optimizer=optimizers.Adam(learning_rate=0.01),loss=tf.losses.CategoricalCrossentropy(from_logits=True),metrics=['accuracy'])
# 模型训练
network.fit(train_db, epochs=50,validation_data=val_db, validation_steps=10,callbacks=[early_stopping])
# 模型评估
print('模型评估:')
network.evaluate(test_db)

写在最后

👍🏻点赞,你的认可是我创作的动力!
⭐收藏,你的青睐是我努力的方向!
✏️评论,你的意见是我进步的财富!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/52875.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Vue快速入门以及基础标签使用

目录 开始示例el挂载点data数据对象 vue基本标签v-textv-htmlv-on计数器示例实现v-showv-ifv-bind图片切换示例v-forv-on补充v-model axios网络请求axios基本使用vue中使用axios 开始示例 1.首先在html页面中引入vue的生产环境&#xff0c;在body标签中粘上下面代码 <scrip…

visual studio 2022.NET Core 3.1 未显示在目标框架下拉列表中

问题描述 在Visual Studio 2022我已经安装了 .NET core 3.1 并验证可以运行 .NET core 3.1 应用程序&#xff0c;但当创建一个新项目时&#xff0c;目标框架的下拉列表只允许 .NET 6.0和7.0。而我在之前用的 Visual Studio 2019&#xff0c;可以正确地添加 .NET 核心项目。 …

Windows平台Unity下播放RTSP或RTMP如何开启硬解码?

我们在做Windows平台Unity播放RTMP或RTSP的时候&#xff0c;遇到这样的问题&#xff0c;比如展会、安防监控等场景下&#xff0c;需要同时播放多路RTMP或RTSP流&#xff0c;这样对设备性能&#xff0c;提出来更高的要求。 虽然我们软解码&#xff0c;已经做的资源占有非常低了…

人力资源小程序的设计与开发步骤

在当前信息化时代&#xff0c;小程序成为了各行各业提升用户体验和服务效率的重要渠道。人力资源部门也可以通过定制开发人力资源小程序来提升招聘、培训、员工福利等方面的工作效率。接下来&#xff0c;我们将介绍人力资源小程序定制系统开发的具体流程。 首先&#xff0c;我们…

[JavaWeb]【十四】web后端开发-MAVEN高级

目录 一、分模块设计与开发 1.1 分模块设计 1.2 分模块设计-实践​编辑 1.2.1 复制老项目改为spring-boot-management 1.2.2 新建maven模块runa-pojo 1.2.2.1 将原项目pojo复制到runa-pojo模块 1.2.2.2 runa-pojo引入新依赖 1.2.2.3 删除原项目pojo包 1.2.2.4 在spring-…

微软 Visual Studio 现已内置 Markdown 编辑器,可直接修改预览 .md 文件

Visual Studio Code V1.66.0 中文版 大小&#xff1a;75.30 MB类别&#xff1a;文字处理 本地下载 Markdown 是一种轻量级标记语言&#xff0c;当开发者想要格式化代码但又不想牺牲易读性时&#xff0c;Markdown 是一个很好的解决方案&#xff0c;比如 GitHub 就使用 Markdo…

Cauchy’s integral formula

见&#xff1a;https://math.mit.edu/~jorloff/18.04/notes/topic4.pdf

uniapp使用sqlite 数据库

uniapp使用sqlite 数据库 傻瓜式使用方式&#xff0c;按步骤&#xff0c;即可使用。 1.开启sqlite 在项目中manifest.json该文件中配置 2.封装数据库的调用方法 const sqlName "zmyalh" //定义的数据库名称 const sqlPath "_doc/zmyalh.db" //定义数…

macOS M1使用TensorFlow GPU加速

本人是在pycharm运行代码&#xff0c;安装了tensorflow版本2.13.0 先运行代码查看有没有使用GPU加速&#xff1a; import tensorflow as tf# Press the green button in the gutter to run the script. if __name__ __main__:physical_devices tf.config.list_physical_dev…

Sentinel dashboard无法查询到应用的限流配置问题以及解决

一。问题引入 使用sentinle-dashboard控制台 项目整体升级后&#xff0c;发现控制台上无法看到流控规则了 之前的问题是无法注册上来 现在是注册上来了。结果看不到流控规则配置了。 关于注册不上来的问题&#xff0c;可以看另一篇文章 https://blog.csdn.net/a15835774652/…

linux 性能分析之内存分析(free,vmstat,top,ps,pmap等工具使用介绍)

引言 学生时代经常听到老师和同学说到学习 linux 的重要性。但是当时看到这个命令行界面就头疼&#xff0c;也就草草地应付学了一下&#xff0c;哎嘛&#xff0c;还是游戏香&#xff01; 但是当前两天自己捣鼓服务器的时候&#xff0c;发现自己部署的一个服务总是崩溃&#x…

python将png格式的图片转换为jpg格式的图片

png图片是4通道 RGBA图像&#xff0c;具有4个通道&#xff08;红色、绿色、蓝色和透明度&#xff09;&#xff0c;用于表示彩色图像以及透明度信息。 只是简单的修改后缀&#xff0c;并不能将png格式图片改为jpg格式。 将png格式的图片转换为jpg格式的图片 确保安装了pillow库…

Visual Studio 2022的MFC框架——theApp全局对象

我是荔园微风&#xff0c;作为一名在IT界整整25年的老兵&#xff0c;今天我们来重新审视一下Visual Studio 2022下开发工具的MFC框架知识。 MFC中的WinMain函数是如何与MFC程序中的各个类组织在一起的呢&#xff1f;MFC程序中的类是如何与WinMain函数关联起来的呢&#xff1f…

UWB高精度人员定位系统源码,微服务+java+ spring boot+ vue+ mysql技术开发

工业物联网感知预警体系&#xff0c;大中小企业工业数字化转型需求的工业互联网平台 工厂人员定位系统是指能够对工厂中的人员、车辆、设备等进行定位&#xff0c;实现对人员和车辆的实时监控与调度的系统&#xff0c;是智慧工厂建设中必不可少的一环。由于工厂的工作环境比较…

Linux(基础篇一)

Linux基础篇 Linux基础篇一1. Linux文件系统与目录结构1.1 Linux文件系统1.2 Linux目录结构 2. VI/VIM编辑器2.1 vi/vim是什么2.2 模式间的转换2.3 一般模式2.4 插入模式2.4.1 进入编辑模式2.4.2 退出编辑模式 2.5 命令模式 3. 网络配置3.1 网络连接模式3.2 修改静态ip3.3 配置…

优维产品最佳实践第5期:什么是持续集成?

谈到到DevOps&#xff0c;持续交付流水线是绕不开的一个话题&#xff0c;相对于其他实践&#xff0c;通过流水线来实现快速高质量的交付价值是相对能快速见效的&#xff0c;特别对于开发测试人员&#xff0c;能够获得实实在在的收益。 本期EasyOps产品使用最佳实践&#xff0c…

先进API生产力工具eqable HTTP,一站式开发调试工具推荐

简介 Reqable是什么? Regable Fiddler/Charles Postman Reqable是HTTP一站式开发调试国产化解决方案&#xff0c;拥有更便捷的体验&#xff0c;更先进的协议&#xff0c;更高效的性能和更精致的界面。 Reqable是一款跨平台的专业HTTP开发和调试工具&#xff0c;在全平台支持…

【业务功能篇83】微服务SpringCloud-ElasticSearch-Kibanan-docke安装-应用层实战

五、ElasticSearch应用 1.ES 的Java API两种方式 Elasticsearch 的API 分为 REST Client API&#xff08;http请求形式&#xff09;以及 transportClient API两种。相比来说transportClient API效率更高&#xff0c;transportClient 是通过Elasticsearch内部RPC的形式进行请求…

基于OpenCV实战(基础知识二)

目录 简介 1.ROI区域 2.边界填充 3.数值计算 4.图像融合 简介 OpenCV是一个流行的开源计算机视觉库&#xff0c;由英特尔公司发起发展。它提供了超过2500个优化算法和许多工具包&#xff0c;可用于灰度、彩色、深度、基于特征和运动跟踪等的图像处理和计算机视觉应用。Ope…

水果flstudio好用吗?中文版FL21最新版本如何下载

FL Studio21版是一款功能强大的音乐制作软件&#xff0c;广泛应用于电子音乐、流行音乐、电影配乐等领域。它提供了丰富多样的音频合成和编辑工具&#xff0c;使音乐制作变得更加灵活多样。无论是初学者还是专业音乐制作人&#xff0c;都可以通过直观的界面和丰富的音频特效来实…