GLM3源码学习

原文链接:chatglm源码学习

GLM3源码:https://github.com/THUDM/ChatGLM3

我们直接从openai_api_demo入手,因为api_demo一般是nlp模型后端核心功能实现的部分

openai_api_demo源码

api_server.py

api_server.py是提供web api接口的入口文件,是使用flask框架提供的一个异步接口支持

app = FastAPI(lifespan=lifespan)
class ModelCard(BaseModel):
...
class ChatCompletionResponse(BaseModel):

上面这一堆class是实现chat这个api功能的主要对象,如模型卡、请求体和响应体

@app.get("/health")
async def health() -> Response:"""Health check."""return Response(status_code=200)

这个是测试api状态函数,可以看到这个测试功能还是很直接的,没有考虑部署应用下的问题,如负载情况和安全状况,这个demo也就是一个学习的小demo项目。

@app.post("/v1/embeddings", response_model=EmbeddingResponse)
async def get_embeddings(request: EmbeddingRequest):if isinstance(request.input, str):# 判断输入是否是字符串,字符串直接编码,否则对字符串列表编码embeddings = [embedding_model.encode(request.input)]else:embeddings = [embedding_model.encode(text) for text in request.input]embeddings = [embedding.tolist() for embedding in embeddings]def num_tokens_from_string(string: str) -> int:"""Returns the number of tokens in a text string.use cl100k_base tokenizer"""encoding = tiktoken.get_encoding('cl100k_base')num_tokens = len(encoding.encode(string))return num_tokensresponse = {"data": [{"object": "embedding","embedding": embedding,"index": index}for index, embedding in enumerate(embeddings)],"model": request.model,"object": "list","usage": CompletionUsage(prompt_tokens=sum(len(text.split()) for text in request.input), completion_tokens=0,total_tokens=sum(num_tokens_from_string(text) for text in request.input),)}return response

这个函数是获取文本向量编码的,sentences_to_embeddings功能。
这里面有个函数num_tokens_from_string是统计文本的tokens数量,使用的tiktoken模块是openai开源的一个快速分词统计库,cl100k_base是和gpt4同款编码器,也就是说glm3的tokenizer实际上是使用的gpt4的tokenizer,在论文里面glm的baseline是最开始的gpt-1模型,那从理论上,glm3的性能提升肯定会受到分词的影响的(清华博士教大家的水论文小技巧hhh)。

class ModelCard(BaseModel):id: strobject: str = "model"created: int = Field(default_factory=lambda: int(time.time()))owned_by: str = "owner"root: Optional[str] = Noneparent: Optional[str] = Nonepermission: Optional[list] = None@app.get("/v1/models", response_model=ModelList)
async def list_models():model_card = ModelCard(id="chatglm3-6b")return ModelList(data=[model_card])

这个list_models直接限定了就是chatglm3-6b模型了,里面没有包括实际的模型

@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):global model, tokenizerif len(request.messages) < 1 or request.messages[-1].role == "assistant":raise HTTPException(status_code=400, detail="Invalid request")#截这个的原因是gpt模型是允许任意角色的消息序列的包括assistant多次生成的功能。glm3则不允许if request.stream:#SSE流式响应response = generate_chatglm3(model, tokenizer, gen_params) #直接响应message = ChatMessage(role="assistant",content=response["text"],...)#创建消息体#计算使用量然后返回响应体,choice_data里面只放了一个数据return ChatCompletionResponse(model=request.model,id="",  # for open_source model, id is emptychoices=[choice_data],object="chat.completion",usage=usage)

chat最核心的响应函数了,由于函数较长就不全截了。

首先我们看到的是一个消息验证不允许assistant多次生成,原因主要是这个功能本身对助手是没有什么意义的,而且多次生成的训练效果比较差,之前我测试过gpt api的多次生成。因为他们用的训练数据基本上都是一个消息内全部回复了,上下文数据本身不存在多次生成的场景,因此这些模型多次生成并不是把问题分多次回复(和人类不同,一句话可以多方面讲,分段讲),只是把答案回答多次。

如果要实现更真实的问答AI,拥有更真实的对话体验,那对数据的要求是很高的,最好的数据集应该是QQ微信这种聊天软件的数据,但是企业是不可能拿这些隐私数据训练的。不过也有平替,如贴吧微博这些开放平台的数据也是很好的,但是这些数据看过后,上下文的逻辑性还是有问题的,并且多轮对话的人物被屏蔽了,也就是说明明是多个人的对话被训练成了二人的对话,这些模型后面肯定被高质量多轮对话微调过,不然单纯这些语料不会达到gpt的这种效果。

