在自定义数据集上使用 Detectron2 和 PyTorch 进行人脸检测

本文讲讲述如何使用Python在自定义人脸检测数据集上微调预训练的目标检测模型。学习如何为Detectron2和PyTorch准备自定义人脸检测数据集,微调预训练模型以在图像中找到人脸边界。

人脸检测是在图像中找到(边界的)人脸的任务。这在以下情况下很有用:

  • 安全系统(识别人员的第一步)

  • 为拍摄出色的照片进行自动对焦和微笑检测

  • 检测年龄、种族和情感状态以用于营销

1b5f0fcaea773df53fcb08e057aff96f.png

历史上,这是一个非常棘手的问题。大量的手动特征工程、新颖的算法和方法被开发出来以改进最先进技术。

如今,人脸检测模型已经包含在几乎每个计算机视觉包/框架中。其中一些表现最佳的模型使用了深度学习方法。例如,OpenCV提供了各种工具,如级联分类器。

在本指南中,您将学习如何:

  • 准备一个用于人脸检测的自定义数据集,以用于Detectron2

  • 使用(接近)最先进的目标检测模型在图像中查找人脸

  • 您可以将这项工作扩展到人脸识别

Detectron2

Detectron2是一个用于构建最先进的目标检测和图像分割模型的框架,由Facebook Research团队开发。Detectron2是第一个版本的完全重写。Detectron2使用PyTorch(与最新版本兼容),并且允许进行超快速训练。您可以在Facebook Research的入门博客文章中了解更多信息。

Detectron2的真正强大之处在于模型动物园中提供了大量的预训练模型。但是,如果您不能在自己的数据集上对其进行微调,那又有什么好处呢?幸运的是,这非常容易!在本指南中,我们将看到如何完成这项工作。

安装Detectron2

在撰写本文时,Detectron2仍处于alpha阶段。虽然有官方版本,但我们将从主分支克隆和编译。这应该等于版本0.1。让我们首先安装一些要求:

!pip install -q cython pyyaml == 5.1 
!pip install -q -U 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'

然后,下载、编译和安装Detectron2包: 

!git clone https://github.com/facebookresearch/detectron2 detectron2_repo 
!pip install -q -e detectron2_repo

此时,您需要重新启动笔记本运行时以继续!

%reload_ext watermark %watermark -v -p numpy,pandas,pycocotools,torch,torchvision,detectron2
CPython 3.6.9
IPython 5.5.0
numpy 1.17.5
pandas 0.25.3
pycocotools 2.0
torch 1.4.0
torchvision 0.5.0
detectron2 0.1
import torch, torchvision
import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()import globimport os
import ntpath
import numpy as np
import cv2
import random
import itertools
import pandas as pd
from tqdm import tqdm
import urllib
import json
import PIL.Image as Imagefrom detectron2 import model_zoo
from detectron2.engine import DefaultPredictor, DefaultTrainer
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer, ColorMode
from detectron2.data import DatasetCatalog, MetadataCatalog, build_detection_test_loader
from detectron2.evaluation import COCOEvaluator, inference_on_dataset
from detectron2.structures import BoxModeimport seaborn as sns
from pylab import rcParams
import matplotlib.pyplot as plt
from matplotlib import rc%matplotlib inline
%config InlineBackend.figure_format='retina'sns.set(style='whitegrid', palette='muted', font_scale=1.2)HAPPY_COLORS_PALETTE = ["#01BEFE", "#FFDD00", "#FF7D00", "#FF006D", "#ADFF02", "#8F00FF"]sns.set_palette(sns.color_palette(HAPPY_COLORS_PALETTE))rcParams['figure.figsize'] = 12, 8RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

人脸检测数据

该数据集在公共领域免费提供。它由Dataturks提供,并托管在Kaggle上:图像中标有边界框的人脸。有大约500张图像,通过边界框手动标记了大约1100个人脸。

我已经下载了包含注释的JSON文件,并将其上传到了Google Drive。让我们获取它:

!gdown --id 1K79wJgmPTWamqb04Op2GxW0SW9oxw8KS

让我们将文件加载到Pandas数据框中:

faces_df = pd.read_json('face_detection.json', lines=True)

