whisper 模型源码解读

在这里插入图片描述

whisper官方源码

whisper 模型官方代码:https://github.com/openai/whisper/blob/main/whisper/model.py ;注释如下

import base64
import gzip
from dataclasses import dataclass
from typing import Dict, Iterable, Optionalimport numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor, nn# 从其他模块导入必要的函数
from .decoding import decode as decode_function
from .decoding import detect_language as detect_language_function
from .transcribe import transcribe as transcribe_function@dataclass
class ModelDimensions:"""该类用于存储模型的各项参数"""n_mels: int  # Mel谱图的频带数量n_audio_ctx: int  # 音频上下文窗口大小n_audio_state: int  # 音频状态维度n_audio_head: int  # 音频注意力头数量n_audio_layer: int  # 音频层数量n_vocab: int  # 词汇表大小n_text_ctx: int  # 文本上下文窗口大小n_text_state: int  # 文本状态维度n_text_head: int  # 文本注意力头数量n_text_layer: int  # 文本层数量class LayerNorm(nn.LayerNorm):def forward(self, x: Tensor) -> Tensor:"""重写 forward 方法,确保输入张量的类型在归一化前后保持一致"""return super().forward(x.float()).type(x.dtype)class Linear(nn.Linear):def forward(self, x: Tensor) -> Tensor:"""重写 forward 方法,确保权重和偏置与输入张量的类型一致"""return F.linear(x,self.weight.to(x.dtype),None if self.bias is None else self.bias.to(x.dtype),)class Conv1d(nn.Conv1d):def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor:"""重写 _conv_forward 方法,确保卷积操作中的权重和偏置与输入张量的类型一致"""return super()._conv_forward(x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype))def sinusoids(length, channels, max_timescale=10000):"""生成用于位置嵌入的正弦曲线"""assert channels % 2 == 0log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)class MultiHeadAttention(nn.Module):def __init__(self, n_state: int, n_head: int):"""初始化多头注意力层"""super().__init__()self.n_head = n_headself.query = Linear(n_state, n_state)self.key = Linear(n_state, n_state, bias=False)self.value = Linear(n_state, n_state)self.out = Linear(n_state, n_state)def forward(self,x: Tensor,xa: Optional[Tensor] = None,mask: Optional[Tensor] = None,kv_cache: Optional[dict] = None,):"""多头注意力的前向传播"""q = self.query(x)if kv_cache is None or xa is None or self.key not in kv_cache:# 如果没有缓存键和值,则正常计算k = self.key(x if xa is None else xa)v = self.value(x if xa is None else xa)else:# 如果有缓存,则使用缓存的键和值k = kv_cache[self.key]v = kv_cache[self.value]wv, qk = self.qkv_attention(q, k, v, mask)return self.out(wv), qkdef qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):"""计算 QKV 注意力"""n_batch, n_ctx, n_state = q.shapescale = (n_state // self.n_head) ** -0.25q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scalek = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scalev = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)qk = q @ kif mask is not None:qk = qk + mask[:n_ctx, :n_ctx]qk = qk.float()w = F.softmax(qk, dim=-1).to(q.dtype)return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()class ResidualAttentionBlock(nn.Module):def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):"""初始化残差注意力块"""super().__init__()self.attn = MultiHeadAttention(n_state, n_head)self.attn_ln = LayerNorm(n_state)self.cross_attn = (MultiHeadAttention(n_state, n_head) if cross_attention else None)self.cross_attn_ln = LayerNorm(n_state) if cross_attention else Nonen_mlp = n_state * 4self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state))self.mlp_ln = LayerNorm(n_state)def forward(self,x: Tensor,xa: Optional[Tensor] = None,mask: Optional[Tensor] = None,kv_cache: Optional[dict] = None,):"""残差注意力块的前向传播"""x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]if self.cross_attn:x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]x = x + self.mlp(self.mlp_ln(x))return xclass AudioEncoder(nn.Module):def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):"""初始化音频编码器"""super().__init__()self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList([ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)])self.ln_post = LayerNorm(n_state)def forward(self, x: Tensor):"""前向传播,处理音频输入x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)音频的Mel谱图"""x = F.gelu(self.conv1(x))x = F.gelu(self.conv2(x))x = x.permute(0, 2, 1)assert x.shape[1:] == self.positional_embedding.shape, "音频形状不正确"x = (x + self.positional_embedding).to(x.dtype)for block in self.blocks:x = block(x)x = self.ln_post(x)return xclass TextDecoder(nn.Module):def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):"""初始化文本解码器"""super().__init__()self.token_embedding = nn.Embedding(n_vocab, n_state)self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList([ResidualAttentionBlock(n_state, n_head, cross_attention=True)for _ in range(n_layer)])self.ln = LayerNorm(n_state)mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)self.register_buffer("mask", mask, persistent=False)def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):"""前向传播,处理文本输入并结合音频特征x : torch.LongTensor, shape = (batch_size, <= n_ctx)文本的标记序列xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)编码后的音频特征"""offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0x = (self.token_embedding(x)+ self.positional_embedding[offset : offset + x.shape[-1]])x = x.to(xa.dtype)for block in self.blocks:x = block(x, xa, mask=self.mask, kv_cache=kv_cache)x = self.ln(x)logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()return logitsclass Whisper(nn.Module):def __init__(self, dims: ModelDimensions):"""初始化 Whisper 模型"""super().__init__()self.dims = dimsself.encoder = AudioEncoder(self.dims.n_mels,self.dims.n_audio_ctx,self.dims.n_audio_state,self.dims.n_audio_head,self.dims.n_audio_layer,)self.decoder = TextDecoder(self.dims.n_vocab,self.dims.n_text_ctx,self.dims.n_text_state,self.dims.n_text_head,self.dims.n_text_layer,)# 默认情况下,使用解码器层的后一半进行时间对齐;# 若要使用特定的注意力头,可以使用 `set_alignment_heads()` 方法。all_heads = torch.zeros(self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool)all_heads[self.dims.n_text_layer // 2 :] = Trueself.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)def set_alignment_heads(self, dump: bytes):"""设置对齐的注意力头"""array = np.frombuffer(gzip.decompress(base64.b85decode(dump)), dtype=bool).copy()mask = torch.from_numpy(array).reshape(self.dims.n_text_layer, self.dims.n_text_head)self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)def embed_audio(self, mel: torch.Tensor):"""编码音频特征"""return self.encoder(mel)def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):"""获取预测的logits"""return self.decoder(tokens, audio_features)def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]:"""前向传播"""return self.decoder(tokens, self.encoder(mel))@propertydef device(self):"""获取模型所在的设备"""return next(self.parameters()).device@propertydef is_multilingual(self):"""判断模型是否支持多语言"""return self.dims.n_vocab >= 51865@propertydef num_languages(self):"""获取模型支持的语言数量"""return self.dims.n_vocab - 51765 - int(self.is_multilingual)def install_kv_cache_hooks(self, cache: Optional[dict] = None):"""为键和值的投影模块安装缓存钩子返回-------cache : Dict[nn.Module, torch.Tensor]映射键/值投影模块到其缓存的字典对象hooks : List[RemovableHandle]用于停止调用钩子的 PyTorch RemovableHandle 对象列表"""cache = {**cache} if cache is not None else {}hooks = []def save_to_cache(module, _, output):if module not in cache or output.shape[1] > self.dims.n_text_ctx:# 第一次标记或交叉注意时保存原始值cache[module] = outputelse:cache[module] = torch.cat([cache[module], output], dim=1).detach()return cache[module]def install_hooks(layer: nn.Module):if isinstance(layer, MultiHeadAttention):hooks.append(layer.key.register_forward_hook(save_to_cache))hooks.append(layer.value.register_forward_hook(save_to_cache))self.decoder.apply(install_hooks)return cache, hooksdetect_language = detect_language_function  # 语言检测函数transcribe = transcribe_function  # 转录函数decode = decode_function  # 解码函数

