Mindspore 公开课 - BERT

BERT

BERT模型本质上是结合了 ELMo 模型与 GPT 模型的优势。

  • 相比于ELMo,BERT仅需改动最后的输出层,而非模型架构,便可以在下游任务中达到很好的效果;
  • 相比于GPT,BERT在处理词元表示时考虑到了双向上下文的信息;

BERT 结构

BERT(Bidirectional Encoder Representation from Transformers)是一个仅有 Encoder 的Transformer。

我们调用 BertModel 来搭建BERT模型,并通过 BertConfig 配置模型相关参数。

from pretrain.src.bert import BertModel
from pretrain.src.config import BertConfig

BERT 输入

Transformer:Encoder接受源句子(source sentence),Decoder接受目标句子(target sentence);

  • BERT:
    • 针对句子对相关任务,将两个句子合并为一个句子对输入到Encoder中,<cls> + 第一个句子 + <sep> + 第二个句子 + <sep>;
    • 针对单个文本相关任务,<cls> + 句子 + <sep>。

在这里插入图片描述

BERT Embedding

BERT Embedding与Transformer中的Embedding操作类似,包含词嵌入(word embedding)与位置嵌入(positional embedding)。

但与Transformer不同的是,BERT使用了可学习的位置信息,并额外增加了表示区分不同句子的段嵌入(segment embedding)。
在这里插入图片描述

import mindspore
from mindspore import nn
import mindspore.common.dtype as mstype
from mindspore.common.initializer import initializer, TruncatedNormalclass BertEmbeddings(nn.Cell):"""Embeddings for BERT, include word, position and token_type"""def __init__(self, config):super().__init__()self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, embedding_table=TruncatedNormal(config.initializer_range))self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size, embedding_table=TruncatedNormal(config.initializer_range))self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size, embedding_table=TruncatedNormal(config.initializer_range))self.layer_norm = nn.LayerNorm((config.hidden_size,), epsilon=config.layer_norm_eps)self.dropout = nn.Dropout(1 - config.hidden_dropout_prob)def construct(self, input_ids, token_type_ids=None, position_ids=None):seq_len = input_ids.shape[1]if position_ids is None:position_ids = mnp.arange(seq_len)position_ids = position_ids.expand_dims(0).expand_as(input_ids)if token_type_ids is None:token_type_ids = ops.zeros_like(input_ids)words_embeddings = self.word_embeddings(input_ids)position_embeddings = self.position_embeddings(position_ids)token_type_embeddings = self.token_type_embeddings(token_type_ids)embeddings = words_embeddings + position_embeddings + token_type_embeddingsembeddings = self.layer_norm(embeddings)embeddings = self.dropout(embeddings)return embeddings

BERT 模型构建

BERT模型的构建与上一节课程的Transformer Encoder构建类似。

分别构建multi-head attention层,feed-forward network,并在中间用add&norm连接,最后通过线性层与softmax层进行输出。

BERT 预训练

BERT通过Masked LM(masked language model)与NSP(next sentence prediction)获取词语和句子级别的特征。
在这里插入图片描述

Masked Language Model (Masked LM)

BERT模型通过Masked LM捕捉词语层面的信息。

我们随机将每个句子中15%的词语进行遮盖,替换成掩码。在训练过程中,模型会对句子进行“完形填空”,预测这些被遮盖的词语是什么,通过减小被mask词语的损失值来对模型进行优化。
在这里插入图片描述
由于<mask>仅在预训练中出现,为了让预训练和微调中的数据处理尽可能接近,我们在随机mask的时候进行如下操作:

  • 80%的概率替换为<mask>
  • 10%的概率替换为文本中的随机词
  • 10%的概率不进行替换,保持原有的词元
    在这里插入图片描述
    我们通过BERTPredictionHeadTranform实现单层感知机,对被遮盖的词元进行预测。在前向网络中,我们需要输入BERT模型的编码结果hidden_states。
