深度学习实际使用经验总结

以下仅是个人在使用过程中的经验总结,请谨慎参考。

常用算法总结

图像分类

  • 常用算法(可作为其他任务的骨干网络):
  • 服务端:VGG、ResNet、ResNeXt、DenseNet
  • 移动端:MobileNet、ShuffleNet等
  • 适用场景:识别区分场景类型

目标检测

  • 常用算法:Yolo系列
  • 适用场景:检测识别场景中的目标类型及位置

目标跟踪

  • 单目标:SiamFC、SiamRPN、SiamRPN++
  • 多目标:ByteTrack、deep sort
  • 适用场景:在上下文连续多帧中持续定位目标位置变化,如跟踪人的位置变化

人脸识别

  • 人脸检测:MTCNN、RetinaFace
  • 特征提取:ArcFace
  • 特征匹配:Nmslib、Annoy
  • 适用场景:直播、点播场景下人脸坐标定位和人物身份识别

度量学习

  • 常用算法:神经网络模型 + Contrastive Learning
  • 适用场景:将输入转换为高维特征,根据特征相似度完成具体任务,如音乐识别等

音频识别

  • 常用算法:音频频谱 + 神经网络模型
  • 适用场景:如识别音频中是否出现音乐,以及出现的是哪一首音乐

传统图像算法

  • 常用算法:边缘检测、直线检测、透视变化、坐标映射
  • 适用场景:如识别足球场地边缘,根据球员实时位置构建赛场态势图

常用技术总结

网络设计

  • 卷积网络
  • 池化
  • 全连接网络
  • 残差网络
  • 分组卷积
  • 深度可分离卷积
  • Inception网络
  • BatchNorm/LayerNorm
  • 循环神经网络
  • Transformer网络

特征融合

  • FPN特征金字塔
  • PAN特征融合
  • UNet细粒度特征融合
  • Attention全局注意力
  • 多分支融合叠加

激活函数

  • Sigmoid
  • ReLU
  • Tanh
  • Leakey-ReLU
  • PReLU
  • Swish

优化算法

  • 随机梯度下降
  • 批量梯度下降
  • 小批量梯度下降
  • 动量梯度下降
  • 自适应梯度下降

损失函数

  • 单分类损失Logistic Loss
  • 多分类损失CE
  • 多标签分类损失BCE
  • 类别不均衡损失Focal Loss
  • 回归损失MSE、L1 Smoothing
  • 多任务损失
  • Contrastive Loss
  • Label Smoothing

数据增强

  • 平移/旋转/翻转/缩放
  • 添加高斯噪声
  • 亮度/对比度/饱和度变换
  • 数据融合MixUp
  • 随机掩码Cutout
  • 频谱掩码(音频数据增强)

模型压缩

  • 模型剪枝
  • 模型重参数化
  • 模型蒸馏
  • 模型量化

训练策略

  • 迁移学习
  • 权重衰减
  • 学习率衰减
  • WarmUp预热
  • DropOut/DropBlock
  • 并行训练
  • Early Stopping

使用经验总结

数据预处理:在数据预处理阶段,通常包含以下流程

  • 数据加载:这部分可能涉及到大规模训练数据的高性能加载,消除数据读取造成的性能瓶颈
  • 数据增强:根据具体任务选择合适的数据增强策略
  • 数据归一化:将数据的数值做归一化处理,加速模型的收敛速度

网络模型设计:根据具体的业务场景,选择/设计合适的网络架构,通常遵循以下原则:

  • 优先选择开源的预训练模型参数,使用自有数据进行微调训练
  • 使用开源预训练模型的基础部分,修改模型的上层适配自有的业务场景
  • 根据业务场景自行设计模型架构(通常只在没有开源参考模型的条件下使用,模型效果很难得到保证)
  • 通常情况下对于图像、视频数据,追求模型效果一般使用ResNet34、ResNet50等模型架构,追求处理性能一般使用ResNet18、MobileNet、ShuffleNet模型架构

