OCR经典神经网络(三)LayoutLM v2算法原理及其在发票数据集上的应用(NER及RE)
- LayoutLM系列模型是微软发布的、文档理解多模态基础模型领域最重要和有代表性的工作:
- LayoutLM v2:在一个单一的多模态框架中对文本(text)、布局(layout)和图像(image)之间的交互进行建模。
- LayoutXLM:LayoutXLM是 LayoutLMv2的多语言扩展版本。
- LayoutLM v3:借鉴了ViLT和BEIT,不需要经过预训练的视觉backbone,通过MLM、MIM和WPA进行预训练的多模态Transformer。在以视觉为中心的任务上(如文档图像分类和文档布局分析)和以文本为中心的任务上(表单理解、收据理解、文档问答)都表现很好。
- 今天,我们来了解下LayoutLM v2模型。
- 论文链接:https://arxiv.org/pdf/2012.14740
- 同样,百度开源的paddleocr中,在
关键信息抽取
中集成了此算法。 - paddleocr中集成的算法列表:https://github.com/PaddlePaddle/PaddleOCR/blob/main/docs/algorithm/overview.md
1 LayoutLM v2算法原理
- LayoutLM v2是一种多模态Transformer模型,该模型在预训练阶段整合了文档文本、版式及视觉信息,实现了在一个框架内端到端地学习跨模态交互。同时,将一种空间感知的自注意力机制融入到了Transformer架构中。
- 除了掩码视觉语言模型(MVLM)预训练策略外,LayoutLM v2还新增了文本-图像对齐(TIA)和文本-图像匹配(TIM) 作为预训练策略,以强化不同模态间的对齐。
- LayoutLMv2不仅在传统的富视觉文档理解(VrDU)任务上取得了显著的性能提升并达到当时新的最优水平,还在文档图像的视觉问题回答(VQA)任务上实现了新突破,这证明了多模态预训练在富视觉文档理解领域的巨大潜力。
1.1 模型结构
-
模型结构如下图所示,可以看到LayoutLM v2接收文本、视觉及版式信息作为输入,以建立深度的跨模态交互。另外,将spatial-aware的自注意力机制整合到了transformer中。
-
这里,我们主要看下Embedding层:
-
文本嵌入
- 文本嵌入包含三种嵌入:词嵌入代表词本身,一维位置嵌入表示词的位置索引,而段落嵌入用于区分不同的文本段落。
t i = T o k E m b ( w i ) + P o s E m b 1 D ( i ) + S e g E m b ( s i ) t_i= TokEmb(w_i)+PosEmb1D(i)+SegEmb(s_i) ti=TokEmb(wi)+PosEmb1D(i)+SegEmb(si)
- 使用WordPiece对OCR文本序列进行分词,并将每个分词(token)分配给特定的段落。接着,在序列的开始添加[CLS]标记,在每个文本段落的末尾添加[SEP]标记。为了使最终序列的长度恰好等于最大序列长度L,在序列末尾额外添加[PAD]填充符。
-
视觉嵌入
- 给定一个文档页面图像I,将其调整大小至224×224像素后输入到视觉主干网络中。之后,输出的特征图通过平均池化到固定尺寸,宽度为W,高度为H。接下来,它被展平为长度为W×H(例如:7×7)的视觉嵌入序列,此序列被称为VisTokEmb(I)。然后
对每个视觉token嵌入应用线性投影层,以统一其维度与文本嵌入的维度
。 - 由于基于CNN的视觉主干无法捕获位置信息,因此添加一维位置嵌入。
- 对于段落嵌入,将所有视觉令牌附属于视觉段[C]。
v i = P r o j ( V i s T o k E m b ( I ) i + P o s E m b 1 D ( i ) + S e g E m b ( [ C ] ) v_i= Proj(VisTokEmb(I)_i+PosEmb1D(i)+SegEmb([C]) vi=Proj(VisTokEmb(I)i+PosEmb1D(i)+SegEmb([C])
- 给定一个文档页面图像I,将其调整大小至224×224像素后输入到视觉主干网络中。之后,输出的特征图通过平均池化到固定尺寸,宽度为W,高度为H。接下来,它被展平为长度为W×H(例如:7×7)的视觉嵌入序列,此序列被称为VisTokEmb(I)。然后
-
布局嵌入(2D Position Embeddings)
- 将所有的坐标标准化并离散化为[0, 1000]范围内的整数,并使用两个嵌入层分别嵌入x轴特征和y轴特征。
- 给定第i个( 0 ≤ i < W × H + L 0 ≤ i < W×H + L 0≤i<W×H+L)文本/视觉token的标准化边界框 b o x i = ( x m i n , x m a x , y m i n , y m a x , w i d t h , h e i g h t ) box_i = (x_{min}, x_{max}, y_{min}, y_{max}, width, height) boxi=(xmin,xmax,ymin,ymax,width,height),布局嵌入层将这六个边界框特征连接起来构建一个token级的2D位置嵌入,即布局嵌入。
-
- 由于卷积神经网络(CNNs)执行局部变换,因此视觉token嵌入可以一一映射回图像区域,既没有重叠也没有遗漏。
在计算边界框时,视觉token可以被视为均匀划分的网格。
- 对于特殊token [CLS]、[SEP]和[PAD],会附加一个空边界框boxPAD = (0, 0, 0, 0, 0, 0)。这意味着这些特殊符号在空间布局上不占用实际区域,但通过这样的空边界框嵌入,模型能够将它们整合到序列中的相应位置上,同时保持空间信息的一致性。
1.2 预训练目标及数据
1.2.1 MVLM
- 采用了掩码视觉-语言建模(Masked Visual-Language Modeling, MVLM)方法,以便模型在跨模态线索的帮助下更好地学习语言方面。
- 随机掩蔽一些文本token,并要求模型恢复这些被掩蔽的token。
- 与此同时,布局信息保持不变,这意味着模型了解每个被掩蔽token在页面上的位置。
- 为了避免视觉线索泄露,在将原始页面图像输入到视觉编码器之前,会先对应掩蔽掉与被掩蔽文本token相对应的图像区域。
1.2.2 TIA
- Text-Image Alignment(TIA):随机遮盖图像,然后识别文本对应图像是否被遮盖了。
- 为了帮助模型学习图像与边界框坐标的空間位置对应关系,提出了细粒度的跨模态对齐任务——文本-图像对齐(Text-Image Alignment, TIA)。
- 在TIA任务中,随机选择一些文本行,并在其文档图像上的对应图像区域进行遮盖, 称此操作为“遮盖”,以避免与MVLM中的“掩码”操作混淆。
- 预训练期间,在编码器输出之上构建了一个分类层。该层根据文本令牌是否被遮盖(即,[Covered]或[Not Covered])预测每个文本令牌的标签,并计算二元交叉熵损失。
- 考虑到输入图像的分辨率有限,且某些文档元素(如图表中的符号和线条)可能看起来像被遮盖的文本区域,寻找单词大小的遮盖图像区域的任务可能会存在噪声。因此,遮盖操作是在行级别进行的。
- 当MVLM和TIA同时执行时,MVLM中被掩蔽的令牌的TIA损失不予考虑。这防止了模型学习从[MASK]到[Covered]这种无用但直观的对应关系。
1.2.3 TIM
- Text-Image Matching(TIM):使用[CLS]来判断给出的图片特征与文本特征是否属于同一个页面。
- 为了帮助模型学习文档图像与文本内容之间的对应关系,采用了较为粗粒度的跨模态对齐任务,即文本-图像匹配(Text-Image Matching, TIM)。
- 将[CLS]位置的输出表示送入一个分类器,以预测图像和文本是否来自同一文档页面。正常的配对输入被视为正样本。
- 为了构建负样本,图像要么被另一文档的页面图像替换,要么被移除。
- 为防止模型通过寻找任务特定特征来作弊,对负面样本中的图像也执行相同的掩码和遮盖操作。在负面样本中,TIA的目标标签全部设置为[Covered]。
1.2.4 预训练数据
-
为了预训练和评估LayoutLMv2模型,作者从富含视觉元素的文档理解领域中选择了广泛的数据集。
-
使用IIT-CDIP作为预训练数据集。
1.3 模型微调
- 在文档级别分类任务RVL-CDIP中,使用[CLS]输出以及池化的视觉令牌表示作为全局特征。
- 对于提取式问答任务DocVQA及其他四个实体提取任务,在LayoutLMv2输出的文本部分上
构建特定任务的头部层
。在DocVQA论文中,实验结果显示,在SQuAD数据集上微调过的BERT模型比原始BERT模型表现更优。受此启发,增加了一个额外的设置:首先在问题生成(Question Generation, QG)数据集上微调LayoutLMv2,随后再在DocVQA数据集上微调。这个QG数据集包含近百万对由训练于SQuAD数据集的生成模型产生的问题-答案对。
1.4 LayoutXLM模型结构
- LayoutXLM是 LayoutLMv2的多语言扩展版本。为了准确评估LayoutXLM,论文中还引入了一个多语言表单理解基准数据集,名为XFUND,该数据集包含了7种语言(中文、日语、西班牙语、法语、意大利语、德语、葡萄牙语)的表单理解样本,并为每种语言的手工标注了键值对。
- 论文链接:https://arxiv.org/pdf/2104.08836
- LayoutXLM预训练策略,同LayoutLMv2
- 该框架如下图所示:
- 模型接收来自三种不同模态的信息,即文本、布局和图像,分别使用文本嵌入、布局嵌入和视觉嵌入层进行编码。文本和图像嵌入被连接在一起,然后加上布局嵌入以获得输入嵌入。
- 输入嵌入通过带有空间感知自注意力机制的多模态Transformer进行编码。
- 最后,输出的上下文表示可以用于后续的任务特定层。
1.5 VI-LayoutXLM
-
百度在PP-StructureV2中,针对 LayoutXLM 进行改进,得到了VI-LayoutXLM。
-
论文链接:https://arxiv.org/pdf/2210.05391
-
模型部分改进如下:
- LayoutLMv2 以及 LayoutXLM 中引入视觉骨干网络,用于提取视觉特征,并与后续的 text embedding 进行联合,作为多模态的输入 embedding。但是该模块为基于 ResNet_x101_64x4d 的特征提取网络,特征抽取阶段耗时严重。
- 因此,
移除视觉特征提取模块
,同时仍然保留文本、位置以及布局等信息,最终发现针对 LayoutXLM 进行改进,下游 SER 任务精度无损,针对 LayoutLMv2 进行改进,下游 SER 任务精度仅降低2.1%,而模型大小减小了约340M。
2 VI-LayoutXLM在发票数据集上的应用
-
关键信息抽取 (Key Information Extraction, KIE)
指的是是从文本或者图像中,抽取出关键的信息。- 针对文档图像的关键信息抽取任务作为OCR的下游任务,存在非常多的实际应用场景,如表单识别、车票信息抽取、身份证信息抽取等。
- 文档图像中的KIE一般包含2个子任务,示意图如下图所示。
SER: 语义实体识别 (Semantic Entity Recognition)
,对每一个检测到的文本进行分类,如将其分为姓名,身份证。如下图中的黑色框和红色框。RE: 关系抽取 (Relation Extraction)
,对每一个检测到的文本进行分类,如将其分为问题 (key) 和答案 (value) 。然后对每一个问题找到对应的答案,相当于完成key-value的匹配过程。如下图中的红色框和黑色框分别代表问题和答案,黄色线代表问题和答案之间的对应关系。
-
除了
视觉特征无关的多模态预训练模型结构
,paddleocr中在KIE任务上,还有两个主要的优化策略:TB-YX:考虑阅读顺序的文本行排序逻辑
- 文本阅读顺序对于信息抽取与文本理解等任务至关重要,传统多模态模型中,没有考虑不同 OCR 工具可能产生的不正确阅读顺序,而模型输入中包含位置编码,阅读顺序会直接影响预测结果
- 在预处理中,对文本行按照从上到下,从左到右(YX)的顺序进行排序,为防止文本行位置轻微干扰带来的排序结果不稳定问题,在排序的过程中,引入位置偏移阈值 Th,对于 Y 方向距离小于 Th 的2个文本内容,使用 X 方向的位置从左到右进行排序。
UDML:联合互学习知识蒸馏策略
- UDML(Unified-Deep Mutual Learning)联合互学习是 PP-OCRv2 与 PP-OCRv3 中采用的对于文本识别非常有效的提升模型效果的策略。
- 在训练时,引入2个完全相同的模型进行互学习,计算2个模型之间的互蒸馏损失函数(DML loss),同时对 transformer 中间层的输出结果计算距离损失函数(L2 loss)。
- 使用该策略,最终 XFUND 数据集上,SER 任务 F1 指标提升0.6%,RE 任务 F1 指标提升5.01%。
-
KIE常用思路有如下两种:
-
一种是SER:
- 直接使用SER,获取关键信息的类别;常用于关键信息类别固定的场景。
- 以身份证场景为例, 关键信息一般包含
姓名
、性别
、民族
等,我们直接将对应的字段标注为特定的类别即可,如下图所示:
-
注意:
- 标注过程中,对于无关于KIE关键信息的文本内容,均需要将其标注为
other
类别,相当于背景信息。如在身份证场景中,如果我们不关注性别信息,那么可以将“性别”与“男”这2个字段的类别均标注为other
。 - 标注过程中,需要以文本行为单位进行标注,无需标注单个字符的位置信息。
数据量方面,一般来说,对于比较固定的场景,50张左右的训练图片即可达到可以接受的效果,可以使用PPOCRLabel完成KIE的标注过程。
- 标注过程中,对于无关于KIE关键信息的文本内容,均需要将其标注为
-
一种是SER+RE:
- 联合使用SER+RE,先利用SER找到key和value,然后再利用RE进行匹配;常用于关系类别不固定的场景。
- 以身份证场景为例, 关键信息一般包含
姓名
、性别
、民族
等关键信息。在SER阶段,我们需要识别所有的question (key) 与answer (value) 。每个字段的类别信息(label
字段)可以是question、answer或者other(与待抽取的关键信息无关的字段)
- 在RE阶段,需要标注每个字段的的id与连接信息,如下图所示:
- 标注过程中,如果value是多个字符,那么linking中可以新增一个key-value对,如
[[0, 1], [0, 2]]
- 数据量方面,一般来说,对于比较固定的场景,50张左右的训练图片即可达到可以接受的效果,可以使用PPOCRLabel完成KIE的标注过程。
- 标注过程中,如果value是多个字符,那么linking中可以新增一个key-value对,如
-
我们参考案例:https://aistudio.baidu.com/projectdetail/4823162(
项目里提供了发票数据集
),来对VI-LayoutXLM模型有更深的认识。
-
2.1 语义实体识别 (SER)
2.1.1 模型构建
-
我这里不用命令行执行,在
paddleocr\tests
目录下创建一个py文件执行训练过程 -
我们复制一份
paddleocr\configs\kie\vi_layoutxlm\ser_vi_layoutxlm_xfund_zh_udml.yml
文件到paddleocr\tests\configs进行修改(参考上面项目链接进行修改),发票数据集在上面项目中已提供,模型部分的配置文件如下:Architecture:model_type: &model_type "kie"name: DistillationModelalgorithm: DistillationModels:Teacher:pretrained:freeze_params: falsereturn_all_feats: truemodel_type: *model_typealgorithm: &algorithm "LayoutXLM"Transform:Backbone:name: LayoutXLMForSerpretrained: True # 会利用paddle-nlp加载预训练模型# one of base or vimode: vicheckpoints:num_classes: &num_classes 5 # 采用BIO的标注,训练需要修改Student:pretrained:freeze_params: falsereturn_all_feats: truemodel_type: *model_typealgorithm: *algorithmTransform:Backbone:name: LayoutXLMForSerpretrained: True# one of base or vimode: vicheckpoints:num_classes: *num_classes
-
通过下面的py文件,我们就可以愉快的查看源码了。
def train_kie_token_ser_demo():from tools.train import program, set_seed, main# 配置文件的源地址地址: paddleocr\configs\kie\vi_layoutxlm\ser_vi_layoutxlm_xfund_zh_udml.ymlconfig, device, logger, vdl_writer = program.preprocess(is_train=True)###############修改配置(也可在yml文件中修改)################### 评估频率config["Global"]["eval_batch_step"] = [0, 200]# log的打印频率config["Global"]["print_batch_step"] = 50# 训练的epochsconfig["Global"]["epoch_num"] = 1# 随机种子seed = config["Global"]["seed"] if "seed" in config["Global"] else 1024set_seed(seed)###############模型训练##################main(config, device, logger, vdl_writer, seed)def train_kie_token_re_demo():from tools.train import program, set_seed, main# 配置文件的源地址地址: paddleocr\configs\kie\vi_layoutxlm\re_vi_layoutxlm_xfund_zh_udml.ymlconfig, device, logger, vdl_writer = program.preprocess(is_train=True)###############修改配置(也可在yml文件中修改)################### 评估频率config["Global"]["eval_batch_step"] = [0, 200]# log的打印频率config["Global"]["print_batch_step"] = 50# 训练的epochsconfig["Global"]["epoch_num"] = 1# 随机种子seed = config["Global"]["seed"] if "seed" in config["Global"] else 1024set_seed(seed)###############模型训练##################main(config, device, logger, vdl_writer, seed)if __name__ == '__main__':train_kie_token_ser_demo()# train_kie_token_re_demo()
LayoutXLMForTokenClassification
- 首先,利用LayoutXLMModel提取特征(文本、布局信息)
- 然后,利用文本部分的特征进行BIO多分类
# paddleocr.ppocr.modeling.backbones.vqa_layoutlm.py
class LayoutXLMForTokenClassification(LayoutXLMPretrainedModel):def __init__(self, config: LayoutXLMConfig):super(LayoutXLMForTokenClassification, self).__init__(config)self.num_classes = config.num_labelsself.layoutxlm = LayoutXLMModel(config)self.dropout = nn.Dropout(config.hidden_dropout_prob)self.classifier = nn.Linear(config.hidden_size, self.num_classes)......def forward(self,input_ids=None,bbox=None,image=None,attention_mask=None,token_type_ids=None,position_ids=None,head_mask=None,labels=None,):# 1、经过12层的Transformer Block Encoderoutputs = self.layoutxlm(input_ids=input_ids,bbox=bbox,image=image,attention_mask=attention_mask,token_type_ids=token_type_ids,position_ids=position_ids,head_mask=head_mask,)seq_length = input_ids.shape[1]# sequence out and image out# 2、进行BIO多分类# sequence_output: (bs, 561, 768) -> (bs, 512, 768) -> (bs, 512, 5)sequence_output = outputs[0][:, :seq_length]sequence_output = self.dropout(sequence_output)logits = self.classifier(sequence_output)hidden_states = {f"hidden_states_{idx}": outputs[2][f"{idx}_data"] for idx in range(self.layoutxlm.config.num_hidden_layers)}if self.training:outputs = (logits, hidden_states)else:outputs = (logits,)......return outputs
LayoutXLMModel
这里我们主要看下LayoutXLMModel模型中,文本的embedding和视觉部分的embedding。
-
文本的embedding:
-
word_embeddings:对tokenizer后的input_ids进行word_embeddings,shape变化:(bs, 512) -> (bs, 512, 768)
-
position_embeddings(1D position embedding):对文本部分的position_ids进行embeding,shape变化:(bs, 512) -> (bs, 512, 768)。这里,文本和视觉的position_embeddings是共享的。
-
spatial_position_embeddings:这里shape变化为(bs, 512, 4) -> (bs, 512, 768),是将每一个bbox信息的(x_min, y_min, x_max, y_max, h, w)编码,然后concat得到,代码如下所示。注意:如果一个bbox内的文字,被切分为多个token,那么这些token的bbox信息是一致的。
# paddlenlp.transformers.layoutxlm.modeling.pydef _cal_spatial_position_embeddings(self, bbox):try:# (bs, embdedding_dim) -> (bs, embdedding_dim, 128)left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0])# (bs, embdedding_dim) -> (bs, embdedding_dim, 128)upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1])# (bs, embdedding_dim) -> (bs, embdedding_dim, 128)right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2])# (bs, embdedding_dim) -> (bs, embdedding_dim, 128)lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3])except IndexError as e:raise IndexError("The :obj:`bbox`coordinate values should be within 0-1000 range.") from e# (bs, embdedding_dim) -> (bs, embdedding_dim, 128)h_position_embeddings = self.h_position_embeddings(bbox[:, :, 3] - bbox[:, :, 1])# (bs, embdedding_dim) -> (bs, embdedding_dim, 128)w_position_embeddings = self.w_position_embeddings(bbox[:, :, 2] - bbox[:, :, 0])# [x_min, y_min, x_max, y_max, h, w] concat -> (bs, embdedding_dim, 128*6)spatial_position_embeddings = paddle.concat([left_position_embeddings,upper_position_embeddings,right_position_embeddings,lower_position_embeddings,h_position_embeddings,w_position_embeddings,],axis=-1,)return spatial_position_embeddings
-
token_type_embeddings:这里的token_type_ids全为0,shape变化为(bs, 512) -> (bs, 512, 768)
-
-
视觉部分的embedding:
- position_embeddings(1D position embedding):
shape变化为(bs, 49) -> (bs, 49, 768)
。视觉部分的position ids为:[0, 1, 2, …, 48] -> (bs, 49)。这里虽然去除了视觉提取,但是position ids按照图像224×224经过降采样32倍后的feature map:7×7进行生成。这里,文本和视觉的position_embeddings是共享的; - spatial_position_embeddings:视觉部分布局信息,即bbox的生成的核心逻辑是:7×7网格中,每一个小的正方形的坐标(x_min, y_min, x_max, y_max)即为一个视觉token。
shape变化为(bs, 49, 4) -> (bs, 49, 768)
; - visual_segment_embedding
- position_embeddings(1D position embedding):
-
最终,将文本的embedding和视觉部分的embedding送入到12层的Transformer Encoder Block提取特征。
# paddlenlp.transformers.layoutxlm.modeling.py
@register_base_model
class LayoutXLMModel(LayoutXLMPretrainedModel):def __init__(self, config: LayoutXLMConfig):super(LayoutXLMModel, self).__init__(config)self.config = configself.use_visual_backbone = config.use_visual_backboneself.has_visual_segment_embedding = config.has_visual_segment_embeddingself.embeddings = LayoutXLMEmbeddings(config)if self.use_visual_backbone is True:self.visual = VisualBackbone(config)self.visual.stop_gradient = Trueself.visual_proj = nn.Linear(config.image_feature_pool_shape[-1], config.hidden_size)if self.has_visual_segment_embedding:self.visual_segment_embedding = self.create_parameter(shape=[config.hidden_size,],dtype=paddle.float32,)self.visual_LayerNorm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps)self.visual_dropout = nn.Dropout(config.hidden_dropout_prob)self.encoder = LayoutXLMEncoder(config)self.pooler = LayoutXLMPooler(config)def _calc_visual_bbox(self, image_feature_pool_shape, bbox, visual_shape):"""视觉部分布局信息,即bbox的生成:- image_feature_pool_shape:(7, 7, 256)- 文字token的bbox信息:(bs, 512, 4)- visual_shape:[bs, 49]"""# 首先,生成一个序列[0, 1000, 2000, 3000, 4000, 5000, 6000, 7000]# 然后,离散化为[0, 1000],即[0, 142, 285, 428, 571, 714, 857, 1000]visual_bbox_x = (paddle.arange(0,1000 * (image_feature_pool_shape[1] + 1),1000,dtype=bbox.dtype,)// image_feature_pool_shape[1])visual_bbox_y = (paddle.arange(0,1000 * (image_feature_pool_shape[0] + 1),1000,dtype=bbox.dtype,)// image_feature_pool_shape[0])expand_shape = image_feature_pool_shape[0:2] # (7, 7)# 7×7网格中,每一个小的正方形的坐标(x_min, y_min, x_max, y_max)即为一个视觉token# visual_bbox shape = (7×7, 4)visual_bbox = paddle.stack([visual_bbox_x[:-1].expand(expand_shape),visual_bbox_y[:-1].expand(expand_shape[::-1]).transpose([1, 0]),visual_bbox_x[1:].expand(expand_shape),visual_bbox_y[1:].expand(expand_shape[::-1]).transpose([1, 0]),],axis=-1,).reshape([expand_shape[0] * expand_shape[1], paddle.shape(bbox)[-1]])# 扩展到bs个样本, (7×7, 4) -> (bs, 7×7, 4)visual_bbox = visual_bbox.expand([visual_shape[0], visual_bbox.shape[0], visual_bbox.shape[1]])return visual_bboxdef _calc_text_embeddings(self, input_ids, bbox, position_ids, token_type_ids):"""文本部分进行embeddings:word_embeddings+ position_embeddings(文本和视觉的position_embeddings是共享的)+ spatial_position_embeddings+ token_type_embeddings"""# (bs, 512) -> (bs, 512, 768)words_embeddings = self.embeddings.word_embeddings(input_ids)# (bs, 512) -> (bs, 512, 768)position_embeddings = self.embeddings.position_embeddings(position_ids)# (bs, 512, 4) -> (bs, 512, 768)spatial_position_embeddings = self.embeddings._cal_spatial_position_embeddings(bbox)# (bs, 512) -> (bs, 512, 768)token_type_embeddings = self.embeddings.token_type_embeddings(token_type_ids)# 4种embedding相加embeddings = words_embeddings + position_embeddings + spatial_position_embeddings + token_type_embeddings# LayerNorm + dropoutembeddings = self.embeddings.LayerNorm(embeddings)embeddings = self.embeddings.dropout(embeddings)return embeddingsdef _calc_img_embeddings(self, image, bbox, position_ids):"""视觉部分进行embedding:position_embeddings(文本和视觉的position_embeddings是共享的)+ spatial_position_embeddings+ visual_segment_embedding"""use_image_info = self.use_visual_backbone and image is not None# (bs, 49) -> (bs, 49, 768)position_embeddings = self.embeddings.position_embeddings(position_ids)# (bs, 49, 4) -> (bs, 49, 768)spatial_position_embeddings = self.embeddings._cal_spatial_position_embeddings(bbox)if use_image_info is True:visual_embeddings = self.visual_proj(self.visual(image.astype(paddle.float32)))embeddings = visual_embeddings + position_embeddings + spatial_position_embeddingselse:# embedding相加embeddings = position_embeddings + spatial_position_embeddingsif self.has_visual_segment_embedding:# self.visual_segment_embedding shape = (768)embeddings += self.visual_segment_embedding# visual_LayerNorm + visual_dropoutembeddings = self.visual_LayerNorm(embeddings)embeddings = self.visual_dropout(embeddings)return embeddingsdef forward(self,input_ids=None,bbox=None,image=None,token_type_ids=None,position_ids=None,attention_mask=None,head_mask=None,output_hidden_states=False,output_attentions=False,):input_shape = paddle.shape(input_ids)visual_shape = list(input_shape)visual_shape[1] = self.config.image_feature_pool_shape[0] * self.config.image_feature_pool_shape[1]# 视觉部分的bbox的生成# 视觉token被视为均匀划分的网格# 生成的bbox信息:feature_map(7×7)网格中,每一个小的正方形的坐标(x_min, y_min, x_max, y_max)即为一个视觉tokenvisual_bbox = self._calc_visual_bbox(self.config.image_feature_pool_shape, bbox, visual_shape)# 1、2D position embedding(文本部分bbox+视觉部分bbox)# (bs, 512, 4) + (bs, 49, 4) -> (bs, 561, 4)final_bbox = paddle.concat([bbox, visual_bbox], axis=1)if attention_mask is None:attention_mask = paddle.ones(input_shape)if self.use_visual_backbone is True:# 使用视觉部分的backbonevisual_attention_mask = paddle.ones(visual_shape)else:# 移除视觉特征提取模块,mask全设置为0visual_attention_mask = paddle.zeros(visual_shape)attention_mask = attention_mask.astype(visual_attention_mask.dtype)# concat后attention_mask:(bs, 512) + (bs, 49) -> (bs, 561)final_attention_mask = paddle.concat([attention_mask, visual_attention_mask], axis=1)if token_type_ids is None:token_type_ids = paddle.zeros(input_shape, dtype=paddle.int64)# 2、1D position embedding(文本部分+视觉部分) (bs, 512) + (bs, 49) -> (bs, 561)if position_ids is None:# 文本部分的position embeddingseq_length = input_shape[1]position_ids = self.embeddings.position_ids[:, :seq_length]position_ids = position_ids.expand(input_shape)# 视觉部分的position embedding# [0, 1, 2, ..., 48] -> (bs, 49)visual_position_ids = paddle.arange(0, visual_shape[1]).expand([input_shape[0], visual_shape[1]])final_position_ids = paddle.concat([position_ids, visual_position_ids], axis=1)if bbox is None:bbox = paddle.zeros(input_shape + [4])# 3、 text embedding & visual (bs, 512, 768) + (bs, 49, 768) -> (bs, 561, 768)# 文本部分进行embdedding (bs, 512) -> (bs, 512, 768)text_layout_emb = self._calc_text_embeddings(input_ids=input_ids,bbox=bbox,token_type_ids=token_type_ids,position_ids=position_ids,)# 视觉部分进行embedding(注意此时没有image,仅有视觉的bbox以及position_ids)visual_emb = self._calc_img_embeddings(image=image,bbox=visual_bbox,position_ids=visual_position_ids,)final_emb = paddle.concat([text_layout_emb, visual_emb], axis=1)# (bs, 561) -> (bs, 1, 1, 561)extended_attention_mask = final_attention_mask.unsqueeze(1).unsqueeze(2)extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0if head_mask is not None:if head_mask.dim() == 1:head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)elif head_mask.dim() == 2:head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)else:head_mask = [None] * self.config.num_hidden_layers# 经过Transformer Encoder Block(12层)encoder_outputs = self.encoder(final_emb, # 文本&视觉部分的embedding , shape=(bs, 561, 768)extended_attention_mask, # attention_mask , shape=(bs, 1, 1, 561)bbox=final_bbox, # 2D position embedding【如果需要相对位置位置编码,加在attention_score上,这里为False】, shape=(bs, 561, 4)position_ids=final_position_ids, # 1D position embedding【如果需要相对位置位置编码,加在attention_score上,这里为False】, shape=(bs, 561)head_mask=head_mask,output_attentions=output_attentions,output_hidden_states=output_hidden_states,)# sequence_output shape = (bs, 561, 768)sequence_output = encoder_outputs[0]# pooled_output shape = (bs, 768)pooled_output = self.pooler(sequence_output)return sequence_output, pooled_output, encoder_outputs[1]
2.1.2 损失计算
- 由于使用了
UDML:联合互学习知识蒸馏策略
,损失计算的配置如下:
Loss:name: CombinedLoss # ppocr.losses.combined_loss.CombinedLossloss_config_list:- DistillationVQASerTokenLayoutLMLoss: # GT loss ppocr.losses.distillation_loss.DistillationVQASerTokenLayoutLMLossweight: 1.0model_name_list: ["Student", "Teacher"]key: backbone_outnum_classes: *num_classes- DistillationSERDMLLoss: # DML loss ppocr.losses.distillation_loss.DistillationSERDMLLossweight: 1.0act: "softmax"use_log: truemodel_name_pairs:- ["Student", "Teacher"]key: backbone_out- DistillationVQADistanceLoss: # S5 loss ppocr.losses.distillation_loss.DistillationVQADistanceLossweight: 0.5mode: "l2"model_name_pairs:- ["Student", "Teacher"]key: hidden_states_5name: "loss_5"- DistillationVQADistanceLoss: # S8 loss ppocr.losses.distillation_loss.DistillationVQADistanceLossweight: 0.5mode: "l2"model_name_pairs:- ["Student", "Teacher"]key: hidden_states_8name: "loss_8"
- 如下所示,在DistillationModel中,Teacher和Student模型分别进行前向过程
# paddleocr.ppocr.modeling.architectures.distillation_model.py
class DistillationModel(nn.Layer):def __init__(self, config):"""the module for OCR distillation.args:config (dict): the super parameters for module."""super().__init__()self.model_list = []self.model_name_list = []for key in config["Models"]:model_config = config["Models"][key]freeze_params = Falsepretrained = Noneif "freeze_params" in model_config:freeze_params = model_config.pop("freeze_params")if "pretrained" in model_config:pretrained = model_config.pop("pretrained")model = BaseModel(model_config)if pretrained is not None:load_pretrained_params(model, pretrained)if freeze_params:for param in model.parameters():param.trainable = Falseself.model_list.append(self.add_sublayer(key, model))self.model_name_list.append(key)def forward(self, x, data=None):result_dict = dict()# 执行所有模型的前向过程, 例如:Teacher和Student模型for idx, model_name in enumerate(self.model_name_list):result_dict[model_name] = self.model_list[idx](x, data)return result_dict
- 在CombinedLoss中遍历配置的损失函数,分别计算损失,最后相加最为总损失
# paddleocr.ppocr.losses.combined_loss.py
class CombinedLoss(nn.Layer):"""CombinedLoss:a combionation of loss function"""def __init__(self, loss_config_list=None):super().__init__()self.loss_func = []self.loss_weight = []assert isinstance(loss_config_list, list), "operator config should be a list"......def forward(self, input, batch, **kargs):# input包含Teacher模型以及Student模型的输出结果# batch是批次数据,里面包含labelloss_dict = {}loss_all = 0.0# 遍历配置的所有的损失函数,计算损失for idx, loss_func in enumerate(self.loss_func):loss = loss_func(input, batch, **kargs)if isinstance(loss, paddle.Tensor):loss = {"loss_{}_{}".format(str(loss), idx): loss}weight = self.loss_weight[idx]loss = {key: loss[key] * weight for key in loss}if "loss" in loss:loss_all += loss["loss"]else:loss_all += paddle.add_n(list(loss.values()))loss_dict.update(loss)loss_dict["loss"] = loss_allreturn loss_dict
-
我们看下具体配置的损失函数:
-
DistillationVQASerTokenLayoutLMLoss的实质就是每个模型分别计算NER任务的CrossEntropyLoss,即GT loss:
class DistillationVQASerTokenLayoutLMLoss(VQASerTokenLayoutLMLoss):def __init__(self, num_classes, model_name_list=[], key=None, name="loss_ser"):super().__init__(num_classes=num_classes)self.model_name_list = model_name_listself.key = keyself.name = namedef forward(self, predicts, batch):loss_dict = dict()# 遍历Teacher模型、Student模型for idx, model_name in enumerate(self.model_name_list):# 先从predicts取出相关模型的预测结果out = predicts[model_name]# 然后,从out中取出key(即配置文件中配置的backbone_out)的值if self.key is not None:out = out[self.key]# 调用父类,计算损失loss = super().forward(out, batch)loss_dict["{}_{}".format(self.name, model_name)] = loss["loss"]return loss_dict# DistillationVQASerTokenLayoutLMLoss的父类 class VQASerTokenLayoutLMLoss(nn.Layer):def __init__(self, num_classes, key=None):super().__init__()self.loss_class = nn.CrossEntropyLoss()self.num_classes = num_classesself.ignore_index = self.loss_class.ignore_indexself.key = keydef forward(self, predicts, batch):if isinstance(predicts, dict) and self.key is not None:predicts = predicts[self.key]labels = batch[5] # (bs, 512)attention_mask = batch[2] # (bs, 512)if attention_mask is not None:active_loss = (attention_mask.reshape([-1,])== 1)# active_output_shape = (bs, 512, 5) -> (bs*512, 5)active_output = predicts.reshape([-1, self.num_classes])[active_loss]# active_label_shape = bs*512active_label = labels.reshape([-1,])[active_loss]# 交叉熵损失函数loss = self.loss_class(active_output, active_label)else:loss = self.loss_class(predicts.reshape([-1, self.num_classes]),labels.reshape([-1,]),)return {"loss": loss}
-
DistillationSERDMLLoss实质是计算Techaer和Student模型之间的互蒸馏损失函数,即KL散度。
class DistillationSERDMLLoss(DMLLoss):""" """def __init__(self,act="softmax",use_log=True,num_classes=7,model_name_pairs=[],key=None,name="loss_dml_ser",):super().__init__(act=act, use_log=use_log)assert isinstance(model_name_pairs, list)self.key = keyself.name = nameself.num_classes = num_classesself.model_name_pairs = model_name_pairsdef forward(self, predicts, batch):loss_dict = dict()# 遍历Teacher模型、Student模型for idx, pair in enumerate(self.model_name_pairs):# 取出Teacher模型以及Student模型中的结果out1 = predicts[pair[0]]out2 = predicts[pair[1]]if self.key is not None:# 取出backbone_outout1 = out1[self.key]out2 = out2[self.key]out1 = out1.reshape([-1, out1.shape[-1]])out2 = out2.reshape([-1, out2.shape[-1]])attention_mask = batch[2]if attention_mask is not None:active_output = (attention_mask.reshape([-1,])== 1)out1 = out1[active_output]out2 = out2[active_output]# 调用父类的方法loss_dict["{}_{}".format(self.name, idx)] = super().forward(out1, out2)return loss_dict# DistillationSERDMLLoss的父类 class DMLLoss(nn.Layer):"""DMLLoss"""def __init__(self, act=None, use_log=False):super().__init__()if act is not None:assert act in ["softmax", "sigmoid"]if act == "softmax":self.act = nn.Softmax(axis=-1)elif act == "sigmoid":self.act = nn.Sigmoid()else:self.act = Noneself.use_log = use_logself.jskl_loss = KLJSLoss(mode="kl")def _kldiv(self, x, target):"""计算两个概率分布之间的KL散度:KL散度的公式是 KL(P||Q) = ΣP(x) * log(P(x)/Q(x)),这里将其重写为ΣP(x)*(log(P(x)) - log(Q(x)))即target * (paddle.log(target + eps) - x)"""eps = 1.0e-10loss = target * (paddle.log(target + eps) - x)# batch mean lossloss = paddle.sum(loss) / loss.shape[0]return lossdef forward(self, out1, out2):if self.act is not None:out1 = self.act(out1) + 1e-10out2 = self.act(out2) + 1e-10if self.use_log:# 计算KL散度# for recognition distillation, log is needed for feature maplog_out1 = paddle.log(out1)log_out2 = paddle.log(out2)loss = (self._kldiv(log_out1, out2) + self._kldiv(log_out2, out1)) / 2.0else:# for detection distillation log is not neededloss = self.jskl_loss(out1, out2)return loss
-
DistillationVQADistanceLoss,本质是对 transformer 中间层的输出结果计算距离损失函数(L2 loss)
# DistillationVQADistanceLoss的父类 class DistanceLoss(nn.Layer):"""DistanceLoss:mode: loss mode"""def __init__(self, mode="l2", **kargs):super().__init__()assert mode in ["l1", "l2", "smooth_l1"]if mode == "l1":self.loss_func = nn.L1Loss(**kargs)elif mode == "l2":self.loss_func = nn.MSELoss(**kargs)elif mode == "smooth_l1":self.loss_func = nn.SmoothL1Loss(**kargs)def forward(self, x, y):return self.loss_func(x, y)
其他部分,诸如数据集的加载、构建优化器、创建评估函数、加载预训练模型、模型训练等,大家可以查看源码,不再赘述。
-
2.2 关系抽取(RE)
- 我们这里,看下模型的构建部分代码,其他代码,大家可以查看源码,不再赘述。
# paddlenlp.transformers.layoutxlm.modeling.py
class LayoutXLMForRelationExtraction(LayoutXLMPretrainedModel):def __init__(self, config: LayoutXLMConfig):super(LayoutXLMForRelationExtraction, self).__init__(config)self.layoutxlm = LayoutXLMModel(config)self.extractor = REDecoder(config.hidden_size, config.hidden_dropout_prob)self.dropout = nn.Dropout(config.hidden_dropout_prob)......def forward(self,input_ids,bbox,image=None,attention_mask=None,entities=None,relations=None,token_type_ids=None,position_ids=None,head_mask=None,labels=None,):# 1、经过12层的Transformer Block Encoderoutputs = self.layoutxlm(input_ids=input_ids, # (bs, 512)bbox=bbox, # (bs, 512, 4)image=image, # Noneattention_mask=attention_mask, # (bs, 512)token_type_ids=token_type_ids, # (bs. 512)position_ids=position_ids, # Nonehead_mask=head_mask, # None)seq_length = input_ids.shape[1]# 最后一层输出# sequence_output_shape = (bs, 512, 768)sequence_output = outputs[0][:, :seq_length]sequence_output = self.dropout(sequence_output)# 2、计算loss和预测关系loss, pred_relations = self.extractor(sequence_output, entities, relations)hidden_states = [outputs[2][f"{idx}_data"] for idx in range(self.layoutxlm.config.num_hidden_layers)]hidden_states = paddle.stack(hidden_states, axis=1)# 3、返回结果res = dict(loss=loss, pred_relations=pred_relations, hidden_states=hidden_states)return res
-
主要代码在REDecoder中
- 首先,构建构建关系对的正负样本
- 然后,获取关系头(question)、关系尾(answer)对应的特征信息
- 获取关系头(即question)在input_ids中开始的索引对应token的hidden_states(shape=(100, 768))和关系头(question)经过Embedding后的结果(shape=(100, 768))进行concat
- 获取关系尾(即answer)在input_ids中开始的索引对应token的hidden_states(shape=(100, 768))和关系尾(answer)经过Embedding后的结果(shape=(100, 768))进行concat
- 利用提取到的head_repr、tail_repr特征信息进行关系分类
- 最后,利用预测结果,计算交叉熵损失等
- 下面,给出一个relations和entities示例,方便理解。
class REDecoder(nn.Layer):def __init__(self, hidden_size=768, hidden_dropout_prob=0.1):super(REDecoder, self).__init__()self.entity_emb = nn.Embedding(3, hidden_size)# 100代表:100个关系对# (100, 1536) -> (100, 768) -> (100, 384)projection = nn.Sequential(nn.Linear(hidden_size * 2, hidden_size),nn.ReLU(),nn.Dropout(hidden_dropout_prob),nn.Linear(hidden_size, hidden_size // 2),nn.ReLU(),nn.Dropout(hidden_dropout_prob),)self.ffnn_head = copy.deepcopy(projection)self.ffnn_tail = copy.deepcopy(projection)# (100, 384) -> (100, 2)self.rel_classifier = BiaffineAttention(hidden_size // 2, 2)self.loss_fct = CrossEntropyLoss()def build_relation(self, relations, entities):"""relations_shape = (bs, 262145, 2)entities_shape = (bs, 513, 3)注:relations第1个数组代表实际长度,例如:[10, 10],代表:关系对(QUESTION-ANSWER)只有10个,其他为填充entities第1个数组代表实际长度,例如:[20, 20, 20],代表:实例(QUESTION或ANSWER)只有20个,其他为填充"""batch_size, max_seq_len = paddle.shape(entities)[:2]# new_relations_shape = (bs, 513*513, 3), 初始化为-1new_relations = paddle.full(shape=[batch_size, max_seq_len * max_seq_len, 3], fill_value=-1, dtype=relations.dtype)for b in range(batch_size):if entities[b, 0, 0] <= 2:entitie_new = paddle.full(shape=[512, 3], fill_value=-1, dtype=entities.dtype)entitie_new[0, :] = 2entitie_new[1:3, 0] = 0 # startentitie_new[1:3, 1] = 1 # endentitie_new[1:3, 2] = 0 # labelentities[b] = entitie_new# 实体label_shape为: [2, 1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 1, 1, 2, 2, 2, 2, 1]# all_possible_relations1为: [1 , 2 , 4 , 6 , 8 , 10, 12, 13, 14, 19] QUESTION# all_possible_relations2为: [0 , 3 , 5 , 7 , 9 , 11, 15, 16, 17, 18] ANSWERentitie_label = entities[b, 1 : entities[b, 0, 2] + 1, 2]all_possible_relations1 = paddle.arange(0, entities[b, 0, 2], dtype=entities.dtype)all_possible_relations1 = all_possible_relations1[entitie_label == 1]all_possible_relations2 = paddle.arange(0, entities[b, 0, 2], dtype=entities.dtype)all_possible_relations2 = all_possible_relations2[entitie_label == 2]# 所有可能的关系:all_possible_relations_shape:(100, 2)# [# [1, 0], [1, 3], ... , [1, 18],# [2, 0], [2, 3], ... , [2, 18],# ......# [19, 0], [19, 3], ... , [19, 18]# ]all_possible_relations = paddle.stack(paddle.meshgrid(all_possible_relations1, all_possible_relations2), axis=2).reshape([-1, 2])if len(all_possible_relations) == 0:all_possible_relations = paddle.full_like(all_possible_relations, fill_value=-1, dtype=entities.dtype)all_possible_relations[0, 0] = 0all_possible_relations[0, 1] = 1# relation_head: [1 , 2 , 4 , 6 , 8 , 10, 12, 13, 14, 19]# relation_tail: [0 , 3 , 5 , 7 , 9 , 11, 17, 15, 16, 18]relation_head = relations[b, 1 : relations[b, 0, 0] + 1, 0]relation_tail = relations[b, 1 : relations[b, 0, 1] + 1, 1]# positive_relations_shape: (10, 2)positive_relations = paddle.stack([relation_head, relation_tail], axis=1)# (100, 2) -> (100, 10, 2)all_possible_relations_repeat = all_possible_relations.unsqueeze(axis=1).tile([1, len(positive_relations), 1])# (100, 2) -> (100, 10, 2)positive_relations_repeat = positive_relations.unsqueeze(axis=0).tile([len(all_possible_relations), 1, 1])# mask shape = (100, 10)mask = paddle.all(all_possible_relations_repeat == positive_relations_repeat, axis=2)# 获取关系对负样本# negative_mask = paddle.any(mask, axis=1) is Falsenegative_mask = ~paddle.any(mask, axis=1)negative_relations = all_possible_relations[negative_mask]# 获取关系对正样本# positive_mask = paddle.any(mask, axis=0) is Truepositive_mask = paddle.any(mask, axis=0)positive_relations = positive_relations[positive_mask]if negative_mask.sum() > 0:# positive_relations_shape = (10, 2)# negative_relations_shape = (90, 2)# reordered_relations_shape = (100, 2)reordered_relations = paddle.concat([positive_relations, negative_relations])else:reordered_relations = positive_relationsrelation_per_doc_label = paddle.zeros([len(reordered_relations), 1], dtype=reordered_relations.dtype)relation_per_doc_label[: len(positive_relations)] = 1# relation_per_doc shape: (100, 3)"""relation_per_doc = [[1 , 0 , 1 ],# 正样本[2 , 3 , 1 ],[4 , 5 , 1 ],......[19, 18, 1 ],[1 , 3 , 0 ],# 负样本[1 , 5 , 0 ],......]"""relation_per_doc = paddle.concat([reordered_relations, relation_per_doc_label], axis=1)assert len(relation_per_doc[:, 0]) != 0# 第1个元素记录正负样本的长度信息,例如:[100, 100, 100]new_relations[b, 0] = paddle.shape(relation_per_doc)[0].astype(new_relations.dtype)# 将正负样本放到new_relations中new_relations[b, 1 : len(relation_per_doc) + 1] = relation_per_doc# new_relations.append(relation_per_doc)return new_relations, entitiesdef get_predicted_relations(self, logits, relations, entities):"""logits: 预测得到的关系概率, 例如:shape = (100, 2)relations: shape = (100, 3)entities: shape = (513, 3)"""pred_relations = []for i, pred_label in enumerate(logits.argmax(-1)):if pred_label != 1:continuerel = paddle.full(shape=[7, 2], fill_value=-1, dtype=relations.dtype)rel[0, 0] = relations[:, 0][i]rel[1, 0] = entities[:, 0][relations[:, 0][i] + 1]rel[1, 1] = entities[:, 1][relations[:, 0][i] + 1]rel[2, 0] = entities[:, 2][relations[:, 0][i] + 1]rel[3, 0] = relations[:, 1][i]rel[4, 0] = entities[:, 0][relations[:, 1][i] + 1]rel[4, 1] = entities[:, 1][relations[:, 1][i] + 1]rel[5, 0] = entities[:, 2][relations[:, 1][i] + 1]rel[6, 0] = 1pred_relations.append(rel)return pred_relationsdef forward(self, hidden_states, entities, relations):"""hidden_states_shape:(bs, 512, 768)entities_shape: (bs, 513, 3) , 其中:513 = 512 + 1,第一个元素记录长度信息relations_shape: (bs, 262145, 2),其中:262145 = 512*512 + 1,第一个元素记录长度信息"""batch_size, max_length, _ = paddle.shape(entities)# 1、构建关系的正负样本# relations_shape: (bs, 263169, 3) , 其中: 263169 = 513 * 513# entities_shape: (bs, 513, 3)relations, entities = self.build_relation(relations, entities)loss = 0# 所有预测关系结果all_pred_relations = paddle.full(shape=[batch_size, max_length * max_length, 7, 2], fill_value=-1, dtype=entities.dtype)for b in range(batch_size):# 2、获取关系头(question)、关系尾(answer)对应的特征信息# 取出正负样本关系对, relation_shape = (100, 3)relation = relations[b, 1 : relations[b, 0, 0] + 1]# 获取关系头(question)、关系尾(answer)、以及关系标签(1表示question和answer是一对,即正样本, 0表示负样本)head_entities = relation[:, 0]tail_entities = relation[:, 1]relation_labels = relation[:, 2]# 每一个实体(question或answer)在input_ids中开始的索引# 例如: [0 , 3 , 4 , 8 , 14 , 16 , 23 , 29 , 34 , 37 , 60 , 65 , 82 , 84 ,# 87 , 90 , 91 , 96 , 102, 106]entities_start_index = paddle.to_tensor(entities[b, 1 : entities[b, 0, 0] + 1, 0])# 获取每个实体类型编号,1表示question,2表示answer# 例如:[2, 1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 1, 1, 2, 2, 2, 2, 1]entities_labels = paddle.to_tensor(entities[b, 1 : entities[b, 0, 2] + 1, 2])# 获取关系头(即question)在input_ids中开始的索引,为了后面获取对应token的hidden_stateshead_index = entities_start_index[head_entities]# 获取关系头(question)对应的实体类型编号head_label = entities_labels[head_entities]# 关系头(question)经过Embedding, head_label_repr_shape = (100, 768)head_label_repr = self.entity_emb(head_label)# 获取关系尾(即answer)在input_ids中开始的索引,为了后面获取对应token的hidden_statestail_index = entities_start_index[tail_entities]# 获取关系尾(answer)对应的实体类型编号tail_label = entities_labels[tail_entities]# 关系尾(answer)经过Embedding, tail_label_repr_shape = (100, 768)tail_label_repr = self.entity_emb(tail_label)# 获取关系头(question)开始token的hidden_states, tmp_hidden_states shape: (100, 768)tmp_hidden_states = hidden_states[b][head_index]if len(tmp_hidden_states.shape) == 1:tmp_hidden_states = paddle.unsqueeze(tmp_hidden_states, axis=0)# concat, head_repr_shape = (100, 1536)head_repr = paddle.concat((tmp_hidden_states, head_label_repr), axis=-1)# 获取关系尾(answer)开始token的hidden_states, tmp_hidden_states shape: (100, 768)tmp_hidden_states = hidden_states[b][tail_index]if len(tmp_hidden_states.shape) == 1:tmp_hidden_states = paddle.unsqueeze(tmp_hidden_states, axis=0)# concat, tail_repr_shape = (100, 1536)tail_repr = paddle.concat((tmp_hidden_states, tail_label_repr), axis=-1)# 3、利用提取到的head_repr、tail_repr进行关系分类# heads_shape = (100, 1536) -> (100, 384)# tails_shape = (100, 1536) -> (100, 384)heads = self.ffnn_head(head_repr)tails = self.ffnn_tail(tail_repr)# 结合双线性层和线性层,实现对两个输入向量的复杂交互建模# logits_shape = (100, 2)logits = self.rel_classifier(heads, tails)# 4、计算交叉熵损失loss += self.loss_fct(logits, relation_labels)pred_relations = self.get_predicted_relations(logits, relation, entities[b])if len(pred_relations) > 0:pred_relations = paddle.stack(pred_relations)all_pred_relations[b, 0, :, :] = paddle.shape(pred_relations)[0].astype(all_pred_relations.dtype)all_pred_relations[b, 1 : len(pred_relations) + 1, :, :] = pred_relationsreturn loss, all_pred_relations
- 关于模型的预测代码(使用OCR结果进行预测等),可以参考https://aistudio.baidu.com/projectdetail/4823162。