用 C 语言进行大模型推理:探索 llama2.c 仓库(二)

文章目录

  • 前提
  • 如何构建一个Transformer Model
    • 模型定义
    • 模型初始化
  • 如何构建tokenzier 和 sampler
  • 如何进行推理
  • 总结

前提

上一节我们介绍了llama2.c中如何对hugging face的权重进行处理,拿到了llama2.c想要的权重格式和tokenizer.bin格式。这一节我们分析下在llama2.c如何解析这两个.bin文件。这一节的所有代码都在run.c文件里。

用 C 语言进行大模型推理:探索 llama2.c 仓库(一)

如何构建一个Transformer Model

按照一个最简单地理解,我们可以使用C语言构建一个Transformer Model,然后将两个.bin文件按照格式填进去即可。那这个Transformer Model 应该是一个什么数据结构呢,或者是一个什么样的组织架构呢?在C语言中没有class这个概念的,最多我们常见的也就是结构体了,而且结构体里只能定义变量,不能定义函数。所以那些操作Transformer Model中的那些算子又该如何实现呢?带着这些问题,或者你还有其他的问题,我们一步一步来看下llama2.c中是如何实现的。

模型定义

typedef struct {int dim;        // transformer dimensionint hidden_dim; // for ffn layersint n_layers;   // number of layersint n_heads;    // number of query headsint n_kv_heads; // number of key/value heads (can be < query heads because of// multiquery)int vocab_size; // vocabulary size, usually 256 (byte-level)int seq_len;    // max sequence length
} Config;typedef struct {// token embedding tablefloat *token_embedding_table; // (vocab_size, dim)// weights for rmsnormsfloat *rms_att_weight; // (layer, dim) rmsnorm weightsfloat *rms_ffn_weight; // (layer, dim)// weights for matmuls. note dim == n_heads * head_sizefloat *wq; // (layer, dim, n_heads * head_size)float *wk; // (layer, dim, n_kv_heads * head_size)float *wv; // (layer, dim, n_kv_heads * head_size)float *wo; // (layer, n_heads * head_size, dim)// weights for ffnfloat *w1; // (layer, hidden_dim, dim)float *w2; // (layer, dim, hidden_dim)float *w3; // (layer, hidden_dim, dim)// final rmsnormfloat *rms_final_weight; // (dim,)// (optional) classifier weights for the logits, on the last layerfloat *wcls;
} TransformerWeights;typedef struct {// current wave of activationsfloat *x;      // activation at current time stamp (dim,)float *xb;     // same, but inside a residual branch (dim,)float *xb2;    // an additional buffer just for convenience (dim,)float *hb;     // buffer for hidden dimension in the ffn (hidden_dim,)float *hb2;    // buffer for hidden dimension in the ffn (hidden_dim,)float *q;      // query (dim,)float *k;      // key (dim,)float *v;      // value (dim,)float *att;    // buffer for scores/attention values (n_heads, seq_len)float *logits; // output logits// kv cachefloat *key_cache;   // (layer, seq_len, dim)float *value_cache; // (layer, seq_len, dim)
} RunState;typedef struct {Config config; // the hyperparameters of the architecture (the blueprint)TransformerWeights weights; // the weights of the modelRunState state; // buffers for the "wave" of activations in the forward pass// some more state needed to properly clean up the memory mapping (sigh)int fd;            // file descriptor for memory mappingfloat *data;       // memory mapped data pointerssize_t file_size; // size of the checkpoint file in bytes
} Transformer;

llama2.c中的Transformer是一个结构体,其中最重要的三个成员变量configweightsstate,分别保存了网络的超参数,权重,以及网络运行过程中的中间结果。
强烈建议这里你仔细理解理解,体会一下这个写法。

模型初始化

我们要对定义的模型进行初始化,主要是两个方面:权重初始化和中间变量初始化。这里llama2.c的写法就更厉害了。请仔细欣赏下面的两个函数:

权重初始化函数:

void memory_map_weights(TransformerWeights *w, Config *p, float *ptr,int shared_weights) {int head_size = p->dim / p->n_heads;// make sure the multiplications below are done in 64bit to fit the parameter// counts of 13B+ modelsunsigned long long n_layers = p->n_layers;w->token_embedding_table = ptr;ptr += p->vocab_size * p->dim;w->rms_att_weight = ptr;ptr += n_layers * p->dim;w->wq = ptr;ptr += n_layers * p->dim * (p->n_heads * head_size);w->wk = ptr;ptr += n_layers * p->dim * (p->n_kv_heads * head_size);w->wv = ptr;ptr += n_layers * p->dim * (p->n_kv_heads * head_size);w->wo = ptr;ptr += n_layers * (p->n_heads * head_size) * p->dim;w->rms_ffn_weight = ptr;ptr += n_layers * p->dim;w->w1 = ptr;ptr += n_layers * p->dim * p->hidden_dim;w->w2 = ptr;ptr += n_layers * p->hidden_dim * p->dim;w->w3 = ptr;ptr += n_layers * p->dim * p->hidden_dim;w->rms_final_weight = ptr;ptr += p->dim;ptr += p->seq_len * head_size /2; // skip what used to be freq_cis_real (for RoPE)ptr += p->seq_len * head_size /2; // skip what used to be freq_cis_imag (for RoPE)w->wcls = shared_weights ? w->token_embedding_table : ptr;
}

