基于python开发用于深度学习模型训练过程loss值曲线的平滑处理模块

深度学习网络模型的loss曲线是训练过程中非常重要的一个监控指标,它能够直观地反映模型的学习状态以及可能存在的问题。以下是对深度学习网络模型loss曲线的详细介绍:

一、loss曲线的基本概念

在深度学习的训练过程中,loss函数用于衡量模型预测结果与实际标签之间的差异。loss曲线则是通过记录每个epoch(或者迭代步数)的loss值,并将其以图形化的方式展现出来,以便我们更好地理解和分析模型的训练过程。

二、loss曲线的解读

  1. loss值的变化趋势:
    • 如果loss值随着训练的进行而逐渐降低,说明模型正在学习并优化,这是一个正常的训练过程。
    • 如果loss值在训练初期迅速下降,但随后趋于稳定或波动较小,可能意味着模型已经收敛,或者陷入了局部最优解。
    • 如果loss值在训练过程中出现剧烈波动,可能是学习率设置不当、模型结构复杂度过高等原因导致的。
  2. 训练和验证loss的对比:
    • 训练loss和验证loss的差距可以反映模型的过拟合程度。如果训练loss持续下降而验证loss却开始上升,说明模型可能出现了过拟合现象。
    • 理想的训练过程应该是训练loss和验证loss都逐渐下降,且两者之间的差距较小。
  3. 不同阶段的loss变化:
    • 在训练初期,由于模型参数是随机初始化的,因此loss值通常会比较大。随着训练的进行,loss值会逐渐降低并趋于稳定。
    • 在训练后期,如果模型没有出现过拟合现象,loss值应该能够稳定在一个较低的水平上。

三、loss曲线的绘制与监控

在深度学习框架(如TensorFlow、PyTorch等)中,通常都提供了绘制loss曲线的功能。通过调用相应的API或库(如matplotlib、Visdom等),我们可以方便地绘制出训练过程中的loss曲线,并对其进行实时监控和分析。

四、loss曲线的优化策略

针对loss曲线反映出的问题,我们可以采取以下优化策略:

  1. 调整学习率:学习率是影响loss曲线变化的重要因素之一。如果学习率设置得过大,可能会导致loss值在训练过程中出现剧烈波动;如果学习率设置得过小,则可能会导致训练过程过于缓慢。因此,我们需要根据loss曲线的变化情况来适时调整学习率的大小。
  2. 添加正则化项:正则化项可以有效地防止模型过拟合。通过向损失函数中添加正则化项(如L1正则化、L2正则化等),我们可以限制模型参数的复杂度,从而降低过拟合的风险。
  3. 使用更复杂的模型结构:如果模型的复杂度不够高,可能无法充分拟合训练数据中的复杂模式。在这种情况下,我们可以尝试使用更复杂的模型结构(如增加网络层数、使用更复杂的激活函数等)来提高模型的拟合能力。
  4. 增加训练数据:增加训练数据可以提供更多的信息供模型学习,从而降低过拟合的风险。如果条件允许的话,我们可以尝试增加训练数据的数量或多样性来提高模型的性能。

实际工作中,经常会需要训练构建深度学习模型,相信做这块工作的同学对于loss曲线一定不会陌生的,大家肯定也都经常在模型开发过程中实际去绘制模型的loss曲线,在一些特殊的场景下需要对原始产生的loss曲线进行平滑处理,这里主要是记录实践这块的内容。

这里我以经常使用的keras框架,介绍下我常用的讲模型训练过程日志进行记录存储的方式,核心代码实现如下:

#记录日志
history = model.fit(X_train,y_train,validation_data=(X_test, y_test),#传入回调callbacks=[checkpoint],epochs=nepoch,batch_size=32,
)
print(history.history.keys())
# loss提取
lossdata, vallossdata = history.history["loss"], history.history["val_loss"]
# 绘制loss曲线
plot_both_loss_acc_pic(lossdata, vallossdata, picpath=saveDir + "train_val_loss.png"
)
history = {}
#提取训练过程对应的log
history["loss"], history["val_loss"] = lossdata, vallossdata
#存储日志数据
with open(saveDir + "history.json", "w") as f:f.write(json.dumps(history))

这里我给出history.json的样例数据,如下所示:

