[NLP] BERT模型参数量

一 BERT_Base 110M参数拆解

BERT_base模型的110M的参数具体是如何组成的呢,我们一起来计算一下:

刚好也能更深入地了解一下Transformer Encoder模型的架构细节。

借助transformers模块查看一下模型的架构:

import torch
from transformers import BertTokenizer, BertModelbertModel = BertModel.from_pretrained('bert-base-uncased', output_hidden_states=True, output_attentions=True)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
for name,param in bertModel.named_parameters():print(name, param.shape)

得到的模型参数为:

embeddings.word_embeddings.weight torch.Size([30522, 768])
embeddings.position_embeddings.weight torch.Size([512, 768])
embeddings.token_type_embeddings.weight torch.Size([2, 768])
embeddings.LayerNorm.weight torch.Size([768])
embeddings.LayerNorm.bias torch.Size([768])encoder.layer.0.attention.self.query.weight torch.Size([768, 768])
encoder.layer.0.attention.self.query.bias torch.Size([768])
encoder.layer.0.attention.self.key.weight torch.Size([768, 768])
encoder.layer.0.attention.self.key.bias torch.Size([768])
encoder.layer.0.attention.self.value.weight torch.Size([768, 768])
encoder.layer.0.attention.self.value.bias torch.Size([768])encoder.layer.0.attention.output.dense.weight torch.Size([768, 768])
encoder.layer.0.attention.output.dense.bias torch.Size([768])
encoder.layer.0.attention.output.LayerNorm.weight torch.Size([768])
encoder.layer.0.attention.output.LayerNorm.bias torch.Size([768])encoder.layer.0.intermediate.dense.weight torch.Size([3072, 768])
encoder.layer.0.intermediate.dense.bias torch.Size([3072])
encoder.layer.0.output.dense.weight torch.Size([768, 3072])
encoder.layer.0.output.dense.bias torch.Size([768])
encoder.layer.0.output.LayerNorm.weight torch.Size([768])
encoder.layer.0.output.LayerNorm.bias torch.Size([768])encoder.layer.11.attention.self.query.weight torch.Size([768, 768])
encoder.layer.11.attention.self.query.bias torch.Size([768])
encoder.layer.11.attention.self.key.weight torch.Size([768, 768])
encoder.layer.11.attention.self.key.bias torch.Size([768])
encoder.layer.11.attention.self.value.weight torch.Size([768, 768])
encoder.layer.11.attention.self.value.bias torch.Size([768])
encoder.layer.11.attention.output.dense.weight torch.Size([768, 768])
encoder.layer.11.attention.output.dense.bias torch.Size([768])
encoder.layer.11.attention.output.LayerNorm.weight torch.Size([768])
encoder.layer.11.attention.output.LayerNorm.bias torch.Size([768])
encoder.layer.11.intermediate.dense.weight torch.Size([3072, 768])
encoder.layer.11.intermediate.dense.bias torch.Size([3072])
encoder.layer.11.output.dense.weight torch.Size([768, 3072])
encoder.layer.11.output.dense.bias torch.Size([768])
encoder.layer.11.output.LayerNorm.weight torch.Size([768])
encoder.layer.11.output.LayerNorm.bias torch.Size([768])pooler.dense.weight torch.Size([768, 768])
pooler.dense.bias torch.Size([768])

其中,BERT模型的参数主要由三部分组成:

Embedding层参数

Transformer Encoder层参数

LayerNorm层参数

二 Embedding层参数

由于词向量是由Token embedding,Position embedding,Segment embedding三部分构成的,因此embedding层的参数也包括以上三部分的参数。

BERT_base英文词表大小为:30522, 隐藏层hidden_size=768,文本最大长度seq_len = 512

Token embedding参数量为:30522 * 768;

Position embedding参数量为:512 * 768;

Segment embedding参数量为:2 * 768。

因此总的参数量为:(30522 + 512 +2)* 768 = 23,835,648

 

LN层在Embedding层

norm使用的是layer normalization,每个维度有两个参数

768 * 2 = 1536

三 Transformer Encoder层参数

可以将该部分拆解成两部分:Self-attention层参数、Feed-Forward Network层参数

1.Self-attention层参数

改层主要是由Q、K、V三个矩阵运算组成,BERT模型中是Multi-head多头的Self-attention(记为SA)机制。先通过Q和K矩阵运算并通过softmax变换得到对应的权重矩阵,然后将权重矩阵与 V矩阵相乘,最后将12个头得到的结果进行concat,得到最终的SA层输出。