每行包含一个单独的人脸注释。请注意,多行可能指向单个图像(例如,每个图像有多个人脸)。

数据预处理

数据集仅包含图像URL和注释。我们将不得不下载这些图像。我们还将对注释进行标准化,以便稍后在Detectron2中更容易使用:

os.makedirs("faces", exist_ok=True)dataset = []for index, row in tqdm(faces_df.iterrows(), total=faces_df.shape[0]):img = urllib.request.urlopen(row["content"])img = Image.open(img)img = img.convert('RGB')image_name = f'face_{index}.jpeg'img.save(f'faces/{image_name}', "JPEG")annotations = row['annotation']for an in annotations:data = {}width = an['imageWidth']height = an['imageHeight']points = an['points']data['file_name'] = image_namedata['width'] = widthdata['height'] = heightdata["x_min"] = int(round(points[0]["x"] * width))data["y_min"] = int(round(points[0]["y"] * height))data["x_max"] = int(round(points[1]["x"] * width))data["y_max"] = int(round(points[1]["y"] * height))data['class_name'] = 'face'dataset.append(data)

让我们将数据放入数据框中,以便我们可以更好地查看:

df = pd.DataFrame(dataset)
print(df.file_name.unique().shape[0], df.shape[0])
409 1132

我们总共有409张图像(比承诺的500张少得多)和1132个注释。让我们将它们保存到磁盘上(以便您可以重用它们):

数据

让我们查看一些示例注释数据。我们将使用OpenCV加载图像,添加边界框并调整大小。我们将定义一个助手函数来完成所有这些操作:

def annotate_image(annotations, resize=True):file_name = annotations.file_name.to_numpy()[0]img = cv2.cvtColor(cv2.imread(f'faces/{file_name}'), cv2.COLOR_BGR2RGB)for i, a in annotations.iterrows():cv2.rectangle(img, (a.x_min, a.y_min), (a.x_max, a.y_max), (0, 255, 0), 2)if not resize:return imgreturn cv2.resize(img, (384, 384), interpolation = cv2.INTER_AREA)

让我们首先显示一些带注释的图像:

5cd702ae92b05f2ecba0f1a2e0cdfd68.png

f6740914dcb440fcd84269b6d24ebff5.png

这些都是不错的图像,注释清晰可见。我们可以使用torchvision创建一个图像网格。请注意,这些图像具有不同的大小,因此我们将对其进行调整大小:

2c212409437c164de8cc81fa25a9b218.png

您可以清楚地看到一些注释缺失(第4列)。这就是现实生活中的数据,有时您必须以某种方式处理它。

使用Detectron 2进行人脸检测

现在,我们将逐步介绍使用自定义数据集微调模型的步骤。但首先,让我们保留5%的数据进行测试:

df = pd.read_csv('annotations.csv')IMAGES_PATH = f'faces'unique_files = df.file_name.unique()train_files = set(np.random.choice(unique_files, int(len(unique_files) * 0.95), replace=False))
train_df = df[df.file_name.isin(train_files)]
test_df = df[~df.file_name.isin(train_files)]

在这里,经典的训练测试分割方法不适用,因为我们希望在文件名之间进行分割。

接下来的部分以稍微通用的方式编写。显然,我们只有一个类别-人脸。但是,添加更多类别应该就像向数据框中添加更多注释一样简单:

classes = df.class_name.unique().tolist()

接下来,我们将编写一个将我们的数据集转换为Detectron2:

def create_dataset_dicts(df, classes):dataset_dicts = []for image_id, img_name in enumerate(df.file_name.unique()):record = {}image_df = df[df.file_name == img_name]file_path = f'{IMAGES_PATH}/{img_name}'record["file_name"] = file_pathrecord["image_id"] = image_idrecord["height"] = int(image_df.iloc[0].height)record["width"] = int(image_df.iloc[0].width)objs = []for _, row in image_df.iterrows():xmin = int(row.x_min)ymin = int(row.y_min)xmax = int(row.x_max)ymax = int(row.y_max)poly = [(xmin, ymin), (xmax, ymin),(xmax, ymax), (xmin, ymax)]poly = list(itertools.chain.from_iterable(poly))obj = {"bbox": [xmin, ymin, xmax, ymax],"bbox_mode": BoxMode.XYXY_ABS,"segmentation": [poly],"category_id": classes.index(row.class_name),"iscrowd": 0}objs.append(obj)record["annotations"] = objsdataset_dicts.append(record)return dataset_dicts

