手撸nano-gpt

nano GPT

跟着youtube上AndrejKarpathy大佬复现一个简单GPT

1.数据集准备

很小的莎士比亚数据集

wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

1.1简单的tokenize

数据和等下的模型较简单,所以这里用了个很简单的直接按照字母去分割的tokenize

复杂些的可以用**tiktoken**: openai在gpt2上用的。

with open('input.txt', 'r', encoding='utf-8') as f:text = f.read()print(len(text))
#>>> 1115394chars = sorted(list(set(text)))
vocab_size = len(chars)stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}string = 'hii there'
decode(encode(string)) == string
#>>> True

1.2切分训练集

import torchdata = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
#>>> torch.Size([1115394]) torch.int64n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

1.3获取小批量

注意,target在切分的时候错开了一个位置。原因是如果原串是[1,2,3,4,5]。当我们的input是[1,2,3]的时候应该生成一个[1,2,3,4]。

实现如下

torch.manual_seed(1337)
batch_size, block_size = 4, 8def get_batch(split):data = train_data if split == 'train' else val_dataix = torch.randint(len(data) - block_size, (batch_size, ))x = torch.stack([data[i: i+block_size] for i in ix])y = torch.stack([data[i+1:i+1+block_size] for i in ix])return x, yxb, yb = get_batch('train')
print(xb.shape)
print(xb)print('target:')
print(yb.shape)
print(yb)for b in range(batch_size):for t in range(block_size):# print(xb[b: :t+1])context = xb[b, :t+1].tolist()target = yb[b, t]print(f'when input is {context}, target is {target}')

image-20240308150934834

2.模型定义

2.1模型代码

这里具体解释一下为什么inputs, target送入模型前要做reshape。因为F.cross_entropy规定了 input 的shape必须是 [N, C] 其中N是样本数C是类别数这里也就是我们的vocab_size。与之对应,我们的 target 的shape就应该是[N]。input 送入模型后我们会得到input中每一个位置的下一个位置的预测,如果原文本是 [1,2,3],input : [1,2] ,target : [2,3]。那么送入 input 后我们可能会得到[2, 2.7]然后用这个和target计算损失。

import torch
import torch.nn as nn
from torch.nn import functional as Ftorch.manual_seed(1337)class BigramLanguageModel(nn.Module):def __init__(self, vocab_size):super().__init__()self.token_embedding = nn.Embedding(vocab_size, vocab_size)def forward(self, inputs, target=None):# inputs: [B,L], target: [B,1]logits = self.token_embedding(inputs) #[B,L,C]if target is None: loss = Noneelse:B, T, C = logits.shapelogits = logits.reshape(B*T, C)target = target.reshape(-1)loss = F.cross_entropy(logits, target)return logits, lossdef generate(self, idx, max_new_tokens):# idx is [B, T] array of indices in the current contextfor _ in range(max_new_tokens):logits, loss = self(idx)# 关注最后一个位置logits = logits[:, -1, :] # [B, C]probs = F.softmax(logits, dim=-1) # [B, C]idx_next = torch.multinomial(probs, num_samples=1) # [B, 1]idx = torch.cat([idx, idx_next], dim=1)return idxm = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
print(logits.shape)
print(loss)idx = torch.zeros((1, 1), dtype=torch.long)
print(decode(m.generate(idx, max_new_tokens=100)[0].tolist()))

2.2优化器及训练

optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)batch_size = 32
for steps in range(1000):xb, yb = get_batch('train')logits, loss = m(xb, yb)optimizer.zero_grad(set_to_none=True)loss.backward()optimizer.step()
print(loss.item())

2.3 生成

print(decode(m.generate(idx, max_new_tokens=300)[0].tolist()))

3.加入注意力机制

3.1单头注意力