响应类型分为直接响应和SSE响应,其中直接响应简单,就是拿model直接推理得到message。
这里有个问题是这个chat函数是asyn异步的,但是model资源是global的单个模型,如果同时多个请求可能会报错。可以对模型封装个请求拥塞队列,比如大于3个请求就返回繁忙。

SSE响应部分和直接响应不同,SSE没有提供使用量这些信息,仅返回了响应文本,SSE还对前端的响应方法有要求,因此如果是仅学习开发和小规模应用没有必要追求SSE

    predict_stream_generator = predict_stream(request.model, gen_params)output = next(predict_stream_generator)if not contains_custom_function(output):return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")

通过predict_stream创建一个生成器,next生成下一个字符然后返回。

utils.py

utils.py提供了响应的实现函数generate_stream_chatglm3和generate_chatglm3

def generate_chatglm3(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, params: dict):for response in generate_stream_chatglm3(model, tokenizer, params):passreturn response

循环调用generate_stream_chatglm3后返回响应

generate_stream_chatglm3函数

def generate_stream_chatglm3(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, params: dict):messages = params["messages"] #消息tools = params["tools"] #工具temperature = float(params.get("temperature", 1.0)) #温度参数repetition_penalty = float(params.get("repetition_penalty", 1.0))#惩罚参数,transformer有个问题就是高概率文本会重复生成,在有的论文中提出了惩罚参数,即对已经生成的token的概率乘上惩罚参数让这个token的概率变小,减小重复概率。top_p = float(params.get("top_p", 1.0)) #top_p top_k是采样的一个过滤方法,p是按概率阈值过滤,k是按排序过滤max_new_tokens = int(params.get("max_tokens", 256)) #最大允许新生成的tokensecho = params.get("echo", True)messages = process_chatglm_messages(messages, tools=tools)#消息处理query, role = messages[-1]["content"], messages[-1]["role"]#最后一个消息内容inputs = tokenizer.build_chat_input(query, history=messages[:-1], role=role) #把历史和问题构建输入inputs = inputs.to(model.device)input_echo_len = len(inputs["input_ids"][0])#输入编码序列长度if input_echo_len >= model.config.seq_length: #输入序列长度限制print(f"Input length larger than {model.config.seq_length}")eos_token_id = [ #结束tokentokenizer.eos_token_id,tokenizer.get_command("<|user|>"),tokenizer.get_command("<|observation|>")]gen_kwargs = { #控制参数"max_new_tokens": max_new_tokens,"do_sample": True if temperature > 1e-5 else False,"top_p": top_p,"repetition_penalty": repetition_penalty,"logits_processor": [InvalidScoreLogitsProcessor()],}if temperature > 1e-5:gen_kwargs["temperature"] = temperaturetotal_len = 0for total_ids in model.stream_generate(**inputs, eos_token_id=eos_token_id, **gen_kwargs):total_ids = total_ids.tolist()[0]total_len = len(total_ids)if echo: #没看懂echo什么意思 input_echo_len应该是生成的total_ids中echo控制是否对问题重复一遍,重复了就减掉output_ids = total_ids[:-1]  else:output_ids = total_ids[input_echo_len:-1]#反正output_ids是stream_generate的idsresponse = tokenizer.decode(output_ids)if response and response[-1] != "�": #乱码了就跳出response, stop_found = apply_stopping_strings(response, ["<|observation|>"]) #判断是否结束yield { #yield作为一个生成器每次调用生成output_ids然后返回"text": response,"usage": {"prompt_tokens": input_echo_len,#输入tokens"completion_tokens": total_len - input_echo_len,#总tokens-重复的输入tokens"total_tokens": total_len,#总tokens},"finish_reason": "function_call" if stop_found else None,}if stop_found:break#最后一个字符跳出返回结束# Only last stream result contains finish_reason, we set finish_reason as stopret = {"text": response,"usage": {"prompt_tokens": input_echo_len,"completion_tokens": total_len - input_echo_len,"total_tokens": total_len,},"finish_reason": "stop",}yield ret#内存显存收下垃圾gc.collect()torch.cuda.empty_cache() 

其中里面有个函数很关键process_chatglm_messages:消息处理函数

def process_chatglm_messages(messages, tools=None):_messages = messagesmessages = []msg_has_sys = Falseif tools:messages.append({"role": "system","content": "Answer the following questions as best as you can. You have access to the following tools:","tools": tools})msg_has_sys = Truefor m in _messages:role, content, func_call = m.role, m.content, m.function_callif role == "function":messages.append({"role": "observation","content": content})elif role == "assistant" and func_call is not None:for response in content.split("<|assistant|>"):metadata, sub_content = response.split("\n", maxsplit=1)messages.append({"role": role,"metadata": metadata,"content": sub_content.strip()})else:if role == "system" and msg_has_sys:msg_has_sys = Falsecontinuemessages.append({"role": role, "content": content})return messages

