Image和Video在同一个Dataloader中交错加载联合训练

单卡实现

本文主要从两个方面进行展开:
1.将两个或多个dataset组合成pytorch中的一个ConcatDataset.这个dataset将会作为pytorch中Dataloader的输入。
2.覆盖重写RandomSampler修改batch产生过程,以确保在第一个batch中产生第一个任务的数据(image),在第二个batch中产生下一个任务的数据(video)。

下述定义了一个BatchSchedulerSampler类,实现了一个新的sampler iterator。首先,通过为每一个单独的dataset创建RandomSampler;接着,在每一个dataset iter中获取对应的sample index;最后,创建新的sample index list

import math
import torch
from torch.utils.data.sampler import RandomSampler
from torch.utils.data.dataset import ConcatDataset
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from dataset.K600 import *class BatchSchedulerSampler(torch.utils.data.sampler.Sampler):"""iterate over tasks and provide a random batch per task in each mini-batch"""def __init__(self, dataset, batch_size):self.dataset = datasetself.batch_size = batch_sizeself.number_of_datasets = len(dataset.datasets)self.largest_dataset_size = max([len(cur_dataset.samples) for cur_dataset in dataset.datasets])def __len__(self):return self.batch_size * math.ceil(self.largest_dataset_size / self.batch_size) * len(self.dataset.datasets)def __iter__(self):samplers_list = []sampler_iterators = []for dataset_idx in range(self.number_of_datasets):cur_dataset = self.dataset.datasets[dataset_idx]sampler = RandomSampler(cur_dataset)samplers_list.append(sampler)cur_sampler_iterator = sampler.__iter__()sampler_iterators.append(cur_sampler_iterator)push_index_val = [0] + self.dataset.cumulative_sizes[:-1]step = self.batch_size * self.number_of_datasetssamples_to_grab = self.batch_size# for this case we want to get all samples in dataset, this force us to resample from the smaller datasetsepoch_samples = self.largest_dataset_size * self.number_of_datasetsfinal_samples_list = []  # this is a list of indexes from the combined datasetfor _ in range(0, epoch_samples, step):for i in range(self.number_of_datasets):cur_batch_sampler = sampler_iterators[i]cur_samples = []for _ in range(samples_to_grab):try:cur_sample_org = cur_batch_sampler.__next__()cur_sample = cur_sample_org + push_index_val[i]cur_samples.append(cur_sample)except StopIteration:# got to the end of iterator - restart the iterator and continue to get samples# until reaching "epoch_samples"sampler_iterators[i] = samplers_list[i].__iter__()cur_batch_sampler = sampler_iterators[i]cur_sample_org = cur_batch_sampler.__next__()cur_sample = cur_sample_org + push_index_val[i]cur_samples.append(cur_sample)final_samples_list.extend(cur_samples)return iter(final_samples_list)if __name__ == "__main__":image_dataset = ImageFolder(root='/mnt/workspace/data/imagenet/data/newtrain', transform=transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]))video_dataset = VideoFolder(root='/mnt/workspace/data/k600/train_videos', transform=transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]))joint_dataset = ConcatDataset([image_dataset, video_dataset])batch_size = 8dataloader = torch.utils.data.DataLoader(dataset=joint_dataset,sampler=BatchSchedulerSampler(dataset=joint_dataset, batch_size=batch_size),batch_size=batch_size, shuffle=False)num_epochs = 1for epoch in range(num_epochs):for inputs, labels in dataloader:print(inputs.shape)
'''
torch.Size([8, 3, 224, 224])
torch.Size([8, 3, 16, 224, 224])
torch.Size([8, 3, 224, 224])
torch.Size([8, 3, 16, 224, 224])
'''

DDP多卡实现

在多卡训练中使用分布式数据并行(DDP)时,你需要重写 DistributedSampler 而不是 RandomSampler,以确保每个进程都能正确地获取数据子集。以下是如何实现一个 BatchSchedulerDistributedSampler 类来支持多卡训练的示例

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/66259.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

