AI大模型:无需训练让LLM支持超长输入

显式搜索: 知识库外挂

  • paper: Unleashing Infinite-Length Input Capacity for Large-scale Language Models with Self-Controlled Memory System

  • 看到最无敌的应用,文本和表格解析超厉害https://chatdoc.com/?viaurl=ainavpro.com

  • ChatGPT代码实现: GitHub - arc53/DocsGPT: GPT-powered chat for documentation, chat with your documents

  • ChatGLM代码实现: GitHub - chatchat-space/Langchain-Chatchat: Langchain-Chatchat(原Langchain-ChatGLM)基于 Langchain 与 ChatGLM 等语言模型的本地知识库问答 | Langchain-Chatchat (formerly langchain-ChatGLM), local knowledge based LLM (like ChatGLM) QA app with langchain

  • 适用于大规模知识问答场景

这块可能是GPT后比较火的方向,有一阵每天都能看到类似的新应用,从GPT读论文,再到百科问答,搭配langchain框架,在DocQA,KBQA的场景简直无往不利, 以上分别给出了基于ChatGPT和ChatGLM的两个实现方案。

实现的步骤基本可以被下图概括

img

  1. 长文本解析切分成chunk: 实际使用过程中发现文本解析竟然是最核心的部分,能否把需要保留语义完整性的段落拆成整段,能否高质量的解析表格,和结构化数据,对后续QA的影响最大

  2. 文本向量化:中文可用的embedding模型有不少,也可以基于simcse,consert在垂直领域做进一步的微调。在向量化阶段主要的问题是文本截断带来的上下文损失会影响召回,因此可以尝试重叠切分,拼接摘要/标题等方式

  3. 向量入库:需要高效向量检索的数据库,Milvus、Pinecone,这块最近也火了一波初创公司

  4. 用户问题改写:在多轮QA的场景,对话历史有两种使用方式,其一使用历史对话对当前query进行改写再召回,其二种是使用原始用户query去召回文本,在回复阶段引入对话历史

  5. 召回:基于用户query或改写query进行向量化检索,topK或者阈值召回。除了考虑相关性,在部分场景也要考虑时效性,文本质量等等

  6. 答案生成:使用召回文档拼接用户query进行答案生成,这一步往往还需要用到模型摘要,Refine等能力,核心是对以上召回的长文本进行压缩

搜索法最大的优点是实现简单,不过也有许多限制就是只能支持NLU任务,以及会破坏输入文本的上下文连续性,和文本顺序。但在大规模知识问答这块算是现在看到最好的方案。

隐式搜索:Unlimiformer

  • Unlimiformer: Long-Range Transformers with Unlimited Length Input

  • GitHub - abertsch72/unlimiformer: Public repo for the NeurIPS 2023 paper "Unlimiformer: Long-Range Transformers with Unlimited Length Input"

  • 适用于Encoder-Decoder模型,长文本摘要等场景

特意起了个隐式搜索的标题,是因为和上面的文本搜索实现有异曲同工之妙,本质的差异只是以上是离散文本块的搜索。而Unlimiformer是在解码阶段对超长输入,token粒度的输出层embedding进行检索,选择最相关的Top Token计算Attention。

img

首先对于超长输入,unlimiformr采用以上提到的重叠切分的方法,重叠率50%,这样可以更好保留上文和文本连贯性,例如第一段文本是1-500字,第二段重叠250字取250-750字。然后使用Encoder对每段文本进行独立编码,绕过Attention的平方复杂度问题。最后输出每段文本的Embedding,注意这里不是文本整体embedidng, 而是后半部分(250~500字)每个Token最上层的Embedding,并写入向量索引,这里用的是Faiss。

在解码层,每一步解码,query都会检索注意力最高的Top-k个输入Token,作为编码器部分的信息用于解码器的解码。这里简单回忆下Attention计算, Top-K个Token就是让以下注意力取值最高的key。