1. multi-head因为分成12份, 单个head的参数是 768 * (768/12) * 3,  紧接着将多个head进行concat再进行变换,此时W的大小是768 * 768

    12个head就是  768 * (768/12) * 3 * 12  + 768 * 768 = 1,769,472 + 589,824 = 2359296

3. LN层在Self-attention层

norm使用的是layer normalization,每个维度有两个参数

768 * 2 = 1536

2.Feed-Forward Network层参数

由FFN(x)=max(0, xW1+b1)W2+b2可知,前馈网络FFN主要由两个全连接层组成,且W1和W2的形状分别是(768,3072),(3072,768),因此该层的参数量为:

feed forward的参数主要由两个全连接层组成,intermediate_size为3072(原文中4H长度) ,那么参数为12*(768*3072+3072*768)= 56623104

LN层在FFN

norm使用的是layer normalization,每个维度有两个参数

768 * 2 = 1536

layer normalization

layer normalization有两个参数,分别是gamma和beta。有三个地方用到了layer normalization,分别是embedding层后、multi-head attention后、feed forward后,这三部分的参数为768*2+12*(768*2+768*2)=38400

四 总结

综上,BERT模型的参数总量为:

23835648 + 12*2359296(28311552)   + 56623104 +  38400  = 108808704  ≈103.7M

Embedding层约占参数总量的20%,Transformer层约占参数总量的80%。

注:本文介绍的参数仅是BERT模型的Transformer Encoder部分的参数,涉及的bias由于参数很少,本文也未计入。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/45085.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux 线程库中的接口介绍

1.pthread_create()创建线程 pthread_create()的语法形式: 参数解释: 第一个参数thread:事先创建好的pthread_t类型的参数。成功时thread指向的内存单元被设置为新创建线程的线程ID。 第二个参数attr:用于定制各种不同的线程属性…

SQL Monitor Crack,PostgreSQL监控的传入复制图表

SQL Monitor Crack,PostgreSQL监控的传入复制图表  现在,您可以在从Estate页面导出的Microsoft Excel报告的摘要标题中看到UTC偏移量。 添加了PostgreSQL监控的传入复制图表。 Microsoft PowerShell API现在支持将使用New-SqlMonitorWindowsHost和New-SqlMonitorin…

【AI大模型】训练Al大模型

大模型超越AI 前言 洁洁的个人主页 我就问你有没有发挥! 知行合一,志存高远。 目前所指的大模型,是“大规模深度学习模型”的简称,指具有大量参数和复杂结构的机器学习模型,可以处理大规模的数据和复杂的问题&#x…

MybatisPlus整合p6spy组件SQL分析

目录 p6spy java为什么需要 如何使用 其他配置 p6spy p6spy是一个开源项目,通常使用它来跟踪数据库操作,查看程序运行过程中执行的sql语句。 p6spy将应用的数据源给劫持了,应用操作数据库其实在调用p6spy的数据源,p6spy劫持到…

uniapp配置添加阿里巴巴图标icon流程步骤

文章目录 下载复制文件到项目文件夹里项目配置目录结构显示图标 下载 阿里巴巴icon官网 https://www.iconfont.cn/ 复制文件到项目文件夹里 项目配置目录结构 显示图标

华为将收取蜂窝物联网专利费,或将影响LPWAN市场发展

近日,华为正式公布了其4G和5G手机、Wi-Fi6设备和物联网产品的专利许可费率,其中包含了长距离通信技术蜂窝物联网。作为蜂窝物联网技术的先驱,华为是LTE Category NB (NB-IoT)、LTE Category M和其他4G物联网标准的主要贡献者。 在NB-IoT领域…

基于traccar快捷搭建gps轨迹应用

0. 环境 - win10 虚拟机ubuntu18 - i5 ubuntu22笔记本 - USB-GPS模块一台,比如华大北斗TAU1312-232板 - 双笔记本组网设备:路由器,使得win10笔记本ip:192.168.123.x,而i5笔记本IP是192.168.123.215。 - 安卓 手机 1.…

React2023电商项目实战 - 1.项目搭建

古人学问无遗力,少壮工夫老始成。 纸上得来终觉浅,绝知此事要躬行。 —— 陆游《《冬夜读书示子聿》》 系列文章目录 项目搭建App登录及网关App文章自媒体平台(博主后台)内容审核(自动) 文章目录 系列文章目录一、项目介绍1.页面…

ubuntu安装Microsoft Edge并设置为中文

1、下载 edge.deb 版本并安装 sudo dpkg -i microsoft-edg.deb 2. 设置默认中文显示 如果是通过.deb方式安装的: 打开默认安装路径下的microsoft-edge-dev文件,在文件最开头加上: export LANGUAGEZH-CN.UTF-8 ,保存退出。 cd /opt/micr…

