ChatGLM3 源码解析(五)

PrefixEncoder

# 根据前缀 ID 获取前缀嵌入
# 前缀嵌入将连接到分头之后的 K 和 V 上
class PrefixEncoder(torch.nn.Module):"""The torch.nn model to encode the prefixInput shape: (batch-size, prefix-length)Output shape: (batch-size, prefix-length, 2*layers*hidden)"""def __init__(self, config: ChatGLMConfig):super().__init__()# 控制是否开启前缀投影,即用两层 MLP 处理前缀嵌入self.prefix_projection = config.prefix_projectionif self.prefix_projection:# KVSize = NLayer * 2 * NGroup * HeadSizekv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2# 将 ID 变为嵌入的嵌入层,[PreSeqLen, KVSize]self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size)# 处理嵌入的 MLP# 映射到 HidSize, 计算 tanh,在映射到 KVSizeself.trans = torch.nn.Sequential(torch.nn.Linear(kv_size, config.hidden_size),torch.nn.Tanh(),torch.nn.Linear(config.hidden_size, kv_size))else:# 将 ID 变为嵌入的嵌入层self.embedding = torch.nn.Embedding(config.pre_seq_len,config.num_layers * config.kv_channels * config.multi_query_group_num * 2)def forward(self, prefix: torch.Tensor):# 前缀 ID 尺寸为 [BatchSize, PreSeqLen]# 根据前缀 ID 获取嵌入,尺寸为 [BatchSize, PreSeqLen, KVSize]# 如果设定了需要投影,就用两层 MLP 处理嵌入if self.prefix_projection:prefix_tokens = self.embedding(prefix)past_key_values = self.trans(prefix_tokens)else:past_key_values = self.embedding(prefix)return past_key_values

ChatGLMPreTrainedModel

class ChatGLMPreTrainedModel(PreTrainedModel):"""An abstract class to handle weights initialization anda simple interface for downloading and loading pretrained models."""is_parallelizable = Falsesupports_gradient_checkpointing = Trueconfig_class = ChatGLMConfigbase_model_prefix = "transformer"_no_split_modules = ["GLMBlock"]def _init_weights(self, module: nn.Module):"""Initialize the weights."""return# 从输入单词 ID,KVCache生成默认的(上三角)掩码矩阵def get_masks(self, input_ids, past_key_values, padding_mask=None):# 单词 ID 尺寸为 [BatchSize, SeqLen]batch_size, seq_length = input_ids.shape# 掩码矩阵初始化为全 1,形状为 [BatchSize, SeqLen, SeqLen],每个输入序列一个full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)# 保留其下三角元素,其余设为 9full_attention_mask.tril_()# CacheLen:KVCache 中序列长度# 如果没有提供则设为 0,如果提供了,从中获取长度past_length = 0if past_key_values:past_length = past_key_values[0][0].shape[0]# 如果提供了 KVCache,在每个掩码矩阵的上方填充 1,形状为 [BatchSize, SeqLen, CacheSeqLen]if past_length:full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length,device=input_ids.device), full_attention_mask), dim=-1)# 如果提供了掩码数组([BatchSize, (Cache)SeqLen])# 将其变形为 [BatchSize, 1, (Cache)SeqLen]# 然后与掩码矩阵相乘# 将掩码数组为0的列设为0if padding_mask is not None:full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)# 如果提供了掩码数组,并且没有提供 KVCache# 将其变形为 [BatchSize, SeqLen, 1]# 然后将掩码数组为 0 的行设为 1if not past_length and padding_mask is not None:full_attention_mask -= padding_mask.unsqueeze(-1) - 1# 小于 0.5 变成 true,大于 0.5 变成 false,相当于将其翻转,上三角不为 0full_attention_mask = (full_attention_mask < 0.5).bool()# 分头,变形为 [BatchSize, 1, SeqLen, SeqLen]full_attention_mask.unsqueeze_(1)return full_attention_mask# 从输入单词 ID 生成默认的(从零开始的)序列 IDdef get_position_ids(self, input_ids, device):# 单词 ID 尺寸为 [BatchSize, SeqLen]batch_size, seq_length = input_ids.shape# 序列 ID 创建为 0~(SeqLen-1)的一维数组# 变形为 [1, SeqLen],之后重复第一维 BatchSize 次,得到 [BatchSize, SeqLen]position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)return position_idsdef _set_gradient_checkpointing(self, module, value=False):if isinstance(module, GLMTransformer):module.gradient_checkpointing = value

