【HuggingFace Transformers】BertSelfOutput 和 BertOutput源码解析

BertSelfOutput 和 BertOutput源码解析

  • 1. 介绍
    • 1.1 共同点
      • (1) 残差连接 (Residual Connection)
      • (2) 层归一化 (Layer Normalization)
      • (3) Dropout
      • (4) 线性变换 (Linear Transformation)
    • 1.2 不同点
      • (1) 处理的输入类型
      • (2) 线性变换的作用
      • (3) 输入的特征大小
  • 2. 源码解析
    • 2.1 BertSelfOutput 源码解析
    • 2.2 BertOutput 源码解析

1. 介绍

BertSelfOutputBertOutputBERT 模型中两个相关但不同的模块。它们在功能上有许多共同点,但也有一些关键的不同点。以下通过共同点和不同点来介绍它们。

1.1 共同点

BertSelfOutputBertOutput 都包含残差连接、层归一化、Dropout 和线性变换,并且这些操作的顺序相似。

(1) 残差连接 (Residual Connection)

两个模块都应用了残差连接,即将模块的输入直接与经过线性变换后的输出相加。这种结构可以帮助缓解深层神经网络中的梯度消失问题,使信息更直接地传递,保持梯度流动顺畅。

(2) 层归一化 (Layer Normalization)

在应用残差连接后,两个模块都使用层归一化 (LayerNorm) 来规范化输出。这有助于加速训练,稳定网络性能,并减少内部分布变化的问题。

(3) Dropout

两个模块都包含一个 Dropout 层,用于随机屏蔽一部分神经元的输出,增强模型的泛化能力,防止过拟合。

(4) 线性变换 (Linear Transformation)

两个模块都包含一个线性变换 (dense 层)。这个线性变换用于调整数据的维度,并为后续的残差连接和层归一化做准备。

1.2 不同点

BertSelfOutput 专注于处理自注意力机制的输出,而 BertOutput 则处理前馈神经网络的输出。它们的输入特征维度也有所不同,线性变换的作用在两个模块中也略有差异。

(1) 处理的输入类型

  • BertSelfOutput:处理自注意力机制 (BertSelfAttention) 的输出。它关注的是如何将注意力机制生成的特征向量与原始输入结合起来。
  • BertOutput:处理的是前馈神经网络的输出。它将经过注意力机制处理后的特征进一步加工,并整合到当前层的最终输出中。

(2) 线性变换的作用

  • BertSelfOutput:线性变换的作用是对自注意力机制的输出进行进一步的变换和投影,使其适应后续的处理流程。
  • BertOutput:线性变换的作用是对前馈神经网络的输出进行变换,使其与前一层的输出相结合,并准备传递到下一层。

(3) 输入的特征大小

  • BertSelfOutput:输入和输出的特征维度保持一致,都是 BERT 模型的隐藏层大小 (hidden_size)。
  • BertOutput:输入的特征维度是中间层大小 (intermediate_size),输出则是 BERT 模型的隐藏层大小 (hidden_size)。这意味着 BertOutput 的线性变换需要将中间层的维度转换回隐藏层的维度。

2. 源码解析

源码地址:transformers/src/transformers/models/bert/modeling_bert.py

2.1 BertSelfOutput 源码解析

# -*- coding: utf-8 -*-
# @time: 2024/7/15 14:27import torch
from torch import nnclass BertSelfOutput(nn.Module):def __init__(self, config):super().__init__()self.dense = nn.Linear(config.hidden_size, config.hidden_size)  # 定义线性变换层,将自注意力输出映射到 hidden_size 维度self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)  # 层归一化self.dropout = nn.Dropout(config.hidden_dropout_prob)  # Dropout层def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:hidden_states = self.dense(hidden_states)  # 对自注意力机制的输出进行线性变换hidden_states = self.dropout(hidden_states)  # Dropout操作hidden_states = self.LayerNorm(hidden_states + input_tensor)  # 残差连接后进行层归一化return hidden_states

2.2 BertOutput 源码解析

# -*- coding: utf-8 -*-
# @time: 2024/8/22 15:41import torch
from torch import nnclass BertOutput(nn.Module):def __init__(self, config):super().__init__()self.dense = nn.Linear(config.intermediate_size, config.hidden_size)  # 定义线性变换层,将前馈神经网络输出从 intermediate_size 映射到 hidden_sizeself.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)  # 层归一化self.dropout = nn.Dropout(config.hidden_dropout_prob)  # Dropout层def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:hidden_states = self.dense(hidden_states)  # 对前馈神经网络的输出进行线性变换hidden_states = self.dropout(hidden_states)  # Dropout操作hidden_states = self.LayerNorm(hidden_states + input_tensor)  # 残差连接后进行层归一化return hidden_states

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/52425.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Facebook的AI助手:如何提升用户社交体验的智能化

在现代社交媒体平台中,人工智能(AI)的应用正逐渐改变人们的社交体验。Facebook作为全球最大的社交媒体平台之一,已在AI技术的开发与应用上投入了大量资源,并通过其AI助手为用户提供了更加个性化、智能化的互动体验。这…

vagrant 创建虚拟机

创建一个名为 “Vagrantfile” 的文件,修改如下内容: Vagrant.configure("2") do |config|(1..3).each do |i|config.vm.define "k8s-node#{i}" do |node|# 设置虚拟机的Boxnode.vm.box "centos/7"# 设置虚拟机的主机名…

逆向中的游戏-入土为安的第二十五天

逆向中的游戏 CE的介绍 Cheat Engine ,简称CE,是逆向工程师常用的几大神器之一,也是游戏汉化、破解以及外挂编写中常用的工具,其功能包括:内存扫描、十六进制编辑器、调试工具,可以进行反汇编调试、断点跟…