这个函数就是把message对象转化为dict对象,我们可以看到这里面有system、observation、assistant、user。

在generate_stream_chatglm3:inputs = tokenizer.build_chat_input(query, history=messages[:-1], role=role)
这个函数中把dict对象对应的history转换成了文本格式,例如:

<|system|>
You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.
<|user|>
Hello
<|assistant|>
Hello, I'm ChatGLM3. What can I assist you today?

也就是说我们以为的多轮对话实际上就是把历史记录拼起来的。

这部分想到了个idea,这种拼起来的实际上有历史限制,如果让模型生成每个对话的重要性,然后按照重要性+时间权重排序选择性记忆能不能增强长期记忆能力?感觉这部分应该有人在做或者做出来了。

main

最后回来看下api_server.py的main函数

if __name__ == "__main__":# Load LLMtokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map="auto").eval()# load Embeddingembedding_model = SentenceTransformer(EMBEDDING_PATH, device="cuda")uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)

main函数从transformers加载模型然后作为global对象推理。
transformers模型就和传统的bert、t5类似了

chatglm的改进主要包括:

  • 二维位置编码+GELU+残差、层归一化重排序
  • 文档级+句子 NLG预训练
  • NLG+NLU两种任务都进行训练,同时微调的时候还使用了slot填空的NLU方法

总结

之前没怎么看过这种有上下文模型响应的完整流程,这趟下来解决了我之前好几个疑惑:

  1. transformer的重复问题我遇到了好几次,可以通过惩罚参数控制
  2. 上下文实现方法-实际上还是把历史对话融在一起
  3. 模型推理资源占用问题,请求队列感觉是一定要有的,web框架本身是异步请求响应的,不对临界资源管理感觉没啥可靠性

加上这个,目前已经把带上下文的文本生成+知识库扩展永久记忆解决了,后面再对模型结构魔改下,然后集成一些动作指令,就可以实现本地部署家用AI了hhh。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/46022.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【面试题】Golang 之Channel底层原理 (第三篇)

目录 1.常见channel三大坑&#xff1a;死锁、内存泄漏、panic 1.死锁 1.只有生产者&#xff0c;没有消费者&#xff0c;或者反过来 2 生产者和消费者出现在同一个 goroutine 中 3 buffered channel 已满&#xff0c;且在同一个goroutine中 2.内存泄露 1 如何实现 gorout…

CSS学习碎碎念之卡片展示

效果展示&#xff1a; 代码展示 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>图片展示</title…

Android C++系列:Linux网络(三)协议格式

1. 数据包封装 传输层及其以下的机制由内核提供,应用层由用户进程提供(后面将介绍如何使用 socket API编写应用程序),应用程序对通讯数据的含义进行解释,而传输层及其以下 处理通讯的细节,将数据从一台计算机通过一定的路径发送到另一台计算机。应用层 数据通过协议栈发到…

《Linux系统编程篇》vim的使用 ——基础篇

引言 上节课我们讲了&#xff0c;如何将虚拟机的用户目录映射到自己windows的z盘&#xff0c;虽然这样之后我们可以用自己的编译器比如说Visual Studio Code&#xff0c;或者其他方式去操作里面的文件&#xff0c;但是这是可搭建的情况下&#xff0c;在一些特殊情况下&#xf…

C# 使用 NPOI 处理Excel,导入单元格内容是公式的处理

在C#中使用NPOI库处理Excel文件时&#xff0c;如果单元格内容包含公式&#xff0c;NPOI能够读取这些公式以及它们计算后的值。NPOI是一个开源的.NET库&#xff0c;用于处理Microsoft Office文档&#xff0c;特别是Excel文件&#xff08;.xls和.xlsx&#xff09;。 要处理包含公…

【小超嵌入式】C++猜数字游戏详细分析

一、程序源码 #include <iostream> #include <cstdlib> #include <ctime>using namespace std;int main() {srand(static_cast<unsigned int>(time(0))); // 随机数种子int targetNumber rand() % 100 1; // 生成 1 到 100 之间的随机数int guess…

helm系列之-构建自己的Helm Chart

构建自己的Helm Chart 一般常见的应用&#xff08;nginx、wordpress等&#xff09;公有的helm仓库都提供了chart&#xff0c;可以直接安装或者自定义安装。下面实践从零构建自己的helm chart应用。 准备工作 准备一个用于部署测试的应用镜像并推送到镜像仓库。 应用代码 这…

Linux 命令个人学习笔记

1. 操作目录的命令 (1) ls : 查看指定目录中, 都有哪些内容 直接输入 ls 是查看当前目录中的内容. 还可以给 ls 后面加上一个路径(绝对/相对), 就可以查看指定目录中的内容 比如看根目录(刚安装Centos下) ls / 根目录的地位类似于Java中的Object ls -l 详细查看当前文件的内容…

