Gemma中RoPE代码详细讲解

最近在看Gemma代码感觉比LLama的代码看的方便点, 看到RoPE代码跟常规的方式不太一样(也不算常规,就是我理解的方式),特此记录一下。我的RoPE入门代码参考:Rotary Position Embedding (RoPE, 旋转式位置编码) | 原理讲解+torch代码实现
原理我就不讲了,直接贴一下图,图源自于上面的链接。
在这里插入图片描述
我们先粘贴一下代码,逐步讲解:

dim:单头维度信息
end:序列长度
theta:10000
def precompute_freqs_cis(dim: int,end: int,theta: float = 10000.0) -> torch.Tensor:"""Precomputes the frequency cis."""freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))t = torch.arange(end, device=freqs.device)freqs = torch.outer(t, freqs).float()freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64return freqs_cisx:输入特征维度[batch, end, num_head, dim]
freqs_cis:上个函数获取的结果
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:"""Applies the rotary embedding to the query and key tensors."""x_ = torch.view_as_complex(torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1),dim=-1))x_out = torch.view_as_real(x_ * freqs_cis).type_as(x)x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2],-1).transpose(1, 2)return x_out

precompute_freqs_cis

  • freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
    这个代码主要是实现一下公式在这里插入图片描述
    torch.arange(0, dim, 2),生成列表 [0,2, …d_model//2]
  • t = torch.arange(end, device=freqs.device)
    生成序列长度, [0, 1, …, end(也就是序列长度)]
  • freqs = torch.outer(t, freqs).float()
    进行笛卡尔积,维度变成[end, dim//2]
    在这里插入图片描述
  • freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    通过polar函数生成cos和sin值,为什么要使用torch.ones_like(freqs), 下面公式,abs为1,不就是cos值和sin值了
    在这里插入图片描述

apply_rotary_emb

  • x_ = torch.view_as_complex(
    torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1),
    dim=-1))
    x.transpose(1, 2).float()将输入维度变为[batch, num_head, end, dim]
    torch.chunk将数据前dim//2 和后dim//2分开,我理解的是[q0, q1, …qn]是奇偶分开,而不是前后分开,可能无所谓吧。
    torch.stack则是对维度进行合并,产生[batch, num_head, end, dim//2, 2]这种维度。
    我简单举例子验证一下:
import torch
a = torch.arange(10)
print(a)
b = torch.chunk(a, 2, dim=-1)
print(b)
c = torch.stack(b, dim = -1)
print(c)

在这里插入图片描述
并且用torch.view_as_complex转为复述的形式

  • x_out = torch.view_as_real(x_ * freqs_cis).type_as(x)
    x_ * freqs_cis实现复述的预算
    举例子解释一下:
    x_为:q0+q1 i
    freqs_cis:cos+sin i
    实部为:q0cos -q1sin
    虚部为:q1cos + q1sin
    torch.view_as_real函数则是把转为实数的形式,a+bi->[a, b]形式
  • x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
    则是把维度转为[batch, num_head, end, dim]的形式
  • x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2],
    -1).transpose(1, 2)
    转为输入的形式

下面为LLama的RoPE实现:

def apply_rotary_emb(xq: torch.Tensor,xk: torch.Tensor,freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))freqs_cis = reshape_for_broadcast(freqs_cis, xq_)xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)return xq_out.type_as(xq), xk_out.type_as(xk)

LLama我理解,他则是采用类似于奇偶分开的方式,我简单尝试了一下:

import torch
a = torch.arange(10)
print(a)
b = a.reshape(5,2)
print(b)

在这里插入图片描述

总结:

以上就是我对RoPE代码实现的理解,相比原来理解的方式,这种相对更加简洁,但是略有一些绕

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/745369.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

自然语言处理实验2 字符级RNN分类实验

实验2 字符级RNN分类实验 必做题: (1)数据准备:academy_titles.txt为“考硕考博”板块的帖子标题,job_titles.txt为“招聘信息”板块的帖子标题,将上述两个txt进行划分,其中训练集为70%&#xf…

服务器Debian 12.x中安装Jupyer并配置远程访问

服务器系统:Debian 12.x;IP地址:10.100.2.138 客户端:Windows 10;IP地址:10.100.2.38 利用ssh登录服务器: 1.安装python3 #apt install python3 2.安装pip #apt install python3-pip … 3.安装virtualen…

Unity Timeline学习笔记(3) - SignalTrack信号轨道和自定义带参数的Marker信号和轨道

信号轨道,顾名思义就是运行到某处发送一个信号。 普通用法 普通用法就是没有任何封装的,个人感觉特别难用,但是有必要理解一下工作原理。 添加信号 我们添加一个信号资源 生成后可以看到资源文件,这个是可以拖到SignalTrack上…

【Python数据结构与判断7/7】数据结构小结

