Whisper-AT:抗噪语音识别模型(Whisper)实现通用音频事件标记(Audio Tagger)

1.概述:

       Whisper-AT 是建立在 Whisper 自动语音识别(ASR)模型基础上的一个模型。Whisper 模型使用了一个包含 68 万小时标注语音的大规模语料库进行训练,这些语料是在各种不同条件下录制的。Whisper 模型以其在现实背景噪音(如音乐)下的鲁棒性著称。尽管如此,其音频表示并非噪音不变,而是与非语音声音高度相关。这意味着 Whisper 在识别语音时会依据背景噪音类型进行调整

主要发现:

  1. 噪音变化的表示:

    • Whisper 的音频表示编码了丰富的非语音背景声音信息,这与通常追求噪音不变表示的 ASR 模型目标不同。
    • 这一特性使得 Whisper 能够在各种噪音条件下通过识别和适应噪音来保持其鲁棒性。
  2. ASR 和音频标签的统一模型:

    • 通过冻结 Whisper 模型的骨干网络,并在其上训练一个轻量级的音频标签模型,Whisper-AT 可以在一次前向传递中同时识别音频事件和语音文本,额外的计算成本不足 1%。
    • Whisper-AT 在音频事件检测方面表现出色,同时保持了 Whisper 的 ASR 功能。

技术细节:

  1. Whisper ASR 模型:

    • Whisper 使用基于 Transformer 的编码器-解码器架构。
    • 其训练集包括从互联网上收集的 68 万小时音频-文本对,涵盖了广泛的环境、录音设置、说话人和语言。
  2. 抗噪机制:

    • Whisper 的鲁棒性并非通过噪音不变性实现,而是通过在其表示中编码噪音类型。
    • 这一机制使得 Whisper 能够根据背景噪音类型来转录文本,从而在嘈杂条件下表现优越。
  3. 构建 Whisper-AT:

    • Whisper-AT 是通过在 Whisper 模型上添加新的音频标签层而构建的,未修改其原始权重。

    • 探索了不同的音频标签层集成方法,包括:
      • Last-MLP:对 Whisper 的最后一层表示进行时间均值池化,然后应用线性层。
      • WA-MLP:对所有层的表示进行加权平均,然后应用线性层。
      • WA-Tr:用时间 Transformer 层替换线性层。
      • TL-Tr:使用时间和层次 Transformer 处理所有层的表示。
  4. 效率考量:

    • 为保持计算效率,采用了各种策略,例如减少表示的序列长度,并在应用音频标签 Transformer 之前可选地降低维度。

性能:

  • Whisper-AT 在 AudioSet 上达到了 41.5 的 mAP,略低于独立的音频标签模型,但处理速度显著更快,超过 40 倍。

意义:

  • 能够同时执行 ASR 和音频标签任务,使得 Whisper-AT 非常适合于视频转录、语音助手和助听器系统等应用场景,在这些场景中需要同时进行语音文本和声学场景分析。

2.代码:

       欲了解详细的实现和实验结果,请访问 GitHub: github.com/yuangongnd/whisper-at.下面是对 Whisper-AT 代码的详细解释。我们将逐步解析其主要组件和功能,帮助理解其工作原理。

安装和准备

首先,确保你已经安装了 Whisper 和相关的依赖项:

pip install git+https://github.com/openai/whisper.git
pip install torch torchaudio
pip install transformers datasets

代码结构

简要 Whisper-AT 的代码结构如下所示:

Whisper-AT/
│
├── whisper_at.py
├── train.py
├── dataset.py
├── utils.py
└── README.md

whisper_at.py - Whisper-AT 模型

import torch
import torch.nn as nn
import whisperclass WhisperAT(nn.Module):def __init__(self, model_name="base"):super(WhisperAT, self).__init__()self.whisper = whisper.load_model(model_name)self.audio_tagging_head = nn.Linear(self.whisper.dims, 527)  # 527 是 AudioSet 的标签数def forward(self, audio):# 获取 Whisper 的中间表示with torch.no_grad():features = self.whisper.encode(audio)# 通过音频标签头audio_tagging_output = self.audio_tagging_head(features.mean(dim=1))return audio_tagging_output

train.py - 训练脚本

