[RoFormer]论文实现:ROFORMER: ENHANCED TRANSFORMER WITH ROTARY POSITION EMBEDDING

文章目录

    • 一、完整代码
    • 二、论文解读
      • 2.1 注意力机制
      • 2.2 绝对位置编码
      • 2.3 相对位置编码
      • 2.4 旋转位置编码
        • Long-term decay
        • Adaption for linear attention
      • 2.5 模型效果
    • 三、过程实现
    • 四、整体总结

论文:ROFORMER: ENHANCED TRANSFORMER WITH ROTARY POSITION EMBEDDING
作者:Jianlin Su, Yu Lu, Shengfeng Pan, Ahmed Murtadha, Bo Wen, Yunfeng Liu
时间:2021
地址:https://huggingface.co/docs/transformers/model_doc/roformer

一、完整代码

由于Transformer是老生常谈了,这里我们只简要实现RoPE

# 完整代码在这里
class RotaryEmbedding(tf.keras.layers.Layer):def __init__(self,max_wavelength=10000,scaling_factor=1.0,sequence_axis=1,feature_axis=-1,**kwargs):super().__init__(**kwargs)self.max_wavelength = max_wavelengthself.sequence_axis = sequence_axisself.feature_axis = feature_axisself.scaling_factor = scaling_factorself.built = Truedef call(self, inputs, start_index=0):rotary_dim = tf.shape(inputs)[-1]cos_emb, sin_emb = self._compute_cos_sin_embedding(inputs, rotary_dim, start_index)return self._apply_rotary_pos_emb(inputs, cos_emb, sin_emb)def _apply_rotary_pos_emb(self, tensor, cos_emb, sin_emb):x1, x2 = tf.split(tensor, 2, axis=self.feature_axis)half_rot_tensor = tf.concat((-x2, x1), axis=self.feature_axis)return (tf.matmul(tensor,cos_emb)) + (tf.matmul(half_rot_tensor, sin_emb))def _compute_cos_sin_embedding(self, x, rotary_dim, start_index):freq_range = tf.range(0, rotary_dim, 2, dtype="float32")freq_range = tf.cast(freq_range, self.compute_dtype)freq_range = freq_range / tf.cast(self.scaling_factor, self.compute_dtype)inverse_freq = 1.0 / (self.max_wavelength** (freq_range / tf.cast(rotary_dim, self.compute_dtype)))seq_len = tf.shape(x)[self.sequence_axis]tensor = tf.range(seq_len, dtype="float32") + start_indextensor = tf.cast(tensor, dtype=inverse_freq.dtype)freq = tf.einsum("i, j -> ij", tensor, inverse_freq)embedding = tf.concat((freq, freq), axis=self.feature_axis)def get_axis(axis):return axis if axis > 0 else len(x.shape) + axisfeature_axis = get_axis(self.feature_axis)sequence_axis = get_axis(self.sequence_axis)for axis in range(len(x.shape)):if axis != sequence_axis and axis != feature_axis:embedding = tf.expand_dims(embedding, axis)return tf.cos(embedding), tf.sin(embedding)

二、论文解读

RoPE通过其特性优先于现有的位置编码方法,包括序列长度的灵活性、随着相对距离的增加而减少的标记间依赖性,以及用相对位置编码装备线性自注意的能力。在各种长文本分类基准数据集上的实验结果表明,具有RoPE嵌入的Transformer,即RoFormer,具有更好的性能;

RoPE的关键思想是通过将上下文表示与一个旋转矩阵相乘来获取元素的相对位置;

2.1 注意力机制

下面是注意力机制的公式,老生常谈了,给个图就行;

2.2 绝对位置编码

这个是最普通的Transformer采取的编码方式,非常的经典;

2.3 相对位置编码

下图是Transformer-XL采取的编码方式,其目的是为了避免在循环机制中出现位置混淆;

下面两个是Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer采取的编码方式;

可以看到,这里直接把位置编码转化为一个要学习的参数 b i , j b_{i,j} bi,j进行嵌入,自由度非常大;

这里和上图的不同是这里另外添加了绝对位置编码的信息;

