1.概述:
Whisper-AT 是建立在 Whisper 自动语音识别(ASR)模型基础上的一个模型。Whisper 模型使用了一个包含 68 万小时标注语音的大规模语料库进行训练,这些语料是在各种不同条件下录制的。Whisper 模型以其在现实背景噪音(如音乐)下的鲁棒性著称。尽管如此,其音频表示并非噪音不变,而是与非语音声音高度相关。这意味着 Whisper 在识别语音时会依据背景噪音类型进行调整。
主要发现:
-
噪音变化的表示:
- Whisper 的音频表示编码了丰富的非语音背景声音信息,这与通常追求噪音不变表示的 ASR 模型目标不同。
- 这一特性使得 Whisper 能够在各种噪音条件下通过识别和适应噪音来保持其鲁棒性。
-
ASR 和音频标签的统一模型:
- 通过冻结 Whisper 模型的骨干网络,并在其上训练一个轻量级的音频标签模型,Whisper-AT 可以在一次前向传递中同时识别音频事件和语音文本,额外的计算成本不足 1%。
- Whisper-AT 在音频事件检测方面表现出色,同时保持了 Whisper 的 ASR 功能。
技术细节:
-
Whisper ASR 模型:
- Whisper 使用基于 Transformer 的编码器-解码器架构。
- 其训练集包括从互联网上收集的 68 万小时音频-文本对,涵盖了广泛的环境、录音设置、说话人和语言。
-
抗噪机制:
- Whisper 的鲁棒性并非通过噪音不变性实现,而是通过在其表示中编码噪音类型。
- 这一机制使得 Whisper 能够根据背景噪音类型来转录文本,从而在嘈杂条件下表现优越。
-
构建 Whisper-AT:
- Whisper-AT 是通过在 Whisper 模型上添加新的音频标签层而构建的,未修改其原始权重。
- 探索了不同的音频标签层集成方法,包括:
- Last-MLP:对 Whisper 的最后一层表示进行时间均值池化,然后应用线性层。
- WA-MLP:对所有层的表示进行加权平均,然后应用线性层。
- WA-Tr:用时间 Transformer 层替换线性层。
- TL-Tr:使用时间和层次 Transformer 处理所有层的表示。
-
效率考量:
- 为保持计算效率,采用了各种策略,例如减少表示的序列长度,并在应用音频标签 Transformer 之前可选地降低维度。
性能:
- Whisper-AT 在 AudioSet 上达到了 41.5 的 mAP,略低于独立的音频标签模型,但处理速度显著更快,超过 40 倍。
意义:
- 能够同时执行 ASR 和音频标签任务,使得 Whisper-AT 非常适合于视频转录、语音助手和助听器系统等应用场景,在这些场景中需要同时进行语音文本和声学场景分析。
2.代码:
欲了解详细的实现和实验结果,请访问 GitHub: github.com/yuangongnd/whisper-at.下面是对 Whisper-AT 代码的详细解释。我们将逐步解析其主要组件和功能,帮助理解其工作原理。
安装和准备
首先,确保你已经安装了 Whisper 和相关的依赖项:
pip install git+https://github.com/openai/whisper.git
pip install torch torchaudio
pip install transformers datasets
代码结构
简要 Whisper-AT 的代码结构如下所示:
Whisper-AT/
│
├── whisper_at.py
├── train.py
├── dataset.py
├── utils.py
└── README.md
whisper_at.py
- Whisper-AT 模型
import torch
import torch.nn as nn
import whisperclass WhisperAT(nn.Module):def __init__(self, model_name="base"):super(WhisperAT, self).__init__()self.whisper = whisper.load_model(model_name)self.audio_tagging_head = nn.Linear(self.whisper.dims, 527) # 527 是 AudioSet 的标签数def forward(self, audio):# 获取 Whisper 的中间表示with torch.no_grad():features = self.whisper.encode(audio)# 通过音频标签头audio_tagging_output = self.audio_tagging_head(features.mean(dim=1))return audio_tagging_output
train.py
- 训练脚本
import torch
from torch.utils.data import DataLoader
from dataset import AudioSetDataset
from whisper_at import WhisperAT
import torch.optim as optim
import torch.nn.functional as Fdef train():# 加载数据集train_dataset = AudioSetDataset("path/to/training/data")train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)# 初始化模型model = WhisperAT()model.train()# 定义优化器optimizer = optim.Adam(model.parameters(), lr=1e-4)for epoch in range(10): # 假设训练10个epochfor audio, labels in train_loader:optimizer.zero_grad()# 前向传播outputs = model(audio)# 计算损失loss = F.binary_cross_entropy_with_logits(outputs, labels)# 反向传播和优化loss.backward()optimizer.step()print(f"Epoch {epoch}, Loss: {loss.item()}")if __name__ == "__main__":train()
dataset.py
- 数据集处理
import torch
from torch.utils.data import Dataset
import torchaudioclass AudioSetDataset(Dataset):def __init__(self, data_path):self.data_path = data_pathself.audio_files = [...] # 这里假设你有一个包含所有音频文件路径的列表self.labels = [...] # 这里假设你有一个包含所有对应标签的列表def __len__(self):return len(self.audio_files)def __getitem__(self, idx):# 加载音频audio, sample_rate = torchaudio.load(self.audio_files[idx])# 获取对应标签labels = torch.tensor(self.labels[idx])return audio, labels
utils.py
- 辅助功能
import torchdef save_model(model, path):torch.save(model.state_dict(), path)def load_model(model, path):model.load_state_dict(torch.load(path))model.eval()
详细解释
-
Whisper-AT 模型 (
whisper_at.py
):WhisperAT
类继承自nn.Module
,初始化时加载 Whisper 模型,并在其上添加一个线性层用于音频标签任务。forward
方法首先调用 Whisper 模型的encode
方法获取音频特征,然后将这些特征传递给音频标签头(线性层)以生成标签输出。
-
训练脚本 (
train.py
):train
函数中,数据集被加载并传递给 DataLoader。- 模型实例化并设置为训练模式。
- 定义了 Adam 优化器和二进制交叉熵损失函数。
- 在训练循环中,音频输入通过模型生成输出,计算损失并执行反向传播和优化。
-
数据集处理 (
dataset.py
):AudioSetDataset
类继承自Dataset
,实现了音频数据和标签的加载。__getitem__
方法加载音频文件并返回音频张量和对应标签。
-
辅助功能 (
utils.py
):- 包含保存和加载模型状态的函数,方便模型的持久化和恢复。
通过以上代码结构和解释,可以帮助理解 Whisper-AT 的实现和训练流程。可以根据需要扩展这些代码来适应具体的应用场景和数据集。