GPT(Generative Pre-Training)论文解读及源码实现(二)

本篇为gpt2的pytorch实现,参考 nanoGPT

nanoGPT如何使用见后面第5节

1 数据准备及预处理

data/shakespeare/prepare.py 文件源码分析

1.1 数据划分

下载数据后90%作为训练集,10%作为验证集

with open(input_file_path, 'r') as f:data = f.read()
n = len(data)
train_data = data[:int(n*0.9)]
val_data = data[int(n*0.9):]

1.2 数据编码

使用tiktoken包进行gpt2编码,gpt2默认编码方式为 bpe

enc = tiktoken.get_encoding("gpt2")
train_ids = enc.encode_ordinary(train_data[:100])
val_ids = enc.encode_ordinary(val_data[:100])
print(f"train has {len(train_ids):,} tokens")
print(f"val has {len(val_ids):,} tokens")>>>train has 31 tokensval has 40 tokens

如上取了train_data100个字符,编码后为31个tokens, 100个val 编码后为40个token. 可以通过enc.decode(train_ids) 还原为原始文本数据

train_ids 输出形式为:[5962, 22307, 25, 198, 8421, 356, 5120, 597, 2252, 11, 3285, 502, 2740, 13, 198, 198, 3237, 25, 198, 5248, 461, 11, 2740, 13, 198, 198, 5962, 22307, 25, 198, 1639]

2 训练数据加工

构造训练数据X,Y,其中target数据Y为X平移一位生成,每次取batch_size个数据

    data = train_data if split == 'train' else val_data # data: 301966ix = torch.randint(len(data) - block_size, (batch_size,))x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])

3 模型训练

3.1 GPT模型结构

3.1.1 embedding层

token embedding 和位置embedding
(batch_size 取4,句子长度取8,则输入x shape =[4,8])
embedding后维度,如下图所示

  • token embedding shape=[4,8,128]
  • 位置embeddign shape=[8,128]
wte = nn.Embedding(config.vocab_size, config.n_embd),
wpe = nn.Embedding(config.block_size, config.n_embd),

其中位置信息根据句子长度生成

pos = torch.arange(0, x.size(1), dtype=torch.long, device=device)

在这里插入图片描述

3.1.2 attention 层(带因果推断的attention,即需要上三角maske)

输入x shape=[4,8,128],
通过线性层后 shape: q=k=v=[4,8,128]
将embedding维度进行多头划分后,shape =[4,4,8,32]
(torch2 支持因果attention )

在这里插入图片描述

重点:attention 中mask实现,
即给上三角矩阵填充负无穷大数(负无穷在softmax时,值为0,即权重为0)

 L, S = query.size(-2), key.size(-2)scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scaleattn_bias = torch.zeros(L, S, dtype=query.dtype)if is_causal:assert attn_mask is Nonetemp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))attn_bias.to(query.dtype)

在这里插入图片描述

3.1.3 block层

  • gpt2 会有n_layer个block层,每个block层由layer normal层,attention层,mlp层构成(具体可以参考transformer)

  • block层,由attention层和全连接层组成,
    输入x shape为 [4,8,128]
    输出attention shape为 [4,8,128]
    输出MLP shape为 [4,8,128]

class Block(nn.Module):def __init__(self, config):super().__init__()self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)self.attn = CausalSelfAttention(config)self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)self.mlp = MLP(config)def forward(self, x):x = x + self.attn(self.ln_1(x)) # shape [4,8,128]x = x + self.mlp(self.ln_2(x)) # shape [4,8,128]return x

3.2 损失函数

输入x shape [4,8,128]
输出 logits shape [4,8, 50304],即词典中每个单词的得分
loss为交叉熵损失,为一个标量,

 logits = self.lm_head(x) #; logits shape: [4,8, 50304]loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) # targets 即为之前训练数据的 Y数据

4 模型推理

4.1 模型加载

加载训练时保存的模型
在这里插入图片描述

4.2 定义数据处理的编解码器

数据编解码器与训练时一致

    enc = tiktoken.get_encoding("gpt2")encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})decode = lambda l: enc.decode(l)

4.3 数据生成(重点)

  • 第一次输入只有一个字符
    idx_cond=idx: shape =[1,1]
    logits shape =[1,50304]
    从topK个中安概率随机取一个
    和上面的idx拼接,作为第二次的输入
  • 第二次输入
    idx_cond=idx: shape =[1,2]
    logits shape =[1,50304]
    然后再从topk中安概率随机取一个进行拼接

    -直到达到最大输出活着终止字符
    @torch.no_grad()def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):"""Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and completethe sequence max_new_tokens times, feeding the predictions back into the model each time.Most likely you'll want to make sure to be in model.eval() mode of operation for this."""for _ in range(max_new_tokens):# if the sequence context is growing too long we must crop it at block_sizeidx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]# forward the model to get the logits for the index in the sequencelogits, _ = self(idx_cond)# pluck the logits at the final step and scale by desired temperaturelogits = logits[:, -1, :] / temperature# optionally crop the logits to only the top k optionsif top_k is not None:v, _ = torch.topk(logits, min(top_k, logits.size(-1)))logits[logits < v[:, [-1]]] = -float('Inf')# apply softmax to convert logits to (normalized) probabilitiesprobs = F.softmax(logits, dim=-1)# sample from the distributionidx_next = torch.multinomial(probs, num_samples=1)# append sampled index to the running sequence and continueidx = torch.cat((idx, idx_next), dim=1)return idx

