Meta Llama 3 RMSNorm(Root Mean Square Layer Normalization)

Meta Llama 3 RMSNorm(Root Mean Square Layer Normalization)

flyfish

目录

  • Meta Llama 3 RMSNorm(Root Mean Square Layer Normalization)
    • 先看LayerNorm和BatchNorm
    • 举个例子计算 LayerNorm
    • RMSNorm 的整个计算过程
      • 实际代码实现
      • 结果

先看LayerNorm和BatchNorm

展示计算的方向
在这里插入图片描述

  • axis=0 代表第一个轴,逐列处理数据。
  • axis=1 代表第二个轴,逐行处理数据。在二维数组中,axis=-1 等同于 axis=1。
  • axis=-1 代表最后一个轴。在二维数组中,axis=-1 等同于 axis=1,即最后一个轴。

在二维的情况 下,BatchNorm是按列算,LayerNorm按行算

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nnclass CustomLayerNorm:def __init__(self, eps=1e-5):self.eps = epsdef __call__(self, x):mean = np.mean(x, axis=-1, keepdims=True)std = np.std(x, axis=-1, keepdims=True)normalized = (x - mean) / (std + self.eps)return normalizedclass CustomBatchNorm:def __init__(self, eps=1e-5):self.eps = epsdef __call__(self, x):mean = np.mean(x, axis=0)std = np.std(x, axis=0)normalized = (x - mean) / (std + self.eps)return normalized# Original Data
data = np.array([[1.0, 2.0, 3.0],[4.0, 5.0, 6.0],[7.0, 8.0, 9.0]])# Apply Custom LayerNorm
custom_layer_norm = CustomLayerNorm()
custom_layer_norm_data = custom_layer_norm(data)# Apply Custom BatchNorm
custom_batch_norm = CustomBatchNorm()
custom_batch_norm_data = custom_batch_norm(data)# Apply PyTorch LayerNorm
data_tensor = torch.tensor(data, dtype=torch.float32)
layer_norm = nn.LayerNorm(data_tensor.size()[1:])
pytorch_layer_norm_data = layer_norm(data_tensor).detach().numpy()# Compare Custom and PyTorch LayerNorm
print("Original Data:\n", data)
print("Custom LayerNorm Data:\n", custom_layer_norm_data)
print("PyTorch LayerNorm Data:\n", pytorch_layer_norm_data)
Original Data:[[1. 2. 3.][4. 5. 6.][7. 8. 9.]]
Custom LayerNorm Data:[[-1.22472987  0.          1.22472987][-1.22472987  0.          1.22472987][-1.22472987  0.          1.22472987]]
PyTorch LayerNorm Data:[[-1.2247356  0.         1.2247356][-1.2247356  0.         1.2247356][-1.2247356  0.         1.2247356]]

举个例子计算 LayerNorm

