手撸nano-gpt

nano GPT

跟着youtube上AndrejKarpathy大佬复现一个简单GPT

1.数据集准备

很小的莎士比亚数据集

wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

1.1简单的tokenize

数据和等下的模型较简单,所以这里用了个很简单的直接按照字母去分割的tokenize

复杂些的可以用**tiktoken**: openai在gpt2上用的。

with open('input.txt', 'r', encoding='utf-8') as f:text = f.read()print(len(text))
#>>> 1115394chars = sorted(list(set(text)))
vocab_size = len(chars)stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}string = 'hii there'
decode(encode(string)) == string
#>>> True

1.2切分训练集

import torchdata = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
#>>> torch.Size([1115394]) torch.int64n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

1.3获取小批量

注意,target在切分的时候错开了一个位置。原因是如果原串是[1,2,3,4,5]。当我们的input是[1,2,3]的时候应该生成一个[1,2,3,4]。

实现如下

torch.manual_seed(1337)
batch_size, block_size = 4, 8def get_batch(split):data = train_data if split == 'train' else val_dataix = torch.randint(len(data) - block_size, (batch_size, ))x = torch.stack([data[i: i+block_size] for i in ix])y = torch.stack([data[i+1:i+1+block_size] for i in ix])return x, yxb, yb = get_batch('train')
print(xb.shape)
print(xb)print('target:')
print(yb.shape)
print(yb)for b in range(batch_size):for t in range(block_size):# print(xb[b: :t+1])context = xb[b, :t+1].tolist()target = yb[b, t]print(f'when input is {context}, target is {target}')

image-20240308150934834

2.模型定义

2.1模型代码

这里具体解释一下为什么inputs, target送入模型前要做reshape。因为F.cross_entropy规定了 input 的shape必须是 [N, C] 其中N是样本数C是类别数这里也就是我们的vocab_size。与之对应,我们的 target 的shape就应该是[N]。input 送入模型后我们会得到input中每一个位置的下一个位置的预测,如果原文本是 [1,2,3],input : [1,2] ,target : [2,3]。那么送入 input 后我们可能会得到[2, 2.7]然后用这个和target计算损失。

import torch
import torch.nn as nn
from torch.nn import functional as Ftorch.manual_seed(1337)class BigramLanguageModel(nn.Module):def __init__(self, vocab_size):super().__init__()self.token_embedding = nn.Embedding(vocab_size, vocab_size)def forward(self, inputs, target=None):# inputs: [B,L], target: [B,1]logits = self.token_embedding(inputs) #[B,L,C]if target is None: loss = Noneelse:B, T, C = logits.shapelogits = logits.reshape(B*T, C)target = target.reshape(-1)loss = F.cross_entropy(logits, target)return logits, lossdef generate(self, idx, max_new_tokens):# idx is [B, T] array of indices in the current contextfor _ in range(max_new_tokens):logits, loss = self(idx)# 关注最后一个位置logits = logits[:, -1, :] # [B, C]probs = F.softmax(logits, dim=-1) # [B, C]idx_next = torch.multinomial(probs, num_samples=1) # [B, 1]idx = torch.cat([idx, idx_next], dim=1)return idxm = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
print(logits.shape)
print(loss)idx = torch.zeros((1, 1), dtype=torch.long)
print(decode(m.generate(idx, max_new_tokens=100)[0].tolist()))

2.2优化器及训练

optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)batch_size = 32
for steps in range(1000):xb, yb = get_batch('train')logits, loss = m(xb, yb)optimizer.zero_grad(set_to_none=True)loss.backward()optimizer.step()
print(loss.item())

2.3 生成

print(decode(m.generate(idx, max_new_tokens=300)[0].tolist()))

3.加入注意力机制

3.1单头注意力