DeBERTa: Decoding-enhanced BERT with Disentangled Attention这篇论文中认为常规注意力机制中的 p m T ⋅ W q T ⋅ W k ⋅ p n p_m^T·W_q^T·W_k·p_n pmTWqTWkpn并没有表达相对信息,只是做一个bias的作用,而bias q , k , v q,k,v q,k,v时就已经体现,不需要bias,采取删除的方法,然后把绝对位置信息转化为相对位置信息;

论文说Radford and Narasimhan(这两货是GPT模型的提出者)在2018年的时候对这四种变体进行了比较,发现第四个相对位置编码即删除了bias的相对位置编码最为合理;但让我纳闷的是这不是2020年的论文吗?

2.4 旋转位置编码

旋转位置编码RoPE的关键思想是通过将上下文表示与一个旋转矩阵相乘来编码相对位置;

所以RoPE本质上也是一种相对位置编码,那么其目标肯定 q m T k n q_m^Tk_n qmTkn 只与 x m x_m xm x n x_n xn 以及其相对位置 m − n m-n mn 有关;公式如下:

但凡提到旋转Rotary,肯定是离不开三角函数的,这种方法是把一串序列绕成一个圆,如图所示:

这是我随便从网上下载的图片,简单了解方式即可;第一个位置从3点钟方向开始,把所有的序列逆时针打满一圈,这就是旋转位置编码,论文中有一张图很形象,如图所示:

下面便是上图的公式化表达;

论文中得出这一公式有一个推导,有意思但同时有点长,我把他贴在下面;

不得不感慨,还是咱们中国人把文章写得明白和透彻;

这样做有什么优势呢?

Long-term decay

这里的推理其实很简单,最后一个公式是由图像说明的,

∑ i = 1 d / 2 ∣ S i ∣ \sum_{i=1}^{d/2}|S_i| i=1d/2Si n − m n-m nm上虽然不是单调递减,但是其总体趋势是递减的

Adaption for linear attention

其相对位置不需要学习,不需要训练参数,只需要乘以一个旋转矩阵,类似于绝对编码,但是其实质有相对性;

2.5 模型效果

从下图中可以看到RoPE的效果要比Sinusoidal positional encoding要好;

三、过程实现

class RotaryEmbedding(tf.keras.layers.Layer):def __init__(self,max_wavelength=10000,scaling_factor=1.0,sequence_axis=1,feature_axis=-1,**kwargs):super().__init__(**kwargs)self.max_wavelength = max_wavelengthself.sequence_axis = sequence_axisself.feature_axis = feature_axisself.scaling_factor = scaling_factorself.built = Truedef call(self, inputs, start_index=0):rotary_dim = tf.shape(inputs)[-1]cos_emb, sin_emb = self._compute_cos_sin_embedding(inputs, rotary_dim, start_index)return self._apply_rotary_pos_emb(inputs, cos_emb, sin_emb)def _apply_rotary_pos_emb(self, tensor, cos_emb, sin_emb):x1, x2 = tf.split(tensor, 2, axis=self.feature_axis)half_rot_tensor = tf.concat((-x2, x1), axis=self.feature_axis)return (tf.matmul(tensor,cos_emb)) + (tf.matmul(half_rot_tensor, sin_emb))def _compute_cos_sin_embedding(self, x, rotary_dim, start_index):freq_range = tf.range(0, rotary_dim, 2, dtype="float32")freq_range = tf.cast(freq_range, self.compute_dtype)freq_range = freq_range / tf.cast(self.scaling_factor, self.compute_dtype)inverse_freq = 1.0 / (self.max_wavelength** (freq_range / tf.cast(rotary_dim, self.compute_dtype)))seq_len = tf.shape(x)[self.sequence_axis]tensor = tf.range(seq_len, dtype="float32") + start_indextensor = tf.cast(tensor, dtype=inverse_freq.dtype)freq = tf.einsum("i, j -> ij", tensor, inverse_freq)embedding = tf.concat((freq, freq), axis=self.feature_axis)def get_axis(axis):return axis if axis > 0 else len(x.shape) + axisfeature_axis = get_axis(self.feature_axis)sequence_axis = get_axis(self.sequence_axis)for axis in range(len(x.shape)):if axis != sequence_axis and axis != feature_axis:embedding = tf.expand_dims(embedding, axis)return tf.cos(embedding), tf.sin(embedding)

