【人工智能学习之STGCN训练自己的数据集】

STGCN训练自己的数据集

  • 准备事项
  • 数据集制作
    • 视频转json
    • jsons转json
    • json转npy&pkl
  • 训练STGCN
    • 添加图结构
    • 修改训练参数
    • 开始训练
    • 测试

准备事项

  1. st-gcn代码下载与环境配置
git clone https://github.com/yysijie/st-gcn.git
cd st-gcn
pip install -r requirements.txt
cd torchlight
python setup.py install
cd ..
  1. 数据集结构
    可以使用open pose制作数据集,制作过程见下章。
    (参考openpose环境搭建和利用openpose提取自建数据集)

我的数据集结构如下:

dataset/
├── stgcn_data/ 				# 最终使用的数据集
│   ├── train  					# 训练数据集
│   └── val/  					# 验证数据集
├── video/  					# 视频数据集
│   └── fall  					# 分类0视频文件夹
│	└── normal					# 分类1视频文件夹
│	└── resized					# 视频缩放与json
│		└── data				# 视频对应的json
│		└── fall				# 分类0缩放视频
│		└── normal				# 分类1缩放视频
│		└── snippets			# 视频每一帧的json
│	└── label0.txt				# 分类0标签文本
│	└── label1.txt				# 分类1标签文本
└── pose_demo/  				# openpose目录└── bin/  					# 目录└── openpose_demo.exe/  # openpose的exe可执行文件└── XXX.dll/  			# 各种依赖项

数据集制作

视频转json

首先需要将视频放到对应的目录下,目录名称就是你的类名。

提示一下:用于st-gcn训练的数据集视频帧数不要超过300帧,5~6s的视频时长比较好,不要10几s的视频时长。要不然会报:index 300 is out of bounds for axis 1 with size 300 这种错误。很多博主使用的用FFmpeg对视频进行缩放,但我FFmpeg老是出问题,索性直接自己用cv2实现了。
【win64中FFmpegReader报错:Cannot find installation of real FFmpeg (which comes with ffprobe).】

以下是我video2json的代码,存于dataset目录下运行:

#!/usr/bin/env python
import os
import argparse
import json
import shutilimport numpy as np
import torch
# import skvideo
# ffmpeg_path = r"/dataset/ffmpeg-master-latest-win64-gpl/ffmpeg-master-latest-win64-gpl/bin"
# skvideo.setFFmpegPath(ffmpeg_path)
import skvideo.io# from .io import IO
import tools
import tools.utils as utilsimport cv2
import osdef resize_video(input_path, output_path, size=(340, 256), fps=30):# 打开视频文件cap = cv2.VideoCapture(input_path)# 获取视频帧率if fps == 0:fps = cap.get(cv2.CAP_PROP_FPS)# 获取视频编码格式fourcc = cv2.VideoWriter_fourcc(*'mp4v')# 创建 VideoWriter 对象out = cv2.VideoWriter(output_path, fourcc, fps, size)while True:ret, frame = cap.read()if not ret:break# 调整帧大小resized_frame = cv2.resize(frame, size)# 写入输出文件out.write(resized_frame)# 释放资源cap.release()out.release()def resize_all_video(originvideo_file,resizedvideo_file):# 获取视频文件名列表videos_file_names = [f for f in os.listdir(originvideo_file) if f.endswith('.mp4')]# 遍历并处理每个视频文件for file_name in videos_file_names:video_path = os.path.join(originvideo_file, file_name)outvideo_path = os.path.join(resizedvideo_file, file_name)resize_video(video_path, outvideo_path)print(f'{file_name} resize success')def get_video_frames(video_path):"""读取视频文件并返回所有帧:param video_path: 视频文件路径:return: 帧列表"""frames = []cap = cv2.VideoCapture(video_path)while True:ret, frame = cap.read()if not ret:breakframes.append(frame)cap.release()return framesclass PreProcess():"""利用openpose提取自建数据集的骨骼点数据"""def start(self):###########################修改处################type_number = 2gongfu_filename_list = ['fall','normal']#################################################for process_index in range(type_number):gongfu_filename = gongfu_filename_list[process_index]# 标签信息# labelgongfu_name = 'xxx_{}'.format(process_index)labelAction_name = '{}'.format(gongfu_filename)label_no = process_index# 视频所在文件夹originvideo_file = './video/{}/'.format(gongfu_filename)# resized视频输出文件夹resizedvideo_file = './video/resized/{}/'.format(gongfu_filename)'''videos_file_names = os.listdir(originvideo_file)# 1. Resize文件夹下的视频到340x256 30fpsfor file_name in videos_file_names:video_path = '{}{}'.format(originvideo_file, file_name)outvideo_path = '{}{}'.format(resizedvideo_file, file_name)writer = skvideo.io.FFmpegWriter(outvideo_path,outputdict={'-f': 'mp4', '-vcodec': 'libx264', '-s': '340x256','-r': '30'})reader = skvideo.io.FFmpegReader(video_path)for frame in reader.nextFrame():writer.writeFrame(frame)writer.close()print('{} resize success'.format(file_name))
'''# 1. Resize文件夹下的视频到340x256 30fps cv处理resize_all_video(originvideo_file,resizedvideo_file)# 2. 利用openpose提取每段视频骨骼点数据resizedvideos_file_names = os.listdir(resizedvideo_file)for file_name in resizedvideos_file_names:outvideo_path = '{}{}'.format(resizedvideo_file, file_name)# openpose = '{}/examples/openpose/openpose.bin'.format(self.arg.openpose)openpose = 'C:/WorkFiles/company_server_SSH/st-gcn-master/dataset/pose_demo/bin/OpenPoseDemo.exe'video_name = file_name.split('.')[0]output_snippets_dir = './video/resized/snippets/{}'.format(video_name)print(f"Output snippets directory: {output_snippets_dir}")output_sequence_dir = 'video/resized/data'output_sequence_path = '{}/{}.json'.format(output_sequence_dir, video_name)label_name_path = 'video/label{}.txt'.format(process_index)with open(label_name_path) as f:label_name = f.readlines()label_name = [line.rstrip() for line in label_name]# pose estimationopenpose_args = dict(video=outvideo_path,write_json=output_snippets_dir,display=0,render_pose=0,model_pose='COCO')command_line = openpose + ' 'command_line += ' '.join(['--{} {}'.format(k, v) for k, v in openpose_args.items()])print(f"Running command: {command_line}")shutil.rmtree(output_snippets_dir, ignore_errors=True)os.makedirs(output_snippets_dir)os.system(command_line)# pack openpose ouputs# video = utils.video.get_video_frames(outvideo_path)video = get_video_frames(outvideo_path)height, width, _ = video[0].shape# 这里可以修改label, label_indexvideo_info = utils.openpose.json_pack(output_snippets_dir, video_name, width, height, labelAction_name , label_no)if not os.path.exists(output_sequence_dir):os.makedirs(output_sequence_dir)with open(output_sequence_path, 'w') as outfile:json.dump(video_info, outfile)if len(video_info['data']) == 0:print('{} Can not find pose estimation results.'.format(file_name))returnelse:print('{} pose estimation complete.'.format(file_name))if __name__ == '__main__':p=PreProcess()p.start()

注意:
如果你的openpose装的是cpu版本,会比较慢,跑很多视频生成json文件的时间就会很长。这个时候晚上如果去睡觉了一定一定一定要设置为从不熄屏休眠!!!从不!!!
不然就会这样:
在这里插入图片描述
经过漫长的等待之后。。。。终于运行完成了,我们可以得到

  1. 缩放过后的视频。

  2. 每一帧的json:在snippets目录下,以视频名称命名的文件下的每一帧json文件:在这里插入图片描述打开是这样的:在这里插入图片描述

  3. 每一个视频的json:在data目录下,以视频名称命名的json文件:打开是这样的:在这里插入图片描述
    到这里,得到了每个视频的json文件,数据集的制作就完成了一半了。

jsons转json

得到了每个视频的json文件之后,需要我们手动将data目录下的json文件分配到train和val下,一般按照9:1划分。放好之后运行下面这段jsons2json.py:

以下是我jsons2json的代码,存于stgcn_data目录下运行:

import json
import osif __name__ == '__main__':train_json_path = './train'val_json_path = './val'test_json_path = './test'output_train_json_path = './train_label.json'output_val_json_path = './val_label.json'output_test_json_path = './test_label.json'train_json_names = os.listdir(train_json_path)val_json_names = os.listdir(val_json_path)test_json_names = os.listdir(test_json_path)train_label_json = dict()val_label_json = dict()test_label_json = dict()for file_name in train_json_names:name = file_name.split('.')[0]json_file_path = '{}/{}'.format(train_json_path, file_name)json_file = json.load(open(json_file_path))file_label = dict()if len(json_file['data']) == 0:file_label['has_skeleton'] = Falseelse:file_label['has_skeleton'] = Truefile_label['label'] = json_file['label']file_label['label_index'] = json_file['label_index']train_label_json['{}'.format(name)] = file_labelprint('{} success'.format(file_name))with open(output_train_json_path, 'w') as outfile:json.dump(train_label_json, outfile)for file_name in val_json_names:name = file_name.split('.')[0]json_file_path = '{}/{}'.format(val_json_path, file_name)json_file = json.load(open(json_file_path))file_label = dict()if len(json_file['data']) == 0:file_label['has_skeleton'] = Falseelse:file_label['has_skeleton'] = Truefile_label['label'] = json_file['label']file_label['label_index'] = json_file['label_index']val_label_json['{}'.format(name)] = file_labelprint('{} success'.format(file_name))with open(output_val_json_path, 'w') as outfile:json.dump(val_label_json, outfile)for file_name in test_json_names:name = file_name.split('.')[0]json_file_path = '{}/{}'.format(test_json_path, file_name)json_file = json.load(open(json_file_path))file_label = dict()if len(json_file['data']) == 0:file_label['has_skeleton'] = Falseelse:file_label['has_skeleton'] = Truefile_label['label'] = json_file['label']file_label['label_index'] = json_file['label_index']test_label_json['{}'.format(name)] = file_labelprint('{} success'.format(file_name))with open(output_test_json_path, 'w') as outfile:json.dump(test_label_json, outfile)

运行完成后,我们就可以得到总的训练和验证的json文件了,里面包含了所有的视频动作关键点和标签信息。
还差一步就完成啦!
在这里插入图片描述

json转npy&pkl

第三步需要用到官方的工具文件kinetics_gendata.py。
所在的路径应该是:./st-gcn-master/tools/kinetics_gendata.py

这里有两处地方需要修改:

  1. 修改考虑参数
def gendata(data_path,label_path,data_out_path,label_out_path,num_person_in=3,  #observe the first 5 persons#这个参数指定了在每个视频帧中考虑的最大人数。#例如,如果设置为5,则脚本会尝试从每个视频帧中获取最多5个人的骨架信息。#这并不意味着每个视频帧都会有5个人,而是说脚本最多会处理5个人的数据。num_person_out=1,  #then choose 2 persons with the highest score#这个参数指定了最终每个序列(或视频)中保留的人数。#即使输入中有更多的人员,也只有得分最高的两个人会被选择出来用于训练。#这里的“得分”通常指的是骨架检测的置信度分数,即算法对某个点确实是人体某部位的信心程度。max_frame=300):
  1. 修改路径位置
if __name__ == '__main__':parser = argparse.ArgumentParser(description='Kinetics-skeleton Data Converter.')parser.add_argument('--data_path', default='./st-gcn-master/dataset/stgcn_data')parser.add_argument('--out_folder', default='./st-gcn-master/dataset/stgcn_data')arg = parser.parse_args()part = ['train', 'val']for p in part:data_path = '{}/{}'.format(arg.data_path, p)label_path = '{}/{}_label.json'.format(arg.data_path, p)data_out_path = '{}/{}_data.npy'.format(arg.out_folder, p)label_out_path = '{}/{}_label.pkl'.format(arg.out_folder, p)if not os.path.exists(arg.out_folder):os.makedirs(arg.out_folder)gendata(data_path, label_path, data_out_path, label_out_path)

ok,检查一下是否得到了npy和pkl文件
在这里插入图片描述

yes,到这里就可以开始训练啦。

训练STGCN

添加图结构

在net/utils/graph.py文件里面get_edge函数中保存的是不同的图结构。

注意这里的默认的layout如果符合自己定义的姿态就不用修改,否则需要自定义一个,本文采用的openpose即默认的openpose的18个关键点,不需要修改。其中num_node为关键点的个数,neighbor_link为关键点连接关系。如果自己的数据集是新定义的姿态点数不为18,在后续转换中可能还有修改需要保持一致。

修改训练参数

  1. 将config/st_gcn/kinetics-skeleton/train.yaml复制一份到根目录,重命名为mytrain.yaml,并修改其中参数。
  2. data_path和label_path修改为之前生成的文件路径;
  3. num_class改为自建数据集的行为类别个数;
  4. layout参数修改为之前添加的layout类别;
  5. strategy设置为spatial;
  6. 修改使用的GPU数量,单个设置device: [0];
  7. optim部分适当调整,base_lr: 0.1是基础学习率,step: [80, 120, 160, 200]:这表明使用了一种学习率衰减策略,num_epoch: 200:指定了整个训练过程将进行多少个周期(epochs)
  8. 不知道是不是我搞错了,居然不会自动保存best?而是每10轮自动保存一次;

以下是我的yaml:

work_dir: ./work_dir/recognition/kinetics_skeleton/ST_GCN# feeder
feeder: feeder.feeder.Feeder
train_feeder_args:random_choose: Truerandom_move: Truewindow_size: 150 data_path: C:/WorkFiles/company_server_SSH/st-gcn-master/dataset/stgcn_data/train_data.npylabel_path: C:/WorkFiles/company_server_SSH/st-gcn-master/dataset/stgcn_data/train_label.pkl
test_feeder_args:data_path: C:/WorkFiles/company_server_SSH/st-gcn-master/dataset/stgcn_data/val_data.npylabel_path: C:/WorkFiles/company_server_SSH/st-gcn-master/dataset/stgcn_data/val_label.pkl# model
model: net.st_gcn.Model
model_args:in_channels: 3num_class: 2edge_importance_weighting: Truegraph_args:layout: 'openpose'strategy: 'spatial'# training
device: [0]
batch_size: 32
test_batch_size: 32#optim
base_lr: 0.1
step: [80, 120, 160, 200]
num_epoch: 200

开始训练

训练指令(记得先激活环境,或者在pychar的终端运行)
yaml文件可以直接放在根目录下;也可以放在自己喜欢的位置(指令需要加上路径)

python main.py recognition -c mytrain.yaml

在这里插入图片描述
这里我的验证集俩分类各取了3个,总训练集200多个,所以top很高。

测试

可以像训练集那样仿照一个测试集出来。
也可以像我一样直接使用视频进行测试,不过我的代码涉及了一些实际的工业工作内容,无法提供,但可以给出一些思路:

  1. 通过其他关键点网络输出关键点(我用的YoloPose)
  2. 将某自定义时长内的所有关键点拼接在一起重塑为 (N, in_channels, T_in, V_in, M_in) 形状
  3. 通过网络输出获取这个时间段内的行为分类

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/61248.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Dify+Docker

1. 获取代码 直接下载 (1)访问 langgenius/dify: Dify is an open-source LLM app development platform. Difys intuitive interface combines AI workflow, RAG pipeline, agent capabilities, model management, observability features and more, …

Android so库的编译

在没弄明白so库编译的关系前,直接看网上博主的博文,常常会觉得云里雾里的,为什么一会儿通过Android工程cmake编译,一会儿又通过NDK命令去编译。两者编译的so库有什么区别? android版第三方库编译总体思路: 对于新手小白来说搞明白上面的总体思路图很有必…

Java函数式编程+Lambda表达式

文章目录 函数式编程介绍纯函数Lambda表达式基础Lambda的引入传统方法1. 顶层类2. 内部类3. 匿名类 Lambda 函数式接口(Functional Interface)1. **函数式接口的定义**示例: 2. **函数式接口与Lambda表达式的关系**关联逻辑:示例&…

Linux操作系统2-进程控制3(进程替换,exec相关函数和系统调用)

上篇文章:Linux操作系统2-进程控制2(进程等待,waitpid系统调用,阻塞与非阻塞等待)-CSDN博客 本篇代码Gitee仓库:Linux操作系统-进程的程序替换学习 d0f7bb4 橘子真甜/linux学习 - Gitee.com 本篇重点:进程替换 目录 …

文件上传漏洞:你的网站安全吗?

文章目录 文件上传漏洞攻击方式:0x01绕过前端限制0x02黑名单绕过1.特殊解析后缀绕过2..htaccess解析绕过3.大小写绕过4.点绕过5.空格绕过6.::$DATA绕过7.配合中间件解析漏洞8.双后缀名绕过9.短标签绕过 0x03白名单绕过1.MIME绕过(Content-Type绕过)2.%00截断3.0x00截…

设计模式-适配器模式-注册器模式

设计模式-适配器模式-注册器模式 适配器模式 如果开发一个搜索中台,需要适配或接入不同的数据源,可能提供的方法参数和平台调用的方法参数不一致,可以使用适配器模式 适配器模式通过封装对象将复杂的转换过程隐藏于幕后。 被封装的对象甚至…

springboot341+vue校园求职招聘系统设计和实现pf(论文+源码)_kaic

毕 业 设 计(论 文) 校园求职招聘系统设计与实现 摘 要 传统办法管理信息首先需要花费的时间比较多,其次数据出错率比较高,而且对错误的数据进行更改也比较困难,最后,检索数据费事费力。因此,…

基于java web的网上书店系统设计

摘 要 随着互联网的越发普及,网上购物成为了当下流行的热门行为。网络上开店创业有许多的优势:投入少,启动 资金低,交易便捷。网上书店与传统的线下书店比起来优势巨大,网上书店的经营方式和销售渠道是不同与线下书 店…

Java设计模式——职责链模式:解锁高效灵活的请求处理之道

嘿,各位 Java 编程大神和爱好者们!今天咱们要一同深入探索一种超厉害的设计模式——职责链模式。它就像一条神奇的“处理链”,能让请求在多个对象之间有条不紊地传递,直到找到最合适的“处理者”。准备好跟我一起揭开它神秘的面纱…

Android 设备使用 Wireshark 工具进行网络抓包

背景 电脑和手机连接同一网络,想使用wireshark抓包工具抓取Android手机网络日志,有以下两种连接方法: Wi-Fi 网络抓包。USB 网络共享抓包。需要USB 数据线将手机连接到电脑,并在开发者模式中启用 USB 网络共享。 查看设备连接信…

redis大key和热key

redis中大key、热key 什么是大key大key可能产生的原因大key可能会造成什么影响如何检测大key如何优化删除大key时可能的问题删除大key的策略 热key热key可能导致的问题解决热key的方法 什么是大key 大key通常是指占用内存空间过大或包含大量元素的键值对。 数据量大&#xff…

SpringBoot源码-spring boot启动入口ruan方法主线分析(二)

12.刷新前操作 // 刷新前操作prepareContext(context, environment, listeners, applicationArguments, printedBanner);进入prepareContext private void prepareContext(ConfigurableApplicationContext context, ConfigurableEnvironment environment,SpringApplicationRun…

使用 VLC 在本地搭建流媒体服务器 (详细版)

提示:详细流程 避坑指南 Hi~!欢迎来到碧波空间,平时喜欢用博客记录学习的点滴,欢迎大家前来指正,欢迎欢迎~~ ✨✨ 主页:碧波 📚 📚 专栏:音视频 目录 借助VLC media pl…

【单片机毕业设计12-基于stm32c8t6的智能称重系统设计】

【单片机毕业设计12-基于stm32c8t6的智能称重系统设计】 前言一、功能介绍二、硬件部分三、软件部分总结 前言 🔥这里是小殷学长,单片机毕业设计篇12-基于stm32c8t6的智能称重系统设计 🧿创作不易,拒绝白嫖可私 一、功能介绍 ----…

51单片机快速入门之中断的应用 2024/11/23 串口中断

51单片机快速入门之中断的应用 基本函数: void T0(void) interrupt 1 using 1 { 这里放入中断后需要做的操作 } void T0(void): 这是一个函数声明,表明函数 T0 不接受任何参数,并且不返回任何值。 interrupt 1: 这是关键字和参…

输入json 达到预览效果

下载 npm i vue-json-pretty2.4.0 <template><div class"newBranchesDialog"><t-base-dialogv-if"addDialogShow"title"Json数据配置"closeDialog"closeDialog":dialogVisible"addDialogShow":center"…

ML 系列:第 32节 — 机器学习中的统计简介

文章目录 一、说明二、统计概述三、描述性统计与推断性统计3.1 描述统计学3.2 推论统计 四、描述性统计中的均值、中位数和众数 一、说明 机器学习中的统计 随着我们深入研究机器学习领域&#xff0c;了解统计学在该领域的作用至关重要。统计学是机器学习的支柱&#xff0c;它…

大数据新视界 -- Hive 数据分区:精细化管理的艺术与实践(上)(7/ 30)

&#x1f496;&#x1f496;&#x1f496;亲爱的朋友们&#xff0c;热烈欢迎你们来到 青云交的博客&#xff01;能与你们在此邂逅&#xff0c;我满心欢喜&#xff0c;深感无比荣幸。在这个瞬息万变的时代&#xff0c;我们每个人都在苦苦追寻一处能让心灵安然栖息的港湾。而 我的…

VTK的基本概念(一)

文章目录 三维场景的基本要素1.灯光2.相机3.颜色4.纹理映射 三维场景的基本要素 1.灯光 在三维渲染场景中&#xff0c;可以有多个灯光的存在&#xff0c;灯光和相机是三维渲染场景的必备要素&#xff0c;如果没有指定的话&#xff0c;vtkRenderer会自动创建默认的灯光和相机。…

简单好用的折线图绘制!

折线图的概念及作用&#xff1a; 折线图&#xff08;Line Chart&#xff09;是一种常见的图表类型&#xff0c;用于展示数据的变化趋势或时间序列数据。它通过一系列的数据点&#xff08;通常表示为坐标系中的点&#xff09;与这些点之间的线段相连&#xff0c;直观地展示变量…