class Head(nn.Module):def __init__(self, head_size):super().__init__()self.key = nn.Linear(n_embd, head_size)self.query = nn.Linear(n_embd, head_size)self.value = nn.Linear(n_embd, head_size)self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))self.dropout = nn.Dropout(dropout)def forward(self, x):B, T, C  = x.shapek = self.key(x)q = self.query(x)v = self.value(x)wei = q @ k.transpose(-2, -1) * C ** -0.5wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))wei = F.softmax(wei, dim=-1)wei = self.dropout(wei)out = wei @ vreturn out

3.2多头注意力

class MultiHeadAttention(nn.Module):def __init__(self, num_heads, head_size):super().__init__()self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])self.proj = nn.Linear(n_embd, n_embd)self.dropout = nn.Dropout(dropout)def forward(self, x):out = torch.cat([head(x) for head in self.heads], dim=-1)out = self.dropout(self.proj(out))return out

3.3前馈神经网络

class FeedForward(nn.Module):def __init__(self, n_embed):super().__init__()self.net = nn.Sequential(nn.Linear(n_embed, 4 * n_embed),nn.ReLU(),nn.Linear(4 * n_embed, n_embed))self.dropout = nn.Dropout(dropout)def forward(self, x):return self.dropout(self.net(x))

3.4transformerBlock

这里实现的是一个简易版的,如果n_embd是32, n_head=4, 那么每个单独的头只会产生 [B, T, 8] 这个尺寸的信息,然后将4个头的信息在dim=-1这个维度拼接起来即可。

class Block(nn.Module):def __init__(self, n_embd, n_head):super().__init__()head_size = n_embd // n_headself.sa = MultiHeadAttention(n_head, head_size)self.ffwd = FeedForward(n_embd)self.ln1 = nn.LayerNorm(n_embd)self.ln2 = nn.LayerNorm(n_embd)def forward(self, x):x = x + self.sa(self.ln1(x))x = x + self.ffwd(self.ln2(x))return x

4.最终训练

加入注意力机制和扩大模型后我们得到了这样的模型以及超参数

参数

batch_size = 64
block_size = 256
max_iters = 5000
eval_interval = 500
lr = 3e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 384
n_head = 6
n_layer = 6
dropout = 0.2

模型

class BigramLanguageModel(nn.Module):def __init__(self, n_embd):super().__init__()self.token_embedding = nn.Embedding(vocab_size, n_embd)self.lm_head = nn.Linear(n_embd, vocab_size)self.position_embedding_table = nn.Embedding(block_size, n_embd)self.blocks = nn.Sequential(Block(n_embd, n_head=n_head),Block(n_embd, n_head=n_head),Block(n_embd, n_head=n_head),Block(n_embd, n_head=n_head),Block(n_embd, n_head=n_head),Block(n_embd, n_head=n_head),nn.LayerNorm(n_embd))self.ffwd = FeedForward(n_embd)def forward(self, inputs, target=None):# inputs: [B,L], target: [B,L]B, T = inputs.shapetok_emb = self.token_embedding(inputs)  # [B,T,C]pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # [T, C]x = tok_emb + pos_embx = self.blocks(x)x = self.ffwd(x)logits = self.lm_head(x) # [B, T, C] C = vocab_sizeif target is None:loss = Noneelse:B, T, C = logits.shapelogits = logits.reshape(B * T, C)target = target.reshape(-1)loss = F.cross_entropy(logits, target)return logits, lossdef generate(self, idx, max_new_tokens):# idx is [B, T] array of indices in the current contextfor _ in range(max_new_tokens):idx_cond = idx[:, -block_size:]logits, loss = self(idx_cond)# 关注最后一个位置logits = logits[:, -1, :]  # [B, C]probs = F.softmax(logits, dim=-1)  # [B, C]idx_next = torch.multinomial(probs, num_samples=1)  # [B, 1]idx = torch.cat([idx, idx_next], dim=1)return idx

在A800上训练可以得到如下结果

可以看到loss已经降的不错了,只不过说出来的话还不太合理hhh

