导出BERT句子模型为ONNX并推理

在深度学习中,将模型导出为ONNX(Open Neural Network Exchange)格式并利用ONNX进行推理是提高推理速度和模型兼容性的一种常见做法。本文将介绍如何将BERT句子模型导出为ONNX格式,并使用ONNX Runtime进行推理,具体以中文文本处理为例。

1. 什么是ONNX?

ONNX 是一种开放的神经网络交换格式,旨在促进深度学习模型在不同平台和工具之间的共享和移植。它支持包括PyTorch、TensorFlow等多种主流框架,可以通过ONNX Runtime库高效推理。通过将模型转换为ONNX格式,我们可以获得跨平台部署的优势,并利用ONNX Runtime加速推理过程。

2. 准备工作

在导出和推理之前,需要安装以下库:

pip install torch transformers onnx onnxruntime

3. 导出BERT句子模型为ONNX

首先,我们将使用HuggingFace的transformers库加载一个预训练的BERT句子模型(text2vec-base-chinese),然后将其导出为ONNX格式。以下是导出模型的步骤和代码:

3.1 导出模型的代码

import torch
from transformers import BertTokenizer, BertModel# 加载预训练的BERT模型和分词器
tokenizer = BertTokenizer.from_pretrained('shibing624/text2vec-base-chinese')
model = BertModel.from_pretrained('shibing624/text2vec-base-chinese')# 读取要处理的句子
with open("corpus/words_nlu.txt", 'rt', encoding='utf-8') as f:nlu_words = [line.strip() for line in f.readlines()]
nlu_words.insert(0, "摄像头打开一下")  # 插入要比较的句子# 对句子进行编码
encoded_input = tokenizer(nlu_words, padding=True, truncation=True, return_tensors='pt')# 设置ONNX模型的保存路径
onnx_model_path = "text2vec-base-chinese.onnx"
model.eval()# 导出模型为ONNX格式
with torch.no_grad():torch.onnx.export(model,(encoded_input['input_ids'], encoded_input['attention_mask']),onnx_model_path,input_names=['input_ids', 'attention_mask'],output_names=['last_hidden_state'],opset_version=14,dynamic_axes={'input_ids': {0: 'batch_size', 1: 'sequence_length'},'attention_mask': {0: 'batch_size', 1: 'sequence_length'},'last_hidden_state': {0: 'batch_size', 1: 'sequence_length'}})
print(f"ONNX模型已导出到 {onnx_model_path}")

在这段代码中,我们将text2vec-base-chinese模型导出为ONNX格式,指定了输入和输出的名称,并使用了动态轴设置(如批大小和序列长度),这样可以处理不同长度的句子。

4. 使用ONNX进行推理

导出模型后,我们可以使用ONNX Runtime进行推理。以下是基于ONNX的推理代码。该代码实现了对输入文本进行预处理、调用ONNX模型进行推理、以及对模型输出进行均值池化处理。

4.1 ONNX推理代码

import numpy as np
from onnxruntime import InferenceSessionclass PIPE_NLU:def __init__(self, model_path="text2vec-base-chinese.onnx", vocab_path="vocab.txt") -> None:self.model_path = model_pathself.vocab_path = vocab_pathself.vocab = self.load_vocab(vocab_path)self.onnx_session = InferenceSession(model_path)print("成功加载NLU解码器")def load_vocab(self, vocab_path):"""加载BERT词汇表"""vocab = {}with open(vocab_path, 'r', encoding='utf-8') as f:for idx, line in enumerate(f):token = line.strip()vocab[token] = idxreturn vocabdef tokenize(self, text):"""将文本分词为BERT的input_ids"""tokens = ['[CLS]']for char in text:if char in self.vocab:tokens.append(char)else:tokens.append('[UNK]')tokens.append('[SEP]')input_ids = [self.vocab[token] if token in self.vocab else self.vocab['[UNK]'] for token in tokens]return input_idsdef preprocess(self, texts, max_length=128):"""对输入文本进行预处理"""input_ids_list = []attention_mask_list = []for text in texts:input_ids = self.tokenize(text)if len(input_ids) > max_length:input_ids = input_ids[:max_length]else:input_ids += [0] * (max_length - len(input_ids))attention_mask = [1 if idx != 0 else 0 for idx in input_ids]input_ids_list.append(input_ids)attention_mask_list.append(attention_mask)inputs = {'input_ids': np.array(input_ids_list, dtype=np.int64),'attention_mask': np.array(attention_mask_list, dtype=np.int64)}return inputsdef mean_pooling_numpy(self, model_output, attention_mask):"""对模型输出进行均值池化"""token_embeddings = model_outputinput_mask_expanded = np.expand_dims(attention_mask, -1).astype(float)return np.sum(token_embeddings * input_mask_expanded, axis=1) / np.clip(np.sum(input_mask_expanded, axis=1), a_min=1e-9, a_max=None)def compute_embeddings(self, texts):"""计算输入文本的句子嵌入"""onnx_inputs = self.preprocess(texts)onnx_outputs = self.onnx_session.run(None, onnx_inputs)last_hidden_state = onnx_outputs[0]sentence_embeddings = self.mean_pooling_numpy(last_hidden_state, onnx_inputs['attention_mask'])sentence_embeddings = sentence_embeddings / np.linalg.norm(sentence_embeddings, axis=1, keepdims=True)return sentence_embeddings

