无论是工作还是上学,使用Bert、RoBERTa或者Bert系列模型作为基底,在此基础之上构建神经网络分支进行微调,是非常简单、常见的一种任务实现方式。既然是基于别人的工作进行展开,那就有必要了解Bert系列模型的输出结果和结构。(不同版本的torch、tensorflow在Bert输出结构上可能略有不同)。
Bert模型实现的框架中,输出主要有3个:last_hidden_state、pooler_output 以及hidden_state。注:当然还有其他的一些输出,比如attentions、cross_attentions等,但是这些输出结果一般不用于任务)
1. last_hidden_state
这是模型的最后一层隐藏层状态,本身是一个三维张量,其形状为 [batch_size, sequence_length, hidden_size],它是模型的主要输出。对于每一个输入的token,last_hidden_state都会输出一个向量,形状为[1, hidden_size],也就是最后一维。
绝大多数任务都可以直接基于这个张量进行训练。
举一个实体匹配的例子:有两个实体text1和text2,输入Bert的形式一般为[CLS,text1,SEP,text2,SEP],其中 CLS 和 SEP 为特殊字符。Bert模型输出last_hidden_state,在这里有几种很直观的利用方式。
(1)不做任何处理直接将last_hidden_state,过Layer_norm、全连接等结构,最终经过sigmoid 或者 softmax 完成不同的任务。
(2)基于last_hidden_state,做cross_attention,然后经过基础网络层输出结果
(3)将text1和text2对应的token抽取并使用avgpool1()求平均(得到两个[1, hidden_size]的张量),这两个张量与后面描述的CLS特征向量做拼接。基于这个拼接后的结果做后续的操作(基础网络层)。
实验结果表明:(2)和(3)的效果更好一点。
2. pooler_output
它是Bert模型的输出中最后一层隐藏状态(last_hidden_state)的第一个token(通常是[CLS]标记)经过一个线性层和一个tanh激活函数得到的,形状为[batch, hidden_size]
我们通常将pooler_output理解为整个句子 or 整个输入的总结,因此这个输出通常会被用作句子级别的任务,比如句子分类。
在实际使用过程中,将其与last_hidden_state或者其他隐藏层向量进行拼接,是一种效果不错的用法。
3. hidden_states
这是模型的所有隐藏层的输出,它包含了模型的所有中间层的信息,形状为[13, batch_size, sequence_length, hidden_size](13 = 1层embedding layer 和 12层transfomer block)。
hidden_states的使用主要在于微调阶段,根据任务类型以及数据量的大小,可以手动控制哪一层的权重需要冻结,哪一层的权重需要重新训练/微调。
if config.freeze_layer_count: # 个人自定义的参数,控制是否需要启用参数冻结logger.info(f"frozen layers count of {str(config.freeze_layer_count)}")# We freeze here the embeddings of the model# for param in qm_sim_model.bert.embeddings.parameters(): # embedding层# param.requires_grad = Falseif config.freeze_layer_count != -1:# if freeze_layer_count == -1, we only freeze the embedding layer# otherwise we freeze the first `freeze_layer_count` encoder layersfor layer in qm_sim_model.bert.encoder.layer[:config.freeze_layer_count]:for param in layer.parameters():param.requires_grad = False# param.requires_grad = False,意味着不会计算该参数的梯度,也就是不会更新这个参数
根据需要,微调部分transformer block的参数
参考资料:
NLP系列(3)文本分类(Bert+TextCNN)pytorch - 知乎
https://zhuanlan.zhihu.com/p/654396722
Bert模型输出:last_hidden_state转换为pooler_output_获取bert中间层的输出-CSDN博客