使用的格式的函数:我们将每个注释行转换为一个具有注释列表的单个记录。您可能还会注意到,我们正在构建一个与边界框完全相同形状的多边形。这对于Detectron2中的图像分割模型是必需的。

您将不得不将数据集注册到数据集和元数据目录中:

for d in ["train", "val"]:DatasetCatalog.register("faces_" + d, lambda d=d: create_dataset_dicts(train_df if d == "train" else test_df, classes))MetadataCatalog.get("faces_" + d).set(thing_classes=classes)statement_metadata = MetadataCatalog.get("faces_train")

不幸的是,默认情况下不包含测试集的评估器。我们可以通过编写自己的训练器轻松修复它:

class CocoTrainer(DefaultTrainer):@classmethoddef build_evaluator(cls, cfg, dataset_name, output_folder=None):if output_folder is None:os.makedirs("coco_eval", exist_ok=True)output_folder = "coco_eval"return COCOEvaluator(dataset_name, cfg, False, output_folder)

如果未提供文件夹,则评估结果将存储在coco_eval文件夹中。

在Detectron2模型上微调与编写PyTorch代码完全不同。我们将加载配置文件,更改一些值,然后启动训练过程。但是嘿,如果您知道自己在做什么,这真的会有所帮助。在本教程中,我们将使用Mask R-CNN X101-FPN模型。它在COCO数据集上进行了预训练,并且表现非常好。缺点是训练速度较慢。

让我们加载配置文件和预训练的模型权重:

cfg = get_cfg()cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml")
)cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml"
)

指定我们将用于训练和评估的数据集(我们注册了这些数据集):

cfg.DATASETS.TRAIN = ("faces_train",)
cfg.DATASETS.TEST = ("faces_val",)
cfg.DATALOADER.NUM_WORKERS = 4

至于优化器,我们将进行一些魔法以收敛到某个好的值:

cfg.SOLVER.IMS_PER_BATCH = 4
cfg.SOLVER.BASE_LR = 0.001
cfg.SOLVER.WARMUP_ITERS = 1000
cfg.SOLVER.MAX_ITER = 1500
cfg.SOLVER.STEPS = (1000, 1500)
cfg.SOLVER.GAMMA = 0.05

除了标准的内容(批量大小、最大迭代次数和学习率)外,我们还有几个有趣的参数:

  • WARMUP_ITERS - 学习率从0开始,并在此次数的迭代中逐渐增加到预设值

  • STEPS - 学习率将在其检查点(迭代次数)降低的次数

最后,我们将指定类别的数量以及我们将在测试集上进行评估的周期:

cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 64
cfg.MODEL.ROI_HEADS.NUM_CLASSES = len(classes)cfg.TEST.EVAL_PERIOD = 500

是时候开始训练了,使用我们自定义的训练器:

os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)trainer = CocoTrainer(cfg)
trainer.resume_or_load(resume=False)
trainer.train()

评估目标检测模型

与评估标准分类或回归模型相比,评估目标检测模型有点不同。您需要了解的主要指标是IoU(交并比)。它测量两个边界之间的重叠程度-预测的和真实的。它可以在0和1之间获得值。

a363c6b48bd9d29d9f2e4509c5cc8eb5.png

使用IoU,可以定义阈值(例如> 0.5)来分类预测是否为真阳性(TP)或假阳性(FP)。现在,您可以通过获取精度-召回曲线下的区域来计算平均精度(AP)现在,AP@X(例如AP50)只是某个IoU阈值下的AP。这应该让您对如何评估目标检测模型有一个工作的了解。

我已经准备了一个预训练模型,因此不必等待训练完成。下载它:

!gdown --id 18Ev2bpdKsBaDufhVKf0cT6RmM3FjW3nL 
!mv face_detector.pth output/model_final.pth

