NLP深入学习:结合源码详解 BERT 模型(三)

文章目录

  • 1. 前言
  • 2. 预训练
    • 2.1 modeling.BertModel
      • 2.1.1 embedding_lookup
      • 2.1.2 embedding_postprocessor
      • 2.1.3 transformer_model
    • 2.2 get_masked_lm_output
    • 2.3 get_next_sentence_output
    • 2.4 训练
  • 3. 参考


1. 前言

前情提要:
《NLP深入学习:结合源码详解 BERT 模型(一)》
《NLP深入学习:结合源码详解 BERT 模型(二)》

之前已经详细说明了 BERT 模型的主要架构和思想,并且讲解了 BERT 源代码对于数据准备的流程,回顾下关键字段的含义:

# 以下是输出到文件的值,也是会作为后续预训练的输入值,重点看!
input_ids:tokens在字典的索引位置,不足max_seq_length(128)则补0
input_mask:初始化为1,不足max_seq_length(128)则补0
segment_ids: 句子A的token和句子B的token,按照0/1排列区分。不足max_seq_length(128)则补0
masked_lm_positions: 被选中 MASK 的token位置索引
masked_lm_ids:被选中 MASK 的token原始值在字典的索引位置
masked_lm_weights:初始化为1
next_sentence_labels:对应is_random_next,1表示随机选择,0表示正常语序

下面我们结合预训练代码详细讲解下 BERT 的预训练流程。

2. 预训练

预训练代码在 run_pretraing.py 文件中,注意我们需要把数据准备的结果作为预训练的输入:
在这里插入图片描述
那我们打上断点,继续开启 debug 吧!
在这里插入图片描述

2.1 modeling.BertModel

看预训练代码,大部分的核心代码集中在 modeling.BertModel 这个 class 的 __init__ 代码中:
在这里插入图片描述
解释下 modeling.BertModel 的参数:

  • config: BERT 的配置文件,后续的很多参数都来源于此。我放到路径 ./multi_cased_L-12_H-768_A-12/bert_config.json ,内容如下:
{"attention_probs_dropout_prob": 0.1, "directionality": "bidi", "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "hidden_size": 768, "initializer_range": 0.02, "intermediate_size": 3072, "max_position_embeddings": 512, "num_attention_heads": 12, "num_hidden_layers": 12, "pooler_fc_size": 768, "pooler_num_attention_heads": 12, "pooler_num_fc_layers": 3, "pooler_size_per_head": 128, "pooler_type": "first_token_transform", "type_vocab_size": 2, "vocab_size": 119547
}
  • is_training:True 表示训练,False 表示评估
  • input_ids:对应于数据准备的字段 input_ids,形状 [batch_size, seq_length],即 [32, 128]
  • input_mask:对应于数据准备的字段 input_mask,形状 [batch_size, seq_length],即 [32, 128]
  • token_type_ids:对应于数据准备的字段 segment_ids,形状 [batch_size, seq_length],即 [32, 128]
  • use_one_hot_embeddings:词嵌入是否用 one_hot 模式
  • scope:变量的scope,用于 tf.variable_scope(scope, default_name="bert") 默认是 bert

2.1.1 embedding_lookup

modeling.BertModel__init__ 代码中,第一个重要的方法是 embedding_lookup
在这里插入图片描述
我们看下具体的代码,返回值有两个:

  • out_put 是根据输入的 input_ids 在字典中找到对应的词,并且返回词对应的 embedding 向量,out_put 的形状是 [batch_size, seq_length, embedding_size]
  • embedding_table 是字典每一个词对应的向量,形状是 [vocab_size, embedding_size]

在这里插入图片描述
ps: 有些同学不清楚字典是什么?字典在项目的 ./multi_cased_L-12_H-768_A-12/vocab.txt 里,每一行对应一个词,里例如id=0则表示字典第一个对应的词[PAD],字典内容如下:

[PAD]
[unused1]
[unused2]
[unused3]
[unused4]
...
[unused99]
[UNK]
[CLS]
[SEP]
[MASK]
<S>
<T>
!
"
#
$
%
...
A
B
C
D
E
F
G
H

