手把手写深度学习(23):视频扩散模型之Video DataLoader

手把手写深度学习(0):专栏文章导航

前言:训练自己的视频扩散模型的第一步就是准备数据集,而且这个数据集是text-video或者image-video的多模态数据集,这篇博客手把手教读者如何写一个这样扩散模型的的Video DataLoader。

目录

准备工作

下载数据集

视频数据打标签

代码讲解

纯视频文件夹+txt描述prompt 读取方式

CSV描述文件读取方式


准备工作

下载数据集

一般会去下载webvid数据集,但是这个数据集非常大,如果读者不做预训练的话不建议下载。

《Animating Pictures with Eulerian Motion Fields》提供了一个比较小的测试数据集:Animating Pictures with Eulerian Motion Fields

大概一个GB左右,谷歌云盘的链接如下:

https://drive.google.com/file/d/1-MKuNxO1mjopgY6UoEVGDVt5I_QvVeDn/view

下载之后的.pth文件我们暂时不用管,可以先删除掉,只保留.mp4文件。

视频数据打标签

很多数据集是没有一个比较好的文字描述的,如果我们要训练text-to-video的任务,第一步要做的事情是对视频数据打上文字标签。

如果有,那么就算了,主打一个淘气(不是)

还是下一讲专门讲一下如何用V-BLIP给视频数据打上text标签吧

代码讲解

纯视频文件夹+txt描述prompt 读取方式

第一个DataLoader只需要输入视频的文件夹路径,prompt要么是全部指定成相同的(那肯定不行),要么从同名的txt文件中读取:

        if os.path.exists(self.video_files[index].replace(".mp4", ".txt")):with open(self.video_files[index].replace(".mp4", ".txt"), "r") as f:prompt = f.read()else:prompt = self.fallback_prompt

注意这里的text我们直接用预训练的tokenizer编码了,如果不想要的话也可以把这里注释掉:

    def get_prompt_ids(self, prompt):return self.tokenizer(prompt,truncation=True,padding="max_length",max_length=self.tokenizer.model_max_length,return_tensors="pt",).input_ids

获取视频的部分需要特别注意的是,要把"f h w c"转换成"f c h w":

        video = rearrange(video, "f h w c -> f c h w")

完整代码如下:

class VideoFolderDataset(Dataset):def __init__(self,tokenizer=None,width: int = 256,height: int = 256,n_sample_frames: int = 16,fps: int = 8,path: str = "./data",fallback_prompt: str = "",use_bucketing: bool = False,**kwargs):self.tokenizer = tokenizerself.use_bucketing = use_bucketingself.fallback_prompt = fallback_promptself.video_files = glob(f"{path}/*.mp4")self.width = widthself.height = heightself.n_sample_frames = n_sample_framesself.fps = fpsdef get_frame_buckets(self, vr):h, w, c = vr[0].shape        width, height = sensible_buckets(self.width, self.height, w, h)resize = T.transforms.Resize((height, width), antialias=True)return resizedef get_frame_batch(self, vr, resize=None):n_sample_frames = self.n_sample_framesnative_fps = vr.get_avg_fps()every_nth_frame = max(1, round(native_fps / self.fps))every_nth_frame = min(len(vr), every_nth_frame)effective_length = len(vr) // every_nth_frameif effective_length < n_sample_frames:n_sample_frames = effective_lengtheffective_idx = random.randint(0, (effective_length - n_sample_frames))idxs = every_nth_frame * np.arange(effective_idx, effective_idx + n_sample_frames)video = vr.get_batch(idxs)video = rearrange(video, "f h w c -> f c h w")if resize is not None: video = resize(video)return video, vrdef process_video_wrapper(self, vid_path):video, vr = process_video(vid_path,self.use_bucketing,self.width, self.height, self.get_frame_buckets, self.get_frame_batch)return video, vrdef get_prompt_ids(self, prompt):return self.tokenizer(prompt,truncation=True,padding="max_length",max_length=self.tokenizer.model_max_length,return_tensors="pt",).input_ids@staticmethoddef __getname__(): return 'folder'def __len__(self):return len(self.video_files)def __getitem__(self, index):video, _ = self.process_video_wrapper(self.video_files[index])if os.path.exists(self.video_files[index].replace(".mp4", ".txt")):with open(self.video_files[index].replace(".mp4", ".txt"), "r") as f:prompt = f.read()else:prompt = self.fallback_promptprompt_ids = self.get_prompt_ids(prompt)return {"pixel_values": normalize_input(video[0]), "prompt_ids": prompt_ids, "text_prompt": prompt, 'dataset': self.__getname__()}

