LLM 中的长文本问题

近期,随着大模型技术的发展,长文本问题逐渐成为热门且关键的问题,不妨简单梳理一下近期出现的典型的长文本模型:

  • 10 月上旬,Moonshot AI 的 Kimi Chat 问世,这是首个支持 20 万汉字输入的智能助手产品;

  • 10 月下旬,百川智能发布 Baichuan2-192K 长窗口大模型,相当于一次处理约35 万个汉字;

  • 11 月上旬,OpenAI 发布支持 128K 上下文窗口的 GPT-4 Turbo 模型;

  • 11 月下旬,Anthropic 发布支持 200K 上下文窗口的 Claude 2.1 模型;

  • 12 月上旬,零一万物开源了长文本模型 Yi-6B-200K和 Yi-34B-200K。

实际上,随着文本长度的提高,模型能够处理问题的边界也大大提高,因此研究并解决长文本问题就显得非常必要。本文将从长文本问题的本质出发,逐步分析和研究长文本实现的问题及解决办法。

一、长文本的核心问题与解决方向

1.1 文本长度与显存及计算量之关系

要研究清楚长文本的问题,首先应该搞清楚文本长度在模型中的地位与影响。那么我们便以 Decoder-base 的模型为例来进行分析

1.1.1 模型参数量

Decoder-base 的模型主要包括 3 个部分:embedding, decoder-layer, head。

其中最主要部分是decoder-layer,其由 lll 个层组成,每个层又分为两部分:self-attention 和 MLP。

self-attention的模型参数有、、的权重矩阵 、、及bias,输出矩阵  及bias,4个权重矩阵的形状为 ( 表示 hidden_size),4个bias的形状为  。则 self- attention 的参数量为 。

MLP由2个线性层组成,一般地,第一个线性层是先将维度从  映射到  ,第二个线性层再将维度从映射到。第一个线性层的权重矩阵  的形状为  ,偏置的形状为 。第二个线性层权重矩阵 的形状为  ,偏置形状为  。则 MLP 的参数量为  。

self-attention 和MLP各有一个layer normalization,包含了2个可训练模型参数:缩放参数γ和平移参数β,形状都是。2个layer normalization的参数量为 。

由此,每个Decoder层的参数量为。

此外,embeddinghead 的参数量相同,与词表相关,为(如果是 Tied embedding,则二者共用同一个参数)。由于位置编码多样,且参数量小,故忽略此部分。

综上, 层模型的可训练模型参数量为 。当 较大时,可以忽略一次项,模型参数量近似为。

1.1.2 计算量估计

如果说参数量是模型的固有属性,那么计算量便是由模型和输入共同决定,下面分析这一过程。假设输入数据的形状为  ( 表示batch_size,表示sequence_length)。

先分析Decoder中self-attention的计算量,计算公式如下:

图片

  1. 计算:矩阵乘法的输入和输出形状为。计算量为。

  2. 矩阵乘法的输入和输出形状为

    图片

    计算量为。

  3. 计算在上的加权,矩阵乘法的输入和输出形状为

    图片

    计算量为 。

  4. attention后的线性映射,矩阵乘法的输入和输出形状为.计算量为。

接下来分析MLP块的计算,计算公式如下:

图片

  1. 第一个线性层,矩阵乘法的输入和输出形状为 。计算量为 。

  2. 第二个线性层,矩阵乘法的输入和输出形状为 。计算量为。

将上述计算量相加,得到每个Decoder层的计算量大约为。

此外,另一个计算量的大头是logits的计算,将隐藏向量映射为词表大小。矩阵乘法的输入和输出形状为,计算量为。

因此,对于一个 lll 层的模型,输入数据形状为的情况下,一次前向计算的计算量为。

1.1.3 文本长度与计算量、参数量、显存的关系

忽略低次项,一次输入的tokens数为bs, 则计算量与参数量的关系为  在实际中通常 ,因此该项可近似认为约等于2。即在一次前向传递中,对于每个token,每个模型参数,需要进行2次浮点数运算(一次乘法法运算和一次加法运算)。考虑到后向传递的计算量是前向传递的2倍。因此一次训练迭代中,对于每个 token,每个模型参数,需要进行  次浮点数运算。

通过以上分析,我们可以得到结论:计算量主要和模型参数和 token 数相关,文本长度并不会显著增加计算量。那么这就引出另一个问题:文本长度与显存的关系。

除了模型参数、梯度、优化器状态外,占用显存的大头就是前向传递过程中计算得到的中间激活值。这里的激活(activations)指的是:前向传递过程中计算得到的,并在后向传递过程中需要用到的所有张量。

先分析 Decoder layer 中 self-attention 的中间激活:

  1. 对于  ,需要保存它们共同的输入 ,这就是中间激活。输入  的形状为 ,元素个数为  ,占用显存大小为 。

  2. 对于  矩阵乘法,需要保存中间激活 ,两个张量的形状都是 ,占用显存大小合计为 。

  3. 对于  函数,需要保存函数的输入 ,占用显存大小为 ,这里的  表示注意力头数。

  4. 计算完  函数后,会进行dropout操作。需要保存一个mask矩阵,mask矩阵的形状与  相同,占用显存大小为 。

  5. 计算在  上的attention,即 ,需要保存 score ,大小为 ;以及 ,大小为 。二者占用显存大小合计为 。

  6. 计算输出映射以及一个dropout操作。输入映射需要保存其输入,大小为 ;dropout需要保存mask矩阵,大小为 。二者占用显存大小合计为 。

