基于PyTorch的视频分类实战

1、数据集下载

官方链接:https://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/#Downloads

百度网盘连接:

https://pan.baidu.com/s/1sSn--u_oLvTDjH-BgOAv_Q?pwd=xsri

提取码: xsri 

        官方链接有详细的数据集介绍,下载的是压缩包 ‘hmdb51_org.rar’,解压后里面是 51 个.rar 压缩包,每个压缩包名是一个类别,里面的是对应类别的视频片段(.avi 文件)。因为资源有限,这里只解压了 5 个类别的视频如图 1 所示:

图1 'hmdb5/org'

        这里新建了 ‘hmdb5’ 文件夹,并新建了 ‘org’ 子文件夹,然后把 ‘hmdb51_org’ 文件夹的 5 个子文件夹放到 ‘org’ 中。作为这次实践的源视频数据。

2、utils.py
        在这里先实现 utils.py,即取帧(get_frames)和存帧(store_frames)函数,取帧函数的功能为从视频中等间距抽取 n_frame 帧,并返回这些帧组成的列表。存帧函数的功能即为将帧列表按序存到 path 中。

import os
import cv2
import numpy as npdef get_frames(path, n_frames=1):""":param path: 视频文件路径:param n_frames: 读取的帧数:return: 读取的帧列表 frames"""frames = []# 实例化一个用于捕获视频流的对象, 若参数为整数则用于读取摄像头视频, 若参数为字符串则用于读取视频文件v_cap = cv2.VideoCapture(path)'''cv2.CAP_PROP_FRAME_COUNT 是 cv2.VideoCapture 类的一个属性标识符,用于获取视频流或视频文件中的总帧数cv2.VideoCapture 的 get 方法用于获取视频流或视频文件的属性(返回值均为实数):propId 是属性标识符,整数:cv2.CAP_PROP_FRAME_WIDTH:视频的帧宽度(以像素为单位)cv2.CAP_PROP_FRAME_HEIGHT:视频的帧高度(以像素为单位)cv2.CAP_PROP_FPS:视频的帧率(每秒的帧数)cv2.CAP_PROP_POS_FRAMES:当前读取帧的位置(基于 0 的索引)cv2.CAP_PROP_POS_AVI_RATIO:视频文件的相对位置(播放进度)cv2.CAP_PROP_FRAME_COUNT:视频文件中的总帧数'''v_len = int(v_cap.get(propId=cv2.CAP_PROP_FRAME_COUNT))'''在指定区间返回等距的数字数组:start: 区间起点stop: 区间终点num: 采样数量endpoint: 默认为 True,若为 False 则区间不包括 stop'''frame_list = np.linspace(start=0, stop=v_len - 1, num=n_frames + 1, dtype=np.int16)for fn in range(v_len):# 读取下一帧。它返回两个值:一个布尔值 success 表示是否成功读取帧和一个数组 frame 表示读取到的帧。success, frame = v_cap.read()if success is False:continueif fn in frame_list:frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)frames.append(frame)v_cap.release()return framesdef store_frames(frames, path):""":param frames: 待保存为 jpg 图片的帧列表:param path: 存储路径:return:"""for i, frame in enumerate(frames):frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)path2img = os.path.join(path, "frame" + str(i) + ".jpg")cv2.imwrite(path2img, frame)

3、数据抽帧,并划分训练集和测试集

        先在 ‘hmdb5’ 文件夹中新建子文件夹 ‘train’ 和 ‘test',再运行以下代码即可数据抽帧,并划分训练集和测试集。       

