深入理解 PyTorch 的数据加载

深入理解 PyTorch 的数据加载

在进行深度学习时,数据的加载和预处理是至关重要的步骤。PyTorch 提供了 torch.utils.data.Datasettorch.utils.data.DataLoader 这两个强大的工具来简化这一过程。

1. torch.utils.data.Dataset

Dataset 是 PyTorch 中用于定义数据集的基类。通过继承 Dataset 类,您可以创建自己的数据集。实现 __len____getitem__ 方法是创建自定义数据集的关键。

主要方法

  • __len__:返回数据集的大小,表示数据集中样本的总数。
  • __getitem__:根据索引返回数据集中的一个样本。在此方法中,您可以实现数据的读取和预处理逻辑。

注意事项

  • 数据长度一致性:在处理音频或图像数据时,确保每个样本的长度一致是非常重要的。如果样本长度不一致,DataLoader 将无法将它们堆叠成批次,可能会导致运行时错误。
  • 数据预处理:在 __getitem__ 方法中,您可以进行必要的预处理,例如归一化、数据增强等。

2. torch.utils.data.DataLoader

DataLoader 是 PyTorch 中用于加载数据的工具。它可以自动处理批次、打乱数据和多线程加载。

主要参数

  • dataset:需要加载的数据集。
  • batch_size:每个批次的样本数量。
  • shuffle:在每个 epoch 之前是否打乱数据。
  • num_workers:用于数据加载的子进程数量,可以加速数据加载。

注意事项

  • 批次大小:选择合适的批次大小可以提高训练效率。过小的批次可能导致训练不稳定,而过大的批次可能导致内存溢出。
  • 数据打乱:在每个 epoch 之前打乱数据可以帮助模型更好地泛化。

代码示例

以下是一个完整的示例代码,展示了如何使用 torch.utils.data.Datasettorch.utils.data.DataLoader 来处理音频数据。

import os
import random
import torch
from torch.utils.data import Dataset, DataLoader
import torchaudio
import torch.nn.functional as F
import librosa
import numpy as npclass AudioDataset(Dataset):def __init__(self, mix_dir, noise_dir, voice_dir, target_length=1024, num_samples=5):self.mix_dir = mix_dirself.noise_dir = noise_dirself.voice_dir = voice_dirself.target_length = target_lengthself.num_samples = num_samples  # 每个音频文件提取的样本数量# 获取所有混合音频文件名(去掉 '_mix.wav' 后缀)self.file_names = [f.replace('_mix.wav', '') for f in os.listdir(mix_dir) if f.endswith('_mix.wav')]def __len__(self):return len(self.file_names)def __getitem__(self, idx):# 获取文件名file_name = self.file_names[idx]# 构建文件路径mix_path = os.path.join(self.mix_dir, f"{file_name}_mix.wav")noise_path = os.path.join(self.noise_dir, f"{file_name}_noise.wav")voice_path = os.path.join(self.voice_dir, f"{file_name}_voice.wav")# 读取音频文件mix, _ = torchaudio.load(mix_path)noise, _ = torchaudio.load(noise_path)voice, _ = torchaudio.load(voice_path)# 随机选择多个 1024 个数据点的片段mix_samples = self._get_random_samples(mix)noise_samples = self._get_random_samples(noise)voice_samples = self._get_random_samples(voice)# 计算 STFT 并返回频谱mix_spectrogram = [self._compute_stft(sample) for sample in mix_samples]noise_spectrogram = [self._compute_stft(sample) for sample in noise_samples]voice_spectrogram = [self._compute_stft(sample) for sample in voice_samples]return mix_spectrogram, noise_spectrogram, voice_spectrogramdef _get_random_samples(self, audio):samples = []for _ in range(self.num_samples):if audio.size(1) > self.target_length:start_idx = random.randint(0, audio.size(1) - self.target_length)sample = audio[:, start_idx:start_idx + self.target_length]else:# 如果音频长度小于目标长度,则填充padding = self.target_length - audio.size(1)sample = F.pad(audio, (0, padding), value=0)  # 在末尾填充零samples.append(sample)return samples@staticmethoddef _compute_stft(audio):# 将张量转换为 NumPy 数组audio_np = audio.numpy().flatten()  # 转换为一维数组# 计算 STFTstft_result = librosa.stft(audio_np, n_fft=512, hop_length=256)# 转换为幅度谱spectrogram = np.abs(stft_result)return spectrogram  # 返回幅度谱# 使用示例
if __name__ == "__main__":# 文件夹路径mix_directory = './ds/mixtures'noise_directory = './ds/noises'voice_directory = './ds/real_voices'# 创建数据集audio_dataset = AudioDataset(mix_directory, noise_directory, voice_directory, num_samples=5)# 创建数据加载器data_loader = DataLoader(audio_dataset, batch_size=16, shuffle=True)# 遍历数据for mix_spectrogram, noise_spectrogram, voice_spectrogram in data_loader:# 将样本在第一维拼接mix_tensor = torch.cat(mix_spectrogram, dim=0)  # 拼接 mix_samplesnoise_tensor = torch.cat(noise_spectrogram, dim=0)  # 拼接 noise_samplesvoice_tensor = torch.cat(voice_spectrogram, dim=0)  # 拼接 voice_samples# 处理您的频谱数据print(f'Mix spectrogram shape: {mix_tensor.shape}')  # 打印每个频谱的形状print(f'Noise spectrogram shape: {noise_tensor.shape}')  # 打印每个频谱的形状print(f'Voice spectrogram shape: {voice_tensor.shape}')  # 打印每个频谱的形状
Mix spectrogram shape: torch.Size([80, 257, 5])
Noise spectrogram shape: torch.Size([80, 257, 5])
Voice spectrogram shape: torch.Size([80, 257, 5])
Mix spectrogram shape: torch.Size([80, 257, 5])
Noise spectrogram shape: torch.Size([80, 257, 5])
Voice spectrogram shape: torch.Size([80, 257, 5])
Mix spectrogram shape: torch.Size([40, 257, 5])
Noise spectrogram shape: torch.Size([40, 257, 5])
Voice spectrogram shape: torch.Size([40, 257, 5])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/61617.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