考虑Decoder的每一层(N层)中的每一个head(L个头)都需要和Encoder的输出层进行交互, 检索Top Key,如果存储每一层每个head的Key,需要构建O(L∗N∗seqlen)的向量存储。对此作者进行了优化,改变了以下QK的计算顺序,用每一层每个头Key的映射矩阵对Q进行映射,这样只需要存储一份seq_len的编码向量(hencoderℎ,在每一层检索时用映射后的Q进行检索既可,其实就是时间换空间

unlimiformer提供了代码实现,核心代码抽出来看下有两块

  1. 超长文本编码:对文本进行切块,分别编码,取后半部分

for context_start_ind, context_end_ind, update_start_ind, update_end_ind in window_indices:chunk = input_ids[:, context_start_ind:context_end_ind]chunk_attention_mask = attention_mask[:, context_start_ind:context_end_ind]hidden_states = self.model(chunk, attention_mask=chunk_attention_mask, labels=dummy_labels, return_dict=True)last_hidden = hidden_states.encoder_last_hidden_state # (batch, chunked_source_len, dim)to_add = last_hidden[:, update_start_ind:update_end_ind].detach()to_apply_mask = chunk_attention_mask[:, update_start_ind:update_end_ind]
  1. 向前计算检索Top-key用于Attention矩阵的计算

def attention_forward_hook(self, module, input, output):# output: (batch, time, 3 * heads * attention_dim)with torch.no_grad():query = self.process_query(output)[:,-1] # (batch * beam, head, dim)query = query[:, self.head_nums] # (batch * beam, head, dim)
​#这是前面提到的计算优化使用每层每个head的Key映射矩阵对Query进行映射用于搜索attention_layer_list = self.attention_layer_to_capture(self.layer_begin, self.layer_end)k_proj_layer = [layers[0] for layers in attention_layer_list][self.cur_decoder_layer_index]# modify query by k_projs k_proj = k_proj_layer.weightk_proj = k_proj.view(1, self.num_heads, query.shape[-1], k_proj.shape[0]) # (1, num_heads, attn_dim, embed_dim)datastore_query = query.unsqueeze(-2) # (batch * beam, num_heads, 1, attn_dim)datastore_query = torch.matmul(datastore_query, k_proj) # (batch * beam, num_heads, 1, embed_dim)datastore_query = datastore_query.squeeze(-2)  # (batch * beam, num_heads, embed_dim)datastore_query = datastore_query.view((self.datastore.batch_size, -1, datastore_query.shape[2])) # (batch, beam * num_heads, embed_dim)# 这里进行Top Key的检索:得到Key的索引,Embedding和得分top_search_key_scores, top_search_key_indices = self.datastore.search(datastore_query, k=self.actual_model_window_size)embeddings = torch.take_along_dim(input=self.embeddings.unsqueeze(1), indices=top_search_key_indices.unsqueeze(-1).to(self.embeddings.device), dim=-2)
​##后面就是常规的对Embedding进行Key和Value的映射然后做Attention了

和前面的文本检索对比,unlimiformer的存储成本会更高,因为要存储token粒度的Embedding信息,更适用于on-the-fly的长文本推理使用,例如针对单一文档的QA,只存储当前文档,而前面文本块检索方案更适合一些大规模知识,批量的文档的存储。

但其实unlimiformer直接对Token进行离散召回,这一点我让我有些困惑,这样单一token的检索召回,真的不会破坏上文连续性么?还是说Encoder编码方式已经保证了检索召回大概率会召回成段的Token,又或者说每个Token的Embedding内已经充分编码了连续上下文的信息,召回离散Token也不会出现割裂的语义信息?哈哈考虑unlimiformer只支持Encoder-Decoder的框架,和我们用的Decoder框架不适配,我决定不细纠结了!有在中文尝试过效果的童鞋可以分享下~

并行输入:PCW

  • Parallel Context Windows for Large Language Models

  • GitHub - AI21Labs/Parallel-Context-Windows

  • 适用于Decoder模型,以及小规模内容理解场景

同样是对超长文本进行切块,然后独立编码,PCW使用的是Decoder框架。和unlimiformer只使用Top-Key进行解码,PCW在解码过程中对全部输入上文进行Attention。对比Encoder-Decoder框架,因为输入和输出都在Decoder侧,PCW需要解决两个问题:位置编码和注意力矩阵如何调整, 下图基本概括了这两个细节

img

\1. 位置编码:输入文本截断后,每段文本的位置编码相同。考虑所最长的文本长度为C,则输入文本最大的位置编码id是PC,则解码器第一个字的位置编码id是PC+1+1,然后顺序向后编码。其实就是丢弃了上文多段文本之间的位置关系,解码时只知道上文多段文本都是在解码器之前,但无法区分文本之间的位置。不过因为上文每段文本复用了相同的位置编码,因此位置编码的长度大幅降低,也就降低了对位置编码外推性的需求。

position_ids = attention_mask.long().cumsum(-1) - 1
n_task_tokens = position_ids.shape[1] - sum_windows_size
# 保证解码器的位置编码比最长上文要长度+1
position_ids[0, -n_task_tokens:] = torch.arange(max_window_size, max_window_size + n_task_tokens, 1)
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:  # i.e., first token is already generatedposition_ids = position_ids[:, -1].unsqueeze(-1)
elif windows_key_values:  # i.e., we are in the first token generation #其实就是取-n_task_tokens:position_ids = position_ids[:, sum_windows_size:]
  1. 注意力矩阵

  • 输入文本进行截断后各自独立通过Decoder进行编码。因此每一段输入的文本的注意力矩阵是相互独立的。这块不需要修改注意力矩阵的实现,只需要文本chunk后分别过模型即可。得到每段文本的past-key-values直接进行拼接

def combine_past_key_values(past_lst: List[Tuple[Tuple[torch.Tensor]]],contains_bos_token: bool = True) -> Tuple[Tuple[torch.Tensor]]:# 这里past_lst是每段文本的past-key-value# GPT是n_layer * 2(key+value) * tensor(seq_len,batch,n_head,n_hidden)# 注意不同模型past-key-value的shape不同# Chatglm是n_layer * 2(key+value) * tensor(seq_len,batch, n_head, n_hidden)return tuple((torch.cat([c[i][0] for c in past_lst], dim=2), torch.cat([c[i][1] for c in past_lst], dim=2))for i in range(len(past_lst[0])))
  • 解码器对全部上文进行Attention计算:这里需要修改Attention把上文的全部Attention进行拼接,让解码器的每一步可以对全部上文计算Attention

res['past_attention_mask'] = torch.cat([window['attention_mask'] for window in windows], dim=1)
combined_attention_mask = torch.cat((cache['past_attention_mask'], encoded_task_text['attention_mask']), dim=1)

考虑ChatGLM本身是二维的Attention矩阵和位置编码,特殊的BOS和GMASK,我重写了PCW,但是在长文本QA问题上表现比较一般,表现在当上文多段文本无明显关系的时候例如多个完全无关的新闻,在进行问答的时候,正确答案中会混杂很多无关的文本变短,以及这个问题当上文片段变多,或者指令问题变多的时候会变得越来越严重,直到开始完全胡说八道。当然不排除我写bug了哈哈哈,但我自己是真的没查出来。

不过也有一种可能,是PCW是在输入层就开始对超长上文进行Attention,因为不同上文的位置编码相同,一定程度上会让解码注意力变得非常分散,导致注意力的熵值变高,解码的不确定性变大,更容易出现乱码。

并行解码:NBCE

  • 苏剑林. (May. 23, 2023). 《NBCE:使用朴素贝叶斯扩展LLM的Context处理长度 》[Blog post]. Retrieved from NBCE:使用朴素贝叶斯扩展LLM的Context处理长度 - 科学空间|Scientific Spaces

  • 苏剑林. (May. 31, 2023). 《关于NBCE方法的一些补充说明和分析 》[Blog post]. Retrieved from 关于NBCE方法的一些补充说明和分析 - 科学空间|Scientific Spaces

  • GitHub - bojone/NBCE: Naive Bayes-based Context Extension

  • 适用于Encoder-Decoder模型,长文本内容理解如摘要问答等场景

压轴的必须是苏神的NBCE!这里我把看完博客后的理解进行简单的总结,详细推理请看去苏神的科学空间!答应我一定要去看!每次看苏神推导,都会觉得数学之魂在燃烧!

NBCE的原理简单解释如下图,和PCW相同是对每段上文进行独立编码,但差异在于PCW是在输入层进行融合,而NBCE是在输出层对每一个Step输出的预测token的概率矩阵进行融合,更大程度上避免了注意力被分散,保证了解码的合理性。

img

这里我们简单说下如何在输出层进行融合,把找超长文本chunk成多段文本后(s1,s2,...sk1,2,...,基于朴素贝叶斯的简化假设, 基于多段文本进行并行解码的预测概率可以简化如下,也就是每段文本条件解码概率之和减去无条件解码概率

既然说了是简化假设,因此可以对上式进行一些调优,核心是让模型对上文的解码更加准确,降低无关上文带来的解码噪声,比较重要的优化包括

  1. 准确率优化解码

以上解码概率求和,其实是对k段文本生成的vocab∗K的概率矩阵,沿K做AvergePooling,得到最终vocab∗1∗1的解码概率。但考虑LM训练其实是拟合one-hot(出现概率最高的词),也就是除了概率最高的几个token之外其余token的预测概率都不靠谱。如果直接取平均的多路打分,很容易投出一个在各段文本上打分都不高不低的token,上文越多这个问题越明显。但其实在阅读理解例如抽取,QA问题的解码策略上我们要的是在某段文本上打分置信度最高的token,因为答案往往只来自一个上文片段。

因此苏神给出了两种准确率更高的解码方案,一个是MaxPooling+GreedySearch,其实就是对vocab∗k的概率矩阵取全局概率最高的token,另一个是最小熵+RandomSampling,也就是从多段上文中取1个预测置信度最高的上文进行解码。这里其实是和PCW最大的差异,也就是在解码层进行融合,并通过熵值较低的融合策略来保证解码的准确率。

以及后面苏神还通过Top-P来进一步过滤尾部的噪声,以及通过控制每一步解码的转移概率,来让解码器不会在不同上文片段之间反复切换,而是保证连续的解码片段大概率来自相同的上文片段。

  1. Context-aware解码

基于上文来进行解码的一个核心是为了降低模型回答胡说八道的概率。例如在金融场景我们直接问chatgpt基金赎回费用是多少 vs 我们基于某个基金的介绍问模型该基金的赎回费用是多少,后者得到的答案一定是更准确的。而其实以上二者的差异在于条件(上文)解码和无条件解码, 因此可以通过diff无条件编码的方式来提高解码对上文的依赖程度(reliablity)。如下图

img

因此苏神把把n变成超参Beta, 控制条件概率和无条件概率的占比,Beta越高解码和上文的关联度越高,QA等场景的解码准确率越高,生成自由度越低。

当前NBCE的局限性在于无法处理上文片段之间的位置关系,以及无法处理解码需要依赖多个上文片段的场景。后者感觉可以通过预测概率矩阵的相关性修改Pooling方式,而前者

基于苏神提供的代码,在chatglm上做了尝试,只需要简单调整下输入输出的部分就可以直接使用。我在论文,书籍,和新闻上进行摘要,实体抽取和QA问答后发现,INT8量化的模型效果似乎要略优于FP16, 显著优于INT4。INT8量化下,10K左右的输入,显存占用基本可以限制在单卡A100(40g),大家可以自行尝试下~

@torch.inference_mode()
def generate(max_tokens):device = torch.device('cuda')"""Naive Bayes-based Context Extension 演示代码"""inputs = tokenizer(batch, padding='longest', return_tensors='pt').to(device)input_ids = inputs.input_idsn = input_ids.shape[0]
​with torch.no_grad():for i in range(max_tokens):# 模型输出model_input = model.prepare_inputs_for_generation(input_ids)
​outputs = model(**model_input, return_dict=True,use_cache=True)"""中间代码不变"""
​# 把唯一的回答扩充到每一个batch进行下一轮的解码next_tokens = next_tokens.unsqueeze(-1).tile(n, 1)input_ids = torch.cat([input_ids, next_tokens], dim=-1)# 更新past-key-values, 更新attention_mask, 更新position_idsmodel_kwargs = model._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=model.config.is_encoder_decoder)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/586323.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【neo4j】desktop下载

【neo4j】desktop下载 https://neo4j.com/download/ 点击download,填写表格 之后就可以正常使用了

智慧园区物联综合管理平台感知对象管理能力简述

物联感知对象管理, 不局限于物理传感设备, 还包括物联业务对象, 平台提供标准的设备建模能力以及标准的物联设备、 第三方物联系统SDK接入方案等; 实现对感知对象运行、 报警、 故障状态的反馈以及物联感知对象全生命周期信息管理。 基础定义配置 平台提供物联网目感知对…

Halcon纹理分析texture_laws/trans_from_rgb

Halcon纹理分析 文章目录 Halcon纹理分析1. 纹理滤波器2. 织物折痕检测 纹理是图像表面的一种灰度变化。有的纹理很规则,会以局部小区域为单元重复出现,而有的纹理则呈现出随机性。对于规则的纹理,可以很容易地从中分辨出重复的区域&#xff…

最新版 BaseRecyclerViewAdapterHelper4:4.1.2 最简单的QuickViewHolder用法,最简洁的代码,复制可用

为了照顾新手,尽量详细,高手勿喷!!! 怕麻烦的话可以直接下载源码:https://download.csdn.net/download/ERP_LXKUN_JAK/88678044?spm1001.2014.3001.5503 先看文件结构,是不是很简单 AndroidSt…

【Pytorch】学习记录分享10——PyTorchTextCNN用于文本分类处理

【Pytorch】学习记录分享10——PyTorchTextCNN用于文本分类处理 1. TextCNN用于文本分类2. 代码实现 1. TextCNN用于文本分类 具体流程: 2. 代码实现 # coding: UTF-8 import torch import torch.nn as nn import torch.nn.functional as F import numpy as np…

【C++】STL 容器 - set 集合容器 ⑦ ( 查找元素 - set#find 函数 | 获取元素个数 - set#count 函数 )

文章目录 一、查找元素 - set#find 函数1、函数原型 简介2、代码示例 - set#find 函数 二、获取元素个数 - set#count 函数1、函数原型 简介2、代码示例 - set#find 函数 一、查找元素 - set#find 函数 1、函数原型 简介 在 C 语言的 STL 标准模板库 , std::set 集合容器 是一个…

数据分析硬核工具Origin各版本安装指南

下载链接 https://pan.baidu.com/s/12mENFtRFdNaLzVKmE6w_Uw?pwd0531 1.鼠标右击【Origin 2022(64bit)】压缩包(win11及以上系统需先点击显示更多“选项”)选择【解压到 Origin 2022(64bit)】。 2.双击打开解压后的【Origin 2022(64bit)】文件夹。 3.…

Python学习 - 爬虫系统架构设计

主要业务流程 初始请求请求过滤器请求队列响应下载器数据解析器数据清洗器存储器 设计图 master slave:master控制队列,过滤,传递任务;slave负责执行 缺点:master和slave端交互数据频繁,slave的数据进出…

图文证明 牛顿-莱布尼茨公式

牛顿-莱布尼茨公式 牛顿-莱布尼茨公式是微积分中的基本定理之一,它描述了函数的导数和不定积分之间的关系。 该公式通常用来计算定积分。设函数f(x)在区间[a, b]上连续,且F(x)是f(x)在该区间上的一个原函数 即F’(x) f(x)。则牛顿-莱布尼茨公式表示为&…

【AIGC-图片生成视频系列-2】八仙过海,各显神通:AI生成视频相关汇总剖析

最近「图片生成视频系列」层出不穷,我拜读并结合实践(对,就是手撕代码,有开源就撕),并对以下几篇文章的相似点以及关键点稍微做个总结: 一. 生成视频中图像的一致性 在图像生成视频的这个过程…

提升CSC加分项|高职教师赴新西兰惠灵顿维多利亚大学访学交流

S老师科研背景条件一般,担心无法获得邀请函及通过CSC审批。我们建议:1.以加强国际合作和跨学科合作的方式,增强高职院校的影响力,为CSC评审提供加分项;2.同时申报4月份的国家公派和5月份的西部/地方合作项目&#xff0…

Java进阶(第八期): Java中递归的的使用和递归解决一些算法问题 Java中的异常机制、异常的处理逻辑 自定义异常

文章目录 一、递归1.1 递归的介绍1.2 递归的简单练习1.3 图解递归执行流程:1.4 使用递归完成悲波那契数列1.5 猴子吃桃子问题 二、异常三 、异常的处理逻辑3.1 try catch 捕获异常3.2 throws抛出异常 四、自定义异常 Java进阶(第八期) 一、递…

2、gdb常用功能2

1.4、线程 程序避免不了涉及到多线程.常用指令如下. 命令简写形式说明info thread显示当前进程内所有线程信息thread 切换到num线程thread find 寻找regexp在gdb中的idinfo address 结合上述图片理解,第一列的id是gdb内部为线程排序的一个id,第三列中…

行人重识别(ReID)基础知识入门

这里写目录标题 1、ReID技术概述1.1 基本原理1.2 实现流程1.3 重识别存在的技术挑战 2、训练数据格式介绍 1、ReID技术概述 1.1 基本原理 ReID,全称Re-identification,目的是利用各种智能算法在图像数据库中找到与要搜索的目标相似的对象。ReID是图像检…

Eureka服务注册与发现

1. Eureka简介 Eureka采用了CS的设计架构,Eureka Server 作为服务注册功能的服务器,它是服务注册中心。而系统中的其他微服务,使用 Eureka的客户端连接到 Eureka Server并维持心跳连接。这样系统的维护人员就可以通过 Eureka Server 来监控系…

三个故事,谈谈小米汽车技术发布会

都说新年新气象,随着年末消费旺季到来,汽车市场越来越热闹了。 继蔚来12月23日公布旗舰车型ET9,华为26日发布问界M9,小米汽车首款量产车型SU7终于正式亮相。 12月28日,在小米汽车技术发布会上,小米创办人…

CCNP课程实验-Route_Path_Control_CFG

目录 实验条件网络拓朴需求 配置实现基础配置需求实现1.A---F所有区用Loopback模拟,地址格式为:XX.XX.XX.XX/32,其中X为路由器编号。根据拓扑宣告进对应协议。A1和A2区为特例,A1:55.55.55.0/24,A2&#xff…

【数学建模美赛M奖速成系列】Matplotlib绘图技巧(二)

Matplotlib绘图技巧(二) 写在前面2. 函数间区域填充函数fill_between()和fill()参数: 3. 散点图 scatter4. 直方图 hist5. 条形图 bar5.1 一个数据样本的条形图参数: 5.2 多个数据样本进行对比的直方图5.3 水平条形图参数 5.4 绘制…

Vue3-31-路由-RouterView的name属性的作用

作用描述 <router-view> 标签是用来渲染路由对应的组件的位置&#xff1b; 默认情况下&#xff0c;一个路由是只对应一个组件的。 但是&#xff0c;可以通过给 <router-view> 指定 name 属性的方式&#xff0c;实现同时渲染多个组件的效果。 这也叫做 命名视图。 注…

《企业数据资源相关会计处理暂行规定》学习笔记

附&#xff1a;2023年数据资源入表白皮书下载&#xff1a; 关注WX公众号&#xff1a; commindtech77&#xff0c; 获得数据资产相关白皮书下载地址 1. 回复关键字&#xff1a;数据资源入表白皮书 下载 《2023数据资源入表白皮书》 2. 回复关键字&#xff1a;光大银行 下载 光…