因此,将上述中间激活相加得到,self-attention的中间激活占用显存大小为 。接下来分析分析Decoder layer中MLP的中间激活:

  1. 第一个线性层需要保存其输入,占用显存大小为 。

  2. 激活函数需要保存其输入,占用显存大小为 。

  3. 第二个线性层需要保存其输入,占用显存大小为 。

  4. 最后有一个dropout操作,需要保存mask矩阵,占用显存大小为 。

对于MLP块,需要保存的中间激活值为 。

另外,self-attention块和MLP块分别对应了一个layer normalization。每个layer norm需要保存其输入,大小为 。2个layer norm需要保存的中间激活为 。

综上,每个层需要保存的中间激活占用显存大小为 。对于  层transformer模型,还有embedding层、最后的输出层。embedding层不需要中间激活。总的而言,当隐藏维度  比较大,层数  较深时,这部分的中间激活是很少的,可以忽略。因此,对于  层模型,中间激活占用的显存大小可以近似为 ,这个结果与文本长度关系密切。

下面以GPT3-175B为例,对比下文本长度对模型参数与中间激活的显存大小的影响。假设数据类型为 FP16 。

模型名参数量层数隐藏维度注意力头数
GPT3175B961228896

GPT3的模型参数量为175B,占用的显存大小为 。GPT3 模型需要占用350GB的显存。

假设 GPT3 输入的 。对比不同的文本长度下占用的中间激活:

当  时,中间激活占用显存为

,大约是模型参数显存的0.79倍;

当  时,中间激活占用显存为

,大约是模型参数显存的2.68倍。

可以看到长度仅仅到 4K,显存占用就出现了剧烈增加,同时 GPU onchip 的 memory 就显得更加捉襟见肘(因此也就出现了 FlashAttention 这类算法)。因此如何解决长文本带来的巨量显存开销成为关键及核心问题。

1.2 长文本问题的解决思路

当前,为了实现更长长文本的支持,解决思路主要可以分为两个阶段:

  • 阶段一:在预训练阶段尽可能支持更长的文本长度
    为实现这一阶段目标,通常采用并行化 (parallelism) 方法将显存占用分摊到多个 device,或者改造 attention 结构,避免显存占用与文本长度成二次关系。

  • 阶段二:在 SFT 或推理阶段尽可能外推到更大长度
    为实现这一阶段目标,通常也是需要在两个方面进行考虑:

    • 对位置编码进行外推

    • 优化 At

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/596561.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

江苏事业单位计算机岗复习备考经验(2023年)

一、考情分析:根据历年考试分析统计,我们江苏事业单位计算机岗考试题型为前百分之四十的行测和时政加上后百分之六十的计算机专业知识;其中前百分之四十为单项选择题,后面的计算机专业知识为单选题、多选题、简答题和实务题。由于…

ssrf之gopher协议的使用和配置,以及需要注意的细节

gopher协议 目录 gopher协议 (1)安装一个cn (2)使用Gopher协议发送一个请求,环境为:nc起一个监听,curl发送gopher请求 (3)使用curl发送http请求,命令为 …

黑马程序员Java项目实战《瑞吉外卖》,轻松掌握springboot + mybatis plus开发核心技术的真java实战项目——第二部分

黑马程序员Java项目实战《瑞吉外卖》,轻松掌握springboot mybatis plus开发核心技术的真java实战项目——第二部分 1.员工管理模块1.1 完善登陆功能1.2 新增员工1.2.1 全局异常捕获 1.3 员工信息分页查询1.4 启用/禁用员工账号1.4.1 使用自定义消息转换器 1.5 编辑…

springboot整合gateway网关

一、网关基本概念 1、API网关介绍 API 网关出现的原因是微服务架构的出现,不同的微服务一般会有不同的网络地址,而外部客户端可能需要调用多个服务的接口才能完成一个业务需求,如果让客户端直接与各个微服务通信,会有以下的问题…

React Admin 前端脚手架之ant-design-pro

文章目录 一、React Admin 前端脚手架选型二、React Admin 前端脚手架之ant-design-pro三、ant-design-pro使用步骤四、常用总结(持续更新)EditableProTable组件 常用组件EditableProTable组件 编辑某行后,保存时候触发发送请求EditableProTa…

linux 系统 kill 指令笔记

kill 名称 kill - send a signal to a process 向指定的线程或进程发送信号 描述 The default signal for kill is TERM. Use -l or -L to list availablesignals. Particularly useful signals include HUP, INT, KILL, STOP,CONT, and 0. Alternate signals …

k8s笔记1- 初步认识k8s

k8s简介: kubernetes,俗称k8是,用于自动部署,扩缩和管理容器化应用程序的开源系统,它将组成应用程序的容器,组合成逻辑单元,便于管理和服务发现。 k8s的作用 自动化上线和回滚、存储编排…

