基于PaddleClas的人物年龄分类项目

目录

一、任务概述

二、算法研发

2.1 下载数据集

2.2 数据集预处理

2.3 安装PaddleClas套件

2.4 算法训练

2.5 静态图导出

2.6 静态图推理

三、小结


一、任务概述

    最近遇到个需求,需要将图像中的人物区分为成人和小孩,这是一个典型的二分类问题,打算采用飞桨的图像分类套件PaddleClas来完成算法研发。本文记录相关流程。

二、算法研发

2.1 下载数据集

    本文采用MaGaAge_Asian数据集,该数据集主要由亚洲人图片组成,训练集包含40000张图像,验证集包含3495张图像,每张图像都有对应的年龄真值,所有图像均处理成了统一的大小,宽178像素,高218像素。

数据集地址下载链接。数据集部分示例如下图所示:

    该数据集本意是用来做年龄预测的,属于一个数值回归任务,本文将其变成二分类任务,以13岁年龄为界限,小于该年龄的属于小孩,大于该年龄的属于成人。这里之所以选择13岁,因为这个任务是需要筛选出长得很“像”小孩的小孩,13岁以上的青少年很多本身已经长的像成人了,因此,选择13岁作为分界线。

    下面首先对该数据集进行处理。

2.2 数据集预处理

    MaGaAge_Asian数据集每张图片对应的人物年龄存放在list文件夹的两个文件中,其中train_age.txt存放训练集对应的年龄真值,test_age.txt存放验证集对应的年龄真值。下面要写一个脚本,将所有小于13岁的图片移动到一个文件夹内,所有大于等于13岁的图片移动到另一个文件夹内。

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''
@文件        :split_asian.py
@说明        :拆分megaage_asian数据集,将小于13岁的移动到一个文件夹,大于等于13岁的移动到另一个文件夹
@时间        :2024/07/16 09:11:16
@作者        :Bin Qian
@版本        :1.0
'''import os
import cv2thr = 13 # 年龄阈值# 读取年龄列表
agefile = 'megaage_asian/list/test_age.txt'
f=open(agefile) 
ageLst = f.read().splitlines()
f.close() # 读取图像
imgFolder = 'megaage_asian/val'
imgnames = os.listdir(imgFolder)
index = 50000
for imgname in imgnames:imgPath = os.path.join(imgFolder,imgname)img = cv2.imread(imgPath)if img is None:continueprint(imgPath)imgindex = int(imgname.split('.')[0])age = int(ageLst[imgindex-1])if age < thr:dstFolder = 'ageclas/child'else:dstFolder = 'ageclas/adult'savePath = os.path.join(dstFolder,str(index)+'_asian.jpg')cv2.imwrite(savePath,img)index += 1
print('完成')

值得注意的是MaGaAge_Asian数据集中有很多质量较差的图像,这些“脏”图像会影响学习效果,最好手工检查这些数据并将其剔除。

另外,为了能够取得更好的效果,本文从互联网和FFHQ数据集里面再挑选出一些小孩和成人的照片进行补充。部分代码如下:

import os
import cv2# 读取图像
imgFolder = 'adult'
imgnames = os.listdir(imgFolder)
index = 1
for imgname in imgnames:imgPath = os.path.join(imgFolder,imgname)img = cv2.imread(imgPath)if img is None:continueprint(imgPath)dstFolder = 'ageclas/adult'savePath = os.path.join(dstFolder,str(index)+'_data.jpg')cv2.imwrite(savePath,img)index += 1
print('完成')

补充完整后,最后对整理好的数据集进行拆分,并且获得对应的文件列表:

# 导入系统库
import os
import random
import cv2# 定义参数
img_folder = 'ageclas'
trainlst = 'train_list.txt'
vallst = 'val_list.txt'
ratio = 0.95 # 训练集占比
labellst='label.txt'def writeLst(lstpath,namelst):'''保存文件列表'''print('正在写入 '+lstpath)random.shuffle (namelst)# 写入训练样本文件f=open(lstpath, 'a', encoding='utf-8')for i in range(len(namelst)):text = namelst[i]+'\n'f.write(text)f.close()print(lstpath+ '已完成写入')def main():'''主函数'''# 查找文件夹folderlst = os.listdir(img_folder)print('共找到 %d 个文件夹' % len(folderlst))# 循环处理trainnamelst = list()valnamelst = list()labelnamelst = list()for i in range(len(folderlst)):class_name = folderlst[i]class_label = iprint('开始处理 '+class_name+' 文件夹')# 获取子文件夹文件列表filenamelst = os.listdir(os.path.join(img_folder,class_name))totalNum = len(filenamelst)print('当前文件夹图片数量为: ' + str(totalNum)) trainNum = int(ratio*totalNum)text =  str(class_label)+ ' ' + class_namelabelnamelst.append(text)# 检查并校验图像for j in range(totalNum):imgpath = os.path.join(img_folder,class_name,filenamelst[j])img = cv2.imread(imgpath, cv2.IMREAD_COLOR)if img is None:continuetext = imgpath + ' ' + str(class_label)if j <= trainNum: trainnamelst.append(text)else:valnamelst.append(text)writeLst(trainlst,trainnamelst)writeLst(vallst,valnamelst)   writeLst(labellst,labelnamelst)     print('全部完成')if __name__ == '__main__':'''程序入口'''main()

运行后会生成train_lst.txt、val_lst.txt以及label.txt三个文件,有了这三个文件就可以使用PaddleClas套件进行算法研发了。

2.3 安装PaddleClas套件

git clone https://gitee.com/paddlepaddle/PaddleClas.git
cd PaddleClas
sudo python setup.py install

2.4 算法训练

在PaddleClas目录下新建一个配置文件config_lcnet.yaml,采用PPLCNet_x0_5模型来训练,配置文件代码如下:

# global configs
Global:checkpoints: nullpretrained_model: nulloutput_dir: ./output/device: gpusave_interval: 5eval_during_train: Trueeval_interval: 5epochs: 200print_batch_step: 10use_visualdl: True# used for static mode and model exportimage_shape: [3, 224, 224]save_inference_dir: ./output/inference
# model architecture
Arch:name: PPLCNet_x0_5class_num: 2# loss function config for traing/eval process
Loss:Train:- CELoss:weight: 1.0epsilon: 0.1Eval:- CELoss:weight: 1.0Optimizer:name: Momentummomentum: 0.9lr:name: Cosinelearning_rate: 0.8warmup_epoch: 5regularizer:name: 'L2'coeff: 0.00003# data loader for train and eval
DataLoader:Train:dataset:name: ImageNetDatasetimage_root: ../process_data/cls_label_path: ../process_data/train_list.txttransform_ops:- DecodeImage:to_rgb: Truechannel_first: False- ResizeImage:size: [224,224]- RandFlipImage:flip_code: 1- NormalizeImage:scale: 1.0/255.0mean: [0.485, 0.456, 0.406]std: [0.229, 0.224, 0.225]order: ''sampler:name: DistributedBatchSamplerbatch_size: 64drop_last: Falseshuffle: Trueloader:num_workers: 4use_shared_memory: TrueEval:dataset: name: ImageNetDatasetimage_root: ../process_data/cls_label_path: ../process_data/val_list.txttransform_ops:- DecodeImage:to_rgb: Truechannel_first: False- ResizeImage:size: [224,224]- NormalizeImage:scale: 1.0/255.0mean: [0.485, 0.456, 0.406]std: [0.229, 0.224, 0.225]order: ''sampler:name: DistributedBatchSamplerbatch_size: 64drop_last: Falseshuffle: Falseloader:num_workers: 4use_shared_memory: TrueInfer:infer_imgs: "../testimgs/10.jpg"batch_size: 1transforms:- DecodeImage:to_rgb: Truechannel_first: False- ResizeImage:size: [224,224]- NormalizeImage:scale: 1.0/255.0mean: [0.485, 0.456, 0.406]std: [0.229, 0.224, 0.225]order: ''- ToCHWImage:PostProcess:name: Topktopk: 1class_id_map_file: "../process_data/label.txt"Metric:Train:- TopkAcc:topk: [1]Eval:- TopkAcc:topk: [1]

然后使用下面的命令进行训练:

export CUDA_VISIBLE_DEVICES=0,1
python3 -m paddle.distributed.launch \--gpus="0,1" \tools/train.py \-c config_lcnet.yaml 

训练完成后可以使用下面的命令可视化查看训练结果:

visualdl --logdir results/vdl

运行效果如下:

可以看到,基本在epoch=100以后就收敛了,最高top1准确率达到97.5%,准确率还是比较高的。

下面可以使用动态图对单张图像进行测试,命令如下:

python3 tools/infer.py -c config_lcnet.yaml -o Global.pretrained_model=output/PPLCNet_x0_5/best_model

输出如下:

[{'class_ids': [1], 'scores': [0.93522], 'file_name': '../testimgs/10.jpg', 'label_names': ['adult']}]

2.5 静态图导出

为了方便后面进行模型部署,将训练好的最佳模型进行静态图导出。具体命令如下:

python3 tools/export_model.py \-c config_lcnet.yaml \-o Global.pretrained_model=output/PPLCNet_x0_5/best_model \-o Global.save_inference_dir=output/inference

导出的静态图模型存放在output/inference文件夹下面,整个模型参数加起来不超过3M,因此可以看出这个训练好的PPLCNet_x0_5模型是一个非常轻量级的模型。

2.6 静态图推理

下面使用静态图来进行推理。在推理前先使用visualdl工具查看下静态图模型的输入和输出,这将为编写推理脚本奠定基础。

可以看到,输入是[batch,3,224,224]的float型图像数据,输出是[batch,2]的float型数据。尤其是输出的两个值,代表的是两个类别的概率。

有了上面的分析,下面可以用PaddleInference写一个推理脚本infer.py:

import cv2
import numpy as np
from paddle.inference import create_predictor
from paddle.inference import Config as PredictConfig# 加载静态图模型
model_path = "./output/inference/inference.pdmodel"
params_path = "./output/inference/inference.pdiparams"
pred_cfg = PredictConfig(model_path, params_path)
pred_cfg.enable_memory_optim()  # 启用内存优化
pred_cfg.switch_ir_optim(True)
pred_cfg.enable_use_gpu(500, 0)  # 启用GPU推理
predictor = create_predictor(pred_cfg)  # 创建PaddleInference推理器# 解析模型输入输出
input_names = predictor.get_input_names()
input_handle = {}
for i in range(len(input_names)):input_handle[input_names[i]] = predictor.get_input_handle(input_names[i])
output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(output_names[0])# 图像预处理
img = cv2.imread("../testimgs/10.jpg", flags=cv2.IMREAD_COLOR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_AREA)
img = img.astype(np.float32)
PIXEL_MEANS =(0.485, 0.456, 0.406)    # RGB格式的均值和方差
PIXEL_STDS = (0.229, 0.224, 0.225)
img/=255.0
img-=np.array(PIXEL_MEANS)
img/=np.array(PIXEL_STDS)
img = np.transpose(img[np.newaxis, :, :, :], (0, 3, 1, 2))# 预测
input_handle["x"].copy_from_cpu(img)
predictor.run()
results = output_handle.copy_to_cpu()# 后处理
results = results.squeeze(0)
if results[0]>results[1]:print('小孩'+"  "+str(results[0]))
else:print('大人'+"  "+str(results[1]))

从网上随便找两张照片,运行效果如下:

输出结果:

小孩  0.7256172

输出结果:

大人  0.9533998

可以看到,推理效果还是比较满意的。

三、小结

本文以项目为主线,使用了PaddleClas算法套件解决了年龄分类问题。后续读者如果想要深入学习PaddlePaddle(飞桨)及相关算法套件,可以关注我的书籍(链接)。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/875598.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Leetcode3219. 切蛋糕的最小总开销 II

Every day a Leetcode 题目来源&#xff1a;3219. 切蛋糕的最小总开销 II 解法1&#xff1a;贪心 谁的开销更大&#xff0c;就先切谁&#xff0c;并且这个先后顺序与切的次数无关。 代码&#xff1a; /** lc appleetcode.cn id3219 langcpp** [3219] 切蛋糕的最小总开销 I…

【SQL 新手教程 1/20】SQL语言MySQL数据库 简介

&#x1f497; 什么是SQL&#xff1f;⭐ (Structured Query Language) 结构化查询语言&#xff0c;是访问和处理关系数据库的计算机标准语言 无论用什么编程语言&#xff08;Java、Python、C……&#xff09;编写程序&#xff0c;只要涉及到操作关系数据库都必须通过SQL来完成 …

4招清洁法,清理电脑无死角,焕然一新效率高

随着时间的积累&#xff0c;电脑内部可能会堆积起大量的垃圾文件、缓存数据和无用程序。因此&#xff0c;定期清理电脑是很有必要的。为了让你的电脑重新焕发生机&#xff0c;提高工作效率&#xff0c;本文将为你介绍4招实用的清洁法&#xff0c;助你轻松清理电脑死角&#xff…

JavaWeb学习——请求响应、分层解耦

目录 一、请求响应学习 1、请求 简单参数 实体参数 数组集合参数 日期参数 Json参数 路径参数 总结 2、响应 ResponseBody&统一响应结果 二、分层解耦 1、三层架构 三层架构含义 架构划分 2、分层解耦 引入概念 容器认识 3、IOC&DI入门 4、IOC详解 …

Cadence23学习笔记(十四)

ARC就是圆弧走线的意思&#xff1a; 仅打开网络的话可以只针对net进行修改走线的属性&#xff1a; 然后现在鼠标左键点那个走线&#xff0c;那个走线就会变为弧形&#xff1a; 添加差分对&#xff1a; 之后&#xff0c;分别点击两条线即可分配差分对&#xff1a; 选完差分对之后…

微服务实践和总结

H5原生组件web Component Web Component 是一种用于构建可复用用户界面组件的技术&#xff0c;开发者可以创建自定义的 HTML 标签&#xff0c;并将其封装为包含逻辑和样式的独立组件&#xff0c;从而在任何 Web 应用中重复使用。 <!DOCTYPE html> <html><head…

【vue3|第18期】Vue-Router路由的三种传参方式

日期:2024年7月17日 作者:Commas 签名:(ง •_•)ง 积跬步以致千里,积小流以成江海…… 注释:如果您觉得有所帮助,帮忙点个赞,也可以关注我,我们一起成长;如果有不对的地方,还望各位大佬不吝赐教,谢谢^ - ^ 1.01365 = 37.7834;0.99365 = 0.0255 1.02365 = 1377.408…

unity2D游戏开发01项目搭建

1新建项目 选择2d模板,设置项目名称和存储位置 在Hierarchy面板右击&#xff0c;create Empty 添加组件 在Project视图中右键新建文件夹 将图片资源拖进来&#xff08;图片资源在我的下载里面&#xff09; 点击Player 修改属性&#xff0c;修好如下 点击Sprite Editor 选择第二…

Angular由一个bug说起之八:实践中遇到的一个数据颗粒度的问题

互联网产品离不开数据处理&#xff0c;数据处理有一些基本的原则包括&#xff1a;准确性、‌完整性、‌一致性、‌保密性、‌及时性。‌ 准确性&#xff1a;是数据处理的首要目标&#xff0c;‌确保数据的真实性和可靠性。‌准确的数据是进行分析和决策的基础&#xff0c;‌因此…

【数据结构】建堆算法复杂度分析及TOP-K问题

【数据结构】建堆算法复杂度分析及TOP-K问题 &#x1f525;个人主页&#xff1a;大白的编程日记 &#x1f525;专栏&#xff1a;数据结构 文章目录 【数据结构】建堆算法复杂度分析及TOP-K问题前言一.复杂度分析1.1向下建堆复杂度1.2向上建堆复杂度1.3堆排序复杂度 二.TOP-K问…

深度学习1-简介

人工智能&#xff08;AI&#xff09;旨在打造模仿智能行为的系统。它覆盖了众多方法&#xff0c;涵盖了基于逻辑、搜索和概率推理的技术。机器学习是 AI 的一个分支&#xff0c;它通过对观测数据进行数学模型拟合来学习决策制定。这个领域近年来迅猛发展&#xff0c;现在几乎&a…

​ ​【Linux】-----工具篇(多模式编辑器vim介绍及配置)

目录 认识常用三种模式 基本操作 Ⅰ、进入/打开vim Ⅱ、模式转换 Ⅲ、退出vim 命令集 Ⅰ、命令模式下 移动光标 删除文字 复制 替换 撤销 批量化注释 批量化去注释 Ⅱ、底行模式下 列出行号 跳转至指定行 查找字符 保存文件 退出vim 查看文件 分屏操作 vim的简…

论文阅读:面向自动驾驶场景的多目标点云检测算法

论文地址:面向自动驾驶场景的多目标点云检测算法 概要 点云在自动驾驶系统中的三维目标检测是关键技术之一。目前主流的基于体素的无锚框检测算法通常采用复杂的二阶段修正模块,虽然在算法性能上有所提升,但往往伴随着较大的延迟。单阶段无锚框点云检测算法简化了检测流程,…

微信小程序:vant-weapp 组件库、css 变量

vant-weapp 组件库 前往 vant-weapp 官网 npm 使用限制&#xff1a;不支持依赖于 Node.js 内置库、浏览器内置对象、C 插件 的包。 安装 vant-weapp # 通过 npm 安装 npm i vant/weapp -S --production# 通过 yarn 安装 yarn add vant/weapp --production# 安装 0.x 版本 npm i…

Mac环境报错 error: symbol(s) not found for architecture x86_64

Mac 环境Qt Creator报错 error: symbol(s) not found for architecture x86_64 错误信息 "symbol(s) not found for architecture x86_64" 通常是在编译或链接过程中出现的问题。这种错误提示通常涉及到符号未找到或者是因为编译器没有找到适当的库文件或函数定义。 …

powe bi界面认识及矩阵表基本操作 - 1

powe bi界面认识及矩阵表操作 1. 界面认识1.1 选择数据源1.2 选择相关表及点击加载1.3 表字段显示位置1.4 表属性按钮位置1.5 界面布局按钮认识 2. 矩阵表基本操作2.1 选择矩阵表2.2 创建矩阵表2.3 设置字体大小2.4 行填充&#xff1a;修改高度2.5 列宽&#xff1a;设置列的宽度…

【Python实战因果推断】55_因果推理概论5

目录 Consistency and Stable Unit Treatment Values Violations Causal Quantities of Interest Consistency and Stable Unit Treatment Values 在上述方程中&#xff0c;隐含着两个基本假设。第一个假设意味着潜在结果与处理是一致的&#xff1a;当时&#xff0c;。换句…

Vue3相比于Vue2进行了哪些更新

1、响应式原理 vue2 vue2中采用 defineProperty 来劫持整个对象&#xff0c;然后进行深度遍历所有属性&#xff0c;给每个属性添加getter和setter&#xff0c;结合发布订阅模式实现响应式。 存在的问题&#xff1a; 检测不到对象属性的添加和删除数组API方法无法监听到需要对…

Shader笔记1:基础概念

有相当一部分来自shader圣经 Base of CG Concepts Tangent, Normal and Binormal N&#xff1a;法线&#xff08;Normal, N&#xff09;垂直于表面 T&#xff1a;切线&#xff08;Tangent, T&#xff09;与U方向同向 B&#xff1a;副切线&#xff08;BiTangent, B&#xff09…

ADS 使用教程(十六)Using Sliders for Data Processing

上一篇&#xff1a;ADS 使用教程&#xff08;十五&#xff09;Multi-Dimensional Data Processing in ADS 在这一节&#xff0c;我们来谈论一下如何在进行多维数据处理时使用滑块&#xff08;Sliders&#xff09;来进行数据处理和分析。通过该方法&#xff0c;我们可以通过拖动…