OCR经典神经网络(三)LayoutLM v2算法原理及其在发票数据集上的应用(NER及RE)

OCR经典神经网络(三)LayoutLM v2算法原理及其在发票数据集上的应用(NER及RE)

  • LayoutLM系列模型是微软发布的、文档理解多模态基础模型领域最重要和有代表性的工作:
    • LayoutLM v2:在一个单一的多模态框架中对文本(text)、布局(layout)和图像(image)之间的交互进行建模。
    • LayoutXLM:LayoutXLM是 LayoutLMv2的多语言扩展版本。
    • LayoutLM v3:借鉴了ViLT和BEIT,不需要经过预训练的视觉backbone,通过MLM、MIM和WPA进行预训练的多模态Transformer。在以视觉为中心的任务上(如文档图像分类和文档布局分析)和以文本为中心的任务上(表单理解、收据理解、文档问答)都表现很好。
  • 今天,我们来了解下LayoutLM v2模型。
    • 论文链接:https://arxiv.org/pdf/2012.14740
    • 同样,百度开源的paddleocr中,在关键信息抽取中集成了此算法。
    • paddleocr中集成的算法列表:https://github.com/PaddlePaddle/PaddleOCR/blob/main/docs/algorithm/overview.md

1 LayoutLM v2算法原理

  • LayoutLM v2是一种多模态Transformer模型,该模型在预训练阶段整合了文档文本、版式及视觉信息,实现了在一个框架内端到端地学习跨模态交互。同时,将一种空间感知的自注意力机制融入到了Transformer架构中。
  • 除了掩码视觉语言模型(MVLM)预训练策略外,LayoutLM v2还新增了文本-图像对齐(TIA)和文本-图像匹配(TIM) 作为预训练策略,以强化不同模态间的对齐。
  • LayoutLMv2不仅在传统的富视觉文档理解(VrDU)任务上取得了显著的性能提升并达到当时新的最优水平,还在文档图像的视觉问题回答(VQA)任务上实现了新突破,这证明了多模态预训练在富视觉文档理解领域的巨大潜力。

