BERT模型核心组件详解及其实现

摘要

BERT(Bidirectional Encoder Representations from Transformers)是一种基于Transformer架构的预训练模型,在自然语言处理领域取得了显著的成果。本文详细介绍了BERT模型中的几个关键组件及其实现,包括激活函数、变量初始化、嵌入查找、层归一化等。通过深入理解这些组件,读者可以更好地掌握BERT模型的工作原理,并在实际应用中进行优化和调整。

1. 引言

BERT模型由Google研究人员于2018年提出,通过大规模的无监督预训练和任务特定的微调,显著提升了多个自然语言处理任务的性能。本文将重点介绍BERT模型中的几个核心组件,包括激活函数、变量初始化、嵌入查找、层归一化等,并提供相应的代码实现。

2. 激活函数
2.1 Gaussian Error Linear Unit (GELU)

GELU是一种平滑的ReLU变体,其数学表达式如下:

GELU(x)=x⋅Φ(x)GELU(x)=x⋅Φ(x)

其中,Φ(x)Φ(x)是标准正态分布的累积分布函数(CDF)。在TensorFlow中,GELU可以通过以下方式实现:

def gelu(x):"""Gaussian Error Linear Unit.This is a smoother version of the RELU.Original paper: https://arxiv.org/abs/1606.08415Args:x: float Tensor to perform activation.Returns:`x` with the GELU activation applied."""cdf = 0.5 * (1.0 + tf.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))return x * cdf
2.2 激活函数映射

为了方便使用不同的激活函数,我们定义了一个映射函数get_activation,该函数根据传入的字符串返回相应的激活函数:

def get_activation(activation_string):"""Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`.Args:activation_string: String name of the activation function.Returns:A Python function corresponding to the activation function. If`activation_string` is None, empty, or "linear", this will return None.If `activation_string` is not a string, it will return `activation_string`.Raises:ValueError: The `activation_string` does not correspond to a knownactivation."""if not isinstance(activation_string, six.string_types):return activation_stringif not activation_string:return Noneact = activation_string.lower()if act == "linear":return Noneelif act == "relu":return tf.nn.reluelif act == "gelu":return geluelif act == "tanh":return tf.tanhelse:raise ValueError("Unsupported activation: %s" % act)
3. 变量初始化

在深度学习中,合理的变量初始化对于模型的收敛速度和最终性能至关重要。BERT模型中使用了截断正态分布初始化器(truncated_normal_initializer),其标准差为0.02:

def create_initializer(initializer_range=0.02):"""Creates a `truncated_normal_initializer` with the given range."""return tf.truncated_normal_initializer(stddev=initializer_range)
4. 嵌入查找

嵌入查找是将输入的token id转换为向量表示的过程。BERT模型中使用了两种方法:一种是使用tf.gather(),另一种是使用one-hot编码:

def embedding_lookup(input_ids,vocab_size,embedding_size=128,initializer_range=0.02,word_embedding_name="word_embeddings",use_one_hot_embeddings=False):"""Looks up words embeddings for id tensor.Args:input_ids: int32 Tensor of shape [batch_size, seq_length] containing wordids.vocab_size: int. Size of the embedding vocabulary.embedding_size: int. Width of the word embeddings.initializer_range: float. Embedding initialization range.word_embedding_name: string. Name of the embedding table.use_one_hot_embeddings: bool. If True, use one-hot method for wordembeddings. If False, use `tf.gather()`.Returns:float Tensor of shape [batch_size, seq_length, embedding_size]."""if input_ids.shape.ndims == 2:input_ids = tf.expand_dims(input_ids, axis=[-1])embedding_table = tf.get_variable(name=word_embedding_name,shape=[vocab_size, embedding_size],initializer=create_initializer(initializer_range))flat_input_ids = tf.reshape(input_ids, [-1])if use_one_hot_embeddings:one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size)output = tf.matmul(one_hot_input_ids, embedding_table)else:output = tf.gather(embedding_table, flat_input_ids)input_shape = get_shape_list(input_ids)output = tf.reshape(output,input_shape[0:-1] + [input_shape[-1] * embedding_size])return (output, embedding_table)
5. 层归一化

层归一化(Layer Normalization)是一种常用的归一化技术,用于加速训练过程并提高模型的泛化能力。BERT模型中使用了tf.contrib.layers.layer_norm来进行层归一化:

def layer_norm(input_tensor, name=None):"""Run layer normalization on the last dimension of the tensor."""return tf.contrib.layers.layer_norm(inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name)

为了方便使用,我们还定义了一个组合函数layer_norm_and_dropout,该函数先进行层归一化,再进行dropout操作:

def layer_norm_and_dropout(input_tensor, dropout_prob, name=None):"""Runs layer normalization followed by dropout."""output_tensor = layer_norm(input_tensor, name)output_tensor = dropout(output_tensor, dropout_prob)return output_tensor
6. Dropout

