LLaMA长度外推高性价比trick:线性插值法及相关改进源码阅读及相关记录

前言

最近,开源了可商用的llama2,支持长度相比llama1的1024,拓展到了4096长度,然而,相比GPT-4、Claude-2等支持的长度,llama的长度外推显得尤为重要,本文记录了三种网络开源的RoPE改进方式及相关源码的阅读。

关于长度外推性:https://kexue.fm/archives/9431

关于RoPE:https://kexue.fm/archives/8265

1、线性插值法

论文:EXTENDING CONTEXT WINDOW OF LARGE LANGUAGE MODELS VIA POSITION INTERPOLATION

链接:https://arxiv.org/pdf/2306.15595.pdf

思想:不进行长度外推,而是直接缩小位置索引。即:将4096的位置编码通过线性插值法压缩到2048内,这样只需在少量的4096长度的数据上继续预训练,便可达到不错的效果。

在这里插入图片描述

源码阅读(附注释)

class LlamaLinearScaledRotaryEmbedding(torch.nn.Module):def __init__(self, dim, max_position_embeddings=2048, base=10000, scale=1, device=None):super().__init__()# 相比RoPE增加scale参数self.scale = scale# inv_freq为基值向量inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))self.register_buffer("inv_freq", inv_freq)# Build here to make `torch.jit.trace` work.self.max_seq_len_cached = max_position_embeddings# 构建max_seq_len_cached大小的张量tt = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)# 张量t归一化,RoPE没有这一步t /= self.scale# einsum计算频率矩阵# 'i, j->i j’表示分别输入尺寸为[i]、[j]的向量,做笛卡尔运算得到尺寸为[i, j]的矩阵。freqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculation# 在-1维做一次拷贝、拼接emb = torch.cat((freqs, freqs), dim=-1)dtype = torch.get_default_dtype()# 注册为模型的缓冲区cos_cached和sin_cachedself.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)def forward(self, x, seq_len=None):# x: [bs, num_attention_heads, seq_len, head_size]# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.# seq_len为序列长度,seq_len大于max_seq_len_cached,则重新计算频率矩阵,并更新cos_cached和sin_cached的缓冲区if seq_len > self.max_seq_len_cached:self.max_seq_len_cached = seq_lent = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)t /= self.scalefreqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculationemb = torch.cat((freqs, freqs), dim=-1).to(x.device)self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)# 长度裁剪:返回cos_cached和sin_cached中与seq_len(序列长度)return (self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),)

线性插值法的相关实验效果:https://lmsys.org/blog/2023-06-29-longchat/

2、NTK插值法

NTK插值改进llama中使用的RoPE插值方法,同样,对于RoPE代码改动更小,其他地方与线性插值法实现一致。

reddit原帖:NTK-Aware Scaled RoPE allows LLaMA models to have extended (8k+) context size without any fine-tuning and minimal perplexity degradation

链接:https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/?rdt=58346

源码阅读:

class LlamaNTKScaledRotaryEmbedding(torch.nn.Module):def __init__(self, dim, max_position_embeddings=2048, base=10000, alpha=1, device=None):super().__init__()# 与线性插值法相比,实现更简单,alpha仅用来改变basebase = base * alpha ** (dim / (dim-2))inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))self.register_buffer("inv_freq", inv_freq)# Build here to make `torch.jit.trace` work.self.max_seq_len_cached = max_position_embeddingst = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)freqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculationemb = torch.cat((freqs, freqs), dim=-1)dtype = torch.get_default_dtype()self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)def forward(self, x, seq_len=None):# x: [bs, num_attention_heads, seq_len, head_size]# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.if seq_len > self.max_seq_len_cached:self.max_seq_len_cached = seq_lent = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)freqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculationemb = torch.cat((freqs, freqs), dim=-1).to(x.device)self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)return (self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),)

3、动态插值法

动态插值法又是对NTK插值法和线性插值法的改进,可以看作是上述两者的一种结合思想,旨在减少困惑度损失并实现更大的缩放。

reddit原帖:Dynamically Scaled RoPE further increases performance of long context LLaMA with zero fine-tuning

链接:https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/

源码阅读

