纯c++实现transformer 训练+推理

项目地址

https://github.com/freelw/cpp-transformer

C++ 实现的 Transformer

这是一个无需依赖特殊库的 Transformer 的 C++ 实现,涵盖了训练与推理功能。

本项目使用C++复刻了《Dive into Deep Learning》中关于 Transformer 的第 11 章11.7小节点内容。构建了一个英法机器翻译模型。本项目自主开发了自动求导框架,仅依赖 C++ 标准库,旨在助力用户理解 Transformer 的底层原理。

项目亮点

注重原理

从基础操作入手构建模型,不依赖深度学习框架。这种方式清晰地展示了 Transformer 的运行机制。

自动求导

自主研发的自动求导框架简化了梯度计算流程,有助于更好地理解反向传播算法。

低依赖性

该项目仅依赖 C++ 标准库。尽管其性能可能不如那些使用高级库的项目,但它清晰呈现了每一个计算细节。这一特性使用户能够深入理解反向传播算法以及 Transformer 架构的底层原理。

快速开始

构建

./build_all.sh 

测试推理翻译

./test_translation.sh

输出

./test_translation.sh 
~/project/cpp-transformer/checkpoints/save ~/project/cpp-transformer
~/project/cpp-transformer
OMP_THREADS: 8
epochs : 0
dropout : 0.2
lr : 0.001
tiny : 0
data loaded
warmUp done
parameter size = 21388
all parameters require_grad = true
loading from checkpoint : ./checkpoints/save/checkpoint_20250402_150847_40.bin
loaded from checkpoint
serving mode
go now . <eos> 
translate res : <bos> allez-y maintenant maintenant maintenant . <eos> 
i try . <eos> 
translate res : <bos> j'essaye . <eos> 
cheers ! <eos> 
translate res : <bos> santé ! <eos> 
get up . <eos> 
translate res : <bos> lève-toi . <eos> 
hug me . <eos> 
translate res : <bos> <unk> dans vos bras ! <eos> 
i know . <eos> 
translate res : <bos> je sais . <eos> 
no way ! <eos> 
translate res : <bos> en aucune manière ! <eos> 
be nice . <eos> 
translate res : <bos> soyez gentille ! <eos> 
i jumped . <eos> 
translate res : <bos> j'ai sauté . <eos> 
congratulations ! <eos> 
translate res : <bos> à ! <eos> 

测试训练

在tiny训练集上进行训练(300句英法对照语料)

./train_tiny.sh

输出

./train_tiny.sh 
OMP_THREADS: 8
epochs : 10
dropout : 0.2
lr : 0.001
tiny : 0
data loaded
warmUp done
parameter size = 21388
all parameters require_grad = true
[300/300]checkpoint saved : ./checkpoints/checkpoint_20250402_164906_0.bin
epoch 0 loss : 9.0757 emit_clip : 3
[300/300]epoch 1 loss : 7.90043 emit_clip : 3
[300/300]epoch 2 loss : 6.8447 emit_clip : 3
[300/300]epoch 3 loss : 5.85042 emit_clip : 3
[300/300]epoch 4 loss : 5.00354 emit_clip : 3
[300/300]epoch 5 loss : 4.38405 emit_clip : 3
[300/300]epoch 6 loss : 3.96133 emit_clip : 3
[300/300]epoch 7 loss : 3.70218 emit_clip : 3
[300/300]epoch 8 loss : 3.51153 emit_clip : 3
[300/300]checkpoint saved : ./checkpoints/checkpoint_20250402_164906_9.bin
epoch 9 loss : 3.35273 emit_clip : 3

代码片段一览

前向

以PositionwiseFFN举例,我们只要声明前向过程即可,框架会自动生成计算图,在调用backward时自动求导

autograd::Node *PositionwiseFFN::forward(autograd::Node *x) {return dense2->forward(dense1->forward(x)->Relu());
}

自动求导实现

以矩阵乘法为例,在node.h node.cpp中,乘法会生成一个结果节点,关联两条边到两个乘数。

    Node *Node::operator*(Node *rhs) {auto *node = allocNode(*w * *(rhs->w));if (is_require_grad() || rhs->is_require_grad()) {node->require_grad();if (is_require_grad()) {node->edges.push_back(MulEdge::create(this, rhs->get_weight()));}if (rhs->is_require_grad()) {node->edges.push_back(MulEdge::create(rhs, w));}}return node;}