.NET 9 全面上线:开启开发新纪元

微软最新发布的.NET 9为开发者带来了翻天覆地的变化,这次升级不仅仅是一次普通的版本迭代,更像是为开发者打开了一扇通往未来的大门。 性能革新:AOT编译的突破性进展 原生提前编译(AOT)是此次更新最耀眼的明珠。过去&…

蓝桥杯每日真题 - 第19天

题目:(费用报销) 题目描述(13届 C&C B组F题) 解题思路: 1. 问题抽象 本问题可以看作一个限制条件较多的优化问题,核心是如何在金额和时间约束下选择最优方案: 动态规划是理想…

数据结构及算法--排序篇

在 C 语言中,可以通过嵌套循环和比较运算符来实现常见的排序算法,比如冒泡排序、选择排序或插入排序 目录 基础算法: 1.冒泡排序(Bubble Sort) 2.选择排序(Selection Sort) 3.插入排序&…

科研实验室的数字化转型:Spring Boot系统

1系统概述 1.1 研究背景 随着计算机技术的发展以及计算机网络的逐渐普及,互联网成为人们查找信息的重要场所,二十一世纪是信息的时代,所以信息的管理显得特别重要。因此,使用计算机来管理实验室管理系统的相关信息成为必然。开发合…

【Redis】持久化机制RDB与AOF

一、RDB RDB模式是就是将内存中的数据存储到磁盘中,等到连接断开的时候会进行持久化操作。但是如果服务器宕机,会导致这个持久化机制不会执行,但是内存中的文件会直接丢失。所以可以设置一个触发机制,save 60 1000 就是代表60秒 执…

Excel——宏教程(精简版)

一、宏的简介 1、什么是宏? Excel宏是一种自动化工具,它允许用户录制一系列操作并将其转换为VBA(Visual Basic for Applications)代码。这样,用户可以在需要时执行这些操作,以自动化Excel任务。 2、宏的优点 我们可以利用宏来…

【MyBatisPlus·最新教程】包含多个改造案例,常用注解、条件构造器、代码生成、静态工具、类型处理器、分页插件、自动填充字段

文章目录 一、MyBatis-Plus简介二、快速入门1、环境准备2、将mybatis项目改造成mybatis-plus项目(1)引入MybatisPlus依赖,代替MyBatis依赖(2)配置Mapper包扫描路径(3)定义Mapper接口并继承BaseM…

Git 多仓库提交用户信息动态设置

Git 多仓库提交用户信息动态设置 原文地址:dddhl.cn 前言 在日常开发中,我们可能需要同时管理多个远程仓库(如 GitHub、Gitee、GitLab),而每个仓库使用不同的邮箱和用户名。比如,GitHub 和 Gitee 使用相…

【spring】spring单例模式与锁对象作用域的分析

