【适合初学者】这可能是全网最简单的Bert代码学习

最简单的Bert代码学习

  • 教程
    • 数据集说明
    • 建立词表
    • 任务一:随机mask并预测重建
      • 举例说明
        • 构建输入
        • 构建目标
        • segment_label
      • 随机mask的具体实现
    • 任务二:预测两个子句是否来自同一个句子
      • 举例说明
        • 构建输入
        • 构建目标
      • 随机选择子句的实现方式
    • BERT模型
      • 输入
      • 生成mask
      • 三种编码
        • Token编码
        • Position编码
          • 初始化
          • forward
        • Segment编码
        • 求和
      • Encoder
        • 多头自注意力
        • 前馈网络

这是一个关于2018年Google AI Language提出的BERT模型的极简入门教程,用于带领初学者以最快的速度了解BERT。

BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding
Paper Link: https://arxiv.org/abs/1810.04805

本文整理自Github项目BERT-Easy-Tutorial,完整的代码细节和详细注释都在项目中。
欢迎Star!

项目具有以下特点:

  • 超精简的数据:仅由两行文本构成数据集。
  • 超详细的注释:每行核心代码都有解释说明。
  • 超全面的文档:中英文文档详细介绍数据流水线。
  • 无冗余的代码:不需要显卡训练、配置加载、模型保存等操作。
  • 易配置的环境:只需要Python、Pytorch、Numpy即可运行。

教程

数据集说明

项目使用仅由两行句子构成的极为简单的数据集。在每行句子中,又通过\t分为两个子句,表示句子的上下文(也可认为前半句为question,后半句为answer)。
即:
Welcome to the [\t] the jungle
I can stay [\t] here all night

建立词表

为数据集建立对应的词表vocab(需要额外引入一些特殊的token),词表的本质就是一个字典,实现由英文单词到数字的映射:
{‘< pad >’: 0, ‘< unk >’: 1, ‘< eos >’: 2, ‘< sos >’: 3, ‘< mask >’: 4, ‘the’: 5, ‘I’: 6, ‘Welcome’: 7, ‘all’: 8, ‘can’: 9, ‘here’: 10, ‘jungle’: 11, ‘night’: 12, ‘stay’: 13, ‘to’: 14}
其中< pad >表示填充,< unk >表示未知,< sos >表示句子的开头,< eos >表示句子的结尾,< mask >表示被遮挡。

任务一:随机mask并预测重建

在Bert论文中,提出了两种训练任务。

第一个任务是随机将句子中的一些词给遮挡(mask)住,然后使Bert预测这些被遮挡的词原来是什么。

举例说明

构建输入
  1. 将Welcome to the the jungle转为词表的index序列:7 14 5 和 5 11
  2. 现在随机mask掉一部分,序列变为:4 14 5 和 5 10 (即将7变为mask token,将11变为随机token)
  3. 合并序列,并添加起止token,序列变为:3 4 14 5 2 5 10 2(需要注意的是,在前半句前后添加sos 3和eos 2,但是后半句只在最后添加eos 2)
  4. 填充到预定义长度:3 4 14 5 2 5 10 2 0 0 0 0 0 0 0 0 0 0 0 0(此处预设长度为20)

至此,构建完了Bert任务一的输入序列(注意,只是展示了前后两端来自同一个句子的情况,也可能任务二中使用Welcome to the和here all night进行组合,但是这并不影响整个流程)

构建目标

在构建输入中可知,将7变为mask token 4,将11变为随机token 10,这两个变化需要Bert将其重新预测为原来的值,因此构建目标target为0 7 0 0 0 0 11 0 0 0 0 0 0 0 0 0 0 0 0 0(注意,sos和eos的位置用padding的0代替)

至此,生成了Bert任务一的目标序列。

segment_label

表明当前的单词来自哪里,例如:1 1 1 1 1 2 2 2 0 0 0 0 0 0 0 0 0 0 0 0
若为1,则表示token来自前半段(前半个子句);若为2,则表示token来自后半段(后半个子句);若为0,则表示token来自padding。这一项在Bert模型的编码中需要使用。