在边中实现梯度的反向传播,注意左边和右边的操作方式不同(是否需要专置)

    class MatMulLEdge : public Edge {public:static Edge* create(Node *_node, Matrix *_param) {Edge *edge = new MatMulLEdge(_node, _param);edges.push_back(edge);return edge;}MatMulLEdge(Node *_node, Matrix *_param): Edge(MatMulL, _node), param(_param) {}virtual ~MatMulLEdge() {}void backward(Matrix *grad) override {assert(node->is_require_grad());// *node->get_grad() is grad of W*node->get_grad() += *(grad->at(*(param->transpose())));}private:Matrix *param; // Input Vector};class MatMulREdge : public Edge {public:static Edge* create(Node *_node, Matrix *_param) {Edge *edge = new MatMulREdge(_node, _param);edges.push_back(edge);return edge;}MatMulREdge(Node *_node, Matrix *_param): Edge(MatMulR, _node), param(_param) {}virtual ~MatMulREdge() {}void backward(Matrix *grad) override {assert(node->is_require_grad());// *node->get_grad() is grad of Input*node->get_grad() += *(param->transpose()->at(*grad));}private:Matrix *param; // W};

训练

在main.cpp train函数中,逻辑和pytorch类似,都是要将模型的所有parameters引用/指针传递给优化器,然后依次清理grad,反向传播,裁剪梯度,执行权重调整

            auto loss = dec_outputs->CrossEntropyMask(labels, mask);assert(loss->get_weight()->getShape().rowCnt == 1);assert(loss->get_weight()->getShape().colCnt == 1);loss_sum += (*loss->get_weight())[0][0];adam.zero_grad();loss->backward();if (adam.clip_grad(1)) {emit_clip++;}adam.step();

反向传播梯度公式推导

主要三个比较复杂的层 softmax交叉熵 softmax layernorm

https://github.com/freelw/cpp-transformer/blob/main/doc/equations/readme.md

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/76515.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Go 语言规范学习(7)

文章目录 Built-in functionsAppending to and copying slicesClearCloseManipulating complex numbersDeletion of map elementsLength and capacityMaking slices, maps and channelsMin and maxAllocationHandling panicsBootstrapping PackagesSource file organizationPac…

Python Cookbook-5.1 对字典排序

任务 你想对字典排序。这可能意味着需要先根据字典的键排序&#xff0c;然后再让对应值也处于同样的顺序。 解决方案 最简单的方法可以通过这样的描述来概括:先将键排序&#xff0c;然后由此选出对应值: def sortedDictValues(adict):keys adict.keys()keys.sort()return …

Git Rebase 操作中丢失提交的恢复方法

背景介绍 在团队协作中,使用 Git 进行版本控制是常见实践。然而,有时在执行 git rebase 或者其他操作后,我们可能会发现自己的提交记录"消失"了,这往往让开发者感到恐慌。本文将介绍几种在 rebase 后恢复丢失提交的方法。 问题描述 当我们执行以下操作时,可能…

C语言基础要素(019):输出ASCII码表

计算机以二进制处理信息&#xff0c;但二进制对人类并不友好。比如说我们规定用二进制值 01000001 表示字母’A’&#xff0c;显然通过键盘输入或屏幕阅读此数据而理解它为字母A&#xff0c;是比较困难的。为了有效的使用信息&#xff0c;先驱者们创建了一种称为ASCII码的交换代…

鸿蒙定位开发服务

引言 鸿蒙操作系统&#xff08;HarmonyOS&#xff09;作为面向万物互联时代的分布式操作系统&#xff0c;其定位服务&#xff08;Location Kit&#xff09;为开发者提供了多场景、高精度的位置能力支持。本文将从技术原理、开发流程到实战案例&#xff0c;全面解析鸿蒙定位服务…

rknn_convert的使用方法

rknn_convert是RKNN-Toolkit2提供的一套常用模型转换工具&#xff0c;通过封装上述API接口&#xff0c;用户只需编辑模型对应的yml配置文件&#xff0c;就可以通过指令转换模型。以下是如何使用rknn_convert工具的示例命令以及支持的指令参数&#xff1a; python -m rknn.api.…

解决 axios get请求瞎转义问题

在Vue.js项目中&#xff0c;axios 是一个常用的HTTP客户端库&#xff0c;用于发送HTTP请求。qs 是一个用于处理查询字符串的库&#xff0c;通常与 axios 结合使用&#xff0c;特别是在处理POST请求时&#xff0c;将对象序列化为URL编码的字符串。 1. 安装 axios 和 qs 首先&a…

【XTerminal】【树莓派】Linux系统下的函数调用编程

目录 一、XTerminal下的Linux系统调用编程 1.1理解进程和线程的概念并在Linux系统下完成相应操作 (1) 进程 (2)线程 (3) 进程 vs 线程 (4)Linux 下的实践操作 1.2Linux的“虚拟内存管理”和stm32正式物理内存&#xff08;内存映射&#xff09;的区别 (1)Linux虚拟内存管…

torch 拆分子张量 分割张量

目录 unbind拆分子张量 1. 沿着第n个维度拆分&#xff08;即按“批次”拆分&#xff09; split分割张量 常用用法&#xff1a; 总结&#xff1a; unbind拆分子张量 import torchquaternions torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]) result torch.unbind(quaternio…

【Linux】内核驱动学习笔记(二)

7、framebuffer驱动详解 7.1、什么是framebuffer (1)裸机中如何操作LCD (2)OS下操作LCD的难点 (3)framebuffer帧缓冲&#xff08;简称fb&#xff09;是linux内核中虚拟出的一个设备 (4)framebuffer向应用层提供一个统一标准接口的显示设备 (5)从驱动来看&#xff0c;fb是一个…

用 Docker Compose 与 Nginx 反向代理部署 Vikunja 待办事项管理系统

在高效管理日常任务和项目的过程中&#xff0c;开源待办事项工具 Vikunja 以其简洁、直观的设计和多视图支持受到越来越多用户的青睐。本文将详细介绍如何使用 Docker Compose 快速部署 Vikunja&#xff0c;并通过 Nginx 反向代理实现 HTTPS 访问&#xff0c;从而确保服务安全稳…

使用Python快速接入DeepSeek API的步骤指南

使用Python快速接入DeepSeek API的步骤指南 1. 前期准备 注册DeepSeek账号 访问DeepSeek官网注册账号 完成邮箱验证等认证流程 获取API密钥 登录后进入控制台 → API管理 创建新的API Key并妥善保存 安装必要库 pip install requests # 可选&#xff1a;处理复杂场景 pip…

Redis 主要能够用来做什么

Redis&#xff08;Remote Dictionary Server&#xff09;是一种基于内存的键值存储数据库&#xff0c;它的性能极高&#xff0c;广泛应用于各种高并发场景。以下是 Redis 常见的用途&#xff1a; 1. 缓存&#xff08;Cache&#xff09; 作用&#xff1a;存储热点数据&#xf…

印度股票实时数据API接口选型指南:iTick.org如何成为开发者优选

在全球金融数字化浪潮中&#xff0c;印度股票市场因其高速增长潜力备受关注。对于量化交易开发者、金融科技公司而言&#xff0c;稳定可靠的股票报价API接口是获取市场数据的核心基础设施。本文将深度对比主流印度股票API&#xff0c;并揭示iTick在数据服务领域的独特优势。 一…

24.多路转接-poll

poll也是一种linux中的多路转接的方案 解决select的fd有上限的问题解决select每次调用都要重新设置关心的fd poll函数接口 poll, ppoll - wait for some event on a file descriptor#include <poll.h>int poll(struct pollfd *fds, nfds_t nfds, int timeout);DESCRIP…

Linux 基础入门操作 前言 linux操作指令介绍

1 linux 目录介绍 Linux 文件系统采用层次化的目录结构&#xff0c;所有目录都从根目录 / 开始 1.1 核心目录 / (根目录) 整个文件系统的起点、包含所有其他目录和文件 /bin (基本命令二进制文件) 存放系统最基本的shell命令&#xff1a;如 ls, cp, mv, rm, cat 等&#…

Chrome开发者工具实战:调试三剑客

在前端开发的世界里&#xff0c;Chrome开发者工具就是我们的瑞士军刀&#xff0c;它集成了各种强大的功能&#xff0c;帮助我们快速定位和解决代码中的问题。今天&#xff0c;就让我们一起来看看如何使用Chrome开发者工具中的“调试三剑客”&#xff1a;断点调试、调用栈跟踪和…

函数柯里化(Currying)介绍(一种将接受多个参数的函数转换为一系列接受单一参数的函数的技术)

文章目录 柯里化的特点示例普通函数柯里化实现使用Lodash进行柯里化 应用场景总结 函数柯里化&#xff08;Currying&#xff09;是一种将接受多个参数的函数转换为一系列接受单一参数的函数的技术。换句话说&#xff0c;柯里化将一个多参数函数转化为一系列嵌套的单参数函数。 …

torch.nn中的非线性激活介绍合集——Pytorch中的非线性激活

1、nn.ELU 基本语法&#xff1a; class torch.nn.ELU(alpha1.0, inplaceFalse)按元素应用 Exponential Linear Unit &#xff08;ELU&#xff09; 函数。 论文中描述的方法&#xff1a;通过指数线性单元 &#xff08;ELU&#xff09; 进行快速准确的深度网络学习。 ELU 定义为…

Databend Cloud Dashboard 全新升级:直击痛点,释放数据价值

自 Databend Cloud 上线以来&#xff0c;我们一直致力于为用户提供高效的数据处理与可视化体验。早期&#xff0c;我们在工作区的“图表”区域推出了轻量级可视化功能&#xff0c;支持积分卡、饼图、柱状图和折线图四种展示方式。这些功能简单易用&#xff0c;基本满足了用户对…