Spring中的工厂类、bean的作用范围和生命周期

1.Spring中的工厂类 1.1ApplicationContext ClassPathXmlApplicationContext:加载类路径下 Spring 的配置文件 FileSystemXmlApplicationContext:加载本地磁盘下 Spring 的配置文件 1.1.1service ApplicationContext:只要一读取配置文件…

PyTorch|PyTorch张量解释

神经网络中的输入、输出和转换都使用张量表示,因此,神经网络编程大量使用张量,张量是我们在 PyTorch 中编程神经网络时将使用的数据结构。 关于张量及其维数的简要说明,以及术语: 你有时会看到一个称为向量的一维张量…

[论文分享]TimesURL:通用时间序列表示学习的自监督对比学习

论文题目:TimesURL: Self-supervised Contrastive Learning for Universal Time Series Representation Learning 论文地址:https://arxiv.org/abs/2312.15709 代码地址:暂无 摘要 学习适用于各种下游任务的通用时间序列表示具有挑战性&…

Springboot整合RocketMQ 基本消息处理

目录 1. 同步消息 2. 异步消息 3. 单向消息 4. 延迟消息 5. 批量消息 6. 顺序消息 7. Tag过滤 导入依赖 <dependency><groupId>org.apache.rocketmq</groupId><artifactId>rocketmq-spring-boot-starter</artifactId></dependency> …

14:00面试,14:08就出来了,问的问题过于变态了。。。

从小厂出来&#xff0c;没想到在另一家公司又寄了。 到这家公司开始上班&#xff0c;加班是每天必不可少的&#xff0c;看在钱给的比较多的份上&#xff0c;就不太计较了。没想到10月一纸通知&#xff0c;所有人不准加班&#xff0c;加班费不仅没有了&#xff0c;薪资还要降40…

机器学习原理到Python代码实现之LinearRegression

Linear Regression 线性回归模型 该文章作为机器学习的第一篇文章&#xff0c;主要介绍线性回归模型的原理和实现方法。 更多相关工作请参考&#xff1a;Github 算法介绍 线性回归模型是一种常见的机器学习模型&#xff0c;用于预测一个连续的目标变量&#xff08;也称为响应变…

Spring的bean的生命周期!!!

一.单例模式 单例&#xff1a;[启动容器]--->通过构造方法&#xff08;创建对象&#xff09;---->调用set方法&#xff08;注入&#xff09;--->调用init方法&#xff08;初始化&#xff09;----[容器关闭]----->调用destroy方法&#xff08;销毁&#xff09; app…

死锁的处理策略“检测和解除”-第三十九天

目录 前言 死锁的检测 数据结构资源分配图 基于“图”检测死锁 可以消除所有边 不能消除所有边 结论 死锁定理 死锁的解除 本节思维导图 前言 如果系统中既不采取预防死锁的措施&#xff0c;也不采取避免死锁的措施&#xff0c;系统就很可能发生死锁&#xff0c;在这种…

西电期末1019.校验和计算

一.题目 二.分析与思路 难点在于逐个取出数据的每一位&#xff0c;我们编写f函数&#xff0c;使用了一个while函数&#xff0c;每次循环中用取余的运算符找到数据的个位累加&#xff0c;再将n/10&#xff0c;如此n便被去除了个位&#xff0c;十位就成了新的个位&#xff0c;最…

案例精选|淄博绿能燃气工程有限公司日志审计系统建设方案

淄博绿能燃气工程有限公司&#xff0c;成立于1994年&#xff0c;前身为淄博市煤气公司管道液化气分公司。公司业务主要涉及天然气、液化气等市政工程施工及城镇燃气供应等领域&#xff0c;具有市政公用工程施工总承包二级资质&#xff0c;《压力管道安装许可证》压力管道安装GB…

利用Embedding优化搜索功能

我们继续用Gemini学习LLM编程之旅。 Embedding是一种自然语言处理 (NLP) 技术&#xff0c;可将文本转换为数值向量。Embedding捕获语义含义和上下文&#xff0c;从而导致具有相似含义的文本具有更接近的Embedding。例如&#xff0c;句子“我带我的狗去看兽医”和“我带我的猫去…

LeetCode---378周赛

题目列表 2980. 检查按位或是否存在尾随零 2981. 找出出现至少三次的最长特殊子字符串 I 2982. 找出出现至少三次的最长特殊子字符串 II 2983. 回文串重新排列查询 一、检查按位或是否存在尾随零 这题和位运算有关&#xff0c;不是很难&#xff0c;题目要求至少有两个数的…

案例073:基于微信小程序的智慧旅游平台开发

文末获取源码 开发语言&#xff1a;Java 框架&#xff1a;SSM JDK版本&#xff1a;JDK1.8 数据库&#xff1a;mysql 5.7 开发软件&#xff1a;eclipse/myeclipse/idea Maven包&#xff1a;Maven3.5.4 小程序框架&#xff1a;uniapp 小程序开发软件&#xff1a;HBuilder X 小程序…