4.2 推理流程

  1. 加载ONNX模型:通过InferenceSession加载ONNX模型。
  2. 加载词汇表:读取BERT的词汇表,用于将输入文本转化为模型可接受的input_ids格式。
  3. 文本预处理:将输入的文本进行分词、截断或填充为固定长度,并生成相应的注意力掩码attention_mask
  4. 模型推理:通过ONNX Runtime调用模型,获取句子的最后隐藏状态输出。
  5. 均值池化:对最后的隐藏状态进行均值池化,计算出句子的嵌入向量。
  6. 归一化嵌入:将句子嵌入向量进行归一化,使得向量长度为1。

5. 总结

通过将BERT模型导出为ONNX并使用ONNX Runtime进行推理,我们可以大幅度提升推理速度,同时保持了高精度的句子嵌入计算。在实际应用中,ONNX Runtime的跨平台特性和高性能表现使其成为模型部署和推理的理想选择。

使用上述步骤,您可以轻松将BERT句子模型应用到各种自然语言处理任务中,如语义相似度计算、文本分类和句子嵌入等。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/57825.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

05方差分析续

文章目录 1.Three way ANOVA2.Latin square design2.Hierarchical (nested) ANOVA3.Split-plot ANOVA4.Repeated measures ANOVA5.Mixed effect models 1.Three way ANOVA 三因素相关分析 单因子分析的代码 data(mtcars) nrow(mtcars) # 32 mtcars$cyl as.factor(mtcars$cyl…

c#子控件拖动父控件方法及父控件限在窗体内拖动

一、效果 拖放位置不超过窗体四边,超出后自动靠边停靠支持多子控件拖动指定控件拖放(含父控件或窗体)点击左上角logo弹出消息窗口(默认位置右下角)1.1 效果展示 1.2 关于MQTTnet(最新版v4.3.7.1207)实现在线客服功能,见下篇博文 https://github.com/dotnet/MQTTnet 网上…

BIO,NIO,直接内存,零拷贝

前置知识 什么是Socket? Socket是应用层与TCP/IP协议族通信的中间软件抽象层,它是一组接口,一般由操作系统提供。在设计模式中,Socket其实就是一个门面模式,它把复杂的TCP/IP协议处理和通信缓存管理等等都隐藏在Sock…

莱维飞行(Levy Flight)机制的介绍和MATLAB例程

文章目录 莱维飞行机制算法简介自然现象中的应用优化问题中的应用关键公式 MATLAB代码示例代码说明运行结果 莱维飞行机制算法的应用前景1. 自然科学中的应用2. 计算机科学中的应用3. 工程技术中的应用4. 金融与经济学中的应用5. 医疗与生物信息学中的应用6. 未来研究方向 结论…

【软件工程】软件工程入门

🌈 个人主页:十二月的猫-CSDN博客 🔥 系列专栏: 🏀软件开发必练内功_十二月的猫的博客-CSDN博客 💪🏻 十二月的寒冬阻挡不了春天的脚步,十二点的黑夜遮蔽不住黎明的曙光 目录 1. 前…

软件分享丨Marktext 编辑器

Marktext是一款开源免费的Markdown编辑器,它具有简洁优雅的界面设计和强大的功能,支持多种Markdown语法,包括表格、流程图、甘特图、数学公式、代码高亮等。Marktext还支持导出HTML和PDF格式的文档,非常适合需要编写Markdown文档的…

5G NR:BWP入门

