Image和Video在同一个Dataloader中交错加载联合训练

单卡实现

本文主要从两个方面进行展开:
1.将两个或多个dataset组合成pytorch中的一个ConcatDataset.这个dataset将会作为pytorch中Dataloader的输入。
2.覆盖重写RandomSampler修改batch产生过程,以确保在第一个batch中产生第一个任务的数据(image),在第二个batch中产生下一个任务的数据(video)。

下述定义了一个BatchSchedulerSampler类,实现了一个新的sampler iterator。首先,通过为每一个单独的dataset创建RandomSampler;接着,在每一个dataset iter中获取对应的sample index;最后,创建新的sample index list

import math
import torch
from torch.utils.data.sampler import RandomSampler
from torch.utils.data.dataset import ConcatDataset
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from dataset.K600 import *class BatchSchedulerSampler(torch.utils.data.sampler.Sampler):"""iterate over tasks and provide a random batch per task in each mini-batch"""def __init__(self, dataset, batch_size):self.dataset = datasetself.batch_size = batch_sizeself.number_of_datasets = len(dataset.datasets)self.largest_dataset_size = max([len(cur_dataset.samples) for cur_dataset in dataset.datasets])def __len__(self):return self.batch_size * math.ceil(self.largest_dataset_size / self.batch_size) * len(self.dataset.datasets)def __iter__(self):samplers_list = []sampler_iterators = []for dataset_idx in range(self.number_of_datasets):cur_dataset = self.dataset.datasets[dataset_idx]sampler = RandomSampler(cur_dataset)samplers_list.append(sampler)cur_sampler_iterator = sampler.__iter__()sampler_iterators.append(cur_sampler_iterator)push_index_val = [0] + self.dataset.cumulative_sizes[:-1]step = self.batch_size * self.number_of_datasetssamples_to_grab = self.batch_size# for this case we want to get all samples in dataset, this force us to resample from the smaller datasetsepoch_samples = self.largest_dataset_size * self.number_of_datasetsfinal_samples_list = []  # this is a list of indexes from the combined datasetfor _ in range(0, epoch_samples, step):for i in range(self.number_of_datasets):cur_batch_sampler = sampler_iterators[i]cur_samples = []for _ in range(samples_to_grab):try:cur_sample_org = cur_batch_sampler.__next__()cur_sample = cur_sample_org + push_index_val[i]cur_samples.append(cur_sample)except StopIteration:# got to the end of iterator - restart the iterator and continue to get samples# until reaching "epoch_samples"sampler_iterators[i] = samplers_list[i].__iter__()cur_batch_sampler = sampler_iterators[i]cur_sample_org = cur_batch_sampler.__next__()cur_sample = cur_sample_org + push_index_val[i]cur_samples.append(cur_sample)final_samples_list.extend(cur_samples)return iter(final_samples_list)if __name__ == "__main__":image_dataset = ImageFolder(root='/mnt/workspace/data/imagenet/data/newtrain', transform=transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]))video_dataset = VideoFolder(root='/mnt/workspace/data/k600/train_videos', transform=transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]))joint_dataset = ConcatDataset([image_dataset, video_dataset])batch_size = 8dataloader = torch.utils.data.DataLoader(dataset=joint_dataset,sampler=BatchSchedulerSampler(dataset=joint_dataset, batch_size=batch_size),batch_size=batch_size, shuffle=False)num_epochs = 1for epoch in range(num_epochs):for inputs, labels in dataloader:print(inputs.shape)
'''
torch.Size([8, 3, 224, 224])
torch.Size([8, 3, 16, 224, 224])
torch.Size([8, 3, 224, 224])
torch.Size([8, 3, 16, 224, 224])
'''

DDP多卡实现

在多卡训练中使用分布式数据并行(DDP)时,你需要重写 DistributedSampler 而不是 RandomSampler,以确保每个进程都能正确地获取数据子集。以下是如何实现一个 BatchSchedulerDistributedSampler 类来支持多卡训练的示例

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/66259.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

springboot实战纪实-课程介绍

教程介绍 Spring Boot是由Pivotal团队提供的一套开源框架,可以简化spring应用的创建及部署。它提供了丰富的Spring模块化支持,可以帮助开发者更轻松快捷地构建出企业级应用。 Spring Boot通过自动配置功能,降低了复杂性,同时支持…

BBP飞控板中的坐标系变换

一般飞控板中至少存在以下坐标系: 陀螺Gyro坐标系加速度计Acc坐标系磁强计Mag坐标系飞控板坐标系 在BBP飞控板采用的IMU为同时包含了陀螺(Gyro)及加速度计(Acc)的6轴传感器,故Gyro及Acc为同一坐标系。同时…

数据表中的索引详解

文章目录 一、索引概述二、普通索引三、唯一索引四、全文索引五、多列索引六、索引的设计原则七、隐藏和删除索引 一、索引概述 日常生活中,我们经常会在电话号码簿中查阅“某人”的电话号码,按姓查询或者按字母排序查询;在字典中查阅“某个…

大模型系列17-RAGFlow搭建本地知识库

大模型系列17-RAGFlow搭建本地知识库 安装ollama安装open-wehui安装并运行ragflowRAG(检索、增强、生成)RAG是什么RAG三过程RAG问答系统构建步骤向量库构建检索模块生成模块 RAG解决LLM的痛点 使用ragflow访问ragflow配置ollama模型添加Embedding模型添加…

