基于PyTorch搭建CNN实现视频动作分类任务代码详解

在这里插入图片描述
数据及具体讲解来源:
基于PyTorch搭建CNN实现视频动作分类任务

import torch
import torch.nn as nn
import torchvision.transforms as T
import scipy.io
from torch.utils.data import DataLoader,Dataset
import os
from PIL import Image
from torch.autograd import Variable
import numpy as np"""
加载数据
"""
#获得标签
label_mat = scipy.io.loadmat('./datasets/q3_2_data.mat')
#获得训练集标签
label_train = label_mat['trLb']
print(len(label_train))
#获得验证集标签
label_val = label_mat['valLb']
print(len(label_val))"""
通过Dataset类进行数据预处理
"""
class ActionDataset(Dataset):def __init__(self,root_dir,labels = [],transform=None):"""Args::param root_dir: 数据路径:param labels: 图片标签:param transform: 数据处理函数"""self.root_dir = root_dirself.transform = transformself.length = len(os.listdir(self.root_dir))self.labels = labelsdef __len__(self):  #返回数据数量return self.length*3    #一个视频片段包含3帧(3个图片)def __getitem__(self, idx): #图片处理及返回数据folder = idx//3+1   #判断该帧属于第几个视频中imidx = idx%3 + 1   #判断该帧在该视频中是第几帧folder = format(folder,'05d')   #将folder格式化,05d代表五位数,若不到五位用0填充imgname = str(imidx) + '.jpg'img_path = os.path.join(self.root_dir,folder,imgname)image = Image.open(img_path)"""当输入标签有值时,说明是训练集和验证集,输出的样本也是有标签的,若没有值,说明是测试集,输出的样本是没有标签的"""if len(self.labels)!=0:Label = self.labels[idx//3][0]-1#如果有对数据的处理先对数据进行处理if self.transform:image = self.transform(image)if len(self.labels)!=0:sample = {'image':image,'img_path':img_path,'Label':Label}else:sample = {'image':image,'img_path':img_path}return sampleimage_datast = ActionDataset(root_dir='./datasets/trainClips/',labels=label_train,transform=T.ToTensor())
# torchvision.transforms中定义了非常多对图像的预处理方法,这里使用的ToTensor方法为将0~255的RGB值映射到0~1的Tensor类型。
# #测试一下
# for i in range(3):
#     sample = image_datast[i]
#     print(sample['image'].shape)
#     print(sample['Label'])
#     print(sample['img_path'])"""
Dataloader类进行封装
注意:Windows不要用num_works
"""
#image_dataloader = DataLoader(image_datast,batch_size=4,shuffle=True)
# for i , sample in enumerate(image_dataloader):
#     #enumerate(iteration, start):返回一个枚举的对象
#     sample['image'] = sample['image']
#     print(sample[i,sample['image'].shape,sample['img_path'],'Label'])
#     if i == 6:
#         break
#训练集
image_dataset_train=ActionDataset(root_dir='./datasets/trainClips/',labels=label_train,transform=T.ToTensor())
image_dataloader_train = DataLoader(image_dataset_train, batch_size=32,shuffle=True)
#验证集
image_dataset_val=ActionDataset(root_dir='./datasets/valClips/',labels=label_val,transform=T.ToTensor())
image_dataloader_val = DataLoader(image_dataset_val, batch_size=32,shuffle=False)
#测试集:没有给定labels
image_dataset_test=ActionDataset(root_dir='./datasets/testClips/',labels=[],transform=T.ToTensor())
image_dataloader_test = DataLoader(image_dataset_test, batch_size=32,shuffle=False)"""
搭建模型
"""dtype = torch.FloatTensor # 这是pytorch所支持的cpu数据类型中的浮点数类型。print_every = 100   # 这个参数用于控制loss的打印频率,因为我们需要在训练过程中不断的对loss进行检测。def reset(m):   # 这是模型参数的初始化if hasattr(m, 'reset_parameters'):m.reset_parameters()#数据解释和处理
class Flatten(nn.Module):def forward(self, x):N, C, H, W = x.size() # 读取各个维度。return x.view(N, -1)  # -1代表除了特殊声明过的以外的全部维度。fixed_model_base = nn.Sequential(nn.Conv2d(3,8,kernel_size=7,stride=1),   ##3*64*64 -> 8*58*58nn.ReLU(inplace=True),nn.MaxPool2d(2, stride = 2),    # 8*58*58 -> 8*29*29nn.Conv2d(8, 16, kernel_size=7, stride=1), # 8*29*29 -> 16*23*23nn.ReLU(inplace=True),nn.MaxPool2d(2, stride = 2), # 16*23*23 -> 16*11*11Flatten(),nn.ReLU(inplace=True),nn.Linear(1936, 10)     # 1936 = 16*11*11
)
fixed_model = fixed_model_base.type(dtype)  #将模型数据转换成pytorch所支持的cpu数据类型中的浮点数类型。
# #测试:
# x = torch.randn(32, 3, 64, 64).type(dtype)
# x_var = Variable(x.type(dtype)) # 需要将其封装为Variable类型。
# ans = fixed_model(x_var)
# print(np.array(ans.size())) # 检查模型输出。
# np.array_equal(np.array(ans.size()), np.array([32, 10]))"""
训练步骤及模块
"""
optimizer = torch.optim.RMSprop(fixed_model_base.parameters(), lr = 0.0001)
loss_fn = nn.CrossEntropyLoss()def train(model,loss_fn,optimizer,dataloader,num_epoch = 1):for epoch in range(num_epoch):check_accuracy(fixed_model,image_dataloader_val)    #在验证集验证模型效果model.train()   #模型的.train()方法让模型进入训练模式,参数保留梯度,dropout层等部分正常工作for t,sample in enumerate(dataloader):x_var = Variable(sample['image'])y_var = Variable(sample['Label'].long())scores = model(x_var)   #得到输出loss = loss_fn(scores,y_var)if (t+1)%print_every ==0:print('t = %d, loss = %.4f' % (t + 1, loss.item()))#三步更新optimizer.zero_grad()loss.backward()optimizer.step()def check_accuracy(model,loader):num_correct = 0num_samples = 0model.eval()    # 模型的.eval()方法切换进入评测模式,对应的dropout等部分将停止工作。for t,sample in enumerate(loader):x_var = Variable(sample['image'])y_var = Variable(sample['Label'])scores = model(x_var)_,preds = scores.data.max(1)    # 找到可能最高的标签作为输出。num_correct += (preds.numpy() == y_var.numpy()).sum()num_samples += preds.size(0)acc = float(num_correct)/num_samplesprint('Got %d / %d correct (%.2f)' % (num_correct, num_samples, 100 * acc))"""
训练并验证
"""
torch.random.manual_seed(54321)
fixed_model.cpu()
fixed_model.apply(reset)
#pytorch中的model.apply(fn)会递归地将函数fn应用到父模块的每个子模块submodule,也包括model这个父模块自身
fixed_model.train()
train(fixed_model, loss_fn, optimizer,image_dataloader_train, num_epoch=5)
check_accuracy(fixed_model, image_dataloader_val)"""
测试
"""def predict_on_test(model, loader):model.eval()results = open('results.csv', 'w')  # 模型预测结果会被放在这里。count = 0results.write('Id' + ',' + 'Class' + '\n')for t, sample in enumerate(loader):x_var = Variable(sample['image'])scores = model(x_var)_, preds = scores.data.max(1)for i in range(len(preds)):results.write(str(count) + ',' + str(preds[i]) + '\n')count += 1results.close()return countcount = predict_on_test(fixed_model, image_dataloader_test)  # 放入你想要测试的训练集,然后打开文件去看一看结果吧。
print(count)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/389283.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

missforest_missforest最佳丢失数据插补算法

missforestMissing data often plagues real-world datasets, and hence there is tremendous value in imputing, or filling in, the missing values. Unfortunately, standard ‘lazy’ imputation methods like simply using the column median or average don’t work wel…

华硕猛禽1080ti_F-22猛禽动力回路的视频分析

华硕猛禽1080tiThe F-22 Raptor has vectored thrust. This means that the engines don’t just push towards the front of the aircraft. Instead, the thrust can be directed upward or downward (from the rear of the jet). With this vectored thrust, the Raptor can …

聊天常用js代码

<script languagejavascript>//转意义字符与替换图象以及字体HtmlEncode(text)function HtmlEncode(text){return text.replace(//"/g, &quot;).replace(/</g, <).replace(/>/g, >).replace(/#br#/g,<br>).replace(/IMGSTART/g,<IMG style…

温故而知新:柯里化 与 bind() 的认知

什么是柯里化?科里化是把一个多参数函数转化为一个嵌套的一元函数的过程。&#xff08;简单的说就是将函数的参数&#xff0c;变为多次入参&#xff09; const curry (fn, ...args) > fn.length < args.length ? fn(...args) : curry.bind(null, fn, ...args); // 想要…

OPENVAS运行

https://www.jianshu.com/p/382546aaaab5转载于:https://www.cnblogs.com/diyunpeng/p/9258163.html

Memory-Associated Differential Learning论文及代码解读

Memory-Associated Differential Learning论文及代码解读 论文来源&#xff1a; 论文PDF&#xff1a; Memory-Associated Differential Learning论文 论文代码&#xff1a; Memory-Associated Differential Learning代码 论文解读&#xff1a; 1.Abstract Conventional…

大数据技术 学习之旅_如何开始您的数据科学之旅?

大数据技术 学习之旅Machine Learning seems to be fascinating to a lot of beginners but they often get lost into the pool of information available across different resources. This is true that we have a lot of different algorithms and steps to learn but star…

纯API函数实现串口读写。

以最后决定用纯API函数实现串口读写。 先从网上搜索相关代码&#xff08;关键字&#xff1a;C# API 串口&#xff09;&#xff0c;发现网上相关的资料大约来源于一个版本&#xff0c;那就是所谓的msdn提供的样例代码&#xff08;msdn的具体出处&#xff0c;我没有考证&#xff…

数据可视化工具_数据可视化

数据可视化工具Visualizations are a great way to show the story that data wants to tell. However, not all visualizations are built the same. My rule of thumb is stick to simple, easy to understand, and well labeled graphs. Line graphs, bar charts, and histo…

Android Studio调试时遇见Install Repository and sync project的问题

我们可以看到&#xff0c;报的错是“Failed to resolve: com.android.support:appcompat-v7:16.”&#xff0c;也就是我们在build.gradle中最后一段中的compile项内容。 AS自动生成的“com.android.support:appcompat-v7:16.”实际上是根据我们的最低版本16来选择16.x.x及以上编…

Apache Ignite 学习笔记(二): Ignite Java Thin Client

前一篇文章&#xff0c;我们介绍了如何安装部署Ignite集群&#xff0c;并且尝试了用REST和SQL客户端连接集群进行了缓存和数据库的操作。现在我们就来写点代码&#xff0c;用Ignite的Java thin client来连接集群。 在开始介绍具体代码之前&#xff0c;让我们先简单的了解一下Ig…

VGAE(Variational graph auto-encoders)论文及代码解读

一&#xff0c;论文来源 论文pdf Variational graph auto-encoders 论文代码 github代码 二&#xff0c;论文解读 理论部分参考&#xff1a; Variational Graph Auto-Encoders&#xff08;VGAE&#xff09;理论参考和源码解析 VGAE&#xff08;Variational graph auto-en…

IIS7设置

IIS 7.0和IIS 6.0相比改变很大谁都知道&#xff0c;而且在IIS 7.0中用VS2005来调试Web项目也不是什么新鲜的话题&#xff0c;但是我还是第一次运用这个东东&#xff0c;所以在此记下我的一些过程&#xff0c;希望能给更多的后来者带了一点参考。其实我写这篇文章时也参考了其他…

tableau大屏bi_Excel,Tableau,Power BI ...您应该使用什么?

tableau大屏biAfter publishing my previous article on data visualization with Power BI, I received quite a few questions about the abilities of Power BI as opposed to those of Tableau or Excel. Data, when used correctly, can turn into digital gold. So what …

python 可视化工具_最佳的python可视化工具

python 可视化工具Disclaimer: I work for Datapane免责声明&#xff1a;我为Datapane工作 动机 (Motivation) There are amazing articles on data visualization on Medium every day. Although this comes at the cost of information overload, it shouldn’t prevent you …

网络编程 socket介绍

Socket介绍 Socket是应用层与TCP/IP协议族通信的中间软件抽象层&#xff0c;它是一组接口。在设计模式中&#xff0c;Socket其实就是一个门面模式&#xff0c;它把复杂的TCP/IP协议族隐藏在Socket接口后面&#xff0c;对用户来说&#xff0c;一组简单的接口就是全部。 Socket通…

猿课python 第三天

字典 字典是python中唯一的映射类型,字典对象是可变的&#xff0c;但是字典的键是不可变对象&#xff0c;字典中可以使用不同的键值字典功能> dict.clear()          -->清空字典 dict.keys()          -->获取所有key dict.values()      …

在C#中使用代理的方式触发事件

事件&#xff08;event&#xff09;是一个非常重要的概念&#xff0c;我们的程序时刻都在触发和接收着各种事件&#xff1a;鼠标点击事件&#xff0c;键盘事件&#xff0c;以及处理操作系统的各种事件。所谓事件就是由某个对象发出的消息。比如用户按下了某个按钮&#xff0c;某…

BP神经网络反向传播手动推导

BP神经网络过程&#xff1a; 基本思想 BP算法是一个迭代算法&#xff0c;它的基本思想如下&#xff1a; 将训练集数据输入到神经网络的输入层&#xff0c;经过隐藏层&#xff0c;最后达到输出层并输出结果&#xff0c;这就是前向传播过程。由于神经网络的输出结果与实际结果…

使用python和pandas进行同类群组分析

背景故事 (Backstory) I stumbled upon an interesting task while doing a data exercise for a company. It was about cohort analysis based on user activity data, I got really interested so thought of writing this post.在为公司进行数据练习时&#xff0c;我偶然发…