四、整体总结

中国人牛逼!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/194629.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java 使用itextpdf创建Pdf文件

DOM文件添加Maven依赖 <dependency><groupId>com.itextpdf</groupId><artifactId>itext7-core</artifactId><version>7.2.0</version><type>pom</type></dependency> 主要代码&#xff1a; PdfFont font PdfFo…

【数据结构】拆分详解 - 二叉树的链式存储结构

文章目录 一、前置说明二、二叉树的遍历  1. 前序、中序以及后序遍历   1.1 前序遍历   1.2 中序遍历   1.3 后序遍历 2. 层序遍历 三、常见接口实现  0. 递归中的分治思想  1. 查找与节点个数   1.1 节点个数   1.2 叶子节点个数   1.3 第k层节…

yo!这里是智能指针相关介绍

目录 前言 内存泄漏 RAII 智能指针原理 智能指针分类 auto_ptr unique_ptr shared_ptr 两个问题 线程安全 循环引用 后记 前言 对于智能指针&#xff0c;听起来很高大上&#xff0c;其实本质上就是一个类。为什么叫指针呢&#xff1f;因为可以像指针一样管理一块资…

linux 应用开发笔记---【I/O文件/基础篇 】

文章笔记来自于【速学Linux】手把手教你学嵌入式Linux C应用编程_哔哩哔哩_bilibili 一&#xff0c;什么是linux应用程序 1.运行在linux操作系统用户空间的程序 2.内核程序运行在内核空间&#xff0c;应用程序运行在用户空间 在终端执行的命令ls,ps。。。。。。都是运行在用…

使用gdb调试QEMU模拟的RISC-V平台程序

我们跑一个裸核程序&#xff0c;也就是不带操作系统的程序&#xff0c;然后使用gdb调试该程序。 首先编译目标程序&#xff0c;然后使用QEMU的kernel参数进行加载 qemu-system-riscv64 -s -S -bios opensbi.elf -m 4G -smp 4 -kernel my_program.x -nographic -s 让QEMU在12…

【MySQL的DQL查询语句】

MySQL的DQL查询语句-----在Navicat下 将学生表导入Navicat中查询语句查询一整张表查询年龄大于22年龄大于22的女生查找文科的学生查找六班的学生计算学生的总分 &#xff08;group by&#xff09;合并两表 &#xff08;join on xxxx&#xff09;合并两张表 并求总分先合并在聚合…

Java+springboot+avue医院绩效考核系统源码支持二次开发

公立医院改革要求建立公立医疗卫生机构绩效考核体系&#xff0c;借助绩效考核来引导各级公立医院把社会效益摆在首位&#xff0c;提高医疗服务质量&#xff0c;规范医疗服务行为&#xff0c;加强医院内部管理&#xff0c;促进医院高质量发展 医院绩效考核系统&#xff0c;建立以…

python 运用pandas 库处理excel 表格数据

文章目录 读取文件查看数据数据选择数据筛选创建新列计算并总结数据分组统计 读取文件 Pandas 是一个强大的数据分析库&#xff0c;它提供了丰富的数据结构和数据分析工具&#xff0c;其中之一是用于读取不同格式文件的 read_* 函数系列。以下是一个简单介绍如何使用 Pandas 读…

Siemens-NXUG二次开发-C/C++/Python环境配置[20231204]

Siemens-NXUG二次开发-C/C/Python运行方式[20231204] 1.NX/UG C/C/Python API官方开发文档2.运行方式2.1内部模式2.2 外部模式2.3 许可证书服务器启动 3.C/C环境配置4.Python环境配置5.第三方环境配置 1.NX/UG C/C/Python API官方开发文档 西门子NX/UG Python api开发文档&…

Spring学习笔记:Day2