R shiny app | 网页应用 空格分隔的文本文件在线转csv

shiny 能快速把R程序以web app的形式提供出来,方便使用,降低技术使用门槛。 本文提供的示例:把空格分隔的txt文件转为逗号分隔的csv文件。 前置依赖:需要有R环境(v4.2.0),安装shiny包(v1.9.1)。括号内是我使用的版本…

SocraticLM: Exploring Socratic Personalized Teaching with Large Language Models

题目 苏格拉底式教学:用大型语言模型探索苏格拉底式个性化教学 论文地址:https://openreview.net/pdf?idqkoZgJhxsA 项目地址:https://github.com/Ljyustc/SocraticLM 摘要 大型语言模型(LLM)被认为是推进智能教育的一项关键技术,因为它们展…

第一节:电路连接【51单片机+A4988+步进电机教程】

摘要:本节介绍如何搭建一个51单片机A4988步进电机控制电路,所用材料均为常见的模块,简单高效的方式搭建起硬件环境 一、硬件清单 ①51单片机最小控制模块 ②开关电源 ③A4988模块转接座 ④二相四线步进电机 ⑤电线若干 二、接线 三、A49…

Outlook2024版如何回到经典Outlook

Outlook2024版如何回到经典Outlook 如果新加入一家公司,拿到的电脑,大概率是最新版的Windows, 一切都是新的。 如果不coding, 使用国产的foxmail大概就可以解决一切问题了。可惜老程序员很多Coding都是基于传统Outlook的,科技公司所有人都是I…

网关如何识别和阻止网络攻击

网关在识别和阻止网络攻击方面扮演着关键角色,它通过多种技术和机制来确保网络的安全。以下是网关如何识别和阻止网络攻击的一些主要方法: 1.深度包检测(DPI) 网关可以对经过的数据包进行深度分析,检查数据包的头部、负…

操作系统复习(理论版)

目录 只会在选择填空出现类型 第一章:操作系统导论 操作系统介绍 不得不知道的概念 可能出现在答题的类型 第二章:进程调度 进程管理: 处理机调度: 进程同步: 死锁: 预防死锁: 避免死…

概述(讲讲python基本语法和第三方库)

我是北子,这是我自己写的python教程,主要是记录自己的学习成果方便自己日后复习, 我先学了C/C,所以这套教程中可能会将很多概念和C/C去对比,所以该教程大概不适合零基础的人。 it seems that python nowadays 只在人工…

Linux(Centos 7.6)命令详解:ls

1.命令作用 列出目录内容(list directory contents) 2.命令语法 Usage: ls [OPTION]... [FILE]... 3.参数详解 OPTION: -l,long list 使用长列表格式-a,all 不忽略.开头的条目(打印所有条目,包括.开头的隐藏条目&#xff09…

改善 Kibana 中的 ES|QL 编辑器体验

作者:来自 Elastic Marco Liberati 随着新的 ES|QL 语言正式发布,Kibana 中开发了一种新的编辑器体验,以帮助用户编写更快、更好的查询。实时验证、改进的自动完成和快速修复等功能将简化 ES|QL 体验。 我们将介绍改进 Kibana 中 ES|QL 编辑器…

基于Spring Boot的紧急物资管理系统

基于Spring Boot的紧急物资管理系统是一个非常实用的应用,特别是在应对自然灾害、公共卫生事件等情况下。该系统可以帮助管理者有效地追踪和分配物资,确保资源能够及时到达需要的地方。以下是一个基本的实现思路和一些关键组件: 项目规划 需…

机器学习基础-概率图模型

(一阶)马尔科夫模型的基本概念 状态、状态转换概率、初始概率 状态转移矩阵的基本概念 隐马尔可夫模型(HMM)的基本概念 条件随机场(CRF)的基本概念 实际应用中的马尔科夫性 自然语言处理: 在词…

Qt打包为exe文件

个人学习笔记 选择release 进入项目文件夹,查看releas生成的文件 releas文件路径 进入release看到exe文件,但是无法执行 将exe文件单独放到一个文件夹内 选择MinGW 用CD 进入存放exe文件的路径,输入下面指令 cd J:\C\Qt\test4-3-1 windeploy…

VScode怎么重启

原文链接:【vscode】vscode重新启动 键盘按下 Ctrl Shift p 打开命令行,如下图: 输入Reload Window,如下图:

Web安全 - “Referrer Policy“ Security 头值不安全

文章目录 概述原因分析风险说明Referrer-Policy 头配置选项1. 不安全的策略no-referrer-when-downgradeunsafe-url 2. 安全的策略no-referreroriginorigin-when-cross-originsame-originstrict-originstrict-origin-when-cross-origin 推荐配置Nginx 配置示例 在 Nginx 中配置 …

Hyperbolic dynamics

http://www.scholarpedia.org/article/Hyperbolic_dynamics#:~:textAmong%20smooth%20dynamical%20systems%2C%20hyperbolic%20dynamics%20is%20characterized,semilocal%20or%20even%20global%20information%20about%20the%20dynamics. 什么是双曲动力系统? A hy…

基于SpringBoot在线竞拍平台系统功能实现十五

一、前言介绍: 1.1 项目摘要 随着网络技术的飞速发展和电子商务的普及,竞拍系统作为一种新型的在线交易方式,已经逐渐深入到人们的日常生活中。传统的拍卖活动需要耗费大量的人力、物力和时间,从组织拍卖、宣传、报名、竞拍到成…