华为开源自研AI框架昇思MindSpore应用案例:人体关键点检测模型Lite-HRNet

如果你对MindSpore感兴趣,可以关注昇思MindSpore社区

在这里插入图片描述

在这里插入图片描述

一、环境准备

1.进入ModelArts官网

云平台帮助用户快速创建和部署模型,管理全周期AI工作流,选择下面的云平台以开始使用昇思MindSpore,获取安装命令,安装MindSpore2.0.0-alpha版本,可以在昇思教程中进入ModelArts官网

在这里插入图片描述

选择下方CodeLab立即体验

在这里插入图片描述

等待环境搭建完成

在这里插入图片描述

2.使用CodeLab体验Notebook实例

下载NoteBook样例代码,Lite-HRNet实现人体关键点检测 ,.ipynb为样例代码

在这里插入图片描述

选择ModelArts Upload Files上传.ipynb文件

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

选择Kernel环境

在这里插入图片描述

切换至GPU环境,切换成第一个限时免费

在这里插入图片描述

进入昇思MindSpore官网,点击上方的安装

在这里插入图片描述

获取安装命令

在这里插入图片描述

回到Notebook中,在第一块代码前加入命令
在这里插入图片描述

conda update -n base -c defaults conda

在这里插入图片描述

安装MindSpore 2.0 GPU版本

conda install mindspore=2.0.0a0 -c mindspore -c conda-forge

在这里插入图片描述

安装mindvision

pip install mindvision

在这里插入图片描述

安装下载download

pip install download

人体关键点检测模型Lite-HRNet

人体关键点检测是计算机视觉的基本任务之一,在许多应用场景诸如自动驾驶、安防等有着重要的地位。可以发现,在这些应用场景下,深度学习模型可能需要部署在IoT设备上,这些设备算力较低,存储空间有限,无法支撑太大的模型,因此轻量但不失高性能的人体关键点检测级模型将极大降低模型部署难度。Lite-HRNet便提供了一轻量级神经网络骨干,通过接上不同的后续模型可以完成不同的任务,其中便包括人体关键点检测,在配置合理的情况下,Lite-HRNet可以以大型神经网络数十分之一的参数量及计算量达到相近的性能。

模型简介

Lite-HRNet由HRNet(High-Resolution Network)改进而来,HRNet的主要思路是在前向传播过程中通过维持不同分辨率的特征,使得最后生成的高阶特征既可以保留低分辨率高阶特征中的图像语义信息,也可以保留高分辨率高阶特征中的物体位置信息,进而提高在分辨率敏感的任务如语义分割、姿态检测中的表现。Lite-HRNet是HRNet的轻量化改进,改进了HRNet中的卷积模块,将HRNet中的参数量从28.5M降低至1.1M,计算量从7.1GFLOPS降低至0.2GFLOPS,但AP75仅下降了7%。
综上,Lite-HRNet具有计算量、参数量低,精度可观的优点,有利于部署在物联网低算力设备上服务于各个应用场景。

数据准备

本案例使用COCO2017数据集作为训练、验证数据集,请首先安装Mindspore Vision套件,并确保安装的Mindspore是GPU版本,随后请在https://cocodataset.org/ 上下载好2017 Train Images、2017 Val Images以及对应的标记2017 Train/Val Annotations,并解压至当前文件夹,文件夹结构下表所示

Lite-HRNet/├── imgs├── src├── annotations├──person_keypoints_train2017.json└──person_keypoints_train2017.json├── train2017└── val2017

训练、测试原始图片如下所示,图片中可能包含多个人体,且包含的人体不一定包含COCO2017中定义的17个关键点,标注中有每个人体的边框、关键点信息,以便处理图像后供模型训练。

数据预处理

src/mindspore_coco.py中定义了供mindspore模型训练、测试的COCO数据集接口,在加载训练数据集时只需指定所用数据集文件夹位置、输入图像的尺寸、目标热力图的尺寸、以及手动设置对训练图像采用的变换即可