ChatGLMForConditionalGeneration.stream_generate()

    @torch.inference_mode()def stream_generate(self,input_ids,generation_config: Optional[GenerationConfig] = None,logits_processor: Optional[LogitsProcessorList] = None,stopping_criteria: Optional[StoppingCriteriaList] = None,prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,return_past_key_values=False,**kwargs,):batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]if generation_config is None:generation_config = self.generation_configgeneration_config = copy.deepcopy(generation_config)model_kwargs = generation_config.update(**kwargs)model_kwargs["use_cache"] = generation_config.use_cachebos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_idif isinstance(eos_token_id, int):eos_token_id = [eos_token_id]eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else Nonehas_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not Noneif has_default_max_length and generation_config.max_new_tokens is None:warnings.warn(f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. ""This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"" recommend using `max_new_tokens` to control the maximum length of the generation.",UserWarning,)elif generation_config.max_new_tokens is not None:generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_lengthif not has_default_max_length:logger.warn(f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. ""Please refer to the documentation for more information. ""(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",UserWarning,)# 如果 SeqLen 大于等于配置里设定的 MaxSeqLen,发出警告if input_ids_seq_length >= generation_config.max_length:input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"logger.warning(f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"" increasing `max_new_tokens`.")# 如果没有提供 logits 处理器,初始化为空列表logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()# 没有提供停止标准,初始化为空列表stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()# 根据生成配置等对象获取 logits 处理器logits_processor = self._get_logits_processor(generation_config=generation_config,input_ids_seq_length=input_ids_seq_length,encoder_input_ids=input_ids,prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,logits_processor=logits_processor,)# 根据生成配置等对象获取停止标准stopping_criteria = self._get_stopping_criteria(generation_config=generation_config, stopping_criteria=stopping_criteria)# 根据生成配置获取 logits 包装器logits_warper = self._get_logits_warper(generation_config)# 未完成标志,表示每个序列是否生成完毕的数组# 初始化为 [BatchSize] 尺寸的全 1 数组unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)scores = Nonewhile True:# 根据传入参数组装成字典,请见该方法定义model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)# 将单词 ID 传入模型,得到(所有前缀)下一个单词的 logits# [BatchSize, SeqLen, VocabSize]outputs = self(**model_inputs,return_dict=True,output_attentions=False,output_hidden_states=False,)# 截取 SeqLen 维度的最后一维,得到整句话下一个单词的 logits# [BatchSize, VocabSize]next_token_logits = outputs.logits[:, -1, :]# 传入 logits 处理器和包装器,修正 logitsnext_token_scores = logits_processor(input_ids, next_token_logits)next_token_scores = logits_warper(input_ids, next_token_scores)# 计算 softmax 得到概率值probs = nn.functional.softmax(next_token_scores, dim=-1)# 如果设定了需要采样,对其进行多项式采样,样本容量为 1# 否则直接取最大的# 得到下个单词 ID,尺寸为 [BatchSize]if generation_config.do_sample:next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)else:next_tokens = torch.argmax(probs, dim=-1)# 下个单词 ID 变形为 [BatchSize, 1],然后和输入单词 ID 拼接input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)# 根据当前输出更新KVCache、注意力掩码和位置IDmodel_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder)# `next_tokens` 变形为 [1, BatchSize],再将第一维重复 NEOS 次,[NEOS, BatchSize]# `eos_token_id_tensor` 变形为 [NEOS, 1],将广播第二维变成 [NEOS, BatchSize]# 之后二者逐元素比较是否不相等,形成一个比较结果,尺寸为 [NEOS, BatchSize]# 之后按照 BatchSize 维度计算乘积,得到未完成标志,[BatchSize]# 如果某个序列等于终止符集合里面的任意一个,那么比较结果就会出现一个 0,未完成标志将会是 0。unfinished_sequences = unfinished_sequences.mul(next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0))# 如果指定了返回 KVCache# 产生输入ID和已生成的输出ID# 和 KVCache# 否则只产生第一个if return_past_key_values:yield input_ids, outputs.past_key_valueselse:yield input_ids# 如果未完成标志全为零(表示序列都已生成完毕),或者达到了停止标准,就停止生成if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):break

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/742840.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2024 前端javaScript+ES6