损失函数设计:根据具体业务场景,选择/设计合适的损失函数,通常遵循以下原则:

  • 分类场景一般使用交叉熵损失Cross Entropy
  • 类别极度不均衡的分类场景(不均衡比例超过1000以上),分类损失尝试Focal Loss
  • 在分类损失场景下使用Label Smoothing软化模型的学习能力
  • 回归损失根据使用场景可选MSE、MAE、L1 Smoothing等
  • 表征学习(也叫对比学习或度量学习)优先使用Circle Loss、ArcFace Loss、Triplet Loss等损失函数,模型训练时根据具体业务场景适当增大训练批量batch size,提升表征学习的特征区分度

训练策略:迭代训练更新模型参数,逐步提升模型效果

  • 优先使用小批量梯度下降 + Momentum动量 + LR decay学习率衰减 + Weight Decay权重衰减训练策略
  • 如果使用多GPU并行训练,训练的批量大小batch size和学习率LR要同比例改变
  • 先使用极少的训练数据验证模型训练效果,排除模型在设计上的问题和工程化问题,然后再迁移到大数据量,避免造成无效的资源和时间浪费
  • 模型训练过程中保存每轮迭代预测异常的数据,通过bad case分析,逐步提升数据质量或调整模型策略
  • 保存模型每轮迭代的准确率、召回率、loss变化曲线,监控是否发生过拟合、欠拟合等问题

模型验证:使用实际场景数据验证模型的预测性能,包括效果和速度

  • 优先检查网络架构和模型参数的匹配情况,防止参数不匹配带来的潜在错误(问题较难定位)
  • 使用训练数据验证模型的预测效果,排除训练和验证阶段数据处理不一致带来的差异
  • 保证验证数据和训练数据具有相同或相似的数据分布
  • 将模型转换为推理模式,固化BatchNorm、Dropout等具有随机性的操作
  • 避免不必要的资源消耗,如使用torch.no_grad避免不必要的显存占用
  • 关注模型在推理阶段的显卡内存占用,保证分配的内存资源大于波动的峰值

常见问题和思考——模型训练阶段

模型训练效果差

  • 常见问题:神经网络模型在训练过程中,没有学习到有效的信息,模型收敛慢或者不收敛
  • 原因分析:
    • 训练数据中存在脏数据(首要排除的因素)
    • 没有使用预训练模型和数据归一化
    • 训练集过大,模型容量小,出现欠拟合
    • 训练集过小,模型容量大,出现过拟合
    • 学习率设置过大,导致参数更新过快,结果出现震荡
    • 学习率设置过小,导致参数更新过慢,学习进展缓慢
    • 数值稳定性问题导致数值溢出,出现梯度爆炸
    • 网络设计问题导致梯度消失
    • 使用了不合理的批量大小
    • 使用了不合理的训练迭代次数及停止策略
    • 使用了不合理的学习率衰减策略
    • 使用了不合理的参数初始化策略
    • 使用了不合理的数据增强,比如检测人体时对图像做了上下翻转
    • 使用了不合理的损失函数,比如分类问题使用回归损失函数
    • 使用了不合理的网络架构,根据一维、二维、多维、时序数据选择合适的架构
    • 使用了较强的正则化,限制模型的学习能力,出现欠拟合
    • 使用了较弱的正则化,过渡学习训练集,出现过拟合

模型训练速度慢

  • 常见问题:神经网络模型训练慢,主要体现在收敛慢、资源使用率低等方面
  • 原因分析:
    • 没有使用GPU资源进行模型训练
    • 模型训练时使用过小的学习率,导致参数更新慢
    • 模型训练时使用了过大的batch size,导致参数更新频次少
    • 模型训练时使用了过小的batch size,没有充分利用计算资源
    • 没有使用预训练参数,重头训练效率低
    • 没有使用Batch Norm等做数据归一化
    • 训练过程中存在过多的内存、磁盘访问
    • 模型过于简单,不具备学习复杂任务的能力
    • 使用了过强的正则化,限制了模型的学习能力
    • 没有使用单机多卡、多机多卡并行训练
    • 模型复杂度高,参数量大,以及使用过大的图像分辨率
    • 训练数据量大,数据加载成为性能瓶颈

常见问题和思考——模型预测阶段

模型预测指标

  • 分类模型:准确率、召回率、PR曲线、ROC曲线(类别不均衡)
  • 检测模型:mAP、准确率、召回率
  • 特征提取模型:top 1、top K、r_precision

