Llama3:全模型GQA与tiktoken分词的新突破

        在本篇文章中,我们将介绍Llama3模型,并且对比它与Llama2在模型层面上的主要区别。Llama3 相较于Llama2的最显著变化是引入了全模型GQA(Grouped Query Attention)机制,并且在分词阶段使用了与GPT一致的 tiktoken 分词方式。

Llama3 和 Llama2 的模型层面区别

        Llama3 相较于 Llama2 的主要区别在于其全模型使用了 GQA(Grouped Query Attention),这使得多头注意力机制中的键值对变得更加高效,减少了计算和内存开销。

模型参数定义

        在实现 Llama3 时,我们使用了 Python 的 @dataclass 装饰器来定义模型的超参数。@dataclass 能够简化类的定义过程,自动生成构造函数 __init__(),打印方法 __repr__(),以及判断两个类是否相等的 __eq__()

代码示例:
@dataclass
class ModelArgs:dim: int = 4096  # 模型维度n_layers: int = 6  # 层数n_heads: int = 6  # 注意力头数n_group: Optional[int] = 3  # GQA组数vocab_size: int = 4096  # 词表大小hidden_dim: Optional[int] = None  # 隐藏层维度multiple_of: int = 256  # MLP层隐层维度的计算因子norm_eps: float = 1e-5  # 正则化epsmax_seq_len: int = 2048  # 最大序列长度dropout: float = 0.0  # Dropout比率

RMS正则化

        RMS正则化的原理已经在之前的 Qwen 文章中讲解过,Llama3 采用了同样的 RMSNorm 来实现层的标准化。

代码示例:
class RMSNorm(torch.nn.Module):def __init__(self, dim: int, eps: float):super().__init__()self.eps = epsself.weight = nn.Parameter(torch.ones(dim))def _norm(self, x):return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)def forward(self, x):output = self._norm(x.float()).type_as(x)return output * self.weight

ROPE 相对位置嵌入

        ROPE(Rotary Positional Embedding)的实现与 Qwen 模型类似,负责对自注意力的查询(Query)和键(Key)进行位置编码。

代码示例:
def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)xq_out_r = xq_r * freqs_cos - xq_i * freqs_sinxq_out_i = xq_r * freqs_sin + xq_i * freqs_cosxk_out_r = xk_r * freqs_cos - xk_i * freqs_sinxk_out_i = xk_r * freqs_sin + xk_i * freqs_cosxq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)return xq_out.type_as(xq), xk_out.type_as(xk)

Grouped Query Attention (GQA)

        在 Llama3 中,Attention 模块使用了 GQA 机制,这意味着每组注意力头共享相同的键和值,这种方法减少了计算开销。

代码示例:
class Attention(nn.Module):def __init__(self, args: ModelArgs):super().__init__()self.group = args.n_groupself.heads = args.n_headsself.kv_heads = args.n_heads // args.n_groupself.head_dim = args.dim // args.n_headsself.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)self.wk = nn.Linear(args.dim, self.kv_heads * self.head_dim, bias=False)self.wv = nn.Linear(args.dim, self.kv_heads * self.head_dim, bias=False)self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)self.attn_dropout = nn.Dropout(args.dropout)self.resid_dropout = nn.Dropout(args.dropout)def forward(self, x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor):xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)xq = xq.view(-1, xq.size(1), self.heads, self.head_dim)xk = xk.view(-1, xk.size(1), self.kv_heads, self.head_dim)xv = xv.view(-1, xv.size(1), self.kv_heads, self.head_dim)xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)xk = repeat_kv(xk, self.group)xv = repeat_kv(xv, self.group)scores = torch.matmul(xq, xk.transpose(-1, -2)) / math.sqrt(self.head_dim)scores = torch.softmax(scores, dim=-1)output = torch.matmul(scores, xv)output = output.transpose(1, 2).contiguous().view(-1, x.size(1), self.heads * self.head_dim)return self.wo(output)

FeedForward 模块

        Llama3 的 MLP 模块通过线性变换、激活函数和 Dropout 组成,与 Qwen 模型一致。

代码示例:
class FeedForward(nn.Module):def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):super().__init__()self.w1 = nn.Linear(dim, hidden_dim, bias=False)self.w2 = nn.Linear(hidden_dim, dim, bias=False)self.w3 = nn.Linear(dim, hidden_dim, bias=False)self.dropout = nn.Dropout(dropout)def forward(self, x):return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))

TransformerBlock: 将模块组合成完整层

        Llama3 的 Transformer 层由 Attention、FeedForward、RMSNorm 等模块组成,通过多层堆叠构建模型。

代码示例:
class TransformerBlock(nn.Module):def __init__(self, layer_id: int, args: ModelArgs):super().__init__()self.attention = Attention(args)self.feed_forward = FeedForward(dim=args.dim, hidden_dim=args.hidden_dim, multiple_of=args.multiple_of, dropout=args.dropout)self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)def forward(self, x, freqs_cos, freqs_sin):h = x + self.attention.forward(self.attention_norm(x), freqs_cos, freqs_sin)out = h + self.feed_forward.forward(self.ffn_norm(h))return out

