导出RWKV模型为onnx

测试模型:

https://huggingface.co/RWKV/rwkv-5-world-3b

导出前对modeling_rwkv5.py进行一个修改:

#        out = out.reshape(B * T, H * S)
        out = out.reshape(B * T, H * S, 1) # <<--- modified
        out = F.group_norm(out, num_groups=H, weight=lxw, bias=lxb).reshape(B, T, H * S)

因为目前存pytorch导出onnx在bug,不支持2d输入的group_norm导出。

注意:

rwkv_linear_attention_v5_cpu中使用 for t in range(T):来拆分计算,这导致首次prompt和后续decoding阶段导出的onnx模型结构不一样。这部分需要改进后才能导出同时适用于prompt和decoding的onnx。

if hidden.size(1) == 1这样的判断逻辑也可能导致上述问题。

此外,为了高效的推理,这个rwkv还可以进一步优化,例如state是把按照

        state[1][:, :, :, :, self.layer_id] = layer_state
更新每一层的状态,这种方法比把layer_id放在最外层性能是显著更差的:

        state[1][self.layer_id, :, :, :, :] = layer_state

甚至说可以就像transformer架构模型一样,直接把每一层的layer_state单独存在一个List里面,虽然增加了模型的输入输出个数,但是避免了复杂的ScatterND算子。

导出代码参考(可以尝试device=cpu导出):

import os
import argparse
import torch
from torch import nn
from transformers import AutoModelForCausalLM, AutoTokenizerclass LLMForCausalLMWrapper(nn.Module):def __init__(self, model, config, args):super().__init__()self.model = modelself.config = configself.args = argsdef forward(self,input_ids,state,):outputs = self.model(input_ids=input_ids,state=state,use_cache=True,)logits = outputs.logitsstate_out = outputs.statereturn logits, state_outdef export_llm_to_single_onnx(model, config, dtype, args, model_name):llama_model_wrapper = LLMForCausalLMWrapper(model, config, args)onnx_file_name = os.path.join(args.out_dir, f"{model_name}.onnx")hidden_size = config.hidden_sizelayer_num = config.num_hidden_layershead_num = config.hidden_size // config.num_attention_headshead_hidden_size = config.hidden_size // head_numbatch = 1N = 4input_ids_shape = [batch, N]input_ids = torch.ones(input_ids_shape, dtype=torch.int64).to(args.device)dynamic_axes = {'input_ids': {1: 'N', },}if args.dyn_batch:dynamic_axes['input_ids'][0] = "batch"state_0 = torch.randn([batch, hidden_size, layer_num], dtype=dtype).to(args.device)state_1 = torch.randn([batch, head_num, head_hidden_size, head_hidden_size, layer_num], dtype=dtype).to(args.device)state_2 = torch.randn([batch, hidden_size, layer_num], dtype=dtype).to(args.device)state = [state_0, state_1, state_2]in_names = ["input_ids", "state_0_in", "state_1_in", "state_2_in"]kv_caches_in = []out_names = ["lm_logits", "state_0_out", "state_1_out", "state_2_out"]input_datas = (input_ids, state)torch.onnx.export(llama_model_wrapper,input_datas,onnx_file_name,opset_version=args.opset,do_constant_folding=True,input_names=in_names,output_names=out_names,dynamic_axes=dynamic_axes,)def export_rwkv(args):device = args.devicedtype_map = {"float32": torch.float32,"float16": torch.float16,"bfloat16": torch.bfloat16,}dtype = dtype_map[args.dtype]print(f"begin load model from {args.model_path}")model = AutoModelForCausalLM.from_pretrained(args.model_path, device_map=device, torch_dtype=dtype, trust_remote_code=True).eval()model.rwkv.blocks = model.rwkv.blocks[:1]  # only export few layer for debugprint(f"finish load model from {args.model_path}")config = model.configprint("config:", config)print(f"begin export llm")export_llm_to_single_onnx(model, config, dtype, args, "llm_onnx")if __name__ == "__main__":parser = argparse.ArgumentParser(description='export llm',)parser.add_argument('-m', '--model_path', required=True, type=str)parser.add_argument('-o', '--out_dir', required=False, type=str, default="")parser.add_argument('--opset', required=False, type=int, default=15)parser.add_argument('-d', '--device', required=False, type=str, choices=["cpu", "cuda"], default="cuda")parser.add_argument('-p', '--dtype', required=False, type=str,choices=["float32", "float16", "bfloat16"], default="float16")parser.add_argument('--add_topk_warper', required=False, type=int, default=0)parser.add_argument('--topk', required=False, type=int, default=4)parser.add_argument('--dyn_batch', action='store_true')args = parser.parse_args()export_rwkv(args)