模型预测效果差

  • 常见问题:神经网络模型训练效果较好,但是在预测阶段模型表现较差
  • 原因分析:
    • 模型训练和测试的数据处理pipeline不一致,比如训练时做了Normalize,测试时没做Normalize
    • 模型在测试时没有切换到推理模式,如pytorch中的eval()转换
    • 输入的数据维度不正确,比如训练时使用[N, C, H, W],测试时也要使用同样的数据维度顺序,有些模型即使输入的数据尺寸和训练时不一样也不会报错
    • 模型参数加载不完全,以pytorch框架为例,加载模型时设置完全匹配的参数为False,在加载过程中即使参数和模型不匹配也不会报错,但是会使用默认的随机参数

模型预测资源占用高

  • 常见问题:神经网络模型在预测阶段GPU使用率低,CPU使用率高,或者出现显卡内存溢出
  • 原因分析:
    • 数据预处理在CPU上进行,没有充分利用GPU算力
    • 在预测阶段模型没有设置成推理模式,计算产生无用的中间结果占用资源
    • 深度学习框架如Pytorch自动搜索最优算子导致显存占用短暂飙升

模型处理未知类别

  • 常见问题:如何让分类模型对未见过的数据类别说“不知道,不认识”,提升鲁棒性
  • 解决方案:
    • 给分类模型添加一个其他类别,此种方法不适用于真实开放环境
    • 使用BCE(Binary Cross Entropy)多标签二分类损失函数,以猫狗分类为例,分别输出是猫狗的概率,如果输出既不是猫也不是狗,则表示未知
    • 使用表征学习(度量学习)方法通过特征匹配进行分类识别,将输入的数据与已知类别进行相似度匹配

模型效果提升

  • bad case分析:数据决定了模型的上限,提升模型性能首先应当从数据层面入手,避免脏数据带来的负面影响
  • 数据增强:通过给训练数据增加异常扰动提升模型效果的鲁棒性
    • 图像数据增强:随机裁剪、旋转、翻转、缩放、颜色/亮度/对比度调整、多图像融合等
    • 音频数据增强:频谱随机连续掩码、多频谱融合等
  • 选择合适的网络架构:针对具体的使用场景选择合适的网络架构,如图像使用2D卷积、音频使用1D卷积等
  • 选择合适的损失函数:针对具体任务使用合适的损失函数,常见的是分类损失、回归损失及混合损失
  • 选择合适的评价指标:如分类使用准确率、召回率,类别不均衡使用ROC,目标检测使用mAP等
  • 使用开源预训练模型:预训练模型通常具备较好的参数基础,在此基础上进行训练有助于提升性能
  • 多模型融合:使用不同的弱模型训练多个效果稍弱的模型,融合多个模型结果提升最终性能
  • 知识蒸馏:使用知识蒸馏将大模型学习到的知识迁移到小模型,在提升效果的同时还可以提升速度

常见问题与思考——模型推理加速

模型推理加速

  • 数据处理层面
    • 将数据预处理、后处理等操作在GPU上进行
    • 避免在CPU和GPU之间频繁进行数据拷贝
    • 避免保存大量的中间结果,如磁盘写入
  • 网络模型层面
    • 模型剪枝:根据具体业务场景,裁剪掉与业务不相关的计算模块
    • 算子融合:将多个算子融合成一个,减少内存访问次数,如将Conv + BN融合成Conv
    • 半精度推理:使用float16进行模型推理
    • 分组卷积:使用分组卷积降低网络模型的参数量和计算量,但是会增加内存访问次数
    • 模型蒸馏:将大模型学习到的知识迁移到小模型,在提升效果的同时还可以提升速度
  • 部署工具层面
    • 服务端:ONNX、TensorRT
    • 移动端:NCNN、MNN

影响模型速度的主要因素

  • 数据处理层面
    • 没有使用GPU资源进行推理
    • 输入数据尺寸大,如高分辨率图像
  • 模型架构层面
    • 网络模型参数量大
    • 网络模型计算量大
    • 网络模型并行度低,如存在多分支结构等
    • 内存访问次数多,如大量使用分组卷积
  • 工程实现层面
    • 没有使用批量推理,没有充分利用GPU的并行计算能力
    • 不合理的数据复用导致频繁拷贝