CSV描述文件读取方式

这种方法每次都要打开一个txt文件去读取prompt,很不方便。而且如果读取的量级大了之后IO的开销会很大!

所以建议使用CSV方式的读取,CSV文件中存放着video-prompt的对应关系,样例如下:

video_path,prompt

...

video_path建议写成绝对路径,这样更方便读取。

完整代码如下:

class VideoCSVDataset(Dataset):def __init__(self,tokenizer=None,width: int = 256,height: int = 256,n_sample_frames: int = 16,fps: int = 8,csv_path: str = "./data",use_bucketing: bool = False,**kwargs):self.tokenizer = tokenizerself.use_bucketing = use_bucketingif not os.path.exists(csv_path):raise FileNotFoundError(f"The csv path does not exist: {csv_path}")self.csv_data = pd.read_csv(csv_path)self.width = widthself.height = heightself.n_sample_frames = n_sample_framesself.fps = fpsdef get_frame_buckets(self, vr):h, w, c = vr[0].shape        width, height = sensible_buckets(self.width, self.height, w, h)resize = T.transforms.Resize((height, width), antialias=True)return resizedef get_frame_batch(self, vr, resize=None):n_sample_frames = self.n_sample_framesnative_fps = vr.get_avg_fps()every_nth_frame = max(1, round(native_fps / self.fps))every_nth_frame = min(len(vr), every_nth_frame)effective_length = len(vr) // every_nth_frameif effective_length < n_sample_frames:n_sample_frames = effective_lengtheffective_idx = random.randint(0, (effective_length - n_sample_frames))idxs = every_nth_frame * np.arange(effective_idx, effective_idx + n_sample_frames)video = vr.get_batch(idxs)video = rearrange(video, "f h w c -> f c h w")if resize is not None: video = resize(video)return video, vrdef process_video_wrapper(self, vid_path):video, vr = process_video(vid_path,self.use_bucketing,self.width, self.height, self.get_frame_buckets, self.get_frame_batch)return video, vrdef get_prompt_ids(self, prompt):return self.tokenizer(prompt,truncation=True,padding="max_length",max_length=self.tokenizer.model_max_length,return_tensors="pt",).input_ids@staticmethoddef __getname__(): return 'csv'def __len__(self):return len(self.csv_data)def __getitem__(self, index):print(self.csv_data.iloc[index])video_path, prompt = self.csv_data.iloc[index]video, _ = self.process_video_wrapper(video_path)prompt_ids = self.get_prompt_ids(prompt)return {"pixel_values": normalize_input(video[0]), "prompt_ids": prompt_ids, "text_prompt": prompt, 'dataset': self.__getname__()}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/741823.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Vue3】深入理解Vue3路由器的工作原理to的两种写法

&#x1f497;&#x1f497;&#x1f497;欢迎来到我的博客&#xff0c;你将找到有关如何使用技术解决问题的文章&#xff0c;也会找到某个技术的学习路线。无论你是何种职业&#xff0c;我都希望我的博客对你有所帮助。最后不要忘记订阅我的博客以获取最新文章&#xff0c;也欢…

Realsense 相机SDK学习(一)——librealsense使用方法及bug解决(不使用Ros)

一.介绍 realsense相机是一个intel开发出来的一款深度相机&#xff0c;我之前使用他来跑过slam&#xff0c;也配置过他的驱动&#xff0c;在此附上realsense的相机驱动安装方法&#xff1a;Ubuntu20.04安装Intelrealsense相机驱动&#xff08;涉及Linux内核降级&#xff09; …