Transformer模型:完整的 Llama3 实现

代码示例:
class Transformer(nn.Module):def __init__(self, params: ModelArgs):super().__init__()self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)self.layers = nn.ModuleList([TransformerBlock(i, params) for i in range(params.n_layers)])self.norm = RMSNorm(params.dim, eps=params.norm_eps)self.output = nn.Linear(params.dim, params.vocab_size, bias=False)self.tok_embeddings.weight = self.output.weightfreqs_cos, freqs_sin = precompute_freqs_cis(params.dim // params.n_heads, params.max_seq_len)self.register_buffer("freqs_cos", freqs_cos, persistent=False)self.register_buffer("freqs_sin", freqs_sin, persistent=False)def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None) -> torch.Tensor:h = self.tok_embeddings(tokens)for layer in self.layers:h = layer(h, self.freqs_cos[:h.size(1)], self.freqs_sin[:h.size(1)])h = self.norm(h)return self.output(h)

结语

        通过本篇文章,我们学习了如何从零开始预训练Llama3模型,并认识了它与Llama2在模型结构上的主要区别。Llama3的引入GQA机制大幅提升了模型的推理效率,同时结合 tiktoken 的分词方式,使其在处理文本任务时更具优势。后续,我们将进一步更新关于数据预处理和模型优化的相关教程。

如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!

欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。

谢谢大家的支持!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/53460.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

AI大模型日报#0923:李飞飞创业之后首个专访、华为云+腾讯音乐发布昇腾适配方案

导读:AI大模型日报,爬虫LLM自动生成,一文览尽每日AI大模型要点资讯!目前采用“文心一言”(ERNIE-4.0-8K-latest)、“智谱AI”(glm-4-0520)生成了今日要点以及每条资讯的摘要。欢迎阅…

基于单片机无线智能报警系统的设计

文章目录 前言资料获取设计介绍功能介绍设计程序具体实现截图设计获取 前言 💗博主介绍:✌全网粉丝10W,CSDN特邀作者、博客专家、CSDN新星计划导师,一名热衷于单片机技术探索与分享的博主、专注于 精通51/STM32/MSP430/AVR等单片机设计 主要对…

计算机毕业设计 基于Python的荣誉证书管理系统 Django+Vue 前后端分离 附源码 讲解 文档

🍊作者:计算机编程-吉哥 🍊简介:专业从事JavaWeb程序开发,微信小程序开发,定制化项目、 源码、代码讲解、文档撰写、ppt制作。做自己喜欢的事,生活就是快乐的。 🍊心愿:点…

2024全球超模大赛(北京|山东|内蒙三城联动)顺利举办

近日,2024 全球超模大赛(北京|山东|内蒙)三城联动暨新国潮文化赛事主题发布会在紫薇美力集团国贸鲁采赋盛大举行。此次发布会旨在鼓励优质模特共同传播中国传统文化,让其在全球范围内绽放光彩,展现中国人的骄傲与风采&…

用Python提取PowerPoint演示文稿中的音频和视频

将多种格式的媒体内容进行重新利用(如PowerPoint演示中的音频和视频)是非常有价值的。无论是创建独立的音频文件、提取视频以便在线分发,还是为了未来的使用需求进行资料归档,从演示文稿中提取这些媒体文件可以为多媒体内容的多次…

基于STM32的温度、电流、电压检测proteus仿真系统(OLED、DHT11、继电器、电机)

目录 一、主要功能 二、硬件资源 三、程序编程 四、实现现象 一、主要功能 基于STM32F103C8T6 采用DHT11读取温度、滑动变阻器模拟读取电流、电压。 通过OLED屏幕显示,设置电流阈值为80,电流小阈值为50,电压阈值为60,温度阈值为30 随便哪个超过预祝,则继电器切断,LE…

【sgCreateCallAPIFunctionParam】自定义小工具:敏捷开发→调用接口方法参数生成工具

<template><div :class"$options.name" class"sgDevTool"><sgHead /><div class"sg-container"><div class"sg-start"><div style"margin-bottom: 10px">参数列表[逗号模式]<el-too…

9.23作业

仿照string类&#xff0c;自己手动实现 My_string 代码如下 MyString.h #ifndef MYSTRING_H #define MYSTRING_H #include <iostream> #include <cstring>using namespace std;class My_string { private:char *ptr; //指向字符数组的指针int size; …

十大常用加密软件排行榜|2024年好用的加密软件推荐【精选】

在信息安全日益重要的时代&#xff0c;加密软件成为保护个人和企业数据的关键工具。选择合适的加密软件可以有效防止数据泄露和未授权访问。以下是2024年值得推荐的十大加密软件&#xff0c;帮助你找到适合的解决方案。 1. Ping32加密软件 Ping32是一款功能强大的加密软件&…

Linux C# Day4

作业&#xff1a; 1.统计家目录下.c文件的个数 #!/bin/bash num0 for filename in ls ~/*.c do((num)) done echo $num2.定义一个稀疏数组(下标不连续)&#xff0c;写一个函数&#xff0c;求该稀疏数组的和&#xff0c;要求稀疏数组中的数值通过参数传递到函数中arr([2]9 [4…

Android轻量级RTSP服务使用场景分析和设计探讨

技术背景 好多开发者&#xff0c;对我们Android平台轻量级RTSP服务模块有些陌生&#xff0c;不知道这个模块具体适用于怎样的场景&#xff0c;有什么优缺点&#xff0c;实际上&#xff0c;我们的Android平台轻量级RTSP服务模块更适用于内网环境下、对并发要求不高的场景&#…

基于深度学习的药品三期OCR字符识别

在药品生产线上,药品三期的喷码与条形码识别是保证药品追溯和安全管理的重要环节。传统的识别方法依赖于人工操作,不仅效率低下且容易出错。随着深度学习技术的不断发展,基于OCR(Optical Character Recognition,光学字符识别)的自动化识别系统逐渐成为主流。本文将以哪吒…

DataOps:解决数字化转型中数据价值挖掘挑战的最佳方案

云计算de小白 随着数字化转型的普及与深入&#xff0c;大数据技术在各行业被广泛应用&#xff0c;企业生产、营销、运营等各个环节的数据将被广泛采集&#xff0c;数据应用开发需求的增长、数据使用者角色的复杂度导致企业数据开发、数据运维的工作量、数据应用交付协同难度大…

电子看板实时监控数据可视化助力工厂精细化管理

在当今竞争激烈的制造业领域&#xff0c;工厂的精细化管理成为提高竞争力的关键。而电子看板实时监控数据可视化作为一种先进的管理工具&#xff0c;正为工厂的精细化管理带来巨大的助力。 一、工厂精细化管理的挑战 随着市场需求的不断变化和客户对产品质量要求的日益提高&am…

VMware ESXi 8.0U3b macOS Unlocker OEM BIOS 2.7 集成网卡驱动和 NVMe 驱动 (集成驱动版)

VMware ESXi 8.0U3b macOS Unlocker & OEM BIOS 2.7 集成网卡驱动和 NVMe 驱动 (集成驱动版) 发布 ESXi 8.0U3 集成驱动版&#xff0c;在个人电脑上运行企业级工作负载 请访问原文链接&#xff1a;https://sysin.org/blog/vmware-esxi-8-u3-sysin/&#xff0c;查看最新版…

CSP-J 2019 入门级 第一轮(初赛) 完善程序(1)

【题目】 CSP-J 2019 入门级 第一轮&#xff08;初赛&#xff09; 完善程序&#xff08;1&#xff09; 1.&#xff08;矩阵变幻&#xff09;有一个奇幻的矩阵&#xff0c;在不停的变幻&#xff0c;其变幻方式为&#xff1a; 数字 0 变成矩阵 0 0 0 1 数字 1 变成矩阵 1 1 1 0 …

云南自闭症康复寄宿学校:帮助孩子重塑美好未来

在云南这片充满希望的土地上&#xff0c;自闭症儿童的康复教育一直是社会各界关注的焦点。家长们渴望为孩子找到一所能够提供全面支持和专业指导的康复寄宿学校&#xff0c;帮助他们重塑美好未来。而当我们跨越地域的界限&#xff0c;将目光投向广州&#xff0c;星贝育园自闭症…

1网络安全的基本概念

文章目录 网络安全的基本概念可以总结为以下几个方面&#xff1a; 网络安全的需求&#xff1a; 信息安全的重要性&#xff1a;信息安全是计算机、通信、物理、数学等领域的交叉学科&#xff0c;对于社会的发展至关重要。信息安全的目标&#xff1a;主要包括保密性、完整性、可用…

萃取硫酸镍萃取槽技改离心萃取机

将硫酸镍萃取工艺中的萃取槽技改为离心萃取机&#xff0c;是一个旨在提高生产效率、降低能耗、改善产品质量的技术升级过程。以下是对这一技改过程的详细分析&#xff1a; 一、技改背景 传统萃取槽在硫酸镍萃取过程中存在分相效果差、澄清时间长、有夹带等问题&#xff0c;这些…

mat (Eclipse Memory Analyzer Tool)使用以及详解

前言 在Java开发中&#xff0c;内存问题往往不易被发现&#xff0c;但它们可能导致应用性能下降甚至崩溃。Eclipse Memory Analyzer Tool&#xff08;MAT&#xff09;是一个强大的开源工具&#xff0c;专门用于分析Java堆转储&#xff08;heap dumps&#xff09;文件&#xff…