常见问题与思考——工程化问题

工程化常见问题

  • 算法集群扩容
    • kafka topic的partition数量要大于算法消费者数量
  • 直播场景混流和AI算法结果对齐
    • 算法解码处理直播流内容,获取每帧的dts、pts时间戳
    • 混流侧根据算法返回结果以及dts、pts时间戳对齐到原流,将算法结果压制到直播流中
    • 算法侧和混流侧使用pts解码时间戳对齐,不用dts时间戳,否则会造成画面闪烁
  • Kafka、Redis数据保存周期
    • 根据业务场景和处理性能设定保存周期,保存周期过长造成额外资源占用,保存周期过短造成数据丢失
  • Restful API高稳定性、高并发
    • 算法模型对外提供API接口需要关注高并发、高稳定性,通常使用gunicorn进行部署
    • 算法模型高并发部署需要成倍的资源,重点关注显存、内存和CPU资源占用情况
  • 容器化部署
    • 关注容器CUDA版本与主机显卡驱动,以及深度学习框架之间的匹配问题

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/17771.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何使用fiddler进行抓包

首先需要下载fiddler,推荐使用bing搜索引擎搜索(百度搜狗一般搜这种工具展示的前几个全都是广告),直接搜索fiddler,搜出来第一个fiddler官网 然后直接点击download下载 进入下载页面后,正确填写一个邮箱&a…

linux 动态库so相关操作

1. 查看库版本号 一般在文件名上有版本号,若文件名上没有版本号,使用如下命令查看: readelf -d libstdc.so 2. 查看库内函数 a) nm -d libstdc.so | grep 内容 b) objdump -tT libstdc.so | grep 内容 c) readelf -s libstdc.so | grep…

通用版Bubble_sort

❤博主CSDN:啊苏要学习 ▶专栏分类:C语言◀ C语言的学习,是为我们今后学习其它语言打好基础,C生万物! 开始我们的C语言之旅吧!✈ 目录 前言: 一.分析Bubble_sort 二.解决措施 三.模拟实现 前言&#xff…

【数据结构】带头+双向+循环链表(DList)(增、删、查、改)详解

一、带头双向循环链表的定义和结构 1、定义 带头双向循环链表,有一个数据域和两个指针域。一个是前驱指针,指向其前一个节点;一个是后继指针,指向其后一个节点。 // 定义双向链表的节点 typedef struct ListNode {LTDataType dat…

java判断字符串是否和空字符串(““)相等、是否和空引用(null)相等,比较顺序不同导致出现死代码(Dead code)

我在用Java实现需求的时候,用到了字符串跟空字符串(“”)比较,跟空引用null比较,两个比较语句的顺序不同,一个顺序出现了死代码(Dead code)。 下面这个代码片段,字符串li…

探秘二叉树后序遍历:从叶子到根的深度之旅

本篇博客会讲解力扣“145. 二叉树的后序遍历”的解题思路,这是题目链接。 本题的思路是: 先创建一个数组,用来存储二叉树后序遍历的结果。数组的大小跟树的结点个数有关。树的结点个数可以使用递归实现,即总个数左子树结点个数右…

如何将单体项目拆分成微服务

1、如何将单体项目拆分成微服务 如何拆分微服务?其实对不同的业务项目场景,对应有不同的拆分方案。需要项目人员详细的分析项目需求、团队现状、业务边界、业务逻辑等方方面面,拆分的粒度既不能过细,也不能过粗,需要把…

图像 检测 - FCOS: Fully Convolutional One-Stage Object Detection (ICCV 2019)

FCOS: Fully Convolutional One-Stage Object Detection - 全卷积一阶段目标检测(ICCV 2019) 摘要1. 引言2. 相关工作3. 我们的方法3.1 全卷积一阶目标检测器3.2 FCOS的FPN多级预测3.3 FCOS中心度 4. 实验4.1 消融研究4.1.1 FPN多级预测4.1.2 有无中心度…

Gis入门,根据起止点和一个控制点计算二阶贝塞尔曲线(共三个控制点组成的线段转曲线)

