【PyTorch实战演练】深入剖析MTCNN(多任务级联卷积神经网络)并使用30行代码实现人脸识别

文章目录

      • 0. 前言
      • 1. 级联神经网络介绍
      • 2. MTCNN介绍
        • 2.1 MTCNN提出背景
        • 2.2 MTCNN结构
      • 3. MTCNN PyTorch实战
        • 3.1 facenet_pytorch库中的MTCNN
        • 3.2 识别图像数据
        • 3.3 人脸识别
        • 3.4 关键点定位

0. 前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解及成果,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。

本文详细介绍MTCNN——多任务级联卷积神经网络的结构,并通过PyTorch实例说明MTCNN在人脸识别上的应用。

MTCNN的全称是Multi-Task Cascaded Convolutional Networks,它的缩写确实是MTCNN不是MTCCN.

1. 级联神经网络介绍

级联(cascaded)神经网络是一种人工神经网络的架构设计,它指的是多个神经网络层按照特定的方式连接起来,形成一个逐层处理信息的多层结构。在级联神经网络中,前一层次网络的输出作为后一层次网络的输入,这种结构允许在网络在深度方向上对复杂性和抽象层数进行增加

级联网络的重要特点是其动态构建特性,即可以从一个小规模的基本网络开始,并随着训练过程自动添加更多的隐藏单元或子网络,逐渐扩展成一个更深层次的结构,然后通过只针对新增部分数据进行训练来更新权重,即增量式学习(Incremental Learning)。这与传统的——构建完整模型后统一进行训练更新权重的思路非常不同。

使用传统的思路,如果发现我们的模型并不适用于待解决的任务,导致要调整模型结构时,通常会意味着之前的训练模型的工作全部白费了。

总结起来级联神经网络具有以下优点:

  1. 自适应结构:级联网络设计允许根据训练数据或学习过程动态调整网络结构,比如自动增加新的层或神经元,以适应更复杂的模式识别任务。

  2. 学习效率提升:可以通过增量学习或局部训练来加快学习速度,只针对新增加的部分进行训练优化。在某些情况下,级联网络可以采用非传统的权重更新机制,不需要在整个网络上执行全局误差反向传播算法。

  3. 鲁棒性和容错性:分层结构有助于提高系统的鲁棒性,单个层次的错误可能在后续层次中得到修正。

2. MTCNN介绍

2.1 MTCNN提出背景

MTCNN是Kaipeng Zhang等人在论文——Joint Face Detection and Alignment using Multi-task Cascaded Convolutional Networks中提出的,其宗旨是通过多任务级联CNN解决两个问题:人脸检测(找出图像中人脸的位置和边界框)和人脸对齐(精确定位面部特征点)。

2.2 MTCNN结构

MTCNN的构建思路可以简单分为下面几个步骤:
在这里插入图片描述

  • 准备步骤:对图像进行缩放,建立图像金字塔;
  • 第一步Proposal-Net:快速选出若干候选框,为下一步准备;
  • 第二步Refine-Net:对第一步的众多候选框进行精选,留下置信度大的候选框;
  • 第三步Output-Net:输出最终bounding box、人脸关键特征定位和置信度。

详细来说P-Net、R-Net和O-Net的结构如下:在这里插入图片描述
通过Netron可以看到facenet_pytorch库中的MTCNN的结构及详细参数如下:
请添加图片描述

  1. P-Net (Proposal Network):

    • 输入是原始图像。
    • 首先通过一个卷积层(Conv2d)将3通道的输入图像转换为10通道特征图,使用3x3的卷积核(kernel_size=(3, 3))。
    • 紧接着使用PReLU激活函数(prelu1)进行非线性变换。
    • 使用最大池化层(MaxPool2d)下采样特征图(pool1),步长为2。
    • 再经过两个卷积层(Conv2d)提取更深层次的特征,并分别用PReLU激活函数(prelu2和prelu3)进行非线性处理。
    • 最后通过两个1x1卷积层(Conv2d)生成两个输出:一个是softmax4_1用于预测每个像素是否为人脸的概率分布,另一个是conv4_2用于回归bounding box的位置信息。
  2. R-Net (Refine Network):

    • 输入是P-Net的候选区域。
    • 类似于P-Net,R-Net也包含多个卷积层与激活函数,以及池化层进行特征提取和下采样。
    • 在最后,通过两个全连接层(Dense或Linear)生成两个输出:softmax5_1用于判断候选框内是否为人脸并给出置信度,dense5_2用于进一步细化人脸框的位置。
  3. O-Net (Output Network):

    • 输入同样是前一级网络(R-Net)筛选后的候选区域。
    • O-Net具有更多的卷积层以获取更精细的特征表达,同样在最后阶段通过三个全连接层生成三个输出:softmax6_1用于人脸分类,dense6_2用于人脸框回归精修,dense6_3用于估计关键点(如眼睛、嘴巴等)的位置。