Dropout是一种常用的正则化技术,用于防止模型过拟合。在BERT模型中,dropout的概率可以通过配置参数进行设置:

def dropout(input_tensor, dropout_prob):"""Perform dropout.Args:input_tensor: float Tensor.dropout_prob: Python float. The probability of dropping out a value (NOT of*keeping* a dimension as in `tf.nn.dropout`).Returns:A version of `input_tensor` with dropout applied."""if dropout_prob is None or dropout_prob == 0.0:return input_tensoroutput = tf.nn.dropout(input_tensor, 1.0 - dropout_prob)return output
7. 从检查点加载变量

在微调过程中,通常需要从预训练的模型检查点中加载变量。get_assignment_map_from_checkpoint函数用于计算当前变量与检查点变量的映射关系:

def get_assignment_map_from_checkpoint(tvars, init_checkpoint):"""Compute the union of the current variables and checkpoint variables."""assignment_map = {}initialized_variable_names = {}name_to_variable = collections.OrderedDict()for var in tvars:name = var.namem = re.match("^(.*):\\d+$", name)if m is not None:name = m.group(1)name_to_variable[name] = varinit_vars = tf.train.list_variables(init_checkpoint)assignment_map = collections.OrderedDict()for x in init_vars:(name, var) = (x[0], x[1])if name not in name_to_variable:continueassignment_map[name] = nameinitialized_variable_names[name] = 1initialized_variable_names[name + ":0"] = 1return (assignment_map, initialized_variable_names)
8. 结论

本文详细介绍了BERT模型中的几个核心组件,包括激活函数、变量初始化、嵌入查找、层归一化等。通过深入理解这些组件,读者可以更好地掌握BERT模型的工作原理,并在实际应用中进行优化和调整。希望本文能为读者在自然语言处理领域的研究和开发提供有益的参考。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/60768.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Transformer中的算子:其中Q,K,V就是算子

目录 Transformer中的算子 其中Q,K,V就是算子 一、数学中的算子 二、计算机科学中的算子 三、深度学习中的算子 四、称呼的由来 Transformer中的算子 其中Q,K,V就是算子 “算子”这一称呼源于其在数学、计算机科学以及深度学习等多个领域中的广泛应用和特定功能。以下是…

ElementPlus el-upload上传组件on-change只触发一次

ElementPlus el-upload上传组件on-change只触发一次 主要运用了:on-exceed方法 废话不多说&#xff0c;直接上代码 <el-uploadclass"avatar-uploader"action"":on-change"getFilesj":limit"1":auto-upload"false"accep…

厦大南洋理工最新开源,一种面向户外场景的特征-几何一致性无监督点云配准方法

导读 本文提出了INTEGER&#xff0c;一种面向户外点云数据的无监督配准方法&#xff0c;通过整合高层上下文和低层几何特征信息来生成更可靠的伪标签。该方法基于教师-学生框架&#xff0c;创新性地引入特征-几何一致性挖掘&#xff08;FGCM&#xff09;模块以提高伪标签的准确…

Conda环境与Ubuntu环境移植详解

Conda环境与Ubuntu环境移植详解 在计算机科学中&#xff0c;环境迁移是一项常见的任务&#xff0c;特别是对于使用Anaconda等工具进行数据科学和机器学习的开发人员。迁移环境不仅能够帮助开发者在不同设备间无缝切换&#xff0c;还能确保项目依赖的一致性&#xff0c;从而避免…

【深度学习基础】PyCharm anaconda PYTorch python CUDA cuDNN 环境配置

这里写目录标题 PyCharm 安装anaconda安装PYTorch安装确定python版本CUDA安装cuDNN安装检验环境是否配置成功参照:PyCharm 安装 官网下载 anaconda安装 官网下载 :https://www.anaconda.com/download 配置环境变量,增加 D:\WorkSoftware\Install\Anaconda3 D:\WorkSoftw…

生产环境中AI调用的优化:AI网关高价值应用实践

随着越来越多的组织将生成式AI引入生产环境&#xff0c;他们面临的挑战已经超出了初步实施的范畴。如果管理不当&#xff0c;扩展性限制、安全漏洞和性能瓶颈可能会阻碍AI应用的推广。实际问题如用户数据的安全性、固定容量限制、成本管理和延迟优化等&#xff0c;需要创新的解…

Redis 概 述 和 安 装

安 装 r e d i s: 1. 下 载 r e dis h t t p s : / / d o w n l o a d . r e d i s . i o / r e l e a s e s / 2. 将 redis 安装包拷贝到 /opt/ 目录 3. 解压 tar -zvxf redis-6.2.1.tar.gz 4. 安装gcc yum install gcc 5. 进入目录 cd redis-6.2.1 6. 编译 make …

SpringBoot 2.2.10 无法执行Test单元测试

很早之前的项目今天clone现在&#xff0c;想执行一个业务订单的检查&#xff0c;该检查的代码放在test单元测试中&#xff0c;启动也是好好的&#xff0c;当点击对应的方法执行Test的时候就报错 tip&#xff1a;已添加spring-boot-test-starter 所以本身就引入了junit5的库 No…