import os
from utils import get_frames, store_framespath = "hmdb5"
org_dir = "org"
org_path = os.path.join(path, org_dir)
categories_list = os.listdir(org_path)
# brush_hair: 0, chartwheel: 1, clap: 2, catch: 3, chew: 4# 输出每个类别的视频数量
for c in categories_list:print("category:", c)p = os.path.join(org_path, c)video_list = os.listdir(p)print("number of videos:", len(video_list))print("-" * 50)
"""
category: brush_hair
number of videos: 107
--------------------------------------------------
category: cartwheel
number of videos: 107
--------------------------------------------------
category: clap
number of videos: 130
--------------------------------------------------
category: catch
number of videos: 102
--------------------------------------------------
category: chew
number of videos: 109
--------------------------------------------------
"""extension = '.avi'
n_frames = 16
train_rate = 0.9for i, c in enumerate(categories_list):p = os.path.join(org_path, c)videos = [v for v in os.listdir(p) if v.endswith(extension)]train_size = int(len(videos) * train_rate)for j, name in enumerate(videos):video_path = os.path.join(p, name)frames = get_frames(video_path, n_frames=n_frames)path2store = os.path.join(path, "train")if j >= train_size:path2store = os.path.join(path, "test")path2store = os.path.join(path2store, str(i)+"_"+name[:-4])print(path2store)os.makedirs(path2store, exist_ok=True)store_frames(frames, path2store)

        第一段代码输出五个类别的视频数量,可以看到 brush_hair、cartwheel、clap、catch 和  chew 依次有 107、107、130、102、109 个视频。最后一段代码的功能是依次对每个类别的每个视频抽帧,并将抽帧结果存置指定路径,同时划分训练集和测试集。这里设置每个视频的抽帧数量 n_frame=16,按 9:1(498:57) 划分训练集和测试集。每个样本(视频文件夹)名都在原来的名字前拼接上 ‘类别编号_’,其中类别编号为:

brush_hair: 0, chartwheel: 1, clap: 2, catch: 3, chew: 4

        这段代码的运行结果如图 2 所示(以测试集为例)即所有类别样本都在一个文件夹中,不再有类别目录,样本名字最前面的数字即为该样本的类别。

图2 测试集部分样本

4、train.py

4.1 导包

import os
import re
import torch
from torch import nn
from tqdm import tqdm
from PIL import Image
from torchvision import transforms
from torchvision.models import video
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau

4.2 设置环境变量

os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

4.3 定义超参数

lr = 3e-5
gamma = 0.5
epochs = 20
step_size = 5
batch_size = 16
weight_decay = 1e-2

        这里定义初始学习率为 lr=3e-5,训练轮次为 epochs=20,batch_size=16,正则化系数为 weight_decay=1e-2。gamma 和 step_size 分为 torch.optim.lr_scheduler.ReduceLROnPlateau 类构造函数的入参 factor 和 patience。factor 是学习率降低的因子,新的学习率将是当前学习率乘以这个因子;patience 指观察验证指标在多少个 epoch 内没有改善后降低学习率。

4.4 定义图像变换函数

train_transform = transforms.Compose([transforms.Resize((112, 112)),transforms.RandomHorizontalFlip(p=0.5),# 用于对图像进行随机的仿射变换, degrees 为旋转角度, translate 为水平和垂直平移的最大绝对分数transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),transforms.ToTensor(),transforms.Normalize([0.4322, 0.3947, 0.3765], [0.2280, 0.2215, 0.2170])
])test_transform = transforms.Compose([transforms.Resize((112, 112)),transforms.ToTensor(),transforms.Normalize([0.4322, 0.3947, 0.3765], [0.2280, 0.2215, 0.2170]),
])

        这里定义了两个图像变换函数,即用于训练集的 train_transform 和用于测试集的 test_transform, train_transform 在训练前依次对图片进行 resize 操作,以 0.5 的概率水平镜像变换操作,随机仿射操作(随机沿 x,y 方向分别平移 (-0.1*w,0.1*w)、(-0.1*h,0.1*h)),转换为 tensor 操作和标准化操作。test_transform 相较于 train_transform 去掉了起数据增强作用的两个操作。

4.5 定义训练集和测试集路径

# 训练集(498):测试机(57)=9:1
train_dir = 'hmdb5/train'
test_dir = 'hmdb5/test'

4.6 定义数据集类