具体步骤如下:

  1. 计算每行的均值
  • 对每一行,计算其均值。
  • 第1行: mean = (1 + 2 + 3) / 3 = 2
  • 第2行: mean = (4 + 5 + 6) / 3 = 5
  • 第3行: mean = (7 + 8 + 9) / 3 = 8
  1. 计算每行的标准差
  • 对每一行,计算其标准差。
  • 第1行: s t d = s q r t ( ( ( 1 − 2 ) 2 + ( 2 − 2 ) 2 + ( 3 − 2 ) 2 ) / 3 ) = s q r t ( ( 1 + 0 + 1 ) / 3 ) = s q r t ( 2 / 3 ) ≈ 0.8165 std = sqrt(((1-2)^2 + (2-2)^2 + (3-2)^2) / 3) = sqrt((1 + 0 + 1) / 3) = sqrt(2 / 3) ≈ 0.8165 std=sqrt(((12)2+(22)2+(32)2)/3)=sqrt((1+0+1)/3)=sqrt(2/3)0.8165
  • 第2行: s t d = s q r t ( ( ( 4 − 5 ) 2 + ( 5 − 5 ) 2 + ( 6 − 5 ) 2 ) / 3 ) = s q r t ( ( 1 + 0 + 1 ) / 3 ) = s q r t ( 2 / 3 ) ≈ 0.8165 std = sqrt(((4-5)^2 + (5-5)^2 + (6-5)^2) / 3) = sqrt((1 + 0 + 1) / 3) = sqrt(2 / 3) ≈ 0.8165 std=sqrt(((45)2+(55)2+(65)2)/3)=sqrt((1+0+1)/3)=sqrt(2/3)0.8165
  • 第3行: s t d = s q r t ( ( ( 7 − 8 ) 2 + ( 8 − 8 ) 2 + ( 9 − 8 ) 2 ) / 3 ) = s q r t ( ( 1 + 0 + 1 ) / 3 ) = s q r t ( 2 / 3 ) ≈ 0.8165 std = sqrt(((7-8)^2 + (8-8)^2 + (9-8)^2) / 3) = sqrt((1 + 0 + 1) / 3) = sqrt(2 / 3) ≈ 0.8165 std=sqrt(((78)2+(88)2+(98)2)/3)=sqrt((1+0+1)/3)=sqrt(2/3)0.8165
  1. 标准化每一行
  • 对每一行,使用均值和标准差进行标准化。公式为: ( x − m e a n ) / ( s t d + e p s ) (x - mean) / (std + eps) (xmean)/(std+eps)。其中 eps 是一个小常数,防止除零,通常取值为 1e-5。
  • 计算结果如下:

标准化公式: n o r m a l i z e d = ( x − m e a n ) / ( s t d + e p s ) normalized = (x - mean) / (std + eps) normalized=(xmean)/(std+eps)

第1行: 
[(1-2)/(0.8165+1e-5), (2-2)/(0.8165+1e-5), (3-2)/(0.8165+1e-5)]
= [-1.2247, 0, 1.2247]第2行: 
[(4-5)/(0.8165+1e-5), (5-5)/(0.8165+1e-5), (6-5)/(0.8165+1e-5)]
= [-1.2247, 0, 1.2247]第3行: 
[(7-8)/(0.8165+1e-5), (8-8)/(0.8165+1e-5), (9-8)/(0.8165+1e-5)]
= [-1.2247, 0, 1.2247]

最终标准化结果矩阵为:

[[-1.2247, 0, 1.2247][-1.2247, 0, 1.2247][-1.2247, 0, 1.2247]]

RMSNorm 的整个计算过程