我们可以通过加载模型并设置最低的85%的置信度阈值来开始进行预测,以此来将预测视为正确:

cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.85
predictor = DefaultPredictor(cfg)

运行评估器与训练好的模型:

evaluator = COCOEvaluator("faces_val", cfg, False, output_dir="./output/")
val_loader = build_detection_test_loader(cfg, "faces_val")
inference_on_dataset(trainer.model, val_loader, evaluator)

在图像中查找人脸

接下来,让我们创建一个文件夹,并保存测试集中所有带有预测注释的图像:

os.makedirs("annotated_results", exist_ok=True)test_image_paths = test_df.file_name.unique()
for clothing_image in test_image_paths:file_path = f'{IMAGES_PATH}/{clothing_image}'im = cv2.imread(file_path)outputs = predictor(im)v = Visualizer(im[:, :, ::-1],metadata=statement_metadata,scale=1.,instance_mode=ColorMode.IMAGE)instances = outputs["instances"].to("cpu")instances.remove('pred_masks')v = v.draw_instance_predictions(instances)result = v.get_image()[:, :, ::-1]file_name = ntpath.basename(clothing_image)write_res = cv2.imwrite(f'annotated_results/{file_name}', result)

eaef04c00ad6f46682d79e21682c449c.png

·  END  ·

HAPPY LIFE

c941539f67c96091161bb2db5bd893b7.png

本文仅供学习交流使用,如有侵权请联系作者删除

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/50951.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于前端技术原生HTML、JS、CSS 电子病历编辑器源码

电子病历系统采取结构化与自由式录入的新模式,自由书写,轻松录入。实现病人医疗记录(包含有首页、病程记录、检查检验结果、医嘱、手术记录、护理记录等等。)的保存、管理、传输和重现,取代手写纸张病历。不仅实现了纸…

centos7.9和redhat6.9 离线升级OpenSSH和openssl (2023年的版本)

升级注意事项! 1、多开几个连接窗口(xshell),避免升级openssh失败无法再次连接终端,否则要跑机房了。 2、可开启telnet服务、vnc服务、打快照。多几个“保命”的路数。一、centos7.9的信息 [rootnode2 ~]# openssl v…

Socket通信与WebSocket协议

文章目录 目录 文章目录 前言 一、Socket通信 1.1 BIO 1.2 NIO 1.3 AIO 二、WebSocket协议 总结 前言 一、Socket通信 Socket是一种用于网络通信的编程接口(API),它提供了一种机制,使不同主机之间可以通过网络进行数据传输和通信…

HQL解决连续三天登陆问题

1.背景 统计连续登录天数超过3天的用户,输出信息包括:用户id,登录天数,起始时间,结束时间; 2.准备数据 -- 建表 create table if not exists user_login_3days(user_id STRING,login_date date );--插入…

14、缓存预热+缓存雪崩+缓存击穿+缓存穿透

缓存预热缓存雪崩缓存击穿缓存穿透 ● 缓存预热、雪崩、穿透、击穿分别是什么?你遇到过那几个情况? ● 缓存预热你是怎么做到的? ● 如何避免或者减少缓存雪崩? ● 穿透和击穿有什么区别?它两一个意思还是截然不同&am…

JDBC详解

文章目录 一、引言1.1 如何操作数据库1.2 实际开发中,会采用客户端操作数据库吗? 二、JDBC(Java Database Connectivity)2.1 什么是 JDBC?2.2 JDBC 核心思想2.2.1 MySQL 数据库驱动2.2.2 JDBC API 2.3 环境搭建 三、JD…

面试官:请手写一个Promise

前端面试题库 (面试必备) 推荐:★★★★★ 地址:前端面试题库 前言 面试官:请手写一个Promise?(开门见山) 我:既然说到Promise,那我肯定得先介…

RabbitMQ介绍

RabbitMQ的概念 RabbitMQ 是一个消息中间件:它接受并转发消息。你可以把它当做一个快递站点,当你要发送一个包裹时,你把你的包裹放到快递站,快递员最终会把你的快递送到收件人那里,按照这种逻辑 RabbitMQ 是 一个快递…