语音识别自回归解码过程分析和举例说明

分析

语音识别自回归解码过程通常涉及以下步骤:

  1. 音频预处理:首先将输入的音频信号转换为Mel谱图。这一步骤在实际应用中通常由音频前端处理模块完成。

  2. 音频编码:将预处理后的Mel谱图输入到音频编码器中,生成音频特征表示。这些特征表示将作为后续文本解码器的输入。

  3. 文本解码:文本解码器通过自回归方式生成文本序列。具体来说,文本解码器在每个时间步上根据前一步生成的文本标记以及音频特征生成下一个文本标记。

  4. 语言检测和转录:在生成的文本序列基础上,可以进行语言检测,确认文本所使用的语言。此外,转录过程将生成的文本序列转换为最终的文本输出。

具体步骤

以下代码展示了上述过程的具体实现:

import torch# 初始化模型参数
dims = ModelDimensions(n_mels=80,n_audio_ctx=1500,n_audio_state=512,n_audio_head=8,n_audio_layer=6,n_vocab=51865,n_text_ctx=448,n_text_state=512,n_text_head=8,n_text_layer=6,
)# 创建模型实例
model = Whisper(dims)# 假设我们有一个Mel谱图输入
mel_spectrogram = torch.randn(1, 80, 1500)  # (batch_size, n_mels, n_audio_ctx)# 编码音频特征
audio_features = model.embed_audio(mel_spectrogram)# 假设我们有一个初始的文本标记序列
initial_tokens = torch.tensor([[1, 2, 3]])  # (batch_size, seq_len)# 自回归解码过程
for _ in range(10):  # 假设生成长度为10的序列logits = model.logits(initial_tokens, audio_features)next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)initial_tokens = torch.cat([initial_tokens, next_token], dim=-1)# 最终生成的文本标记序列
final_tokens = initial_tokens# 打印生成的文本标记序列
print("Generated tokens:", final_tokens)