【四】【算法分析与设计】贪心算法的初见

455. 分发饼干 假设你是一位很棒的家长&#xff0c;想要给你的孩子们一些小饼干。但是&#xff0c;每个孩子最多只能给一块饼干。 对每个孩子 i&#xff0c;都有一个胃口值 g[i]&#xff0c;这是能让孩子们满足胃口的饼干的最小尺寸&#xff1b;并且每块饼干 j&#xff0c;都有…

AI时代Python金融大数据分析实战:ChatGPT让金融大数据分析插上翅膀【文末送书-38】

文章目录 Python驱动的金融智能&#xff1a;数据分析、交易策略与风险管理Python在金融数据分析中的应用 实战案例&#xff1a;基于ChatGPT的金融事件预测AI时代Python金融大数据分析实战&#xff1a;ChatGPT让金融大数据分析插上翅膀【文末送书-38】 Python驱动的金融智能&…

eVTOL适航领先新构型,沃飞长空布局空中交通新局面

汽车、火车、飞机……人类对于出行方式的探索从未停止。随着沃飞长空旗下首款自研eVTOL(飞行汽车)AE200适航技术验证机一阶段顺利试飞,eVTOL(飞行汽车)这种面向空中交通的新型交通工具进入了我们的视野,那么eVTOL(飞行汽车)是什么?eVTOL(飞行汽车)前景怎么样? eVTOL(飞行汽车…

Power Apps 学习笔记 -- Action

文章目录 1. Action 简介2. Action 配置3. 待补充 1. Action 简介 Action基础教程 : Action概述 操作Action: 1. 操作Action类似于工作流Workflow&#xff0c;提供一些重用性的操作&#xff0c;允许工作流或其他Web服务端点调用(例如javascript). 2. Action 类似于c#当中的一个…

专题二 -滑动窗口 - leetcode 209. 长度最小的子数组 | 中等难度

leetcode 209. 长度最小的子数组 leetcode 209. 长度最小的子数组 | 中等难度1. 题目详情1. 原题链接2. 基础框架 2. 解题思路1. 题目分析2. 算法原理3. 时间复杂度 3. 代码实现4. 知识与收获 leetcode 209. 长度最小的子数组 | 中等难度 1. 题目详情 给定一个含有 n 个正整数…

Android14音频进阶:AudioTrack如何巧妙衔接AudioFlinger(五十七)

简介: CSDN博客专家,专注Android/Linux系统,分享多mic语音方案、音视频、编解码等技术,与大家一起成长! 优质专栏:Audio工程师进阶系列【原创干货持续更新中……】🚀 优质专栏:多媒体系统工程师系列【原创干货持续更新中……】🚀 人生格言: 人生从来没有捷径,只…

人工智能迷惑行为大赏!

目录 人工智能迷惑行为大赏 一&#xff1a;人工智能的“幽默”瞬间 1. 图像识别出现AI的极限 2. 小批量梯度下降优化器 3. 智能聊天机器人的冰雹问题 4. 大语言模型-3经典语录 二&#xff1a;技术原理探究 1. 深度学习 2. 机器学习 3. 自然语言处理 4. 计算机视觉 三…

博士推荐 | 拥有10多年纺织工程经验,纤维与聚合物科学博士

编辑 / 木子 审核 / 朝阳 伟骅英才 伟骅英才致力于以大数据、区块链、AI人工智能等前沿技术打造开放的人力资本生态&#xff0c;用科技解决职业领域问题&#xff0c;提升行业数字化服务水平&#xff0c;提供创新型的产业与人才一体化服务的人力资源解决方案和示范平台&#x…

什么是架构?架构设计原则是哪些?什么是设计模式?设计模式有哪些?

什么是架构?架构设计原则是哪些?什么是设计模式?设计模式有哪些? 架构的本质 架构本身是一种抽象的、来自建筑学的体系结构,其在企业及IT系统中被广泛应用。 架构的本质是对事物复杂性的管理,是对一个企业、一个公司、一个系统复杂的内部关系进行结构化、体系化的抽象,…

骨传导游泳耳机哪个牌子好?四款实力扛鼎的游泳耳机推荐

游泳是一项全身性的运动&#xff0c;能够有效锻炼身体、释放压力。然而&#xff0c;在水下欣赏音乐却成为了一项难题。普通的耳机在水中无法使用&#xff0c;而骨传导技术的出现&#xff0c;让游泳与音乐完美结合。今天&#xff0c;我们将为大家推荐四款超强的的骨传导游泳耳机…

分享一个国内可用的AIGC网站,PC/手机端通用|免费无限制,支持Claude3 Claude2

背景 AIGC作为一种基于人工智能技术的自然语言处理工具&#xff0c;近期的热度直接沸腾&#x1f30b;。 作为一个AI爱好者&#xff0c;翻遍了各大基于AIGC的网站&#xff0c;终于找到一个免费&#xff01;免登陆&#xff01;手机电脑通用&#xff01;国内可直接对话的AIGC&am…

EasyRecovery恢复电脑丢失数据怎么样?

电脑是我们大家熟悉并且常用的数据存储设备&#xff0c;也是综合性非常强的数据处理设备。对于电脑设备来讲&#xff0c;最主要的数据存储介质是硬盘&#xff0c;电脑硬盘被划分成多个分区&#xff0c;在电脑上表现为C盘&#xff0c;E盘等&#xff0c;用来保存系统文件以及其他…

记OnlyOffice的两个大坑

开发版&#xff0c;容器部署&#xff0c;试用许可已安装。 word&#xff0c;ppt&#xff0c;excel均能正常浏览。 自带的下载菜单按钮能用。 但config里自定义的downloadAs方法却不一而足。 word能正常下载&#xff0c;excel和ppt都不行。 仔细比对调试了代码。发现app.js…

yolov5-v6.0详细解读

yolov5-v6.0详细解读 一、yolov5版本介绍二、网络结构2.1 Backbone特征提取部分2.1.1 ConvBNSiLU模块2.1.2 C3模块2.1.2.1 BottleNeck模块 2.1.3 SPPF模块 2.2 Neck特征融合部分2.2.1 FPN2.2.2 PANet 2.3Head模块 三、目标框回归3.1 yolo标注格式3.2 yolov4目标回归框3.3 yolov…

《行业指标体系白皮书》重磅发布,剖析指标建设困境,构建前瞻性的指标体系(附下载)

正处于企业指标建设过程中的你&#xff0c;是否经常遇到这样的问题&#xff1a; • 各个部门独立建设信息系统&#xff0c;由此产生的指标定义和计算方式各异&#xff0c;导致管理层无法快速准确地掌握整体业务运行状况 • 缺乏对指标的统一管理和规范&#xff0c;产生重复的指…

IO复用之select

目录 一.select方法介绍 2.1 select 系统调用的原型 2.2 集合的数据结构 2.2.1 fd_set 结构如下: 2.2.2 关于集合fd_set的解析 2.3 select第一个参数 2.4 select方法之超时时间timeout 2.5 select方法的用法简述及返回值 2.6 如何检测集合中有哪些描述符有事件就绪 三…

【建议收藏】大气颗粒物与VOCs PMF源解析

查看原文>>>最新大气颗粒物与VOCs PMF源解析实践技术应用 目前&#xff0c;大气颗粒物和臭氧污染成为我国亟待解决的环境问题。颗粒物和臭氧污染不仅对气候和环境有重要影响&#xff0c;而且对人体健康有严重损害。而臭氧的前体物之一为挥发性有机物&#xff08;VOCs&…

Orange3数据预处理(离散化组件)

离散化&#xff1a;将数值属性转换为分类属性。 输出 数据&#xff1a;具有离散化值的数据集 设置离散化的默认方法。 选择变量以为每个变量设置特定的离散化方法。将鼠标悬停在变量上显示区间。 离散化方法Keep numeric(保持数值)&#xff1a;保持变量不变。Remove (移除)&a…