目录 序言 整体回忆 定义方式 访问元素 访问单个元素 访问多个与元素 修改元素 添加元素 列表里添加元素 字典里添加元素 删除元素 in运算符 实战案例 总结 序言 今天将对前面学过的三种数据结构:元组(tuple)、列表(…

微前端框架 qiankun 配置使用【基于 vue/react脚手架创建项目 】

qiankun官方文档:qiankun - qiankun 一、创建主应用: 这里以 vue 为主应用,vue版本:2.x // 全局安装vue脚手架 npm install -g vue/clivue create main-app 省略 vue 创建项目过程,若不会可以自行百度查阅教程 …

java垃圾回收-三色标记法

三色标记法 引言什么是三色标记法白色灰色黑色 三色标记过程三色标记带来的问题多标问题漏标问题 如何弥补漏标问题增量更新原始快照总结 引言 在CMS,G1这种并发的垃圾收集器收集对象时,假如一个对象A被GC线程标记为不可达对象,但是用户线程又把A对象做…

数字化经济的前沿:深入了解 Web3 的商业模式

随着区块链技术的迅速发展,Web3作为一种新型的互联网范式,正逐渐引起人们的关注。它不仅仅是一种技术革新,更是一种商业模式和价值观的转变。本文将深入探讨Web3的商业模式,以及它对数字化经济的影响。 1. 理解Web3的商业模式 We…

算法---滑动窗口练习-4(无重复字符的最长子串)

无重复字符的最长子串 1. 题目解析2. 讲解算法原理3. 编写代码 1. 题目解析 题目地址:点这里 2. 讲解算法原理 算法的主要思想是使用滑动窗口来维护一个不含重复字符的子串。定义两个指针 left 和 right 分别表示窗口的左边界和右边界。还定义了一个数组 hash 来记…

Apache Paimon 的 CDC Ingestion 概述

CDC Ingestion 1)概述 Paimon支持schema evolution将数据插入到Paimon表中,添加的列将实时同步到Paimon表,并且无需重启同步作业。 目前支持的同步方式如下: MySQL Synchronizing Table: 将MySQL中的一个或多个表同步到一个Pa…

【算法与数据结构】深入解析二叉树(一)

文章目录 📝数概念及结构🌠 树的概念🌉树的表示🌠 树在实际中的运用(表示文件系统的目录树结构) 🌉二叉树概念及结构🌠概念🌉数据结构中的二叉树🌠特殊的二叉…

Spring web MVC(2)

1、RequestMapping称为路由映射(既是类注解也是方法注解提供访问路径) 2、RequestParam起到重命名的作用,也起到绑定的作用,传递集合list时会用到,多个值绑定给list,默认是必传参数如果不传参数需要设置re…

如何在Windows 10上打开和关闭平板模式?这里提供详细步骤

前言 默认情况下,当你将可翻转PC重新配置为平板模式时,Windows 10会自动切换到平板模式。如果你希望手动打开或关闭平板模式,有几种方法可以实现。​ 自动平板模式在Windows 10上如何工作 如果你使用的是二合一可翻转笔记本电脑&#xff0…

Spring, SpringBoot, SpringCloud,微服务

1,SSM (Spring+SpringMVC+MyBatis) SSM框架集由Spring、MyBatis两个开源框架整合而成(SpringMVC是Spring中的部分内容),常作为数据源较简单的web项目的框架。 Spring MVC 是 Spring 提供的一个基于 MVC 设计模式的轻量级 Web 开发框架,本质上相当于 Servlet,Controlle…

vue 基于elementUI/antd-vue, h函数实现message中嵌套链接跳转到指定路由 (h函数点击事件的写法)

效果如图: 点击message 组件中的 工单管理, 跳转到工单管理页面。 以下是基于vue3 antd-vue 代码如下: import { message } from ant-design-vue; import { h, reactive, ref, watch } from vue; import { useRouter } from vue-router; c…

PY32离线烧录器功能介绍,可批量烧录,支持PY32系列多款单片机

PY32离线烧录器可以对PY系列单片机进行批量烧录,现支持PY32F002A/002B/002/003/030/071/072/040/403/303芯片各封装和XL2409,XL32F001/003等芯片。PY32离线烧录器需要搭配上位机软件才能使用,上位机软件在我们官网(www.xinlinggo.…

【软考】UML中的图之对象图

目录 1. 说明2. 图示3. 特性 1. 说明 1.对象图即object diagram2.展现了某一时刻一组对象以及它们之间的关系3.描述了在类图中所建立的事物的实例的静态快照4.对象图一般包括对象和链5.对象图展示的是对象之间关系,不存在交互,所以不是交互图 2. 图示 …

#微信小程序(一个emo文案界面)

1.IDE:微信开发者工具 2.实验:一个emo文案界面 (1)最好使用rpx (2)图片宽度占不满,在CSS中设置width为100% (3)imag图片全部为网页链接图片 3.记录 4.代码 index.htm…

Jmeter+ant,ant安装与配置

1.ant含义 ant:Ant翻译过来是蚂蚁的意思,在我们做接口测试的时候,是可以用来做JMeter接口测试生成测试报告的工具 2.ant下载 下载地址:Apache Ant - Ant Manual Distributions download中选择ant 下载安装最新版zip文件 3.…

阿里云国际放行DDoS高防回源IP

如果源站服务器上设置了IP白名单访问控制(如安全软件、安全组),由于设置了DDoS高防后,回源IP是高防回源IP段,您需要将DDoS高防的回源IP段的地址加入安全软件和安全组的白名单中,避免DDoS高防的回源流量被误…

导入fetch_california_housing 加州房价数据集报错解决(HTTPError: HTTP Error 403: Forbidden)

报错 HTTPError Traceback (most recent call last) Cell In[3], line 52 from sklearn.datasets import fetch_california_housing3 from sklearn.model_selection import train_test_split ----> 5 X, Y fetch_california_housing(retu…