Gemma模型论文详解(附源码)

原文链接:Gemma模型论文详解(附源码)

1. 背景介绍

Gemma模型是在2023.2.21号Google新发布的大语言模型, Gemma复用了Gemini相同的技术(Gemini也是Google发布的多模态模型),Gemma这次发布了了2B和7B两个版本的参数,不仅提供了预训练的checkpoints,还提供了用于对话、指令跟随等fine-tune的checkpoints。在QA问答、常识。在11

在这里插入图片描述

2. 模型介绍

2.1 模型结构

Gemma模型使用了transformer decoder结构进行训练,训练的上下文大小为8192个token,模型参数如下:
在这里插入图片描述

相比原始transformer结构的区别:

  • Multi-Query Attention:7B模型使用了multi-head attention,2B模型使用了multi-query attention (with 𝑛𝑢𝑚_𝑘𝑣_ℎ𝑒𝑎𝑑𝑠 = 1)。对比llama2中用了group-query attention
    在这里插入图片描述

  • RoPE Embeddings: 不使用绝对位置编码,在每一层前加下RoPE Embedding,同时共享输入与输出层的embedding权重。

  • GeGLU Activations: ReLU的激活替换为GeGLU的激活。对比llama中用了swiglu。

  • Normalizer Location: 在transformer的每一层layer的前后都进行规一化,这里使用RMSNorm做为规一化层。

2.2 训练搭建

Gemma使用TPUv5e进行训练;一个pod中有256块TPUv5e芯片,256块芯片被设计为16X16的2D拓扑;Gemma-7B使用16个pods(4096块卡)进行训练,Gemma-2B使用2个pods(512块卡)。7B模型在一个pod内使用16路模型并行和16路数据并行,2B模型在一个pod内使用256路数据并行。优化器状态使用ZeRO-3进行切分,减少显存占用。在pod外使用类似Pathways的方式减少数据复制的成本。

和Gemini模型训练一样,综合了Jax和Pathways的单控制器single controller编程范式,使用单个python进程编排整个训练; 使用GSPMD partitioner用于训练step的计算,使用XLA compiler减少中间结果的大小。

2.3 训练数据

Gemma 2B和7B分别基于2T和6T个token进行训练,token来源于纯英文的文本,内容包括网页、数学、代码等。使用SentencePiece的tokenizer,字典大小有256K个token。数据过滤使用基于模型的分类器去除有害的、低质量的内容。最后采用类似Gemini的方式进行训练数据的混合,提升高质量数据的占比。

2.4 指令微调(Instruction Tuning)

2B和7B进行有监督微调(SFT)训练中使用混合生成数据和人工标注的prompt文本对,同时进行RLHF训练。在SFT阶段,基于给定的一个prompt,通过测试模型生成多个响应的回答结果,通过一个更大更好的模型进行结果的好坏判断。基于不同的侧重方向(指令跟随/事实/创造性/安全等)构建不同的prompt。使用多种基于LM的自动判断方法,比如chain-of-thought prompting

训练和推理过程中使用相同的数据格式,格式的设计重点在于两点,一个是确定多轮对话中的角色,一个是确定一轮对话的开始结束。对应格式标记和示例的训练数据如下:

在这里插入图片描述
在这里插入图片描述

3. 源码

  • Tensorflow实现的源码在github google-deepmind/gemma中,PyTorch实现的源码在github google/gemma_pytorch。

  • 模型的配置在gemma/config.py文件中, 7B与2B区别主要在于num_hidden_layers/num_attention_heads/num_key_value_heads/hidden_size/intermediate_size

@dataclasses.dataclass
class GemmaConfig:# The number of tokens in the vocabulary.vocab_size: int = 256000# The maximum sequence length that this model might ever be used with.max_position_embeddings: int = 8192# The number of blocks in the model.num_hidden_layers: int = 28# The number of attention heads used in the attention layers of the model.num_attention_heads: int = 16# The number of key-value heads for implementing attention.num_key_value_heads: int = 16# The hidden size of the model.hidden_size: int = 3072# The dimension of the MLP representations.intermediate_size: int = 24576# The number of head dimensions.head_dim: int = 256# The epsilon used by the rms normalization layers.rms_norm_eps: float = 1e-6# The dtype of the weights.dtype: str = 'bfloat16'# Whether a quantized version of the model is used.quant: bool = False# The path to the model tokenizer.tokenizer: Optional[str] = 'tokenizer/tokenizer.model'def get_dtype(self) -> Optional[torch.dtype]:"""Gets the torch dtype from the config dtype string."""return _STR_DTYPE_TO_TORCH_DTYPE.get(self.dtype, None)def get_config_for_7b() -> GemmaConfig:return GemmaConfig()def get_config_for_2b() -> GemmaConfig:return GemmaConfig(num_hidden_layers=18,num_attention_heads=8,num_key_value_heads=1,hidden_size=2048,intermediate_size=16384)
  • 模型定义在gemma/model.py文件中,GemmaDecoderLayer的定义如下:
class GemmaDecoderLayer(nn.Module):def __init__(self,config: gemma_config.GemmaConfig,):super().__init__()self.self_attn = GemmaAttention(hidden_size=config.hidden_size,num_heads=config.num_attention_heads,num_kv_heads=config.num_key_value_heads,head_dim=config.head_dim,quant=config.quant,)self.mlp = GemmaMLP(hidden_size=config.hidden_size,intermediate_size=config.intermediate_size,quant=config.quant,)self.input_layernorm = RMSNorm(config.hidden_size,eps=config.rms_norm_eps)self.post_attention_layernorm = RMSNorm(config.hidden_size,eps=config.rms_norm_eps)
  • GeGLU的实现跟llama的swiglu不同,geglu相比glu区是采用了gelu的激活,以下是glu的计算示例图:
    在这里插入图片描述

代码参考如下,代码中self.gate_proj对应上图中的B矩阵,gate相当于 σ ( B ) \sigma(B) σ(B)self.up_proj对应上图中的A矩阵.

class GemmaMLP(nn.Module):def __init__(self,hidden_size: int,intermediate_size: int,quant: bool,):super().__init__()self.gate_proj = Linear(hidden_size, intermediate_size, quant)self.up_proj = Linear(hidden_size, intermediate_size, quant)self.down_proj = Linear(intermediate_size, hidden_size, quant)def forward(self, x):gate = self.gate_proj(x)gate = F.gelu(gate)up = self.up_proj(x)fuse = gate * upoutputs = self.down_proj(fuse)return outputs

4. 参考

  • google-deepmind/gemma
  • Gemma 开放模型
  • Gemma: Open Models Based on Gemini Research and Technology
  • gemma-open-models
  • github google/gemma_pytorch
  • github google-deepmind/gemma
  • Grouped Query Attention论文阅读
  • SwiGLU论文阅读

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/698303.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何快速卸载windows电脑的一些软件?

本系列是一些电脑常规操作的普及,有需要借鉴即可 注:每个电脑都会有差异,参考即可。 其实大部分软件你删除桌面上的图标不等于删除,因为桌面上的那个图标就是一个简单的快捷方式而已。 在这里插入图片描述 那如何正确的卸载软件呢…

Android 广播的基本概念

一.广播简介 Broadcast是安卓四大组件之一。安卓为了方便进行系统级别的消息通知,引入了一套广播消息机制。打个比方,记得原来在上课的时候,每个班级的教室里都会装有一个喇叭,这些喇叭都是接入到学校的广播室的,一旦…

(done) 什么是特征值和特征向量?如何求特征值的特征向量 ?如何判断一个矩阵能否相似对角化?

什么是齐次方程? https://blog.csdn.net/shimly123456/article/details/136198159 行列式和是否有解的关系? https://blog.csdn.net/shimly123456/article/details/136198215 特征值和特征向量 参考视频:https://www.bilibili.com/video/BV…

《Solidity 简易速速上手小册》第4章:智能合约的设计与开发(2024 最新版)

文章目录 4.1 合约结构和布局4.1.1 基础知识解析深入合约布局原则理解组织结构高效布局的重要性4.1.2 重点案例:构建一个在线商店合约案例 Demo:编写在线商店智能合约案例代码:OnlineStore.sol测试和验证拓展功能4.1.3 拓展案例 1:可升级的合约案例 Demo:创建可升级的智能…

【主题广范|见刊快】2024年科技,绿色能源和可持续发展国际会议(ICTGESD 2024)

【主题广范|见刊快】2024年科技,绿色能源和可持续发展国际会议(ICTGESD 2024) 2024 International Conference on Technology, Green Energy, and Sustainable Development ⊙会议简介: 2024年科技,绿色能源和可持续发…

音视频技术-声反馈啸叫的产生与消除

目录 1.均衡调节: 2.移频法: 3.移相法: 4.比较法: 在扩音系统中,产生啸叫危害很大,一方面影响会议、演出等活动的正常进行,另一方面严重的啸叫会导致音响设备的损坏。 “啸叫”是“声反馈”的俗称,形成的机制复杂,消除的手段多样,专业调音师也对

Unity中URP实现水效果(水的深度)

文章目录 前言一、搭建预备场景1、新建一个面片,使其倾斜一个角度,来模拟水底和岸边的效果2、随便创建几个物体,作为与水面接触的物体3、再新建一个面片,作为水面 二、开始编写水体的Shader效果1、新建一个URP基础Shader2、把水体…