Meta Llama 3 使用了RMSNorm
假设我们有以下 2D 输入张量 X X X(为了简单起见,我们假设这个张量有 2 行 3 列):
[ 1 2 3 4 5 6 ] \begin{bmatrix}1 & 2 & 3 \\ 4 & 5 & 6 \end{bmatrix} [142536]
RMSNorm 的计算过程如下:

  1. 计算每行的均方根 (RMS)
    首先,对于每一行,我们计算该行元素的平方和的均值,然后取其平方根。
    对于第 1 行:
    RMS row1 = 1 2 + 2 2 + 3 2 3 = 1 + 4 + 9 3 = 4.67 ≈ 2.16 \text{RMS}_{\text{row1}} = \sqrt{\frac{1^2 + 2^2 + 3^2}{3}} = \sqrt{\frac{1 + 4 + 9}{3}} = \sqrt{4.67} \approx 2.16 RMSrow1=312+22+32 =31+4+9 =4.67 2.16
    对于第 2 行:
    RMS row2 = 4 2 + 5 2 + 6 2 3 = 16 + 25 + 36 3 = 25.67 ≈ 5.07 \text{RMS}_{\text{row2}} = \sqrt{\frac{4^2 + 5^2 + 6^2}{3}} = \sqrt{\frac{16 + 25 + 36}{3}} = \sqrt{25.67} \approx 5.07 RMSrow2=342+52+62 =316+25+36 =25.67 5.07
  2. 使用均方根对输入进行归一化
    将每行的元素除以该行的 RMS 值。这里的 epsilon 用于防止除以零的问题,我们假设 ϵ = 1 e − 6 \epsilon = 1e-6 ϵ=1e6
    对于第 1 行: Normed row1 = [ 1 2.16 + ϵ 2 2.16 + ϵ 3 2.16 + ϵ ] ≈ [ 0.462 0.925 1.387 ] \text{Normed}_{\text{row1}} = \begin{bmatrix} \frac{1}{2.16 + \epsilon} & \frac{2}{2.16 + \epsilon} & \frac{3}{2.16 + \epsilon} \end{bmatrix} \approx \begin{bmatrix} 0.462 & 0.925 & 1.387 \end{bmatrix} Normedrow1=[2.16+ϵ12.16+ϵ22.16+ϵ3][0.4620.9251.387]
    对于第 2 行: Normed row2 = [ 4 5.07 + ϵ 5 5.07 + ϵ 6 5.07 + ϵ ] ≈ [ 0.789 0.986 1.183 ] \text{Normed}_{\text{row2}} = \begin{bmatrix} \frac{4}{5.07 + \epsilon} & \frac{5}{5.07 + \epsilon} & \frac{6}{5.07 + \epsilon} \end{bmatrix} \approx \begin{bmatrix} 0.789 & 0.986 & 1.183 \end{bmatrix} Normedrow2=[5.07+ϵ45.07+ϵ55.07+ϵ6][0.7890.9861.183]
  3. 应用可学习的缩放参数
    假设权重参数 weight \text{weight} weight 为一个向量 [ 1 , 1 , 1 ] [1, 1, 1] [1,1,1],表示每个元素的缩放因子。对于第 1 行: Output row1 = [ 0.462 ⋅ 1 0.925 ⋅ 1 1.387 ⋅ 1 ] = [ 0.462 0.925 1.387 ] \text{Output}_{\text{row1}} = \begin{bmatrix} 0.462 \cdot 1 & 0.925 \cdot 1 & 1.387 \cdot 1 \end{bmatrix} = \begin{bmatrix} 0.462 & 0.925 & 1.387 \end{bmatrix} Outputrow1=[0.46210.92511.3871]=[0.4620.9251.387]对于第 2 行: Output row2 = [ 0.789 ⋅ 1 0.986 ⋅ 1 1.183 ⋅ 1 ] = [ 0.789 0.986 1.183 ] \text{Output}_{\text{row2}} = \begin{bmatrix} 0.789 \cdot 1 & 0.986 \cdot 1 & 1.183 \cdot 1 \end{bmatrix} = \begin{bmatrix} 0.789 & 0.986 & 1.183 \end{bmatrix} Outputrow2=[0.78910.98611.1831]=[0.7890.9861.183]

实际代码实现

以下是使用 PyTorch 实现上述步骤的代码示例:

import torch
import torch.nn as nnclass RMSNorm(nn.Module):def __init__(self, dim: int, eps: float = 1e-6):super().__init__()self.eps = epsself.weight = nn.Parameter(torch.ones(dim))def _norm(self, x):return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)def forward(self, x):output = self._norm(x.float()).type_as(x)return output * self.weight# 示例数据
data = torch.tensor([[1.0, 2.0, 3.0],[4.0, 5.0, 6.0]])# 实例化 RMSNorm 层
rms_norm = RMSNorm(dim=data.size(-1))# 计算归一化后的输出
normalized_data = rms_norm(data)print("Original Data:\n", data)
print("RMSNorm Normalized Data:\n", normalized_data)

结果

运行上述代码后,我们将得到归一化后的数据:

 tensor([[1., 2., 3.],[4., 5., 6.]])
RMSNorm Normalized Data:tensor([[0.4629, 0.9258, 1.3887],[0.7895, 0.9869, 1.1843]], grad_fn=<MulBackward0>)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/24814.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux内核epoll

