从头构建gpt2 基于Transformer

从头构建gpt2 基于Transformer

VX关注{晓理紫|小李子},获取技术推送信息,如感兴趣,请转发给有需要的同学,谢谢支持!!

如果你感觉对你有所帮助,请关注我。

源码获取 VX关注晓理紫并回复“chatgpt-0”

在这里插入图片描述

  • 头文件以及超参数
import torch
import torch.nn as nn
from torch.nn import functional as F
# 加入为了扩大网络进行修改 head ,注意力、前向网络添加了dropout和设置蹭数目
#超参数
batch_size = 64
block_size = 34 #块大小 现在有34个上下文字符来预测
max_iters = 5000
eval_interval = 500
learning_rate=3e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 384  #嵌入维度
n_head = 6    #有6个头,每个头有284/6维
n_layer = 6   # 6层
dropout = 0.2torch.manual_seed(1337)
  • 数据处理

with open('input.txt','r',encoding='utf-8') as f:text = f.read()chars = sorted(list(set(text)))
vocab_size = len(chars)stoi = {ch:i for i,ch in enumerate(chars)}itos = {i:ch for i,ch in enumerate(chars)}encode = lambda s : [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])data = torch.tensor(encode(text),dtype=torch.long)n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]def get_batch(split):data = train_data if split=="train" else val_dataix = torch.randint(len(data)-batch_size,(batch_size,))x = torch.stack([data[i:i+block_size] for i in ix])y = torch.stack([data[i+1:i+block_size+1] for i in ix])x,y = x.to(device),y.to(device)return x,y
  • 估计损失
@torch.no_grad()
def estimate_loos(model):out={}model.eval()for split in ['train','val']:losses = torch.zeros(eval_iters)for k in range(eval_iters):x,y = get_batch(split)logits,loss = model(x,y)losses[k] = loss.mean()out[split] = losses.mean()model.train()return out
  • 单头注意力
class Head(nn.Module):"""one head of self-attention"""def __init__(self, head_size):super(Head,self).__init__()self.key = nn.Linear(n_embd,head_size,bias=False)self.query = nn.Linear(n_embd,head_size,bias=False)self.value= nn.Linear(n_embd,head_size,bias=False)self.register_buffer('tril',torch.tril(torch.ones(block_size,block_size)))self.dropout = nn.Dropout(dropout)def forward(self,x):B,T,C = x.shapek = self.key(x) #(B,T,C)q = self.query(x) #(B,T,C)wei = q@k.transpose(-2,-1)*C**-0.5  #(B,T,C) @ (B,C,T)-->(B,T,T)wei = wei.masked_fill(self.tril[:T,:T]==0,float('-inf'))#(B,T,T)wei = F.softmax(wei,dim=-1) #(B,T,T)wei = self.dropout(wei)v= self.value(x)out = wei@vreturn out
  • 多头注意力

在这里插入图片描述

class MultiHeadAttention(nn.Module):"""multiple heads of self-attention in parallel"""def __init__(self, num_heads,head_size) -> None:super(MultiHeadAttention,self).__init__()self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])self.proj = nn.Linear(n_embd,n_embd) #投影,为了方便使用惨差跳连self.dropout = nn.Dropout(dropout)def forward(self,x):out = torch.cat([h(x) for h in self.heads],dim=-1)out = self.dropout(self.proj(out))return out
  • 前馈网络

在这里插入图片描述