activation_map = {'relu': nn.ReLU(),'gelu': nn.GELU(False),'gelu_approximate': nn.GELU(),'swish':nn.SiLU()
}class BertPredictionHeadTransform(nn.Cell):def __init__(self, config):super().__init__()self.dense = nn.Dense(config.hidden_size, config.hidden_size, weight_init=TruncatedNormal(config.initializer_range))self.transform_act_fn = activation_map.get(config.hidden_act, nn.GELU(False))self.layer_norm = nn.LayerNorm((config.hidden_size,), epsilon=config.layer_norm_eps)def construct(self, hidden_states):hidden_states = self.dense(hidden_states)hidden_states = self.transform_act_fn(hidden_states)hidden_states = self.layer_norm(hidden_states)return hidden_states

根据被遮盖的词元位置masked_lm_positions,获得这些词元的预测输出。

import mindspore.ops as ops
import mindspore.numpy as mnp
from mindspore import Parameter, Tensorclass BertLMPredictionHead(nn.Cell):def __init__(self, config):super(BertLMPredictionHead, self).__init__()self.transform = BertPredictionHeadTransform(config)self.decoder = nn.Dense(config.hidden_size, config.vocab_size, has_bias=False, weight_init=TruncatedNormal(config.initializer_range))self.bias = Parameter(initializer('zeros', config.vocab_size), 'bias')def construct(self, hidden_states, masked_lm_positions):batch_size, seq_len, hidden_size = hidden_states.shapeif masked_lm_positions is not None:flat_offsets = mnp.arange(batch_size) * seq_lenflat_position = (masked_lm_positions + flat_offsets.reshape(-1, 1)).reshape(-1)flat_sequence_tensor = hidden_states.reshape(-1, hidden_size)hidden_states = ops.gather(flat_sequence_tensor, flat_position, 0)hidden_states = self.transform(hidden_states)hidden_states = self.decoder(hidden_states) + self.biasreturn hidden_states

Next Sentence Prediction (NSP)

BERT通过NSP捕捉句子级别的信息,使其可以理解句子与句子之间的联系,从而能够应用于问答或者推理任务。

NSP本质上是一个二分类任务,通过输入一个句子对,判断两句话是否为连续句子。输入的两个句子A和B中,B有50%的概率是A的下一句。

另外,输入的内容最好是document-level的语料,而非sentence-level的语料,这样训练出的模型可以具备抓取长序列特征的能力。

在这里,我们使用一个单隐藏层的多层感知机 BERTPooler 进行二分类预测。因为特殊占位符 <cls>在预训练中对应了句子级别的特征信息,所以多层感知机分类器只需要输出 <cls>对应的隐藏层输出。

class BertPooler(nn.Cell):def __init__(self, config):super(BertPooler, self).__init__()self.dense = nn.Dense(config.hidden_size, config.hidden_size, activation='tanh', weight_init=TruncatedNormal(config.initializer_range))def construct(self, hidden_states):first_token_tensor = hidden_states[:, 0]pooled_output = self.dense(first_token_tensor)return pooled_output

最后,多层感知机分类器的输出通过一个线性层self.seq_relationship,输出对nsp的预测。

在BERTPreTrainingHeads中,我们对以上提到的两种方式进行整合。最终输出Maked LM(prediction scores)和NSP(seq_realtionship_score)的预测结果

class BertPreTrainingHeads(nn.Cell):def __init__(self, config):super(BertPreTrainingHeads, self).__init__()self.predictions = BertLMPredictionHead(config)self.seq_relationship = nn.Dense(config.hidden_size, 2, weight_init=TruncatedNormal(config.initializer_range))def construct(self, sequence_output, pooled_output, masked_lm_positions):prediction_scores = self.predictions(sequence_output, masked_lm_positions)seq_relationship_score = self.seq_relationship(pooled_output)return prediction_scores, seq_relationship_score

BERT预训练代码整合

我们将上述的类进行实例化,并借此回顾一下BERT预训练的整体流程。

  1. BertModel构建BERT模型;
  2. BertPretrainingHeads整合了Masked LM与NSP两个训练任务, 输出预测结果;
    • BertLMPredictionHead:输入BERT编码与的位置,输出对应位置词元的预测;
    • BERTPooler:输入BERT编码,输出对的隐藏状态,并在BertPretrainingHeads中通过线性层输出预测结果;
