interLM的Transformer架构,重要模块的实现解析
Decoder架构
class InternLMDecoderLayer(nn.Module):def __init__(self, config: InternLMXComposerConfig):super().__init__()self.hidden_size = config.hidden_sizeif hasattr(config,'intern_converted_llm') and config.intern_converted_llm:self.self_attn = InternConvertedInternLMAttention(config=config)else:self.self_attn = InternLMAttention(config=config)self.mlp = InternLMMLP(hidden_size=self.hidden_size,intermediate_size=config.intermediate_size,hidden_act=config.hidden_act,config=config,)self.input_layernorm = InternLMRMSNorm(config.hidden_size,eps=config.rms_norm_eps)self.post_attention_layernorm = InternLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)def forward(self,hidden_states: torch.Tensor,attention_mask: Optional[torch.Tensor] = None,position_ids: Optional[torch.LongTensor] = None,past_key_value: Optional[Tuple[torch.Tensor]] = None,output_attentions: Optional[bool] = False,use_cache: Optional[bool] = False,) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor,torch.FloatTensor]]]:"""Args:hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`attention_mask (`torch.FloatTensor`, *optional*): attention mask of size`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.output_attentions (`bool`, *optional*):Whether or not to return the attentions tensors of all attention layers. See `attentions` underreturned tensors for more detail.use_cache (`bool`, *optional*):If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding(see `past_key_values`).past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states"""residual = hidden_stateshidden_states = self.input_layernorm(hidden_states)# Self Attentionhidden_states, self_attn_weights, present_key_value = self.self_attn(hidden_states=hidden_states,attention_mask=attention_mask,position_ids=position_ids,past_key_value=past_key_value,output_attentions=output_attentions,use_cache=use_cache,)hidden_states = residual + hidden_states# Fully Connectedresidual = hidden_stateshidden_states = self.post_attention_layernorm(hidden_states)hidden_states = self.mlp(hidden_states)hidden_states = residual + hidden_statesoutputs = (hidden_states, )if output_attentions:outputs += (self_attn_weights, )if use_cache:outputs += (present_key_value, )return outputs
MLP
- 两个MLP层+一个门控激活函数
class InternLMMLP(nn.Module):def __init__(self, hidden_size: int, intermediate_size: int,hidden_act: str, config: InternLMXComposerConfig):super().__init__()self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)if config.lora_cfg is not None and 'ffn' in config.lora_cfg['learn_param']:lora_cfg = config.lora_cfgself.down_proj = LoRALinear(intermediate_size,hidden_size,bias=False,**lora_cfg)self.up_proj = LoRALinear(hidden_size,intermediate_size,bias=False,**lora_cfg)else:self.down_proj = nn.Linear(intermediate_size,hidden_size,bias=False)self.up_proj = nn.Linear(hidden_size,intermediate_size,bias=False)self.act_fn = ACT2FN[hidden_act]def forward(self, x):return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
CausalAttention Mask
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(input_ids_shape: torch.Size,dtype: torch.dtype,device: torch.device,past_key_values_length: int = 0):"""Make causal mask used for bi-directional self-attention."""# 获取输入的形状,包括批量大小和目标长度bsz, tgt_len = input_ids_shape# 初始化一个形状为(目标长度, 目标长度)的tensor,用极小值填充. 即mask矩阵mask = torch.full((tgt_len, tgt_len),torch.tensor(torch.finfo(dtype).min, device=device),device=device)# 创建一个mask_cond张量,其范围是[0, tgt_len-1]mask_cond = torch.arange(mask.size(-1), device=device)# 根据条件进行填充,下三角为0,上三角为1mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)# 转换mask的数据类型为dtypemask = mask.to(dtype)# 如果过去键值的长度大于0,则将其拼接到mask的前面if past_key_values_length > 0:mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device),mask],dim=-1)# 返回形状为[bsz, 1, tgt_len, tgt_len + past_key_values_length]的maskreturn mask[None, None, :, :].expand(bsz, 1, tgt_len,tgt_len + past_key_values_length)
past_key_values_length
在Transformer中,past_key_values_length是指用于存储前一次计算的注意力键值对(key-value pairs)的长度。Transformer模型在处理较长的序列时,为了提高效率会使用存储,以避免重复计算。
-
当输入序列长度增加时,前一次的键值对会被缓存以供后续的注意力计算使用。这样可以节省计算时间,特别是在生成式任务中,如机器翻译或文本生成。
-
为什么用zeros?
如果past_key_values_length
大于0,即存在过去的键值对需要存储,我们需要将这些过去的键值对所对应的掩码(mask)拼接到当前的掩码中。
在这里,我们首先创建了一个与当前mask形状相同的全零张量,用于表示过去的掩码。然后,通过使用torch.cat
函数将这个全零张量和当前的mask进行拼接,以便将过去的信息与当前的信息合并在一起,形成一个更大的掩码张量。
详细解释一下如何创建CasualMask矩阵
当调用masked_fill_
函数时,我们传入了一个条件(mask_cond < (mask_cond + 1).view(mask.size(-1), 1))和一个填充值(0)。
这个条件 mask_cond < (mask_cond + 1).view(mask.size(-1), 1) 创建了一个下三角为True,上三角为False的条件掩码。
当我们执行 (mask_cond + 1).view(mask.size(-1), 1)
时,我们将 mask_cond
中的每个元素增加 1,并且重新塑造成一个列向量。假设 mask_cond
最初是一个长度为 4 的向量 [0, 1, 2, 3]
,那么在执行 +1
和 view
操作后得到的列向量就是:
[1]
[2]
[3]
[4]
现在,我们比较 mask_cond
和 (mask_cond+1).view(mask.size(-1), 1)
。我们发现,如果 mask_cond
中的元素小于列向量中对应位置的元素,这意味着该位置处于下三角区域。例如,在这个例子中,当我们比较原始向量和列向量时:
[0, 1, 2, 3] < [1]
[1, 2, 3, 4] < [2]
[2, 3, 4, 5] < [3]
[3, 4, 5, 6] < [4]
这将生成一个下三角为 True,上三角为 False 的布尔掩码,可以用于创建Mask。
masked_fill_
函数用条件掩码来填充张量。在这里,如果条件为True,对应位置将被填充为0。这样就实现了对角线以下的元素被填充为0,对角线以上的元素保持不变。
Attention Mask
def _expand_mask(mask: torch.Tensor,dtype: torch.dtype,tgt_len: Optional[int] = None):"""Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`."""bsz, src_len = mask.size()# 如果未提供目标序列长度,默认使用源序列的长度tgt_len = tgt_len if tgt_len is not None else src_len# 对输入的掩码进行扩展expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)# 创建一个反转的掩码inverted_mask = 1.0 - expanded_mask# 使用反转的掩码来填充掩码张量中的元素return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
- 使用反转的掩码来填充掩码张量中的元素的目的是将掩码中原本为0的位置填充为负无穷小。
在注意力计算中,当掩码中某个位置的元素为负无穷小时,经过softmax计算后,该位置对应的注意力权重会趋近于0,即忽略该位置的信息。这样做的目的是,在计算注意力时,我们希望掩码的位置能够有效地抑制相关位置的注意力权重,从而确保模型在处理序列时不会受到未来信息的影响,比如在解码阶段不会看到未来时刻的标记。因此,使用反转的掩码来填充掩码张量中的元素是为了在注意力计算中实现对未来信息的屏蔽。
RoPE
class InternLMRotaryEmbedding(torch.nn.Module):def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):super().__init__()# 计算频率,根据RoPE公式 1.0 / (base **(2 * (i // 2) / dim))inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))self.register_buffer("inv_freq", inv_freq) # 将频率注册为缓冲张量# 构建sin和cos缓存self.max_seq_len_cached = max_position_embeddings# t是位置索引t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)freqs = torch.einsum("i,j->ij", t, self.inv_freq) # 通过张量乘法计算频率emb = torch.cat((freqs, freqs), dim=-1) # 按照最后一个维度拼接sin和cosself.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) # 将cos缓存注册为缓冲张量self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) # 将sin缓存注册为缓冲张量def forward(self, x, seq_len=None):# x: [bs, num_attention_heads, seq_len, head_size]# 这个if块不太可能在构建sin/cos后运行。保持逻辑在这里以防万一。if seq_len > self.max_seq_len_cached:self.max_seq_len_cached = seq_lent = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)freqs = torch.einsum("i,j->ij", t, self.inv_freq) # 通过张量乘法计算频率emb = torch.cat((freqs, freqs), dim=-1).to(x.device) # 按照最后一个维度拼接sin和cosself.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) # 更新注册cos缓存self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) # 更新注册sin缓存# 返回缓存中的sin和cos张量,截取到指定的序列长度return (self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype))
def rotate_half(x):"""Rotates half the hidden dims of the input."""# 将输入张量沿最后一个维度分成两部分,执行旋转操作x1 = x[..., :x.shape[-1] // 2]x2 = x[..., x.shape[-1] // 2:]# 拼接结果返回return torch.cat((-x2, x1), dim=-1)def apply_rotary_pos_emb(q, k, cos, sin, position_ids):"""Applies rotary positional embeddings to input queries and keys.Args:q: 输入的查询张量k: 输入的键张量cos: cos缓存张量sin: sin缓存张量position_ids: 位置编码张量Returns:q_embed: 应用了旋转位置嵌入后的查询张量k_embed: 应用了旋转位置嵌入后的键张量"""# 根据position_ids创建索引张量gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])# 通过gather_indices选择对应的cos和sin张量cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)# 应用旋转位置嵌入公式得到新的查询张量和键张量q_embed = (q * cos) + (rotate_half(q) * sin)k_embed = (k * cos) + (rotate_half(k) * sin)return q_embed, k_embed
torch.gather
函数的参数包括:
input
:这是输入张量,从这个张量中收集值。dim
:这是一个整数值,表示在input
张量中收集数据的维度。index
:这是包含了索引的张量。根据这些索引,函数将从input
张量中收集对应的值。
基本语法为:torch.gather(input, dim, index)
。
cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
可以分解为以下几个步骤:
-
cos.repeat(gather_indices.shape[0], 1, 1, 1)
: 这一步是将cos
张量沿着每个维度进行复制以匹配gather_indices
的形状。repeat
函数会根据指定的次数沿着各个维度对原始张量进行复制。在这里,它会根据gather_indices.shape[0]
的值在第一个维度上进行复制,而不在其他维度进行复制。 -
torch.gather(repeated_cos, 2, gather_indices)
: 紧接着,我们使用torch.gather
函数根据gather_indices
中指定的索引从repeated_cos
中收集对应的值。对于序列中的每个位置,gather_indices
指定了从repeated_cos
张量中选择哪个值。
torch.gather
操作主要用于根据索引张量从源张量中收集对应的值。通过上述操作,我们能够根据gather_indices
为序列中的每个位置选择正确的cos值,并将其应用于后续的计算中。这是PyTorch中的常见技术,用于根据索引张量从张量中提取值。
LoRA
- 有意思的是,对LoRA做了改动
- 有点残差连接和RoPE的思想糅合到一起的操作
- x += res
- 中间断开,奇偶分开
class ConvertedLoRALinear(nn.Linear):def __init__(self,in_features: int,out_features: int,bias: bool = True,device=None,dtype=None,lora_r=8,lora_alpha=16,lora_dropout=0.05,**kwargs) -> None:super().__init__(in_features, out_features, bias, device, dtype)self.lora_r = lora_rself.lora_alpha = lora_alphaif lora_dropout > 0.:self.lora_dropout = nn.Dropout(p=lora_dropout)else:self.lora_dropout = lambda x: xself.lora_scaling = self.lora_alpha / self.lora_rself.lora_A = nn.Linear(in_features,self.lora_r,bias=False,device=device,dtype=dtype)self.lora_B = nn.Linear(self.lora_r,out_features,bias=False,device=device,dtype=dtype)self.reset_parameters()def reset_parameters(self):if hasattr(self, 'lora_A'):# initialize A the same way as the default for nn.Linear and B to zeronn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))nn.init.zeros_(self.lora_B.weight)# print ("lora weight init {} {}".format(torch.mean(self.lora_A.weight), torch.mean(self.lora_B.weight)))def forward(self, x):orig_type = x.dtyperes = super().forward(x)dim = int(res.shape[-1] // 2)r1 = res[..., :dim]r2 = res[..., dim:]r1 = r1.float()r2 = r2.float()x_ = x.float()tmp = self.lora_B(self.lora_A(self.lora_dropout(x_))) * self.lora_scalingtmp1 = tmp[..., ::2]tmp2 = tmp[..., 1::2]r1 += tmp1r2 += tmp2r1 = r1.to(orig_type)r2 = r2.to(orig_type)res = torch.cat([r1, r2], -1)# res += self.lora_B(self.lora_A(# self.lora_dropout(x))) * self.lora_scalingreturn res
关于生成是模型的Loss计算
outputs = self.model(input_ids=input_ids,attention_mask=attention_mask,position_ids=position_ids,past_key_values=past_key_values,inputs_embeds=inputs_embeds,query_embeds=query_embeds,use_cache=use_cache,output_attentions=output_attentions,output_hidden_states=output_hidden_states,return_dict=return_dict,)hidden_states = outputs[0]logits = self.lm_head(hidden_states)loss = Noneif labels is not None:# Shift so that tokens < n predict nshift_logits = logits[..., :-1, :].contiguous()shift_labels = labels[..., 1:].contiguous()# Flatten the tokensloss_fct = CrossEntropyLoss(reduce=False)loss_reduce = CrossEntropyLoss()shift_logits = shift_logits.view(-1, self.config.vocab_size)shift_labels = shift_labels.view(-1)shift_labels = shift_labels.to(shift_logits.device)###if self.sp_id >= 0:ori_mask = (shift_labels != self.sp_id).float()ori_mask = ori_mask * (shift_labels >= 0).float()local_mask = (shift_labels == self.sp_id).float()else:ori_mask = (shift_labels <self.config.vocab_size - self.ex_size).float()ori_mask = ori_mask * (shift_labels >= 0).float()local_mask = (shift_labels >=self.config.vocab_size - self.ex_size).float()# Enable model parallelismloss = loss_reduce(shift_logits, shift_labels)loss_all = loss_fct(shift_logits, shift_labels)loss_o = (loss_all * ori_mask).sum() / ori_mask.sum()if torch.sum(local_mask) == 0:loss_l = loss_o * 0else:loss_l = (loss_all * local_mask).sum() / local_mask.sum()
代码中loss计算的逐步解释:
1. 首先检查是否有标签(labels),如果有则继续计算loss,否则将loss保持为None。2. 在标签存在的情况下,对logits进行了一个向左的位移,这是因为模型中的输入数据和输出标签之间需要进行一定的位移。即把logits中的每个位置的预测,对应到相应位置期待的标签。3. 之后对logits和labels进行view操作,将其形状改变为2D的张量,以便进行交叉熵损失的计算。4. 根据self.sp_id的不同取值,计算了ori_mask和local_mask。ori_mask为了确保不计算特殊token(sp_id)的loss,local_mask则是用于计算特殊token(sp_id)的loss。5. 调用`CrossEntropyLoss`设置了两个不同的loss,loss_reduce用于在整个批次上计算损失,loss_fct则是用于对每个位置的损失值进行计算。6. 最后,计算了不同的部分的损失。loss_o计算了非特殊token的损失,而loss_l计算了特殊token的损失。如果local_mask全为0,则loss_l为0.
总结:该段代码进行了交叉熵损失的计算,但根据输入token是否为特殊token(sp_id),它分别计算了不同的loss值,即ori_mask用于过滤掉特殊token本身的loss,local_mask用于计算特殊token的loss。
- 这个loss的计算实际上是基于给定的vocabulary的多分类交叉熵损失。
在语言模型中,通常需要将模型的输出与词汇表中的token进行比较,以根据模型的预测计算损失。因此,将模型输出的logits与标签进行比较,并计算交叉熵损失,这通常用于语言模型中的训练过程。
如何修改Transformer模块
模仿GPT2,改为文本二分类任务
@add_start_docstrings("""The InternLM Model transformer with a sequence classification head on top (linear layer).[`InternLMForSequenceClassification`] uses the last token in order to do the classification, as other causal models(e.g. GPT-2) do.Since it does classification on the last token, it requires to know the position of the last token. If a`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. Ifno `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess thepadding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value ineach row of the batch).""",INTERNLM_START_DOCSTRING,
)
class InternLMForSequenceClassification(InternLMPreTrainedModel):_keys_to_ignore_on_load_missing = [r"lm_head.weight"]def __init__(self, config):super().__init__(config)self.num_labels = config.num_labelsself.model = InternLMModel(config)self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)# Initialize weights and apply final processingself.post_init()def get_input_embeddings(self):return self.model.embed_tokensdef set_input_embeddings(self, value):self.model.embed_tokens = value@add_start_docstrings_to_model_forward(INTERNLM_INPUTS_DOCSTRING)def forward(self,input_ids: torch.LongTensor = None,attention_mask: Optional[torch.Tensor] = None,position_ids: Optional[torch.LongTensor] = None,past_key_values: Optional[List[torch.FloatTensor]] = None,inputs_embeds: Optional[torch.FloatTensor] = None,labels: Optional[torch.LongTensor] = None,use_cache: Optional[bool] = None,output_attentions: Optional[bool] = None,output_hidden_states: Optional[bool] = None,return_dict: Optional[bool] = None,) -> Union[Tuple, SequenceClassifierOutputWithPast]:r"""labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If`config.num_labels > 1` a classification loss is computed (Cross-Entropy)."""return_dict = return_dict if return_dict is not None else self.config.use_return_dicttransformer_outputs = self.model(input_ids,attention_mask=attention_mask,position_ids=position_ids,past_key_values=past_key_values,inputs_embeds=inputs_embeds,use_cache=use_cache,output_attentions=output_attentions,output_hidden_states=output_hidden_states,return_dict=return_dict,)hidden_states = transformer_outputs[0]logits = self.score(hidden_states)if input_ids is not None:batch_size = input_ids.shape[0]else:batch_size = inputs_embeds.shape[0]if self.config.pad_token_id is None and batch_size != 1:raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")if self.config.pad_token_id is None:sequence_lengths = -1else:if input_ids is not None:sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)else:sequence_lengths = -1pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]loss = Noneif labels is not None:labels = labels.to(logits.device)if self.config.problem_type is None:if self.num_labels == 1:self.config.problem_type = "regression"elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):self.config.problem_type = "single_label_classification"else:self.config.problem_type = "multi_label_classification"if self.config.problem_type == "regression":loss_fct = MSELoss()if self.num_labels == 1:loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())else:loss = loss_fct(pooled_logits, labels)elif self.config.problem_type == "single_label_classification":loss_fct = CrossEntropyLoss()loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))elif self.config.problem_type == "multi_label_classification":loss_fct = BCEWithLogitsLoss()loss = loss_fct(pooled_logits, labels)if not return_dict:output = (pooled_logits,) + transformer_outputs[1:]return ((loss,) + output) if loss is not None else outputreturn SequenceClassifierOutputWithPast(loss=loss,logits=pooled_logits,past_key_values=transformer_outputs.past_key_values,hidden_states=transformer_outputs.hidden_states,attentions=transformer_outputs.attentions,)