基于PaddleClas的人物年龄分类项目

目录

一、任务概述

二、算法研发

2.1 下载数据集

2.2 数据集预处理

2.3 安装PaddleClas套件

2.4 算法训练

2.5 静态图导出

2.6 静态图推理

三、小结


一、任务概述

    最近遇到个需求,需要将图像中的人物区分为成人和小孩,这是一个典型的二分类问题,打算采用飞桨的图像分类套件PaddleClas来完成算法研发。本文记录相关流程。

二、算法研发

2.1 下载数据集

    本文采用MaGaAge_Asian数据集,该数据集主要由亚洲人图片组成,训练集包含40000张图像,验证集包含3495张图像,每张图像都有对应的年龄真值,所有图像均处理成了统一的大小,宽178像素,高218像素。

数据集地址下载链接。数据集部分示例如下图所示:

    该数据集本意是用来做年龄预测的,属于一个数值回归任务,本文将其变成二分类任务,以13岁年龄为界限,小于该年龄的属于小孩,大于该年龄的属于成人。这里之所以选择13岁,因为这个任务是需要筛选出长得很“像”小孩的小孩,13岁以上的青少年很多本身已经长的像成人了,因此,选择13岁作为分界线。

    下面首先对该数据集进行处理。

2.2 数据集预处理

    MaGaAge_Asian数据集每张图片对应的人物年龄存放在list文件夹的两个文件中,其中train_age.txt存放训练集对应的年龄真值,test_age.txt存放验证集对应的年龄真值。下面要写一个脚本,将所有小于13岁的图片移动到一个文件夹内,所有大于等于13岁的图片移动到另一个文件夹内。

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''
@文件        :split_asian.py
@说明        :拆分megaage_asian数据集,将小于13岁的移动到一个文件夹,大于等于13岁的移动到另一个文件夹
@时间        :2024/07/16 09:11:16
@作者        :Bin Qian
@版本        :1.0
'''import os
import cv2thr = 13 # 年龄阈值# 读取年龄列表
agefile = 'megaage_asian/list/test_age.txt'
f=open(agefile) 
ageLst = f.read().splitlines()
f.close() # 读取图像
imgFolder = 'megaage_asian/val'
imgnames = os.listdir(imgFolder)
index = 50000
for imgname in imgnames:imgPath = os.path.join(imgFolder,imgname)img = cv2.imread(imgPath)if img is None:continueprint(imgPath)imgindex = int(imgname.split('.')[0])age = int(ageLst[imgindex-1])if age < thr:dstFolder = 'ageclas/child'else:dstFolder = 'ageclas/adult'savePath = os.path.join(dstFolder,str(index)+'_asian.jpg')cv2.imwrite(savePath,img)index += 1
print('完成')

值得注意的是MaGaAge_Asian数据集中有很多质量较差的图像,这些“脏”图像会影响学习效果,最好手工检查这些数据并将其剔除。

另外,为了能够取得更好的效果,本文从互联网和FFHQ数据集里面再挑选出一些小孩和成人的照片进行补充。部分代码如下:

import os
import cv2# 读取图像
imgFolder = 'adult'
imgnames = os.listdir(imgFolder)
index = 1
for imgname in imgnames:imgPath = os.path.join(imgFolder,imgname)img = cv2.imread(imgPath)if img is None:continueprint(imgPath)dstFolder = 'ageclas/adult'savePath = os.path.join(dstFolder,str(index)+'_data.jpg')cv2.imwrite(savePath,img)index += 1
print('完成')

补充完整后,最后对整理好的数据集进行拆分,并且获得对应的文件列表:

# 导入系统库
import os
import random
import cv2# 定义参数
img_folder = 'ageclas'
trainlst = 'train_list.txt'
vallst = 'val_list.txt'
ratio = 0.95 # 训练集占比
labellst='label.txt'def writeLst(lstpath,namelst):'''保存文件列表'''print('正在写入 '+lstpath)random.shuffle (namelst)# 写入训练样本文件f=open(lstpath, 'a', encoding='utf-8')for i in range(len(namelst)):text = namelst[i]+'\n'f.write(text)f.close()print(lstpath+ '已完成写入')def main():'''主函数'''# 查找文件夹folderlst = os.listdir(img_folder)print('共找到 %d 个文件夹' % len(folderlst))# 循环处理trainnamelst = list()valnamelst = list()labelnamelst = list()for i in range(len(folderlst)):class_name = folderlst[i]class_label = iprint('开始处理 '+class_name+' 文件夹')# 获取子文件夹文件列表filenamelst = os.listdir(os.path.join(img_folder,class_name))totalNum = len(filenamelst)print('当前文件夹图片数量为: ' + str(totalNum)) trainNum = int(ratio*totalNum)text =  str(class_label)+ ' ' + class_namelabelnamelst.append(text)# 检查并校验图像for j in range(totalNum):imgpath = os.path.join(img_folder,class_name,filenamelst[j])img = cv2.imread(imgpath, cv2.IMREAD_COLOR)if img is None:continuetext = imgpath + ' ' + str(class_label)if j <= trainNum: trainnamelst.append(text)else:valnamelst.append(text)writeLst(trainlst,trainnamelst)writeLst(vallst,valnamelst)   writeLst(labellst,labelnamelst)     print('全部完成')if __name__ == '__main__':'''程序入口'''main()

运行后会生成train_lst.txt、val_lst.txt以及label.txt三个文件,有了这三个文件就可以使用PaddleClas套件进行算法研发了。

2.3 安装PaddleClas套件

git clone https://gitee.com/paddlepaddle/PaddleClas.git
cd PaddleClas
sudo python setup.py install

2.4 算法训练

在PaddleClas目录下新建一个配置文件config_lcnet.yaml,采用PPLCNet_x0_5模型来训练,配置文件代码如下:

# global configs
Global:checkpoints: nullpretrained_model: nulloutput_dir: ./output/device: gpusave_interval: 5eval_during_train: Trueeval_interval: 5epochs: 200print_batch_step: 10use_visualdl: True# used for static mode and model exportimage_shape: [3, 224, 224]save_inference_dir: ./output/inference
# model architecture
Arch:name: PPLCNet_x0_5class_num: 2# loss function config for traing/eval process
Loss:Train:- CELoss:weight: 1.0epsilon: 0.1Eval:- CELoss:weight: 1.0Optimizer:name: Momentummomentum: 0.9lr:name: Cosinelearning_rate: 0.8warmup_epoch: 5regularizer:name: 'L2'coeff: 0.00003# data loader for train and eval
DataLoader:Train:dataset:name: ImageNetDatasetimage_root: ../process_data/cls_label_path: ../process_data/train_list.txttransform_ops:- DecodeImage:to_rgb: Truechannel_first: False- ResizeImage:size: [224,224]- RandFlipImage:flip_code: 1- NormalizeImage:scale: 1.0/255.0mean: [0.485, 0.456, 0.406]std: [0.229, 0.224, 0.225]order: ''sampler:name: DistributedBatchSamplerbatch_size: 64drop_last: Falseshuffle: Trueloader:num_workers: 4use_shared_memory: TrueEval:dataset: name: ImageNetDatasetimage_root: ../process_data/cls_label_path: ../process_data/val_list.txttransform_ops:- DecodeImage:to_rgb: Truechannel_first: False- ResizeImage:size: [224,224]- NormalizeImage:scale: 1.0/255.0mean: [0.485, 0.456, 0.406]std: [0.229, 0.224, 0.225]order: ''sampler:name: DistributedBatchSamplerbatch_size: 64drop_last: Falseshuffle: Falseloader:num_workers: 4use_shared_memory: TrueInfer:infer_imgs: "../testimgs/10.jpg"batch_size: 1transforms:- DecodeImage:to_rgb: Truechannel_first: False- ResizeImage:size: [224,224]- NormalizeImage:scale: 1.0/255.0mean: [0.485, 0.456, 0.406]std: [0.229, 0.224, 0.225]order: ''- ToCHWImage:PostProcess:name: Topktopk: 1class_id_map_file: "../process_data/label.txt"Metric:Train:- TopkAcc:topk: [1]Eval:- TopkAcc:topk: [1]

然后使用下面的命令进行训练:

export CUDA_VISIBLE_DEVICES=0,1
python3 -m paddle.distributed.launch \--gpus="0,1" \tools/train.py \-c config_lcnet.yaml 

训练完成后可以使用下面的命令可视化查看训练结果:

visualdl --logdir results/vdl

运行效果如下:

可以看到,基本在epoch=100以后就收敛了,最高top1准确率达到97.5%,准确率还是比较高的。

下面可以使用动态图对单张图像进行测试,命令如下:

python3 tools/infer.py -c config_lcnet.yaml -o Global.pretrained_model=output/PPLCNet_x0_5/best_model

输出如下:

[{'class_ids': [1], 'scores': [0.93522], 'file_name': '../testimgs/10.jpg', 'label_names': ['adult']}]

2.5 静态图导出

为了方便后面进行模型部署,将训练好的最佳模型进行静态图导出。具体命令如下:

python3 tools/export_model.py \-c config_lcnet.yaml \-o Global.pretrained_model=output/PPLCNet_x0_5/best_model \-o Global.save_inference_dir=output/inference

导出的静态图模型存放在output/inference文件夹下面,整个模型参数加起来不超过3M,因此可以看出这个训练好的PPLCNet_x0_5模型是一个非常轻量级的模型。

2.6 静态图推理

下面使用静态图来进行推理。在推理前先使用visualdl工具查看下静态图模型的输入和输出,这将为编写推理脚本奠定基础。

可以看到,输入是[batch,3,224,224]的float型图像数据,输出是[batch,2]的float型数据。尤其是输出的两个值,代表的是两个类别的概率。

有了上面的分析,下面可以用PaddleInference写一个推理脚本infer.py:

import cv2
import numpy as np
from paddle.inference import create_predictor
from paddle.inference import Config as PredictConfig# 加载静态图模型
model_path = "./output/inference/inference.pdmodel"
params_path = "./output/inference/inference.pdiparams"
pred_cfg = PredictConfig(model_path, params_path)
pred_cfg.enable_memory_optim()  # 启用内存优化
pred_cfg.switch_ir_optim(True)
pred_cfg.enable_use_gpu(500, 0)  # 启用GPU推理
predictor = create_predictor(pred_cfg)  # 创建PaddleInference推理器# 解析模型输入输出
input_names = predictor.get_input_names()
input_handle = {}
for i in range(len(input_names)):input_handle[input_names[i]] = predictor.get_input_handle(input_names[i])
output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(output_names[0])# 图像预处理
img = cv2.imread("../testimgs/10.jpg", flags=cv2.IMREAD_COLOR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_AREA)
img = img.astype(np.float32)
PIXEL_MEANS =(0.485, 0.456, 0.406)    # RGB格式的均值和方差
PIXEL_STDS = (0.229, 0.224, 0.225)
img/=255.0
img-=np.array(PIXEL_MEANS)
img/=np.array(PIXEL_STDS)
img = np.transpose(img[np.newaxis, :, :, :], (0, 3, 1, 2))# 预测
input_handle["x"].copy_from_cpu(img)
predictor.run()
results = output_handle.copy_to_cpu()# 后处理
results = results.squeeze(0)
if results[0]>results[1]:print('小孩'+"  "+str(results[0]))
else:print('大人'+"  "+str(results[1]))

从网上随便找两张照片,运行效果如下:

输出结果:

小孩  0.7256172

输出结果:

大人  0.9533998

可以看到,推理效果还是比较满意的。

三、小结

本文以项目为主线,使用了PaddleClas算法套件解决了年龄分类问题。后续读者如果想要深入学习PaddlePaddle(飞桨)及相关算法套件,可以关注我的书籍(链接)。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/875598.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

AI学习指南机器学习篇-SOM在数据聚类和可视化中的应用

AI学习指南机器学习篇 - SOM在数据聚类和可视化中的应用 引言 在机器学习领域&#xff0c;数据聚类和可视化是非常重要的任务。传统的聚类算法如K-means、DBSCAN等在一些场景下表现良好&#xff0c;但对于高维数据的聚类和可视化而言&#xff0c;它们的效果会受到限制。Self-…

Leetcode3219. 切蛋糕的最小总开销 II

Every day a Leetcode 题目来源&#xff1a;3219. 切蛋糕的最小总开销 II 解法1&#xff1a;贪心 谁的开销更大&#xff0c;就先切谁&#xff0c;并且这个先后顺序与切的次数无关。 代码&#xff1a; /** lc appleetcode.cn id3219 langcpp** [3219] 切蛋糕的最小总开销 I…

ubuntu20.04服务器搭建mongodb7

安装参考自mongo官网&#xff1a;在 Ubuntu 上安装 MongoDB Community Edition - MongoDB 手册 v7.0 MongoDB 版本 本教程安装的是 MongoDB 7.0 Community Edition。想要安装不同版本的 MongoDB Community Edition&#xff0c;请移步本页面左上角的版本下拉菜单&#xff0c;选…

ubuntu递归下载deb安装包,解决离线依赖问题

ubuntu递归下载安装包 主要针对离线环境的电脑安装deb包。 将下面的build-essential换成自己需要安装的包&#xff0c;虽然下面代码会递归下载依赖安装包&#xff0c;但是在离线环境下仍然可能会出现依赖包为配置问题。 因此&#xff0c;根据报错&#xff0c;手动递归下载报错…

【SQL 新手教程 1/20】SQL语言MySQL数据库 简介

&#x1f497; 什么是SQL&#xff1f;⭐ (Structured Query Language) 结构化查询语言&#xff0c;是访问和处理关系数据库的计算机标准语言 无论用什么编程语言&#xff08;Java、Python、C……&#xff09;编写程序&#xff0c;只要涉及到操作关系数据库都必须通过SQL来完成 …

4招清洁法,清理电脑无死角,焕然一新效率高

随着时间的积累&#xff0c;电脑内部可能会堆积起大量的垃圾文件、缓存数据和无用程序。因此&#xff0c;定期清理电脑是很有必要的。为了让你的电脑重新焕发生机&#xff0c;提高工作效率&#xff0c;本文将为你介绍4招实用的清洁法&#xff0c;助你轻松清理电脑死角&#xff…

js 数组常用函数总结

目录 1、push 2、unshif 3、pop 4、shift 5、concat 6、slice 7、splice 8、join 9、indexOf 10、lastIndexOf 11、forEach 12、map 13、filter 14、reduce 15、sort 16、reverse 17、includes 18、some 19、every 20、toString 21.、find 22、findLast 23、…

JavaWeb学习——请求响应、分层解耦

目录 一、请求响应学习 1、请求 简单参数 实体参数 数组集合参数 日期参数 Json参数 路径参数 总结 2、响应 ResponseBody&统一响应结果 二、分层解耦 1、三层架构 三层架构含义 架构划分 2、分层解耦 引入概念 容器认识 3、IOC&DI入门 4、IOC详解 …

Cadence23学习笔记(十四)

ARC就是圆弧走线的意思&#xff1a; 仅打开网络的话可以只针对net进行修改走线的属性&#xff1a; 然后现在鼠标左键点那个走线&#xff0c;那个走线就会变为弧形&#xff1a; 添加差分对&#xff1a; 之后&#xff0c;分别点击两条线即可分配差分对&#xff1a; 选完差分对之后…

微服务实践和总结

H5原生组件web Component Web Component 是一种用于构建可复用用户界面组件的技术&#xff0c;开发者可以创建自定义的 HTML 标签&#xff0c;并将其封装为包含逻辑和样式的独立组件&#xff0c;从而在任何 Web 应用中重复使用。 <!DOCTYPE html> <html><head…

css in js 相比较 css modules 有什么好处?

CSS-in-JS 和 CSS Modules 都是用于管理 React 组件样式的流行方案&#xff0c;它们各有优势。相比 CSS Modules&#xff0c;CSS-in-JS 的主要好处包括: 动态样式&#xff1a;CSS-in-JS 可以轻松创建基于 props 或状态的动态样式&#xff0c;更灵活地处理复杂的样式逻辑。 无需…

【vue3|第18期】Vue-Router路由的三种传参方式

日期:2024年7月17日 作者:Commas 签名:(ง •_•)ง 积跬步以致千里,积小流以成江海…… 注释:如果您觉得有所帮助,帮忙点个赞,也可以关注我,我们一起成长;如果有不对的地方,还望各位大佬不吝赐教,谢谢^ - ^ 1.01365 = 37.7834;0.99365 = 0.0255 1.02365 = 1377.408…

EtherNet/IP网络基础

EtherNet/IP&#xff08;Ethernet Industrial Protocol&#xff09;是一种用于工业自动化的通信协议&#xff0c;基于以太网技术。它允许设备和控制系统之间进行高效的数据交换和通信。以下是EtherNet/IP网络的基础知识。 1. 什么是EtherNet/IP&#xff1f; EtherNet/IP是由O…

ctfshow SSTI注入 web369--web372

web369 这把request过滤了&#xff0c;只能自己拼字符了 ""[[__clas,s__]|join] 或者 ""[(__clas,s__)|join] 相当于 ""["__class__"]举个例子&#xff0c;chr(97) 返回的是字符 a&#xff0c;因为 97 是小写字母 a 的 Unicode 编码…

go操作aws s3

v2 官方推荐版本&#xff0c;需要go版本>1.20 安装 go get github.com/aws/aws-sdk-go-v2 go get github.com/aws/aws-sdk-go-v2/config go get github.com/aws/aws-sdk-go-v2/service/s3必要参数 bucket: 存储桶的名称 Region: 存储桶所在区域,例us-east-1 accessKey…

PHP运算符

PHP 运算符是用于执行各种操作&#xff08;如算术运算、比较、逻辑运算、字符串连接等&#xff09;的符号。在 PHP 中&#xff0c;运算符的命名主要是基于它们的功能和用法&#xff0c;而不是像变量或函数那样可以自定义名称。以下是一个关于 PHP 运算符的详细教程&#xff0c;…

unity2D游戏开发01项目搭建

1新建项目 选择2d模板,设置项目名称和存储位置 在Hierarchy面板右击&#xff0c;create Empty 添加组件 在Project视图中右键新建文件夹 将图片资源拖进来&#xff08;图片资源在我的下载里面&#xff09; 点击Player 修改属性&#xff0c;修好如下 点击Sprite Editor 选择第二…

Angular由一个bug说起之八:实践中遇到的一个数据颗粒度的问题

互联网产品离不开数据处理&#xff0c;数据处理有一些基本的原则包括&#xff1a;准确性、‌完整性、‌一致性、‌保密性、‌及时性。‌ 准确性&#xff1a;是数据处理的首要目标&#xff0c;‌确保数据的真实性和可靠性。‌准确的数据是进行分析和决策的基础&#xff0c;‌因此…

【目标检测】非极大值抑制(Non-Maximum Suppression, NMS)步骤与实现

步骤 置信度排序&#xff1a;首先根据预测框的置信度&#xff08;即预测框包含目标物体的概率&#xff09;对所有预测框进行降序排序。选择最佳预测框&#xff1a;选择置信度最高的预测框作为参考框。计算IoU&#xff1a;计算其他所有预测框与参考框的交并比&#xff08;Inter…

【数据结构】建堆算法复杂度分析及TOP-K问题

【数据结构】建堆算法复杂度分析及TOP-K问题 &#x1f525;个人主页&#xff1a;大白的编程日记 &#x1f525;专栏&#xff1a;数据结构 文章目录 【数据结构】建堆算法复杂度分析及TOP-K问题前言一.复杂度分析1.1向下建堆复杂度1.2向上建堆复杂度1.3堆排序复杂度 二.TOP-K问…