(十一) Docker compose 部署 Mysql 和 其它容器

文章目录 1、前言1.1、部署 MySQL 容器的 3 种类型1.2、M2芯片类型问题 2、具体实现2.1、单独部署 mysql 供宿主机访问2.1.1、文件夹结构2.1.2、docker-compose.yml 内容2.1.3、运行 2.2、单独部署 mysql 容器供其它容器访问&#xff08;以 apollo 为例&#xff09;2.2.1、文件…

pyinstaller教程(二)-快速使用(打包python程序为exe)

1.介绍 PyInstaller 是一个强大的 Python 打包工具&#xff0c;可以将 Python 程序打包成独立的可执行文件。以下会基于如何在win系统上将python程序打包为exe可执行程序为例&#xff0c;介绍安装方式、快速使用、注意事项以及特别用法。 2.安装方式 通过 pip 安装 PyInstal…

万界星空科技MES系统:食品加工安全的实时监控与智能管理

万界星空科技MES系统通过集成多种技术和功能&#xff0c;能够实时监控食品加工过程中各环节的安全风险。以下是对该系统如何实现实时监控的详细分析&#xff1a; 一、集成传感器和数据分析技术 万界星空科技MES系统利用集成的传感器和数据分析技术&#xff0c;实时监控生产过程…

基于SSM的校园一卡通管理系统的设计与实现

摘 要 本报告全方位、深层次地阐述了校园一卡通管理系统从构思到落地的整个设计与实现历程。此系统凭借前沿的 SSM&#xff08;Spring、Spring MVC、MyBatis&#xff09;框架精心打造而成&#xff0c;旨在为学校构建一个兼具高效性、便利性与智能化的一卡通管理服务平台。 该系…

数学建模入门

目录 文章目录 前言 一、数学建模是什么&#xff1f; 1、官方概念&#xff1a; 2、具体过程 3、适合哪一类人参加&#xff1f; 4、需要有哪些学科基础呢&#xff1f; 二、怎样准备数学建模&#xff08;必备‘硬件’&#xff09; 1.组队 2.资料搜索 3.常用算法总结 4.论文撰写的…

微前端解决方案

在实施微前端架构时&#xff0c;前端框架和技术的选型是非常重要的。不同的框架和技术有着不同的优缺点&#xff0c;需要结合具体的应用场景进行选择。一、常见的微前端解决方案 Web Components Web Components&#xff08;包括Custom Elements、Shadow DOM和HTML Imports&…

数据建设实践之大数据平台(一)准备环境

大数据组件版本信息 zookeeper-3.5.7hadoop-3.3.5mysql-5.7.28apache-hive-3.1.3spark-3.3.1dataxapache-dolphinscheduler-3.1.9大数据技术架构 大数据组件部署规划 node101node102node103node104node105datax datax datax ZK ZK ZK RM RM NM

HTML网页大设计-家乡普宁德安里

代码地址: https://pan.quark.cn/s/57e48c3b3292

Layer2是什么?为什么需要Layer2?

目录 什么是Layer1需要Layer2的原因概念结构图Layer2有哪些风险 什么是Layer1 要了解Layer2前&#xff0c;需要先了解下Layer1。 一层网络&#xff08;Layer 1 Network&#xff09;通常指的是区块链技术中的主链或基础层&#xff0c;它提供了区块链的核心功能和特性。以下是一…

二分图——AcWing 257. 关押罪犯

目录 二分图 定义 运用情况 注意事项 解题思路 AcWing 257. 关押罪犯 题目描述 运行代码 代码思路 改进思路 二分图 定义 二分图&#xff08;Bipartite Graph&#xff09;是一种特殊的图&#xff0c;在这种图中&#xff0c;顶点可以被分成两个互不相交的集合&…

C语言 | Leetcode C语言题解之第233题数字1的个数

题目&#xff1a; 题解&#xff1a; int countDigitOne(int n) {// mulk 表示 10^k// 在下面的代码中&#xff0c;可以发现 k 并没有被直接使用到&#xff08;都是使用 10^k&#xff09;// 但为了让代码看起来更加直观&#xff0c;这里保留了 klong long mulk 1;int ans 0;f…

硬盘HDD:AI时代的战略金矿?

在这个AI如火如荼的时代&#xff0c;你可能以为硬盘HDD已经像那些过时的诺基亚手机一样&#xff0c;被闪存和云存储淘汰到历史的尘埃里。但&#xff0c;别急着给HDD们举行退休派对&#xff0c;因为根据Finis Conner这位硬盘界的传奇人物的说法&#xff0c;它们非但没退场&#…