{"loss": [0.0239631230070447, 0.0075705342770514186, 0.004838935030165967, 0.0037340148873459002, 0.002886130001751231, 0.0024534663854011295, 0.0023201104651917924, 0.002976924244579323, 0.002085769131966776, 0.0018753843622715224, 0.0019806173175960172, 0.002174197305382795, 0.001658159012761194, 0.001545024904081888, 0.001667008826952705, 0.0013947380403409929, 0.0012537746476829388, 0.0014786866023657216, 0.0016623390785946131, 0.0016191040555174983, 0.0014966548395261134, 0.001477120676483648, 0.0016280919435364668, 0.0017182350213351422, 0.0038554738028685545, 0.0027464564262130392, 0.0017087835722348526, 0.0014510032096478255, 0.001268975875018749, 0.001481830868523139, 0.001604654047318627, 0.0011948789410770326, 0.001490574051798416, 0.001524109376014187, 0.0015062743931394868, 0.0013054789145924908, 0.0011241542828178905, 0.0010764475793075279, 0.0011480460991939996, 0.0012678029520214276, 0.0012396599495106504, 0.0011639709738934618, 0.0012134075943700145, 0.0012499850020485322, 0.001329989843023205, 0.0011846670753724083, 0.001357856133803473, 0.0015265580290890642, 0.0012421558107066537, 0.001249045898042552, 0.0013697822622925414, 0.0010749583650784015, 0.0010974660338928532, 0.0010916401195769782, 0.0010911698460223627, 0.001078350035803241, 0.001045568893730859, 0.001084814094107926, 0.0011569271895574074, 0.0011443737715600441, 0.001247118570911225, 0.0012540589338988402, 0.0011518743058927274, 0.0015513227900919035, 0.0017111857056945697, 0.0015170943725414776, 0.001481423723410487, 0.0011165965530857377, 0.0016210588698042031, 0.002381780790270182, 0.0011541179547393296, 0.0013710694562572288, 0.0012280710985459404, 0.001037340645381916, 0.0010694707121014697, 0.0009750368017871479, 0.001008019566722004, 0.0011101727457727573, 0.0012511928422843225, 0.001071397447170223, 0.0011470449074543591, 0.0015238439674756194, 0.0010109543884446336, 0.0011297101726488506, 0.001058421874235954, 0.001103364821769398, 0.001025826505723811, 0.0010999036314539848, 0.001329398845137427, 0.0017114742325290903, 0.0011102726525873048, 0.0011274378092930091, 0.0011542693009646294, 0.0011940637438370937, 0.0012636104229160712, 0.0013925317771055863, 0.00100061368093664, 0.0011615896552567776, 0.0010081333990953022, 0.001092779955855081], "val_loss": [0.004923976918584422, 0.00991965542106252, 0.01076323433877214, 0.003843901578434988, 0.011352231488318036, 0.0016448196832482753, 0.0016787166668923179, 0.02015221753696862, 0.003941944209049995, 0.0029026116281257648, 0.0045372380556440665, 0.0014563935330155992, 0.001654032355260202, 0.0013641683946173688, 0.0015195327850769421, 0.0011578488043405262, 0.0012232080662598539, 0.00509458419768826, 0.0012246073744455843, 0.0023663273782738923, 0.0011423173363590124, 0.006865146876263775, 0.0020036918448137217, 0.007316410553788668, 0.001553758288929729, 0.0013593508740948317, 0.0025380967877266045, 0.0023082743653120765, 0.0013224915555359697, 0.00858367411909919, 0.0009927515703326974, 0.0010470627885201553, 0.0011798253622959907, 0.0024045295798905976, 0.0015412836871722614, 0.0038771789925368992, 0.0015362703578399592, 0.001756014192697445, 0.00334732801114258, 0.000975109149983741, 0.0046767660281866, 0.0018946981394516401, 0.0021767043220614524, 0.004211987026869074, 0.0009522750635177975, 0.0021094563270085734, 0.0037733877482088772, 0.001548874757549799, 0.0027838850510306656, 0.008273044527557335, 0.00123940688829048, 0.0016841785786183257, 0.0009756766973479994, 0.001928586675479126, 0.0011492695222075685, 0.0012013394433827336, 0.0010477521618380897, 0.00121309975940293, 0.0030147337820380926, 0.0013649057897150909, 0.0023210468165895067, 0.0011219763923068775, 0.0017544153219971217, 0.0030385015789713516, 0.0016239398731206739, 0.0031037202962723217, 0.002162101651590906, 0.003717466969484169, 0.0033957000386803165, 0.0009902583321826043, 0.00193247984708777, 0.001976198960389746, 0.0027693257654870028, 0.0025635553493262514, 0.0013357499459648113, 0.0012082410958675226, 0.001168333794186382, 0.0025652966841957286, 0.0010059437105634347, 0.0009358364489775054, 0.0036403173617528457, 0.0009317236960142556, 0.0015049418612187238, 0.0017247698554695634, 0.0010254738238903596, 0.001047871537096063, 0.0009514076437641818, 0.0036001800608478096, 0.014663169037942824, 0.002012193938227076, 0.001970677826351388, 0.0037272977164799445, 0.0012484785829560438, 0.002330363199182198, 0.0011025683723059237, 0.0013020975239525893, 0.001059662765137067, 0.0009167807317632986, 0.0009290355350362676, 0.0012791703282058924]}

