测试模型:
https://huggingface.co/RWKV/rwkv-5-world-3b
导出前对modeling_rwkv5.py进行一个修改:
# out = out.reshape(B * T, H * S)
out = out.reshape(B * T, H * S, 1) # <<--- modified
out = F.group_norm(out, num_groups=H, weight=lxw, bias=lxb).reshape(B, T, H * S)
因为目前存pytorch导出onnx在bug,不支持2d输入的group_norm导出。
注意:
rwkv_linear_attention_v5_cpu中使用 for t in range(T):来拆分计算,这导致首次prompt和后续decoding阶段导出的onnx模型结构不一样。这部分需要改进后才能导出同时适用于prompt和decoding的onnx。
if hidden.size(1) == 1这样的判断逻辑也可能导致上述问题。
此外,为了高效的推理,这个rwkv还可以进一步优化,例如state是把按照
state[1][:, :, :, :, self.layer_id] = layer_state
更新每一层的状态,这种方法比把layer_id放在最外层性能是显著更差的:
state[1][self.layer_id, :, :, :, :] = layer_state
甚至说可以就像transformer架构模型一样,直接把每一层的layer_state单独存在一个List里面,虽然增加了模型的输入输出个数,但是避免了复杂的ScatterND算子。
导出代码参考(可以尝试device=cpu导出):
import os
import argparse
import torch
from torch import nn
from transformers import AutoModelForCausalLM, AutoTokenizerclass LLMForCausalLMWrapper(nn.Module):def __init__(self, model, config, args):super().__init__()self.model = modelself.config = configself.args = argsdef forward(self,input_ids,state,):outputs = self.model(input_ids=input_ids,state=state,use_cache=True,)logits = outputs.logitsstate_out = outputs.statereturn logits, state_outdef export_llm_to_single_onnx(model, config, dtype, args, model_name):llama_model_wrapper = LLMForCausalLMWrapper(model, config, args)onnx_file_name = os.path.join(args.out_dir, f"{model_name}.onnx")hidden_size = config.hidden_sizelayer_num = config.num_hidden_layershead_num = config.hidden_size // config.num_attention_headshead_hidden_size = config.hidden_size // head_numbatch = 1N = 4input_ids_shape = [batch, N]input_ids = torch.ones(input_ids_shape, dtype=torch.int64).to(args.device)dynamic_axes = {'input_ids': {1: 'N', },}if args.dyn_batch:dynamic_axes['input_ids'][0] = "batch"state_0 = torch.randn([batch, hidden_size, layer_num], dtype=dtype).to(args.device)state_1 = torch.randn([batch, head_num, head_hidden_size, head_hidden_size, layer_num], dtype=dtype).to(args.device)state_2 = torch.randn([batch, hidden_size, layer_num], dtype=dtype).to(args.device)state = [state_0, state_1, state_2]in_names = ["input_ids", "state_0_in", "state_1_in", "state_2_in"]kv_caches_in = []out_names = ["lm_logits", "state_0_out", "state_1_out", "state_2_out"]input_datas = (input_ids, state)torch.onnx.export(llama_model_wrapper,input_datas,onnx_file_name,opset_version=args.opset,do_constant_folding=True,input_names=in_names,output_names=out_names,dynamic_axes=dynamic_axes,)def export_rwkv(args):device = args.devicedtype_map = {"float32": torch.float32,"float16": torch.float16,"bfloat16": torch.bfloat16,}dtype = dtype_map[args.dtype]print(f"begin load model from {args.model_path}")model = AutoModelForCausalLM.from_pretrained(args.model_path, device_map=device, torch_dtype=dtype, trust_remote_code=True).eval()model.rwkv.blocks = model.rwkv.blocks[:1] # only export few layer for debugprint(f"finish load model from {args.model_path}")config = model.configprint("config:", config)print(f"begin export llm")export_llm_to_single_onnx(model, config, dtype, args, "llm_onnx")if __name__ == "__main__":parser = argparse.ArgumentParser(description='export llm',)parser.add_argument('-m', '--model_path', required=True, type=str)parser.add_argument('-o', '--out_dir', required=False, type=str, default="")parser.add_argument('--opset', required=False, type=int, default=15)parser.add_argument('-d', '--device', required=False, type=str, choices=["cpu", "cuda"], default="cuda")parser.add_argument('-p', '--dtype', required=False, type=str,choices=["float32", "float16", "bfloat16"], default="float16")parser.add_argument('--add_topk_warper', required=False, type=int, default=0)parser.add_argument('--topk', required=False, type=int, default=4)parser.add_argument('--dyn_batch', action='store_true')args = parser.parse_args()export_rwkv(args)
导出其他模型和对大模型进行onnxsim参考:
GitHub - luchangli03/export_llama_to_onnx: export llama to onnx
GitHub - luchangli03/onnxsim_large_model: simplify >2GB large onnx model