用code去探索理解Llama架构的简单又实用的方法

除了白月光我们也需要朱砂痣

      我最近也在反思,可能有时候算法和论文也不是每个读者都爱看,我也会在今后的文章中加点code或者debug模型的内容,也许还有一些好玩的应用demo,会提升这部分在文章类型中的比例

      今天带着大家通过代码角度看一下Llama,或者说看一下Casual-LLM的Transfomer到底长啥样

     对Transfomer架构需要更了解的读者,可以先看这个系列

小周带你读论文-2之"草履虫都能看懂的Transformer老活儿新整"Attention is all you need(1) (qq.com)

小周带你读论文-2之"草履虫都能看懂的Transformer老活儿新整"Attention is all you need(2) (qq.com)

小周带你读论文-2之"草履虫都能看懂的Transformer老活儿新整"Attention is all you need(3) (qq.com)

小周带你读论文-2之"草履虫都能看懂的Transformer老活儿新整"Attention is all you need(4) (qq.com)

       友情提示,看代码和debug都不需要GPU,在你的PC上就可以做

       首先先安装transfomer库

       pip install transfomers

       然后进入到库下面,一般在这

图片

       进去就能找到Transfomers的库,往下拉到models,可以发现各种模型都在里面

图片

找到Llama,包含以下文件

图片

      点进modeling_llama.py发现1200多行,根本没法看(一般人没那么多耐心,但是其实仔细看两遍还是很有收获的),然后点击左边outline 大纲,(或者ctrl+shift+o),就可以有选择的看你想要研究的具体网络层,这样感觉压力瞬间小了百分之80以上

图片

      我想查某个函数的代码块,想对它加深了解,举个例子,看起来比较怪异命名的,这个线性扩展RoPE embedding的函数

图片

       然后ctl按住函数名字就能查到它作用于attetion的机制里面,在这种情况下即使我不知道它到底干啥,也能猜个89不离10,至少和什么主模块相关我清楚了

图片

     本章的话,我们先从模型主体部分看起,点LlamaModel,就能看到非常清晰的逻辑

    

图片

    主体部分包含3个子模块:

  • 先要embedding token

  • 再要包含一个通过for循环,不断的持续经过的decoder层

  • 还要包含一个Normal(RMSNorm)

     

     从大面上看也就这么3个操作

     初始化部分,初始了哪些模块我们看完了,再看看细节,从forward看起(大家看任何网络都要重点看forward)

图片

 forward输入部分,要求输入的参数:

  • input_ids:输入序列标记的索引,形状为(batch_size, sequence_length)的torch.LongTensor。

  • attention_mask:避免对填充标记进行注意力计算的掩码,形状为(batch_size, sequence_length)的torch.Tensor

  • position_ids:输入序列标记在位置嵌入中的索引,形状为(batch_size, sequence_length)的torch.LongTensor

  • past_key_values:包含预先计算的隐藏状态(自注意力块和交叉注意力块中的键和值),用于加速顺序解码的tuple,推理用的,当use_cache=True时,会返回这个参数,训练不用管

  • inputs_embeds:直接传入embedding表示而不是input_ids,形状为(batch_size, sequence_length, hidden_size)的torch.FloatTensor

  • use_cache:是否使用缓存加速解码的布尔值,当设置为True时,past_key_values的键值状态将被返回,用于加速解码

  • output_attentions:是否返回所有注意力层的注意力张量,布尔值,当设置为True时,会在返回的结果中包含注意力张量

  • output_hidden_states:是否返回所有层的隐藏状态,布尔值,当设置为True时,会在返回的结果中包含隐藏状态

  • return_dict:是否返回utils.ModelOutput对象而不是普通的元组,布尔值,当设置为True时,会返回一个ModelOutput对象

         我们继续看,下面就是一些操作步骤,输入的input_id,会被向量化,生成hidden_states

图片

    hidden_states然后就被扔进了若干个hidden_layer被for循环来回的操作,比如Llama7B的32层

   

图片

     我们简单写一段逻辑描述上述的代码

     比如在把"我爱你"已经分词的情况下 我=100,爱=200,你=300

  input_ids = [100,200,300]

  input_ids  -> nn.Emebdding(dims=3) -> hidden_states

  hidden_states = [[0.1,0.2,0.3],[0.4,0.5,0.6],[0.7.0.8.1.1]]

  hidden_states ->layer1 ->layer2 ------>layer32

  最终还是hidden_states

  也就是最终的shape和初始的hidden_states的shape的相同的

   hidden_states -> Norm -> hidden_states

   其实没必要把hidden_states理解的那么悬,就当它是个中间变量就可以了

     Layer里面都有什么呢?我们点进去Layer里面就能看到

图片

     包含了我们在Transfomer里学到的attention层,MLP层,LayerNorm这些层

     继续往下