class HMDB5Dataset(Dataset):def __init__(self, directory, transform):self.dir = directoryself.transform = transformself.names = os.listdir(directory)def __len__(self):return len(self.names)def __getitem__(self, idx):path = os.path.join(self.dir, self.names[idx])frames = []for i in range(16):frame = Image.open(os.path.join(path, 'frame' + str(i) + '.jpg'))frames.append(self.transform(frame))frames = torch.stack(frames)# 返回 input 的转置版本, 即交换 input 的 dim0 和 dim1frames = torch.transpose(input=frames, dim0=0, dim1=1)# 编译正则表达式, ^ 表示匹配字符串的开始, + 表示一个或多个pattern = re.compile(r'^(\d+)_')match = re.search(pattern, self.names[idx])return frames, int(match.group(1))

        数据集类的构造函数定义了 3 个属性:dir(数据集路径)、transform(数据预处理方式)和names(样本名列表)。

        __getitem__ 函数根据 idx 按序取出一个样本的所有帧,并对所有帧执行了 transform 操作,最后返回的样本 frames 是 shape 为(channels,n_frames,h,w)的 tensor,该函数还利用 re 库从样本名中获取该样本的标签并返回。

4.7 定义模型

def init_model(mi):m = Noneif mi == 1:m = video.r3d_18(num_classes=5)  # epochs = 20, correct = 0.754return m.to(device)

        这里使用的模型为  torchvision.models.video.r3d_18[1],原文链接:https://arxiv.org/abs/1711.11248。实现可以参考 torch 源码。

4.8 计算评价指标

def correct_loss(data_loader, desc, test):results = []correct = 0.0test_loss = 0.0for img, tag in tqdm(data_loader, desc, total=len(data_loader)):img = img.to(device)tag = tag.to(device)pre = model(img)if test:test_loss += loss_fn(pre, tag)correct += torch.sum((pre.argmax(dim=1) == tag).float())results.append(correct / len(data_loader.dataset))if test:results.append(test_loss)return results

        correct_loss 函数用于计算 model 在 data_loader 上的 correct 和 loss(如果 test=True ,即data_loader 是测试集的数据加载器)。并将结果以列表的形式返回。

4.9 训练

if __name__ == '__main__':model = init_model(1)loss_fn = nn.CrossEntropyLoss()optimizer = torch.optim.AdamW(model.parameters(), lr, weight_decay=weight_decay)train_ds = HMDB5Dataset(train_dir, train_transform)test_ds = HMDB5Dataset(test_dir, test_transform)train_dl = DataLoader(train_ds, batch_size, True, num_workers=2)test_dl = DataLoader(test_ds, batch_size, False, num_workers=2)'''在验证指标停止改善时降低学习率:mode(str): 值域为 {'min', 'max'}。指定优化器应该监视的指标是应该最小化还是最大化factor(float): 学习率降低的因子。新的学习率将是当前学习率乘以这个因子patience(int): 观察验证指标在多少个 epoch 内没有改善后降低学习率'''scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=gamma, patience=step_size, verbose=True)best_loss = float('inf')for epoch in range(epochs):s_loss = 0.0print('Epoch:', epoch + 1, '/', epochs)for x, y in tqdm(train_dl, total=len(train_dl)):x = x.to(device)y = y.to(device)pred = model(x)loss = loss_fn(pred, y)s_loss += lossloss.backward()optimizer.step()optimizer.zero_grad()model.eval()  # 将模型设置为评估模式with torch.no_grad():print("s_loss:%.3f" % s_loss)train_metrics = correct_loss(train_dl, 'compute train_metrics:', False)test_metrics = correct_loss(test_dl, 'compute test_metrics:', True)if test_metrics[1] < best_loss:best_loss = test_metrics[1]print("train_correct:%.3f,test_correct:%.3f" % (train_metrics[0], test_metrics[0]))model.train()scheduler.step(best_loss)

        这里使用交叉墒损失函数,AdamW 优化器,学习率使用 ReduceLROnPlateau scheduler,该 scheduler 监视的指标为 test loss。训练过程中得到的最高 test_correct=0.754。

5、项目目录结构

参考文献

[1] Du Tran, Heng Wang, Lorenzo Torresani, Jamie Ray, Yann LeCun, and Manohar Paluri. A closer look at spatiotemporal convolutions for action recognition. In CVPR, pages 6450–6459, 2018. 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/754774.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

网络视频播放器|基于JSP技术+ Mysql+Java+ B/S结构的网络视频播放器设计与实现(可运行源码+数据库+设计文档)