2.1.2 embedding_postprocessor

后续的该方法是用于加上位置编码!
在这里插入图片描述
我们进到函数内部查看具体细节:
在这里插入图片描述
上面代码中,token_type_ids 对应的是 segment_ids,即句子的表示(用0/1来表示),细节见《NLP深入学习:结合源码详解 BERT 模型(二)》 的 2.3章节。token_type_table 和上一节的 embedding_table 是一样的含义,这里就是向量化 segment_ids。由于 segment_ids 只用 0和1来表示,所以token_type_vocab_size=2,并且最终将 out_put 加上了 segment_ids 向量化的结果,就是图中的 TokenEmbeddings + SegmentEmbeddings
在这里插入图片描述
那么显而易见,下一段代码就是再加上 PositionEmbeddings 了!
在这里插入图片描述
注意,这里的 position_embeddings 实际就是词在句子中的位置对应的 embedding~

最后将输出加上了 layer_norm_and_dropout ,即层归一和dropout。

2.1.3 transformer_model

顺着代码debug下去,在准备好了数据之后,就是经典的 Transformer 模型了:
在这里插入图片描述
希望深入了解 Transformer 的,建议参考:
《NLP深入学习:大模型背后的Transformer模型究竟是什么?(一)》
《NLP深入学习:大模型背后的Transformer模型究竟是什么?(二)》

我们先回忆下 Transformer 的结构,因为下面的代码完全是对论文的编码器实现:
在这里插入图片描述
为了方便查看,我把代码的结构和论文的结构对比在一起:
在这里插入图片描述
transformer 结构构建完成之后,下面的self.sequence_out 是把最后一层的输出作为 transformer 的 encoding 结果输出。
在这里插入图片描述
此外,first_token_tensor 是取第一个 token 的输出结果,即 [CLS] 的结果。因为 [CLS] 已经带有上下文信息了,因此对于分类而言,用 [CLS] 的输出即可。这个论文中也有说明:
在这里插入图片描述
以上就是 BERT 模型的构建整体流程,下面来看 BERT 模型的评估流程,包含 Masked Language Model(MLM)和 Next Sentence Prediction(NSP)。

2.2 get_masked_lm_output

先来看 Masked Language Model(MLM)的评估,对应代码中的 get_masked_lm_out ,见下图:

首先看下 get_masked_lm_out 的输入参数:

  • bert_config : BERT 的配置文件,对应我的路径 ./multi_cased_L-12_H-768_A-12/bert_config.json
  • input_tensor:BERT 模型的输出,即上文的 self.sequence_out
  • output_weights:对应上文 embedding_lookup 的第二个输出,即字典每一个词对应的向量,形状是 [vocab_size, embedding_size]
  • positions:对应 features["masked_lm_positions"] ,即被选中 MASK 的 token 位置索引
  • label_ids:对应 features["masked_lm_ids"],即被选中 MASK 的 token 原始值在字典的索引位置
  • label_weights:对应 features["masked_lm_weights"]

下面是整体的代码,代码有些地方需要细细品味:

在这里插入图片描述
要看懂这里的代码,首先我们要知道 BERT 在 Masked Language Model(MLM)上要干啥。BERT 首先给句子的词打上了 [MASK] ,后续就要对 [MASK] 的词进行预测。预测,就是在词典中出现的词给出一个概率,看属于哪个词,本质上就是多分类问题。那么对于多分类问题,通常的做法是计算交叉熵。