python绘制谷歌地图

谷歌地图 更多好看的图片见pyecharts官网 import pyecharts.options as opts from pyecharts.charts import MapGlobe from pyecharts.faker import POPULATIONdata [x for _, x in POPULATION[1:]] low, high min(data), max(data)c (MapGlobe().add_schema().add(mapty…

【100天精通python】Day42:python网络爬虫开发_HTTP请求库requests 常用语法与实战

目录 1 HTTP协议 2 HTTP与HTTPS 3 HTTP请求过程 3.1 HTTP请求过程 3.2 GET请求与POST请求 3.3 常用请求报头 3.4 HTTP响应 4 HTTP请求库requests 常用语法 4.1 发送GET请求 4.2 发送POST请求 4.3 请求参数和头部 4.4 编码格式 4.5 requests高级操作-文件上传 4.6 …

Spring Boot 统一功能处理

目录 1.用户登录权限效验 1.1 Spring AOP 用户统一登录验证的问题 1.2 Spring 拦截器 1.2.1 自定义拦截器 1.2.2 将自定义拦截器加入到系统配置 1.3 拦截器实现原理 1.3.1 实现原理源码分析 2. 统一异常处理 2.1 创建一个异常处理类 2.2 创建异常检测的类和处理业务方法 3. 统一…

【Spring系列篇--关于IOC的详解】

目录 面试经典题目: 1. 什么是spring?你对Spring的理解?简单来说,Spring是一个轻量级的控制反转(IoC)和面向切面(AOP)的容器框架。 2.什么是IoC?你对IoC的理解?IoC的重要性?将实例化对象的权利从程序员…

Centos 8 网卡connect: Network is unreachable错误解决办法

现象1、ifconfig没有ens160配置 [testlocalhost ~]$ ifconfig lo: flags73<UP,LOOPBACK,RUNNING> mtu 65536 inet 127.0.0.1 netmask 255.0.0.0 inet6 ::1 prefixlen 128 scopeid 0x10<host> loop txqueuelen 1000 (Local Loopba…

基于深度学习的指针式仪表倾斜校正方法——论文解读

中文论文题目:基于深度学习的指针式仪表倾斜校正方法 英文论文题目&#xff1a;Tilt Correction Method of Pointer Meter Based on Deep Learning 周登科、杨颖、朱杰、王库.基于深度学习的指针式仪表倾斜校正方法[J].计算机辅助设计与图形学学报, 2020, 32(12):9.DOI:10.3724…

【Java】智慧工地SaaS平台源码:AI/云计算/物联网/智慧监管

智慧工地是指运用信息化手段&#xff0c;围绕施工过程管理&#xff0c;建立互联协同、智能生产、科学管理的施工项目信息化生态圈&#xff0c;并将此数据在虚拟现实环境下与物联网采集到的工程信息进行数据挖掘分析&#xff0c;提供过程趋势预测及专家预案&#xff0c;实现工程…

《强化学习:原理与Python实战》——可曾听闻RLHF

前言&#xff1a; RLHF&#xff08;Reinforcement Learning with Human Feedback&#xff0c;人类反馈强化学习&#xff09;是一种基于强化学习的算法&#xff0c;通过结合人类专家的知识和经验来优化智能体的学习效果。它不仅考虑智能体的行为奖励&#xff0c;还融合了人类专家…

kafka安装说明以及在项目中使用

一、window 安装 1.1、下载安装包 下载kafka 地址&#xff0c;其中官方版内置zk&#xff0c; kafka_2.12-3.4.0.tgz其中这个名称的意思是 kafka3.4.0 版本 &#xff0c;所用语言 scala 版本为 2.12 1.2、安装配置 1、解压刚刚下载的配置文件&#xff0c;解压后如下&#x…

【机器学习】处理不平衡的数据集

一、介绍 假设您在一家给定的公司工作&#xff0c;并要求您创建一个模型&#xff0c;该模型根据您可以使用的各种测量来预测产品是否有缺陷。您决定使用自己喜欢的分类器&#xff0c;根据数据对其进行训练&#xff0c;瞧&#xff1a;您将获得96.2%的准确率&#xff01; …

Integer中缓存池讲解

文章目录 一、简介二、实现原理三、修改缓存范围 一、简介 Integer缓存池是一种优化技术&#xff0c;用于提高整数对象的重用和性能。在Java中&#xff0c;对于整数值在 -128 到 127 之间的整数对象&#xff0c;会被放入缓存池中&#xff0c;以便重复使用。这是因为在这个范围…