JavaScript 基础 1、基本数据类型&#xff1a; 1.1 基本数据类型&#xff1a; Number&#xff08;数值&#xff09;&#xff1a;表示数字&#xff0c;包括整数和浮点数。例如&#xff1a;5、3.14。 String&#xff08;字符串&#xff09;&#xff1a;表示文本数据&#xff…

视觉图像处理和FPGA实现第三次作业--实现一个加法器模块

一、adder模块 module adder(ina, inb, outa); input [5:0] ina ; input [5:0] inb ; output [6:0] outa ;assign outa ina inb; endmodule二、add模块 module add(a,b,c,d,e); input [5:0] a ; input [5:0] b ; input [5:…

阿里云国际修改域名绑定的DDoS高防服务器

本文九河云介绍当您的业务需要绑定多个DDoS高防实例&#xff0c;或者已有的DDoS高防实例已过期需要更换时&#xff0c;如何修改域名接入的配置&#xff0c;才能在业务不中断的前提下平滑迁移。 需绑定多个DDoS高防实例的场景 当网站业务有如下业务需求时&#xff0c;可以为域…

PTA L2-020 功夫传人

一门武功能否传承久远并被发扬光大&#xff0c;是要看缘分的。一般来说&#xff0c;师傅传授给徒弟的武功总要打个折扣&#xff0c;于是越往后传&#xff0c;弟子们的功夫就越弱…… 直到某一支的某一代突然出现一个天分特别高的弟子&#xff08;或者是吃到了灵丹、挖到了特别的…

netty服务器监听和接收数据

1.pom依赖 <dependency><groupId>io.netty</groupId><artifactId>netty-all</artifactId><!-- 根据需要选择版本 --><version>4.1.86.Final</version> </dependency>2.配置属性 application.properties #启动端口 ser…

从零开始,一步步构建服务网格istio

一、环境情况 环境&#xff1a;Ubuntu20.04 机器数量&#xff1a;单机1台 IP&#xff1a;10.9.2.83 二、准备知识 为什么使用 Istio&#xff1f; Istio提供了一种更高级别的服务网格解决方案&#xff0c;它可以简化和加强 Kubernetes 集群中的服务间通信、流量管理、安全…

Git操作指南:子模块、用户名修改和Subtree

引言 在软件开发中&#xff0c;版本控制是一个至关重要的环节。Git 作为目前最流行的版本控制工具之一&#xff0c;提供了丰富的功能和灵活的操作方式。本文将介绍一些常用的 Git 操作&#xff0c;包括管理子模块、修改用户名、使用 Git Subtree 合并项目以及其他一些常见操作…

基于R语言APSIM模型应用

随着数字农业和智慧农业的发展&#xff0c;基于过程的农业生产系统模型在模拟作物对气候变化的响应与适应、农田管理优化、作物品种和株型筛选、农田固碳和温室气体排放等领域扮演着越来越重要的作用。APSIM (Agricultural Production Systems sIMulator)模型是世界知名的作物生…

鸿蒙开发之MPChart图表开发

一、简介 随着移动应用的不断发展,数据可视化成为提高用户体验和数据交流的重要手段之一,因此需要经常使用图表,如折线图、柱形图等。OpenHarmony提供了一个强大而灵活的图表库是实现这一目标的关键。 在 ohpm 中心仓(https://ohpm.openharmony.cn/)中,汇聚了众多开发者…

ubuntu如何添加快捷方式到收藏夹、桌面

一、背景 有时候单独下载的软件包需要在特定路径里启动&#xff0c;这样使用起来非常不方便。因此需要在桌面和收藏夹里创建启动快捷方式。 二、具体步骤 这里以下载的zotero软件&#xff08;一款用于文献管理的软件&#xff09;为例。官网地址: Zotero | Your personal res…

python控制语句-2.1

目录 while循环 while循环练习-1 while 循环 - break 语法 while 循环 - continue 语法 while 循环 - else 语法 while循环练习-2 while循环 while循环练习-1 求1到n的交错和输入正整数 n&#xff0c;求 1 到 n 的交错和&#xff1a;即 -12-34-56-7...((-1)^n)*nn eval(…

shell脚本中数组元素赋值

在Shell&#xff08;特别是Bash&#xff09;脚本中定义和赋值数组有几种不同的方法。基本的数组赋值语句如下&#xff1a; # 无索引数组的赋值 array_name(element1 element2 element3)其中 element1 element2 element3 是数组 array_name 的元素。 如果你想要更新现有数组的…

【gpt实践】同时让chatgpt和claude开发俄罗斯方块

最近chatgpt和claude都在使用&#xff0c;其实大部分日常使用场景表现都没有相差太多&#xff0c;想搞一个有趣的小实验&#xff0c;如果同时让chatgpt和claude开发俄罗斯方块谁会表现的更好呢&#xff0c;说干就干&#xff01; prompt 我选择了用英文描述&#xff0c;毕竟英…

Unity中计算两个三维坐标点的各种方法

1、 根据勾股定理计算两点的距离 /// <summary>/// 根据勾股定理计算两点的距离/// </summary>/// <param name"point1"></param>/// <param name"point2"></param>/// <returns></returns>private float…

《如何使用C语言去下三子棋?》

目录 一、环境配置 二、功能模块 1.打印菜单 2.初始化并打印棋盘 3、行棋 3.1玩家行棋 3.2电脑行棋 4、判断是否和棋 5.判赢 三、代码实现 1、test.c文件 2、game.c文件 3、game.h文件 一、环境配置 本游戏用到三个文件&#xff0c;分别是两个源文件test.c game.c 和…

JWT令牌校验是什么东西?举个例子

JWT&#xff08;JSON Web Token&#xff09;令牌校验是验证JWT令牌的有效性和真实性的过程。JWT是一种用于在网络应用间安全传递信息的开放标准&#xff08;RFC 7519&#xff09;&#xff0c;它由三部分组成&#xff1a;头部&#xff08;header&#xff09;、载荷&#xff08;p…

zabbix-server-pgsql docker镜像备忘

Environment Variables 基本变量 When you start the zabbix-server-pgsql image, you can adjust the configuration of the Zabbix server by passing one or more environment variables on the docker run command line. DB_SERVER_HOST This variable is IP or DNS nam…

cad转shp再转3dtiles生成白模

1、准备CAD数据 2、arcgis中添加cad数据 添加面 cad中的标高字段是能带进arcgis中的&#xff0c;如果这个数据是建筑高度&#xff0c;可以直接用了 3、转shp 4、shp转3dtiles白模 cesiumlab中shp转3dtiles白模效果一

【智能硬件、大模型、LLM 智能音箱】Emo:基于树莓派 4B DIY 能笑会动的桌面机器人

简介 Emo 是一款个人伴侣机器人,集时尚与创新于一身。他的诞生离不开最新的树莓派 4 技术和先进的设计。他不仅仅是一款机器人,更是一个活生生的存在。与其他机器人不同,他拥有独特的个性和情感,能够俘获你的心灵。 硬件部分 – 树莓派 4B – 微雪 2 英寸 IPS LCD 显示屏…

Spring Cloud Alibaba微服务从入门到进阶(三)

Spring Cloud Alibaba是spring Cloud的子项目 Spring Cloud Alibaba的主要组件&#xff08;红框内是开源的&#xff09; Spring Cloud是快速构建分布式系统的工具集&#xff0c; Spring Cloud提供了很多分布式功能 Spring Cloud常用子项目 项目整合 Spring Cloud Alibaba …