推荐阅读100套最新项目 最新ssmjava项目文档视频演示可运行源码分享 最新jspjava项目文档视频演示可运行源码分享 最新Spring Boot项目文档视频演示可运行源码分享 2024年56套包含java&#xff0c;ssm&#xff0c;springboot的平台设计与实现项目系统开发资源&#xff08;可…

Windows server 2008 R2 在VMware虚拟机上的安装

Windows server 2008 R2 在VMware虚拟机上的安装 准备工作VMware 新建并配置虚拟机安装和启动Windows server 2008 R2 准备工作 Windows server 2008 R2 ISO镜像的下载&#xff1a;Windows server 2008 R2 ISO VMware 新建并配置虚拟机 第一步&#xff0c;点击新建虚拟机 第…

【洛谷 P9242】[蓝桥杯 2023 省 B] 接龙数列 题解(线性DP)

[蓝桥杯 2023 省 B] 接龙数列 题目描述 对于一个长度为 K K K 的整数数列&#xff1a; A 1 , A 2 , … , A K A_{1},A_{2},\ldots,A_{K} A1​,A2​,…,AK​&#xff0c;我们称之为接龙数列当且仅当 A i A_{i} Ai​ 的首位数字恰好等于 A i − 1 A_{i-1} Ai−1​ 的末位数字…

ASP.NET通过Appliaction和Session统计在人数和历史访问量

目录 背景: Appliaction&#xff1a; Session&#xff1a; 过程&#xff1a; 数据库&#xff1a; Application_Start&#xff1a; Session_Start&#xff1a; Session_End&#xff1a; Application_End&#xff1a; 背景: 事件何时激发Application_Start在调用当前应用…

200W-300W厚膜电阻-SOT227小方块封装功率负载电阻器

SOT-227 型电阻器是许多电流监测和精密控制应用的理想选择&#xff0c;其电阻值低至 0.5 mΩ。这些高度可靠的无感厚膜功率电阻器采用四端子开尔文连接&#xff0c;可将测量路径与电流路径隔离&#xff0c;当与适当的散热器一起使用时&#xff0c;同样适用于高功率电流监测。电…

C# Selenium Edge 驱动下的常见用法

using OpenQA.Selenium; using OpenQA.Selenium.Edge; using OpenQA.Selenium.Support.UI; //添加缩放属性 将浏览器缩放设为100% EdgeOptions optionsnew EdgeOptions(); options.AddArgument("force-device-scale-factor1"); //不需添加额外属性 options可不写…

若依jar包运行脚本,从零到一:用Bash脚本实现JAR应用的启动、停止与监控

脚本使用说明&#xff1a; 启动应用&#xff1a;sh app.sh start停止应用&#xff1a;sh app.sh stop检查应用状态&#xff1a;sh app.sh status重启应用&#xff1a;sh app.sh restart 注意事项&#xff1a; 请确保你的系统上安装了 Java 环境&#xff0c;并且 ruoyi-admin…

Android11实现能同时开多个录屏应用(或者共享屏幕或投屏时录屏)

1.概述 Android原生对MediaProjection的管理逻辑&#xff0c;是如果服务端已经保存有MediaProjection的实例&#xff0c;那么再次创建的时候&#xff0c;之前的MediaProjection实例就会被暂停&#xff0c;并且引用指向新的实例&#xff0c;也就导致了当开启后一个录屏应用时&a…

Cookie 信息泄露 Cookie未设置http only属性 原理以及修复方法

漏洞名称&#xff1a;Cookie信息泄露、Cookie安全性漏洞、Cookie未设置httponly属性 漏洞描述&#xff1a; cookie的属性设置不当可能会造成系统用户安全隐患&#xff0c;Cookie信息泄露是Cookiehttp only配置缺陷引起的&#xff0c;在设置Cookie时&#xff0c;可以设置的一个…

Visual Studio .NET 中常用的文件类型

Visual Studio .NET 中常用的文件类型 扩展名名称描述.slnVisual studio .NET解决方案文件.sln文件为解决方案资源管理器提供显示管理文件的图形接口所需的信息。打开.sln文件能快捷地打开整个项目的所有文件.csprojVisual C# 项目文件一个特殊的XML文档&#xff0c;主要用来控…

