GPT 系列笔记

open ai 出品, 与 google 的 bert 系列 是不同的任务, NLGeneration vs. NLUnderstanding.

二. GPT-2

hugging-face 的 transformers 库中有模型源码, 为 from transformers.models.gpt2 import GPT2Model.

成员字段

  • wte, Word Token Emb
  • wpe, Word Position Emb
  • h, 模型主体. nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)])

Conv1D

并非 torch cv 中的 Conv1D, 而是 gpt 系列自己的.

class Conv1D(nn.Module):"""1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).Basically works like a linear layer but the weights are transposed.Args:nf (`int`): The number of output features.nx (`int`): The number of input features."""def __init__(self, nf, nx):super().__init__()self.nf = nfw = torch.empty(nx, nf)nn.init.normal_(w, std=0.02)self.weight = nn.Parameter(w)self.bias = nn.Parameter(torch.zeros(nf))def forward(self, x):size_out = x.size()[:-1] + (self.nf,)x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)x = x.view(size_out)return x

addmm() 方法是方程 beta*mat + alpha*(mat1 @ mat2) 的优化版本, 其实就是 kx+b.

GPT2Attention

class GPT2Attention(nn.Module):def __init__(self, config, is_cross_attention=False, layer_idx=None):self.embed_dim = config.hidden_sizeself.num_heads = config.num_attention_headsself.head_dim = self.embed_dim // self.num_headsself.split_size = self.embed_dimself.is_cross_attention = is_cross_attentionself.c_proj = Conv1D(self.embed_dim, self.embed_dim)def _attn(self, query, key, value, attention_mask=None, head_mask=None):attn_weights = torch.matmul(query, key.transpose(-1, -2))attn_weights = nn.functional.softmax(attn_weights, dim=-1)attn_output = torch.matmul(attn_weights, value)return attn_output, attn_weightsdef _split_heads(self, tensor, num_heads, attn_head_size):"""Splits hidden_size dim into attn_head_size and num_heads"""new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)tensor = tensor.view(new_shape)return tensor.permute(0, 2, 1, 3)  # (batch, head, seq_length, head_features)def forward(...):query = self._split_heads(query, self.num_heads, self.head_dim)key = self._split_heads(key, self.num_heads, self.head_dim)value = self._split_heads(value, self.num_heads, self.head_dim)attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)attn_output = self.c_proj(attn_output)outputs = (attn_output, present)return outputs

因为 attention 是 multi_head 的, 所以有四组 qkv 的权重. 但为了计算性能, 用到了 split_heads 等方法, 但可读性急剧下降.

GPT2MLP

class GPT2MLP(nn.Module):def __init__(self, intermediate_size, config):super().__init__()embed_dim = config.hidden_sizeself.c_fc = Conv1D(intermediate_size, embed_dim)self.c_proj = Conv1D(embed_dim, intermediate_size)self.act = ACT2FN[config.activation_function]self.dropout = nn.Dropout(config.resid_pdrop)def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:hidden_states = self.c_fc(hidden_states)hidden_states = self.act(hidden_states)hidden_states = self.c_proj(hidden_states)hidden_states = self.dropout(hidden_states)return hidden_states

GPT2Block

class GPT2Block(nn.Module):def __init__(self, config, layer_idx=None):super().__init__()hidden_size = config.hidden_sizeinner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_sizeself.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)self.attn = GPT2Attention(config, layer_idx=layer_idx)self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)if config.add_cross_attention:self.crossattention = GPT2Attention(config, is_cross_attention=True, layer_idx=layer_idx)self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)self.mlp = GPT2MLP(inner_dim, config)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/61668.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

运维数据(2):全新解说运维的价值和场景

在一个项目中,得知客户高层一直在纠结运维的价值,质疑运维投入的合理性和必要性,要求澄清其具体价值及其与业务的关系。同时,当今IT界流行“场景”一词,不论是解决方案,还是具体产品都在拉关系,…

亚马逊云科技 re:Inforce 大会云安全合规与技术实践及 Security Jam 大赛,快来报名吧!...

‍‍ 2023年8月31日在北京 亚马逊云科技 re:Inforce 大会 首次登陆中国! 我们期待您的莅临, 并与您一起迎接 AI 时代, 开启全面智能的安全旅程! 在13:00-17:00的 培训与动手实验环节中 云安全合规与技术实践 及 Security Jam 大赛…

Ansible学习笔记14

实现多台的分离实现: [rootlocalhost playbook]# cat example3.yaml --- - hosts: 192.168.17.105remote_user: roottasks:- name: create test1 directoryfile: path/test1/ statedirectory- hosts: 192.168.17.106remote_user: roottasks:- name: create test2 d…

Jenkins测试报告样式优化

