Whisper-AT:抗噪语音识别模型(Whisper)实现通用音频事件标记(Audio Tagger)

1.概述:

       Whisper-AT 是建立在 Whisper 自动语音识别(ASR)模型基础上的一个模型。Whisper 模型使用了一个包含 68 万小时标注语音的大规模语料库进行训练,这些语料是在各种不同条件下录制的。Whisper 模型以其在现实背景噪音(如音乐)下的鲁棒性著称。尽管如此,其音频表示并非噪音不变,而是与非语音声音高度相关。这意味着 Whisper 在识别语音时会依据背景噪音类型进行调整

主要发现:

  1. 噪音变化的表示:

    • Whisper 的音频表示编码了丰富的非语音背景声音信息,这与通常追求噪音不变表示的 ASR 模型目标不同。
    • 这一特性使得 Whisper 能够在各种噪音条件下通过识别和适应噪音来保持其鲁棒性。
  2. ASR 和音频标签的统一模型:

    • 通过冻结 Whisper 模型的骨干网络,并在其上训练一个轻量级的音频标签模型,Whisper-AT 可以在一次前向传递中同时识别音频事件和语音文本,额外的计算成本不足 1%。
    • Whisper-AT 在音频事件检测方面表现出色,同时保持了 Whisper 的 ASR 功能。

技术细节:

  1. Whisper ASR 模型:

    • Whisper 使用基于 Transformer 的编码器-解码器架构。
    • 其训练集包括从互联网上收集的 68 万小时音频-文本对,涵盖了广泛的环境、录音设置、说话人和语言。
  2. 抗噪机制:

    • Whisper 的鲁棒性并非通过噪音不变性实现,而是通过在其表示中编码噪音类型。
    • 这一机制使得 Whisper 能够根据背景噪音类型来转录文本,从而在嘈杂条件下表现优越。
  3. 构建 Whisper-AT:

    • Whisper-AT 是通过在 Whisper 模型上添加新的音频标签层而构建的,未修改其原始权重。

    • 探索了不同的音频标签层集成方法,包括:
      • Last-MLP:对 Whisper 的最后一层表示进行时间均值池化,然后应用线性层。
      • WA-MLP:对所有层的表示进行加权平均,然后应用线性层。
      • WA-Tr:用时间 Transformer 层替换线性层。
      • TL-Tr:使用时间和层次 Transformer 处理所有层的表示。
  4. 效率考量:

    • 为保持计算效率,采用了各种策略,例如减少表示的序列长度,并在应用音频标签 Transformer 之前可选地降低维度。

性能:

  • Whisper-AT 在 AudioSet 上达到了 41.5 的 mAP,略低于独立的音频标签模型,但处理速度显著更快,超过 40 倍。

意义:

  • 能够同时执行 ASR 和音频标签任务,使得 Whisper-AT 非常适合于视频转录、语音助手和助听器系统等应用场景,在这些场景中需要同时进行语音文本和声学场景分析。

2.代码:

       欲了解详细的实现和实验结果,请访问 GitHub: github.com/yuangongnd/whisper-at.下面是对 Whisper-AT 代码的详细解释。我们将逐步解析其主要组件和功能,帮助理解其工作原理。

安装和准备

首先,确保你已经安装了 Whisper 和相关的依赖项:

pip install git+https://github.com/openai/whisper.git
pip install torch torchaudio
pip install transformers datasets

代码结构

简要 Whisper-AT 的代码结构如下所示:

Whisper-AT/
│
├── whisper_at.py
├── train.py
├── dataset.py
├── utils.py
└── README.md

whisper_at.py - Whisper-AT 模型

import torch
import torch.nn as nn
import whisperclass WhisperAT(nn.Module):def __init__(self, model_name="base"):super(WhisperAT, self).__init__()self.whisper = whisper.load_model(model_name)self.audio_tagging_head = nn.Linear(self.whisper.dims, 527)  # 527 是 AudioSet 的标签数def forward(self, audio):# 获取 Whisper 的中间表示with torch.no_grad():features = self.whisper.encode(audio)# 通过音频标签头audio_tagging_output = self.audio_tagging_head(features.mean(dim=1))return audio_tagging_output

train.py - 训练脚本