昨天定的学习计划发现通过文心4.0来实现不靠谱&#xff0c;坑太多&#xff0c;今天开始跟随B站进行学习&#xff0c;争取10-15天学习一遍&#xff0c;冲啊&#xff01; 地址&#xff1a;001-课程介绍_哔哩哔哩_bilibili 今日规划&#xff1a; pt 001 - pt 018&#xff0c;提到…

【苍穹外卖】——第一天

第一天学习目标&#xff1a; 本系列只是对于学习苍穹外卖的一个学习总结和问题记录&#xff0c;学习的话还是照着黑马的视频学习 对内容有一个整体把握 搭建项目环境 对一些基础的名词理解 了解nginx反向代理和负载均衡 能使用Swagger测试后端接口 学习内容&#xff1a; pojo分…

小心处理 C++ 静态变量中的陷阱

小心处理 C 静态变量中的陷阱 函数中的 static 变量 static 变量的作用 C 中 static 关键字的最后一个用途是在函数内创建局部变量&#xff0c;这些变量在其作用域内退出和进入时保持其值。函数内的 static 变量类似于只能从该函数访问的全局变量。static 变量的一个常见用途…

【UGUI】实现背包的常用操作

1. 添加物品 首先&#xff0c;你需要一个包含物品信息的类&#xff0c;比如 InventoryItem&#xff1a; using UnityEngine;[CreateAssetMenu(fileName "NewInventoryItem", menuName "Inventory/Item")] public class InventoryItem : ScriptableObje…

网工学习7-配置 GVRP 协议

7.1GARP概述 GARP(Generic Attribute Registration Protocol)是通用属性注册协议的应用&#xff0c;提供 802.1Q 兼容的 VLAN 裁剪 VLAN pruning 功能和在 802.1Q 干线端口 trunk port 上建立动态 VLAN 的功能。 GARP 作为一个属性注册协议的载体&#xff0c;可以用来传播属性…

Java 原子操作类

一、原子类 1.1 基本原子类 AtomicBooleanAtomicIntegerAtomicLong 1.1.1 常用API public final int get() //获取当前的值public final int getAndSet(int newValue)//获取当前的值&#xff0c;并设置新的值public final int getAndIncrement()//获取当前的值&#xff0c;…

游泳馆会员服务预约管理系统预约小程序效果如何

游泳馆在各地每天都有大量用户前往&#xff0c;夏季室外、冬季室内也是学习游泳技术和休闲娱乐的好地方&#xff0c;而消费者大多是年轻人和家长带的孩子&#xff0c;这部分群体更显年轻化&#xff0c;因此在如今互联网环境下&#xff0c;传统商家需要进一步赋能客户消费路径。…

【Vue】Vue CLI 脚手架(Vue Command Line Interface)安装教程(通过npm)

前言 Vue CLI&#xff08;Vue Command Line Interface&#xff09;是一个基于Vue.js的官方脚手架工具&#xff0c;用于快速搭建和管理Vue.js项目。它提供了一套完整的开发工具和配置&#xff0c;包括项目初始化、开发服务器、热重载、构建和打包等功能。 Vue CLI使用了Webpac…

Doris 数据导出方式总结

1 Export导出 数据导出是Doris提供的一种将数据导出的功能。该功能可以将用户指定的表或分区的数据以文本的格式,通过Broker进程导出到远端存储上,如HDFS/BOS等。 1.1 基本原理 用户提交一个 Export 作业后。Doris 会统计这个作业涉及的所有 Tablet。然后对这些 Tablet 进行分…

自动驾驶学习笔记(十三)——感知基础

#Apollo开发者# 学习课程的传送门如下&#xff0c;当您也准备学习自动驾驶时&#xff0c;可以和我一同前往&#xff1a; 《自动驾驶新人之旅》免费课程—> 传送门 《Apollo Beta宣讲和线下沙龙》免费报名—>传送门 文章目录 前言 传感器 测距原理 坐标系 标定 同…

2023/12/3总结

RabbitMq 消息队列 下载地址RabbitMQ: easy to use, flexible messaging and streaming — RabbitMQ 使用详情RabbitMQ使用教程(超详细)-CSDN博客 实现延迟队列&#xff08;为了实现订单15分钟后修改状态&#xff09; 1 死信队列 当一个队列中的消息满足下列情况之一时&…