rpm包详解

一、rpm包 1、过滤系统rpm包,查询已安装的包 rpm -qa | grep htop2、rpm包导出 yumdownnloader htop-2.2.0.33、查看rpm包信息 rpm -qi 包名二、rpm包列表 1、查看软件包列表 yum list available *docker*2、查看软件包依赖 # rpl仓库 yum install epel-rel…

【Adobe Acrobat PDF】Acrobat failed to connect to a DDE server.是怎么回事?

【Adobe Acrobat PDF】Acrobat failed to connect to a DDE server.是怎么回事? 【Adobe Acrobat PDF】Acrobat failed to connect to a DDE server.是怎么回事? 文章目录 【Adobe Acrobat PDF】Acrobat failed to connect to a DDE server.是怎么回事&…

Rabbitmq 业务异常与未手动确认场景及解决方案

消费端消费异常,业务异常 与 未手动确认是不是一个场景,因为执行完业务逻辑,再确认。解决方案就一个,就是重试一定次数,然后加入死信队列。还有就是消费重新放入队列,然后重新投递给其他消费者,…

每日一题 380. O(1) 时间插入、删除和获取随机元素

380. O(1) 时间插入、删除和获取随机元素 最复杂的部分最简单来思考&#xff0c;其他的部分来弥补 class RandomizedSet { public:vector<int> nums;unordered_map<int,int> mp;RandomizedSet() {}bool insert(int val) {if(mp.count(val)){return false;}else{m…

MongoDB-文章目录

MongoDB学习总结1&#xff08;服务安装&#xff09; MongoDB学习总结2&#xff08;常用命令&#xff09; MongoDB学习总结3&#xff08;js文件中写命令&#xff09; MongoDB学习总结4&#xff08;数据插入、修改&#xff09; MongoDB学习总结5&#xff08;数据查询1&#x…

HBase Cassandra的部署和操作

目录 一&#xff0e;数据库的部署与配置 二&#xff0e;使用命令访问数据库 三&#xff0e;数据库的设计 四&#xff0e;编程实现数据库的访问 一&#xff0e;数据库的部署与配置 1.在单个节点上对进行数据库的单机部署 &#xff08;1&#xff09;下载apache-cassandra-4.1.7-…

springboot实战纪实-课程介绍

教程介绍 Spring Boot是由Pivotal团队提供的一套开源框架&#xff0c;可以简化spring应用的创建及部署。它提供了丰富的Spring模块化支持&#xff0c;可以帮助开发者更轻松快捷地构建出企业级应用。 Spring Boot通过自动配置功能&#xff0c;降低了复杂性&#xff0c;同时支持…

BBP飞控板中的坐标系变换

一般飞控板中至少存在以下坐标系&#xff1a; 陀螺Gyro坐标系加速度计Acc坐标系磁强计Mag坐标系飞控板坐标系 在BBP飞控板采用的IMU为同时包含了陀螺&#xff08;Gyro&#xff09;及加速度计&#xff08;Acc&#xff09;的6轴传感器&#xff0c;故Gyro及Acc为同一坐标系。同时…

数据表中的索引详解

文章目录 一、索引概述二、普通索引三、唯一索引四、全文索引五、多列索引六、索引的设计原则七、隐藏和删除索引 一、索引概述 日常生活中&#xff0c;我们经常会在电话号码簿中查阅“某人”的电话号码&#xff0c;按姓查询或者按字母排序查询&#xff1b;在字典中查阅“某个…

大模型系列17-RAGFlow搭建本地知识库

大模型系列17-RAGFlow搭建本地知识库 安装ollama安装open-wehui安装并运行ragflowRAG&#xff08;检索、增强、生成&#xff09;RAG是什么RAG三过程RAG问答系统构建步骤向量库构建检索模块生成模块 RAG解决LLM的痛点 使用ragflow访问ragflow配置ollama模型添加Embedding模型添加…

C++如何遍历数组vector