step 0: train loss 4.1744, val loss 4.1743
step 500: train loss 1.9218, val loss 2.0212
step 1000: train loss 1.5678, val loss 1.7493
step 1500: train loss 1.4277, val loss 1.6303
step 2000: train loss 1.3384, val loss 1.5647
step 2500: train loss 1.2810, val loss 1.5380
step 3000: train loss 1.2325, val loss 1.5121
step 3500: train loss 1.1924, val loss 1.5010
step 4000: train loss 1.1506, val loss 1.4956
step 4500: train loss 1.1204, val loss 1.5051Havingly made me been's wife.
Thy father's name be heard he will not say
Your undoubter'd prift, that's that sympirate.KING RICHARD III:
Those palasion most pallars, these measures
Shame laceling may be invenged by my breast.DUKE VINCENTIO:
Then, I think it, is approach'd lip.PRINCENTIUS:
The

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/740014.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

游戏免费下载平台模板源码

功能介绍 此游戏网站模板源码是专门为游戏下载站而设计的,旨在为网站开发者提供一个高效、易于维护和扩展的解决方案。 特点: 响应式设计:我们的模板可以自适应不同设备屏幕大小,从而为不同平台的用户提供最佳的浏览体验。 …

算法---滑动窗口练习-1(长度最小的子数组)

长度最小的子数组 1. 题目解析2. 讲解算法原理3. 编写代码 1. 题目解析 题目地址:长度最小的子数组 2. 讲解算法原理 首先,定义变量n为数组nums的长度,sum为当前子数组的和,len为最短子数组的长度,初始值为INT_MAX&am…

javascript中的structuredClone()克隆方法

前言: structuredClone 是 JavaScript 的方法之一,用于深拷贝一个对象。它的语法是 structuredClone(obj),其中 obj 是要拷贝的对象。structuredClone 方法将会创建一个与原始对象完全相同但是独立的副本。 案例: 当使用Web Work…

Shadertoy内置函数系列 - mod 取模运算

mod函数返回x % 3的结果 先看一个挑战问题题目: Create a pattern of alternating black and red columns, with 9 columns of each color. Then, hide every third column that is colored red.The shader should avoid using branching or conditional statemen…

2024年最新阿里云和腾讯云云服务器价格租用对比

2024年阿里云服务器和腾讯云服务器价格战已经打响,阿里云服务器优惠61元一年起,腾讯云服务器61元一年,2核2G3M、2核4G、4核8G、4核16G、8核16G、16核32G、16核64G等配置价格对比,阿腾云atengyun.com整理阿里云和腾讯云服务器详细配…

每日OJ题_路径dp②_力扣63. 不同路径 II