方式一:修改Content Security Policy(临时解决,Jenkins重启后失效) 1、jenkins首页—>ManageJenkins—>Tools and Actions标题下—>Script Console 2、粘贴脚本输入框中:System.setProperty("hudson.model.Directo…

Ubuntu学习---跟着绍发学linux课程记录(第一部分)

文章目录 1、启动、关闭、挂起、恢复(电源)2、更多虚拟机操作2.1 电源设置2.2 硬件参数设置2.3 状态栏2.4 全屏显示 3、快照与系统恢复4、桌面环境5、文件系统6、用户目录7、创建目录和文件8、命令行:文件列表ls 9、命令行:切换目…

3.使用IDE的优点

IDE是集成开发环境:Integrated Development Environment的缩写。 1、优点 使用IDE的好处在于,可以把编写代码、组织项目、编译、运行、调试等放到一个环境中运行,能极大地提高开发效率。 IDE提升开发效率主要靠以下几点: 编辑器…

Docker容器:docker consul的注册与发现及consul-template

Docker容器:docker consul的注册与发现及consul-template守护进程 一.docker consul的注册与发现介绍 1.什么是服务注册与发现 (1)服务注册与发现是微服务架构中不可或缺的重要组件。 (2)为解决服务都是单节点的&a…

XSS结合CSRF

假设我们获得了目标CMS的源码&#xff0c;搭建了一个相同的网站&#xff0c;我们在自己的网站执行添加用户的操作&#xff0c;并且用bp抓包 如图&#xff0c;这是我们抓到的添加用户的数据包 接下来&#xff0c;我们可以根据数据包构造js代码 <script> xmlhttp new XML…

PPPoE连接无法建立的排查和修复

嗨&#xff0c;亲爱的读者朋友们&#xff01;你是否曾经遇到过PPPoE连接无法建立的问题&#xff1f;今天我将为你详细解析排查和修复这个问题的步骤。 检查物理连接 首先&#xff0c;我们需要确保物理连接没有问题。请按照以下步骤进行检查&#xff1a; - 检查网线是否插好&…

LeetCode--HOT100题(46)

目录 题目描述&#xff1a;114. 二叉树展开为链表&#xff08;中等&#xff09;题目接口解题思路代码 PS: 题目描述&#xff1a;114. 二叉树展开为链表&#xff08;中等&#xff09; 给你二叉树的根结点 root &#xff0c;请你将它展开为一个单链表&#xff1a; 展开后的单链…

vue自定义键盘

<template><div class"mark" click"isOver"></div><div class"mycar"><div class"mycar_list"><div class"mycar_list_con"><p class"mycar_list_p">车牌号</p>…

云化背景下的接口测试覆盖率自动化检查

一、问题来源 在云化场景下&#xff0c;API的测试覆盖是一项重要评估与考察指标。除了开发者自测试外&#xff08;UT&#xff09;&#xff0c;还可以利用云化测试平台、流水线等方法进行相关指标的检查与考核。利用这种方法既可以减轻开发者测试工作量&#xff0c;不必在本地做…

一、Mycat2介绍与下载安装

第一章 入门概述 1.1 是什么 Mycat 是数据库中间件。 1、数据库中间件 中间件&#xff1a;是一类连接软件组件和应用的计算机软件&#xff0c;以便于软件各部件之间的沟 通。 例子&#xff1a;Tomcat&#xff0c;web中间件。 数据库中间件&#xff1a;连接java应用程序和数据库…

CSS学习笔记02

CSS笔记02 美化网页元素 为什么要美化网页 目的&#xff1a; 有效的传递页面信息美化网页、页面漂亮、才能吸引用户突显页面的主题提高用户的体验 span标签 span标签是短语内容的通用行内容器&#xff0c;它本身并没有任何特殊语义。 通常我们使用span标签来把我们想要重…

【Nginx】负载均衡当其中一台服务器宕机之后

搭建一个简单的负载均衡&#xff0c;然后关闭其中一台再来访问&#xff0c;会发现我们的浏览器卡住一直转圈圈&#xff0c;过了很久才会显示结果。由此我们可以得出结论Nginx负载的时候如果其中一台服务挂掉了&#xff0c;它会把请求转发到另一个可以提供服务的机器&#xff0c…

【MongoDB系列】3. MongoDB 安全策略:验证和授权

前言 前面文章中通过客户端工具&#xff08;MongoDB Shell、Robo 3T&#xff09;连接 MongoDB 服务时&#xff0c;只要有 IP 地址和端口号&#xff0c;就能连接到数据库&#xff0c;之后就能操作数据库。这是因为默认安装的 MongoDB 没有启用身份验证&#xff0c;也没有设置初…

【webpack】HMR热更新原理

本文&#xff1a;参考文章 一、HMR是什么&#xff0c;为什么出现 1、出现的原因 之前&#xff0c;应用的加载、更新都是一个页面级别的操作&#xff0c;即使单个代码文件更新&#xff0c;整个页面都要刷新&#xff0c;才能拿到最新的代码同步到浏览器&#xff0c;导致会丢失…

创作纪念日-我的第1024天

机缘 不知不觉已经成为创作者的第1024天啦… … 刚开始接触博客的初衷就是为了记笔记&#x1f4d2;、记总结&#x1f4dd;&#xff0c;或许对于当时就等同于是为了找工作。坚持学习并持续输出博客一年后&#xff0c;这时我发现再写博客&#xff0c;不在是为了找一份工作&…

python-华为云modelarts的免费codelab运行chatglm2-6b-int4

前提&#xff1a;当前提供 了8核64G的免费体验规格&#xff0c;每天三个小时限额 地址&#xff1a;https://console.huaweicloud.com/modelarts/?regioncn-north-4#/dashboard 下载模型&#xff1a;请参考另一个文章 创建环境&#xff08;自带环境是pytorch1.8的&#xff0c;…

大数据精准营销怎么满足用户的个性化需求?

近年来在AI和媒体的带动下&#xff0c;大数据分析不断介入&#xff0c;各行各业都开始陆续依仗大数据营销这棵大树&#xff0c;以此来更加高效、便捷、智能、精准的服务于用户。 这就像追求恋人一样&#xff0c;投其所好方能成为眷属。 大数据精准营销的好处&#xff1a; 相…