class Head(nn.Module):def __init__(self, head_size):super().__init__()self.key = nn.Linear(n_embd, head_size)self.query = nn.Linear(n_embd, head_size)self.value = nn.Linear(n_embd, head_size)self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))self.dropout = nn.Dropout(dropout)def forward(self, x):B, T, C  = x.shapek = self.key(x)q = self.query(x)v = self.value(x)wei = q @ k.transpose(-2, -1) * C ** -0.5wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))wei = F.softmax(wei, dim=-1)wei = self.dropout(wei)out = wei @ vreturn out

3.2多头注意力

class MultiHeadAttention(nn.Module):def __init__(self, num_heads, head_size):super().__init__()self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])self.proj = nn.Linear(n_embd, n_embd)self.dropout = nn.Dropout(dropout)def forward(self, x):out = torch.cat([head(x) for head in self.heads], dim=-1)out = self.dropout(self.proj(out))return out

3.3前馈神经网络

class FeedForward(nn.Module):def __init__(self, n_embed):super().__init__()self.net = nn.Sequential(nn.Linear(n_embed, 4 * n_embed),nn.ReLU(),nn.Linear(4 * n_embed, n_embed))self.dropout = nn.Dropout(dropout)def forward(self, x):return self.dropout(self.net(x))

3.4transformerBlock

这里实现的是一个简易版的,如果n_embd是32, n_head=4, 那么每个单独的头只会产生 [B, T, 8] 这个尺寸的信息,然后将4个头的信息在dim=-1这个维度拼接起来即可。

class Block(nn.Module):def __init__(self, n_embd, n_head):super().__init__()head_size = n_embd // n_headself.sa = MultiHeadAttention(n_head, head_size)self.ffwd = FeedForward(n_embd)self.ln1 = nn.LayerNorm(n_embd)self.ln2 = nn.LayerNorm(n_embd)def forward(self, x):x = x + self.sa(self.ln1(x))x = x + self.ffwd(self.ln2(x))return x

4.最终训练

加入注意力机制和扩大模型后我们得到了这样的模型以及超参数

参数

batch_size = 64
block_size = 256
max_iters = 5000
eval_interval = 500
lr = 3e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 384
n_head = 6
n_layer = 6
dropout = 0.2

模型

class BigramLanguageModel(nn.Module):def __init__(self, n_embd):super().__init__()self.token_embedding = nn.Embedding(vocab_size, n_embd)self.lm_head = nn.Linear(n_embd, vocab_size)self.position_embedding_table = nn.Embedding(block_size, n_embd)self.blocks = nn.Sequential(Block(n_embd, n_head=n_head),Block(n_embd, n_head=n_head),Block(n_embd, n_head=n_head),Block(n_embd, n_head=n_head),Block(n_embd, n_head=n_head),Block(n_embd, n_head=n_head),nn.LayerNorm(n_embd))self.ffwd = FeedForward(n_embd)def forward(self, inputs, target=None):# inputs: [B,L], target: [B,L]B, T = inputs.shapetok_emb = self.token_embedding(inputs)  # [B,T,C]pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # [T, C]x = tok_emb + pos_embx = self.blocks(x)x = self.ffwd(x)logits = self.lm_head(x) # [B, T, C] C = vocab_sizeif target is None:loss = Noneelse:B, T, C = logits.shapelogits = logits.reshape(B * T, C)target = target.reshape(-1)loss = F.cross_entropy(logits, target)return logits, lossdef generate(self, idx, max_new_tokens):# idx is [B, T] array of indices in the current contextfor _ in range(max_new_tokens):idx_cond = idx[:, -block_size:]logits, loss = self(idx_cond)# 关注最后一个位置logits = logits[:, -1, :]  # [B, C]probs = F.softmax(logits, dim=-1)  # [B, C]idx_next = torch.multinomial(probs, num_samples=1)  # [B, 1]idx = torch.cat([idx, idx_next], dim=1)return idx

在A800上训练可以得到如下结果

可以看到loss已经降的不错了,只不过说出来的话还不太合理hhh