1.1 模型结构

  • 模型结构如下图所示,可以看到LayoutLM v2接收文本、视觉及版式信息作为输入,以建立深度的跨模态交互。另外,将spatial-aware的自注意力机制整合到了transformer中。

  • 这里,我们主要看下Embedding层:

    • 文本嵌入

      • 文本嵌入包含三种嵌入:词嵌入代表词本身,一维位置嵌入表示词的位置索引,而段落嵌入用于区分不同的文本段落。

      t i = T o k E m b ( w i ) + P o s E m b 1 D ( i ) + S e g E m b ( s i ) t_i= TokEmb(w_i)+PosEmb1D(i)+SegEmb(s_i) ti=TokEmb(wi)+PosEmb1D(i)+SegEmb(si)

      • 使用WordPiece对OCR文本序列进行分词,并将每个分词(token)分配给特定的段落。接着,在序列的开始添加[CLS]标记,在每个文本段落的末尾添加[SEP]标记。为了使最终序列的长度恰好等于最大序列长度L,在序列末尾额外添加[PAD]填充符。
    • 视觉嵌入

      • 给定一个文档页面图像I,将其调整大小至224×224像素后输入到视觉主干网络中。之后,输出的特征图通过平均池化到固定尺寸,宽度为W,高度为H。接下来,它被展平为长度为W×H(例如:7×7)的视觉嵌入序列,此序列被称为VisTokEmb(I)。然后对每个视觉token嵌入应用线性投影层,以统一其维度与文本嵌入的维度
      • 由于基于CNN的视觉主干无法捕获位置信息,因此添加一维位置嵌入。
      • 对于段落嵌入,将所有视觉令牌附属于视觉段[C]。

      v i = P r o j ( V i s T o k E m b ( I ) i + P o s E m b 1 D ( i ) + S e g E m b ( [ C ] ) v_i= Proj(VisTokEmb(I)_i+PosEmb1D(i)+SegEmb([C]) vi=Proj(VisTokEmb(I)i+PosEmb1D(i)+SegEmb([C])

    • 布局嵌入(2D Position Embeddings)

      • 将所有的坐标标准化并离散化为[0, 1000]范围内的整数,并使用两个嵌入层分别嵌入x轴特征和y轴特征
      • 给定第i个( 0 ≤ i < W × H + L 0 ≤ i < W×H + L 0i<W×H+L)文本/视觉token的标准化边界框 b o x i = ( x m i n , x m a x , y m i n , y m a x , w i d t h , h e i g h t ) box_i = (x_{min}, x_{max}, y_{min}, y_{max}, width, height) boxi=(xmin,xmax,ymin,ymax,width,height)布局嵌入层将这六个边界框特征连接起来构建一个token级的2D位置嵌入,即布局嵌入

在这里插入图片描述
在这里插入图片描述

  • 由于卷积神经网络(CNNs)执行局部变换,因此视觉token嵌入可以一一映射回图像区域,既没有重叠也没有遗漏。
    • 在计算边界框时,视觉token可以被视为均匀划分的网格。
    • 对于特殊token [CLS]、[SEP]和[PAD],会附加一个空边界框boxPAD = (0, 0, 0, 0, 0, 0)。这意味着这些特殊符号在空间布局上不占用实际区域,但通过这样的空边界框嵌入,模型能够将它们整合到序列中的相应位置上,同时保持空间信息的一致性

1.2 预训练目标及数据

1.2.1 MVLM

  • 采用了掩码视觉-语言建模(Masked Visual-Language Modeling, MVLM)方法,以便模型在跨模态线索的帮助下更好地学习语言方面。
    • 随机掩蔽一些文本token,并要求模型恢复这些被掩蔽的token。
    • 与此同时,布局信息保持不变,这意味着模型了解每个被掩蔽token在页面上的位置。
    • 为了避免视觉线索泄露,在将原始页面图像输入到视觉编码器之前,会先对应掩蔽掉与被掩蔽文本token相对应的图像区域。

1.2.2 TIA

  • Text-Image Alignment(TIA):随机遮盖图像,然后识别文本对应图像是否被遮盖了。
    • 为了帮助模型学习图像与边界框坐标的空間位置对应关系,提出了细粒度的跨模态对齐任务——文本-图像对齐(Text-Image Alignment, TIA)。
    • 在TIA任务中,随机选择一些文本行,并在其文档图像上的对应图像区域进行遮盖, 称此操作为“遮盖”,以避免与MVLM中的“掩码”操作混淆。
    • 预训练期间,在编码器输出之上构建了一个分类层。该层根据文本令牌是否被遮盖(即,[Covered]或[Not Covered])预测每个文本令牌的标签,并计算二元交叉熵损失
    • 考虑到输入图像的分辨率有限,且某些文档元素(如图表中的符号和线条)可能看起来像被遮盖的文本区域,寻找单词大小的遮盖图像区域的任务可能会存在噪声。因此,遮盖操作是在行级别进行的
    • 当MVLM和TIA同时执行时,MVLM中被掩蔽的令牌的TIA损失不予考虑。这防止了模型学习从[MASK]到[Covered]这种无用但直观的对应关系。

1.2.3 TIM

  • Text-Image Matching(TIM):使用[CLS]来判断给出的图片特征与文本特征是否属于同一个页面。
  • 为了帮助模型学习文档图像与文本内容之间的对应关系,采用了较为粗粒度的跨模态对齐任务,即文本-图像匹配(Text-Image Matching, TIM)。
  • 将[CLS]位置的输出表示送入一个分类器,以预测图像和文本是否来自同一文档页面。正常的配对输入被视为正样本
  • 为了构建负样本,图像要么被另一文档的页面图像替换,要么被移除。
  • 为防止模型通过寻找任务特定特征来作弊,对负面样本中的图像也执行相同的掩码和遮盖操作。在负面样本中,TIA的目标标签全部设置为[Covered]

1.2.4 预训练数据

  • 为了预训练和评估LayoutLMv2模型,作者从富含视觉元素的文档理解领域中选择了广泛的数据集。

  • 使用IIT-CDIP作为预训练数据集。

1.3 模型微调

  • 文档级别分类任务RVL-CDIP中,使用[CLS]输出以及池化的视觉令牌表示作为全局特征
  • 对于提取式问答任务DocVQA及其他四个实体提取任务,在LayoutLMv2输出的文本部分上构建特定任务的头部层。在DocVQA论文中,实验结果显示,在SQuAD数据集上微调过的BERT模型比原始BERT模型表现更优。受此启发,增加了一个额外的设置:首先在问题生成(Question Generation, QG)数据集上微调LayoutLMv2,随后再在DocVQA数据集上微调。这个QG数据集包含近百万对由训练于SQuAD数据集的生成模型产生的问题-答案对。

1.4 LayoutXLM模型结构

  • LayoutXLM是 LayoutLMv2的多语言扩展版本。为了准确评估LayoutXLM,论文中还引入了一个多语言表单理解基准数据集,名为XFUND,该数据集包含了7种语言(中文、日语、西班牙语、法语、意大利语、德语、葡萄牙语)的表单理解样本,并为每种语言的手工标注了键值对。
  • 论文链接:https://arxiv.org/pdf/2104.08836
  • LayoutXLM预训练策略,同LayoutLMv2
  • 该框架如下图所示:
    • 模型接收来自三种不同模态的信息,即文本、布局和图像,分别使用文本嵌入、布局嵌入和视觉嵌入层进行编码。文本和图像嵌入被连接在一起,然后加上布局嵌入以获得输入嵌入。
    • 输入嵌入通过带有空间感知自注意力机制的多模态Transformer进行编码。
    • 最后,输出的上下文表示可以用于后续的任务特定层。

在这里插入图片描述

1.5 VI-LayoutXLM

  • 百度在PP-StructureV2中,针对 LayoutXLM 进行改进,得到了VI-LayoutXLM。

  • 论文链接:https://arxiv.org/pdf/2210.05391

  • 模型部分改进如下:

    在这里插入图片描述

    • LayoutLMv2 以及 LayoutXLM 中引入视觉骨干网络,用于提取视觉特征,并与后续的 text embedding 进行联合,作为多模态的输入 embedding。但是该模块为基于 ResNet_x101_64x4d 的特征提取网络,特征抽取阶段耗时严重。
    • 因此,移除视觉特征提取模块,同时仍然保留文本、位置以及布局等信息,最终发现针对 LayoutXLM 进行改进,下游 SER 任务精度无损,针对 LayoutLMv2 进行改进,下游 SER 任务精度仅降低2.1%,而模型大小减小了约340M。

在这里插入图片描述

2 VI-LayoutXLM在发票数据集上的应用

  • 关键信息抽取 (Key Information Extraction, KIE)指的是是从文本或者图像中,抽取出关键的信息。

    • 针对文档图像的关键信息抽取任务作为OCR的下游任务,存在非常多的实际应用场景,如表单识别、车票信息抽取、身份证信息抽取等。
    • 文档图像中的KIE一般包含2个子任务,示意图如下图所示。
      • SER: 语义实体识别 (Semantic Entity Recognition),对每一个检测到的文本进行分类,如将其分为姓名,身份证。如下图中的黑色框和红色框。
      • RE: 关系抽取 (Relation Extraction),对每一个检测到的文本进行分类,如将其分为问题 (key) 和答案 (value) 。然后对每一个问题找到对应的答案,相当于完成key-value的匹配过程。如下图中的红色框和黑色框分别代表问题和答案,黄色线代表问题和答案之间的对应关系。

    在这里插入图片描述

  • 除了视觉特征无关的多模态预训练模型结构,paddleocr中在KIE任务上,还有两个主要的优化策略:

    • TB-YX:考虑阅读顺序的文本行排序逻辑
      • 文本阅读顺序对于信息抽取与文本理解等任务至关重要,传统多模态模型中,没有考虑不同 OCR 工具可能产生的不正确阅读顺序,而模型输入中包含位置编码,阅读顺序会直接影响预测结果
      • 在预处理中,对文本行按照从上到下,从左到右(YX)的顺序进行排序,为防止文本行位置轻微干扰带来的排序结果不稳定问题,在排序的过程中,引入位置偏移阈值 Th,对于 Y 方向距离小于 Th 的2个文本内容,使用 X 方向的位置从左到右进行排序。
    • UDML:联合互学习知识蒸馏策略
      • UDML(Unified-Deep Mutual Learning)联合互学习是 PP-OCRv2 与 PP-OCRv3 中采用的对于文本识别非常有效的提升模型效果的策略。
      • 在训练时,引入2个完全相同的模型进行互学习,计算2个模型之间的互蒸馏损失函数(DML loss),同时对 transformer 中间层的输出结果计算距离损失函数(L2 loss)。
      • 使用该策略,最终 XFUND 数据集上,SER 任务 F1 指标提升0.6%,RE 任务 F1 指标提升5.01%。

    在这里插入图片描述

  • KIE常用思路有如下两种:

    • 一种是SER:

      • 直接使用SER,获取关键信息的类别;常用于关键信息类别固定的场景。
      • 以身份证场景为例, 关键信息一般包含姓名性别民族等,我们直接将对应的字段标注为特定的类别即可,如下图所示:

      在这里插入图片描述

      • 注意:

        • 标注过程中,对于无关于KIE关键信息的文本内容,均需要将其标注为other类别,相当于背景信息。如在身份证场景中,如果我们不关注性别信息,那么可以将“性别”与“男”这2个字段的类别均标注为other
        • 标注过程中,需要以文本行为单位进行标注,无需标注单个字符的位置信息。

        数据量方面,一般来说,对于比较固定的场景,50张左右的训练图片即可达到可以接受的效果,可以使用PPOCRLabel完成KIE的标注过程。

    • 一种是SER+RE:

      • 联合使用SER+RE,先利用SER找到key和value,然后再利用RE进行匹配;常用于关系类别不固定的场景。
      • 以身份证场景为例, 关键信息一般包含姓名性别民族等关键信息。在SER阶段,我们需要识别所有的question (key) 与answer (value) 。每个字段的类别信息(label字段)可以是question、answer或者other(与待抽取的关键信息无关的字段)

      在这里插入图片描述

      • 在RE阶段,需要标注每个字段的的id与连接信息,如下图所示:
        • 标注过程中,如果value是多个字符,那么linking中可以新增一个key-value对,如[[0, 1], [0, 2]]
        • 数据量方面,一般来说,对于比较固定的场景,50张左右的训练图片即可达到可以接受的效果,可以使用PPOCRLabel完成KIE的标注过程。

      在这里插入图片描述

    • 我们参考案例:https://aistudio.baidu.com/projectdetail/4823162(项目里提供了发票数据集),来对VI-LayoutXLM模型有更深的认识。

2.1 语义实体识别 (SER)

2.1.1 模型构建

  • 我这里不用命令行执行,在paddleocr\tests目录下创建一个py文件执行训练过程

  • 我们复制一份paddleocr\configs\kie\vi_layoutxlm\ser_vi_layoutxlm_xfund_zh_udml.yml文件到paddleocr\tests\configs进行修改(参考上面项目链接进行修改),发票数据集在上面项目中已提供,模型部分的配置文件如下:

    Architecture:model_type: &model_type "kie"name: DistillationModelalgorithm: DistillationModels:Teacher:pretrained:freeze_params: falsereturn_all_feats: truemodel_type: *model_typealgorithm: &algorithm "LayoutXLM"Transform:Backbone:name: LayoutXLMForSerpretrained: True             # 会利用paddle-nlp加载预训练模型# one of base or vimode: vicheckpoints:num_classes: &num_classes 5  # 采用BIO的标注,训练需要修改Student:pretrained:freeze_params: falsereturn_all_feats: truemodel_type: *model_typealgorithm: *algorithmTransform:Backbone:name: LayoutXLMForSerpretrained: True# one of base or vimode: vicheckpoints:num_classes: *num_classes
    
  • 通过下面的py文件,我们就可以愉快的查看源码了。

def train_kie_token_ser_demo():from tools.train import program, set_seed, main# 配置文件的源地址地址: paddleocr\configs\kie\vi_layoutxlm\ser_vi_layoutxlm_xfund_zh_udml.ymlconfig, device, logger, vdl_writer = program.preprocess(is_train=True)###############修改配置(也可在yml文件中修改)################### 评估频率config["Global"]["eval_batch_step"] = [0, 200]# log的打印频率config["Global"]["print_batch_step"] = 50# 训练的epochsconfig["Global"]["epoch_num"] = 1# 随机种子seed = config["Global"]["seed"] if "seed" in config["Global"] else 1024set_seed(seed)###############模型训练##################main(config, device, logger, vdl_writer, seed)def train_kie_token_re_demo():from tools.train import program, set_seed, main# 配置文件的源地址地址: paddleocr\configs\kie\vi_layoutxlm\re_vi_layoutxlm_xfund_zh_udml.ymlconfig, device, logger, vdl_writer = program.preprocess(is_train=True)###############修改配置(也可在yml文件中修改)################### 评估频率config["Global"]["eval_batch_step"] = [0, 200]# log的打印频率config["Global"]["print_batch_step"] = 50# 训练的epochsconfig["Global"]["epoch_num"] = 1# 随机种子seed = config["Global"]["seed"] if "seed" in config["Global"] else 1024set_seed(seed)###############模型训练##################main(config, device, logger, vdl_writer, seed)if __name__ == '__main__':train_kie_token_ser_demo()# train_kie_token_re_demo()

LayoutXLMForTokenClassification

  • 首先,利用LayoutXLMModel提取特征(文本、布局信息)
  • 然后,利用文本部分的特征进行BIO多分类
# paddleocr.ppocr.modeling.backbones.vqa_layoutlm.py
class LayoutXLMForTokenClassification(LayoutXLMPretrainedModel):def __init__(self, config: LayoutXLMConfig):super(LayoutXLMForTokenClassification, self).__init__(config)self.num_classes = config.num_labelsself.layoutxlm = LayoutXLMModel(config)self.dropout = nn.Dropout(config.hidden_dropout_prob)self.classifier = nn.Linear(config.hidden_size, self.num_classes)......def forward(self,input_ids=None,bbox=None,image=None,attention_mask=None,token_type_ids=None,position_ids=None,head_mask=None,labels=None,):# 1、经过12层的Transformer Block Encoderoutputs = self.layoutxlm(input_ids=input_ids,bbox=bbox,image=image,attention_mask=attention_mask,token_type_ids=token_type_ids,position_ids=position_ids,head_mask=head_mask,)seq_length = input_ids.shape[1]# sequence out and image out# 2、进行BIO多分类# sequence_output: (bs, 561, 768) -> (bs, 512, 768) -> (bs, 512, 5)sequence_output = outputs[0][:, :seq_length]sequence_output = self.dropout(sequence_output)logits = self.classifier(sequence_output)hidden_states = {f"hidden_states_{idx}": outputs[2][f"{idx}_data"] for idx in range(self.layoutxlm.config.num_hidden_layers)}if self.training:outputs = (logits, hidden_states)else:outputs = (logits,)......return outputs

LayoutXLMModel

这里我们主要看下LayoutXLMModel模型中,文本的embedding和视觉部分的embedding。

  • 文本的embedding:

    • word_embeddings:对tokenizer后的input_ids进行word_embeddings,shape变化:(bs, 512) -> (bs, 512, 768)

    • position_embeddings(1D position embedding):对文本部分的position_ids进行embeding,shape变化:(bs, 512) -> (bs, 512, 768)。这里,文本和视觉的position_embeddings是共享的。

    • spatial_position_embeddings:这里shape变化为(bs, 512, 4) -> (bs, 512, 768),是将每一个bbox信息的(x_min, y_min, x_max, y_max, h, w)编码,然后concat得到,代码如下所示。注意:如果一个bbox内的文字,被切分为多个token,那么这些token的bbox信息是一致的。

          # paddlenlp.transformers.layoutxlm.modeling.pydef _cal_spatial_position_embeddings(self, bbox):try:# (bs, embdedding_dim) -> (bs, embdedding_dim, 128)left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0])# (bs, embdedding_dim) -> (bs, embdedding_dim, 128)upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1])# (bs, embdedding_dim) -> (bs, embdedding_dim, 128)right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2])# (bs, embdedding_dim) -> (bs, embdedding_dim, 128)lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3])except IndexError as e:raise IndexError("The :obj:`bbox`coordinate values should be within 0-1000 range.") from e# (bs, embdedding_dim) -> (bs, embdedding_dim, 128)h_position_embeddings = self.h_position_embeddings(bbox[:, :, 3] - bbox[:, :, 1])# (bs, embdedding_dim) -> (bs, embdedding_dim, 128)w_position_embeddings = self.w_position_embeddings(bbox[:, :, 2] - bbox[:, :, 0])#  [x_min, y_min, x_max, y_max, h, w] concat -> (bs, embdedding_dim, 128*6)spatial_position_embeddings = paddle.concat([left_position_embeddings,upper_position_embeddings,right_position_embeddings,lower_position_embeddings,h_position_embeddings,w_position_embeddings,],axis=-1,)return spatial_position_embeddings
      
    • token_type_embeddings:这里的token_type_ids全为0,shape变化为(bs, 512) -> (bs, 512, 768)

  • 视觉部分的embedding:

    • position_embeddings(1D position embedding):shape变化为(bs, 49) -> (bs, 49, 768)。视觉部分的position ids为:[0, 1, 2, …, 48] -> (bs, 49)。这里虽然去除了视觉提取,但是position ids按照图像224×224经过降采样32倍后的feature map:7×7进行生成。这里,文本和视觉的position_embeddings是共享的;
    • spatial_position_embeddings:视觉部分布局信息,即bbox的生成的核心逻辑是:7×7网格中,每一个小的正方形的坐标(x_min, y_min, x_max, y_max)即为一个视觉token。shape变化为(bs, 49, 4) -> (bs, 49, 768)
    • visual_segment_embedding
  • 最终,将文本的embedding和视觉部分的embedding送入到12层的Transformer Encoder Block提取特征。

# paddlenlp.transformers.layoutxlm.modeling.py
@register_base_model
class LayoutXLMModel(LayoutXLMPretrainedModel):def __init__(self, config: LayoutXLMConfig):super(LayoutXLMModel, self).__init__(config)self.config = configself.use_visual_backbone = config.use_visual_backboneself.has_visual_segment_embedding = config.has_visual_segment_embeddingself.embeddings = LayoutXLMEmbeddings(config)if self.use_visual_backbone is True:self.visual = VisualBackbone(config)self.visual.stop_gradient = Trueself.visual_proj = nn.Linear(config.image_feature_pool_shape[-1], config.hidden_size)if self.has_visual_segment_embedding:self.visual_segment_embedding = self.create_parameter(shape=[config.hidden_size,],dtype=paddle.float32,)self.visual_LayerNorm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps)self.visual_dropout = nn.Dropout(config.hidden_dropout_prob)self.encoder = LayoutXLMEncoder(config)self.pooler = LayoutXLMPooler(config)def _calc_visual_bbox(self, image_feature_pool_shape, bbox, visual_shape):"""视觉部分布局信息,即bbox的生成:- image_feature_pool_shape:(7, 7, 256)- 文字token的bbox信息:(bs, 512, 4)- visual_shape:[bs, 49]"""# 首先,生成一个序列[0, 1000, 2000, 3000, 4000, 5000, 6000, 7000]# 然后,离散化为[0, 1000],即[0, 142, 285, 428, 571, 714, 857, 1000]visual_bbox_x = (paddle.arange(0,1000 * (image_feature_pool_shape[1] + 1),1000,dtype=bbox.dtype,)// image_feature_pool_shape[1])visual_bbox_y = (paddle.arange(0,1000 * (image_feature_pool_shape[0] + 1),1000,dtype=bbox.dtype,)// image_feature_pool_shape[0])expand_shape = image_feature_pool_shape[0:2] # (7, 7)# 7×7网格中,每一个小的正方形的坐标(x_min, y_min, x_max, y_max)即为一个视觉token# visual_bbox shape = (7×7, 4)visual_bbox = paddle.stack([visual_bbox_x[:-1].expand(expand_shape),visual_bbox_y[:-1].expand(expand_shape[::-1]).transpose([1, 0]),visual_bbox_x[1:].expand(expand_shape),visual_bbox_y[1:].expand(expand_shape[::-1]).transpose([1, 0]),],axis=-1,).reshape([expand_shape[0] * expand_shape[1], paddle.shape(bbox)[-1]])# 扩展到bs个样本, (7×7, 4) -> (bs, 7×7, 4)visual_bbox = visual_bbox.expand([visual_shape[0], visual_bbox.shape[0], visual_bbox.shape[1]])return visual_bboxdef _calc_text_embeddings(self, input_ids, bbox, position_ids, token_type_ids):"""文本部分进行embeddings:word_embeddings+ position_embeddings(文本和视觉的position_embeddings是共享的)+ spatial_position_embeddings+ token_type_embeddings"""# (bs, 512) -> (bs, 512, 768)words_embeddings = self.embeddings.word_embeddings(input_ids)# (bs, 512) -> (bs, 512, 768)position_embeddings = self.embeddings.position_embeddings(position_ids)# (bs, 512, 4) -> (bs, 512, 768)spatial_position_embeddings = self.embeddings._cal_spatial_position_embeddings(bbox)# (bs, 512) -> (bs, 512, 768)token_type_embeddings = self.embeddings.token_type_embeddings(token_type_ids)# 4种embedding相加embeddings = words_embeddings + position_embeddings + spatial_position_embeddings + token_type_embeddings# LayerNorm + dropoutembeddings = self.embeddings.LayerNorm(embeddings)embeddings = self.embeddings.dropout(embeddings)return embeddingsdef _calc_img_embeddings(self, image, bbox, position_ids):"""视觉部分进行embedding:position_embeddings(文本和视觉的position_embeddings是共享的)+   spatial_position_embeddings+   visual_segment_embedding"""use_image_info = self.use_visual_backbone and image is not None# (bs, 49) -> (bs, 49, 768)position_embeddings = self.embeddings.position_embeddings(position_ids)# (bs, 49, 4) -> (bs, 49, 768)spatial_position_embeddings = self.embeddings._cal_spatial_position_embeddings(bbox)if use_image_info is True:visual_embeddings = self.visual_proj(self.visual(image.astype(paddle.float32)))embeddings = visual_embeddings + position_embeddings + spatial_position_embeddingselse:# embedding相加embeddings = position_embeddings + spatial_position_embeddingsif self.has_visual_segment_embedding:# self.visual_segment_embedding shape = (768)embeddings += self.visual_segment_embedding#  visual_LayerNorm + visual_dropoutembeddings = self.visual_LayerNorm(embeddings)embeddings = self.visual_dropout(embeddings)return embeddingsdef forward(self,input_ids=None,bbox=None,image=None,token_type_ids=None,position_ids=None,attention_mask=None,head_mask=None,output_hidden_states=False,output_attentions=False,):input_shape = paddle.shape(input_ids)visual_shape = list(input_shape)visual_shape[1] = self.config.image_feature_pool_shape[0] * self.config.image_feature_pool_shape[1]# 视觉部分的bbox的生成# 视觉token被视为均匀划分的网格# 生成的bbox信息:feature_map(7×7)网格中,每一个小的正方形的坐标(x_min, y_min, x_max, y_max)即为一个视觉tokenvisual_bbox = self._calc_visual_bbox(self.config.image_feature_pool_shape, bbox, visual_shape)# 1、2D position embedding(文本部分bbox+视觉部分bbox)# (bs, 512, 4) + (bs, 49, 4) -> (bs, 561, 4)final_bbox = paddle.concat([bbox, visual_bbox], axis=1)if attention_mask is None:attention_mask = paddle.ones(input_shape)if self.use_visual_backbone is True:# 使用视觉部分的backbonevisual_attention_mask = paddle.ones(visual_shape)else:# 移除视觉特征提取模块,mask全设置为0visual_attention_mask = paddle.zeros(visual_shape)attention_mask = attention_mask.astype(visual_attention_mask.dtype)# concat后attention_mask:(bs, 512) + (bs, 49) -> (bs, 561)final_attention_mask = paddle.concat([attention_mask, visual_attention_mask], axis=1)if token_type_ids is None:token_type_ids = paddle.zeros(input_shape, dtype=paddle.int64)# 2、1D position embedding(文本部分+视觉部分) (bs, 512) + (bs, 49) -> (bs, 561)if position_ids is None:# 文本部分的position embeddingseq_length = input_shape[1]position_ids = self.embeddings.position_ids[:, :seq_length]position_ids = position_ids.expand(input_shape)# 视觉部分的position embedding# [0, 1, 2, ..., 48] -> (bs, 49)visual_position_ids = paddle.arange(0, visual_shape[1]).expand([input_shape[0], visual_shape[1]])final_position_ids = paddle.concat([position_ids, visual_position_ids], axis=1)if bbox is None:bbox = paddle.zeros(input_shape + [4])# 3、 text embedding & visual  (bs, 512, 768) + (bs, 49, 768) -> (bs, 561, 768)# 文本部分进行embdedding (bs, 512) -> (bs, 512, 768)text_layout_emb = self._calc_text_embeddings(input_ids=input_ids,bbox=bbox,token_type_ids=token_type_ids,position_ids=position_ids,)# 视觉部分进行embedding(注意此时没有image,仅有视觉的bbox以及position_ids)visual_emb = self._calc_img_embeddings(image=image,bbox=visual_bbox,position_ids=visual_position_ids,)final_emb = paddle.concat([text_layout_emb, visual_emb], axis=1)# (bs, 561) -> (bs, 1, 1, 561)extended_attention_mask = final_attention_mask.unsqueeze(1).unsqueeze(2)extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0if head_mask is not None:if head_mask.dim() == 1:head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)elif head_mask.dim() == 2:head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)else:head_mask = [None] * self.config.num_hidden_layers# 经过Transformer Encoder Block(12层)encoder_outputs = self.encoder(final_emb,                        # 文本&视觉部分的embedding , shape=(bs, 561, 768)extended_attention_mask,          # attention_mask        , shape=(bs, 1, 1, 561)bbox=final_bbox,                  # 2D position embedding【如果需要相对位置位置编码,加在attention_score上,这里为False】, shape=(bs, 561, 4)position_ids=final_position_ids,  # 1D position embedding【如果需要相对位置位置编码,加在attention_score上,这里为False】, shape=(bs, 561)head_mask=head_mask,output_attentions=output_attentions,output_hidden_states=output_hidden_states,)# sequence_output shape = (bs, 561, 768)sequence_output = encoder_outputs[0]# pooled_output shape = (bs, 768)pooled_output = self.pooler(sequence_output)return sequence_output, pooled_output, encoder_outputs[1]

2.1.2 损失计算

  • 由于使用了UDML:联合互学习知识蒸馏策略,损失计算的配置如下:

在这里插入图片描述

Loss:name: CombinedLoss                     # ppocr.losses.combined_loss.CombinedLossloss_config_list:- DistillationVQASerTokenLayoutLMLoss: # GT loss   ppocr.losses.distillation_loss.DistillationVQASerTokenLayoutLMLossweight: 1.0model_name_list: ["Student", "Teacher"]key: backbone_outnum_classes: *num_classes- DistillationSERDMLLoss:              # DML loss  ppocr.losses.distillation_loss.DistillationSERDMLLossweight: 1.0act: "softmax"use_log: truemodel_name_pairs:- ["Student", "Teacher"]key: backbone_out- DistillationVQADistanceLoss:         # S5 loss  ppocr.losses.distillation_loss.DistillationVQADistanceLossweight: 0.5mode: "l2"model_name_pairs:- ["Student", "Teacher"]key: hidden_states_5name: "loss_5"- DistillationVQADistanceLoss:         # S8 loss  ppocr.losses.distillation_loss.DistillationVQADistanceLossweight: 0.5mode: "l2"model_name_pairs:- ["Student", "Teacher"]key: hidden_states_8name: "loss_8"
  • 如下所示,在DistillationModel中,Teacher和Student模型分别进行前向过程
# paddleocr.ppocr.modeling.architectures.distillation_model.py
class DistillationModel(nn.Layer):def __init__(self, config):"""the module for OCR distillation.args:config (dict): the super parameters for module."""super().__init__()self.model_list = []self.model_name_list = []for key in config["Models"]:model_config = config["Models"][key]freeze_params = Falsepretrained = Noneif "freeze_params" in model_config:freeze_params = model_config.pop("freeze_params")if "pretrained" in model_config:pretrained = model_config.pop("pretrained")model = BaseModel(model_config)if pretrained is not None:load_pretrained_params(model, pretrained)if freeze_params:for param in model.parameters():param.trainable = Falseself.model_list.append(self.add_sublayer(key, model))self.model_name_list.append(key)def forward(self, x, data=None):result_dict = dict()# 执行所有模型的前向过程, 例如:Teacher和Student模型for idx, model_name in enumerate(self.model_name_list):result_dict[model_name] = self.model_list[idx](x, data)return result_dict
  • 在CombinedLoss中遍历配置的损失函数,分别计算损失,最后相加最为总损失
# paddleocr.ppocr.losses.combined_loss.py
class CombinedLoss(nn.Layer):"""CombinedLoss:a combionation of loss function"""def __init__(self, loss_config_list=None):super().__init__()self.loss_func = []self.loss_weight = []assert isinstance(loss_config_list, list), "operator config should be a list"......def forward(self, input, batch, **kargs):# input包含Teacher模型以及Student模型的输出结果# batch是批次数据,里面包含labelloss_dict = {}loss_all = 0.0# 遍历配置的所有的损失函数,计算损失for idx, loss_func in enumerate(self.loss_func):loss = loss_func(input, batch, **kargs)if isinstance(loss, paddle.Tensor):loss = {"loss_{}_{}".format(str(loss), idx): loss}weight = self.loss_weight[idx]loss = {key: loss[key] * weight for key in loss}if "loss" in loss:loss_all += loss["loss"]else:loss_all += paddle.add_n(list(loss.values()))loss_dict.update(loss)loss_dict["loss"] = loss_allreturn loss_dict
  • 我们看下具体配置的损失函数:

    • DistillationVQASerTokenLayoutLMLoss的实质就是每个模型分别计算NER任务的CrossEntropyLoss,即GT loss:

      class DistillationVQASerTokenLayoutLMLoss(VQASerTokenLayoutLMLoss):def __init__(self, num_classes, model_name_list=[], key=None, name="loss_ser"):super().__init__(num_classes=num_classes)self.model_name_list = model_name_listself.key = keyself.name = namedef forward(self, predicts, batch):loss_dict = dict()# 遍历Teacher模型、Student模型for idx, model_name in enumerate(self.model_name_list):# 先从predicts取出相关模型的预测结果out = predicts[model_name]# 然后,从out中取出key(即配置文件中配置的backbone_out)的值if self.key is not None:out = out[self.key]# 调用父类,计算损失loss = super().forward(out, batch)loss_dict["{}_{}".format(self.name, model_name)] = loss["loss"]return loss_dict# DistillationVQASerTokenLayoutLMLoss的父类
      class VQASerTokenLayoutLMLoss(nn.Layer):def __init__(self, num_classes, key=None):super().__init__()self.loss_class = nn.CrossEntropyLoss()self.num_classes = num_classesself.ignore_index = self.loss_class.ignore_indexself.key = keydef forward(self, predicts, batch):if isinstance(predicts, dict) and self.key is not None:predicts = predicts[self.key]labels = batch[5]           # (bs, 512)attention_mask = batch[2]   # (bs, 512)if attention_mask is not None:active_loss = (attention_mask.reshape([-1,])== 1)# active_output_shape = (bs, 512, 5) -> (bs*512, 5)active_output = predicts.reshape([-1, self.num_classes])[active_loss]# active_label_shape = bs*512active_label = labels.reshape([-1,])[active_loss]# 交叉熵损失函数loss = self.loss_class(active_output, active_label)else:loss = self.loss_class(predicts.reshape([-1, self.num_classes]),labels.reshape([-1,]),)return {"loss": loss}
      
    • DistillationSERDMLLoss实质是计算Techaer和Student模型之间的互蒸馏损失函数,即KL散度。

      class DistillationSERDMLLoss(DMLLoss):""" """def __init__(self,act="softmax",use_log=True,num_classes=7,model_name_pairs=[],key=None,name="loss_dml_ser",):super().__init__(act=act, use_log=use_log)assert isinstance(model_name_pairs, list)self.key = keyself.name = nameself.num_classes = num_classesself.model_name_pairs = model_name_pairsdef forward(self, predicts, batch):loss_dict = dict()# 遍历Teacher模型、Student模型for idx, pair in enumerate(self.model_name_pairs):# 取出Teacher模型以及Student模型中的结果out1 = predicts[pair[0]]out2 = predicts[pair[1]]if self.key is not None:# 取出backbone_outout1 = out1[self.key]out2 = out2[self.key]out1 = out1.reshape([-1, out1.shape[-1]])out2 = out2.reshape([-1, out2.shape[-1]])attention_mask = batch[2]if attention_mask is not None:active_output = (attention_mask.reshape([-1,])== 1)out1 = out1[active_output]out2 = out2[active_output]# 调用父类的方法loss_dict["{}_{}".format(self.name, idx)] = super().forward(out1, out2)return loss_dict# DistillationSERDMLLoss的父类   
      class DMLLoss(nn.Layer):"""DMLLoss"""def __init__(self, act=None, use_log=False):super().__init__()if act is not None:assert act in ["softmax", "sigmoid"]if act == "softmax":self.act = nn.Softmax(axis=-1)elif act == "sigmoid":self.act = nn.Sigmoid()else:self.act = Noneself.use_log = use_logself.jskl_loss = KLJSLoss(mode="kl")def _kldiv(self, x, target):"""计算两个概率分布之间的KL散度:KL散度的公式是 KL(P||Q) = ΣP(x) * log(P(x)/Q(x)),这里将其重写为ΣP(x)*(log(P(x)) - log(Q(x)))即target * (paddle.log(target + eps) - x)"""eps = 1.0e-10loss = target * (paddle.log(target + eps) - x)# batch mean lossloss = paddle.sum(loss) / loss.shape[0]return lossdef forward(self, out1, out2):if self.act is not None:out1 = self.act(out1) + 1e-10out2 = self.act(out2) + 1e-10if self.use_log:# 计算KL散度# for recognition distillation, log is needed for feature maplog_out1 = paddle.log(out1)log_out2 = paddle.log(out2)loss = (self._kldiv(log_out1, out2) + self._kldiv(log_out2, out1)) / 2.0else:# for detection distillation log is not neededloss = self.jskl_loss(out1, out2)return loss    
      
    • DistillationVQADistanceLoss,本质是对 transformer 中间层的输出结果计算距离损失函数(L2 loss)

      # DistillationVQADistanceLoss的父类
      class DistanceLoss(nn.Layer):"""DistanceLoss:mode: loss mode"""def __init__(self, mode="l2", **kargs):super().__init__()assert mode in ["l1", "l2", "smooth_l1"]if mode == "l1":self.loss_func = nn.L1Loss(**kargs)elif mode == "l2":self.loss_func = nn.MSELoss(**kargs)elif mode == "smooth_l1":self.loss_func = nn.SmoothL1Loss(**kargs)def forward(self, x, y):return self.loss_func(x, y)
      

      其他部分,诸如数据集的加载、构建优化器、创建评估函数、加载预训练模型、模型训练等,大家可以查看源码,不再赘述。

2.2 关系抽取(RE)

  • 我们这里,看下模型的构建部分代码,其他代码,大家可以查看源码,不再赘述。
# paddlenlp.transformers.layoutxlm.modeling.py
class LayoutXLMForRelationExtraction(LayoutXLMPretrainedModel):def __init__(self, config: LayoutXLMConfig):super(LayoutXLMForRelationExtraction, self).__init__(config)self.layoutxlm = LayoutXLMModel(config)self.extractor = REDecoder(config.hidden_size, config.hidden_dropout_prob)self.dropout = nn.Dropout(config.hidden_dropout_prob)......def forward(self,input_ids,bbox,image=None,attention_mask=None,entities=None,relations=None,token_type_ids=None,position_ids=None,head_mask=None,labels=None,):# 1、经过12层的Transformer Block Encoderoutputs = self.layoutxlm(input_ids=input_ids,            # (bs, 512)bbox=bbox,                      # (bs, 512, 4)image=image,                    # Noneattention_mask=attention_mask,  # (bs, 512)token_type_ids=token_type_ids,  # (bs. 512)position_ids=position_ids,      # Nonehead_mask=head_mask,            # None)seq_length = input_ids.shape[1]# 最后一层输出# sequence_output_shape = (bs, 512, 768)sequence_output = outputs[0][:, :seq_length]sequence_output = self.dropout(sequence_output)# 2、计算loss和预测关系loss, pred_relations = self.extractor(sequence_output, entities, relations)hidden_states = [outputs[2][f"{idx}_data"] for idx in range(self.layoutxlm.config.num_hidden_layers)]hidden_states = paddle.stack(hidden_states, axis=1)# 3、返回结果res = dict(loss=loss, pred_relations=pred_relations, hidden_states=hidden_states)return res
  • 主要代码在REDecoder中

    • 首先,构建构建关系对的正负样本
    • 然后,获取关系头(question)、关系尾(answer)对应的特征信息
      • 获取关系头(即question)在input_ids中开始的索引对应token的hidden_states(shape=(100, 768))和关系头(question)经过Embedding后的结果(shape=(100, 768))进行concat
      • 获取关系尾(即answer)在input_ids中开始的索引对应token的hidden_states(shape=(100, 768))和关系尾(answer)经过Embedding后的结果(shape=(100, 768))进行concat
    • 利用提取到的head_repr、tail_repr特征信息进行关系分类
    • 最后,利用预测结果,计算交叉熵损失等
    • 下面,给出一个relations和entities示例,方便理解。

    在这里插入图片描述

class REDecoder(nn.Layer):def __init__(self, hidden_size=768, hidden_dropout_prob=0.1):super(REDecoder, self).__init__()self.entity_emb = nn.Embedding(3, hidden_size)# 100代表:100个关系对# (100, 1536) -> (100, 768) -> (100, 384)projection = nn.Sequential(nn.Linear(hidden_size * 2, hidden_size),nn.ReLU(),nn.Dropout(hidden_dropout_prob),nn.Linear(hidden_size, hidden_size // 2),nn.ReLU(),nn.Dropout(hidden_dropout_prob),)self.ffnn_head = copy.deepcopy(projection)self.ffnn_tail = copy.deepcopy(projection)# (100, 384) -> (100, 2)self.rel_classifier = BiaffineAttention(hidden_size // 2, 2)self.loss_fct = CrossEntropyLoss()def build_relation(self, relations, entities):"""relations_shape = (bs, 262145, 2)entities_shape  = (bs, 513, 3)注:relations第1个数组代表实际长度,例如:[10, 10],代表:关系对(QUESTION-ANSWER)只有10个,其他为填充entities第1个数组代表实际长度,例如:[20, 20, 20],代表:实例(QUESTION或ANSWER)只有20个,其他为填充"""batch_size, max_seq_len = paddle.shape(entities)[:2]# new_relations_shape = (bs, 513*513, 3), 初始化为-1new_relations = paddle.full(shape=[batch_size, max_seq_len * max_seq_len, 3], fill_value=-1, dtype=relations.dtype)for b in range(batch_size):if entities[b, 0, 0] <= 2:entitie_new = paddle.full(shape=[512, 3], fill_value=-1, dtype=entities.dtype)entitie_new[0, :] = 2entitie_new[1:3, 0] = 0  # startentitie_new[1:3, 1] = 1  # endentitie_new[1:3, 2] = 0  # labelentities[b] = entitie_new# 实体label_shape为: [2, 1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 1, 1, 2, 2, 2, 2, 1]# all_possible_relations1为: [1 , 2 , 4 , 6 , 8 , 10, 12, 13, 14, 19]  QUESTION# all_possible_relations2为: [0 , 3 , 5 , 7 , 9 , 11, 15, 16, 17, 18]  ANSWERentitie_label = entities[b, 1 : entities[b, 0, 2] + 1, 2]all_possible_relations1 = paddle.arange(0, entities[b, 0, 2], dtype=entities.dtype)all_possible_relations1 = all_possible_relations1[entitie_label == 1]all_possible_relations2 = paddle.arange(0, entities[b, 0, 2], dtype=entities.dtype)all_possible_relations2 = all_possible_relations2[entitie_label == 2]# 所有可能的关系:all_possible_relations_shape:(100, 2)# [#   [1, 0],  [1, 3], ... , [1, 18],#   [2, 0],  [2, 3], ... , [2, 18],#           ......#   [19, 0], [19, 3], ... , [19, 18]# ]all_possible_relations = paddle.stack(paddle.meshgrid(all_possible_relations1, all_possible_relations2), axis=2).reshape([-1, 2])if len(all_possible_relations) == 0:all_possible_relations = paddle.full_like(all_possible_relations, fill_value=-1, dtype=entities.dtype)all_possible_relations[0, 0] = 0all_possible_relations[0, 1] = 1# relation_head: [1 , 2 , 4 , 6 , 8 , 10, 12, 13, 14, 19]# relation_tail: [0 , 3 , 5 , 7 , 9 , 11, 17, 15, 16, 18]relation_head = relations[b, 1 : relations[b, 0, 0] + 1, 0]relation_tail = relations[b, 1 : relations[b, 0, 1] + 1, 1]# positive_relations_shape: (10, 2)positive_relations = paddle.stack([relation_head, relation_tail], axis=1)# (100, 2) -> (100, 10, 2)all_possible_relations_repeat = all_possible_relations.unsqueeze(axis=1).tile([1, len(positive_relations), 1])# (100, 2) -> (100, 10, 2)positive_relations_repeat = positive_relations.unsqueeze(axis=0).tile([len(all_possible_relations), 1, 1])# mask shape = (100, 10)mask = paddle.all(all_possible_relations_repeat == positive_relations_repeat, axis=2)# 获取关系对负样本# negative_mask = paddle.any(mask, axis=1) is Falsenegative_mask = ~paddle.any(mask, axis=1)negative_relations = all_possible_relations[negative_mask]# 获取关系对正样本# positive_mask = paddle.any(mask, axis=0) is Truepositive_mask = paddle.any(mask, axis=0)positive_relations = positive_relations[positive_mask]if negative_mask.sum() > 0:# positive_relations_shape = (10, 2)# negative_relations_shape = (90, 2)# reordered_relations_shape = (100, 2)reordered_relations = paddle.concat([positive_relations, negative_relations])else:reordered_relations = positive_relationsrelation_per_doc_label = paddle.zeros([len(reordered_relations), 1], dtype=reordered_relations.dtype)relation_per_doc_label[: len(positive_relations)] = 1# relation_per_doc shape: (100, 3)"""relation_per_doc = [[1 , 0 , 1 ],# 正样本[2 , 3 , 1 ],[4 , 5 , 1 ],......[19, 18, 1 ],[1 , 3 , 0 ],# 负样本[1 , 5 , 0 ],......]"""relation_per_doc = paddle.concat([reordered_relations, relation_per_doc_label], axis=1)assert len(relation_per_doc[:, 0]) != 0# 第1个元素记录正负样本的长度信息,例如:[100, 100, 100]new_relations[b, 0] = paddle.shape(relation_per_doc)[0].astype(new_relations.dtype)# 将正负样本放到new_relations中new_relations[b, 1 : len(relation_per_doc) + 1] = relation_per_doc# new_relations.append(relation_per_doc)return new_relations, entitiesdef get_predicted_relations(self, logits, relations, entities):"""logits: 预测得到的关系概率, 例如:shape = (100, 2)relations: shape = (100, 3)entities:  shape = (513, 3)"""pred_relations = []for i, pred_label in enumerate(logits.argmax(-1)):if pred_label != 1:continuerel = paddle.full(shape=[7, 2], fill_value=-1, dtype=relations.dtype)rel[0, 0] = relations[:, 0][i]rel[1, 0] = entities[:, 0][relations[:, 0][i] + 1]rel[1, 1] = entities[:, 1][relations[:, 0][i] + 1]rel[2, 0] = entities[:, 2][relations[:, 0][i] + 1]rel[3, 0] = relations[:, 1][i]rel[4, 0] = entities[:, 0][relations[:, 1][i] + 1]rel[4, 1] = entities[:, 1][relations[:, 1][i] + 1]rel[5, 0] = entities[:, 2][relations[:, 1][i] + 1]rel[6, 0] = 1pred_relations.append(rel)return pred_relationsdef forward(self, hidden_states, entities, relations):"""hidden_states_shape:(bs, 512, 768)entities_shape: (bs, 513, 3)    , 其中:513 = 512 + 1,第一个元素记录长度信息relations_shape: (bs, 262145, 2),其中:262145 = 512*512 + 1,第一个元素记录长度信息"""batch_size, max_length, _ = paddle.shape(entities)# 1、构建关系的正负样本# relations_shape: (bs, 263169, 3) , 其中: 263169 = 513 * 513# entities_shape: (bs, 513, 3)relations, entities = self.build_relation(relations, entities)loss = 0# 所有预测关系结果all_pred_relations = paddle.full(shape=[batch_size, max_length * max_length, 7, 2], fill_value=-1, dtype=entities.dtype)for b in range(batch_size):# 2、获取关系头(question)、关系尾(answer)对应的特征信息# 取出正负样本关系对, relation_shape = (100, 3)relation = relations[b, 1 : relations[b, 0, 0] + 1]# 获取关系头(question)、关系尾(answer)、以及关系标签(1表示question和answer是一对,即正样本, 0表示负样本)head_entities = relation[:, 0]tail_entities = relation[:, 1]relation_labels = relation[:, 2]# 每一个实体(question或answer)在input_ids中开始的索引# 例如:  [0  , 3  , 4  , 8  , 14 , 16 , 23 , 29 , 34 , 37 , 60 , 65 , 82 , 84 ,#         87 , 90 , 91 , 96 , 102, 106]entities_start_index = paddle.to_tensor(entities[b, 1 : entities[b, 0, 0] + 1, 0])# 获取每个实体类型编号,1表示question,2表示answer# 例如:[2, 1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 1, 1, 2, 2, 2, 2, 1]entities_labels = paddle.to_tensor(entities[b, 1 : entities[b, 0, 2] + 1, 2])# 获取关系头(即question)在input_ids中开始的索引,为了后面获取对应token的hidden_stateshead_index = entities_start_index[head_entities]# 获取关系头(question)对应的实体类型编号head_label = entities_labels[head_entities]# 关系头(question)经过Embedding, head_label_repr_shape = (100, 768)head_label_repr = self.entity_emb(head_label)# 获取关系尾(即answer)在input_ids中开始的索引,为了后面获取对应token的hidden_statestail_index = entities_start_index[tail_entities]# 获取关系尾(answer)对应的实体类型编号tail_label = entities_labels[tail_entities]# 关系尾(answer)经过Embedding, tail_label_repr_shape = (100, 768)tail_label_repr = self.entity_emb(tail_label)# 获取关系头(question)开始token的hidden_states, tmp_hidden_states shape: (100, 768)tmp_hidden_states = hidden_states[b][head_index]if len(tmp_hidden_states.shape) == 1:tmp_hidden_states = paddle.unsqueeze(tmp_hidden_states, axis=0)#  concat, head_repr_shape = (100, 1536)head_repr = paddle.concat((tmp_hidden_states, head_label_repr), axis=-1)# 获取关系尾(answer)开始token的hidden_states, tmp_hidden_states shape: (100, 768)tmp_hidden_states = hidden_states[b][tail_index]if len(tmp_hidden_states.shape) == 1:tmp_hidden_states = paddle.unsqueeze(tmp_hidden_states, axis=0)#  concat, tail_repr_shape = (100, 1536)tail_repr = paddle.concat((tmp_hidden_states, tail_label_repr), axis=-1)# 3、利用提取到的head_repr、tail_repr进行关系分类# heads_shape = (100, 1536) -> (100, 384)# tails_shape = (100, 1536) -> (100, 384)heads = self.ffnn_head(head_repr)tails = self.ffnn_tail(tail_repr)# 结合双线性层和线性层,实现对两个输入向量的复杂交互建模# logits_shape = (100, 2)logits = self.rel_classifier(heads, tails)# 4、计算交叉熵损失loss += self.loss_fct(logits, relation_labels)pred_relations = self.get_predicted_relations(logits, relation, entities[b])if len(pred_relations) > 0:pred_relations = paddle.stack(pred_relations)all_pred_relations[b, 0, :, :] = paddle.shape(pred_relations)[0].astype(all_pred_relations.dtype)all_pred_relations[b, 1 : len(pred_relations) + 1, :, :] = pred_relationsreturn loss, all_pred_relations
  • 关于模型的预测代码(使用OCR结果进行预测等),可以参考https://aistudio.baidu.com/projectdetail/4823162。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/57301.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

eQEP正交解码

目录 基本介绍 整体框架 关键模块 编译问题 实验效果 基本介绍 编码器是一种将角位移或者角速度转换成一连串电数字脉冲的旋转式传感器&#xff0c;我们可以通过编码器测量到位移或者速度信息。编码器从输出数据类型上分&#xff0c;可以分为增量式编码器和绝对式编码器。…

深入浅出MySQL:概述与体系结构解析

目录 1. 初识MySQL1.1. 数据库1.1.1. OLTP&#xff08;联机事务处理&#xff09;1.1.2. OLAP&#xff08;联机分析处理&#xff09; 2. SQL2.1. 定义2.2. DQL&#xff08;数据查询语言&#xff09;2.3. DML&#xff08;数据操纵语言&#xff09;2.4. DDL&#xff08;数据定义语…

Python基于OpenCV的实时疲劳检测

2.检测方法 1&#xff09;方法 与用于计算眨眼的传统图像处理方法不同&#xff0c;该方法通常涉及以下几种组合&#xff1a; 1、眼睛定位。 2、阈值找到眼睛的白色。 3、确定眼睛的“白色”区域是否消失了一段时间&#xff08;表示眨眼&#xff09;。 相反&#xff0c;眼睛长…

Python网络请求库requests的10个基本用法

大家好&#xff01;今天我们要聊聊Python中非常实用的一个库——requests。这个库让发送HTTP请求变得超级简单。无论你是想抓取网页数据还是测试API接口&#xff0c;requests都能派上大用场。下面我们就一起来看看如何使用requests完成一些常见的任务。 引言 随着互联网技术的…

队列(数据结构)——C语言

目录 1.概念与结构 2.队列的实现 初始化QueueInit 申请新节点BuyNode 入队QueuePush 判断队为空QueueEmpty 出队QueuePop 读取队头数据QueueFront 读取队尾数据QueueBack 元素个数QueueSize 销毁队列QueueDestroy 3.整体代码 (文章中结点和节点是同一个意思) 1.概…

keil兼容C51和ARM,C251

三合一 C51,AEM,C251获取STC32的包 将 C51,AEM,C251安装到一块。 C51,AEM,C251 将三个软件分别下载到不同的文件夹KEIL,MDK,KEIL2里。 然后打开KEIL,MDK,KEIL2文件夹&#xff0c;复制KEIL文件夹里的C51和KEIL2文件夹里的C251的文件夹到MDK文件夹里。 打开KEIL和KEIL2文件夹里…

单链表的经典算法OJ

目录 1.反转链表 2.链表的中间节点 3.移除链表元素 ——————————————————————————————————————————— 正文开始 1.反转链表 typedef struct ListNode ListNode; struct ListNode* reverseList(struct ListNode* head) {//判空if(…

运行kafka查看所有主题Topic报错zookeeper is not a recognized option

执行命令查看&#xff1a;./kafka-topics.sh --list --zookeeper localhost:2181 报错 zookeeper is not a recognized option joptsimple.UnrecognizedOptionException: zookeeper is not a recognized optionat joptsimple.OptionException.unrecognizedOption(OptionExcept…

000010 - Mapreduce框架原理

Mapreduce框架原理 1. InputFormat 数据输入1.1 切片与 MapTask 并行度决定机制1.2 Job 提交流程源码和切片源码详解1.2.1 Job 提交流程源码详解1.2.2 FileInputFormat 切片源码解析&#xff08;input.getSplits(job)&#xff09; 1.3 FileInputFormat 切片机制1.3.1 切片机制1…

二、PyCharm基本设置

PyCharm基本设置 前言一、设置中文汉化二、设置代码字体颜色三、设置鼠标滚轮调整字体大小四、修改 PyCharm 运行内存4.1 方式一4.1 方式二 五、显示 PyCharm 运行时内存六、设置代码模版配置的参数有&#xff1a; 七、PyCharm设置背景图总结 前言 为了让我们的 PyCharm 更好用…

一家射频芯片企业终止,报告期持续亏损,高端产品占比不足

飞骧科技终止原因如下&#xff1a;飞骧科技从事的射频芯片行业如今竞争激烈&#xff0c;飞骧科技的产品主要应用于中低端手机&#xff0c;如摩托罗拉、传音&#xff0c;相比同行业上市公司已经退出的低集成度市场&#xff0c;相关产品展飞骧科技业务比重仍然不低。交易所质疑其…

【Matlab】基于Prandtl−Ishlinskii的迟滞模型-RLS辨识

PI模型 PI迟滞模型的输出公式&#xff1a; 代码记录 此代码为根据PI模型&#xff0c;已知输入&#xff08;正弦函数幅值为3.5&#xff09;、阈值以及权重值&#xff0c;利用matlab生成迟滞回线。 %% The Prandtl-Ishlinskii Hysteresis Model- %% 20241021 clc;clear; close…

数据结构与算法——Java实现 44.翻转二叉树

目录 226. 翻转二叉树 思路 代码 本地代码测试 不管前方的路有多苦 只要走的方向正确 不管多么崎岖不平 都比站在原地更接近幸福 —— 24.10.21 226. 翻转二叉树 给你一棵二叉树的根节点 root &#xff0c;翻转这棵二叉树&#xff0c;并返回其根节点。 示例 1&#xff1a; 输…

Unity AnimationClip详解(2)——动画数据的优化

【内存优化】 首先要意识到运行时和编辑时的区别&#xff0c;当运行时和编辑时所需的数据相差不大时&#xff0c;我们用同一套数据结构即可&#xff0c;当两者差异较多或者数据量很大时&#xff0c;需要有各自的数据结构&#xff0c;这意味着在打包或构建时需要将编辑时数据转…

【Linux探索学习】第七弹——Linux的工具(二):Linux下vim编辑器的使用详解

Linux的工具&#xff08;一&#xff09;&#xff1a;【Linux探索学习】第六弹——Linux的工具&#xff08;一&#xff09;&#xff1a;Ubuntu系统下的软件包管理器_ubuntu软件管理器-CSDN博客 前言&#xff1a; 在学习Linux之前&#xff0c;相信大家都或多或少的学习过一些计算…

微信小程序用开发工具在本地真机调试可以正常访问摄像头,发布了授权后却无法访问摄像头,解决方案

今天开发上线了一个拍照的微信小程序&#xff0c;用uniapp的Vue3开发的&#xff0c;调用的camera组件&#xff0c;相关代码如下&#xff1a; <!-- 微信小程序相机组件 --><view v-if"showCamera" class"camera-container"><camera :device…

Adobe Acrobat DC 打印PDF文件,没有打印出注释的解决方法

adobe acrobat在打印的时候&#xff0c;打印不出来注释内容&#xff08;之前一直可以&#xff0c;突然就不行&#xff09;&#xff0c;升级版本、嵌入字体等等都试过&#xff0c;也在Google找了半天和问了GPT也么找着办法。 无奈之下&#xff0c;自己通过印前检查&#xff0c;…

免费开源AI助手,颠覆你的数字生活体验

Apt Full作为一款开源且完全免费的软件&#xff0c;除了强大的自然语言处理能力&#xff0c;Apt Full还能够对图像和视频进行一系列复杂的AI增强处理&#xff0c;只需简单几步即可实现专业级的效果。 在图像处理方面&#xff0c;Apt Full提供了一套全面的AI工具&#xff0c;包…

Windows环境下Qt Creator调试模式下qDebug输出中文乱码问题

尝试修改系统的区域设置的方法&#xff1a; 可以修复问题。但会出现其它问题&#xff1a; 比如某些软件打不开&#xff0c;或者一些软件界面的中文显示乱码&#xff01; 暂时没有找到其它更好的办法。

《YOLO目标检测》—— YOLO的简单介绍及Map评估指标

文章目录 一、简单概述二、YOLO中的Map指标1.定义与计算2.应用与意义3.注意事项 一、简单概述 YOLO&#xff08;You Only Look Once&#xff09;是一种目标检测算法&#xff0c;由Redmon等人在2016年提出。它的主要特点是速度快且准确性高&#xff0c;非常适合用于实时目标检测…