图片

     我们可以看到首先定义了残差,然后hidden_states在没过layer之前就先被Normal了一下(这个知识点以前也讲过,Llama RMS LN是前LN),然后过attetion,过MLP,最后hidden_states=residual+hidden_states, 这样就把位置编码啥的也都带过来了

     然后我们点击进入attetion这块,看看它是咋做的

图片

      这块其实感觉都不用讲了,我觉得全部代码就这块写的逻辑是最清楚的

图片

,QKV矩阵就是下面前三个线性层,然后分别主管生成qkv,上面的参数也解释了多头,和多头dim是咋算出来的,GQA咋算,这块代码写的实在清晰,属实没啥可说的

       最后看一眼MLP层,就是FFN

图片

      这一层是由3个线性层组成的gate+up+down,和标准transfomer的FFN是真不一样,主要有两点不同:

1- 原始transfomer是由两个线性层组成,这里有3个,原因就在于 SwiGLU 激活函数比类似 ReLU 的激活函数,需要多一个 Linear 层(gate)进行门控,所以你说SwiGLU是比ReLU效率高了,还是低了呢?

图片

(当然这个门控可以并行操作)

2- 原始Transfomer第一个线性层先将维度映射为4h维,第二个线性层再映射回h维,接着进行激活函数操作。而llama则是将原有4h变成一个常量作为输入,且计算方式也略有不同,可能是因为4h这个说法如果模型太大会罩不住,就用一个常量来代替了(我瞎猜的)

        这基本网络就讲完了,其实大家看看也没啥玩意,比较简单

      再比如说我们像看一下Llama是咋做下游任务的,modeling里面有2个,我们挑CLM任务(主管生成的任务,CasualLM)来看

图片

     如上图所示,就是在Llama的基础模型之上,加了一个线性层

     然后我们展开LlamaForCausalLM

图片

      可以看到它包含很多内容,这里面重点肯定还是forward(前向),我们就看一下它的前向是怎么进行的,输入是啥,输出是啥?

图片

      我们先看forward输入部分,要求输入的参数和前面都一样唯一的区别就是Label

  • Labels:没法解释,就是字面上的Labels

    图片

    图片

     输出这块我们就关注hidden_states就行了,这也是最重要的

图片

      然后hiden_states和你的labels进行匹配求交叉熵损失,最后损失最小的就是最高概率生成的token 

      下面那个文本分类的其实也是一样的,只不过它计算Loss的时候会有说明,单标签就是回归任务,多标签就变分类任务了

图片

     通过这些简单的方法,就能把读者们把以前理解不太透彻的网络模块都理一遍,加深印象      

     本章完

图片

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/680008.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数据库管理-第149期 Oracle Vector DB AI-01(20240210)