class LlamaDynamicScaledRotaryEmbedding(torch.nn.Module):def __init__(self, dim, max_position_embeddings=2048, base=10000, ntk=False, device=None):super().__init__()# 是否开启NTK(Neural Tangent Kernel)self.ntk = ntkself.base = baseself.dim = dimself.max_position_embeddings = max_position_embeddings# inv_freq为基值向量inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))self.register_buffer("inv_freq", inv_freq)# Build here to make `torch.jit.trace` work.self.max_seq_len_cached = max_position_embeddingst = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)freqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculation# emb:[max_seq_len_cached, dim]emb = torch.cat((freqs, freqs), dim=-1)dtype = torch.get_default_dtype()self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)def forward(self, x, seq_len=None):# x: [bs, num_attention_heads, seq_len, head_size]# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.if seq_len > self.max_seq_len_cached:self.max_seq_len_cached = seq_lenif self.ntk:base = self.base * ((self.ntk * seq_len / self.max_position_embeddings) - (self.ntk - 1)) ** (self.dim / (self.dim-2))# 计算新的inv_freqinv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim))self.register_buffer("inv_freq", inv_freq)t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)if not self.ntk:# 缩放t *= self.max_position_embeddings / seq_len# 得到新的频率矩阵freqsfreqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculation# freqs与自身拼接得到新的embemb = torch.cat((freqs, freqs), dim=-1).to(x.device)# 注册为模型的缓冲区cos_cached和sin_cachedself.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)# 长度裁剪return (self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),)

网友对于困惑度的实验并取得了一定的效果:https://github.com/turboderp/exllama/pull/118

总结

本文介绍了llama通过线性插值法及相关改进方案进行长度外推的trcik,并对相关源码阅读及网络资源进行记录,个人粗浅认为,相比LongLLaMA,基于线性插值法+Finetune的方式,是一种高性价比的长度外推方案。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/38287.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Vue-打印组件页面

场景: 需要将页面的局部信息打印出来&#xff0c;只在前端实现&#xff0c;不要占用后端的资源。经过百度经验&#xff0c;决定使用 print-js和html2canvas组件。 1. 下载包 npm install print-js --save npm install --save html2canvas 2. 组件内引用 <script>impo…

C语言之数组指针和指针数组

C语言之数组指针和指针数组 一、含义二、定义2.1 指针数组2.2 数组指针 三、使用3.1 指针数组在参数传递时的使用3.1.1 指针数组的排序3.2 数组指针在参数传递时的使用 一、含义 指针数组&#xff1a;顾名思义&#xff0c;其为一个数组&#xff0c;数组里面存放着多个指针&…

C#生成随机验证码

以下是一个简单的C#验证码示例&#xff1a; private void GenerateCaptcha() {// 生成随机字符串string chars "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";Random random new Random();string captchaString new string(Enumerable.Repe…

TPAMI, 2023 | 用压缩隐逆向神经网络进行高精度稀疏雷达成像