自我感觉这个仓库很经典得一段代码就是这里了,我没有加载权重吧,我只是拿到了它的地址,然后映射给我结构体中的变量。然后等我真正推理计算的时候,用到哪一段权重就将哪一段权重加载到内存中参与计算。

中间变量初始化:

void malloc_run_state(RunState *s, Config *p) {// we calloc instead of malloc to keep valgrind happyint kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;s->x = calloc(p->dim, sizeof(float));s->xb = calloc(p->dim, sizeof(float));s->xb2 = calloc(p->dim, sizeof(float));s->hb = calloc(p->hidden_dim, sizeof(float));s->hb2 = calloc(p->hidden_dim, sizeof(float));s->q = calloc(p->dim, sizeof(float));s->key_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));s->value_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));s->att = calloc(p->n_heads * p->seq_len, sizeof(float));s->logits = calloc(p->vocab_size, sizeof(float));// ensure all mallocs went fineif (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q ||!s->key_cache || !s->value_cache || !s->att || !s->logits) {fprintf(stderr, "malloc failed!\n");exit(EXIT_FAILURE);}
}

如果不太理解权重初始化和中间变量初始化时为什么要申请那么大的空间,可以自己手动地将网络地数据流从头到尾推一遍。

如何构建tokenzier 和 sampler

对于这两个模块地构建我们不多介绍,感兴趣地可以自己去看看源码。

如何进行推理

这部分是我最感兴趣的地方。

  // forward all the layersfor (unsigned long long l = 0; l < p->n_layers; l++) {// attention rmsnormrmsnorm(s->xb, x, w->rms_att_weight + l * dim, dim);// key and value point to the kv cacheint loff = l * p->seq_len * kv_dim; // kv cache layer offset for conveniences->k = s->key_cache + loff + pos * kv_dim;s->v = s->value_cache + loff + pos * kv_dim;// qkv matmuls for this positionmatmul(s->q, s->xb, w->wq + l * dim * dim, dim, dim);matmul(s->k, s->xb, w->wk + l * dim * kv_dim, dim, kv_dim);matmul(s->v, s->xb, w->wv + l * dim * kv_dim, dim, kv_dim);// RoPE relative positional encoding: complex-valued rotate q and k in each// headfor (int i = 0; i < dim; i += 2) {int head_dim = i % head_size;float freq = 1.0f / powf(10000.0f, head_dim / (float)head_size);float val = pos * freq;float fcr = cosf(val);float fci = sinf(val);int rotn = i < kv_dim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q onlyfor (int v = 0; v < rotn; v++) {float *vec =v == 0 ? s->q : s->k; // the vector to rotate (query or key)float v0 = vec[i];float v1 = vec[i + 1];vec[i] = v0 * fcr - v1 * fci;vec[i + 1] = v0 * fci + v1 * fcr;}}// multihead attention. iterate over all headsint h;