class BertForPretraining(nn.Cell):def __init__(self, config, *args, **kwargs):super().__init__(config, *args, **kwargs)self.bert = BertModel(config)self.cls = BertPreTrainingHeads(config)self.vocab_size = config.vocab_sizeself.cls.predictions.decoder.weight = self.bert.embeddings.word_embeddings.embedding_tabledef construct(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, masked_lm_positions=None):outputs = self.bert(input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids,position_ids=position_ids,head_mask=head_mask)sequence_output, pooled_output = outputs[:2]prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output, masked_lm_positions)outputs = (prediction_scores, seq_relationship_score,) + outputs[2:]return outputs

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/625122.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

微服务架构设计核心理论:掌握微服务设计精髓

文章目录 一、微服务与服务治理1、概述2、Two Pizza原则和微服务团队3、主链路规划4、服务治理和微服务生命周期5、微服务架构的网络层搭建6、微服务架构的部署结构7、面试题 二、配置中心1、为什么要配置中心2、配置中心高可用思考 三、服务监控1、业务埋点的技术选型2、用户行…

2023年总结:雄关漫道真如铁,而今迈步从头越,今朝得失

2023年悄然离去&#xff0c;感谢大家的帮助、鼓励和陪伴&#xff0c;感谢家人的理解和支持&#xff0c;祝大家新年快乐&#xff0c;阖家幸福&#xff0c;身体健康。像往常一样&#xff0c;今年也会写一篇年终总结&#xff0c;也是自己的第11篇年终总结&#xff0c;题目就叫《雄…

32 二叉树的定义

之前的通用树结构 采用双亲孩子表示法模型 孩子兄弟表示法模型 引出二叉树 二叉树的定义&#xff1a; 满二叉树和完全二叉树 对此图要有印象 满二叉树一定是完全二叉树&#xff0c;但是完全二叉树不一定是满二叉树 小结

Javaweb之SpringBootWeb案例员工管理分页查询的详细解析

3. 员工管理 完成了部门管理的功能开发之后&#xff0c;我们进入到下一环节员工管理功能的开发。 基于以上原型&#xff0c;我们可以把员工管理功能分为&#xff1a; 分页查询&#xff08;今天完成&#xff09; 带条件的分页查询&#xff08;今天完成&#xff09; 删除员工&…

HNU-算法设计与分析-实验4

算法设计与分析实验4 计科210X 甘晴void 202108010XXX 目录 文章目录 算法设计与分析<br>实验41 回溯算法求解0-1背包问题问题重述想法代码验证算法分析 2 回溯算法实现题5-4运动员最佳配对问题问题重述想法代码验证算法分析 3 分支限界法求解0-1背包问题问题重述想法…

gogs git创建仓库步骤

目录 引言创建仓库clone 仓库推送代码 引言 Gogs 是一款类似GitHub的开源文件/代码管理系统&#xff08;基于Git&#xff09;&#xff0c;Gogs 的目标是打造一个最简单、最快速和最轻松的方式搭建自助 Git 服务。 创建仓库 git中的组织可以把它看成是相关仓库的集合&#xff0c…

DNS主从服务器配置

主从服务器配置&#xff1a; &#xff08;1&#xff09;完全区域传送&#xff1a;复制整个区域文件 #主DNS服务器的配置【主dns服务器的ip地址为192.168.168.129】 #编辑DNS系统配置信息&#xff08;我这里写的增加的信息&#xff0c;源文件里面有很多内容&#xff09; [root…

python中小数据池和编码

嗨喽&#xff0c;大家好呀~这里是爱看美女的茜茜呐 ⼀. 小数据池 在说小数据池之前. 我们先看⼀个概念. 什么是代码块: 根据提示我们从官⽅⽂档找到了这样的说法&#xff1a; A Python program is constructed from code blocks. A block is a piece of Python program text…

大电流直流恒温控制电路

一个电子制冷器控制芯片 实物照片 驱动芯片 使用环境12V直流&#xff0c;电流10A 特此记录 anlog 2024年1月15日

2.1.2 一个关于y=ax+b的故事

跳转到根目录&#xff1a;知行合一&#xff1a;投资篇 已完成&#xff1a; 1、投资&技术   1.1.1 投资-编程基础-numpy   1.1.2 投资-编程基础-pandas   1.2 金融数据处理   1.3 金融数据可视化 2、投资方法论   2.1.1 预期年化收益率   2.1.2 一个关于yaxb的…