import torch
from torch.utils.data import DataLoader
from dataset import AudioSetDataset
from whisper_at import WhisperAT
import torch.optim as optim
import torch.nn.functional as Fdef train():# 加载数据集train_dataset = AudioSetDataset("path/to/training/data")train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)# 初始化模型model = WhisperAT()model.train()# 定义优化器optimizer = optim.Adam(model.parameters(), lr=1e-4)for epoch in range(10):  # 假设训练10个epochfor audio, labels in train_loader:optimizer.zero_grad()# 前向传播outputs = model(audio)# 计算损失loss = F.binary_cross_entropy_with_logits(outputs, labels)# 反向传播和优化loss.backward()optimizer.step()print(f"Epoch {epoch}, Loss: {loss.item()}")if __name__ == "__main__":train()

dataset.py - 数据集处理

import torch
from torch.utils.data import Dataset
import torchaudioclass AudioSetDataset(Dataset):def __init__(self, data_path):self.data_path = data_pathself.audio_files = [...]  # 这里假设你有一个包含所有音频文件路径的列表self.labels = [...]  # 这里假设你有一个包含所有对应标签的列表def __len__(self):return len(self.audio_files)def __getitem__(self, idx):# 加载音频audio, sample_rate = torchaudio.load(self.audio_files[idx])# 获取对应标签labels = torch.tensor(self.labels[idx])return audio, labels

utils.py - 辅助功能

import torchdef save_model(model, path):torch.save(model.state_dict(), path)def load_model(model, path):model.load_state_dict(torch.load(path))model.eval()

详细解释

  1. Whisper-AT 模型 (whisper_at.py):

    • WhisperAT 类继承自 nn.Module,初始化时加载 Whisper 模型,并在其上添加一个线性层用于音频标签任务。
    • forward 方法首先调用 Whisper 模型的 encode 方法获取音频特征,然后将这些特征传递给音频标签头(线性层)以生成标签输出。
  2. 训练脚本 (train.py):

    • train 函数中,数据集被加载并传递给 DataLoader。
    • 模型实例化并设置为训练模式。
    • 定义了 Adam 优化器和二进制交叉熵损失函数。
    • 在训练循环中,音频输入通过模型生成输出,计算损失并执行反向传播和优化。
  3. 数据集处理 (dataset.py):

    • AudioSetDataset 类继承自 Dataset,实现了音频数据和标签的加载。
    • __getitem__ 方法加载音频文件并返回音频张量和对应标签。
  4. 辅助功能 (utils.py):

    • 包含保存和加载模型状态的函数,方便模型的持久化和恢复。

       通过以上代码结构和解释,可以帮助理解 Whisper-AT 的实现和训练流程。可以根据需要扩展这些代码来适应具体的应用场景和数据集。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/19557.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

探究 Meme 的金融与社交属性