class FeedFoward(nn.Module):"""a simple linear layer followed by a non-linearity"""def __init__(self,n_embd):super().__init__()self.net = nn.Sequential(nn.Linear(n_embd,4*n_embd), #从512变成2048nn.ReLU(),nn.Linear(4*n_embd,n_embd), #从2048变成512nn.Dropout(dropout),  #Dropout 是可以在惨差链接之前加的东西)def forward(self,x):out = self.net(x)return out

在这里插入图片描述

class Block(nn.Module):"""Transformer block:communication followed by computation"""def __init__(self, n_embd,n_head) -> None:#n_embd 需要嵌入维度中的嵌入数量#n_head 头部数量super().__init__()head_size = n_embd//n_headself.sa = MultiHeadAttention(n_head,head_size) #通过多头注意力进行计算self.ffwd = FeedFoward(n_embd) # 对注意力计算的结果进行提要完成self.ln1 = nn.LayerNorm(n_embd) #层规范  对于优化深层网络很重要 论文Layer Normalizationself.ln2 = nn.LayerNorm(n_embd) #层规范def forward(self,x):# 通过使用残差网络的跳连进行x = x + self.sa(self.ln1(x)) x = x + self.ffwd(self.ln2(x))return x
  • 整个语言模型
class BigramLangeNodel(nn.Module):def __init__(self):super(BigramLangeNodel,self).__init__()self.token_embedding_table = nn.Embedding(vocab_size,n_embd) #令牌嵌入表,对标记的身份进行编码self.position_embedding_table = nn.Embedding(block_size,n_embd) #位置嵌入表,对标记的位置进行编码。从0到block_size大小减一的每个位置将获得自己的嵌入向量self.blocks = nn.Sequential(*[Block(n_embd,n_head=n_head) for _ in range(n_layer)]) #通过n_layer设置构建的曾数self.ln_f = nn.LayerNorm(n_embd)self.lm_head = nn.Linear(n_embd,vocab_size)  #进行令牌嵌入到logits的转换,这是语言头def forward(self,idx,targets=None):B,T = idx.shapetok_emb= self.token_embedding_table(idx) #(B,T,C) C是嵌入大小    根据idx内的令牌的身份进行编码pos_emb = self.position_embedding_table(torch.arange(T,device=device)) #(T,C) 从0到T减一的整数都嵌入到表中x = tok_emb+pos_emb #(B,T,C) 标记的身份嵌入与位置嵌入相加。x保存了身份以及身份出现的位置# x = self.sa_head(x)  #(B,T,C)x = self.blocks(x)x = self.ln_f(x)logits = self.lm_head(x) #(B,T,vocab_size)if targets is None:loss = Noneelse:B,T,C = logits.shapelogits = logits.view(B*T,C)targets = targets.view(B*T)loss = F.cross_entropy(logits,targets)return logits,lossdef generate(self,idx,max_new_tokens):for _ in range(max_new_tokens):idx_cond = idx[:,-block_size] logits,loss= self(idx_cond)logits = logits[:,-1,:]# becomes (B,C)probs = F.softmax(logits,dim=-1)idx_next = torch.multinomial(probs,num_samples=1)idx = torch.cat((idx,idx_next),dim=1) #(B,T+1)return idx
  • 训练
model = BigramLangeNodel()
m  = model.to(device)optimizer = torch.optim.AdamW(model.parameters(),lr = learning_rate)for iter in range(max_iters):if iter % eval_interval==0:losses = estimate_loos(model)print(f"step {iter}:train loss {losses['train']:.4f},val loss{losses['val']:.4f}")xb,yb = get_batch('train')logits,loss = model(xb,yb)optimizer.zero_grad(set_to_none=True)loss.backward()optimizer.step()context = torch.zeros((1,1),dtype=torch.long,device=device)
print(decode(m.generate(context,max_new_tokens=500)[0].tolist()))
  • 训练损失
step 0:train loss 4.3975,val loss4.3983
step 500:train loss 1.8497,val loss1.9600
step 1000:train loss 1.6500,val loss1.8210
step 1500:train loss 1.5530,val loss1.7376
step 2000:train loss 1.5034,val loss1.6891
step 2500:train loss 1.4665,val loss1.6638
step 3000:train loss 1.4431,val loss1.6457
step 3500:train loss 1.4156,val loss1.6209
step 4000:train loss 1.3958,val loss1.6025
step 4500:train loss 1.3855,val loss1.5988

简单实现自注意力

VX 关注{晓理紫|小李子},获取技术推送信息,如感兴趣,请转发给有需要的同学,谢谢支持!!

如果你感觉对你有所帮助,请关注我。

源码获取 VX关注晓理紫并回复“chatgpt-0”
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/716571.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

CSS 自测题

盒模型的宽度计算 默认为标准盒模型 box-sizing:content-box; offsetWidth (内容宽度内边距 边框),无外边距 答案 122px通过 box-sizing: border-box; 可切换为 IE盒模型 offsetWidth width 即 100px margin 纵向重叠 相邻元素的 margin-top 和 margin-bottom 会发…

Benchmark学习笔记

小记一篇Benchmark的学习笔记 1.什么是benchmark 在维基百科中,是这样子讲的 “As computer architecture advanced, it became more difficult to compare the performance of various computer systems simply by looking at their specifications.Therefore, te…

python标识符、变量和常量

一、保留字与标识符 1.1保留字 保留字是指python中被赋予特定意义的单词,在开发程序时,不可以把这些保留字作为变量、函数、类、模块和其它对象的名称来使用。 比如:and、as、def、if、import、class、finally、with等 查询这些关键字的方…

【LeetCode】升级打怪之路 Day 11 加餐:单调队列

今日题目: 239. 滑动窗口最大值 | LeetCode 今天学习了单调队列这种特殊的数据结构,思路很新颖,值得学习。 Problem:单调队列 【必会】 与单调栈类似,单调队列也是一种特殊的数据结构,它相比与普通的 que…

【NR 定位】3GPP NR Positioning 5G定位标准解读(一)

目录 前言 1. 3GPP规划下的5G技术演进 2. 5G NR定位技术的发展 2.1 Rel-16首次对基于5G的定位技术进行标准化 2.2 Rel-17进一步提升5G定位技术的性能 3. Rel-18 关于5G定位技术的新方向、新进展 3.1 Sidelink高精度定位功能 3.2 针对上述不同用例,3GPP考虑按…

Go-知识简短变量声明

Go-知识简短变量声明 1. 简短变量声明符2. 简短变量赋值可能会重新声明3. 简短变量赋值不能用于函数外部4. 简短变量赋值作用域问题5. 总结 githuio地址:https://a18792721831.github.io/ 1. 简短变量声明符 在Go语言中,可以使用关键字var或直接使用简短…

【STK】手把手教你利用STK进行仿真-STK软件基础02 STK系统的软件界面01 STK的界面窗口组成

STK系统是Windows窗口类型的桌面应用软件,功能非常强大。在一个桌面应用软件中集成了仿真对象管理、仿真对象属性参数、设置、空间场景二三维可视化、场景显示控制欲操作、仿真结果报表定制与分析、对象数据管理、仿真过程控制、外部接口连接和系统集成编程等复杂的功能。 STK…

SpringBoot之Actuator的两种监控模式

SpringBoot之Actuator的两种监控模式 springboot提供了很多的检测端点(Endpoint),但是默认值开启了shutdown的Endpoint&#xff0c;其他默认都是关闭的,可根据需要自行开启 文章目录 SpringBoot之Actuator的两种监控模式1. pom.xml2. 监控模式1. HTTP2. JMX 1. pom.xml <de…

力扣 第 125 场双周赛 解题报告 | 珂学家 | 树形DP + 组合数学

前言 整体评价 T4感觉有简单的方法&#xff0c;无奈树形DP一条路上走到黑了&#xff0c;这场还是有难度的。 T1. 超过阈值的最少操作数 I 思路: 模拟 class Solution {public int minOperations(int[] nums, int k) {return (int)Arrays.stream(nums).filter(x -> x <…

VM虚拟机无法传输文件(更新时间24/3/3)

出现这个问题一般是未安装VMware Tools 以下为手动安装教程及可能出现的问题的解决方法&#xff1a; 1. 准备安装 2.用cmd手动启动安装 3. 安装过程默认即可&#xff0c;直接一直下一步 4.安装完成后会自动重启虚拟机&#xff08;没有的话手动重启即可&#xff09; 5.重启以后…

StarCoder2模型,释放你的大模型编码潜能

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗&#xff1f;订阅我们的简报&#xff0c;深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同&#xff0c;从行业内部的深度分析和实用指南中受益。不要错过这个机会&#xff0c;成为AI领…

部署若依前后端分离项目,连接数据库失败

部署若依前后端分离项目&#xff0c;连接数据库失败&#xff0c;异常如下&#xff1a; 解决方案&#xff1a;application配置文件里&#xff0c;连接数据库的参数useSSL的值改为false

leetcode 长度最小的子数组

在本题中&#xff0c;我们可以知道&#xff0c;是要求数组中组成和为target的最小子数组的长度。所以&#xff0c;我们肯定可以想到用两层for循环进行遍历&#xff0c;然后枚举所有的结果进行挑选&#xff0c;但这样时间复杂度过高。 我们可以采用滑动窗口&#xff0c;其实就是…

编写dockerfile挂载卷、数据容器卷

编写dockerfile挂载卷 编写dockerfile文件 [rootwq docker-test-volume]# vim dockerfile1 [rootwq docker-test-volume]# cat dockerfile1 FROM centosVOLUME ["volume01","volume02"]CMD echo "------end------" CMD /bin/bash [rootwq dock…

2024 年广东省职业院校技能大赛(高职组)“云计算应用”赛项样题 2

#需要资源或有问题的&#xff0c;可私博主&#xff01;&#xff01;&#xff01; #需要资源或有问题的&#xff0c;可私博主&#xff01;&#xff01;&#xff01; #需要资源或有问题的&#xff0c;可私博主&#xff01;&#xff01;&#xff01; 某企业根据自身业务需求&#…

每日OJ题_牛客_合法括号序列判断

目录 合法括号序列判断 解析代码 合法括号序列判断 合法括号序列判断__牛客网 解析代码 class Parenthesis {public:bool chkParenthesis(string A, int n){if (n & 1) // 如果n是奇数return false;stack<char> st;for (int i 0; i < n; i) {if (A[i] () {s…

笔记本hp6930p安装Android-x86补记

在上一篇日记中&#xff08;笔记本hp6930p安装Android-x86避坑日记-CSDN博客&#xff09;提到hp6930p安装Android-x86-9.0&#xff0c;无法正常启动&#xff0c;本文对此再做尝试&#xff0c;原因是&#xff1a;Android-x86-9.0不支持无线网卡&#xff0c;需要在BIOS中关闭WLAN…

B082-SpringCloud-Eureka

目录 微服务架构与springcloud架构演变为什么使用微服务微服务的通讯方式架构的选择springcloud概述场景模拟之基础架构的搭建模拟微服务之间的服务调用目前远程调用的问题 eureka注册中心的作用注册中心的实现服务提供者注册到注册中心 springcloud基于springboot 微服务架构与…

10 计算机结构

冯诺依曼体系结构 冯诺依曼体系结构&#xff0c;也被称为普林斯顿结构&#xff0c;是一种计算机架构&#xff0c;其核心特点包括将程序指令存储和数据存储合并在一起的存储器结构&#xff0c;程序指令和数据的宽度相同&#xff0c;通常都是16位或32位 我们常见的计算机,笔记本…

在Centos7中用Docker部署gitlab-ce

一、介绍 GitLab Community Edition (GitLab CE) 是一个开源的版本控制系统和协作平台&#xff0c;用于管理和追踪软件开发项目。它提供了一套完整的工具和功能&#xff0c;包括代码托管、版本控制、问题跟踪、持续集成、持续交付和协作功能&#xff0c;使团队能够更加高效地进…