基于 Debian 12 的MX Linux 23 正式发布!

导读MX Linux 是基于 Debian 稳定分支的面向桌面的 Linux 发行,它是 antiX 及早先的 MEPIS Linux 社区合作的产物。它采用 Xfce 作为默认桌面环境,是一份中量级操作系统,并被设计为优雅而高效的桌面与如下特性的结合:配置简单、高…

微信开发之一键修改群聊备注的技术实现

修改群备注 修改群名备注后,如看到群备注未更改,是手机缓存问题,可以连续点击进入其他群,在点击进入修改的群,再返回即可看到修改后的群备注名,群名称的备注仅自己可见 请求URL: http://域名地…

ctfshow-红包题第二弹

0x00 前言 CTF 加解密合集CTF Web合集 0x01 题目 0x02 Write Up 同样,先看一下有没有注释的内容,可以看到有一个cmd的入参 执行之后可以看到文件代码,可以看到也是eval,但是中间对大部分的字符串都进行了过滤,留下了…

lvs实现DR模型搭建

目录 一,实现DR模型搭建 1, 负载调度器配置 1.1调整ARP参数 1.2 配置虚拟IP地址重启网卡 1.3 安装ipvsadm 1.4 加载ip_vs模块 1.5 启动ipvsadm服务 1.6 配置负载分配策略 1.7 保存策略 2, web节点配置 1.1 调整ARP参数 1.2 配置虚拟I…

Element Plus <el-table> 组件之展开行Table在项目中使用

目录 官方样式: 展开前: 展开: 原始代码: 代码详解: 项目使用场景: 完成效果: 具体实现范本: 1.调整数据结构 2. 修改标签和数据绑定 3. JavaScript 部分导入和创建对象 …

综合能源系统(8)——综合能源系统支撑技术

综合能源系统关键技术与典型案例  何泽家,李德智主编 1、大数据技术 1.1、大数据技术概述 大数据是指无法在一定时间范围内用常规软件工具进行捕捉、管理和处理的数据集合,是需要新处理模式才能具有更强的决策力、洞察发现力和流程优化能力的海量、高…

wps设置其中几页为横版

问题:写文档的时候,有些表格列数太多,页面纵向显示内容不完整,可以给它改成横向显示。 将鼠标放在表格上一页的底部,点击‘插入-分页-下一页分节符’。 将鼠标放在表格页面的底部,点击‘插入-分页-下一页分…

【Docker入门第一篇】

Docker简介 Docker 是一个开源的应用容器引擎,基于 Go 语言 并遵从 Apache2.0 协议开源。 Docker 可以让开发者打包他们的应用以及依赖包到一个轻量级、可移植的容器中,然后发布到任何流行的 Linux 机器上,也可以实现虚拟化。 容器是完全使…

《数字图像处理-OpenCV/Python》连载(2)目录

《数字图像处理-OpenCV/Python》连载(2)目录 本书京东优惠购书链接:https://item.jd.com/14098452.html 本书CSDN独家连载专栏:https://blog.csdn.net/youcans/category_12418787.html 第一部分 OpenCV-Python的基本操作 第1章 …

Redis多机实现

Background 为啥要有多机--------------1.容错 2.从服务器分担读压力。 主从结构一大难题------------如何保障一致性,对这个一致性要求不是很高,因为redis是用来做缓存的 同时我们要自动化进行故障转移-------哨兵机制,同时哨兵也可能cra…

江西南昌电气机械三维测量仪机械零件3d扫描-CASAIM中科广电

精密机械零部件是指机械设备中起到特定功能的零件,其制造精度要求非常高。这些零部件通常由金属、塑料或陶瓷等材料制成,常见的精密机械零部件包括齿轮、轴承、螺丝、活塞、阀门等。精密机械零部件的制造需要高精度的加工设备和工艺,以确保其…

HJ31 单词倒排 题解

题目描述:单词倒排_牛客题霸_牛客网 (nowcoder.com) 对字符串中的所有单词进行倒排。 1、构成单词的字符只有26个大写或小写英文字母; 2、非构成单词的字符均视为单词间隔符; 3、要求倒排后的单词间隔符以一个空格表示;如果原字符…