整个MTCNN模型通过逐步筛选和优化候选区域,在不同尺度上定位和识别图像中的人脸,从而实现高效准确的人脸检测。

3. MTCNN PyTorch实战

3.1 facenet_pytorch库中的MTCNN

facenet_pytorch库中的MTCNN类是一个用于人脸检测的多任务级联卷积神经网络模型实现。直接使用MTCNN类的最大好处就是该模型已经训练好,可以拿来即用,其初始化时接受多个参数,以下是对这些参数的详细解释:

  1. image_size(默认值:160):输出图像的大小(像素),图像会调整为正方形。

  2. margin(默认值:0):在最终图像上添加到边界框的边距(以像素为单位)。需要注意的是,与davidsandberg/facenet库中的应用方式稍有不同,该库在调整原始图像大小之前就对原始图像应用了边距,导致边距与原始图像大小相关(这是davidsandberg/facenet的一个bug)。

  3. min_face_size(默认值:20):要搜索的人脸的最小尺寸。

  4. thresholds(默认值:[0.6, 0.7, 0.7]):MTCNN人脸检测阈值列表,分别对应P-Net、R-Net和O-Net三个阶段的阈值。

  5. factor(默认值:0.709):用于创建人脸大小缩放金字塔的比例因子。

  6. post_process(默认值:True):是否在返回前对图像张量进行后处理。

  7. select_largest(默认值:True):如果检测到多个人脸,是否选择面积最大的一个返回。若设为False,则选择概率最高的人脸返回。

  8. selection_method(默认值:None):指定使用哪种启发式方法进行选择,如果设置此参数将覆盖select_largest

    • "probability":选择概率最高的。
    • "largest":选择面积最大的框。
    • "largest_over_threshold":选择超过一定概率的最大框。
    • "center_weighted_size":基于框大小减去离图像中心加权距离平方后的结果进行选择。
  9. keep_all(默认值:False):如果设为True,则返回所有检测到的人脸,并按照select_largest参数设定的顺序排列。如果指定了保存路径,第一张人脸将被保存至该路径,其余人脸将依次保存为<save_path>1, <save_path>2等。

  10. device(默认值:None):运行神经网络前向传递时所使用的设备。图像张量和模型会在前向传递前复制到这个设备上。

3.2 识别图像数据

这块没有特殊要求,随便去网上下载,以下是我自己的识别对象数据:
在这里插入图片描述

3.3 人脸识别
  • 代码
from facenet_pytorch import MTCNN
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
import osdevice = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('the device is:{}'.format(device))model = MTCNN(image_size=160, margin=0, min_face_size=10,thresholds=[0.7,0.7,0.7],factor=0.7, post_process=True, device=device)
path = os.path.abspath('face_img')  #在face_img文件夹下面还要再加一个class_folder文件夹
dataset = datasets.ImageFolder(path)
imgs_list = list(sorted(os.listdir(os.path.join(path,'class_folder'))))def collate_fn(x):return x[0]loader = DataLoader(dataset, collate_fn = collate_fn, num_workers=0)
index = 0
#detected_faces = []for pic,_ in loader:aligned, confidence = model(pic , return_prob=True)if confidence is not None:print('Confidence of {} containing human face is {:.8f}'.format(imgs_list[index], confidence))detected_faces.append(aligned)else:print('No human face detected in {}'.format(imgs_list[index]))index += 1# 以下是人脸对齐的还原实现
#face_numpy = (detected_faces[0] + 1) * 127.5  # 由于是 [-1, 1] 范围,将其映射到 [0, 255]
#face_numpy = face_numpy.numpy().astype(np.uint8)
#face_image = Image.fromarray(face_numpy.transpose(1, 2, 0)) # 将 Numpy 数组转为 PIL 图像格式,并注意调整通道顺序为 (H, W, C)#plt.imshow(face_image)
#plt.show()
  • 输出
Confidence of art.png containing human face is 0.99512947
No human face detected in ironman.png
Confidence of man.png containing human face is 0.99643928
No human face detected in ogre.png
Confidence of thanos.png containing human face is 0.96726525
Confidence of woman.png containing human face is 0.99991846

可见MTCNN不认为钢铁侠和食人魔魔法师算“人脸”。MTCNN的输出有2部分:

  1. 对齐后的人脸张量:其范围是[-1, 1],可以将其线性还原到[0, 255]并输出对其后的人脸,例如下图:在这里插入图片描述
  2. 包含人脸的置信度:即上面的0.99512947等置信度数值。
3.4 关键点定位

也可以使用mtcnn.detect()得到人脸得关键点(眼睛、鼻子、嘴角)定位,代码如下:

from facenet_pytorch import MTCNN
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
import os
import numpy
import matplotlib
import matplotlib.pyplot as pltdevice = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('the device is:{}'.format(device))model = MTCNN(image_size=160, margin=0, min_face_size=10,thresholds=[0.7,0.7,0.7],factor=0.7, post_process=True, device=device)path = os.path.abspath('face_img')  #在face_img文件夹下面还要再加一个class_folder文件夹
dataset = datasets.ImageFolder(path)
imgs_list = list(sorted(os.listdir(os.path.join(path,'class_folder'))))def collate_fn(x):return x[0]loader = DataLoader(dataset, collate_fn = collate_fn, num_workers=0)
index = 0
detected_faces = []for pic,_ in loader:aligned, confidence = model(pic , return_prob=True)if confidence is not None:print('Confidence of {} containing human face is {:.8f}'.format(imgs_list[index], confidence))detected_faces.append(aligned)boxes, probs, points = model.detect(pic, landmarks=True)points = points.squeeze(0)for x,y in points:plt.scatter(x,y,s=10,c='r')plt.imshow(pic)plt.savefig('{}_aligned.jpg'.format(imgs_list[index]))plt.close()else:print('No human face detected in {}'.format(imgs_list[index]))index += 1

最终保存的图像为:
在这里插入图片描述

在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

可以看出MTCNN的关键点定位也是很准确的。上面代码的boxs即为人脸边界框,这里不再画出效果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/736530.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

DenseNet笔记

&#x1f4d2;from ©实现pytorch实现DenseNet&#xff08;CNN经典网络模型详解&#xff09; - 知乎 (zhihu.com) 是什么之 DenseBlock 读图&#xff1a; x0是inputH1的输入是x0 (input)H2的输入是x0和x1 (x1是H1的输出) Summary&#xff1a; 传统卷积网&#xff0c;网…

IDEA管理Git + Gitee 常用操作

文章目录 IDEA管理Git Gitee 常用操作1.Gitee创建代码仓库1.创建仓库1.点击新建仓库2.完成仓库信息填写3.创建成功4.管理菜单可以修改这个项目的设置 2.设置SSH公钥免密登录基本介绍1.找到.ssh目录2.执行指令 ssh-keygen3.将公钥信息添加到码云账户1.点击设置2.ssh公钥3.复制.…

[力扣 Hot100]Day50 二叉树中的最大路径和

题目描述 二叉树中的 路径 被定义为一条节点序列&#xff0c;序列中每对相邻节点之间都存在一条边。同一个节点在一条路径序列中 至多出现一次 。该路径 至少包含一个 节点&#xff0c;且不一定经过根节点。 路径和 是路径中各节点值的总和。 给你一个二叉树的根节点 root &…

ETL与抖音数据同步,让数据流动无阻

在当今数字化时代&#xff0c;数据的价值日益凸显&#xff0c;企业需要从各种渠道获取有关用户行为、市场趋势和竞争对手活动的数据。作为一家专注于数据集成和转换的领先平台&#xff0c;ETLCloud为企业提供了强大的数据同步和转换功能。而与此同时&#xff0c;抖音作为一款热…

Java中常见的“类”大全

Java 中有很多常见的类&#xff0c;它们提供了各种功能&#xff0c;从基本数据类型的封装到复杂的数据结构和算法。以下是一些常见的 Java 类&#xff1a; 1.Object 类&#xff1a; 所有类的超类&#xff0c;提供了一些通用的方法&#xff0c;如 toString()、equals()、hashCod…

论文解读:Meta-Baseline: Exploring Simple Meta-Learning for Few-Shot Learning

文章汇总 总体问题 通过对整体分类的训练(文章结构图中ClassifierBaseline)&#xff0c;即在整个标签集上进行分类&#xff0c;它可以得到与许多元学习算法相当甚至更好的嵌入。这两种工作之间的界限尚未得到充分的探索&#xff0c;元学习在少样本学习中的有效性仍然不清楚。…

Visual C++ 2010学习版安装教程

1. 创建项目 点击 “创建新项目”&#xff0c;创建一个项目。 2. 创建 helloworld.c ⽂件 3. 在弹出的编辑框中&#xff0c;选中 “C文件(.cpp)”&#xff0c;将 下方 “源.cpp” 手动改为要新创建的文件名。 如&#xff1a;helloWorld.c 。注意&#xff0c;默认 cpp 后缀名&am…

java SSM旅游景点与公交线路查询系统myeclipse开发mysql数据库springMVC模式java编程计算机网页设计

一、源码特点 java SSM旅游景点与公交线路查询系统是一套完善的web设计系统&#xff08;系统采用SSM框架进行设计开发&#xff0c;springspringMVCmybatis&#xff09;&#xff0c;对理解JSP java编程开发语言有帮助&#xff0c;系统具有完整的源代码和数据库&#xff0c;系…