接下来看下原始loss数据绘制出来的对比可视化曲线:

整体波形不断,也反映了模型实际训练过程并不够稳定,这里抛开模型训练的因素,单纯地基于曲线数据进行分析,想要对其进行平滑处理,得到的效果如下:

核心实现就是使用scipy.signal.savgol_filter方法,scipy.signal.savgol_filter 是 SciPy 库中的一个函数,用于对一维数据序列应用 Savitzky-Golay 平滑滤波器。这个滤波器是一种局部多项式回归的技术,能够在平滑数据的同时尽量保留数据的特征形状,如峰值和谷值,因此特别适用于信号去噪和数据平滑处理,尤其是那些包含噪声的实验数据或时序数据。

以下是 savgol_filter 函数的主要参数及其说明:

  • x (array_like):要过滤的一维数据序列。如果 x 不是单精度或双精度浮点数数组,它将在过滤前被转换为这种类型。

  • window_length (int):滤波器窗口的长度,即应用于数据点上的局部多项式拟合所使用的相邻数据点的数量。这个值必须是奇数,并且 polyorder + 1 <= window_length

  • polyorder (int):拟合局部数据点的多项式的阶数。它决定了平滑程度,阶数越高,可以拟合更复杂的曲线,但也会更多地改变原始数据的特性。必须满足 polyorder < window_length

  • deriv (int, 可选):指定是否计算导数以及计算哪阶导数。默认为0,表示直接平滑数据;大于0的值用于计算相应阶数的导数。

  • delta (float, 可选):采样点之间的间距,默认为1.0。仅在计算导数(deriv > 0)时使用。

  • axis (int, 可选):当输入数据 x 的维度大于1时,指定沿哪个轴应用滤波器。默认为-1,表示最后一个轴。

  • mode (str, 可选):决定如何处理边界效应,可选值有 'mirror''constant''nearest''wrap''interp'。默认为 'interp',表示通过线性插值来扩展数据以处理边界。选择 'mirror' 会在边界处镜像数据,而 'constant' 则会使用边缘值填充。

  • cval (float, 可选):当 mode'constant' 时使用的常数值。默认为0.0。

使用示例

假设我们有一个包含噪声的一维数据列表 data,我们可以使用 savgol_filter 来平滑这些数据:

from scipy.signal import savgol_filter
import numpy as np# 假设 data 是一个包含噪声的数据序列
data = np.random.randn(100)  # 生成随机噪声数据作为示例
window_length = 5  # 窗口长度
polyorder = 3  # 多项式阶数# 应用 Savitzky-Golay 滤波器
smoothed_data = savgol_filter(data, window_length, polyorder)# 然后可以绘制原始数据和平滑后的数据进行对比
import matplotlib.pyplot as pltplt.figure()
plt.plot(data, label='Noisy data')
plt.plot(smoothed_data, label='Smoothed data')
plt.legend()
plt.show()

借助于scipy.signal.savgol_filter方法,我们可以非常方便快捷地实现对原生loss曲线的平滑化处理,这里为了直观对比效果,我们绘制对比可视化曲线,如下所示:

有需要的也都可以尝试下。

完整代码实现如下:

def lossPloter(train_loss,val_loss):"""loss曲线对比可视化"""iters = range(len(train_loss))#单独绘制原始loss曲线plt.clf()plt.figure(figsize=(10,6))plt.plot(iters, train_loss, 'red', linewidth = 2, label='train loss')plt.plot(iters, val_loss, 'coral', linewidth = 2, label='val loss')plt.xlabel('Epoch')plt.ylabel('Loss')plt.title('A Loss Curve')plt.legend(loc="upper right")plt.savefig("original_loss.png")num = 5 if len(train_loss)<25 else 15#插值平滑处理train_loss_smooth=scipy.signal.savgol_filter(train_loss, num, 3)val_loss_smooth=scipy.signal.savgol_filter(val_loss, num, 3)for i in range(5):val_loss_smooth=scipy.signal.savgol_filter(val_loss_smooth, num, 3)#二者同时绘制plt.clf()plt.figure(figsize=(10,6))plt.plot(iters, train_loss, 'red', linewidth = 2, label='train loss')plt.plot(iters, val_loss, 'coral', linewidth = 2, label='val loss')plt.plot(iters, train_loss_smooth, 'green', linestyle = '--', linewidth = 2, label='smooth train loss')plt.plot(iters, val_loss_smooth, '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss')plt.xlabel('Epoch')plt.ylabel('Loss')plt.title('A Loss Curve')plt.legend(loc="upper right")plt.savefig("compare_loss.png")plt.cla()plt.close("all")#单独绘制平滑曲线plt.clf()plt.figure(figsize=(10,6))plt.plot(iters, train_loss_smooth, 'green', linewidth = 2, label='smooth train loss')plt.plot(iters, val_loss_smooth, 'blue', linewidth = 2, label='smooth val loss')plt.xlabel('Epoch')plt.ylabel('Loss')plt.title('Loss Curve')plt.legend(loc="upper right")plt.savefig("smooth_loss.png")