import torch
from torch.utils.data import DataLoader
from dataset import AudioSetDataset
from whisper_at import WhisperAT
import torch.optim as optim
import torch.nn.functional as Fdef train():# 加载数据集train_dataset = AudioSetDataset("path/to/training/data")train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)# 初始化模型model = WhisperAT()model.train()# 定义优化器optimizer = optim.Adam(model.parameters(), lr=1e-4)for epoch in range(10):  # 假设训练10个epochfor audio, labels in train_loader:optimizer.zero_grad()# 前向传播outputs = model(audio)# 计算损失loss = F.binary_cross_entropy_with_logits(outputs, labels)# 反向传播和优化loss.backward()optimizer.step()print(f"Epoch {epoch}, Loss: {loss.item()}")if __name__ == "__main__":train()

dataset.py - 数据集处理

import torch
from torch.utils.data import Dataset
import torchaudioclass AudioSetDataset(Dataset):def __init__(self, data_path):self.data_path = data_pathself.audio_files = [...]  # 这里假设你有一个包含所有音频文件路径的列表self.labels = [...]  # 这里假设你有一个包含所有对应标签的列表def __len__(self):return len(self.audio_files)def __getitem__(self, idx):# 加载音频audio, sample_rate = torchaudio.load(self.audio_files[idx])# 获取对应标签labels = torch.tensor(self.labels[idx])return audio, labels

utils.py - 辅助功能

import torchdef save_model(model, path):torch.save(model.state_dict(), path)def load_model(model, path):model.load_state_dict(torch.load(path))model.eval()

详细解释

  1. Whisper-AT 模型 (whisper_at.py):

    • WhisperAT 类继承自 nn.Module,初始化时加载 Whisper 模型,并在其上添加一个线性层用于音频标签任务。
    • forward 方法首先调用 Whisper 模型的 encode 方法获取音频特征,然后将这些特征传递给音频标签头(线性层)以生成标签输出。
  2. 训练脚本 (train.py):

    • train 函数中,数据集被加载并传递给 DataLoader。
    • 模型实例化并设置为训练模式。
    • 定义了 Adam 优化器和二进制交叉熵损失函数。
    • 在训练循环中,音频输入通过模型生成输出,计算损失并执行反向传播和优化。
  3. 数据集处理 (dataset.py):

    • AudioSetDataset 类继承自 Dataset,实现了音频数据和标签的加载。
    • __getitem__ 方法加载音频文件并返回音频张量和对应标签。
  4. 辅助功能 (utils.py):

    • 包含保存和加载模型状态的函数,方便模型的持久化和恢复。

       通过以上代码结构和解释,可以帮助理解 Whisper-AT 的实现和训练流程。可以根据需要扩展这些代码来适应具体的应用场景和数据集。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/19557.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

探究 Meme 的金融与社交属性