前言 本章讲解如何在gis地图中使用起止点和一个控制点(总共三个控制点)生成二阶贝塞尔曲线。 三阶贝塞尔曲线请参考下一章《Gis入门,使用起止点和两个控制点生成三阶贝塞尔曲线(共四个控制点)》 贝塞尔曲线(Bezier curve)介绍 贝塞尔曲线(Bezier curve)是一种数学…

Nim游戏博弈论

【模板】nim 游戏 题目描述 https://www.luogu.com.cn/problem/P2197 甲,乙两个人玩 nim 取石子游戏。 nim 游戏的规则是这样的:地上有 n n n 堆石子(每堆石子数量小于 1 0 4 10^4 104),每人每次可从任意一堆石子…

GDAL C++ API 学习之路 OGRGeometry 线类 OGRLineString

OGRLineString class "ogr_geometry.h" OGRLineString 类是 OGR 库中的一个几何对象类,用于表示线段或折线。它由多个坐标点组成,并且在坐标点之间形成线段。OGRLineString 可以包含 2D、3D 或 3DM 坐标点,其中 M 表示额外…

前端-mac初始化配置

新电脑设置: Mac三指拖动:https://support.apple.com/zh-cn/HT204609 选取苹果菜单  >“系统设置”(或“系统偏好设置”)。点按“辅助功能”。点按“指针控制”(或“鼠标与触控板”)。点按“触控板选…

ISO 7637-2 5a/5b抛负载测试保护用TVS二极管,如何选型号?

在国际标准ISO 16750-2颁布之前,全球各大汽车零部件制造商一直采用的是ISO 7637-2标准。ISO 16750-2国际标准发行之后,汽车抛负载浪涌测试中ISO 7637-2 5A和5B测试标准被ISO 16750-2测试标准取代。查看ISO 16750-2和ISO 7637-2国际标准文档资料对比会发现…

13个ChatGPT类实用AI工具汇总

在ChatGPT爆火后,各种工具如同雨后春笋一般层出不穷。以下汇总了13种ChatGPT类实用工具,可以帮助学习、教学和科研。 01 / ChatGPT for google/ 一个浏览器插件,可搭配现有的搜索引擎来使用 最大化搜索效率,对搜索体验的提升相…

多线程(JavaEE初阶系列6)

目录 前言: 1.什么是线程池 2.标准库中的线程池 3.实现线程池 结束语: 前言: 在上一节中小编带着大家了解了一下Java标准库中的定时器的使用方式并给大家实现了一下,那么这节中小编将分享一下多线程中的线程池。给大家讲解一…

MySQL主从复制配置

Mysql的主从复制至少是需要两个Mysql的服务,当然Mysql的服务是可以分布在不同的服务器上,也可以在一台服务器上启动多个服务。 (1)首先确保主从服务器上的Mysql版本相同 (2)在主服务器上,创建一个充许从数据库来访问的用户slave,密码为:123456 ,然后使用REPLICATION SLAV…

NoSQL-Redis集群

NoSQL-Redis集群 一、集群:1.单点Redis带来的问题:2.解决:3.集群的介绍:4.集群的优势:5.集群的实现方式: 二、集群的模式:1.类型:2.主从复制: 三、搭建主从复制&#xff…

在CentOS 7上挂载硬盘到系统的步骤及操作

目录 1:查询未挂载硬盘2:创建挂载目录3:检查磁盘是否被分区4:格式化硬盘5:挂载目录6:检查挂载状态7:设置开机自动挂载总结: 本文介绍了在CentOS 7上挂载硬盘到系统的详细步骤。通过确…

代码随想录算法训练营第二十八天 | Leetcode随机抽题检测

Leetcode随机抽题检测--使用题库:Leetcode热题100 1 两数之和未看解答自己编写的青春版重点题解的代码日后再次复习重新写 49 字母异位词分组未看解答自己编写的青春版重点题解的代码日后再次复习重新写 128 最长连续序列未看解答自己编写的青春版重点关于 left 和 …

C语言每日一题:12《数据结构》相交链表。

题目: 题目链接 思路一: 1.如果最后一个节点相同说明一定有交点。 2.使用两个循环获取一下长度,同时可以获取到尾节点。 3。注意初始化lenA和lenB为1,判断下一个节点是空是可以保留尾节点的。长度会少一个,尾节点没有…