会得到三幅图像:
original_loss.png: 原始loss对比曲线

smooth_loss.png: 平滑化的loss对比曲线

compare_loss.png: 二者对比曲线

感兴趣的话可以尝试下!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/19471.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

0521_网络编程5

练习1&#xff1a; TFTP通信过程总结 服务器在69号端口等待客户端的请求服务器若批准此请求&#xff0c;则使用 临时端口 与客户端进行通信。每个数据包的编号都有变化&#xff08;从1开始&#xff09;每个数据包都要得到ACK的确认&#xff0c;如果出现超时&#xff0c;则需要…

骑车不戴头盔监测摄像机

骑行是一种健康的出行方式&#xff0c;但是在骑行途中不戴头盔存在安全隐患&#xff0c;容易造成头部受伤。为了规范骑行行为&#xff0c;保障骑行安全&#xff0c;可以考虑使用骑车不戴头盔监测摄像机进行监测和识别。这种摄像机可以通过智能识别技术&#xff0c;实时监测骑自…

装机数台,依旧还会心念i5-12600KF的性能和性价比优势:

近几个月的时间中&#xff0c; 装机差不多4台电脑&#xff0c;由于工作需要&#xff0c;计划年中再增添一台。 目前市场上英特尔CPU促销非常火爆&#xff0c;第12代、第13代以及第14代的产品在年中有适当的优惠。 年中也是装机的旺季&#xff0c;各种相关配件也相对便宜一些。…

PS系统教学02

多个图片同时进行打开 在素材库里面选中两张图片&#xff0c;直接拖进PS软件中&#xff0c;此时会显示其中一张。当按下回车键会显示另一张。 当图层过多&#xff0c;需要进行选择&#xff0c;其中某一张图片&#xff0c;按住Ctrl键&#xff0c;进行选择点击&#xff0c;可以移…

制造企业如何通过PLM系统实现BOM管理的飞跃

摘要 在当今快速变化的制造行业中&#xff0c;产品生命周期管理&#xff08;PLM&#xff09;系统的应用已成为企业提升效率、降低成本和增强竞争力的关键。本文将探讨PLM系统如何通过其先进的BOM&#xff08;物料清单&#xff09;管理功能&#xff0c;帮助制造企业在整个产品生…

idea+tomcat+mysql 从零开始部署Javaweb项目(保姆级别)

文章目录 新建一个项目添加web支持配置tomcat优化tomcat的部署运行tomcatidea数据库连接java连接数据库 新建一个项目 new project&#xff1b;Java&#xff1b;选择jdk的版本&#xff1b;next&#xff1b;next&#xff1b;填写项目名字&#xff0c;选择保存的路径&#xff1b;…

【Linux 网络编程】协议的分层知识!

文章目录 1. 计算机网络背景2. 认识 "协议"3. 协议分层 1. 计算机网络背景 网络互联: 多台计算机连接在一起, 完成数据共享; &#x1f34e;局域网&#xff08;LAN----Local Area Network&#xff09;: 计算机数量更多了, 通过交换机和路由器连接。 &#x1f34e; 广…

基于 Arm 虚拟硬件的 TinyMaix 超轻量级神经网络推理框架的项目实践

本实验过程中所显示的优惠价格及费用报销等相关信息仅在【Arm AI 开发体验创造营】体验活动过程中有效&#xff0c;逾期无效&#xff0c;请根据实时价格自行购买和体验。同时&#xff0c;感谢本次体验活动 Arm 导师 Liliya 对于本实验手册的共创与指导。 详见活动地址&#xff…

Vue使用axios实现调用后端接口