5 GPT2 使用

5.1 下载git源码

  • git clone https://github.com/karpathy/nanoGPT.git
  • 安装依赖包(建议安装torch2以上版本,其他包不限制版本)
pip install torch numpy transformers datasets tiktoken wandb tqdm

(mac pytorch 安装: pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu

5.1 数据下载

测试下载数据及,并编码成数字格式

python data/shakespeare/prepare.py

5.2 模型训练

参数解释:

  • device :使用的GPU 类型,可以是cuda ,cpu , mps
  • compile 是否使用编译优化,torch2版本支持(mac mps 不支持)
  • eval_iters 迭代次数
  • block_size:训练句子长度(演示最大句子长度只取了8)
  • batch_size : batch size
  • n_layer: 使用多少个transformer block
  • n_head: attention 头数
  • n_embd: embedding 维度
  • dropout:dropout 比例
config/train_shakespeare_char.py --device=mps --compile=False --eval_iters=20 --log_interval=1 --block_size=8 --batch_size=4 --n_layer=4 --n_head=4 --n_embd=128 --max_iters=2000 --lr_decay_iters=2000 --dropout=0.0

(备注: 我使用的是shakespeare数据集,因此将配置文件train_shakespeare_char.py 进行了修改 wandb_project = ‘shakespeare’ ;dataset = ‘shakespeare’)

5.3 模型推理

python sample.py --out_dir=out-shakespeare

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/603763.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MySQL数据库主从复制和读写分离

目录 一、MySQL主从复制和读写分离理论 &#xff08;一&#xff09;读写分离 1.什么是读写分离 2.为什么要读写分离 3.什么时候要读写分离 4.读写分离原理 5.常见MySQL 读写分离 &#xff08;1&#xff09;基于程序代码内部实现 &#xff08;2&#xff09;基于中间代理…

react-hooks-kit v1 正式发布

evanpatchouli/react-hooks-kit - (npmjs.com) v1.0.0 正式发布&#xff01; 下载安装 npm i evanpatchouli/react-hooks-it -S官方文档 在 Gitee 阅读在 Github 阅读 概览 这是一个无依赖的轻量级 React Hooks 库&#xff0c;总共有 60 hooks。 它包含了一系列易于使用…

持续积累ThreadLocal技术【ThreadLocal原理 + ThreadLocal的坑 + ThreadLocal的最佳实践】

持续积累ThreadLocal技术的目录 一、先从使用ThreadLocal开始1、我看到的两种创建方式1.1 ThreadLocal<A> aThreadLocal new ThreadLocal<>();1.2 ThreadLocal<A> aThreadLocal ThreadLocal.withInitial(...)1.3 为啥需要1.2提到的创建方式&#xff1f;直接…

k8s的pod基础

pod概念 pod是k8s中最小的资源管理组件。 pod也是最小化运行容器化的应用的资源管理对象。 pod是一个抽象的概念&#xff0c;可以理解为一个或者多个容器化应用的集合。 在一个pod当中运行一个容器是最常用的方式。在一个pod当中同时运行多个容器&#xff0c;在一个pod当中…

算法练习Day29 (Leetcode/Python-动态规划)

基本概念&#xff1a; 代码随想录&#xff1a; Dynamic Programming&#xff0c;简称DP&#xff0c;如果某一问题有很多重叠子问题&#xff0c;使用动态规划是最有效的。 所以动态规划中每一个状态一定是由上一个状态推导出来的&#xff0c;这一点就区分于贪心&#xff0c;贪…

计算机网络 综合(习题)

【计算机网络习题】系列文章目录 计算机网络 第一章 绪论(习题) 计算机网络 第二章 计算机网络体系结构(习题) 计算机网络 第三章 应用层(习题) 计算机网络 第四章 运输层(习题) 计算机网络 第五章 网络层(习题) 计算机网络 第六章 数据链路层(习题) 计算机网络 第七章 物…

强化学习5——动态规划在强化学习中的应用

动态规划在强化学习中的应用 基于动态规划的算法优良 &#xff1a;策略迭代和价值迭代。 策略迭代分为策略评估和策略提升&#xff0c;使用贝尔曼期望方程得到一个策略的状态价值函数&#xff1b;价值迭代直接使用贝尔曼最优方程进行动态规划&#xff0c;得到最终的最优状态价…

Unity 一文掌握使用AddListener方法为组件事件添加监听器的方法

在Unity中&#xff0c;很多组件都带有事件&#xff0c;比如: Button组件&#xff1a;onClick() Toggle组件&#xff1a;On Value Changed(Boolean) Dropdown组件&#xff1a;On Value Changed(Int32) InputField组件&#xff1a;On Value Changed(String)、On End Edit(Stri…

CCC数字钥匙设计【NFC】--NFC通信之APDU TLV

CCC3.0&#xff0c;包含NFC、BLE、UWB技术。当采用NFC通信时&#xff0c;车端与手机端是通过APDU来进行交互的。而在APDU中的data数据段&#xff0c;又可能会嵌入TLV协议的数据&#xff0c;以完成车端与手机端的通信交互。 本文先介绍APDU及TLV的一些基础知识&#xff0c;再通…

断更后的故事1

文章目录 技术男为何开始写感悟博客&#xff1f;简单的自我介绍为什么断更了默默进化的日子琐碎的事情对阶段1的思索和总结 技术男为何开始写感悟博客&#xff1f; 其实我是一个偏感性的一个技术男&#xff0c;可能这样就有点违背技术男这个定义了&#xff0c;很多时候还是挺理…

全连接网络、卷积神经网络、递归神经网络 通俗的解释

全连接网络、卷积神经网络和递归神经网络是三种不同类型的神经网络&#xff0c;它们在结构和应用上有所不同。下面我将尽量用通俗易懂的语言来解释和对比这三种神经网络。 1.全连接网络 全连接网络是一种最常见的神经网络类型&#xff0c;它的每一层都由许多神经元组成&#…

SSH 密钥身份验证和管理

安全外壳协议&#xff08;Security Shell Protocol&#xff09;是一种应用于计算机网络的安全通信协议&#xff0c;其提供的服务可用于保护网络上的连接和数据传输安全性&#xff0c;其核心思想是为网络上的两台计算机之间搭建一个安全的外壳&#xff0c;以保护数据传输的安全性…

简单介绍Java 的内存泄漏

java最明显的一个优势就是它的内存管理机制。你只需简单创建对象&#xff0c;java的垃圾回收机制负责分配和释放内存。然而情况并不像想像的那么简单&#xff0c;因为在Java应用中经常发生内存泄漏。 本教程演示了什么是内存泄漏&#xff0c;为什么会发生内存泄漏以及如何预防…

2、C语言:控制流

控制流 语句&#xff1a;在表达式后面加上分号&#xff0c;构成语句。 程序块&#xff1a;用一对花括号“{”与“}”把一组声明和语句括在一起就构成了一个复合语句。复合语句在语法上等同于单条语句。 if-else语句else-if语句&#xff1a;从上到下依次执行&#xff0c;等同于…

视频云存储/视频智能分析平台EasyCVR在麒麟系统中无法启动该如何解决?

安防视频监控/视频集中存储/云存储/磁盘阵列EasyCVR平台可拓展性强、视频能力灵活、部署轻快&#xff0c;可支持的主流标准协议有国标GB28181、RTSP/Onvif、RTMP等&#xff0c;以及支持厂家私有协议与SDK接入&#xff0c;包括海康Ehome、海大宇等设备的SDK等。平台既具备传统安…

【docker】网络模式管理

目录 一、Docker网络实现原理 二、Docker的网络模式 1、host模式 1.1 host模式原理 1.2 host模式实操 2、Container模式 2.2 container模式实操 3、none模式 4、bridger模式 4.1 bridge模式的原理 4.2 bridge实操 5、overlay模式 6、自定义网络模式 6.1 为什么需要…

017、使用包、单元包及模块来管理日渐复杂的项目

在编写较为复杂的项目时&#xff0c;合理地对代码进行组织与管理很重要&#xff0c;因为我们不太可能记住代码中所有的细枝末节。只有按照不同的特性来组织或分割相关功能的代码&#xff0c;我们才能够清晰地找到实现指定功能的代码片段&#xff0c;或确定哪些地方需要修改。 到…

【UML】第14篇 协作图

目录 一、协作图的概述 二、协作图的主要构成 2.1 对象 2.2 消息 2.3 链 三、协作图如何画 3.1 思路 3.2 步骤 这个系列暂停了好几天了&#xff0c;适当时候再恢复一下。 UML非常经典&#xff0c;只要在这个行业&#xff0c;代码可能不会写一辈子&#xff0c;但是图肯定…

Java socket编程学习笔记

一、初步了解 1、简易代码(存在socket提前关闭问题) 服务端代码: import java.io.*; import java.net.ServerSocket; import java.net.Socket; import java.nio.charset.StandardCharsets;public class MySocketServer {public static void main(String[] args) throws IOEx…

js判断是否为数字的方法

找到一个比较好用的方法&#xff0c;记录下来&#xff0c;方便以后使用查找 function isNumber(value) {return !isNaN(parseFloat(value)) && isFinite(value); }目前测试情况&#xff1a; isNumber(123) —> true isNumber(12.3) —> true isNumber(-12.3) —…