#pragma omp parallel for private(h)for (h = 0; h < p->n_heads; h++) {// get the query vector for this headfloat *q = s->q + h * head_size;// attention scores for this headfloat *att = s->att + h * p->seq_len;// iterate over all timesteps, including the current onefor (int t = 0; t <= pos; t++) {// get the key vector for this head and at this timestepfloat *k = s->key_cache + loff + t * kv_dim + (h / kv_mul) * head_size;// calculate the attention score as the dot product of q and kfloat score = 0.0f;for (int i = 0; i < head_size; i++) {score += q[i] * k[i];}score /= sqrtf(head_size);// save the score to the attention bufferatt[t] = score;}// softmax the scores to get attention weights, from 0..pos inclusivelysoftmax(att, pos + 1);// weighted sum of the values, store back into xbfloat *xb = s->xb + h * head_size;memset(xb, 0, head_size * sizeof(float));for (int t = 0; t <= pos; t++) {// get the value vector for this head and at this timestepfloat *v =s->value_cache + loff + t * kv_dim + (h / kv_mul) * head_size;// get the attention weight for this timestepfloat a = att[t];// accumulate the weighted value into xbfor (int i = 0; i < head_size; i++) {xb[i] += a * v[i];}}}// final matmul to get the output of the attentionmatmul(s->xb2, s->xb, w->wo + l * dim * dim, dim, dim);// residual connection back into xfor (int i = 0; i < dim; i++) {x[i] += s->xb2[i];}// ffn rmsnormrmsnorm(s->xb, x, w->rms_ffn_weight + l * dim, dim);// Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))// first calculate self.w1(x) and self.w3(x)matmul(s->hb, s->xb, w->w1 + l * dim * hidden_dim, dim, hidden_dim);matmul(s->hb2, s->xb, w->w3 + l * dim * hidden_dim, dim, hidden_dim);// SwiGLU non-linearityfor (int i = 0; i < hidden_dim; i++) {float val = s->hb[i];// silu(x)=x*σ(x), where σ(x) is the logistic sigmoidval *= (1.0f / (1.0f + expf(-val)));// elementwise multiply with w3(x)val *= s->hb2[i];s->hb[i] = val;}// final matmul to get the output of the ffnmatmul(s->xb, s->hb, w->w2 + l * dim * hidden_dim, hidden_dim, dim);// residual connectionfor (int i = 0; i < dim; i++) {x[i] += s->xb[i];}}

for循环所有的layers进行推理,有三个主要的子函数,分别是:rmsnormmatmulsoftmax,分别对应着三个算子,其他的算子则是直接在for循环内实现的。所有的layer都计算一遍后,再加上后处理即可完成一个token的推理。

总结

总得来说,这个库还是有很多的东西值得我们去学习的,学习下大神的编码思维和编码方式。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/9477.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vue原生div做触底加载

第一种&#xff1a; 触底加载和图片懒加载的思路一样&#xff0c;屏幕的高度加上滚动的高度快要大于最后一个元素距离顶部的高度的时候就开始加载数据&#xff1b; &#xff08;1&#xff09;clientHeight&#xff1a;屏幕的高度&#xff1b; &#xff08;2&#xff09;scro…

漫威争锋Marvel Rivals怎么搜索 锁区怎么搜 游戏搜不到怎么办

即将问世的《漫威争锋》&#xff08;Marvel Rivals&#xff09;作为一款万众期待的PvP射击游戏新星&#xff0c;荣耀携手漫威官方网站共同推出。定档5月11日清晨9时&#xff0c;封闭Alpha测试阶段将正式揭开序幕&#xff0c;持续时间长达十天之久。在此首轮测试窗口&#xff0c…

一个开源即时通讯源码

一个开源即时通讯源码 目前已经含服务端、PC、移动端即时通讯解决方案&#xff0c;主要包含以下内容。 服务端简介 不要被客户端迷惑了&#xff0c;真正值钱的是服务端&#xff0c; 服务是采用Java语言开发&#xff0c;基于spring cloud微服务体系开发的一套即时通讯服务端。…

栈结构(c语言)

1.栈的概念 栈&#xff1a;一种特殊的线性表&#xff0c;其只允许在固定的一端进行插入和删除元素操作。进行数据插入和删除操作的一端称为栈顶&#xff0c;另一端称为栈底。栈中的数据元素遵守后进先出LIFO&#xff08;Last In First Out&#xff09;的原则。 压栈&am…

STM32的ADC详解

ADC即模拟数字转换器&#xff0c;通常用于将外部的模拟量信号转换为数字信号。STM32的ADC是12位逐次逼近型的模拟数字转换器&#xff0c;最大可以计数到4095&#xff0c;有18个通道&#xff0c;16个外部通道和2个内部通道。 ADC框图 ADC的功能框图可以分为七个部分&#xff1a…

RabbitMQ php amqp

Linux debian 安装 Windows php amqp 扩展 PECL :: Package :: amqp 将 php_amqp.dll 复制到 php 的 ext 目录下 将 rabbitmq.4.dll 复制到 c:\windows\system32 目录下 php.ini extensionamqp

记一次springboot jpa更新复杂几何类型报错Only simple geometries should be used

问题&#xff1a; 更新数据时&#xff0c; 几何字段MultiPolygon类型时报错&#xff1b; java.lang.IllegalStateException: Only simple geometries should be used 几何字段Point类型时不报错&#xff1b; 新增时字段存在MultiPolygon不报错。 查看日志可知&#xff0c;…

vscode 使用正则搜索

ctrl c 复制&#xff0c;内容如下&#xff1a; Vue3简介创建Vue3工程Vue3核心语法路由pinia组件通信其它 APIVue3新组件

Go 单元测试完全指南(一)- 基本测试流程

