手搓没有softmax 的gpt

手搓没有softmax 的gpt

  • 代码
  • 解析代码

代码

import pandas as pd
from tqdm import tqdm
import numpy as np
import paddle# 
class FeedForward(paddle.nn.Layer):def __init__(self, hidden_dim):super(FeedForward, self).__init__()self.fc_one = paddle.nn.Linear(hidden_dim, hidden_dim // 2, bias_attr=False)self.fc_two = paddle.nn.Linear(hidden_dim // 2, hidden_dim, bias_attr=False)self.gre = paddle.nn.GELU()def forward(self, feed_x):feed_x = self.fc_one(feed_x)feed_x = self.gre(feed_x)feed_x = self.fc_two(feed_x)return feed_x# 注意力层
class Attention(paddle.nn.Layer):def __init__(self, hidden_dim, heads):super(Attention, self).__init__()self.q = paddle.nn.Linear(hidden_dim, heads * hidden_dim, bias_attr=False)self.k = paddle.nn.Linear(hidden_dim, heads * hidden_dim, bias_attr=False)self.v = paddle.nn.Linear(hidden_dim, heads * hidden_dim, bias_attr=False)self.heads = headsdef forward(self, sx):b, s, h = sx.shapeq = paddle.nn.functional.relu(self.q(sx))k = paddle.nn.functional.relu(self.k(sx))v = self.v(sx)qk = q.reshape([b, s, self.heads, h]).transpose([0, 2, 1, 3]) @ k.reshape([b, s, self.heads, h]).transpose([0, 2, 3, 1])mask = paddle.triu(paddle.ones([s, s]))# mask[mask == 0] = -np.infqk_mask = qk * maskqk = qk_mask / (paddle.sum(qk_mask, -2).unsqueeze([-2]) + 0.00000000000001)qkv = qk.transpose([0, 1, 3, 2]) @ v.reshape([b, s, self.heads, h]).transpose([0, 2, 1, 3])qkv = qkv.transpose([0, 2, 3, 1])qkv = paddle.nn.functional.max_pool1d(qkv.reshape([b, -1, self.heads]), self.heads).reshape([b, s, h])return qkvclass GPT(paddle.nn.Layer):def __init__(self, voc_size, hidden_dim, row_layers, lora=False):super(GPT, self).__init__()self.em = paddle.nn.Embedding(voc_size, hidden_dim)self.cv = Attention(hidden_dim, row_layers)self.feed = FeedForward(hidden_dim)self.lora = FeedForward(hidden_dim)self.lora_flag = loraself.out_layer = paddle.nn.Linear(hidden_dim, voc_size, bias_attr=False)self.layer_nor = paddle.nn.LayerNorm(hidden_dim, bias_attr=False)# self.p_next = paddle.to_tensor(list(range(voc_size))).astype("int64").reshape([1, -1])def forward(self, sx):if self.lora_flag:with paddle.no_grad():sx = self.em(sx)sx += self.cv(sx)sx = self.layer_nor(sx)sx += self.feed(sx)sx = self.layer_nor(sx)sx += self.lora(sx)with paddle.no_grad():out = self.out_layer(sx)else:sx = self.em(sx)sx += self.cv(sx)sx = self.layer_nor(sx)sx += self.feed(sx)sx = self.layer_nor(sx)sx += self.lora(sx)out = self.out_layer(sx)return outdef load_lora(self, lora_name):self.lora.load_dict(paddle.load(lora_name))def save_lora(self, lora_name):paddle.save(self.lora.state_dict(), lora_name)def gen_basic_data():seq_len = 32with open("fixed_couplets_in.txt", "r", encoding="utf-8") as f:train_data = f.readlines()with open("fixed_couplets_out.txt", "r", encoding="utf-8") as f:dev_data = f.readlines()train_data = [i.strip().split() for i in tqdm(train_data)]dev_data = [i.strip().split() for i in tqdm(dev_data)]train_data_list = []data_id_index = 0for i, j in tqdm(zip(train_data, dev_data)):one = i + ["。"] + j + list("|_{}_|".format(data_id_index))data_id_index += 1train_data_list += oneseq_len_count = 1with open("train_data_list.txt", "a", encoding="utf-8") as f:voc = dict()for i in tqdm(range(0, len(train_data_list), seq_len)):if i > 0:j = i + seq_lenone = train_data_list[i - seq_len_count:j - seq_len_count]seq_len_count += 1else:j = i + seq_lenone = train_data_list[i:j]if len(one) == seq_len:f.write(str(one) + "\n")for k in one:voc[k] = ""del train_data_listdel train_datadel dev_datavoc = ["<|pad|>"] + list(voc.keys())voc_dict = {k: v for v, k in enumerate(voc)}pd.to_pickle(voc, "voc_data.pandas_pickle")with open("train_data_list.txt", "r", encoding="utf-8") as f:train_data = f.readlines()train_data_list = [[voc_dict[j] for j in eval(i)] for i in tqdm(train_data)]pd.to_pickle(train_data_list, "train_data.pandas_pickle")def train_data():voc_id = pd.read_pickle("voc_data.pandas_pickle")net = GPT(len(voc_id) + 1, 128, 2)loss_func = paddle.nn.CrossEntropyLoss(ignore_index=-1)opt = paddle.optimizer.Adam(learning_rate=0.0001, parameters=net.parameters())bar = tqdm(range(1700))batch_size = 1200data_set = pd.read_pickle("train_data.pandas_pickle")acc_list = []for epoch in bar:np.random.shuffle(data_set)for i in range(0, len(data_set), batch_size):j = i + batch_sizedata = paddle.to_tensor(data_set[i:j]).astype("int64")label = data[:, 1:]input_data = data[:, :-1]out = net(input_data)loss = loss_func(out.reshape([-1, out.shape[-1]]), label.reshape([-1]))acc = paddle.metric.accuracy(out.reshape([-1, len(voc_id) + 1]), label.reshape([-1, 1]))acc_list.append(acc.item())bar.set_description("epoch___{}___step___{}_loss___{:.5f}_acc__{:.5f}__{:.5f}".format(epoch, j, loss.item(),np.mean(acc_list), (paddle.argmax(out, -1) == label).numpy().mean()))opt.clear_grad()loss.backward()opt.step()paddle.save(net.state_dict(), "model_{}.paddle".format(epoch))def train_data_lora(lora_one_name):voc_id = pd.read_pickle("voc_data.pandas_pickle")net = GPT(len(voc_id) + 1, 128, 2, True)net.load_dict(paddle.load("basic.paddle"))loss_func = paddle.nn.CrossEntropyLoss(ignore_index=-1)opt = paddle.optimizer.Adam(learning_rate=0.00001, parameters=net.parameters())bar = tqdm(range(1700))batch_size = 1200data_set = pd.read_pickle("train_data.pandas_pickle")# plt.ion()acc_list = []for epoch in bar:np.random.shuffle(data_set)for i in range(0, len(data_set), batch_size):j = i + batch_sizedata = paddle.to_tensor(data_set[i:j]).astype("int64")label = data[:, -1:]input_data = data[:, :-1]out = net(input_data)loss = loss_func(out.reshape([-1, out.shape[-1]]), label.reshape([-1]))acc = paddle.metric.accuracy(out.reshape([-1, len(voc_id) + 1]), label.reshape([-1, 1]))acc_list.append(acc.item())bar.set_description("epoch___{}___step___{}_loss___{:.5f}_acc__{:.5f}__{:.5f}".format(epoch, j, loss.item(),np.mean(acc_list),(paddle.argmax(out, -1) ==label).numpy().mean()))opt.clear_grad()loss.backward()opt.step()paddle.save(net.lora.state_dict(), "model_{}.paddle".format(lora_one_name))if __name__ == '__main__':# gen_basic_data()train_data()# net = CvFoBlock(256, 2,8)# net(paddle.randn([3, 5, 256]))# eval_data()

解析代码

该代码定义了一个GPT模型,其中包括了三个子模块:FeedForward、Attention和GPT。

FeedForward模块是一个前馈神经网络,用于对输入进行非线性变换。

Attention模块实现了多头注意力机制,用于计算输入的注意力权重。

GPT模块是整个模型的主体,包括了嵌入层、注意力层、前馈层和输出层。它将输入通过嵌入层得到词向量表示,然后通过注意力层和前馈层进行特征提取和转换,最后通过输出层得到最终的预测结果。在其中还包含了LayerNorm层用于归一化输入。

模型的前向传播过程中根据lora_flag标志来决定是否加载预训练的lora模型。如果需要加载lora模型,则在前向传播过程中先进行一次运算,并将结果与输入相加,再进行后续的层运算。最后得到的输出结果经过线性变换得到预测结果。

该模型还提供了加载和保存lora模型的方法,用于在训练过程中保存和加载预训练的lora模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/602716.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C#-数组

数组 (array) 是一种包含若干变量的数据结构&#xff0c;这些变量都可以通过计算索引进行访问。数组中包含的变量&#xff08;又称数组的元素&#xff09;具有相同的类型&#xff0c;该类型称为数组的元素类型。 数组类型为引用类型&#xff0c;因此数组变量的声明只是为数组实…

python 使用hash 给超级多文件高速去重

python 使用hash 给超级多文件高速去重 代码解释代码 import os from glob import globfrom tqdm import tqdm import hashlib def gen_hash(text):# 创建一个SHA256哈希对象hash_object = hashlib.sha256()

C语言基础知识(6):UDP网络编程

UDP 是不具有可靠性的数据报协议。细微的处理它会交给上层的应用去完成。在 UDP 的情况下&#xff0c;虽然可以确保发送消息的大小&#xff0c;却不能保证消息一定会到达。因此&#xff0c;应用有时会根据自己的需要进行重发处理。 1.UDP协议的主要特点&#xff1a; &#xf…

c#调试程序一次启动两个工程(多个工程)

概述 c# - Visual Studio : debug multiple projects at the same time? 以在解决方案中设置多个启动项目(右键单击解决方案&#xff0c;转到设置启动项目&#xff0c;选择多个启动项目)&#xff0c;并为包含在解决方案(无、开始、不调试就开始)。如果您将多个项目设置为开始…

Oracle文件自动“减肥”记

&#x1f4e2;&#x1f4e2;&#x1f4e2;&#x1f4e3;&#x1f4e3;&#x1f4e3; 哈喽&#xff01;大家好&#xff0c;我是【IT邦德】&#xff0c;江湖人称jeames007&#xff0c;10余年DBA及大数据工作经验 一位上进心十足的【大数据领域博主】&#xff01;&#x1f61c;&am…

用通俗易懂的方式讲解:ChatGPT 开放的多模态的DALL-E 3功能,好玩到停不下来!

最近 ChatGPT 对 Plus 用户逐步开放一些多模态的功能&#xff0c;包括 &#xff08;图像生成&#xff09;、 GPT-4V&#xff08;图像识别&#xff09;等&#xff0c;很多网友乐此不疲地对这些新功能进行试用&#xff0c; 目前已经解锁了不少有趣的玩法&#xff0c;我将这些好玩…

C#,入门教程(09)——运算符的基础知识

上一篇&#xff1a; C#&#xff0c;入门教程(08)——基本数据类型及使用的基础知识https://blog.csdn.net/beijinghorn/article/details/123906998 一、算术运算符号 算术运算符号包括&#xff1a;四则运算 加 , 减-, 乘*, 除/与取模%。 // 加法&#xff0c;运算 int va 1 …

CSS3 边框border、outline、box-shadow

1 border 语法&#xff1a;border: width style color 2 outline 语法&#xff1a;outline: width style color 2.1 outline-offet MDN解释&#xff1a;用于设置outline与一个元素边缘或边框之间的间隙 即&#xff1a;设置outline相对border外边缘的偏移&#xff0c;可以为…

C#不会循环响应的Action设计与实现

目录 一、简述二、测试代码三、测试的输出四、核心代码五、其它 一、简述 特点&#xff1a; 不光是能防止直接的死循环调用&#xff1b;还能防止间接的死循环调用&#xff1b;还支持对不同参数判定&#xff0c;不同参数的调用可以不当循环调用&#xff1b; 消息事件系统中必…

SpringBoot 调用mybatis报错:Invalid bound statement (not found):

启动SpringBoot报错&#xff1a;Invalid bound statement (not found): 参考此文排查 命中了第6条 记录一手坑爹的Invalid bound statement (not found)&#xff08;六个方面&#xff09; mapper文件路径配置错误 订正以后 问题解决

如何在Ubuntu安装SVN服务并结合cpolar实现公网TCP地址远程访问本地服务

文章目录 前言1. Ubuntu安装SVN服务2. 修改配置文件2.1 修改svnserve.conf文件2.2 修改passwd文件2.3 修改authz文件 3. 启动svn服务4. 内网穿透4.1 安装cpolar内网穿透4.2 创建隧道映射本地端口 5. 测试公网访问6. 配置固定公网TCP端口地址6.1 保留一个固定的公网TCP端口地址6…

冒泡排序数据结构实验报告

实验目的&#xff1a; 理解冒泡排序算法的原理和基本思路。熟悉冒泡排序在实际应用中的场景和优化方法。 实验内容&#xff08;实验题目与说明&#xff09; 编写一个双向冒泡排序算法&#xff0c;即在排序过程中以交替的正、反两个方向进行遍历。若第一趟把关键字最大的记录…

物联网产品中,终端、网关、协议、PaaS、SaaS之间的关系

在互联网产品中&#xff0c;经常提到的终端、网关、协议、PaaS、SaaS之间&#xff0c;到底有什么关系呢&#xff1f; 一、基本概念 在百度/其他地方搜集的信息中&#xff0c;对于终端、网关、协议、PaaS、SaaS的解释各有不同&#xff0c;整理如下&#xff1a; 终端&#xff1…

SSR 服务器端渲染:提升用户体验的新趋势(上)

&#x1f90d; 前端开发工程师&#xff08;主业&#xff09;、技术博主&#xff08;副业&#xff09;、已过CET6 &#x1f368; 阿珊和她的猫_CSDN个人主页 &#x1f560; 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 &#x1f35a; 蓝桥云课签约作者、已在蓝桥云…

第6章-第3节-Java中的字符串缓冲区:StringBuilder和StringBuffer

1、字符串缓冲区 理解&#xff1a;Java内存层面的一款容器(crud操作) 引入场景&#xff1a; 根据需求需要对某字符串内容进行频繁的改动操作&#xff0c; 如果使用String类原生的方式进行处理&#xff0c;则会在内存中产生大量的对象&#xff1b; 面临的问题&…

Generator - JavaScript的异步颠覆者

&#x1f9d1;‍&#x1f393; 个人主页&#xff1a;《爱蹦跶的大A阿》 &#x1f525;当前正在更新专栏&#xff1a;《VUE》 、《JavaScript保姆级教程》、《krpano》 ​ ​ 目录 ✨ 前言 什么是Generator 生成器函数的执行流程控制 异步编程应用 ✨ 结语 ✨ 前言 Java…

Gitee

Gitee码云 0. 笔记说明1. Gitee概述2. Gitee和GitHub3. 创建Git远程仓库4. 分享已有项目到Gitee5. 文件恢复和合并6. 文件push或pull冲突7. 添加项目成员 0. 笔记说明 该笔记以IDEA 2023专业版进行操作需提前注册好个人gitee账号安装好IDEA的相关gitee插件或者安装Git Bash软件…

【机器学习】循环神经网络(二)-LSTM示例(keras)国际航空乘客问题的回归问题...

使用 Keras 在 Python 中使用 LSTM 循环神经网络进行时间序列预测 国际航空乘客问题的回归问题 这个文件是一个CSV格式的数据集&#xff0c;它包含了从1949年1月到1960年12月的每个月的国际航空乘客的总数&#xff08;以千为单位&#xff09;。第一行是列名&#xff0c;分别是&…

Baumer工业相机堡盟工业相机如何通过NEOAPI SDK修改图像像素格式Mono8或者Mono10(C++)

Baumer工业相机堡盟工业相机如何通过NEOAPI SDK修改图像像素格式Mono8或者Mono10&#xff08;C&#xff09; Baumer工业相机Baumer工业相机的图像像素格式的技术背景CameraExplorer如何查看修改相机图像像素格式信息在NEOAPI SDK里通过函数修改图像像素格式修改像素格式测试演示…

二刷Laravel 教程(用户注册)总结Ⅳ

一、显示用户信息 1&#xff09;resource Route::resource(users, UsersController); 相当于下面这7个路由 我们先用 Artisan 命令查看目前应用的路由&#xff1a; php artisan route:list 2&#xff09; compact 方法 //我们将用户对象 $user 通过 compact 方法转化为一个关联…