代码随想录算法训练营_day28

题目信息 122. 买卖股票的最佳时机 II 题目链接: https://leetcode.cn/problems/best-time-to-buy-and-sell-stock-ii/题目描述: 给你一个整数数组 prices ,其中 prices[i] 表示某支股票第 i 天的价格。 在每一天,你可以决定是否购买和/或出售股票。你…

Springboot-RequestContextHolder

RequestContextHolder 是 Spring 框架中的一个类,主要用于在多线程环境中存储和访问 HTTP 请 求的上下文信息。它允许在 Spring 应用程序中从任何位置访问当前请求的相关信息,比如 HTTP 头部、会话数据等,而无需将请求对象直接传递到每个方法中。 主要用途 存储请求上下…

Seata 学习

简介 我们都知道 Seata 是一个分布式事务的解决方案,今天我们就来带大家了解一下什么是分布式事务,首先我们先来了解一下基础的知识——事务,我们先来了解一下事务的概念是什么。 基本概念 事务四部分构成— ACID: A(Atomic)&…

小程序路由传参和获取页面栈方法

路由方法 navigateTo, redirectTo 只能打开非 tabBar 页面。switchTab 只能打开 tabBar 页面。reLaunch 可以打开任意页面。页面底部的 tabBar 由页面决定,即只要是定义为 tabBar 的页面,底部都有 tabBar。调用页面路由带的参数可以在目标页面的onLoad中…

matlab 计算复共轭

目录 一、概述1、算法概述2、主要函数二、代码示例1、求复数的复共轭2、求矩阵中复数值的复共轭三、参考链接本文由CSDN点云侠翻译,原文链接。如果你不是在点云侠的博客中看到该文章,那么此处便是不要脸的抄袭狗。 一、概述 1、算法概述 2、主要函数 Zc = conj(Z)返回 Z …

【python】Python中小巧的异步web框架Sanic快速上手实战

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,…

理解torch.argmax() ,我是错误的

torch.max() import torch# 定义张量 b b torch.tensor([[1, 3, 5, 7],[2, 4, 6, 8],[11, 12, 13, 17]])# 使用 torch.max() 找到最大值 max_indices torch.max(b, dim0)print(max_indices) 输出:>>> print(max_indices) torch.return_types.max( valu…

【13.3 python中的高级文件操作】

python中的高级文件操作 在Python中,除了基本的文件读写和目录操作外,还有一些高级的文件和目录操作,如删除文件、重命名文件和目录、以及获取文件的基本信息等。这些操作通常通过os模块和pathlib模块来实现。下面我将详细介绍这些操作&#…

Git在IDEA中的集成操作(附步骤图)

1.先做适配操作,将安装的Git软件关联到IDEA中 点击Test之后若成功会显示出Git版本: 2.创建版本仓库 3.创建新的版本 3.1将文件提交到暂存区(不重要) 第一种方式:菜单栏提交 第二种方式:项目右键提交 4.查看历史版本信息 目…

整合sentinel遇到的小问题

1.运行jar包 ,端口为默认8080 正确命令 java -Dserver.port8090 -Dcsp.sentinel.dashboard.server127.0.0.1:8090 -Dproject.namesentinel-dashboard -jar sentinel-dashboard-1.8.6.jar -D这些指令要在 -jar前面 在宝塔部署时,直接复制到运行命令后…

acl2的安装和vescmul的运行

使用sat一类的求解器验证乘法器,通常无法收敛,Mertcan Temel博士基于acl2实现了数字电路中乘法器的验证。下面简单介绍一下acl2的安装、vescmul的运行。 1.下载acl2 git clone https://github.com/acl2/acl2.git 2.acl2基于lisp编程语言实现&#xff…

Sparse Kernel Canonical Correlation Analysis

论文链接:https://arxiv.org/pdf/1701.04207 看这篇论文终于看懂核函数了。。谢谢作者

基于无人机边沿相关 ------- IBUS、SBUS协议和PPM信号

文章目录 一、IBUS协议二、SBUS协议三、PPM信号 一、IBUS协议 IBUS(Intelligent Bus)是一种用于电子设备之间通信的协议,采用串行通信方式,允许多设备通过单一数据线通信,较低延迟,支持多主机和从机结构&a…

SpringBoot集成kafka-监听器注解

SpringBoot集成kafka-监听器注解 1、application.yml2、生产者3、消费者4、测试类5、测试 1、application.yml #自定义配置 kafka:topic:name: helloTopicconsumer:group: helloGroup2、生产者 package com.power.producer;import com.power.model.User; import com.power.uti…

C++11详解 (右值引用、可变参数模板、emplace_back、lambda表达式、function、bind)

目录 简介 左值引用与右值引用 左值引用与右值引用是什么 左值引用与右值引用的比较 右值引用的使用场景与C11中STL的新变化 完美转发 新的类功能 可变参数模板 可变参数模板的应用——emplace_back lambda表达式 包装器 function bind 结语 简介 在过往&#xf…

UGUI空白可点击组件,减少重绘

如果使用image alpha 0,会导致overDraw,直接清空mesh,不绘制即可避免 #if UNITY_EDITOR using UnityEditor; #endif using UnityEngine; using UnityEngine.UI; namespace UnityGameFramework { [AddComponentMenu("Game/UI/GameEmpty4Raycast")] [Requir…

基于vue框架的毕业设计选题系统bqx47(程序+源码+数据库+调试部署+开发环境)系统界面在最后面。

系统程序文件列表 项目功能:学生,指导老师,课题信息,类型,选题信息 开题报告内容 基于Vue框架的毕业设计选题系统 开题报告 一、引言 毕业设计选题是高等教育中极为关键的一环,它不仅关乎学生未来研究的方向与深度,也是培养其创新思维和实…