为什么写单元测试&#xff1f; 关于测试&#xff0c;有一张很经典的图&#xff0c;如下&#xff1a; 说明&#xff1a; 测试类型成本速度频率E2E 测试高慢低集成测试中中中单元测试低快高 也就是说&#xff0c;单元测试是最快、最便宜的测试方式。这不难理解&#xff0c;单元…

Baidu Comate:让编码实现无限可能

目录 1 背景介绍2 快速入门2.1 智能推荐功能2.2 智能生成功能2.2.1 智能注释2.2.2 智能生成2.2.3 智能调优2.2.4 代码解释 3 高兼容性4 即刻体验 1 背景介绍 Baidu Comate&#xff08;智能代码助手&#xff09;是基于文心大模型&#xff0c;结合百度积累多年的编程现场大数据和…

【MySQL数据库】丨一文详解 JdbcTemplate(Spring中的CRUD)

前言 JdbcTemplate 是 Spring框架 中提供的一个对象&#xff0c;用于简化JDBC操作。它使得数据库操作变得更为简单和方便&#xff0c;大大提高了开发效率。 文章目录 前言为何要使用JdbcTemplate在JdbcTemplate中执行SQL语句的方法大致分为3类&#xff1a;案例代码 JdbcTemplat…

word 毕业论文格式调整

添加页眉页脚 页眉 首先在页面上端页眉区域双击&#xff0c;即可出现“页眉和页脚”设置页面&#xff1a; 页眉左右两端对齐 如果想要页眉页脚左右两端对齐&#xff0c;可以选择添加三栏页眉&#xff0c;然后将中间那一栏删除&#xff0c;即可自动实现左右两端对齐&#x…

Linux 操作系统TCP、UDP

1、TCP服务器编写流程 头文件&#xff1a; #include <sys/socket.h> 1.1 创建套接字 函数原型&#xff1a; int socket(int domain, int type, int protocol); 参数&#xff1a; domain: 网域 AF_INET &#xff1a; IPv4 AF_INET6 &a…

fswatch工具:跟踪Linux中的文件和目录更改

fswatch是一个跨平台的文件更改监视器&#xff0c;当指定文件或目录的内容被更改或修改时&#xff0c;它会收到通知警报。 fswatch在不同的操作系统上执行多种类型的监视器&#xff0c;例如&#xff1a; 基于 Apple OS X 的文件系统事件 API 构建的监视器。基于kqueue的监视器…

WPF之DataGird应用

1&#xff0c;DataGrid相关属性 GridLinesVisibility&#xff1a;DataGrid网格线是否显示或者显示的方式。HorizontalGridLinesBrush&#xff1a;水平网格线画刷。VerticalGridLinesBrush&#xff1a;垂直网格线画刷。HorizontalScrollBarVisibility&#xff1a;水平滚动条可见…

ASP.NET MVC 如何使用 Form Authentication?

前言 .NET 的 Form Authentication 是一种基于表单的简单且灵活的身份验证机制&#xff0c;用户通过输入用户名和密码来登录应用程序&#xff0c;并且通过配置来控制用户访问权限。 在使用 Form Authentication 时&#xff0c;我们需要在 web.config 文件中配置身份验证和授权…

Spring Cloud Consul 4.1.1

该项目通过自动配置和绑定到 Spring 环境和其他 Spring 编程模型习惯用法&#xff0c;为 Spring Boot 应用程序提供 Consul 集成。通过一些简单的注释&#xff0c;您可以快速启用和配置应用程序内的常见模式&#xff0c;并使用基于 Consul 的组件构建大型分布式系统。提供的模式…

Spark云计算平台Databricks使用,第一个Spark应用程序WordCount

1 上传文件 上传words.txt文件&#xff1a;Spark云计算平台Databricks使用&#xff0c;上传文件-CSDN博客 上传的文件的路径是/FileStore/tables/words.txt&#xff0c;保存在AWS的S3 hello world hello hadoop hello world hello databricks hadoop hive hbase yarn spark …

利用BACnet分布式IO控制器优化Niagara楼宇自动化系统

在智能建筑领域&#xff0c;随着物联网技术的飞速发展&#xff0c;如何实现高效、灵活且安全的楼宇自动化控制成为了行业关注的焦点。BACnet IP分布式远程I/O模块&#xff0c;作为这一领域的创新成果&#xff0c;正逐渐成为连接智能建筑各子系统的关键桥梁&#xff0c;尤其在与…

短效http代理ip和动态http代理有什么联系?

http代理 是指在客户端和服务器放一个代理服务器进行http协议传输&#xff0c;代理服务器将客户端的请求转发给目标服务器&#xff0c;将响应的信息通过代理服务器返回给客户端。代理服务器可以做到缓存、转发等经过的请求或者响应的信息。从而保护用户的个人信息。 一、概念…