whisper官方源码
whisper 模型官方代码:https://github.com/openai/whisper/blob/main/whisper/model.py ;注释如下
import base64
import gzip
from dataclasses import dataclass
from typing import Dict, Iterable, Optionalimport numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor, nn# 从其他模块导入必要的函数
from .decoding import decode as decode_function
from .decoding import detect_language as detect_language_function
from .transcribe import transcribe as transcribe_function@dataclass
class ModelDimensions:"""该类用于存储模型的各项参数"""n_mels: int # Mel谱图的频带数量n_audio_ctx: int # 音频上下文窗口大小n_audio_state: int # 音频状态维度n_audio_head: int # 音频注意力头数量n_audio_layer: int # 音频层数量n_vocab: int # 词汇表大小n_text_ctx: int # 文本上下文窗口大小n_text_state: int # 文本状态维度n_text_head: int # 文本注意力头数量n_text_layer: int # 文本层数量class LayerNorm(nn.LayerNorm):def forward(self, x: Tensor) -> Tensor:"""重写 forward 方法,确保输入张量的类型在归一化前后保持一致"""return super().forward(x.float()).type(x.dtype)class Linear(nn.Linear):def forward(self, x: Tensor) -> Tensor:"""重写 forward 方法,确保权重和偏置与输入张量的类型一致"""return F.linear(x,self.weight.to(x.dtype),None if self.bias is None else self.bias.to(x.dtype),)class Conv1d(nn.Conv1d):def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor:"""重写 _conv_forward 方法,确保卷积操作中的权重和偏置与输入张量的类型一致"""return super()._conv_forward(x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype))def sinusoids(length, channels, max_timescale=10000):"""生成用于位置嵌入的正弦曲线"""assert channels % 2 == 0log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)class MultiHeadAttention(nn.Module):def __init__(self, n_state: int, n_head: int):"""初始化多头注意力层"""super().__init__()self.n_head = n_headself.query = Linear(n_state, n_state)self.key = Linear(n_state, n_state, bias=False)self.value = Linear(n_state, n_state)self.out = Linear(n_state, n_state)def forward(self,x: Tensor,xa: Optional[Tensor] = None,mask: Optional[Tensor] = None,kv_cache: Optional[dict] = None,):"""多头注意力的前向传播"""q = self.query(x)if kv_cache is None or xa is None or self.key not in kv_cache:# 如果没有缓存键和值,则正常计算k = self.key(x if xa is None else xa)v = self.value(x if xa is None else xa)else:# 如果有缓存,则使用缓存的键和值k = kv_cache[self.key]v = kv_cache[self.value]wv, qk = self.qkv_attention(q, k, v, mask)return self.out(wv), qkdef qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):"""计算 QKV 注意力"""n_batch, n_ctx, n_state = q.shapescale = (n_state // self.n_head) ** -0.25q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scalek = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scalev = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)qk = q @ kif mask is not None:qk = qk + mask[:n_ctx, :n_ctx]qk = qk.float()w = F.softmax(qk, dim=-1).to(q.dtype)return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()class ResidualAttentionBlock(nn.Module):def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):"""初始化残差注意力块"""super().__init__()self.attn = MultiHeadAttention(n_state, n_head)self.attn_ln = LayerNorm(n_state)self.cross_attn = (MultiHeadAttention(n_state, n_head) if cross_attention else None)self.cross_attn_ln = LayerNorm(n_state) if cross_attention else Nonen_mlp = n_state * 4self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state))self.mlp_ln = LayerNorm(n_state)def forward(self,x: Tensor,xa: Optional[Tensor] = None,mask: Optional[Tensor] = None,kv_cache: Optional[dict] = None,):"""残差注意力块的前向传播"""x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]if self.cross_attn:x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]x = x + self.mlp(self.mlp_ln(x))return xclass AudioEncoder(nn.Module):def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):"""初始化音频编码器"""super().__init__()self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList([ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)])self.ln_post = LayerNorm(n_state)def forward(self, x: Tensor):"""前向传播,处理音频输入x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)音频的Mel谱图"""x = F.gelu(self.conv1(x))x = F.gelu(self.conv2(x))x = x.permute(0, 2, 1)assert x.shape[1:] == self.positional_embedding.shape, "音频形状不正确"x = (x + self.positional_embedding).to(x.dtype)for block in self.blocks:x = block(x)x = self.ln_post(x)return xclass TextDecoder(nn.Module):def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):"""初始化文本解码器"""super().__init__()self.token_embedding = nn.Embedding(n_vocab, n_state)self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList([ResidualAttentionBlock(n_state, n_head, cross_attention=True)for _ in range(n_layer)])self.ln = LayerNorm(n_state)mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)self.register_buffer("mask", mask, persistent=False)def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):"""前向传播,处理文本输入并结合音频特征x : torch.LongTensor, shape = (batch_size, <= n_ctx)文本的标记序列xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)编码后的音频特征"""offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0x = (self.token_embedding(x)+ self.positional_embedding[offset : offset + x.shape[-1]])x = x.to(xa.dtype)for block in self.blocks:x = block(x, xa, mask=self.mask, kv_cache=kv_cache)x = self.ln(x)logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()return logitsclass Whisper(nn.Module):def __init__(self, dims: ModelDimensions):"""初始化 Whisper 模型"""super().__init__()self.dims = dimsself.encoder = AudioEncoder(self.dims.n_mels,self.dims.n_audio_ctx,self.dims.n_audio_state,self.dims.n_audio_head,self.dims.n_audio_layer,)self.decoder = TextDecoder(self.dims.n_vocab,self.dims.n_text_ctx,self.dims.n_text_state,self.dims.n_text_head,self.dims.n_text_layer,)# 默认情况下,使用解码器层的后一半进行时间对齐;# 若要使用特定的注意力头,可以使用 `set_alignment_heads()` 方法。all_heads = torch.zeros(self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool)all_heads[self.dims.n_text_layer // 2 :] = Trueself.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)def set_alignment_heads(self, dump: bytes):"""设置对齐的注意力头"""array = np.frombuffer(gzip.decompress(base64.b85decode(dump)), dtype=bool).copy()mask = torch.from_numpy(array).reshape(self.dims.n_text_layer, self.dims.n_text_head)self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)def embed_audio(self, mel: torch.Tensor):"""编码音频特征"""return self.encoder(mel)def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):"""获取预测的logits"""return self.decoder(tokens, audio_features)def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]:"""前向传播"""return self.decoder(tokens, self.encoder(mel))@propertydef device(self):"""获取模型所在的设备"""return next(self.parameters()).device@propertydef is_multilingual(self):"""判断模型是否支持多语言"""return self.dims.n_vocab >= 51865@propertydef num_languages(self):"""获取模型支持的语言数量"""return self.dims.n_vocab - 51765 - int(self.is_multilingual)def install_kv_cache_hooks(self, cache: Optional[dict] = None):"""为键和值的投影模块安装缓存钩子返回-------cache : Dict[nn.Module, torch.Tensor]映射键/值投影模块到其缓存的字典对象hooks : List[RemovableHandle]用于停止调用钩子的 PyTorch RemovableHandle 对象列表"""cache = {**cache} if cache is not None else {}hooks = []def save_to_cache(module, _, output):if module not in cache or output.shape[1] > self.dims.n_text_ctx:# 第一次标记或交叉注意时保存原始值cache[module] = outputelse:cache[module] = torch.cat([cache[module], output], dim=1).detach()return cache[module]def install_hooks(layer: nn.Module):if isinstance(layer, MultiHeadAttention):hooks.append(layer.key.register_forward_hook(save_to_cache))hooks.append(layer.value.register_forward_hook(save_to_cache))self.decoder.apply(install_hooks)return cache, hooksdetect_language = detect_language_function # 语言检测函数transcribe = transcribe_function # 转录函数decode = decode_function # 解码函数
语音识别自回归解码过程分析和举例说明
分析
语音识别自回归解码过程通常涉及以下步骤:
-
音频预处理:首先将输入的音频信号转换为Mel谱图。这一步骤在实际应用中通常由音频前端处理模块完成。
-
音频编码:将预处理后的Mel谱图输入到音频编码器中,生成音频特征表示。这些特征表示将作为后续文本解码器的输入。
-
文本解码:文本解码器通过自回归方式生成文本序列。具体来说,文本解码器在每个时间步上根据前一步生成的文本标记以及音频特征生成下一个文本标记。
-
语言检测和转录:在生成的文本序列基础上,可以进行语言检测,确认文本所使用的语言。此外,转录过程将生成的文本序列转换为最终的文本输出。
具体步骤
以下代码展示了上述过程的具体实现:
import torch# 初始化模型参数
dims = ModelDimensions(n_mels=80,n_audio_ctx=1500,n_audio_state=512,n_audio_head=8,n_audio_layer=6,n_vocab=51865,n_text_ctx=448,n_text_state=512,n_text_head=8,n_text_layer=6,
)# 创建模型实例
model = Whisper(dims)# 假设我们有一个Mel谱图输入
mel_spectrogram = torch.randn(1, 80, 1500) # (batch_size, n_mels, n_audio_ctx)# 编码音频特征
audio_features = model.embed_audio(mel_spectrogram)# 假设我们有一个初始的文本标记序列
initial_tokens = torch.tensor([[1, 2, 3]]) # (batch_size, seq_len)# 自回归解码过程
for _ in range(10): # 假设生成长度为10的序列logits = model.logits(initial_tokens, audio_features)next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)initial_tokens = torch.cat([initial_tokens, next_token], dim=-1)# 最终生成的文本标记序列
final_tokens = initial_tokens# 打印生成的文本标记序列
print("Generated tokens:", final_tokens)
举例说明
假设我们有一段音频,其Mel谱图表示如下:
mel_spectrogram = torch.randn(1, 80, 1500)
我们希望通过自回归解码生成对应的文本表示。首先,我们将Mel谱图输入到音频编码器中,得到音频特征表示:
audio_features = model.embed_audio(mel_spectrogram)
然后,我们使用一个初始的文本标记序列(例如,序列开始标记)开始自回归解码过程:
initial_tokens = torch.tensor([[1]]) # 序列开始标记
在每个时间步,我们根据当前的文本标记序列和音频特征生成下一个文本标记:
logits = model.logits(initial_tokens, audio_features)
next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
initial_tokens = torch.cat([initial_tokens, next_token], dim=-1)
这个过程重复若干次(例如10次)直到生成完整的文本序列:
for _ in range(10):logits = model.logits(initial_tokens, audio_features)next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)initial_tokens = torch.cat([initial_tokens, next_token], dim=-1)
最终得到的文本标记序列为:
final_tokens = initial_tokens
print("Generated tokens:", final_tokens)
以上示例展示了从音频输入到文本输出的完整自回归解码过程。