这里就不详细阐述交叉熵的来龙去脉了,直接说明交叉熵如何计算。我们假设真实分布为 y,而模型输出分布为 y ^ \widehat{y} y ,总的类别数为 n,交叉熵损失函数的计算方法为:
l o s s = ∑ i = 1 n [ − y l o g y ^ i − ( 1 − y ) l o g ( 1 − y ^ i ) ] loss = \sum_{i=1}^{n}[-ylog\widehat{y}_i-(1-y)log(1-\widehat{y}_i)] loss=i=1n[ylogy i(1y)log(1y i)]
好,我们来看代码中关键的几个步骤:

  • log_probs = tf.nn.log_softmax(logits, axis=-1) ,这个方法实际上计算的是:
    l o g _ p r o b s = [ l o g y ^ 1 , l o g y ^ 2 , . . . , l o g y ^ n ] log\_probs = [log\widehat{y}_1, log\widehat{y}_2,...,log\widehat{y}_n] log_probs=[logy 1,logy 2,...,logy n]
    其中 l o g y ^ i log\widehat{y}_i logy i 表达的是属于词典第 i 个词的概率的对数值。

  • one_hot_labels = tf.one_hot(label_ids, depth=bert_config.vocab_size, dtype=tf.float32),计算每个词的在字典的 one_hot 结果,形状是 [batch_size*seq_len, vocab_size]。例如,“animal” 在字典第18883位置,那么"animal"对应的 one_hot 就是 [0,0,…0,1,0,…,0],其中向量长度就是字典的大小,1排在向量的18883个。

  • per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) ,这个方法是用于交叉熵的。因为我们知道真实的分布情况,就是 one_hot_labels 对应的结果,那么对于某一个具体的词,其交叉熵的计算就是 − y l o g y ^ i − ( 1 − y ) l o g ( 1 − y ^ i ) -ylog\widehat{y}_i-(1-y)log(1-\widehat{y}_i) ylogy i(1y)log(1y i),将 y=1(即事先知道一定属于某个词)代入,即交叉熵为 − l o g y ^ i -log\widehat{y}_i logy i。所以事先计算了 log_probsper_example_loss 可以直接得到每个词的交叉熵的结果。

  • lossper_example_loss 得到的结果赋予权重进行加权平均,得到一个最终的 loss,实际上就相当于 l o s s = ∑ i = 1 n w i [ − y l o g y ^ i − ( 1 − y ) l o g ( 1 − y ^ i ) ] loss = \sum_{i=1}^{n}w_i[-ylog\widehat{y}_i-(1-y)log(1-\widehat{y}_i)] loss=i=1nwi[ylogy i(1y)log(1y i)]

2.3 get_next_sentence_output

再来看 Next Sentence Prediction(NSP)评估,预测句子的下一句:
在这里插入图片描述
首先看下 get_next_sentence_output 的输入参数:

  • bert_config: BERT 的配置文件,对应我的路径 ./multi_cased_L-12_H-768_A-12/bert_config.json
  • input_tensor[CLS] 的输出线性变换后的结果,简单理解为 [CLS] 的输出作为当前函数的输入
  • labels:对应 features["next_sentence_labels"] ,1表示下一个句子是随机选择的,0表示正常语序

由于下一个句子只有两种选择,要么是随机的,要么是原先正常的句子,所以其实就是一个二分类问题:
在这里插入图片描述
二分类的交叉熵:
l o s s = ∑ i = 1 n − y l o g y ^ i loss = \sum_{i=1}^{n}-ylog\widehat{y}_i loss=i=1nylogy i
上面的核心逻辑跟 get_masked_lm_output 一模一样。只不过这里的 loss 用的是平均值,没有用加权平均

2.4 训练

计算了 masked_lm_loss 以及 next_sentence_loss 之后,将两种 loss 相加,即是总的 loss
在这里插入图片描述
后续就训练模型降低 loss

3. 参考

《NLP深入学习:结合源码详解 BERT 模型(一)》
《NLP深入学习:结合源码详解 BERT 模型(二)》
《BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding》
《NLP深入学习:大模型背后的Transformer模型究竟是什么?(一)》
《NLP深入学习:大模型背后的Transformer模型究竟是什么?(二)》

欢迎关注本人,我是喜欢搞事的程序猿; 一起进步,一起学习;

欢迎关注知乎:SmallerFL;

也欢迎关注我的wx公众号:一个比特定乾坤

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/776383.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

PyQt5开发——QCheckBox 复选框用法与代码示例