举例说明

假设我们有一段音频,其Mel谱图表示如下:

mel_spectrogram = torch.randn(1, 80, 1500)

我们希望通过自回归解码生成对应的文本表示。首先,我们将Mel谱图输入到音频编码器中,得到音频特征表示:

audio_features = model.embed_audio(mel_spectrogram)

然后,我们使用一个初始的文本标记序列(例如,序列开始标记)开始自回归解码过程:

initial_tokens = torch.tensor([[1]])  # 序列开始标记

在每个时间步,我们根据当前的文本标记序列和音频特征生成下一个文本标记:

logits = model.logits(initial_tokens, audio_features)
next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
initial_tokens = torch.cat([initial_tokens, next_token], dim=-1)

这个过程重复若干次(例如10次)直到生成完整的文本序列:

for _ in range(10):logits = model.logits(initial_tokens, audio_features)next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)initial_tokens = torch.cat([initial_tokens, next_token], dim=-1)

最终得到的文本标记序列为:

final_tokens = initial_tokens
print("Generated tokens:", final_tokens)

以上示例展示了从音频输入到文本输出的完整自回归解码过程。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/28574.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

java设计模式和面向对象编程思想

Java设计模式和面向对象编程思想是软件开发中的核心概念&#xff0c;对于构建可维护、可扩展的软件系统至关重要。下面是对这两个主题的知识点总结&#xff1a; 面向对象编程&#xff08;OOP&#xff09;思想 封装&#xff1a;将数据&#xff08;属性&#xff09;和操作这些数据…

享元和代理模式

文章目录 享元模式1.引出享元模式1.展示网站项目需求2.传统方案解决3.问题分析 2.享元模式1.基本介绍2.原理类图3.外部状态和内部状态4.类图5.代码实现1.AbsWebSite.java 抽象的网站2.ConcreteWebSite.java 具体的网站&#xff0c;type属性是内部状态3.WebSiteFactory.java 网站…