数据库管理149期 2024-02-10 数据库管理-第149期 Oracle Vector DB & AI-01(20240210)1 机器学习2 向量3 向量嵌入4 向量检索5 向量数据库5 专用向量数据库的问题总结 数据库管理-第149期 Oracle Vector DB & AI-01(20240210&#xf…

每日五道java面试题之java基础篇(六)

第一题:Java 创建对象有哪⼏种⽅式? Java 中有以下四种创建对象的⽅式: new 创建新对象通过反射机制采⽤ clone 机制通过序列化机制 前两者都需要显式地调⽤构造⽅法。对于 clone 机制,需要注意浅拷⻉和深拷⻉的区别,对于序列化机制需要明…

书生·浦语大模型第五课作业

基础作业: 使用 LMDeploy 以本地对话、网页Gradio、API服务中的一种方式部署 InternLM-Chat-7B 模型,生成 300 字的小故事(需截图) 这里 /share/conda_envs 目录下的环境是官方未大家准备好的基础环境,因为该目录是共…

leetcode 算法 69.x的平方根(python版)

需求 给你一个非负整数 x ,计算并返回 x 的 算术平方根 。 由于返回类型是整数,结果只保留 整数部分 ,小数部分将被 舍去 。 注意:不允许使用任何内置指数函数和算符,例如 pow(x, 0.5) 或者 x ** 0.5 。 示例 1&#…

基于Java (spring-boot)的宿舍管理系统

一、项目介绍 基于Java (spring-boot)的宿舍管理系统功能:登录界面、宿舍管理、学生管理、班级管理、宿舍楼管理、维修记录、晚归记录、请假记录、用户管理、角色管理、菜单管理、日志管理、我收到的、退宿审核,等等等 二、作品包含 三、项目技术 后端语…

六轴机器人奇异点

1 奇异点说明 有着6个自由度的KUKA机器人具有3个不同的奇点位置。即便在给定状态和步骤顺序的情况下,也无法通过逆向变换(将笛卡尔坐标转换成极坐标值)得出唯一数值时,即可认为是一个奇点位置。这种情况下,或者当最小的笛卡尔变化也能导致非常大的轴角度变化时,即为奇点位置…

fast.ai 机器学习笔记(三)

机器学习 1:第 8 课 原文:medium.com/hiromi_suenaga/machine-learning-1-lesson-8-fa1a87064a53 译者:飞龙 协议:CC BY-NC-SA 4.0 来自机器学习课程的个人笔记。随着我继续复习课程以“真正”理解它,这些笔记将继续更…

【初中生讲机器学习】8. KNN 算法原理 实践一篇讲清!

创建时间:2024-02-11 最后编辑时间:2024-02-12 作者:Geeker_LStar 你好呀~这里是 Geeker_LStar 的人工智能学习专栏,很高兴遇见你~ 我是 Geeker_LStar,一名初三学生,热爱计算机和数学,我们一起加…

Nature Machine Intelligence 使用机器学习驱动的可拉伸智能纺织手套捕捉复杂的手部动作和物体交互

研究背景 对灵巧手运动的精确实时跟踪在人机交互、元宇宙、机器人和远程医疗等领域有着广泛的应用。当前的可穿戴设备中的大多数仅用于检测精度有限的特定手势,并且没有解决与设备的可靠性、准确性和可清洗相关的挑战。对传感器直接放置在用户的手上有严格的要求&am…

第四节 zookeeper集群与分布式锁

目录 1. Zookeeper集群操作 1.1 客户端操作zk集群 1.2 模拟集群异常操作 1.3 curate客户端连接zookeeper集群 2. Zookeeper实战案例 2.1 创建项目引入依赖 2.2 获取zk客户端对象 2.3 常用API 2.4 客户端向服务端写入数据流程 2.5 服务器动态上下线、客户端动态监听 2…

JUC-并发面试题

一、基础 1、为什么要并发编程 充分利用多核CPU的资源2、并发编程存在的问题 上下文切换:PU通过时间片分配算法来循环执行任务,当前任务执行一个时间片后会切换到下一个任务。在切换前会保存上一个任务的状态,以便下次切换回这个任务时,可以再加载这个任务的状态。任务从保…

一个三极管引脚识别的小技巧,再也不用对照手册啦

三极管是一个非常常用的器件,时不时的就需要用到他们,有些时候当我们拿到一颗三极管时 ,对于常用的友来说,三极管的引脚可能早已烂熟于心,而对于不常用或者初学者来说,三极管的引脚可以说是今天记下明天忘&…

Pytorch卷积层原理和示例 nn.Conv1d卷积 nn.Conv2d卷积

内容列表 一,前提 二,卷积层原理 1.概念 2.作用 3. 卷积过程 三,nn.conv1d 1,函数定义: 2, 参数说明: 3,代码: 4, 分析计算过程 四,nn.conv2d 1, 函数定义 2, 参数: 3, 代码 4, 分析计算过程 …

面了滴滴的数据分析师(实习),几道面试题都是原题啊。。。

年前,技术群组织了一场数据类的技术&面试讨论会,邀请了一些同学分享他们的面试经历,讨论会会定期召开,如果你想加入我们的讨论群或者希望要更详细的资料,文末加入。 喜欢本文记得收藏、关注、点赞 。技术交流文末…

C++之继承

一,概念及用法 1)概念 首先我们来了解一下官方的概念:继承(inheritance)机制是面向对象程序设计使代码可以复用的最重要的手段,它允许程序员在保持原有类特性的基础上进行扩展,增加功能,这样产生新的类&a…

leetcode:131.分割回文串

树形结构: 切割到字符串的尾部,就是叶子节点。 回溯算法三部曲: 1.递归的参数和返回值: 参数字符串s和startIndex切割线 2.确定终止条件: 当分割线到字符串末尾时到叶子节点,一种方案出现 3.单层搜索…

使用R语言建立回归模型并分割训练集和测试集

通过简单的回归实例&#xff0c;可以说明数据分割为训练集和测试集的必要性。以下先建立示例数据: set.seed(123) #设置随机种子 x <- rnorm(100, 2, 1) # 生成100个正态分布的随机数&#xff0c;均值为2&#xff0c;标准差为1 y exp(x) rnorm(5, 0, 2) # 生成一个新的变…

【C语言】assert断言:保护程序的利器

在软件开发过程中&#xff0c;我们经常会遇到一些假设条件或者预期行为。例如&#xff0c;我们可能假设一个函数的输入参数必须在某个范围内&#xff0c;或者某个变量的值应该满足特定的条件。当这些假设或预期行为被打破时&#xff0c;程序可能会出现异常行为&#xff0c;甚至…

Java+SpringBoot:高校竞赛管理新篇章

✍✍计算机编程指导师 ⭐⭐个人介绍&#xff1a;自己非常喜欢研究技术问题&#xff01;专业做Java、Python、微信小程序、安卓、大数据、爬虫、Golang、大屏等实战项目。 ⛽⛽实战项目&#xff1a;有源码或者技术上的问题欢迎在评论区一起讨论交流&#xff01; ⚡⚡ Java实战 |…