SQLiteC/C++接口详细介绍sqlite3_stmt类简介

返回&#xff1a;SQLite—系列文章目录 上一篇&#xff1a;SQLiteC/C接口详细介绍之sqlite3类&#xff08;十八&#xff09; 下一篇&#xff1a;SQLiteC/C接口详细介绍sqlite3_stmt类&#xff08;一&#xff09; 预准备语句对象 typedef struct sqlite3_stmt sqlite3_stmt…

【洛谷 P9232】[蓝桥杯 2023 省 A] 更小的数 题解(字符串+区间DP)

[蓝桥杯 2023 省 A] 更小的数 题目描述 小蓝有一个长度均为 n n n 且仅由数字字符 0 ∼ 9 0 \sim 9 0∼9 组成的字符串&#xff0c;下标从 0 0 0 到 n − 1 n-1 n−1&#xff0c;你可以将其视作是一个具有 n n n 位的十进制数字 n u m num num&#xff0c;小蓝可以从 n…

java 程序连接 redis 集群 的时候报错 MUTLI is currently not supported in cluster mode

找了半天找不到,为什么国内文章环境是真的差&#xff0c; redis 集群不支持事务&#xff0c;而你的方法上面估计使用了 spring 的事务导致错误具体解决&#xff1a; Transactional(propagation Propagation.NOT_SUPPORTED)public <T> void removeMultiCacheMapValue…

内置泵电源,热保护电路等功能的场扫描电路D78040,偏转电流可达1.7Ap-p,可用于中小型显示器。

D78040是一款场扫描电路&#xff0c;偏转电流可达1.7Ap-p&#xff0c;可用于中小型显示器。 二 特 点 1、有内置泵电源 2、垂直输出电路 3、热保护电路 4、偏转电流可达1.7Ap-p 三 基本参数 四 应用电路图 1、应用线路 2、PIN5脚输出波形如下&#xff1a;

6-高维空间:机器如何面对越来越复杂的问题

声明 本文章基于哔哩哔哩付费课程《小白也能听懂的人工智能原理》。仅供学习记录、分享&#xff0c;严禁他用&#xff01;&#xff01;如有侵权&#xff0c;请联系删除 目录 一、知识引入 &#xff08;一&#xff09;二维输入数据 &#xff08;二&#xff09;数据特征维度 …

一级指针和二级指针

一级指针 形式&#xff1a;int a 2; int *p &a; 解释&#xff1a; int*p &a表示一级指针p指向变量a的值。此时一级指针p存放的是a的地址&#xff0c;*p解引用是a的值。 作用&#xff1a; c中随处可见。不多言。 二级指针 形式&#xff1a;int a 2; int *p &…

C语言calloc函数的特点,效率低。但是进行初始化操作

#define _CRT_SECURE_NO_WARNINGS 1 #include<stdlib.h> #include<string.h> #include<errno.h> #include<stdio.h> int main() { int *p (int *)calloc(10,sizeof(int)); //初始化&#xff0c;效率低&#xff0c;然而malloc函数相反&#xf…

Linux/Ubuntu/Debian的终端中和的区别

下边举例说明&#xff1a; “cd /home & wine ps.exe”和“cd /home && wine ps.exe”之间的区别在于命令在类 Unix shell 环境&#xff08;例如 Linux 或 macOS&#xff09;中执行的方式&#xff1a; ‘cd /home & wine ps.exe’: 在此命令中&#xff0c;“…

最细致最简单的 Arm 架构搭建 Harbor

更好的阅读体验&#xff1a;点这里 &#xff08; www.doubibiji.com &#xff09; ARM离线版本安装 官方提供了一个 arm 版本&#xff0c;但是好久都没更新了&#xff0c;地址&#xff1a;https://github.com/goharbor/harbor-arm 。 也不知道为什么不更新&#xff0c;我看…

数据机构-2

线性表 概念 顺序表 示例&#xff1a;创建一个存储学生信息的顺序表 表头&#xff08;Tlen总长度&#xff0c; Clen当前长度&#xff09; 函数 #include <seqlist.c> #include <stdio.h> #include <stdlib.h> #include "seqlist.h" #include &…