CSS从入门到精通——动画:CSS3动画执行次数和逆向播放

目录 任务描述 相关知识 动画执行次数 动画反向播放 编程要求 任务描述 本关任务&#xff1a;用 CSS3 实现loading效果。效果图如下&#xff1a; 相关知识 为了完成本关任务&#xff0c;你需要掌握&#xff1a;1.动画执行次数&#xff0c;2.动画反向播放。 需要实现的效…

R调用Taxonkit展示系统发育信息

Introduction TaxonKit是一个用于处理生物分类学数据的命令行工具。 它的主要功能是处理NCBI的生物分类学数据&#xff0c;包括对分类单元&#xff08;如物种、属、科等&#xff09;的查找、分类单元的上下位关系查询、分类单元名称的标准化等。 为了方便R社区用户&#xff0…

【计算机组成原理】指令系统考研真题详解之拓展操作码!

计算机组成原理&#xff1a;指令系统概述与深入解析 1. 指令系统概述 计算机软硬件界面的概念 在计算机组成原理中&#xff0c;指令系统扮演着至关重要的角色&#xff0c;它是计算机软硬件界面的核心。软件通过指令与硬件进行通信&#xff0c;硬件根据指令执行相应的操作。指…

如何解决javadoc一直找不到路径的问题?

目录 一、什么是javadoc二、javadoc为什么会找不到路径三、如何解决javadoc一直找不到路径的问题 一、什么是javadoc Javadoc是一种用于生成Java源代码文档的工具&#xff0c;它可以帮助开发者生成易于阅读和理解的文档。Javadoc通过解析Java源代码中的注释&#xff0c;提取其…

【Python】理解『下采样』:原理与应用

是你多么温馨的目光 教我坚毅望着前路 叮嘱我跌倒不应放弃 没法解释怎可报尽亲恩 爱意宽大是无限 请准我说声真的爱你 &#x1f3b5; Beyond《真的爱你》 在数字信号处理、图像处理和机器学习中&#xff0c;下采样&#xff08;Downsampling&#xff09;是…

42 mysql “+“ 操作符的实现

前言 问题来自于 chinaunix, mysql select 子句的小白问题 mysql 的一些基础的 算术运算符 的计算的实现 这里 整理如下 case, 执行之前 设置如下变量 set a 2; set b 3;select a b; select a b; select 1 3; select 1 3; select a b; select a b; select a b; …

【Quartus 13.0】NIOS II 部署UART 和 PWM

打算在 EP1C3T144I7 芯片上部署 nios ii 做 uart & pwm控制 这个芯片或许不够做 QT 部署 这个芯片好老啊&#xff0c;但是做控制足够了&#xff0c;我只是想装13写 leader给的接口代码是用VHDL写的&#xff0c;我不会 当然verilog我也不太会 就这样&#xff0c;随便写吧 co…

element-plus表单组件之自动补全组件el-autocomplete和级联选择器组件el-cascader

el-autocomplete 自动补全组件 自补全组件的功能和可以根据输入过滤的el-select组件有些类似。 fetch-suggestions 根据输入框的输入获取建议的内容&#xff0c;其接受值是一个函数&#xff0c;有2个参数&#xff0c;querystring:输入的内容&#xff0c;callback内置函数&…

数据结构C语言版:顺序表基本操作的实现

参考教材&#xff1a;数据结构C语言版&#xff08;严蔚敏&#xff0c;吴伟民编著&#xff09; 目录 线性表的基本操作&#xff1a; 1&#xff1a;线性表L的初始化(参数用引用) 2&#xff1a;销毁线性表L 3&#xff1a;清空线性表L 4&#xff1a;求线性表L的长度 5&#xf…