Linux网络IO模型 同步和异步&#xff0c;阻塞和非阻塞 Linux下的五种IO模型 同步和异步&#xff0c;阻塞和非阻塞 Linux 下的五种I/O模型&#xff1a; 阻塞IO&#xff08;Blocking IO&#xff09; BIO 非阻塞IO&#xff08;No Blocking IO&#xff09; IO复用&#xff08;se…

手把手教你实现条纹结构光三维重建(1)——多频条纹生成

关于条纹结构光三维重建的多频相移、格雷码、格雷码相移、互补格雷码等等编码方法&#xff0c;我们在大多数平台上&#xff0c;包括现在使用语言大模型提问&#xff0c;都可以搜到相关的理论&#xff0c;本人重点是想教会你怎么快速用代码实现。 首先说下硬件要求&#xff0c;…

从0到1:企业办公审批小程序开发笔记

可行性分析 企业办公审批小程序&#xff0c;适合各大公司&#xff0c;企业&#xff0c;机关部门办公审批流程&#xff0c;适用于请假审批&#xff0c;报销审批&#xff0c;外出审批&#xff0c;合同审批&#xff0c;采购审批&#xff0c;入职审批&#xff0c;其他审批等规划化…

云计算期末复习(3)

Amazon云计算 习题 私有IP、公有IP和弹性IP的区别在哪里? EC2的实例一旦被创建就会动态地分配公共IP地址和私有IP地址。私有IP地址由动态主机配置协议(DHCP)分配产生。 私有IP、公有IP和弹性IP的主要区别在于它们的使用场景、可达性和管理方式&#xff1a; 私有IP&#xff1a…

46-1 护网溯源 - 钓鱼邮件溯源

一、客户提供钓鱼邮件样本 二、行为分析 三、样本分析 对钓鱼邮件中的木马程序1111.exe文件进行了分析,提交了360安全大脑沙箱云和微步在线云沙箱。 360安全大脑沙箱云显示,该1111.exe文件存在危险,因此在解压时需要谨慎操作,以免触发木马程序。 建议使用360压缩软件进行…

面试(02)————Java集合篇

目录 一、为什么数组索引是从0开始&#xff1f;如果从1开始不行吗&#xff1f; 二、ArrayList底层的实现原理是什么&#xff1f; ​编辑三、ArrayList list new ArrayList(10)中的list扩容几次&#xff1f; 四、如何实现数组与List之间的转换&#xff1f; 五、ArrayList…

Swift 序列(Sequence)排序面面俱到 - 从过去到现在(三)

概述 在上一篇 Swift 序列(Sequence)排序面面俱到 - 从过去到现在(二) 博文中,我们介绍了如何构建一个自定义类型中“多属性”排序的通用实现。 而在本课中我们将再接再厉介绍 iOS 15+ 中新的排序机制,并简要剖析就地排序(In-place sorting)对运行性能有着怎样的显著影…

基础乐理入门

基础概念 乐音&#xff1a;音高&#xff08;频率&#xff09;固定&#xff0c;振动规则的音。钢琴等乐器发出的是乐音&#xff0c;听起来悦耳、柔和。噪音&#xff1a;振动不规则&#xff0c;音高也不明显的音。风声、雨声、机器轰鸣声是噪音&#xff0c;大多数打击乐器&#…

【RK3568】制作Android11开机动画

Android 开机 logo 分为两种&#xff1a;静态显示和动态显示。静态显示就是循环显示一张图片&#xff1b;动态显示就是以特定帧率顺序显示多张图片 1.准备 android logo 图片 Android logo最好是png格式的&#xff0c;因为同一张图片的情况下&#xff0c;png 格式的比 jpg和b…

线性表和链表

一&#xff0c;线性结构 1.Array Array文档&#xff1a;可以自行阅读相关文档来了解Array class array.array(typecode[, initializer]) array.append(x)&#xff1a;添加元素到数组末尾 array.count(x)&#xff1a;计算元素出现次数 array.extend(iterable)&#xff1a;将迭代…

shell编程(二)——字符串与数组