1. 复选框 QCheckBox 是 Qt 框架中的一个控件&#xff0c;用于在界面中表示一个可以被选中或取消选中的复选框。它通常用于允许用户在多个选项之间进行选择。在 Python 中使用 PyQt 或 PySide 开发 GUI 应用程序时&#xff0c;可以使用 QCheckBox 控件来实现复选框。 2.基本用…

[ Linux ] git工具的基本使用(仓库的构建,提交)

1.安装git yum install -y git 2.打开Gitee&#xff0c;创建你的远程仓库&#xff0c;根据提示初始化本地仓库&#xff08;这里以我的仓库为例&#xff09; 新建好仓库之后跟着网页的提示初始化便可以了 3.add、commit、push三板斧 git add . //add仓库新增&#xff08;变…

企业数字化转型:聊聊数据思维!

笔者曾在《深入聊一聊企业数字化转型这个事儿》 一文中给出了数字化转型的定义&#xff0c;即&#xff1a;通过应用数字化技术来重塑企业的信息化环境和业务过程。本质上来讲&#xff0c;企业数字化转型&#xff0c;不仅是技术方面的升级&#xff0c;更是企业文化、思维方式的转…

【计算机考研】408到底有多难?

你真以为大家是学不会408吗&#xff1f; 不是&#xff01;单纯是因为时间不够&#xff01;&#xff01;&#xff01; 再准确一些就是不会分配时间 408的知识其实并不难&#xff0c;要说想上130那确实有难度&#xff0c;但是100在时间充裕的情况下还是可以做到的 我本人是双…

非wpf应用程序项目【类库、用户控件库】中使用HandyControl

文章速览 前言参考文章实现方法1、添加HandyControl包;2、添加资源字典3、修改资源字典内容坚持记录实属不易,希望友善多金的码友能够随手点一个赞。 共同创建氛围更加良好的开发者社区! 谢谢~ 前言 wpf应用程序中,在入口项目中存在App.xaml文件,在这个文件中加上对各个…

Linux之进程控制进程终止进程等待进程的程序替换替换函数实现简易shell

文章目录 一、进程创建1.1 fork的使用 二、进程终止2.1 终止是在做什么&#xff1f;2.2 终止的3种情况&&退出码的理解2.3 进程常见退出方法 三、进程等待3.1 为什么要进行进程等待&#xff1f;3.2 取子进程退出信息status3.3 宏WIFEXITED和WEXITSTATUS&#xff08;获取…

全球首位AI程序员Devin诞生,以此谈谈AI对程序员的影响

一、简介 全球首位 AI 程序员 Devin 是由初创公司 Cognition AI 创造的。这家公司成立仅四个月&#xff0c;却已经引起了广泛关注。 Devin作为人工智能的代表&#xff0c;将展示出人工智能在编程领域的潜力和能力&#xff0c;激发程序员探索和应用人工智能技术的兴趣。这将可…

NanoMQ的安装与部署

本文使用docker进行安装&#xff0c;因此安装之前需要已经安装了docker 拉取镜像 docker pull emqx/nanomq:latest 相关配置及密码认证 创建目录/usr/local/nanomq/conf以及配置文件nanomq.conf、pwd.conf # # # # MQTT Broker # # mqtt {property_size 32max_packet_siz…

6、ChatGLM3-6B 部署实践

一、ChatGLM3-6B介绍与快速入门 ChatGLM3 是智谱AI和清华大学 KEG 实验室在2023年10月27日联合发布的新一代对话预训练模型。ChatGLM3-6B 是 ChatGLM3 系列中的开源模型&#xff0c;免费下载&#xff0c;免费的商业化使用。 该模型在保留了前两代模型对话流畅、部署门槛低等众多…

官网怎么发布新文章,怎么在官方网站上发布新内容

随着企业和组织越来越重视官方网站的建设和更新&#xff0c;发布新内容成为了官方网站管理的重要一环。本文将探讨在官方网站上发布新内容的步骤和方法&#xff0c;以及如何确保发布的内容质量和效果。 1. 确定发布内容 在发布新内容之前&#xff0c;首先需要确定发布的内容。…