import mindspore as ms
import mindspore.dataset as dataset
import mindspore.dataset.vision.py_transforms as py_vision
import mindspore.nn as nn
from mindspore.dataset.transforms.py_transforms import Composefrom src.configs.dataset_config import COCOConfig
from src.dataset.mindspore_coco import COCODatasetcfg = COCOConfig(root="./", output_dir="outputs/", image_size=[192, 256], heatmap_size=[48, 64])
trans = Compose([py_vision.ToTensor(),py_vision.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
train_ds = COCODataset(cfg, "../", "train2017", True, transform=trans)
train_loader = dataset.GeneratorDataset(train_ds, ["data", "target", "weight"])

在这里插入图片描述

构建网络

Lite-HRNet网络骨干大体结构如下图所示:
在这里插入图片描述

网络中存在不同的分辨率分支,网络主干上维持着较高分辨率、较少通道数的输出特征,网络分支上延展出较低分辨率、较多通道数的输出特征,且这些不同分辨率的特征之间通过上采样、下采样的卷积层进行交互、融合。Stage内的Cross Channel Weighting(CCW)则是网络实现轻量化的精髓,它将原HRNet中复杂度较高的1*1卷积以更低复杂度的Spatial Weighting等方法替代,从而实现降低网络参数、计算量的效果。CCW的结构如下图所示

在这里插入图片描述

值得注意的是,除了骨干网络,作者在论文中同时也给出了所使用的检测头即SimpleBaseline,为了简洁起见,在本次的Lite-HRNet的Mindspore实现中,检测头(代码中包括IterativeHeads和LiteTopDownSimpleHeatMap)已集成至骨干网络之后,作为整体模型的一部分,直接调用模型即可得到热力图预测输出。

损失函数

此处使用损失函数为JointMSELoss,即关节点的均方差误差损失函数,其源码如下所示,总体流程即计算每个关节点预测热力图与实际热力图的均方差,其中target是根据关节点的人工标注坐标,通过二维高斯分布生成的热力图,target_weight用于指定参与计算的关节点,若某关节点对应target_weight取值为0,则表明该关节点在输入图像中未出现,不参与计算。

"""JointMSELoss"""
import mindspore.nn as nn
import mindspore.ops as opsclass JointsMSELoss(nn.Cell):"""Joint MSELoss"""def __init__(self, use_target_weight):"""JointMSELoss"""super(JointsMSELoss, self).__init__()self.criterion = nn.MSELoss(reduction='mean')self.use_target_weight = use_target_weightdef construct(self, output, target, weight):"""construct"""target = targettarget_weight = weightbatch_size = output.shape[0]num_joints = output.shape[1]spliter = ops.Split(axis=1, output_num=num_joints)mul = ops.Mul()heatmaps_pred = spliter(output.reshape((batch_size, num_joints, -1)))heatmaps_gt = spliter(target.reshape((batch_size, num_joints, -1)))loss = 0for idx in range(num_joints):heatmap_pred = heatmaps_pred[idx].squeeze()heatmap_gt = heatmaps_gt[idx].squeeze()if self.use_target_weight:heatmap_pred = mul(heatmap_pred, target_weight[:, idx])heatmap_gt = mul(heatmap_gt, target_weight[:, idx])loss += 0.5 * self.criterion(heatmap_pred,heatmap_gt)else:loss += 0.5 * self.criterion(heatmap_pred, heatmap_gt)return loss/num_joints

模型实现与训练

在实现模型时,需指定模型内部结构,在src/net_configs中已指定原论文中10种结构配置,在训练样例种取Lite_18_coco作为模型结构,此处作为案例,仅设置epoch数量为1,在实际训练中可以设置为200,并且可以加入warmup。由于mindspore的训练接口默认数据集中每条数据只有两列(即训练数据和标签),所以这里需自定义Loss Cell。值得注意的是loss在训练前后变化并不会十分大,训练好的模型的loss为0.0004左右

class CustomWithLossCell(nn.Cell):def __init__(self,net: nn.Cell,loss_fn: nn.Cell):super(CustomWithLossCell, self).__init__()self.net = netself._loss_fn = loss_fndef construct(self, img, target, weight):""" build network """heatmap_pred = self.net(img)return self._loss_fn(heatmap_pred,target,weight)
from src.configs.net_configs import get_netconfig
from mindspore.train.callback import  LossMonitor
from src.backbone import LiteHRNetext = get_netconfig("extra_lite_18_coco")
net = LiteHRNet(ext)
criterion = JointsMSELoss(use_target_weight=True)train_loader = train_loader.batch(64)
optim = nn.Adam(net.trainable_params(), learning_rate=2e-3)
loss = JointsMSELoss(use_target_weight=True)
net_with_loss = CustomWithLossCell(net, loss)model = ms.Model(network=net_with_loss, optimizer=optim)
epochs = 1
#Start Training
model.train(epochs, train_loader, callbacks=[LossMonitor(100)], dataset_sink_mode=False)

在这里插入图片描述

模型评估

模型评估过程中使用AP、AP50、AP75以及AR50、AR75作为评价指标,val2017作为评价数据集,pycocotool包中已实现根据评价函数,且src/mindspore_coco.py中的evaluate函数也实现了调用该评价函数的接口,只需提供预测关键点坐标等信息即可获得评价指标。此处载入Lite_18_coco的预训练模型进行评价。

from mindspore import load_checkpoint
from mindspore import load_param_into_netfrom src.utils.utils import get_final_preds
import numpy as npdef evaluate_model(model, dataset, output_path):"""Evaluate"""num_samples = len(dataset)all_preds = np.zeros((num_samples, 17, 3),dtype=np.float32)all_boxes = np.zeros((num_samples, 6))image_path = []for i, data in enumerate(dataset):input_data, target, meta = data[0], data[1], data[3]input_data = ms.Tensor(input_data[0], ms.float32).reshape(1, 3, 256, 192)shit = model(input_data).asnumpy()target = target.reshape(shit.shape)c = meta['center'].reshape(1, 2)s = meta['scale'].reshape(1, 2)score = meta['score']preds, maxvals = get_final_preds(shit, c, s)all_preds[i:i + 1, :, 0:2] = preds[:, :, 0:2]all_preds[i:i + 1, :, 2:3] = maxvals# double check this all_boxes partsall_boxes[i:i + 1, 0:2] = c[:, 0:2]all_boxes[i:i + 1, 2:4] = s[:, 0:2]all_boxes[i:i + 1, 4] = np.prod(s*200, 1)all_boxes[i:i + 1, 5] = scoreimage_path.append(meta['image'])dataset.evaluate(0, all_preds, output_path, all_boxes, image_path)net_dict = load_checkpoint("./ckpt/litehrnet_18_coco_256x192.ckpt")
load_param_into_net(net, net_dict)eval_ds = COCODataset(cfg, "./", "val2017", False, transform=trans)
evaluate_model(net, eval_ds, "./result")

在这里插入图片描述

模型推理

  1. Lite-HRNet是关键点检测模型,所以输入待推理图像应为包含单个人体的图像,作者在论文中提及在coco test 2017测试前已使用SimpleBaseline生成的目标检测Bounding Box处理图像,所以待推理图像应仅包含单个人体。
  2. 网络的输入为(1,3,256,192),所以在输入图像前应先将其变换成网络可处理的形式。
import cv2
from src.utils.utils import get_max_preds
origin_img = cv2.imread("./imgs/man.jpg")
origin_h, origin_w, _ = origin_img.shape
scale_factor = [origin_w/192, origin_h/256]# resize to (112 112 3) and convert to tensor
img = cv2.resize(origin_img, (192, 256))
print(img.shape)
img = trans(img)
# img = np.expand_dims(img, axis=0)
img = ms.Tensor(img)
print(img.shape)# Infer
heatmap_pred = net(img).asnumpy()
pred, _ = get_max_preds(heatmap_pred)# Postprocess
pred = pred.reshape(pred.shape[0], -1, 2)
print(pred[0])
pre_landmark = pred[0] * 4 * scale_factor
# Draw points
for (x, y) in pre_landmark.astype(np.int32):cv2.circle(origin_img, (x, y), 3, (255, 255, 255), -1)# Save image
cv2.imwrite("./imgs/man_infer.jpg", origin_img)

在这里插入图片描述

可以看到模型基本正确标注出了关键点的位置\

在这里插入图片描述

算法基本流程

  1. 获取原始数据
  2. 从数据集的标注json文件中得到各个图像bbox以及关键点坐标信息
  3. 根据bbox裁剪图像,并放缩至指定尺寸,如果是训练还可以作适当数据增强,生成指定尺寸的目标热力图
  4. 指定尺寸的输入经过网络前向传播后得到预测的关键点热力图
  5. 经过处理后取热力图中的最大值坐标作为关键点的预测坐标

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/61079.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Cellebrite VS IOS18Rebooting

Cellebrite VS IOS18Rebooting我们想分享一些有关 iOS 18 重启“功能”的信息。在过去一周左右的时间里,人们对 iOS 18 中一项新的未记录功能产生了极大关注,该功能会导致设备在一段时间不活动后重新启动。 这意味着,如果设备在一定时间不活…

使用 Axios 拦截器优化 HTTP 请求与响应的实践

目录 前言1. Axios 简介与拦截器概念1.1 Axios 的特点1.2 什么是拦截器 2. 请求拦截器的应用与实践2.1 请求拦截器的作用2.2 请求拦截器实现 3. 响应拦截器的应用与实践3.1 响应拦截器的作用3.2 响应拦截器实现 4. 综合实例:一个完整的 Axios 配置5. 使用拦截器的好…

【最大子矩阵——双指针 / 二分】

题目 双指针&#xff1a; 代码 #include <bits/stdc.h> using namespace std; const int N 85, M 1e510; int g[N][M]; int n, m, lim; int ans 1; int main() {ios::sync_with_stdio(0);cin.tie(0);cin >> n >> m;for(int i 1; i < n; i)for(int …

内网渗透-隧道判断-SSH-DNS-icmp-smb-上线linux-mac

1.通道判断 #SMB 隧道&通讯&上线 判断&#xff1a;445 通讯 上线&#xff1a;借助通讯后绑定上线 通讯&#xff1a;直接 SMB 协议通讯即可 #ICMP 隧道&通讯&上线 判断&#xff1a;ping 命令 上线&#xff1a;见前面课程 通讯&#xff1a;其他项…

【优选算法篇】分治乾坤,万物归一:在重组中窥见无声的秩序

文章目录 分治专题&#xff08;二&#xff09;&#xff1a;归并排序的核心思想与进阶应用前言、第二章&#xff1a;归并排序的应用与延展2.1 归并排序&#xff08;medium&#xff09;解法&#xff08;归并排序&#xff09;C 代码实现易错点提示时间复杂度和空间复杂度 2.2 数组…

【微软:多模态基础模型】(3)视觉生成

欢迎关注【youcans的AGI学习笔记】原创作品 【微软&#xff1a;多模态基础模型】&#xff08;1&#xff09;从专家到通用助手 【微软&#xff1a;多模态基础模型】&#xff08;2&#xff09;视觉理解 【微软&#xff1a;多模态基础模型】&#xff08;3&#xff09;视觉生成 【微…

netcore Kafka

一、新建项目KafakDemo <ItemGroup><PackageReference Include"Confluent.Kafka" Version"2.6.0" /></ItemGroup> 二、Program.cs using Confluent.Kafka; using System; using System.Threading; using System.Threading.Tasks;names…

工业生产安全-安全帽第一篇-opencv及java开发环境搭建

一.背景 公司是非煤采矿业&#xff0c;核心业务是采选&#xff0c;大型设备多&#xff0c;安全风险因素多。当下政府重视安全&#xff0c;头部技术企业的安全解决方案先进但价格不低&#xff0c;作为民营企业对安全投入的成本很敏感。利用我本身所学&#xff0c;准备搭建公司的…

fastadmin多个表crud连表操作步骤

1、crud命令 php think crud -t xq_user_credential -u 1 -c credential -i voucher_type,nickname,user_id,voucher_url,status,time --forcetrue2、修改控制器controller文件 <?phpnamespace app\admin\controller;use app\common\controller\Backend;/*** 凭证信息…

【英特尔IA-32架构软件开发者开发手册第3卷:系统编程指南】2001年版翻译,2-26

文件下载与邀请翻译者 学习英特尔开发手册&#xff0c;最好手里这个手册文件。原版是PDF文件。点击下方链接了解下载方法。 讲解下载英特尔开发手册的文章 翻译英特尔开发手册&#xff0c;会是一件耗时费力的工作。如果有愿意和我一起来做这件事的&#xff0c;那么&#xff…

Essential Cell Biology--Fifth Edition--Chapter one (8)

1.1.4.6 The Cytoskeleton [细胞骨架] Is Responsible for Directed Cell Movements 细胞质基液不仅仅是一种无结构的化学物质和细胞器的混合物[soup]。在电子显微镜下&#xff0c;我们可以看到真核细胞的细胞质基液是由长而细的丝交叉而成的。通常[Frequently]&#xff0c;可…

RK3568 Linux 系统加系统运行指示灯

一、dts配置 gpio-leds {status = "okay";compatible = "gpio-leds";work-led {gpios = <&gpio0 RK_PB7 GPIO_ACTIVE_HIGH>

C++11(六)----包装器function和bind

文章目录 包装器&#xff1a;function包装器&#xff1a;bind 包装器&#xff1a;function function接口介绍 在头文件<functional>中 语法&#xff1a;function的语法比较特殊 function<返回值(参数)> 自定义变量名 要被包装的可调用对象 class Plus { public:…

店铺推推-项目测试用例设计(Xmind)

项目介绍&#xff1a; 技术栈: Spring BootMyBatisRedis项目描述&#xff1a; 项目旨在为消费者提供一个公平、公开、透明的平台&#xff0c;让消费者能够基于真实的消费体验对店铺进行评价和 推荐&#xff0c;并为其他潜在消费者提供参考。同时&#xff0c;店铺推推也是为商家…

c++--------《set 和 map》

c--------《set 和 map》 1 set系列的使⽤1.1 set类的介绍1.2 set的构造和迭代器1.3 set重要接口 2 实现样例2.1: insert和迭代器遍历使⽤样例&#xff1a;2.2: find和erase使⽤样例&#xff1a; 练习3.map系列的使用3.1 map类的介绍3.1.1 pair类型介绍 3.2 map的数据修改3.3mu…

GIS融合之路(八)-如何用Cesium直接加载OSGB文件(不用转换成3dtiles)

系列传送门&#xff1a; 山海鲸可视化&#xff1a;GIS融合之路&#xff08;一&#xff09;技术选型CesiumJS/loaders.gl/iTowns? 山海鲸可视化&#xff1a;GIS融合之路&#xff08;二&#xff09;CesiumJS和ThreeJS深度缓冲区整合 山海鲸可视化&#xff1a;GIS融合之路&…

QQ 小程序已发布,但无法被搜索的解决方案

前言 我的 QQ 小程序在 2024 年 8 月就已经审核通过&#xff0c;上架后却一直无法被搜索到。打开后&#xff0c;再在 QQ 上下拉查看 “最近使用”&#xff0c;发现他出现一下又马上消失。 上线是按正常流程走的&#xff0c;开发、备案、审核&#xff0c;没有任何违规&#xf…

word 中长公式换行 / 对齐 | Mathtype 中长公式换行拆分 | latex 中长公式换行

注&#xff1a;本文为 “word 中长公式换行 / 对齐 | Mathtype 中长公式换行拆分 | latex 中长公式换行” 相关专题文章合辑。 未整理去重。 “公式较长时最好在等号 “&#xff1d;” 处转行&#xff0c;如难实现&#xff0c;则可在&#xff0b;、&#xff0d;、、 运算符号处…

【优选算法 — 滑动窗口】串联所有单词的子串 最小覆盖子串

串联所有单词的子串 串联所有单词的子串 题目描述 题目解析 算法原理 以示例一为例&#xff0c;一定要记得&#xff0c;words中的每一个字符串长度相同&#xff0c;所以我们可以根据 words 中的每一个字符串的长度length&#xff0c;将 s 这个字符串以 length 个为一组来…

WEB攻防-通用漏洞SQL注入sqlmapOracleMongodbDB2等

SQL注入课程体系&#xff1a; 1、数据库注入-access mysql mssql oracle mongodb postgresql 2、数据类型注入-数字型 字符型 搜索型 加密型&#xff08;base64 json等&#xff09; 3、提交方式注入-get post cookie http头等 4、查询方式注入-查询 增加 删除 更新 堆叠等 …