CoIR: Compressive Implicit Radar | IEEE TPAMI, 2023 | 用压缩隐逆向神经网络进行高精度稀疏雷达成像 注1:本文系“无线感知论文速递”系列之一,致力于简洁清晰完整地介绍、解读无线感知领域最新的顶会/顶刊论文(包括但不限于Nature/Science及其子刊;MobiCom, Sigcom, MobiSy…

Java【算法 04】HTTP的认证方式之DIGEST认证详细流程说明及举例

HTTP的认证方式之DIGEST 1.是什么2.认值流程2.1 客户端发送请求2.2 服务器返回质询信息2.2.1 质询参数2.2.2 质询举例 2.3 客户端生成响应2.4 服务器验证响应2.5 服务器返回响应 3.算法3.1 SHA-2563.1.1 Response3.1.2 A13.1.3 A2 3.2 MD53.2.1 Request-Digest3.2.2 A13.2.3 A2…

CSS3 中新增了哪些常见的特性?

聚沙成塔每天进步一点点 ⭐ 专栏简介⭐ 圆角&#xff08;Border Radius&#xff09;⭐ 渐变&#xff08;Gradients&#xff09;⭐ 阴影&#xff08;Box Shadow&#xff09;⭐ 文本阴影&#xff08;Text Shadow&#xff09;⭐ 透明度&#xff08;Opacity&#xff09;⭐ 过渡&…

Spring boot与Spring cloud 之间的关系

Spring boot与Spring cloud 之间的关系 Spring boot 是 Spring 的一套快速配置脚手架&#xff0c;可以基于spring boot 快速开发单个微服务&#xff0c;Spring Boot&#xff0c;看名字就知道是Spring的引导&#xff0c;就是用于启动Spring的&#xff0c;使得Spring的学习和使用…

MATLAB中xlsread函数用法

目录 语法 说明 示例 将工作表读取到数值矩阵 读取元胞的范围 读取列 请求数值、文本和原始数据 对工作表执行函数 请求自定义输出 局限性 xlsread函数的功能是读取Microsoft Excel 电子表格文件 语法 num xlsread(filename) num xlsread(filename,sheet) num x…

Nacos和GateWay路由转发NotFoundException: 503 SERVICE_UNAVAILABLE “Unable to find

问题再现&#xff1a; 2023-08-15 16:51:16,151 DEBUG [reactor-http-nio-2][CompositeLog.java:147] - [dc73b32c-1] Encoding [{timestampTue Aug 15 16:51:16 CST 2023, path/content/course/list, status503, errorService Unavai (truncated)...] 2023-08-15 16:51:16,17…

leetcode27—移除元素

思路&#xff1a; 参考26题目双指针的思想&#xff0c;只不过这道题不是快慢指针。 看到示例里面数组是无序的&#xff0c;也就是说后面的元素也是可能跟给定 val值相等的&#xff0c;那么怎么处理呢。就想到了从前往后遍历&#xff0c;如果left对应的元素 val时&#xff0c…

汽车制造业上下游协作时 外发数据如何防泄露?

数据文件是制造业企业的核心竞争力&#xff0c;一旦发生数据外泄&#xff0c;就会给企业造成经济损失&#xff0c;严重的&#xff0c;可能会带来知识产权剽窃损害、名誉伤害等。汽车制造业&#xff0c;会涉及到重要的汽车设计图纸&#xff0c;像小米发送汽车设计图纸外泄事件并…

[足式机器人]Part5 机械设计 Ch00/01 绪论+机器结构组成与连接 ——【课程笔记】

本文仅供学习使用 本文参考&#xff1a; 《机械设计》 王德伦 马雅丽课件与日常作业可登录网址 http://edu.bell-lab.com/manage/#/login&#xff0c;选择观摩登录&#xff0c;查看2023机械设计2。 机械设计-Ch00Ch01——绪论机器结构组成与连接 Ch00-绪论0.1 何为机械设计——…

12.Eclipse导入Javaweb项目

同事复制一份他的项目给我ekp.rar (懒得从SVN上拉取代码了)放在workspace1目录下 新建一个文件夹 workspace2&#xff0c;Eclipse切换到workspace2工作空间 选择Import导入 选择导入的项目(这里是放到workspace1里面) 拷贝一份到workspace2里面 例子 所有不是在自己电脑上开发…

可白嫖的4家免费CDN,并测试其网络加速情况(2023版)

网站加载速度优化过程中&#xff0c;不可避免的会用上CDN来加速资源的请求速度。但是市面上的CDN资源几乎都是要收费的&#xff0c;而且价格还不便宜&#xff0c;对于小公司站长来讲&#xff0c;这将是一笔不小的开销。不过还是有一些良心公司给我们提供了免费的资源&#xff0…

ZooKeeper的基本概念

集群角色 通常在分布式系统中&#xff0c;构成一个集群的每一台机器都有自己的角色&#xff0c;最典型的集群模式就是Master/Slave模式(主备模式)。在这种模式中&#xff0c;我们把能够处理所有写操作的机器称为Master机器&#xff0c;把所有通过异步复制方式获取最新数据&…

Redis_亿级访问量数据处理

11. 亿级访问量数据处理 11.1 场景表述 手机APP用户登录信息&#xff0c;一天用户登录ID或设备ID电商或者美团平台&#xff0c;一个商品对应的评论文章对应的评论APP上有打卡信息网站上访问量统计统计新增用户第二天还留存商品评论的排序月活统计统计独立访客(Unique Vistito…

【BEV】3D视觉 PRELIMINARY

这里的知识来自于论文 Delving into the Devils of Bird’s-eye-view Perception: A Review, Evaluation and Recipe 的 Appendix B.1 部分来自 这篇文章 从透视图转向鸟瞰图。&#xff08;Xw、Yw、Zw&#xff09;、&#xff08;Xc、Yc、Zc&#xff09;表示世界World坐标和相…

Android学习之路(4) UI控件之Button (按钮)与 ImageButton (图像按钮)

本节引言&#xff1a; 今天给大家介绍的Android基本控件中的两个按钮控件&#xff0c;Button普通按钮和ImageButton图像按钮&#xff1b; 其实ImageButton和Button的用法基本类似&#xff0c;至于与图片相关的则和后面ImageView相同&#xff0c;所以本节 只对Button进行讲解&am…

vue自定义穿梭框支持远程滚动加载

分享-2023年资深前端进阶&#xff1a;前端登顶之巅-最全面的前端知识点梳理总结&#xff0c;前端之巅 *分享一个使用比较久的&#x1fa9c; 技术框架公司的选型(老项目)&#xff1a;vue2 iview-ui 方案的实现思路是共性的&#xff0c;展现UI样式需要你们自定义进行更改&#…

【注解使用】使用@Autowired后提示:Field injection is not recommended(Spring团队不推荐使用Field注入)

问题发生场景&#xff1a; 在使用 IDEA 开发 SpringBoot 项目时&#xff0c;在 Controller 类中使用注解 Autowired 注入一个依赖出现了警告提示&#xff0c;查看其他使用该注解的地方同样出现了警告提示。这是怎么回事&#xff1f;由于先去使用了SpringBoot并没有对Spring进行…