BertModel源码解析
- 1. BertModel 介绍
- 2. BertModel 源码逐行注释
1. BertModel 介绍
BertModel 是 transformers 库中的核心模型之一,它实现了 BERT(Bidirectional Encoder Representations from Transformers)模型的架构。BERT 是基于 Transformer 编码器的堆叠模块来构建的。以下是 BertModel 内部包含的主要模块和组件的详细介绍:
- BertEmbeddings (BertEmbeddings 源码解析)
将词嵌入(Token Embeddings) 、位置嵌入(Position Embeddings) 和 标记类型嵌入(Segment Embeddings) 组合起来,为每个输入token生成最终的嵌入表示
- BertEncoder (BertEncoder 源码解析)
BERT 模型的核心部分,包含了多个堆叠的 Transformer 编码器层(Layer)。每一层都是一个自注意力机制与前馈神经网络的组合。即:
--------- Self-Attention Heads (BertSelfAttention)
--------- Feedforward Neural Network (BertIntermediate & BertOutput)
- BertPooler
负责将编码器的输出转化为单一的全局表示。
通常使用第一个 token([CLS])的表示作为整个序列的表示,并通过一个线性层加上 tanh 激活函数生成最终的句子向量。
这个句子向量可以用于分类或其他需要整体序列表示的任务。
BertModel 是由多个模块组合而成的复杂架构,这些模块协同工作,共同实现了强大的文本表示能力。通过这些模块,BertModel 能够捕捉句子中深层次的语义信息,并应用于广泛的 NLP 任务。
2. BertModel 源码逐行注释
源码地址:transformers/src/transformers/models/bert/modeling_bert.py
# -*- coding: utf-8 -*-
# @time: 2024/7/11 10:43"""PyTorch BERT model."""
import torchfrom typing import List, Optional, Tuple, Union
from transformers import BertPreTrainedModel
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask_for_sdpa, _prepare_4d_attention_mask_for_sdpa
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
from transformers.models.bert.modeling_bert import BertEmbeddings, BertEncoder, BertPooler
from transformers.utils import add_start_docstrings_to_model_forward, add_code_sample_docstrings_CHECKPOINT_FOR_DOC = "google-bert/bert-base-uncased"
_CONFIG_FOR_DOC = "BertConfig"BERT_INPUTS_DOCSTRING = r"""Args:input_ids (`torch.LongTensor` of shape `({0})`):Indices of input sequence tokens in the vocabulary.Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and[`PreTrainedTokenizer.__call__`] for details.[What are input IDs?](../glossary#input-ids)attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:- 1 for tokens that are **not masked**,- 0 for tokens that are **masked**.[What are attention masks?](../glossary#attention-mask)token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`:- 0 corresponds to a *sentence A* token,- 1 corresponds to a *sentence B* token.[What are token type IDs?](../glossary#token-type-ids)position_ids (`torch.LongTensor` of shape `({0})`, *optional*):Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,config.max_position_embeddings - 1]`.[What are position IDs?](../glossary#position-ids)head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:- 1 indicates the head is **not masked**,- 0 indicates the head is **masked**.inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. Thisis useful if you want more control over how to convert `input_ids` indices into associated vectors than themodel's internal embedding lookup matrix.output_attentions (`bool`, *optional*):Whether or not to return the attentions tensors of all attention layers. See `attentions` under returnedtensors for more detail.output_hidden_states (`bool`, *optional*):Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors formore detail.return_dict (`bool`, *optional*):Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""class BertModel(BertPreTrainedModel):"""模型可以作为编码器(仅使用自注意力)或解码器。在作为解码器时,自注意力层之间会添加一层交叉注意力层,这遵循 Ashish Vaswani、Noam Shazeer、Niki Parmar、Jakob Uszkoreit、Llion Jones、Aidan N. Gomez、Lukasz Kaiser 和 Illia Polosukhin 所描述的 [Attention is all you need](https://arxiv.org/abs/1706.03762) 架构。要使模型作为解码器,需要在初始化时将配置中的 `is_decoder` 参数设置为 `True`。要在 Seq2Seq 模型中使用,需要将 `is_decoder` 和 `add_cross_attention` 参数都设置为 `True`;此时在前向传播中需要提供 `encoder_hidden_states` 作为输入。"""def __init(self, config, add_pooling_layer=True):super().__init__(config)self.config = config # 保存传入的配置对象self.embeddings = BertEmbeddings(config) # 初始化 BERT 的嵌入层self.encoder = BertEncoder(config) # 初始化 BERT 的编码器层self.pooler = BertPooler(config) if add_pooling_layer else None # 如果 add_pooling_layer 为 True,则初始化池化层self.attn_implementation = config._attn_implementation # 保存注意力机制的实现细节self.position_embedding_type = config.position_embedding_type # 保存位置嵌入的类型self.post_init() # 执行一些初始化后的操作def get_input_embeddingss(self):return self.embeddings.word_embeddingsdef set_input_embeddings(self, value):self.embeddings.word_embeddings = valuedef _prune_heads(self, heads_to_prune):"""修剪模型的注意力头。heads_to_prune: 是一个字典,包含 {layer_num: 该层中要修剪的头的列表}。详见基类 PreTrainedModel。"""for layer, heads in heads_to_prune.items():self.encoder.layer[layer].attention.prune_heads(heads)# 为模型的前向传递函数添加文档字符串和代码示例"""add_start_docstrings_to_model_forward装饰器会将 BERT_INPUTS_DOCSTRING 中的文档字符串添加到模型前向传递函数的开头部分。BERT_INPUTS_DOCSTRING 是一个格式化字符串,其中包含有关输入张量形状的信息。在这里,它被格式化为 "batch_size, sequence_length",描述了输入的批量大小和序列长度。add_code_sample_docstrings装饰器会为模型的前向传递函数添加代码示例文档字符串。checkpoint 参数指定了用于文档的检查点名称。output_type 参数指定了模型前向传递输出的类型,这里是 BaseModelOutputWithPoolingAndCrossAttentions。config_class 参数指定了模型配置的类,这里是 _CONFIG_FOR_DOC。"""@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))@add_code_sample_docstrings(checkpoint=_CHECKPOINT_FOR_DOC,output_type=BaseModelOutputWithPoolingAndCrossAttentions,config_class=_CONFIG_FOR_DOC,)def forward(self,input_ids: Optional[torch.Tensor] = None,attention_mask: Optional[torch.Tensor] = None,token_type_ids: Optional[torch.Tensor] = None,position_ids: Optional[torch.Tensor] = None,head_mask: Optional[torch.Tensor] = None,inputs_embeds: Optional[torch.Tensor] = None,encoder_hidden_states: Optional[torch.Tensor] = None,encoder_attention_mask: Optional[torch.Tensor] = None,past_key_values: Optional[List[torch.FloatTensor]] = None,use_cache: Optional[bool] = None,output_attentions: Optional[bool] = None,output_hidden_states: Optional[bool] = None,return_dict: Optional[bool] = None,) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:r"""encoder_hidden_states (`torch.FloatTensor`,形状为 `(batch_size, sequence_length, hidden_size)`,*可选*):编码器最后一层输出的隐藏状态序列。如果模型配置为解码器,则在交叉注意力中使用。encoder_attention_mask (`torch.FloatTensor`,形状为 `(batch_size, sequence_length)`,*可选*):用于避免对编码器输入中的填充标记索引执行注意力操作的掩码。如果模型配置为解码器,则在交叉注意力中使用。掩码值选择 `[0, 1]`:- 1 表示**未被掩码**的标记,- 0 表示**被掩码**的标记。past_key_values (`tuple(tuple(torch.FloatTensor))`,长度为 `config.n_layers`,每个元组有4个形状为 `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)` 的张量):包含注意力块的预计算键和值的隐藏状态。可用于加速解码。如果使用 `past_key_values`,用户可以选择仅输入形状为 `(batch_size, 1)` 的最后一个 `decoder_input_ids`(那些没有其过去键值状态的输入)而不是形状为 `(batch_size, sequence_length)` 的所有 `decoder_input_ids`。use_cache (`bool`,*可选*):如果设置为 `True`,则返回 `past_key_values` 键值状态,并可用于加速解码(参见 `past_key_values`)。"""# ------------------------------1. 关于参数的配置---------------------------"""最后得到的参数有:output_attentions(是否返回所有注意力层的注意力张量), output_hidden_states(是否返回所有层的隐藏状态), return_dict(是否返回ModelOutput而不是普通元组), use_cache(如果设置为True, past_key_values则返回键值状态并可用于加快解码速度), batch_size, seq_length, device, past_key_values_length(包含注意力块的预计算键和值隐藏状态, 可用于加速解码), token_type_ids,"""# 如果 output_attentions 不为 None,则使用其值,否则使用配置中的默认值 self.config.output_attentionsoutput_attentions = output_attentions if output_attentions is not None else self.config.output_attentions# 如果 output_hidden_states 不为 None,则使用其值,否则使用配置中的默认值 self.config.output_hidden_statesoutput_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)# 如果 return_dict 不为 None,则使用其值,否则使用配置中的默认值 self.config.use_return_dictreturn_dict = return_dict if return_dict is not None else self.config.use_return_dict# 如果模型配置为解码器,设置 use_cache 参数if self.config.is_decoder:use_cache = use_cache if use_cache is not None else self.config.use_cacheelse:use_cache = False # 如果不是解码器,强制将 use_cache 设置为 False# 检查 input_ids 和 inputs_embeds,不能同时指定if input_ids is not None and inputs_embeds is not None:raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")elif input_ids is not None:# 如果指定了 input_ids,检查填充和注意力掩码self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)input_shape = input_ids.size() # 获取 input_ids 的形状elif inputs_embeds is not None:input_shape = inputs_embeds.size()[:-1] # 获取 inputs_embeds 的形状(除去最后一维)else:raise ValueError("You have to specify either input_ids or inputs_embeds")# 从 input_shape 中获取 batch_size 和 seq_lengthbatch_size, seq_length = input_shape# 如果 input_ids 不为 None,则设备为 input_ids 的设备. 否则,设备为 inputs_embeds 的设备device = input_ids.device if input_ids is not None else inputs_embeds.device# 如果 past_key_values 不为 None,则获取 past_key_values 中第一个元素的形状的第三维长度作为 past_key_values_length. 否则,将 past_key_values_length 设置为 0past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0if token_type_ids is None:# 如果模型的嵌入层有 token_type_ids 属性if hasattr(self.embeddings, "token_type_ids"):# 从嵌入层的 token_type_ids 中获取前 seq_length 的部分buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]# 扩展 token_type_ids 以匹配 batch_size 和 seq_lengthbuffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)# 将扩展后的 token_type_ids 赋值给 token_type_idstoken_type_ids = buffered_token_type_ids_expandedelse:# 如果嵌入层没有 token_type_ids 属性,则创建一个全零的 token_type_idstoken_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)# ------------------------------ *2. 输入嵌入层(Input Embeddings)---------------------------# 计算嵌入层的输出embedding_output = self.embeddings(input_ids=input_ids,position_ids=position_ids,token_type_ids=token_type_ids,inputs_embeds=inputs_embeds,past_key_values_length=past_key_values_length,)# ------------------------------3. 注意力掩码的配置------------------------------------------if attention_mask is None:attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device)# 判断是否使用 SDPA 注意力掩码的条件use_sdpa_attention_masks = (self.attn_implementation == "sdpa" # 判断注意力实现是否为 SDPAand self.position_embedding_type == "absolute" # 判断位置嵌入类型是否为绝对位置and head_mask is None # 判断是否没有指定 head_maskand not output_attentions # 判断是否不需要输出注意力)# 根据条件 use_sdpa_attention_masks 进行扩展注意力掩码if use_sdpa_attention_masks:# 为 SDPA 扩展注意力掩码# [batch_size, seq_length] -> [batch_size, 1, seq_length, seq_length]if self.config.is_decoder:# 如果是解码器,准备 4D 因果注意力掩码: 这种掩码确保解码器在生成下一个词时只能看到当前词及其之前的词,而不能看到未来的词。"""attention_mask:输入的注意力掩码,通常是一个二维张量,表示每个词的位置是否应该被注意力机制关注。input_shape:输入的形状,通常是 (batch_size, seq_length)。embedding_output:嵌入层的输出,包含输入词的嵌入表示。past_key_values_length:过去键值的长度,用于支持缓存机制(如在解码器中)。"""extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(attention_mask,input_shape,embedding_output,past_key_values_length,)else:# 如果不是解码器,准备 4D 注意力掩码extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, embedding_output.dtype, tgt_len=seq_length)else:# 提供一个维度为 [batch_size, from_seq_length, to_seq_length] 的自注意力掩码# 只需使其可广播到所有注意力头extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)# ================================ 获得了extended_attention_mask: [batch_size, 1, seq_length, seq_length]=======================================# 如果为交叉注意力提供了 2D 或 3D 注意力掩码# 需要使其可广播到 [batch_size, num_heads, seq_length, seq_length]if self.config.is_decoder and encoder_hidden_states is not None: # 解码器的配置# 获取编码器隐藏状态的形状encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)# 如果没有提供编码器注意力掩码,创建一个全为 1 的编码器注意力掩码if encoder_attention_mask is None:encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)if use_sdpa_attention_masks:# 为 SDPA 扩展注意力掩码# [batch_size, seq_length] -> [batch_size, 1, seq_length, seq_length]encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length)else:# 否则,反转编码器注意力掩码encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)else:# 如果不需要编码器扩展注意力掩码,设置为 Noneencoder_extended_attention_mask = None# ================================ 获得了encoder_extended_attention_mask: [batch_size, 1, seq_length, seq_length]或 None====================================# 准备注意力头掩码(如果需要)# 1.0 表示保留该注意力头# attention_probs 的形状为 bsz x n_heads x N x N# 输入的 head_mask 的形状为 [num_heads] 或 [num_hidden_layers x num_heads]# head_mask 被转换为形状 [num_hidden_layers x batch x num_heads x seq_length x seq_length]head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)# ================================ 获得了head_mask, 一般情况下是None ====================================# ------------------------------ *4. 编码器层(Encoder Layers)-----------------------------# 传递输入到编码器,并获取编码器的输出encoder_outputs = self.encoder(embedding_output, # 嵌入层的输出attention_mask=extended_attention_mask, # 扩展的注意力掩码head_mask=head_mask, # 注意力头掩码,一般Noneencoder_hidden_states=encoder_hidden_states, # 编码器的隐藏状态,一般Noneencoder_attention_mask=encoder_extended_attention_mask, # 编码器的注意力掩码past_key_values=past_key_values, # 过去的键值对use_cache=use_cache, # 是否使用缓存output_attentions=output_attentions, # 是否输出注意力output_hidden_states=output_hidden_states, # 是否输出隐藏状态return_dict=return_dict, # 是否返回字典)# 从编码器的输出中获取序列输出,这里的encoder_outputs[0]值其实就是last_hidden_statesequence_output = encoder_outputs[0]# ------------------------------ *5. 池化层(Pooling Layer)--------------------------------# 如果存在池化层,则对序列输出进行池化,池化就是加了一层线性变换和tanh激活函数pooled_output = self.pooler(sequence_output) if self.pooler is not None else None# ------------------------------6. 返回输出结果--------------------------------------------# 如果 return_dict 为 False,返回一个元组,其中包含序列输出、池化输出和编码器输出中的其他部分if not return_dict:return (sequence_output, pooled_output) + encoder_outputs[1:]# 如果 return_dict 为 True,返回一个 BaseModelOutputWithPoolingAndCrossAttentions 对象return BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=sequence_output, # 最后隐藏状态pooler_output=pooled_output, # 池化输出past_key_values=encoder_outputs.past_key_values, # 过去的键值对hidden_states=encoder_outputs.hidden_states, # 隐藏状态attentions=encoder_outputs.attentions, # 注意力cross_attentions=encoder_outputs.cross_attentions, # 交叉注意力)