前言:spring默认是单例模式,这句话大家应该都不陌生;因为绝大多数都是使用单例模式,避免了某些问题,可能导致对某些场景缺乏思考。本文通过结合lock锁将单例模式、静态变量、锁对象等知识点串联起来。 文章目录 synchr…

Cyberchef使用功能之-多种压缩/解压缩操作对比

cyberchef的compression操作大类中有大量的压缩和解压缩操作,每种操作的功能和区别是什么,本章将进行讲解,作为我的专栏《Cyberchef 从入门到精通教程》中的一篇,详见这里。 关于文件格式和压缩算法的理论部分在之前的文章《压缩…

Elasticsearch开启认证及kibana密码登陆

Elasticsearch不允许root用户运行,使用root用户为其创建一个用户es,为用户es配置密码,并切换到es用户。 adduser elastic passwd elastic su elasticElasticsearch(简称ES)是一个基于Lucene的搜索服务器。它提供了一个分布式、多用户能力的全文搜索引擎,基于RESTful web…

C++初阶学习第十一弹——list的用法和模拟实现

目录 一、list的使用 二.list的模拟实现 三.总结 一、list的使用 list的底层是双向链表结构&#xff0c;双向链表中每个元素存储在互不相关的独立节点中&#xff0c;在节点中通过指针指向 其前一个元素和后一个元素。 常见的list的函数的使用 std::list<int> It {1,…

Postman之安装及汉化基本使用介绍

系列文章目录 1.Postman之安装及汉化基本使用介绍 2.Postman之变量操作 3.Postman之数据提取 4.Postman之pm.test断言操作 5.Postman之newman Postman之安装及汉化 1.安装及汉化postman2.基本使用介绍2.1.基本功能&#xff1a;2.2.编辑、查看、设置环境、全局、集合变量2.3.复制…

VUE 指令 事件绑定,.stop阻止冒泡

1、VUE 的模板语法和指令 目的增强html的功能 所有的指令以自定义属性的方式去写 v-xxx ,指令就是vue提供给我们能够更方便将数据和页面展示出来的操作&#xff0c;具体就是以数据去驱动DOM ,简化DOM操作的行为。 2、内容渲染指令 ① {{}} 模板渲染&#xff08;模板引擎&am…

Android Java 中Lambda 表达式

学习笔记 在 Java 和 Android 中&#xff0c;除了你提到的 setOnClickListener() 使用 Lambda 表达式和匿名类的写法&#xff0c;还有许多类似的场景可以通过 Lambda 表达式来简化代码。以下是一些常见的类似用法以及它们如何通过 Lambda 表达式和匿名类实现。 1. setOnClick…

Android 文件分段上传和下载方案

一、背景 Android 中的大文件下载需要使用分段下载&#xff0c;下载通常是在线程中进行的&#xff0c;假如有5段&#xff0c;那同时5个线程去执行下载&#xff0c;请求http返回文件流后&#xff0c;需要将多个文件流同时写进同一个文件&#xff0c;这里用到 RandomAccessFile…

算法.图论-习题全集(Updating)

文章目录 本节设置的意义并查集篇并查集简介以及常见技巧并查集板子(洛谷)情侣牵手问题相似的字符串组岛屿数量(并查集做法)省份数量移除最多的同行或同列石头最大的人工岛找出知晓秘密的所有专家 建图及其拓扑排序篇链式前向星建图板子课程表 本节设置的意义 主要就是为了复习…

易语言学习-cnblog

易语言数据类型 数值转换命令&#xff08;自己学&#xff09; 数值到大写&#xff08;&#xff09;将一个数值转换到中文读法&#xff0c;第二个参数为是否为简体。 数值到大写&#xff08;123.44&#xff0c;假&#xff09; 猜测结果 数值到金额&#xff08;&#xff09;将双…

atob()为啥明明表示base64toASCII却叫atob?(2)

上篇谈到JavaScript中的atob()函数实际是表示ASCII to binary而非ASCII to base64&#xff0c;那既然函数底层产生的是二进制内容&#xff0c;那为什么咱们在JavaScript环境中通过atob()解码可以直接得到字符串&#xff1f;答案下文揭晓ෆ( ˶ᵕ˶)ෆ 在JavaScript中&#x…

树莓派的开机自启

前言 很多比赛你的装置是不能动的,就是你直接拿上去,不允许用电脑启动. 树莓派开机自启的三种方式 1,快捷方式自启动 就是在我们用户的目录下(他这里是/home/pi,我的是/home/zw),ctrlh可以显示隐藏文件价, #没有就找你自己的用户目录 cd /.config#在这里面建一个autostart文…