随机mask的具体实现

以15%的概率随机mask,具体逻辑为:

import random
sentence='Welcome to the'
toekns=[] # 转为index
output_label=[] # Bert需要预测的targetfor i,word in enumerate(sentence):prob = random.random() # 随机一个0~1的数if prob < 0.15:prob /= 0.15if prob < 0.8: # 80%概率使用mask填充tokens[i] = vocab.mask_indexelif prob < 0.9: # 10%概率使用完全的随机填充tokens[i] = random.randrange(len(vocab))else: # 10%概率不变,即取消mask,注意虽然这里没有mask,但是output_label仍需要做出预测toekns[i]=去vocal中查找word对应的indexoutput_label[i]=去vocal中查找word对应的indexelse: # 不masktoekns[i]=去vocal中查找word对应的indexoutput_label[i]=0 # 表示当前没有mask

通过这种方式构造出tokens,即英文单词在vocab字典中的位置索引序列;和output_label,即Bert的任务一需要预测的target。

任务二:预测两个子句是否来自同一个句子

通过\t可以将数据集的某一行句子拆分为两个子句。第二个任务有50%的概率选择不是来自同一个句子的子句,让Bert进行预测其来源。

举例说明

构建输入
  1. 根据传入的行号,读取数据集中的某一整行,并通过\t分为两个子句,如t1=Welcome to the,t2=the jungle
  2. 有50%的概率将t2换为来自其他行的子句,如可以将t2更换为here all night
构建目标

若bert_input由来自同一个句子的t1和t2构成,则目标为1;否则为0。

随机选择子句的实现方式

t1, t2 = get_corpus_line(index) # 根据index(行号)读取一个完整的句子,通过\t分为t1和t2
if random.random() > 0.5:return t1, t2, 1 # 50%的概率返回来自同一个句子的两个子句,并标记为1
else:return t1, get_random_line(), 0 # 50%的概率返回来自不同句子的两个子句,并标记为0

BERT模型

从主要结构上来看,Bert对输入首先进行三种编码:

  1. token编码,将输入的字典索引序列编码为稠密特征嵌入
  2. position编码,生成序列的位置编码,区分token的位置
  3. segment编码,对前文提到的segment_label进行编码,区分句子来源

然后进行多层的Transformer-Encoder结构,用于提取特征

下面我们只关注Bert模型的forward部分

输入

  • x:shape为[batch_size, seq_len]的序列,seq_len默认设置为了20,表示随机mask和padding后的index序列
  • segment_info:shape为[batch_size, seq_len]的序列,seq_len默认设置为了20,值为1表示当前token来自前半句,值为2表示当前token来自后半句,值为0表示当前token为padding的

生成mask

  1. 由于padding用0填充,因此(x > 0)表示生成一个shape同样为[batch_size, seq_len]的bool类型的序列,>0的为True,否则填充的位置为False,表示不可见
  2. unsqueeze(1)在维度1扩展,生成[batch_size, 1, seq_len]的序列
  3. repeat(1, x.size(1), 1),在扩充出来的维度重复seq_len次,生成[batch_size, seq_len, seq_len]的序列
mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1)

三种编码

Token编码

TokenEmbedding,即单层的nn.Embedding(num_embeddings=len(vocab), embedding_dim=hidden, padding_idx=0)

将x [batch_size, seq_len]]的整数序列编码为转换为密集向量表示[batch_size, seq_len, hidden]

注意,padding_idx=0表示指定整数序列值0是padding填充,对应的向量编码(无论随着训练参数更新了多少次)也全部为0。

Position编码

Bert使用表示绝对位置的正余弦编码,可以在init方法中提前预处理出一个shape为[max_len, hidden(即d_model)]的位置编码map,然后在forward中只截取前seq_len个返回即可。因此,传参初始化时,必须传递hidden,可以不传递max_len

初始化

核心公式

P E ( p o s , 2 i ) = s i n ( p o s / 1000 0 2 i / d i m ) PE(pos,2i)=sin(pos/10000^{2i/dim}) PE(pos,2i)=sin(pos/100002i/dim)