准备后端接口 首先&#xff0c;我已经写好一个后端接口用来返回我的用户数据&#xff0c;并用Postman测试成功如下&#xff1a; 以我的接口为例&#xff0c;接口地址为&#xff1a;http://localhost:8080/user/selectAll 返回Json为&#xff1a; {"code": "2…

docker制作高版本jdk17镜像踩坑

1、创建目录并下载jdk上传到服务器中 从jdk官网下载jdk17镜像&#xff0c;提示&#xff1a;下载到本地用xftp上传到服务器&#xff08;速度会快点&#xff09; jdk官网&#xff1a;https://www.oracle.com/java/technologies/downloads/#graalvmjava21 创建目录&#xff0c;将…

Ubuntu系统编译内核——deb安装 / install安装

摘要 本文简要记录两种编译内核的方法&#xff1a; 打包成deb模块安装&#xff08;推荐&#xff09;&#xff1b;直接make install安装&#xff1b; 更推荐使用——打包成deb模块安装&#xff0c;因为可以方便的拷贝下次其他机器使用。 1. 编译环境准备 系统&#xff1a;lin…

强化学习——学习笔记3

一、强化学习都有哪些分类&#xff1f; 1、基于模型与不基于模型 根据是否具有环境模型&#xff0c;强化学习算法分为两种&#xff1a;基于模型与不基于模型 基于模型的强化学习(Model-based RL)&#xff1a;可以简单的使用动态规划求解&#xff0c;任务可定义为预测和控制&am…

cesium 实现自定义弹窗并跟随场景移动

cesium 添加点位自定义弹窗跟随场景移动 完整代码演示可直接copy使用 1 效果图&#xff1a; 2 深入理解 就是原始点位的数据 id>property 点位真实渲染到球体上的笛卡尔坐标系 id>_polyline 的路径下 可以通过 3 代码示例 <!DOCTYPE html> <html lang"…

【数据分享】2017-2023年全球范围10米精度土地覆盖数据

土地覆盖数据是我们在各项研究中都非常常用的数据&#xff0c;土地覆盖数据的来源也有很多。之前我们分享过欧空局发布的2020年和2021年的10米分辨率的土地覆盖数据,也分享过我国首套1米分辨率的土地覆盖数据&#xff08;均可查看之前的文章获悉详情&#xff09;&#xff01; …

管道液位传感器可以检测哪些液体?

管道液位传感器是一种专门用于检测流动性比较好的液体的传感器装置。它采用光学感应原理&#xff0c;不涉及任何机械运动&#xff0c;具有长寿命、安装方便和微功耗的特点。相比传统机械式液位传感器&#xff0c;光电管道传感器有效解决了低精度和卡死失效等问题&#xff0c;同…

Django 解决 CSRF 问题

在 Django 出现 CSRF 问题 要解决这个问题&#xff0c;就得在 html 里这么修改 <!DOCTYPE html> <html><head></head><body><form action"/login/" method"post">{% csrf_token %}</form></body> </…

短视频脚本创作的五个方法 沈阳短视频剪辑培训

说起脚本&#xff0c;我们大概都听过影视剧脚本、剧本&#xff0c;偶尔可能在某些综艺节目里听过台本。其中剧本是影视剧拍摄的大纲&#xff0c;用来指导影视剧剧情的走向和发展&#xff0c;而台本则是综艺节目流程走向的指导大纲。 那么&#xff0c;短视频脚本是什么&#xf…

探析GPT-4o:技术之巅的跃进

如何评价GPT-4o? 简介&#xff1a;最近&#xff0c;GPT-4o横空出世。对GPT-4o这一人工智能技术进行评价&#xff0c;包括版本间的对比分析、GPT-4o的技术能力以及个人感受等。 随着人工智能领域的不断发展&#xff0c;GPT系列模型一直处于行业的前沿。最近&#xff0c;GPT-4…

前端实习记录——git篇(一些问题与相关命令)

1、版本控制 &#xff08;1&#xff09;版本回滚 git log // 查看版本git reset --mixed HEAD^ // 回滚到修改状态&#xff0c;文件内容没有变化git reset --soft HEAD^ // 回滚暂存区&#xff0c;^的个数代表几个版本git reset --hard HEAD^ // 回滚到修改状态&#xff…

生态农业:引领未来农业新篇章

生态农业&#xff0c;正以其独特的魅力和创新理念&#xff0c;引领着未来农业发展的新篇章。在这个充满变革的时代&#xff0c;我们需要更加关注农业的可持续发展&#xff0c;而生态农业正是实现这一目标的重要途径。 生态农业产业的王总说&#xff1a;生态农业强调生态平衡和可…