原文标题:《A Social and Financial Study of Memecoins》撰文:Andrew Hong编译:Chris,Techub News 每一个市场周期都伴随着 Meme 代币的出现。一群人围绕着某个 Meme 集结起来,暂时抬高了某个资产的价格(从…

Github Copilot登录账号,完美支持chat

Github Copilot 代码补全等功能,提高写代码的效率 https://web.52shizhan.cn/activity/copilot 登录授权后,已经可以使用,完美。如图

flutter 自动生成静态资源的引用

flutter_gen库的使用 第一步、项目yarml中dev_dependencies 新增一下flutter_gen_runner 和build_runner dev_dependencies:build_runner: nullflutter_gen_runner: null # flutter packages pub run build_runner build 第二步、新增配置信息 和(dev_dependencies 同级的) …

大话设计模式学习笔记

目录 工厂模式策略模式备忘录模式(快照模式)代理模式单例模式迭代器模式访问者模式观察者模式解释器模式命令模式模板方法模式桥接模式适配器模式外观模式享元模式原型模式责任链模式中介者模式装饰模式状态模式 工厂模式 策略模式 核心:封装…

03.k8s常用的资源

3.k8s常用的资源 3.1 创建pod资源 k8s yaml的主要组成 apiVersion: v1 api版本 kind: pod 资源类型 metadata: 属性 spec: 详细上传nginx镜像文件,并且上传私有仓库里面 k8s_pod.yaml apiVersion: v1 kind: Pod metadata:name: nginxlabels:app: we…

prometheus 标签选择器 正则表达式 = 、=~

Prometheus expression是一种用于查询和操作Prometheus时间序列数据的查询语言。它具有一套丰富的函数和运算符,可以用于提取、聚合和转换时间序列数据。 正则表达式在Prometheus expresion中也被广泛使用,可以用于匹配和过滤时间序列。 Prometheus ex…

Tuxera Ntfs For Mac 2023的具体使用方法

大家都知道由于操作系统的原因,在苹果电脑上不能够读写NTFS磁盘,但是,今天小编带来的这款tuxera ntfs 2024 mac 破解版,完美的解决了这个问题。这是一款在macOS平台上使用的磁盘读写软件,能够实现苹果Mac OS X系统读写…

CSS实验性功能及CSS4特性

CSS4目前仍然是一个宽泛的概念,因为CSS的发展通常是通过一系列逐步完善的模块来进行的,而不是一次性推出一个全新的“第四代”。许多所谓的“CSS4”特性实际上是正在开发或已经草案阶段的CSS模块,它们可能在未来的CSS规范中被正式采纳。 选择器4: :is() 和 :where() 伪类允…

Docker的数据管理(数据卷+数据卷容器)

文章目录 一、Docker的数据管理1、概述2、主要的技术(三种数据挂载方式)2.1、数据卷(Volumes)2.2、绑定挂载(Bind mounts)2.3、tmpfs挂载(Tmpfs mounts)2.4、之间的关系(…

偏微分方程算法之二阶双曲型方程交替方向隐格式(变形一)

目录 一、研究目标 二、变形 三、算例实现 四、计算结果 本专栏介绍了二阶双曲型偏微分方程的交替方向隐格式的介绍和推导(链接如下),本节将进一步研究二维双曲型方程初边值问题其它的交替方向隐格式。

示例丨医学、医药类查新点填写参考案例

根据《科技查新技术规范》GB/T 32003-2015,科学技术要点是必须要包含查新点内容的,而查新点就是科学技术要点中能够体现查新项目新颖性和技术进步的技术特征点。 在日常查新工作的接待中,我们发现医学、医药类查新合同上查新点的书写&#x…

计算机tcp/ip网络通信过程

目录 (1)同一网段两台计算机通信过程 (2)不同网段的两台计算机通信过程 (3)目的主机收到数据包后的解包过程 (1)同一网段两台计算机通信过程 如果两台计算机在同一个局域网中的同…

算法(九)希尔排序

文章目录 希尔排序简介代码实现 希尔排序简介 希尔排序(shell sort)选定一个小于N(数列长度)的整数gap作为第一增量,然后将所有距离为gap的元素分成一组,然后对每一组的元素进行插入排序。然后再取一个比前…

(1+X)Java程序设计高级(一)

Throwable:异常的基类,所有异常都继承自 java.lang.Throwable 类,Throwable 类有两个直接子类:Error 类和 Exception 类。Error:是 Java 应用程序本身无法恢复的严重错误,应用程序不需要捕获、处理这些严重…

7.1 Go 错误的概念

💝💝💝欢迎莅临我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:「stormsha的主页」…

【SQL每日一练】查询二进制树节点

文章目录 题目一、题析二、题解1.MySQL/SqlServer2.Oracle 题目 有一个表BST,其中包含两列:N和P,其中N表示二进制树中节点的值,P是N的父级。 编写一个查询,以查找按节点值排序的二进制树的节点类型。为每个节点输出以…

迅狐跨境电商系统源码:技术栈与多端集成

随着全球化贸易的不断深入,跨境电商系统源码成为了连接不同国家和地区消费者与商家的重要桥梁。本文将探讨跨境电商系统源码的技术栈以及如何通过多端集成来提升用户体验。 技术栈概览 跨境电商系统源码的技术栈是构建高效、稳定平台的基础。以下是构建跨境电商系…

IP65 IP45 IP68等等数字防护等级

第一个数字的代表意义 : 0 表示无防护 ,对外界的人或物无特殊之防护 1. 表示防止大于50mm的固体物体侵入 ,防止人体(如手掌)因意外而接触,内部之零件。防止较大尺寸(直径大于50mm)的…

Oracle数据块如何存储真实数据

上周休假了几天,颓废了,没有输出。今天写一点内容。 先抛出一个问题。表中的数据在Oracle数据块中是如何存储的呢?今天简单说一下这个问题。通常数据库中的表会存储字符,数字,日期 这3种常见的数据类型。下面的例子就用这3种数据类型作说明 首先,Oracle数据块底层存储这…

Github 2024-05-31开源项目日报 Top10

根据Github Trendings的统计,今日(2024-05-31统计)共有10个项目上榜。根据开发语言中项目的数量,汇总情况如下: 开发语言项目数量Python项目4TypeScript项目3Jupyter Notebook项目2Vue项目1Cuda项目1Elixir项目1简单、纯净的C/CUDA中的LLM培训 创建周期:3 天开发语言:Cuda…