Dubbo 3.2 源码导读

Dubbo 是一个高性能的 Java RPC 框架&#xff0c;广泛用于构建分布式服务。Dubbo 3.2 版本引入了一些新的特性和改进&#xff0c;是一个值得深入研究的版本。以下是对 Dubbo 3.2 源码的导读&#xff0c;帮助你理解其架构和设计。 1. 源码获取 从 GitHub 上获取 Dubbo 3.2 的源…

[项目代码] YOLOv5 铁路工人安全帽安全背心识别 [目标检测]

YOLOv5是一种单阶段&#xff08;one-stage&#xff09;检测算法&#xff0c;它将目标检测问题转化为一个回归问题&#xff0c;能够在一次前向传播过程中同时完成目标的分类和定位任务。相较于两阶段检测算法&#xff08;如Faster R-CNN&#xff09;&#xff0c;YOLOv5具有更高的…

Flutter:Widget生命周期

StatelessWidget&#xff1a;无状态部件的生命周期 import package:flutter/material.dart;void main() {runApp(App()); }class App extends StatelessWidget {overrideWidget build(BuildContext context) {return MaterialApp(home: MyHomePage(title: MyHome),);} }class M…

SIM Jacker攻击分析

简介&#xff1a; 2019年9月12日&#xff0c;AdaptiveMobile Security公布了一种针对SIM卡ST Browser的远程攻击方式&#xff1a;Simjacker。攻击者使用普通手机发送特殊构造的短信即可远程定位目标&#xff0c;危害较大 。sim卡的使用在手机上的使用非常普遍&#xff0c;所以…

Python 操作 Elasticsearch 全指南:从连接到数据查询与处理

文章目录 Python 操作 Elasticsearch 全指南&#xff1a;从连接到数据查询与处理引言安装 elasticsearch-py连接到 Elasticsearch创建索引插入数据查询数据1. 简单查询2. 布尔查询 更新文档删除文档和索引删除文档删除索引 批量插入数据处理分页结果总结 Python 操作 Elasticse…

【linux】centos7 换阿里云源

相关文章 【linux】CentOS 的软件源&#xff08;Repository&#xff09;学习-CSDN博客 查看yum配置文件 yum的配置文件通常位于/etc/yum.repos.d/目录下。你可以使用以下命令查看这些文件&#xff1a; ls /etc/yum.repos.d/ # 或者 ll /etc/yum.repos.d/备份当前的yum配置文…

AI 写作(八)实战项目一:自动写作助手(8/10)

一、项目背景与需求分析 &#xff08;一&#xff09;写作需求的多样化 在互联网普及的今天&#xff0c;人们对写作的需求呈现出前所未有的多样化态势。无论是学术论文、新闻报道&#xff0c;还是社交媒体的动态更新、网络小说的创作&#xff0c;都离不开高质量的写作。以学术研…

微信小程序内嵌h5页面(uniapp写的),使用uni.openLocation无法打开页面问题

1.问题 微信小程序内嵌h5页面(uniapp写的),使用uni.openLocation打开地图页面后,点击该页面下方“到这里”按钮,显示无法打开。如下图: 3.解决方案 在内嵌h5中不使用uniapp的api打开地图,而在h5页面事件处理程序中去跳转新的小程序页面,在该新页面去使用微信小程序…

SpringCloud核心组件(五)

文章目录 Gateway一. 概述简介1. Gateway 是什么2. 什么是网关?3.Gateway 和 Nginx 两个网关的区别什么是流量入口&#xff1f; 4.Gateway 能干嘛5.gateway 三大核心概念6.运行方式 二. 入门案例a.创建gateway模块&#xff0c;在pom.xml中引入依赖b.创建启动类GatewayApplicat…

1+X应急响应(网络)系统备份:

系统备份&#xff1a; 系统备份概述&#xff1a; 备份种类&#xff1a; 灾难恢复等级划分&#xff1a; 执行一次备份&#xff1a; 创建备份计划&#xff1a; 恢复备份&#xff1a;

Python学习26天

集合 # 定义集合 num {1, 2, 3, 4, 5} print(f"num&#xff1a;{num}\nnum数据类型为&#xff1a;{type(num)}") # 求集合中元素个数 print(f"num中元素个数为&#xff1a;{len(num)}") # 增加集合中的元素 num.add(6) print(num) # {1,2,3,4,5,6} # 删除…

git撤销、回退某个commit的修改

文章目录 撤销某个特定的commit方法 1&#xff1a;使用 git revert方法 2&#xff1a;使用 git rebase -i方法 3&#xff1a;使用 git reset 撤销某个特定的commit 如果你要撤销某个很早之前的 commit&#xff0c;比如 7461f745cfd58496554bd672d52efa8b1ccf0b42&#xff0c;可…