比亚迪智驾技术震撼登场!L3级自动驾驶领跑全国,无图导航、夜间挑战轻松应对!

作为新能源汽车领域的翘楚&#xff0c;比亚迪在电池技术与智能驾驶方面都有着卓越的表现。近日&#xff0c;比亚迪凭借其领先的智驾技术&#xff0c;成功入选全国首批L3级自动驾驶上路及行驶试点名单&#xff0c;这无疑将推动智驾技术的普及速度。 你知道吗&#xff1f;比亚迪智…

单目标应用:基于三角拓扑聚合优化算法TTAO的微电网优化(MATLAB代码)

一、微电网模型介绍 微电网多目标优化调度模型简介_vmgpqv-CSDN博客 参考文献&#xff1a; [1]李兴莘,张靖,何宇,等.基于改进粒子群算法的微电网多目标优化调度[J].电力科学与工程, 2021, 37(3):7 二、三角拓扑聚合优化算法求解微电网 2.1算法简介 三角拓扑聚合优化算法&…

如何连接达梦数据库?

连接达梦数据库&#xff08;DM Database&#xff09;可以通过多种方式进行&#xff0c;包括使用 JDBC&#xff08;Java Database Connectivity&#xff09;驱动程序&#xff0c;这是最常见的方式之一。以下是使用 Java 通过 JDBC 连接达梦数据库的详细步骤&#xff1a; 1. 准备…

梦想编织者Luna:COZE从童话绘本到乐章的奇妙转化

前言 Coze是什么&#xff1f; Coze扣子是字节跳动发布的一款AI聊天机器人构建平台&#xff0c;能够快速创建、调试和优化AI聊天机器人的应用程序。只要你有想法&#xff0c;无需有编程经验&#xff0c;都可以用扣子快速、低门槛搭建专属于你的 Chatbot&#xff0c;并一键发布…

gbase8s数据库的逻辑日志、物理日志和两种特殊情形的学习

(一) 日志的介绍 1. 日志的类别 数据库日志主要是分为记录日志、逻辑日志和物理日志。 记录日志&#xff1a;记录日志包括了数据库的报错日志、连接日志、sql执行等信息&#xff0c;这些日志不存储在dbspace上&#xff0c;而是保存在操作系统的文件内逻辑日志和物理日志&…

Kali之metasploit学习

目标&#xff1a;尝试使用metasploit制作一个windows 后门&#xff08;exe文件&#xff09; 一&#xff1a;使用metasploit生成一个exe安装包。 二、将对应的可执行文件放入到目标机 python3 -m http.server 端口号&#xff1a; 模块化启动一个端口。 windows 证书管理工具&…

Python(二)---数据类型与变量、以及运算符

文章目录 前言1.Python程序的构成1.1.代码的组织和缩进1.2.使用\行连接符 2.对象和引用、标识符规则2.1.对象2.2.引用2.3.标识符规则 3.变量和简单赋值语句3.1.变量的声明和赋值3.2.删除变量和垃圾回收机制3.3.常量3.4.链式赋值3.5.系列解包赋值 4.最基本内置数据类型4.1.数字和…

使用了代理IP怎么还会被封?代理IP到底有没有效果

代理IP作为一种网络工具&#xff0c;被广泛应用于各种场景&#xff0c;例如网络爬虫、海外购物、规避地区限制等。然而&#xff0c;很多用户在使用代理IP的过程中却发现自己的账号被封禁&#xff0c;这让他们不禁产生疑问&#xff1a;使用了代理IP怎么还会被封&#xff1f;代理…

芯片验证分享8 —— 代码审查2

大家好&#xff0c;我是谷公子&#xff0c;上节课给大家讲了代码审查中的代码正向检查&#xff0c;今天我们来讲代码审查的其他方法。 今天介绍的检查方法有&#xff1a; 代码反向检查 桌面检查 同行评审 可用性验证 这些验证方法可以应用在芯片开发的任何阶段。代码审查…