【C初阶——内存函数】鹏哥C语言系列文章,基本语法知识全面讲解

本文由睡觉待开机原创&#xff0c;转载请注明出处。 本内容在csdn网站首发 欢迎各位点赞—评论—收藏 如果存在不足之处请评论留言&#xff0c;共同进步&#xff01; 这里写目录标题 1.memcpy使用和模拟实现2.memmove的使用和模拟实现3.memset函数的使用4.memcpy函数的使用 1.m…

linux安装MySQL5.7(安装、开机自启、定时备份)

一、安装步骤 我喜欢安装在/usr/local/mysql目录下 #切换目录 cd /usr/local/ #下载文件 wget https://dev.mysql.com/get/Downloads/MySQL-5.7/mysql-5.7.38-linux-glibc2.12-x86_64.tar.gz #解压文件 tar -zxvf mysql-5.7.38-linux-glibc2.12-x86_64.tar.gz -C /usr/local …

ERP和MES对接的几种接口方式

在数字化工厂的规划建设中&#xff0c;信息化系统的集成&#xff0c;既是重点&#xff0c;但同时也是难点。ERP和MES对接时&#xff0c;ERP主要负责下达生产计划&#xff0c;MES是执行生产计划&#xff0c;二套系统在数据交互时&#xff0c;需要确保基础数据的一致性&#xff0…

SpringBoot源码分析

一&#xff1a;简介 由Pivotal团队提供的全新框架其设计目的是用来简化新Spring应用的初始搭建以及开发过程使用了特定的方式来进行配置快速应用开发领域 二&#xff1a;运行原理以及特点 运行原理&#xff1a; SpringBoot为我们做的自动配置&#xff0c;确实方便快捷&#…

STC8H8K蓝牙智能巡线小车——2. 点亮左右转弯灯与危险报警灯

任务调用示例 RTX 51 TNY 可做多任务调度&#xff0c;API较为简单。 /* 接口API */// 创建任务 extern unsigned char os_create_task (unsigned char task_id); // 结束任务 extern unsigned char os_delete_task (unsigned char task_id);// 等待 extern unsig…

RTKlib操作手册--使用样例数据演示

简介 RTKLIB&#xff08;Real-Time Kinematic Library&#xff09;是一款开源的实时差分全球导航卫星系统&#xff08;GNSS&#xff09;软件库。它旨在提供高精度的位置解算&#xff0c;特别是在实时应用中&#xff0c;如精密农业、测绘、无人机导航等领域。 RTKLIB支持多种G…

目标检测数据集 - 人脸检测数据集下载「包含VOC、COCO、YOLO三种格式」

数据集介绍&#xff1a;行人检测数据集&#xff0c;真实场景高质量图片数据&#xff0c;涉及场景丰富&#xff0c;比如校园行人、街景行人、道路行人、遮挡行人、严重遮挡行人数据&#xff1b;适用实际项目应用&#xff1a;公共场所监控场景下行人检测项目&#xff0c;以及作为…

如何写好年终总结?

前面有读者留言问年终总结要怎么写&#xff0c;我一听你要聊这个我可不困了&#xff0c;这活我熟啊&#xff0c;谁不知道我厂是 PPT 之王。先来一套打法闭环方法论&#xff0c;再来一套赋能抓手组合拳&#xff0c;如此这般&#xff0c;便可笑傲于江湖。 玩笑归玩笑&#xff0c…

常用界面设计组件 —— 字符串与输入输出组件(QT)

2.2 字符串与输入输出组件2.2.1 字符串与数值之间的转换2.2.2 QString的常用功能 2.2 字符串与输入输出组件 2.2.1 字符串与数值之间的转换 界面设计时使用最多的组件恐怕就是QLabel和 QLineEdit了&#xff0c;QLabel用于显示字符串&#xff0c;QLineEdit用于 显示和输入字符…

ioDraw在线图表工具 - 轻松制作专业图表,只需3步!

还在花大量时间手动画图表&#xff1f;还在为图表样式而烦恼&#xff1f;ioDraw为你提供一站式解决方案&#xff01;ioDraw在线图表工具实现了AI自动生成图表&#xff0c;让你轻松制作专业图表&#xff0c;只需3步&#xff01; 1. 录入数据 只需将你的数据告诉ioDraw AI助手&…