在C中&#xff0c;vector是一个可变数组。那么怎么遍历它呢&#xff1f;我们以for循环为例&#xff08;while循环&#xff0c;大家自己脑补&#xff09;。 方法一&#xff1a; 基于范围的for循环&#xff0c;这是C11新引入的。 std::vector<int> v {1, 2, 3, 4, 5, 6…

华为交换机---自动备份配置到指定ftp/sftp服务器

华为交换机—自动备份配置到指定ftp服务器 需求 交换机配置修改后及时备份相关配置,每次配置变化后需要在1分钟后自动进行保存,并且将配置上传至FTP服务器;每隔30分钟,交换机自动把配置上传到FTP服务器。 1、定时保存新配置的时间间隔为*分钟(1天=1440),默认为30分钟(…

深入解析-正则表达式

学习正则&#xff0c;我们到底要学什么&#xff1f; 正则表达式&#xff08;RegEx&#xff09;是一种强大的文本匹配工具&#xff0c;广泛应用于数据验证、文本搜索、替换和解析等领域。学习正则表达式&#xff0c;我们不仅要掌握其语法规则&#xff0c;还需要学会如何高效地利…

R shiny app | 网页应用 空格分隔的文本文件在线转csv

shiny 能快速把R程序以web app的形式提供出来&#xff0c;方便使用&#xff0c;降低技术使用门槛。 本文提供的示例&#xff1a;把空格分隔的txt文件转为逗号分隔的csv文件。 前置依赖&#xff1a;需要有R环境(v4.2.0)&#xff0c;安装shiny包(v1.9.1)。括号内是我使用的版本…

SocraticLM: Exploring Socratic Personalized Teaching with Large Language Models

题目 苏格拉底式教学:用大型语言模型探索苏格拉底式个性化教学 论文地址&#xff1a;https://openreview.net/pdf?idqkoZgJhxsA 项目地址&#xff1a;https://github.com/Ljyustc/SocraticLM 摘要 大型语言模型(LLM)被认为是推进智能教育的一项关键技术&#xff0c;因为它们展…

第一节:电路连接【51单片机+A4988+步进电机教程】

摘要&#xff1a;本节介绍如何搭建一个51单片机A4988步进电机控制电路&#xff0c;所用材料均为常见的模块&#xff0c;简单高效的方式搭建起硬件环境 一、硬件清单 ①51单片机最小控制模块 ②开关电源 ③A4988模块转接座 ④二相四线步进电机 ⑤电线若干 二、接线 三、A49…

C++并发:并发操作的同步

有时我们不仅要共享数据&#xff0c;也要让独立线程上的行为同步。例如&#xff0c;某线程只有先等待另一线程的任务完成&#xff0c;才可以执行自己的任务。 C提供了处理工具&#xff1a;条件变量和future 并且进行了扩充&#xff1a;线程闩&#xff08;latch&#xff09;&a…

Outlook2024版如何回到经典Outlook

Outlook2024版如何回到经典Outlook 如果新加入一家公司&#xff0c;拿到的电脑&#xff0c;大概率是最新版的Windows, 一切都是新的。 如果不coding, 使用国产的foxmail大概就可以解决一切问题了。可惜老程序员很多Coding都是基于传统Outlook的&#xff0c;科技公司所有人都是I…

【大模型】7 天 AI 大模型学习

因为想先快速把 llama 模型学习了&#xff0c;所以跳了两次课&#xff0c;这是这两次课的主要内容&#xff0c;后面有时间会补充上的 &#xff5e; 主要内容有&#xff1a;一些微调技术&#xff08;Alpaca、AdaLoRA、QLoRA&#xff09;、Prefix Tuning、Quantization 1. Alpaca…

网关如何识别和阻止网络攻击

网关在识别和阻止网络攻击方面扮演着关键角色&#xff0c;它通过多种技术和机制来确保网络的安全。以下是网关如何识别和阻止网络攻击的一些主要方法&#xff1a; 1.深度包检测&#xff08;DPI&#xff09; 网关可以对经过的数据包进行深度分析&#xff0c;检查数据包的头部、负…