简介 5G NR 系统带宽比4G LTE 大了很多,4G LTE 最大支持带宽为20MHz, 而5G NR 的FR1 最大支持带宽为100MHz, FR2 最大支持带宽为 400MHz。 带宽越大,意味了终端功耗越多。为了减少终端的功耗,5G NR 引入了BWP(Band Wid…

不写单元测试的我,被批了

最近在看单元测试的东西,想跟大家聊聊我的感受。单元测试这块说实在的,我并不太熟悉,我几乎不写单元测试,也不太爱写单元测试。 当我推广消息推送平台austin的时候,有过批评我整个项目没有单元测试,也有过…

《a16z : 2024 年加密货币现状报告》解析

加密社 原文链接:State of Crypto 2024 - a16z crypto译者:AI翻译官,校对:翻译小组 当我们两年前第一次发布年度加密状态报告的时候,情况跟现在很不一样。那时候,加密货币还没成为政策制定者关心的大事。 比…

生信软件39 - GATK最佳实践流程重构,提高17倍分析速度的LUSH流程

1. LUSH流程简介 基因组测序通常用于分子诊断、分期和预后,而大量测序数据在分析时间方面提出了挑战。 对于从FASTQ到VCF的整个流程,LUSH流程在非GVCF和GVCF模式下都大大降低了运行时间,30 X WGS数据耗时不到2 h,从BAM到VCF约需…

使用 ASP.NET Core 8.0 创建最小 API

构建最小 API,以创建具有最小依赖项的 HTTP API。 它们非常适合需要在 ASP.NET Core 中仅包括最少文件、功能和依赖项的微服务和应用。 本教程介绍使用 ASP.NET Core 生成最小 API 的基础知识。 在 ASP.NET Core 中创建 API 的另一种方法是使用控制器。 有关在最小 …

认识CSS语法

CSS(网页美容) 重点:选择器、盒子模型、浮动、定位、动画,伸缩布局 Css的作用: 美化网页:CSS控制标签的样式 网页布局:CSS控制标签的位置 概念:层叠样式表(级联样式表…

Maven(解决思路)

1.前言 作为一名一线的开发人员,maven大概率是我们用的最多的依赖管理,但是你知道我们的maven出现问题后怎么去排查么?不对,确切的来说,假如你去导入的包没有被成功导入,你有什么方法去排查、去解决这个问题…

Linux-Centos操作系统备份及还原(整机镜像制作与还原)--再生龙

适用场景 Linux系统设备需要备份整机数据,或者需要还原到多台设备上。适用再生龙工具进行整机备用和还原。 镜像制作 下载再生龙镜像:clonezilla-live-2.6.4-10-amd64.iso,制作启动盘-设置U盘启动 启动后界面如下选择第四项other modes of…

力扣143:重排链表

给定一个单链表 L 的头节点 head ,单链表 L 表示为: L0 → L1 → … → Ln - 1 → Ln请将其重新排列后变为: L0 → Ln → L1 → Ln - 1 → L2 → Ln - 2 → … 不能只是单纯的改变节点内部的值,而是需要实际的进行节点交换。 示…

如何使用的是github提供的Azure OpenAI服务

使用的是github提供的Azure OpenAI的服务gpt-4o 说明:使用的是github提供的Azure OpenAI的服务,可以无限薅羊毛。开源地址 进入: 地址 进入后点击 右上角“Get API key”按钮 点击“Get developer key” 选择Beta版本“Generate new to…

HarmonyOS开发 - 本地持久化之实现LocalStorage实例

用户首选项为应用提供Key-Value键值型的数据处理能力,支持应用持久化轻量级数据,并对其修改和查询。数据存储形式为键值对,键的类型为字符串型,值的存储数据类型包括数字型、字符型、布尔型以及这3种类型的数组类型。 说明&#x…

C#通过异或(^)运算符制作二进制加密(C#实现加密)

快速了解异或运算符&#xff1a; 异或运算符在C#中用 “^” 来表示 口诀&#xff1a;相同取0&#xff0c;相异取1 简单加密解密winform示例&#xff1a; /// <summary>/// 异或运算符加密实现/// </summary>/// <param name"p_int_Num">初始值<…

中小企业设备维护新策略:Spring Boot系统设计与实现

2相关技术 2.1 MYSQL数据库 MySQL是一个真正的多用户、多线程SQL数据库服务器。 是基于SQL的客户/服务器模式的关系数据库管理系统&#xff0c;它的有点有有功能强大、使用简单、管理方便、安全可靠性高、运行速度快、多线程、跨平台性、完全网络化、稳定性等&#xff0c;非常…

全面指南:Visual Studio Code 的下载、安装、使用与插件管理

活着&#xff0c;就是一场盛大的遇见&#xff0c;与世界&#xff0c;与自己&#xff0c;与每一个瞬间的奇迹 文章目录 前言下载 Visual Studio Code安装 Visual Studio CodewindowsmacOSLinux 使用 Visual Studio CodeVisual Studio Code 插件安装方法语言支持代码格式化与美化…