原文标题:《A Social and Financial Study of Memecoins》撰文:Andrew Hong编译:Chris,Techub News 每一个市场周期都伴随着 Meme 代币的出现。一群人围绕着某个 Meme 集结起来,暂时抬高了某个资产的价格(从…

Github Copilot登录账号,完美支持chat

Github Copilot 代码补全等功能,提高写代码的效率 https://web.52shizhan.cn/activity/copilot 登录授权后,已经可以使用,完美。如图

大话设计模式学习笔记

目录 工厂模式策略模式备忘录模式(快照模式)代理模式单例模式迭代器模式访问者模式观察者模式解释器模式命令模式模板方法模式桥接模式适配器模式外观模式享元模式原型模式责任链模式中介者模式装饰模式状态模式 工厂模式 策略模式 核心:封装…

03.k8s常用的资源

3.k8s常用的资源 3.1 创建pod资源 k8s yaml的主要组成 apiVersion: v1 api版本 kind: pod 资源类型 metadata: 属性 spec: 详细上传nginx镜像文件,并且上传私有仓库里面 k8s_pod.yaml apiVersion: v1 kind: Pod metadata:name: nginxlabels:app: we…

Tuxera Ntfs For Mac 2023的具体使用方法

大家都知道由于操作系统的原因,在苹果电脑上不能够读写NTFS磁盘,但是,今天小编带来的这款tuxera ntfs 2024 mac 破解版,完美的解决了这个问题。这是一款在macOS平台上使用的磁盘读写软件,能够实现苹果Mac OS X系统读写…

Docker的数据管理(数据卷+数据卷容器)

文章目录 一、Docker的数据管理1、概述2、主要的技术(三种数据挂载方式)2.1、数据卷(Volumes)2.2、绑定挂载(Bind mounts)2.3、tmpfs挂载(Tmpfs mounts)2.4、之间的关系(…

示例丨医学、医药类查新点填写参考案例

根据《科技查新技术规范》GB/T 32003-2015,科学技术要点是必须要包含查新点内容的,而查新点就是科学技术要点中能够体现查新项目新颖性和技术进步的技术特征点。 在日常查新工作的接待中,我们发现医学、医药类查新合同上查新点的书写&#x…

计算机tcp/ip网络通信过程

目录 (1)同一网段两台计算机通信过程 (2)不同网段的两台计算机通信过程 (3)目的主机收到数据包后的解包过程 (1)同一网段两台计算机通信过程 如果两台计算机在同一个局域网中的同…

算法(九)希尔排序

文章目录 希尔排序简介代码实现 希尔排序简介 希尔排序(shell sort)选定一个小于N(数列长度)的整数gap作为第一增量,然后将所有距离为gap的元素分成一组,然后对每一组的元素进行插入排序。然后再取一个比前…

(1+X)Java程序设计高级(一)

Throwable:异常的基类,所有异常都继承自 java.lang.Throwable 类,Throwable 类有两个直接子类:Error 类和 Exception 类。Error:是 Java 应用程序本身无法恢复的严重错误,应用程序不需要捕获、处理这些严重…

7.1 Go 错误的概念

💝💝💝欢迎莅临我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:「stormsha的主页」…

【SQL每日一练】查询二进制树节点

文章目录 题目一、题析二、题解1.MySQL/SqlServer2.Oracle 题目 有一个表BST,其中包含两列:N和P,其中N表示二进制树中节点的值,P是N的父级。 编写一个查询,以查找按节点值排序的二进制树的节点类型。为每个节点输出以…

迅狐跨境电商系统源码:技术栈与多端集成

随着全球化贸易的不断深入,跨境电商系统源码成为了连接不同国家和地区消费者与商家的重要桥梁。本文将探讨跨境电商系统源码的技术栈以及如何通过多端集成来提升用户体验。 技术栈概览 跨境电商系统源码的技术栈是构建高效、稳定平台的基础。以下是构建跨境电商系…

OpenCV中的圆形标靶检测——背景概述

圆形标靶 如下图所示,相机标定中我们使用带有固定间距图案阵列的平板,来得到高精度的标靶像素坐标,进而计算得到相机的内参、畸变系数,相机之间的变换关系,和相机与世界坐标系的变换关系(即外参)。 不过标靶的形式多样,从图案类型来看常见的有棋盘格、圆形标靶…

音视频开发13 FFmpeg 音频 相关格式分析 -- AAC ADTS格式分析

这一节,我们学习常用的音频的格式 AAC,重点是掌握 AAC的传输格式 ADTS 头部的信息,目的是 : 当音频数据有问题的时候,如果是AAC的编码,在分析 头部信息的时候能够根据头部信息 判断问题是否出现在 头部。 A…

今天来讲讲,抖音小店商品的上架流程以及优化细节~

大家好,我是喷火龙。 做抖音小店选品选好之后,优化上架商品也是很重要的,也有很多需要注意的细节,今天就来给大家讲讲。 首先,软件采集,大致分为七步。 1. 以抖精灵为例,注册账号登录&#x…

到无穷大和更远,用分形更好

文章目录 一、说明二、分形到底是什么?三、更多更深刻的四、引进无穷小会产生什么样的怪事?五、希尔伯特曲线六、还有什么有趣的要补充的吗? 一、说明 ​​​​​​​数学领域有太多有趣的领域,领域我特别感兴趣。这是一个奇妙的…

怎么看自己电脑的配置?提升电脑的使用效率

了解自己电脑的配置是非常重要的,它可以帮助您了解电脑的性能水平,从而更好地选择适合的软件和游戏,或者进行系统升级和维护。然而,许多用户可能不知道怎么看自己电脑的配置信息。本文将介绍三种简单的方法,帮助您轻松…

android studio修改字体大小

android studio修改菜单栏、工具栏字体大小 android studio修改编辑框字体大小

常见制氮机的规格的及其特点介绍

制氮机根据其产气量、应用领域和设计特点,可以分为多种规格,满足不同行业的具体需求。以下是一些常见制氮机的规格的及其特点介绍: 制氮机的规格通常以其每小时制氮量进行分类。常见的规格有10L制氮机、50L制氮机、100L制氮机、500L制氮机以及…