导出其他模型和对大模型进行onnxsim参考:

GitHub - luchangli03/export_llama_to_onnx: export llama to onnx

GitHub - luchangli03/onnxsim_large_model: simplify >2GB large onnx model

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/736956.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vue element plus Border 边框

我们对边框进行统一规范&#xff0c;可用于按钮、卡片、弹窗等组件里。 边框样式# 我们提供了以下几种边框样式&#xff0c;以供选择。 NameThicknessDemoSolid1pxDashed2px 圆角# 我们提供了以下几种圆角样式&#xff0c;以供选择。 No Radius border-radius: 0px Small …

ros | 如何在ubuntu上安装ros

只能说小鱼真的强&#xff01;&#xff01;&#xff01; wget http://fishros.com/install -O fishros && bash fishros一行命令安装 至于版本下面这个可能不太准确&#xff0c;自己建议再去查以下 Ubuntu ROS1 14.04 LTS indigo 16.04 LTS Kinetic 18.…

【竞技宝】LOL:knight阿狸伤害爆炸 BLG2-0轻取RA

北京时间2024年3月11日,英雄联盟LPL2024春季常规赛继续进行,昨日共进行三场比赛,首场比赛由BLG对阵RA。本场比赛BLG选手个人实力碾压RA2-0轻松击败对手。以下是本场比赛的详细战报。 第一局: BLG:剑魔、千珏、妮蔻、卡牌、洛 RA:乌迪尔、蔚、阿卡丽、斯莫德、芮尔 首局比赛,B…

QMS质量管理系统在离散型制造业的应用及效益

在离散型制造行业中&#xff0c;质量管理是确保产品满足客户期望和市场需求的关键环节。QMS质量管理系统的实施为企业提供了一种全面、系统的方法来管理和改进质量。 例如&#xff0c;在汽车制造行业&#xff0c;QMS质量管理系统可以应用于零部件采购和装配过程的质量控制。通过…

语音合成(TTS) 声音生成(TTA)最新技术 - 2024- 附论文地址和代码地址

文章目录 1. 我们的模型2. 声音生成模型&#xff1a;AudioLDM3. 语音合成模型&#xff1a;VoiceLDM 生成式 AI 是最近一年最受关注的课题&#xff0c;可以应用于游戏、虚拟现实等智能交互场景。 1. 我们的模型 由中国科学院计算所和东芝中国研发中心联合发表于AAAI 2024 论文…

Spring Cloud GateWay整合熔断器实现限流

其实网关是很强大&#xff0c;能做的事情很多&#xff0c;包含很多过滤器包括限流&#xff0c;具体的网关可以参考我的另外一篇博文Spring Cloud GateWay-过滤器 今天我们来说下网关如何限流&#xff0c;主要两种方案&#xff1a; Spring Cloud GateWay整合hystrx environme…

PCBA方案设计充气泵设计

随着科技的不断进步&#xff0c;充气泵在户外活动、露营和旅行中变得越来越常见。而充气泵的性能和稳定性主要依赖于其控制系统&#xff0c;其中芯片的设计和开发是充气泵方案的关键。SIC8833芯片是一款专门为充气泵设计的芯片&#xff0c;接下来我们来讲下充气泵方案芯片SIC88…

计网Lesson19 - 应用层之WWW

文章目录 1. 应用层概念1.1 应用层协议的定义 2. 应用层体系结构2.1 客户/服务器模式2.2 P2P模式&#xff08;点对点方式&#xff09; 3. 万维网3.1 概述3.2 统一资源定位符3.3 万维文档HTMLCSSJavaScript 1. 应用层概念 应用层是计算机网络体系结构的最顶层&#xff0c;是设计…

面试官:Spring Boot 是否可以使用 XML 配置?如果可以的话怎么配置

该文章专注于面试,面试只要回答关键点即可,不需要对框架有非常深入的回答,如果你想应付面试,是足够了,抓住关键点 面试官:Spring Boot 是否可以使用 XML 配置?如果可以的话怎么配置 Spring Boot 主要推崇使用注解驱动的 Java 配置方式,尤其是通过 @Configuration、@C…

USACO 2023 December, SilverProblem 1. Bovine Acrobatics题解

