基于PyTorch搭建CNN实现视频动作分类任务代码详解

在这里插入图片描述
数据及具体讲解来源:
基于PyTorch搭建CNN实现视频动作分类任务

import torch
import torch.nn as nn
import torchvision.transforms as T
import scipy.io
from torch.utils.data import DataLoader,Dataset
import os
from PIL import Image
from torch.autograd import Variable
import numpy as np"""
加载数据
"""
#获得标签
label_mat = scipy.io.loadmat('./datasets/q3_2_data.mat')
#获得训练集标签
label_train = label_mat['trLb']
print(len(label_train))
#获得验证集标签
label_val = label_mat['valLb']
print(len(label_val))"""
通过Dataset类进行数据预处理
"""
class ActionDataset(Dataset):def __init__(self,root_dir,labels = [],transform=None):"""Args::param root_dir: 数据路径:param labels: 图片标签:param transform: 数据处理函数"""self.root_dir = root_dirself.transform = transformself.length = len(os.listdir(self.root_dir))self.labels = labelsdef __len__(self):  #返回数据数量return self.length*3    #一个视频片段包含3帧(3个图片)def __getitem__(self, idx): #图片处理及返回数据folder = idx//3+1   #判断该帧属于第几个视频中imidx = idx%3 + 1   #判断该帧在该视频中是第几帧folder = format(folder,'05d')   #将folder格式化,05d代表五位数,若不到五位用0填充imgname = str(imidx) + '.jpg'img_path = os.path.join(self.root_dir,folder,imgname)image = Image.open(img_path)"""当输入标签有值时,说明是训练集和验证集,输出的样本也是有标签的,若没有值,说明是测试集,输出的样本是没有标签的"""if len(self.labels)!=0:Label = self.labels[idx//3][0]-1#如果有对数据的处理先对数据进行处理if self.transform:image = self.transform(image)if len(self.labels)!=0:sample = {'image':image,'img_path':img_path,'Label':Label}else:sample = {'image':image,'img_path':img_path}return sampleimage_datast = ActionDataset(root_dir='./datasets/trainClips/',labels=label_train,transform=T.ToTensor())
# torchvision.transforms中定义了非常多对图像的预处理方法,这里使用的ToTensor方法为将0~255的RGB值映射到0~1的Tensor类型。
# #测试一下
# for i in range(3):
#     sample = image_datast[i]
#     print(sample['image'].shape)
#     print(sample['Label'])
#     print(sample['img_path'])"""
Dataloader类进行封装
注意:Windows不要用num_works
"""
#image_dataloader = DataLoader(image_datast,batch_size=4,shuffle=True)
# for i , sample in enumerate(image_dataloader):
#     #enumerate(iteration, start):返回一个枚举的对象
#     sample['image'] = sample['image']
#     print(sample[i,sample['image'].shape,sample['img_path'],'Label'])
#     if i == 6:
#         break
#训练集
image_dataset_train=ActionDataset(root_dir='./datasets/trainClips/',labels=label_train,transform=T.ToTensor())
image_dataloader_train = DataLoader(image_dataset_train, batch_size=32,shuffle=True)
#验证集
image_dataset_val=ActionDataset(root_dir='./datasets/valClips/',labels=label_val,transform=T.ToTensor())
image_dataloader_val = DataLoader(image_dataset_val, batch_size=32,shuffle=False)
#测试集:没有给定labels
image_dataset_test=ActionDataset(root_dir='./datasets/testClips/',labels=[],transform=T.ToTensor())
image_dataloader_test = DataLoader(image_dataset_test, batch_size=32,shuffle=False)"""
搭建模型
"""dtype = torch.FloatTensor # 这是pytorch所支持的cpu数据类型中的浮点数类型。print_every = 100   # 这个参数用于控制loss的打印频率,因为我们需要在训练过程中不断的对loss进行检测。def reset(m):   # 这是模型参数的初始化if hasattr(m, 'reset_parameters'):m.reset_parameters()#数据解释和处理
class Flatten(nn.Module):def forward(self, x):N, C, H, W = x.size() # 读取各个维度。return x.view(N, -1)  # -1代表除了特殊声明过的以外的全部维度。fixed_model_base = nn.Sequential(nn.Conv2d(3,8,kernel_size=7,stride=1),   ##3*64*64 -> 8*58*58nn.ReLU(inplace=True),nn.MaxPool2d(2, stride = 2),    # 8*58*58 -> 8*29*29nn.Conv2d(8, 16, kernel_size=7, stride=1), # 8*29*29 -> 16*23*23nn.ReLU(inplace=True),nn.MaxPool2d(2, stride = 2), # 16*23*23 -> 16*11*11Flatten(),nn.ReLU(inplace=True),nn.Linear(1936, 10)     # 1936 = 16*11*11
)
fixed_model = fixed_model_base.type(dtype)  #将模型数据转换成pytorch所支持的cpu数据类型中的浮点数类型。
# #测试:
# x = torch.randn(32, 3, 64, 64).type(dtype)
# x_var = Variable(x.type(dtype)) # 需要将其封装为Variable类型。
# ans = fixed_model(x_var)
# print(np.array(ans.size())) # 检查模型输出。
# np.array_equal(np.array(ans.size()), np.array([32, 10]))"""
训练步骤及模块
"""
optimizer = torch.optim.RMSprop(fixed_model_base.parameters(), lr = 0.0001)
loss_fn = nn.CrossEntropyLoss()def train(model,loss_fn,optimizer,dataloader,num_epoch = 1):for epoch in range(num_epoch):check_accuracy(fixed_model,image_dataloader_val)    #在验证集验证模型效果model.train()   #模型的.train()方法让模型进入训练模式,参数保留梯度,dropout层等部分正常工作for t,sample in enumerate(dataloader):x_var = Variable(sample['image'])y_var = Variable(sample['Label'].long())scores = model(x_var)   #得到输出loss = loss_fn(scores,y_var)if (t+1)%print_every ==0:print('t = %d, loss = %.4f' % (t + 1, loss.item()))#三步更新optimizer.zero_grad()loss.backward()optimizer.step()def check_accuracy(model,loader):num_correct = 0num_samples = 0model.eval()    # 模型的.eval()方法切换进入评测模式,对应的dropout等部分将停止工作。for t,sample in enumerate(loader):x_var = Variable(sample['image'])y_var = Variable(sample['Label'])scores = model(x_var)_,preds = scores.data.max(1)    # 找到可能最高的标签作为输出。num_correct += (preds.numpy() == y_var.numpy()).sum()num_samples += preds.size(0)acc = float(num_correct)/num_samplesprint('Got %d / %d correct (%.2f)' % (num_correct, num_samples, 100 * acc))"""
训练并验证
"""
torch.random.manual_seed(54321)
fixed_model.cpu()
fixed_model.apply(reset)
#pytorch中的model.apply(fn)会递归地将函数fn应用到父模块的每个子模块submodule,也包括model这个父模块自身
fixed_model.train()
train(fixed_model, loss_fn, optimizer,image_dataloader_train, num_epoch=5)
check_accuracy(fixed_model, image_dataloader_val)"""
测试
"""def predict_on_test(model, loader):model.eval()results = open('results.csv', 'w')  # 模型预测结果会被放在这里。count = 0results.write('Id' + ',' + 'Class' + '\n')for t, sample in enumerate(loader):x_var = Variable(sample['image'])scores = model(x_var)_, preds = scores.data.max(1)for i in range(len(preds)):results.write(str(count) + ',' + str(preds[i]) + '\n')count += 1results.close()return countcount = predict_on_test(fixed_model, image_dataloader_test)  # 放入你想要测试的训练集,然后打开文件去看一看结果吧。
print(count)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/389283.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

missforest_missforest最佳丢失数据插补算法

missforestMissing data often plagues real-world datasets, and hence there is tremendous value in imputing, or filling in, the missing values. Unfortunately, standard ‘lazy’ imputation methods like simply using the column median or average don’t work wel…

华硕猛禽1080ti_F-22猛禽动力回路的视频分析

华硕猛禽1080tiThe F-22 Raptor has vectored thrust. This means that the engines don’t just push towards the front of the aircraft. Instead, the thrust can be directed upward or downward (from the rear of the jet). With this vectored thrust, the Raptor can …

Memory-Associated Differential Learning论文及代码解读

Memory-Associated Differential Learning论文及代码解读 论文来源: 论文PDF: Memory-Associated Differential Learning论文 论文代码: Memory-Associated Differential Learning代码 论文解读: 1.Abstract Conventional…

大数据技术 学习之旅_如何开始您的数据科学之旅?

大数据技术 学习之旅Machine Learning seems to be fascinating to a lot of beginners but they often get lost into the pool of information available across different resources. This is true that we have a lot of different algorithms and steps to learn but star…

数据可视化工具_数据可视化

数据可视化工具Visualizations are a great way to show the story that data wants to tell. However, not all visualizations are built the same. My rule of thumb is stick to simple, easy to understand, and well labeled graphs. Line graphs, bar charts, and histo…

Android Studio调试时遇见Install Repository and sync project的问题

我们可以看到,报的错是“Failed to resolve: com.android.support:appcompat-v7:16.”,也就是我们在build.gradle中最后一段中的compile项内容。 AS自动生成的“com.android.support:appcompat-v7:16.”实际上是根据我们的最低版本16来选择16.x.x及以上编…

VGAE(Variational graph auto-encoders)论文及代码解读

一,论文来源 论文pdf Variational graph auto-encoders 论文代码 github代码 二,论文解读 理论部分参考: Variational Graph Auto-Encoders(VGAE)理论参考和源码解析 VGAE(Variational graph auto-en…

tableau大屏bi_Excel,Tableau,Power BI ...您应该使用什么?

tableau大屏biAfter publishing my previous article on data visualization with Power BI, I received quite a few questions about the abilities of Power BI as opposed to those of Tableau or Excel. Data, when used correctly, can turn into digital gold. So what …

网络编程 socket介绍

Socket介绍 Socket是应用层与TCP/IP协议族通信的中间软件抽象层,它是一组接口。在设计模式中,Socket其实就是一个门面模式,它把复杂的TCP/IP协议族隐藏在Socket接口后面,对用户来说,一组简单的接口就是全部。 Socket通…

BP神经网络反向传播手动推导

BP神经网络过程: 基本思想 BP算法是一个迭代算法,它的基本思想如下: 将训练集数据输入到神经网络的输入层,经过隐藏层,最后达到输出层并输出结果,这就是前向传播过程。由于神经网络的输出结果与实际结果…

使用python和pandas进行同类群组分析

背景故事 (Backstory) I stumbled upon an interesting task while doing a data exercise for a company. It was about cohort analysis based on user activity data, I got really interested so thought of writing this post.在为公司进行数据练习时,我偶然发…

搜索引擎优化学习原理_如何使用数据科学原理来改善您的搜索引擎优化工作

搜索引擎优化学习原理Search Engine Optimisation (SEO) is the discipline of using knowledge gained around how search engines work to build websites and publish content that can be found on search engines by the right people at the right time.搜索引擎优化(SEO…

Siamese网络(孪生神经网络)详解

SiameseFCSiamese网络(孪生神经网络)本文参考文章:Siamese背景Siamese网络解决的问题要解决什么问题?用了什么方法解决?应用的场景:Siamese的创新Siamese的理论Siamese的损失函数——Contrastive Loss损失函…

Dubbo 源码分析 - 服务引用

1. 简介 在上一篇文章中,我详细的分析了服务导出的原理。本篇文章我们趁热打铁,继续分析服务引用的原理。在 Dubbo 中,我们可以通过两种方式引用远程服务。第一种是使用服务直联的方式引用服务,第二种方式是基于注册中心进行引用。…

一件登录facebook_我从Facebook的R教学中学到的6件事

一件登录facebookBetween 2018 to 2019, I worked at Facebook as a data scientist — during that time I was involved in developing and teaching a class for R beginners. This was a two-day course that was taught about once a month to a group of roughly 15–20 …

SiameseFC超详解

SiameseFC前言论文来源参考文章论文原理解读首先要知道什么是SOT?(Siamese要做什么)SiameseFC要解决什么问题?SiameseFC用了什么方法解决?SiameseFC网络效果如何?SiameseFC基本框架结构SiameseFC网络结构Si…

Python全栈工程师(字符串/序列)

ParisGabriel Python 入门基础字符串:str用来记录文本信息字符串的表示方式:在非注释中凡是用引号括起来的部分都是字符串‘’ 单引号“” 双引号 三单引""" """ 三双引有内容代表非空字符串否则是空字符串 区别&#xf…

跨库数据表的运算

跨库数据表的运算,一直都是一个说难不算太难,说简单却又不是很简单的、总之是一个麻烦的事。大量的、散布在不同数据库中的数据表们,明明感觉要把它们合并起来,再来个小小的计算,似乎也就那么回事……但真要做起来&…

熊猫在线压缩图_回归图与熊猫和脾气暴躁

熊猫在线压缩图数据可视化 (Data Visualization) I like the plotting facilities that come with Pandas. Yes, there are many other plotting libraries such as Seaborn, Bokeh and Plotly but for most purposes, I am very happy with the simplicity of Pandas plotting…

SiameseRPN详解

SiameseRPN论文来源论文背景一,简介二,研究动机三、相关工作论文理论注意:网络结构:1.Siamese Network2.RPN3.LOSS计算4.Tracking论文的优缺点分析一、Siamese-RPN的贡献/优点:二、Siamese-RPN的缺点:代码流…