精品凉拌菜系列热卤系列课程

这一系列课程涵盖精美凉拌菜和美味热卤菜的制作技巧。学员将学习如何选材、调味和烹饪&#xff0c;打造口感丰富、色香俱佳的菜肴。通过实践训练&#xff0c;掌握独特的烹饪技能&#xff0c;为家庭聚餐或职业厨艺提升增添亮点。 课程大小&#xff1a;6.6G 课程下载&#xff1…

windows安装R4.3.3

官网地址The Comprehensive R Archive Network 下载后得到exe安装&#xff0c;默认安装到了C:\Program Files\R&#xff0c; 因为之前已经安装了4.2.3&#xff0c;所以新建了文件夹为4.3.3&#xff0c;两者互不干扰 安装完毕后&#xff0c;打开rstudio&#xff0c;设置 然后重…

基于springboot+vue+Mysql的酒店管理系统

开发语言&#xff1a;Java框架&#xff1a;springbootJDK版本&#xff1a;JDK1.8服务器&#xff1a;tomcat7数据库&#xff1a;mysql 5.7&#xff08;一定要5.7版本&#xff09;数据库工具&#xff1a;Navicat11开发软件&#xff1a;eclipse/myeclipse/ideaMaven包&#xff1a;…

C++入门:类和对象(上)

类和对象重点解析 1.类的定义1.类的访问限定符及封装1.C实现封装的方式2.访问限定符注意 3.封装 2.类对象模型2.1类对象存储方式2.2类对象的大小2.2.1结构体内存对齐原则2.2.2为什么要内存对齐 3.this指针3.1this指针的引出3.2this指针的特性3.3this指针的存储3.4this指针可以为…

如何安全地添加液氮到液氮罐中

液氮是一种极低温的液体&#xff0c;它在许多领域广泛应用&#xff0c;但在处理液氮时需谨慎小心。添加液氮到液氮罐中是一个常见的操作&#xff0c;需要遵循一些安全准则以确保操作人员的安全和设备的完整性。 选择合适的液氮容器 选用专业设计用于存储液氮的容器至关重要。…

UEDITOR WORD图片转存交互

1.下载示例&#xff1a; Word一键粘贴控件-示例-泽优软件 2.复制WordPaster插件目录 3.引入插件文件 注意&#xff1a;不要重复引入jquery&#xff0c;如果您的项目已经引入了jq&#xff0c;则不用再引入jq-1.4 4.在工具栏中增加插件按钮 6.初始化控件 注意&#xff1a; 1.如…

专业文件翻译,笔译翻译公司推荐!

在全球化的大潮中&#xff0c;文件翻译已然成为了商业、法律、科技、文化等诸多领域的核心纽带。特别是在商业交往、合同签订、技术交流等方面&#xff0c;一份高质量的译文往往关乎着合作的成败。而在这其中&#xff0c;专业的文件翻译公司更是扮演着至关重要的角色。它们不仅…

C语言例4-33:求调和级数中第多少项的值大于10

代码如下&#xff1a; //求调和级数中第多少项的值大于10 //调和级数的第n项为11/21/3...1/n #include<stdio.h> #define LIMIT 10 int main(void) {int n1;float sum0.0;for(;;) //死循环&#xff0c;或者while&#xff08;1&#xff09;{sumsum1.0/n;if(sum&g…

软件测试工作规范、流程规范

1. 制定规则 为了规范测试工作、减少开发与测试之前的沟通成本、保证项目进度、提高软件质量&#xff0c;测试组起草了这份软件测试工作规范。 1.1. 编码规范 软件程序开发需要遵守编码规范&#xff0c;一是可以减少代码的维护成本&#xff0c;提高开发工作效率&#xff1b;…

Chrome 插件 storage API 解析

Chrome.storage API 解析 使用 chrome.storage API 存储、检索和跟踪用户数据的更改 一、各模块中的 chrome.storage 内容 1. Service worker 中 runtime 内容 2. Action 中 runtime 内容 3. Content 中 runtime 内容 二、权限&#xff08;Permissions&#xff09; 如果需使…