有n件物品&#xff0c;m组叠罗汉&#xff0c;相邻罗汉差值至少为k。 第i件物品的重量和数量 由于m最大范围为1e9&#xff0c;开辟m组罗汉槽存储罗汉值&#xff0c;内存空间不够。 分析样例&#xff1a; 3 5 2 9 4 7 65 5 一开始我想的是层数&#xff0c;但是一层中存在数据…

初学Vue——Vue路由

0 什么是Vue路由 类似于Html中的超链接(<a>)一样&#xff0c;可以跳转页面的一种方式。 前端路由&#xff1a;URL中hash(#号之后的内容)与组件之间的对应关系&#xff0c;如下图&#xff1a; 当我们点击左侧导航栏时&#xff0c;浏览器的地址栏会发生变化&#xff0c;路…

python的数据容器--元组

元组的定义 #元组的定义 t1 (1,"hello",True) t2() t3tuple() t4("hello",) print(f"t4的数据类型{type(t4)},t4的内容是{t4}")#元组的嵌套使用 t5((1,2,3),(4,5,6)) print(f"t5的类型是{type(t5)}")#元组的元素提取 print(f"{t…

【Prometheus】k8s集群部署node-exporter

​ 目录 一、概述 1.1 prometheus简介 1.2 prometheus架构图 1.3 Exporter介绍 1.4 监控指标 1.5 参数定义 1.6 默认启用的参数 1.7 prometheus如何收集k8s/服务的–三种方式收集 二、安装node-exporter组件 【Prometheus】概念和工作原理介绍-CSDN博客 【云原生】ku…

借助 Terraform 功能协调部署 CI/CD 流水线-Part2

在第一部分的文章中&#xff0c;我们介绍了3个步骤&#xff0c;完成了教程的基础配置&#xff1a; 使用 Terraform 创建 AWS EKS Infra在 EKS 集群上部署 ArgoCD 及其依赖项设置 Bitbucket Pipeline并部署到 ECR Repo 本文将继续完成剩余的步骤&#xff0c;以实现 Terraform 编…

windows下pytorch的dataloader多进程(num_workers)问题,为何num_workers的值只能为0?

问题背景介绍 本人是windows系统&#xff0c;在使用torch.utils.data.Dataloader加载torchvision中的数据集时&#xff0c;将其中的形参num_workers设置为了大于0的数&#xff0c;然后出现以下错误。 原因 在 Windows 系统下&#xff0c;num_workers 参数在使用 PyTorch 的 t…

做58代运营到底有没有效果?

58同城代运营有没有效果&#xff0c;其实还是取决于你的产品、预算和想要达到的效果。比如&#xff0c;你是希望通过代运营做出品牌效应&#xff0c;还是希望提升销量&#xff1f;代运营能够带来的效果&#xff0c;主要是咨询和品牌推广&#xff0c;这两个都可以量化&#xff0…

解决Iterm2升级后遇到“Stashed changes“的问题

&#xff1c;&#xff1c;&#xff1c;&#xff1c;&#xff1c;&#xff1c;&#xff1c; Updated upstream ...... &#xff1e;&#xff1e;&#xff1e;&#xff1e;&#xff1e;&#xff1e;&#xff1e; Stashed changes冲突标记符的代码如题&#xff0c;最近有升级Item2…

终于搞懂lSTM的原理了

LSTM简介 一个目前很火的特殊的RNN结构&#xff0c; 有效解决了RNN的梯度爆炸和长序列记忆问题 优势 LSTM 通过引入遗忘门、输入门、输出门&#xff0c; 来实现对特殊特征的记忆和遗忘&#xff0c;来达到更好的对序列数据的处理和记忆效果。 原理图&#xff1a; 总结公式…

校园小情书微信小程序源码 | 社区小程序前后端开源 | 校园表白墙交友小程序

项目描述&#xff1a; 校园小情书微信小程序源码 | 社区小程序前后端开源 | 校园表白墙交友小程序 功能介绍&#xff1a; 表白墙 卖舍友 步数旅行 步数排行榜 情侣脸 漫画脸 个人主页 私信 站内消息 今日话题 评论点赞收藏 服务器环境要求&#xff1a;PHP7.0 MySQL5.7 效果…

Java设计模式-策略模式

策略模式1 概述2 结构3 案例实现4 优缺点5 使用场景6 JDK源码解析 策略模式 1 概述 先看下面的图片&#xff0c;我们去旅游选择出行模式有很多种&#xff0c;可以骑自行车、可以坐汽车、可以坐火车、可以坐飞机。 作为一个程序猿&#xff0c;开发需要选择一款开发工具&#x…