step 0: train loss 4.1744, val loss 4.1743
step 500: train loss 1.9218, val loss 2.0212
step 1000: train loss 1.5678, val loss 1.7493
step 1500: train loss 1.4277, val loss 1.6303
step 2000: train loss 1.3384, val loss 1.5647
step 2500: train loss 1.2810, val loss 1.5380
step 3000: train loss 1.2325, val loss 1.5121
step 3500: train loss 1.1924, val loss 1.5010
step 4000: train loss 1.1506, val loss 1.4956
step 4500: train loss 1.1204, val loss 1.5051Havingly made me been's wife.
Thy father's name be heard he will not say
Your undoubter'd prift, that's that sympirate.KING RICHARD III:
Those palasion most pallars, these measures
Shame laceling may be invenged by my breast.DUKE VINCENTIO:
Then, I think it, is approach'd lip.PRINCENTIUS:
The

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/740014.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

解决mybatis-plus新增数据自增ID与之前数据不匹配问题

实体类的例子 Data public class User {TableId(value "id", type IdType.AUTO)private Integer id;private String username;// 忽略,不传到前端JsonIgnoreprivate String password;private String nickname;private String email;private String phone;private …

css---定位

定位 1. 相对定位1.1 如何设置相对定位?1.2 相对定位的参考点在哪里?1.3 相对定位的特点: 2. 绝对定位2.1 如何设置绝对定位?2.2 绝对定位的参考点在哪里?2.3 绝对定位元素的特点: 3. 固定定位3.1 如何设置…

PostgreSQL教程(三十六):服务器管理(十八)之回归测试

回归测试是PostgreSQL中对于 SQL 实现的一组综合测试集。它们测试标准 SQL 操作以及PostgreSQL的扩展能力。 一、运行测试 回归测试可以在一个已经安装并运行的服务器上运行,或者在编译树中的一个临时安装上运行。此外,还有运行该测试的“并行”和“顺…

游戏免费下载平台模板源码

功能介绍 此游戏网站模板源码是专门为游戏下载站而设计的,旨在为网站开发者提供一个高效、易于维护和扩展的解决方案。 特点: 响应式设计:我们的模板可以自适应不同设备屏幕大小,从而为不同平台的用户提供最佳的浏览体验。 …

算法---滑动窗口练习-1(长度最小的子数组)

长度最小的子数组 1. 题目解析2. 讲解算法原理3. 编写代码 1. 题目解析 题目地址:长度最小的子数组 2. 讲解算法原理 首先,定义变量n为数组nums的长度,sum为当前子数组的和,len为最短子数组的长度,初始值为INT_MAX&am…

javascript中的structuredClone()克隆方法

前言: structuredClone 是 JavaScript 的方法之一,用于深拷贝一个对象。它的语法是 structuredClone(obj),其中 obj 是要拷贝的对象。structuredClone 方法将会创建一个与原始对象完全相同但是独立的副本。 案例: 当使用Web Work…

阿里巴巴按关键字搜索商品 API 返回值说明

item_search-按关键字搜索商品API测试工具 alibaba.item_search 公共参数 名称类型必须描述keyString是调用key(必须以GET方式拼接在URL中)secretString是调用密钥api_nameString是API接口名称(包括在请求地址中)[item_search,…

Shadertoy内置函数系列 - mod 取模运算

mod函数返回x % 3的结果 先看一个挑战问题题目: Create a pattern of alternating black and red columns, with 9 columns of each color. Then, hide every third column that is colored red.The shader should avoid using branching or conditional statemen…

2024年最新阿里云和腾讯云云服务器价格租用对比

2024年阿里云服务器和腾讯云服务器价格战已经打响,阿里云服务器优惠61元一年起,腾讯云服务器61元一年,2核2G3M、2核4G、4核8G、4核16G、8核16G、16核32G、16核64G等配置价格对比,阿腾云atengyun.com整理阿里云和腾讯云服务器详细配…

C语言分支和循环总结

文章目录 概要结构介绍不同结构的语句简单运用小结 概要 C语言中分为三种结构:顺序结构,选择结构,循环结构 结构介绍 顺序结构就是从上到下,从左到右等等;选择结构可以想象是Y字路口就是到了一个地方会有不同的道路…

Redis事务为什么不支持原子性

Redis事务提供了一种将多个命令打包,然后按顺序执行的机制。使用MULTI命令开始事务,接着输入需要队列化的命令,最后使用EXEC命令提交整个事务。尽管Redis事务可以保证一系列命令被连续执行,没有其他客户端命令插入其中执行&#x…

每日OJ题_路径dp②_力扣63. 不同路径 II

目录 力扣63. 不同路径 II 解析代码 力扣63. 不同路径 II 63. 不同路径 II 难度 中等 一个机器人位于一个 m x n 网格的左上角 (起始点在下图中标记为 “Start” )。 机器人每次只能向下或者向右移动一步。机器人试图达到网格的右下角(…

多媒体技术2-颜色空间

颜色空间是一种用于表示和描述颜色的数学模型。它是由颜色分量和坐标系组成的。常见的颜色空间有RGB、CMYK、HSV等。 RGB颜色空间:RGB是红、绿、蓝三个颜色分量的缩写。在RGB颜色空间中,每个颜色分量的取值范围是0到255,表示了红、绿、蓝三个…

鸿蒙Harmony应用开发—ArkTS声明式开发(基础手势:Select)

提供下拉选择菜单&#xff0c;可以让用户在多个选项之间选择。 说明&#xff1a; 该组件从API Version 8开始支持。后续版本如有新增内容&#xff0c;则采用上角标单独标记该内容的起始版本。 子组件 无 接口 Select(options: Array<SelectOption>) 参数&#xff1a;…

golang中fallthrough简介及用法

什么是fallthrough&#xff1f; fallthrough是golang中的一个关键字&#xff0c;它用于在switch语句中控制代码的执行流程。通常情况下&#xff0c;当一个case分支匹配成功后&#xff0c;switch语句就会结束&#xff0c;不会继续执行后面的case分支。但是&#xff0c;如果在一…

git撤回代码提交commit或者修改commit提交注释

执行commit后&#xff0c;还没执行push时&#xff0c;想要撤销之前的提交commit 撤销提交 使用命令&#xff1a; git reset --soft HEAD^命令详解&#xff1a; HEAD^ 表示上一个版本&#xff0c;即上一次的commit&#xff0c;也可以写成HEAD~1 如果进行两次的commit&#xf…

爬虫(六)

复习回顾: 01.浏览器一个网页的加载全过程1. 服务器端渲染html的内容和数据在服务器进行融合.在浏览器端看到的页面源代码中. 有你需要的数据2. 客户端(浏览器)渲染html的内容和数据进行融合是发生在你的浏览器上的.这个过程一般通过脚本来完成(javascript)我们通过浏览器可以…

算法打卡day15|二叉树篇04|110.平衡二叉树、257. 二叉树的所有路径、404.左叶子之和

算法题 Leetcode 110.平衡二叉树 题目链接:110.平衡二叉树 大佬视频讲解&#xff1a;平衡二叉树视频讲解 个人思路 可以用递归法&#xff0c;计算左右子树的高度差&#xff0c;当超过1时就不为平衡二叉树了&#xff1b; 解法 回顾一下二叉树节点的深度与高度&#xff1b; …

Python学习:基础语法

版本查看 python --version编码 默认情况下&#xff0c;Python 3 源码文件以 UTF-8 编码&#xff0c;所有字符串都是 unicode 字符串。 特殊情况下&#xff0c;也可以为源码文件指定不同的编码&#xff1a; # -*- coding: cp-1252 -*-标识符 第一个字符必须是字母表中字母或…

rt-thread组件之audio组件(结合mp3player包使用)

前言 继上一篇RT-Thread组件之Audio框架i2s驱动的编写的编写&#xff0c;应用层使用rt-thread软件包里面的wavplayer组件以及 rt-thread组件之audio组件(结合wavplayer包使用)的文章本篇使用的是 mp3player软件包&#xff0c;与wavplayer设计框架基本上是一样的&#xff0c;只…