P E ( p o s , 2 i + 1 ) = c o s ( p o s / 1000 0 2 i / d i m ) PE(pos,2i+1)=cos(pos/10000^{2i/dim}) PE(pos,2i+1)=cos(pos/100002i/dim)

其中pos表示单词在token序列中的位置,取值范围0 ~ seq_len;i表示维度的位置,取值0 ~ dim;dim表示维度长度

  1. 首先创建[max_len, d_model]的全0 tensor,后续在这上面做修改
pe = torch.zeros(max_len, d_model).float()
  1. 生成0 ~ max_len-1的序列,然后扩充后的维度,变成[max_len, 1],即0~max_len-1中每个数字一行,这表示单词在token序列中的位置,unsqueeze是为了后续的广播
position = torch.arange(0, max_len).float().unsqueeze(1)
  1. 借助log和exp计算分数部分,div_term的shape为[d_model//2]
div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))

公式推理:
e ( 2 i ) ⋅ − ( l o g ( 10000 ) / d i m ) = 1 e 2 i ⋅ l o g ( 10000 ) d i m e^{(2i) \cdot -(log(10000)/ dim)}=\frac{1}{e^{\frac{2i \cdot log(10000)}{dim}}} e(2i)⋅−(log(10000)/dim)=edim2ilog(10000)1
分母单独拿出来
e 2 i ⋅ l o g ( 10000 ) d i m = e l o g ( 1000 0 2 i d i m ) = 1000 0 2 i d i m e^{\frac{2i \cdot log(10000)}{dim}}=e^{log(10000^{\frac{2i}{dim}})}=10000^{\frac{2i}{dim}} edim2ilog(10000)=elog(10000dim2i)=10000dim2i

  1. 计算正余弦。pe的为[max_len, d_model]的全0tensor;position为[max_len, 1]的序列位置;div_term为[d_model//2]。首先position * div_term后,shape变为[max_len, d_model//2],即0~max_len-1每个数都乘以div_term,通过广播完成;然后经过sin或cos,shape不变;后使用切片,对于pe的偶数位置使用sin编码,奇数位置使用cos编码
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
  1. 继续增加维度,由[max_len, d_model]变为[1, max_len, d_model],增加batch size维度,方便forward广播
pe = pe.unsqueeze(0)
  1. pe不需要计算梯度,因此注册到缓冲区
self.register_buffer('pe', pe)
forward

输入x,shape为[batch_size, seq_len]。直接从[1, max_len, d_model]的max_len个中截取前seq_len个返回

return self.pe[:, :x.size(1)]
Segment编码

单层的nn.Embedding(num_embeddings=3, embedding_dim=hidden, padding_idx=0)

对segment_info进行编码,即段编码,区分当前的单词是来自句子前半段、后半段、还是padding

求和

分别获取上述三种编码并求和

x = self.token(x) + self.position(x) + self.segment(segment_info)

Encoder

每层Encoder由多头自注意力、前馈网络再搭配LayerNorm、Dropout构成,核心为前两者。Encoder层的好处是输出的数据维度与输入的数据维度完全一致,因此多层Encoder可以堆叠,这就构成了Bert的主体结构。

多头自注意力

自注意力相较于交叉注意力,区别在于前者的QKV来自同一个输入,而后者的Q来自一个输入KV来自另一个输入。

QKV分别通过三个线性层对输入x进行变换得到。

自注意力计算的核心公式为

o u t p u t = s o f t m a x ( Q ⋅ K T d i m ) ⋅ V output=softmax(\frac{Q \cdot K^T}{\sqrt{dim}}) \cdot V output=softmax(dim QKT)V

前馈网络

由两个线性层并配合激活函数和LayerNorm构成。

核心要点是第一个线性层之后,特征通道数增加;第二个线性层后,特征通道数又变回为原来的样子。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/48463.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

配置服务器

参考博客 1. https://blog.csdn.net/qq_31278903/article/details/83146031 2. https://blog.csdn.net/u014374826/article/details/134093409 3. https://blog.csdn.net/weixin_42728126/article/details/88887350 4. https://blog.csdn.net/Dreamhai/article/details/109…

JS设计模式(一)单例模式

注释很详细&#xff0c;直接上代码 本文建立在已有JS面向对象基础的前提下&#xff0c;若无&#xff0c;请移步以下博客先行了解 JS面向对象&#xff08;一&#xff09;类与对象写法 特点和用途&#xff1a; 全局访问点&#xff1a;通过单例模式可以在整个应用程序中访问同一个…

了解 JSON Web 令牌 (JWT)

一. 介绍 在现代 Web 开发领域&#xff0c;确保客户端和服务器之间的通信安全至关重要。JSON Web Tokens (JWT) 已成为确保安全高效数据交换的流行解决方案。本文将深入探讨 JWT 是什么、它们的工作原理以及它们为何对于 Web 应用程序中的安全身份验证和授权至关重要。 二. 什…

AWS全服务历史年表:发布日期、GA和服务概述一览(二)

我一直在尝试从各种角度撰写关于Amazon Web Services&#xff08;AWS&#xff09;的信息和魅力。由于我喜欢技术历史&#xff0c;这次我总结了AWS服务发布的历史年表。 虽然AWS官方也通过“Whats New”发布了官方公告&#xff0c;但我一直希望能有一篇文章将公告日期、GA日期&…

javac详解 idea maven内部编译原理 自制编译器

起因 不知道大家在开发中&#xff0c;有没有过下面这些疑问。有的话&#xff0c;今天就一次解答清楚。 如何使用javac命令编译一个项目&#xff1f;java或者javac的一些参数到底有什么用&#xff1f;idea或者maven是如何编译java项目的&#xff1f;&#xff08;你可能猜测底层…

【一刷《剑指Offer》】面试题 47:不用加减乘除做加法

力扣对应题目链接&#xff1a;LCR 190. 加密运算 - 力扣&#xff08;LeetCode&#xff09; 牛客对应题目链接&#xff1a;不用加减乘除做加法_牛客题霸_牛客网 (nowcoder.com) 一、《剑指Offer》对应内容 二、分析题目 sumdataA⊕dataB 非进位和&#xff1a;异或运…

Unity UGUI 之 Graphic Raycaster

本文仅作学习笔记与交流&#xff0c;不作任何商业用途 本文包括但不限于unity官方手册&#xff0c;唐老狮&#xff0c;麦扣教程知识&#xff0c;引用会标记&#xff0c;如有不足还请斧正 首先手册连接如下&#xff1a; Unity - Manual: Graphic Raycaster 笔记来源于&#xff…

mqtt协议有哪些机制

MQTT协议提供了一些关键机制来确保消息传递的可靠性、效率和灵活性。这些机制使得MQTT非常适用于物联网&#xff08;IoT&#xff09;和其他需要高效、低带宽通信的应用。以下是MQTT协议的主要机制&#xff1a; 1. 发布/订阅&#xff08;Pub/Sub&#xff09;模型 发布/订阅模型…

无人车技术浪潮真的挡不住了~

正文 无人驾驶汽车其实也不算是新鲜玩意了&#xff0c;早在十年前大家都开始纷纷投入研发&#xff0c;在那时就已经蠢蠢欲动&#xff0c;像目前大部分智驾系统和辅助驾驶系统都是无人驾驶系统的一个中间过度版本&#xff0c;就像手机进入智能机时代的中间版本。 然而前段时间突…

SpringBoot 介绍和使用(详细)

使用SpringBoot之前,我们需要了解Maven,并配置国内源(为什么要配置这些,下面会详细介绍),下面我们将创建一个SpringBoot项目"输出Hello World"介绍. 1.环境准备 ⾃检Idea版本: 社区版: 2021.1 -2022.1.4 专业版: ⽆要求 如果个⼈电脑安装的idea不在这个范围, 需要…

LeetCode 热题 HOT 100 (001/100)【宇宙最简单版】

【链表】 No. 0160 相交链表 【简单】&#x1f449;力扣对应题目指路 希望对你有帮助呀&#xff01;&#xff01;&#x1f49c;&#x1f49c; 如有更好理解的思路&#xff0c;欢迎大家留言补充 ~ 一起加油叭 &#x1f4a6; 欢迎关注、订阅专栏 【力扣详解】谢谢你的支持&#x…

搜维尔科技:【产品推荐】Euleria Health Riablo 运动功能训练与评估系统

Euleria Health Riablo 运动功能训练与评估系统 Riablo提供一种创新的康复解决方案&#xff0c;将康复和训练变得可激励、可衡量和可控制。Riablo通过激活本体感觉&#xff0c;并通过视听反馈促进神经肌肉的训练。 得益于其技术先进和易用性&#xff0c;Riablo是骨科、运动医…

centos软件安装

安装方式 一、二进制安装 --解压即用&#xff0c;只针对特殊平台 --jdk tomcat 二、RPM&#xff1a;按照一定规范安装软件&#xff0c;无法安装依赖的文件 --mysql 三、yum&#xff1a;远程安装基于RPM&#xff0c;把依赖的文件安装上去&#xff0c;需要联网 四、源码安装 jdk安…

jmeter部署

一、windows环境下部署 1、安装jdk并配置jdk的环境变量 (1) 安装jdk jdk下载完成后双击安装包&#xff1a;无限点击"下一步"直到完成&#xff0c;默认路径即可。 (2) jdk安装完成后配置jdk的环境变量 找到环境变量中的系统变量&#xff1a;此电脑 --> 右键属性 …

C语言:温度转换

1.题目&#xff1a;实现摄氏度&#xff08;Celsius&#xff09;和华氏度&#xff08;Fahrenheit&#xff09;之间的转换。 输入一个华氏温度&#xff0c;输出摄氏温度&#xff0c;结果保留两位小数。 2.思路&#xff1a;&#xff08;这是固定公式&#xff0c;其中 F 是华氏度&a…

【C语言】详解结构体(下)(位段)

文章目录 前言1. 位段的含义2. 位段的声明3. 位段的内存分配&#xff08;重点&#xff09;3.1 存储方向的问题3.2 剩余空间利用的问题 4. 位段的跨平台问题5. 位段的应用6. 总结 前言 相信大部分的读者在学校或者在自学时结构体的知识时&#xff0c;可能很少会听到甚至就根本没…

STM32实战篇:按键(外部输入信号)触发中断

功能要求 将两个按键分别与引脚PA0、PA1相连接&#xff0c;通过按键按下&#xff0c;能够触发中断响应程序&#xff08;不需明确功能&#xff09;。 代码流程如下&#xff1a; 实现代码 #include "stm32f10x.h" // Device headerint main() {//开…

一种Android系统双屏异显的两路音频实现方法

技术领域 [0001] 本发明涉及一种Android系统双屏异显的两路音频实现方法。 背景技术 [0002] 关于Android系统的双屏异显两路音频的实现目前还没有通用的方法&#xff0c;Android系 统的双屏异显两路音频的需求是&#xff1a;主屏的声音从主屏对应的声卡输出、副屏的声音从…

Nougat - 学术文档PDF解析(LaTeX数学、表格)

文章目录 一、关于 Nougat二、安装三、获取PDF的预测1、CLI2、API 四、数据集生成数据集 五、训练六、评估七、其它1、常见问题解答2、引文3、致谢4、许可证 一、关于 Nougat Nougat (Neural Optical Understanding for Academic Documents) Nougat是理解LaTeX数学和表格的 学…

Dockerfile相关命令

Dockerfile Dockerfile 是一个用来构建Docker镜像的文本文件&#xff0c;包含了一系列构建镜像所需的指令和参数。 指令详解 Dockerfile 指令说明FROM指定基础镜像&#xff0c;用于后续的指令构建&#xff0c;必须为第一个命令MAINTAINER指定Dockerfile的作者/维护者。&…