最优传输(Optimal Transport)

最优传输(Optimal Transport)是一种数学理论和计算方法,用于描述两个概率分布之间的距离或者对应关系。它的核心概念是如何以最佳方式将一组资源(如质量、能量等)从一个位置传输到另一个位置。 基本概念: …

Java编程实战:构建医疗信息管理新平台

✍✍计算机编程指导师 ⭐⭐个人介绍:自己非常喜欢研究技术问题!专业做Java、Python、微信小程序、安卓、大数据、爬虫、Golang、大屏等实战项目。 ⛽⛽实战项目:有源码或者技术上的问题欢迎在评论区一起讨论交流! ⚡⚡ Java实战 |…

sql注入 [极客大挑战 2019]FinalSQL1

打开题目 点击1到5号的结果 1号 2号 3号 4号 5号 这里直接令传入的id6 传入id1^1^1 逻辑符号|会被检测到,而&感觉成了注释符,&之后的内容都被替换掉了。 传入id1|1 直接盲注比较慢,还需要利用二分法来编写脚本 这里利用到大佬的脚…

英伟达推出免训练,可生成连贯图片的文生图模型

目前,多数文生图模型皆使用的是随机采样模式,使得每次生成的图像效果皆不同,在生成连贯的图像方面非常差。 例如,想通过AI生成一套图像连环画,即便使用同类的提示词也很难实现。虽然DALLE 3和Midjourney可以对图像实现…

linux0.11 源码阅读 head.s setup.s bootsect.s加载位置

从github上下载linux0.11源码 linux0.11源码 将0x10000处的代码往下复制到0开始的地址处。 移动后的内存布局如下 setup中存在gdt和idt的相关数据。此时需要用gdtr和idtr寄存器指向对应的数据。 实模式下,访问内存方式。最多访问1M内存。

有哪些适合程序员的副业?

如果你经常玩知乎、看公众号(软件、工具、互联网这几类的)你就会发现,好多资源连接都变成了夸克网盘、迅雷网盘的资源链接。 例如:天涯神贴,基本上全是夸克、UC、迅雷网盘的资源链接。 有资源的前提下,迅雷…

人工智能 — 图像滤波器

目录 一、图像噪声1、高斯噪声2、椒盐噪声3、泊松噪声4、乘性噪声5、瑞利噪声6、伽马噪声 二、图像滤波三、各种滤波器1、均值滤波2、中值滤波3、最大最小值滤波4、引导滤波 四、图像增强1、点处理1、线性变换2、分段线性变换3、对数变换4、幂律变换/伽马变换 2、领域处理3、图…

2006-2021年地级市资本存量数据(含原始数据+计算过程+计算结果)(以2006年为基期)

2006-2021年地级市资本存量数据(含原始数据计算过程计算结果)(以2006年为基期) 1、时间:2006-2021年 2、来源:城市年鉴、统计年鉴、各省年鉴、各市年鉴和公报、2017-2021年利用固定资产投资增速计算获取 …

【C语言】内存操作,内存函数篇---memcpy,memmove,memset和memcmp内存函数的使用和模拟实现【图文详解】

欢迎来CILMY23的博客喔,本篇为​【C语言】内存操作,内存函数篇---memcpy,memmove,memset和memcmp内存函数的使用和模拟实现【图文详解】,图文讲解四种内存函数,带大家更深刻理解C语言中内存函数的操作&…

WooCommerce商品采集与发布插件

如何采集商品或产品信息,并自动发布到Wordpress系统的WooCommerce商品? 推荐使用简数采集器,操作简单方便,且无缝衔接WooCommerce插件,快速完成商品的采集与发布。 简数采集器的智能自动生成采集规则和可视化操作功能…

Pytorch学习(杂知识)

Mini-batch Mii-batch是一种在机器学习中常用的训练算法。它是将大的数据集分成一些小的数据集,每次只用一个小的数据集来训练模型。通常情况下,训练数据集中的数据越多,训练出的模型越准确,但是如果数据集太大,就会导…

【EI会议征稿通知】第四届生物医学与生物信息工程国际学术会议(ICBBE 2024)

第四届生物医学与生物信息工程国际学术会议(ICBBE 2024) The 4th International Conference on Biomedicine and Bioinformatics Engineering 由河南大学主办,中州实验室、河南大学基础医学院、河南大学郑州校区学术发展部共同承办的第四届生…

微信小程序 --- 微信原生 API

微信原生 API 1. API 基础 小程序开发框架提供丰富的微信原生 API,可以方便的调起微信提供的能力,如获取用户信息,本地存储,支付功能等,几乎所有小程序的 API 都挂载在 wx 对象底下,例如:wx.c…