目录 力扣63. 不同路径 II 解析代码 力扣63. 不同路径 II 63. 不同路径 II 难度 中等 一个机器人位于一个 m x n 网格的左上角 (起始点在下图中标记为 “Start” )。 机器人每次只能向下或者向右移动一步。机器人试图达到网格的右下角(…

鸿蒙Harmony应用开发—ArkTS声明式开发(基础手势:Select)

提供下拉选择菜单&#xff0c;可以让用户在多个选项之间选择。 说明&#xff1a; 该组件从API Version 8开始支持。后续版本如有新增内容&#xff0c;则采用上角标单独标记该内容的起始版本。 子组件 无 接口 Select(options: Array<SelectOption>) 参数&#xff1a;…

git撤回代码提交commit或者修改commit提交注释

执行commit后&#xff0c;还没执行push时&#xff0c;想要撤销之前的提交commit 撤销提交 使用命令&#xff1a; git reset --soft HEAD^命令详解&#xff1a; HEAD^ 表示上一个版本&#xff0c;即上一次的commit&#xff0c;也可以写成HEAD~1 如果进行两次的commit&#xf…

算法打卡day15|二叉树篇04|110.平衡二叉树、257. 二叉树的所有路径、404.左叶子之和

算法题 Leetcode 110.平衡二叉树 题目链接:110.平衡二叉树 大佬视频讲解&#xff1a;平衡二叉树视频讲解 个人思路 可以用递归法&#xff0c;计算左右子树的高度差&#xff0c;当超过1时就不为平衡二叉树了&#xff1b; 解法 回顾一下二叉树节点的深度与高度&#xff1b; …

Python学习:基础语法

版本查看 python --version编码 默认情况下&#xff0c;Python 3 源码文件以 UTF-8 编码&#xff0c;所有字符串都是 unicode 字符串。 特殊情况下&#xff0c;也可以为源码文件指定不同的编码&#xff1a; # -*- coding: cp-1252 -*-标识符 第一个字符必须是字母表中字母或…

rt-thread组件之audio组件(结合mp3player包使用)

前言 继上一篇RT-Thread组件之Audio框架i2s驱动的编写的编写&#xff0c;应用层使用rt-thread软件包里面的wavplayer组件以及 rt-thread组件之audio组件(结合wavplayer包使用)的文章本篇使用的是 mp3player软件包&#xff0c;与wavplayer设计框架基本上是一样的&#xff0c;只…

java-单列集合-set系列

set集合继承collection,所以API都差不多&#xff0c;我就不多加介绍 直接见图看他们的特点 我们主要讲述的是set系列里的HashSet、LinkedHashSet、TreeSet HashSet HashSet它的底层是哈希表 哈希表由数组集合红黑树组成 特点&#xff1a;增删改查都性能良好 哈希表具体是…

网络安全攻击数据的多维度可视化分析

简介 本研究项目通过应用多种数据处理与可视化技术&#xff0c;对网络安全攻击事件数据集进行了深度分析。首先&#xff0c;利用Pandas库读取并预处理数据&#xff0c;包括检查缺失值、剔除冗余信息以及将时间戳转化为日期时间格式以利于后续时间序列分析。 研究步骤 数据分析…

git commit --amend

git commit --amend 1. 修改已经输入的commit 1. 修改已经输入的commit 我已经输入了commit fix: 删除无用代码 然后现在表示不准确&#xff0c;然后我通过命令git commit --amend修改commit

Python 导入Excel三维坐标数据 生成三维曲面地形图(面) 2、线条平滑曲面但有间隔

环境和包: 环境 python:python-3.12.0-amd64包: matplotlib 3.8.2 pandas 2.1.4 openpyxl 3.1.2 scipy 1.12.0 代码: import pandas as pd import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D from scipy.interpolate import griddata imp…

【李沐论文精读】GPT、GPT-2和GPT-3论文精读

论文&#xff1a; GPT&#xff1a;Improving Language Understanding by Generative Pre-Training GTP-2&#xff1a;Language Models are Unsupervised Multitask Learners GPT-3&#xff1a;Language Models are Few-Shot Learners 参考&#xff1a;GPT、GPT-2、GPT-3论文精读…

java基础2-常用API

常用API Math类 帮助我们进行数学计算的工具类。 里面的方法都是静态的。 3.常见方法如下&#xff1a; abs:获取绝对值 absExact:获取绝对值 ceil:向上取整 floor:向下取整 round:四舍五入 max:获取最大值 …

Stable Diffusion 模型:从噪声中生成逼真图像

你好&#xff0c;我是郭震 简介 Stable Diffusion 模型是一种生成式模型&#xff0c;可以从噪声中生成逼真的图像。它由 Google AI 研究人员于 2022 年提出&#xff0c;并迅速成为图像生成领域的热门模型。 数学基础 Stable Diffusion模型基于一种称为扩散概率模型(Diffusion P…

并查集算法

文章目录 并查集并查集引入1.初始化2.查询3.合并路径压缩代码模板(1)朴素并查集&#xff1a;(2)维护size的并查集&#xff1a;(3)维护到祖宗节点距离的并查集&#xff1a; 并查集 并查集引入 并查集&#xff08;Union-find Sets&#xff09;是一种非常精巧而实用的数据结构&a…

设计模式 -- 1:简单工厂模式

目录 代码记录代码部分 代码记录 设计模式的代码注意要运用到面向对象的思想 考虑到紧耦合和松耦合 把具体的操作类分开 不让其互相影响&#xff08;注意这点&#xff09; 下面是UML类图 代码部分 #include <iostream> #include <memory> // 引入智能指针的头文…