本文为shell 编程的第二篇&#xff0c;介绍shell中的字符串和数组相关内容。 一、字符串 shell 字符串可以用单引号 ‘’&#xff0c;也可以用双引号 “”&#xff0c;也可以不用引号。 单引号的特点 单引号里不识别变量单引号里不能出现单独的单引号&#xff08;使用转义符…

ChatTTS增强版V2,批量导出srt,语速控制,情感控制,支持朗读数字,问题修复

ChatTTS增强版最新版本已经发布&#xff0c;本次更新我主要增加了多文本批量、SRT导出、语速控制、情感控制、停顿控制等新功能&#xff0c;并针对上一版本中存在的数字读音异常、随机uv_break等问题进行了修复。 视频版本 【ChatTTS增强版V2&#xff0c;批量导出srt&#xff…

Android AAudio——C API控制音频流(四)

上一篇文章我们介绍了 C API 中音频流的创建流程,以及打开音频流操作,这里我们再来看一下音频流的其他操作流程 一、音频流操作介绍 1、操作流程图 下图是状态变化流程图,虚线框表示瞬时状态,实线框表示稳定状态。 2、操作函数 上图中主要包含下面几个操作函数: aaudio…

2022 hnust 湖科大 javaweb课设 数据库课设 报告+源代码+流程图文件+课设指导书+附赠数据库课堂实验指导书

2022 hnust 湖科大 javaweb课设 数据库课设 报告源代码流程图文件课设指导书附赠数据库课堂实验指导书 描述 湖南科技大学大二下学期先后开展java web和数据库课程设计&#xff0c;两个课设项目可以通用&#xff0c;老师一般会允许自拟选题&#xff0c;所以在此统一打包&…

批量高效调整图片像素:自定义缩小bmp图片,画质优先,一键实现高效优化

图片已经成为我们生活中不可或缺的一部分。无论是社交媒体分享&#xff0c;还是工作文件传输&#xff0c;图片总是扮演着重要的角色。然而&#xff0c;有时候&#xff0c;我们可能会面临一个问题&#xff1a;图片像素过大&#xff0c;不仅占用过多的存储空间&#xff0c;还可能…

Linux编译器-gcc或g++的使用

一.安装gcc/g 在linux中是不会自带gcc/g的&#xff0c;我们需要编译程序就自己需要安装gcc/g。 很简单我们使用简单的命令安装gcc&#xff1a;sudo yum install -y gcc。 g安装&#xff1a;sudo yum install -y gcc-c。 我们知道Windows上区分文件&#xff0c;都是使用文件…

如何使用Python的Turtle模块绘制小猪

一、前置条件 在开始学习如何使用Python的Turtle模块进行绘画之前&#xff0c;请确保你的电脑已安装Python环境。如果尚未安装Python&#xff0c;你可以从Python官网下载并安装最新版本。 Turtle模块是Python内置的一个用于绘图的库&#xff0c;通常不需要额外安装。如果你发…

反转链表 (oj题)

一、题目链接 https://leetcode.cn/problems/reverse-linked-list/submissions/538124207 二、题目思路 1.定义三个指针&#xff0c;p1先指向NULL p2指向头结点 p3指向第二个结点 2.p2的next指向p1。然后移动指针&#xff0c;p1来到p2的位置&#xff0c;p2来到p3的位置&…

中缀表达式和前缀后缀

在中缀表达式中&#xff0c;操作数可能与两个操作符相结合 但是&#xff0c;想要不带括号无歧义&#xff0c;且不需要考虑运算符优先级和结合性 所以考虑 前缀表达式&#xff0c;波兰表达式 后缀表达式 逆波兰表达式 对于人来说&#xff0c;中缀表达式是最容易读懂的。但是对于…

基于JSP技术的网络视频播放器

你好呀&#xff0c;我是计算机学长猫哥&#xff01;如果有相关需求&#xff0c;文末可以找到我的联系方式。 开发语言&#xff1a;Java 数据库&#xff1a;MySQL 技术&#xff1a;JSP技术 工具&#xff1a;IDEA/Eclipse、Navicat、Maven 系统展示 首页 管理员界面 用户界…