趣学前端 | Taro迁移完成之后,总结了一些踩坑经验

背景 四月份的时候&#xff0c;尝试将老的移动端项目改造成多端。因为老项目使用的React框架&#xff0c;综合考量&#xff0c;保障当前业务开发的进度同时&#xff0c;进行项目迁移&#xff0c;所以最后选择了Taro框架。迁移成本会低一些&#xff0c;上手快一些。 上个月&am…

CAN一致性测试:物理层测试之终端电阻测试

从本周开始结合工作实践&#xff0c;给大家总结CAN一致性相关的测试 包括&#xff1a;物理层、数据链路层、应用层三大块知识点 CAN一致性测试:物理层测试之终端电阻测试 试验目的&#xff1a; 测试控制器的 CANH 对地、CANL 对地、CANH 对 CANL 的内阻是否符合 ISO11898-2的…

读写算杂志《读写算》杂志社读写算杂志社2024年第7期目录

教育资讯 全国学生心理健康工作咨询委员会第一次全体会议召开 1 扩优提质 区域先行——基础教育高质量发展现场会在福州晋安召开 1-2 河北唐山曹妃甸&#xff1a;新学期抓好四项工作 2-3 崇红立志——江苏盐城亭湖7万学生争做新时代红色少年 3 习作选登 秋…

ubuntu設定QGC獲取pixhawk Mini4(PX4 Mini 4) 的imu信息

ubuntu20.04 QGC使用v4.3.0的版本 飛控pixhawk Mini4 飛控上只使用一條micro USB連接電腦&#xff0c;沒有其他線 安裝命令 sudo apt-get remove modemmanager -y sudo apt install gstreamer1.0-plugins-bad gstreamer1.0-libav gstreamer1.0-gl -y sudo apt install libf…

简单了解不同行业下4a的定义

工作中我们经常会听见4a这个词语&#xff0c;但大部分人对于4a的定义不是很了解&#xff0c;今天我们就来简单了解下不同行业下4a的定义。 简单了解不同行业下4a的定义 1、网络安全领域 4A指的是认证&#xff08;Authentication&#xff09;、授权&#xff08;Authorization…

ElasticSearch集群的备份和恢复

备份方式 官方建议采用snapshot方式进行备份与恢复。 单节点案例 单节点备份 首先我们看下单节点的情况下&#xff0c;我们首先需要在配置文件中配置好本地磁盘&#xff1a; path.repo:["/opt/elasticsearch-cluster/snapshot_repo"] 可以配置多个仓库&#xf…

python之数组,链表,栈,队列

1.数组 优点&#xff1a; 索引操作速度快&#xff1a;通过索引可以直接访问元素&#xff0c;因此索引操作的时间复杂度是 $O(1)$&#xff0c;即常数级 缺点&#xff1a; 插入、删除元素慢&#xff1a; 如果需要在中间或开始位置插入或删除元素&#xff0c;可能需要移动大量…

加密 / MD5算法 /盐值

目录 加密的介绍 MD5算法 盐值 加密的介绍 加密介绍&#xff1a;在MySQL数据库中, 我们常常需要对密码, 身份证号, 手机号等敏感信息进行加密, 以保证数据的安全性。 如果使用明文存储, 当黑客入侵了数据库时, 就可以轻松获取到用户的相关信息, 从而对用户或者企业造成信息…

程序员职业并不会彻底消失

目录 程序员职业在技术革新背景下面临着怎样的冲击与挑战? 程序员职业的核心能力及价值是否能被AI完全取代? 程序员的核心能力是什么?

跨域问题总结

文章目录 概要web应用整体请求流程技术名词解释跨域问题产生的原理解决方案前端代码角度前端服务器角度后端代码角度后端服务器角度 小结 概要 在不成熟的前后端开发过程中&#xff0c;经常遇到跨域问题&#xff1b; 在前后端分离的模式下的开发过程中&#xff0c;经常遇到跨域…

【全开源】国际版JAVA多商户运营版商城系统源码支持Android+IOS+H5博纳软云

本系统开发使用JAVA技术栈开发 使用uniapp技术栈 支持H5AndroidIOS 一、功能介绍 精准分类、我的团队、开通会员 我的返利、我的订单、快速购买 邀请返利、购物车、我的提现 二、演示说明 多商户体验方式&#xff1a; 请私信客服获取体验地址 多商户自营商城商户端 : 请…

Arm MMU深度解读

文章目录 一、MMU概念介绍二、虚拟地址空间和物理地址空间2.1、(虚拟/物理)地址空间的范围2.2、物理地址空间有效位(范围) 三、Translation regimes四、地址翻译/几级页表&#xff1f;4.1、思考&#xff1a;页表到底有几级&#xff1f;4.2、以4KB granule为例&#xff0c;页表的…