bert-NER 转化成 onnx 模型

保存模型

加载模型

from transformers import AutoTokenizer, AutoModel, AutoConfigNER_MODEL_PATH = './save_model'
ner_tokenizer = AutoTokenizer.from_pretrained(NER_MODEL_PATH)
ner_config = AutoConfig.from_pretrained(NER_MODEL_PATH)
ner_model = AutoModelForTokenClassification.from_pretrained(NER_MODEL_PATH)
ner_model.eval()

测试ner效果

在这里插入图片描述

测试速度

在这里插入图片描述

导出到onnx

# !pip install onnx onnxruntime -i https://pypi.tuna.tsinghua.edu.cn/simple/# 导出 onnx 模型
import onnxruntime
from itertools import chain
from transformers.onnx.features import FeaturesManagerconfig = ner_config
tokenizer = ner_tokenizer
model = ner_model
output_onnx_path = "bert-ner.onnx"onnx_config = FeaturesManager._SUPPORTED_MODEL_TYPE['bert']['sequence-classification'](config)
dummy_inputs = onnx_config.generate_dummy_inputs(tokenizer, framework='pt')torch.onnx.export(model,(dummy_inputs,),f=output_onnx_path,input_names=list(onnx_config.inputs.keys()),output_names=list(onnx_config.outputs.keys()),dynamic_axes={name: axes for name, axes in chain(onnx_config.inputs.items(), onnx_config.outputs.items())},do_constant_folding=True,use_external_data_format=onnx_config.use_external_data_format(model.num_parameters()),enable_onnx_checker=True,opset_version=onnx_config.default_onnx_opset,
)

加载ONNX模型

自定义pipeline

from onnxruntime import SessionOptions, GraphOptimizationLevel, InferenceSessionclass PipeLineOnnx:def __init__(self, tokenizer, onnx_path, config):self.tokenizer = tokenizerself.config = config  # label2id, id2labeloptions = SessionOptions() # initialize session optionsoptions.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL# 设置线程数
#         options.intra_op_num_threads = 4# 这里的路径传上一节保存的onnx模型地址self.session = InferenceSession(onnx_path, sess_options=options, providers=["CPUExecutionProvider"])# disable session.run() fallback mechanism, it prevents for a reset of the execution providerself.session.disable_fallback() def __call__(self, text):inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors='pt')ids = inputs["input_ids"]inputs_offset = self.tokenizer.encode_plus(text, return_offsets_mapping=True).offset_mappinginputs_detach = {k: v.detach().cpu().numpy() for k, v in inputs.items()}# 运行 ONNX 模型# 这里的logits要有export的时候output_names相对应output = self.session.run(output_names=['logits'], input_feed=inputs_detach)[0]logits = torch.tensor(output)num_labels = len(self.config.label2id)active_logits = logits.view(-1, num_labels) # shape (batch_size * seq_len, num_labels)softmax = torch.softmax(active_logits, axis=1)scores = torch.max(softmax, axis=1).values.cpu().detach().numpy()flattened_predictions = torch.argmax(active_logits, axis=1) # shape (batch_size*seq_len,) - predictions at the token leveltokens = self.tokenizer.convert_ids_to_tokens(ids.squeeze().tolist())token_predictions = [self.config.id2label[i] for i in flattened_predictions.cpu().numpy()]wp_preds = list(zip(tokens, token_predictions)) # list of tuples. Each tuple = (wordpiece, prediction)ner_result = [{"index": idx, "word":i,"entity":j, "start": k[0], "end": k[1], "score": s} for idx, (i,j,k,s) in enumerate(zip(tokens, token_predictions, inputs_offset, scores)) if j != 'O']return post_process(ner_result)def allow_merge(a, b):a_flag, a_type = a.split('-')b_flag, b_type = b.split('-')if b_flag == 'B' or a_flag == 'E':return Falseif a_type != b_type:return Falseif (a_flag, b_flag) in [("B", "I"),("B", "E"),("I", "I"),("I", "E")]:return Truereturn Falsedef divide_entities(ner_results):divided_entities = []current_entity = []for item in sorted(ner_results, key=lambda x: x['index']):if not current_entity:current_entity.append(item)elif allow_merge(current_entity[-1]['entity'], item['entity']):current_entity.append(item)else:divided_entities.append(current_entity)current_entity = [item]divided_entities.append(current_entity)return divided_entitiesdef merge_entities(same_entities):def avg(scores):return sum(scores)/len(scores)return {'entity': same_entities[0]['entity'].split("-")[1],'score': avg([e['score'] for e in same_entities]),'word': ''.join(e['word'].replace('##', '') for e in same_entities),'start': same_entities[0]['start'],'end': same_entities[-1]['end']}def post_process(ner_results):return [merge_entities(i) for i in divide_entities(ner_results)]

加载模型

from transformers import AutoTokenizer, AutoConfigNER_MODEL_PATH = './save_model'
ner_tokenizer = AutoTokenizer.from_pretrained(NER_MODEL_PATH)
ner_config = AutoConfig.from_pretrained(NER_MODEL_PATH)pipe2 = PipeLineOnnx(ner_tokenizer, "bert-ner.onnx", config=ner_config)

测试效果

在这里插入图片描述

测试速度

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/9073.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【雅思写作】Vince9120雅思小作文笔记——P1 Intro(前言)

文章目录 链接P1 Intro(前言)字数限制题型综述(problem types overview)1. **柱状图(Bar Chart)** - 描述不同类别在某个或多个变量上的数据量比较。2. **线图(Line Graph)** - 展示…

冯喜运:5.10黄金反弹受阻,原油EIA库存激增引发市场情绪

【黄金消息面分析】:据最新市场数据显示,现货黄金在周四欧市早盘经历了显著下滑,价格一度跌破2310美元/盎司的关口,日内高点回落达10美元,截至发稿,黄金小幅反弹,交投于2312美元/盎司附近。此番…

【工具】如何提取一个mp4文件的关键帧

文章目录 怎么做如何安装ffmepgUbuntu 或 DebianCentOS 或 FedoramacOSWindows其他 Linux 发行版 实践什么是关键帧 怎么做 你可以使用ffmpeg这个强大的多媒体处理工具来提取mp4文件中的关键帧。以下是一个示例命令,可以使用ffmpeg从mp4文件中提取关键帧&#xff1…

即将开幕,邀您共赴创新之旅“2024上海国际消费者科技及创新展览会”

备受期待的2024上海国际消费者科技及创新展览会(以下简称“CTIS”)即将于6月13日至15日亮相上海新国际博览中心N1-N3馆。 2024上海国际消费者科技及创新展览会总面积达40,000平方米,涵盖600余家展商,预计吸引40,000多位观众莅临现…

单片机——直流电机

1 .关于4线直流电机 两根12v供电线,通入12v,风扇以最高转速工作。 一根测速线,电机工作时输出测速信号,提供转速反馈。一根PWM控制信号线,电机工作时控制器输入PWM控制信号,以控制风扇转速(通常为占空比可…

Python爬虫基础知识学习(以爬取某二手房数据、某博数据与某红薯(书)评论数据为例)

一、爬虫基础流程 爬虫的过程模块化,基本上可以归纳为以下几个步骤: 1、分析网页URL:打开你想要爬取数据的网站,然后寻找真实的页面数据URL地址; 2、请求网页数据:模拟请求网页数据,这里我们介…

双翻斗雨量计学习

双翻斗雨量计用户手册(脉冲型) 本仪器由雨量计壳体、承雨口、漏斗、翻斗支撑、上漏斗雨量调节支架、上漏斗、汇集漏斗、计数翻斗雨量调节支架、计数翻斗、干簧管安装架、轴承螺钉、出水漏斗、腿部支架、干簧管、水平泡、调节支撑板、控制盒、调平装置、接…

安装oh-my-zsh(命令行工具)

文章目录 一、安装zsh、git、wget二、安装运行脚本1、curl/wget下载2、手动下载 三、切换主题1、编辑配置文件2、切换主题 四、安装插件1、zsh-syntax-highlighting(高亮语法错误)2、zsh-autosuggestions(自动补全) 五、更多优化配…

MySQL#MySql表的操作

目录 一、创建表 二、查看表结构 三、修改表 1.修改表的名字 2.新增一个列 3.修改列 4.删除列 5.修改列的名称 四、删除表 一、创建表 语法: CREATE TABLE table_name (field1 datatype,field2 datatype,field3 datatype ) character set 字符集 collate 校…

element-ui skeleton 组件源码分享

今日简单分享 skeleton 骨架屏组件源码,主要从以下四个方面来讲解: 1、skeleton 组件的页面结构 2、skeleton 组件的属性 3、skeleton item 组件的属性 4、skeleton 组件的 slot 一、skeleton 组件的页面结构 二、skeleton 组件的属性 2.1 animate…

漏洞管理是如何在攻击者之前识别漏洞从而帮助人们阻止攻击的

漏洞管理 是主动查找、评估和缓解组织 IT 环境中的安全漏洞、弱点、差距、错误配置和错误的过程。该过程通常扩展到整个 IT 环境,包括网络、应用程序、系统、基础设施、软件和第三方服务等。鉴于所涉及的高成本,组织根本无法承受网络攻击和数据泄露。如果…

JUC下的ForkJoinPool详解

详细介绍 ForkJoinPool 是 Java 并发包 (java.util.concurrent) 中的一个特殊线程池,专为分治算法设计,能够高效地处理大量可分解的并行任务。它基于工作窃取(work-stealing)算法,当一个工作线程的任务队列为空时&…

HFSS学习-day3-HFSS的工作界面

工作界面也称为用户界面,是HFSS软件使用者的工作环境:了解、熟悉这个工作环境是掌握HFSS软件使用的第一步 HFSS工作环境介绍 1.HFSS工作界面简单的组成说明2.工作界面中各个工作窗口功能主菜单工具栏项目管理窗口属性窗口信息管理窗口进程窗口三维模型窗口 3.HFSS主…

数据结构_栈和队列(Stack Queue)

✨✨所属专栏:数据结构✨✨ ✨✨作者主页:嶔某✨✨ 栈: 代码:function/数据结构_栈/stack.c 钦某/c-language-learning - 码云 - 开源中国 (gitee.com)https://gitee.com/wang-qin928/c-language-learning/blob/master/function/…

java中的oop(三)、构造器、javabean、uml类图、this、继承

!! 有get/set方法的情况基本就是说要搞个私有属性,不直接对外开放; 构造器 Person p new Person(); //其中的Person();就是构造器;---造对象;Constructor–建设者,建造者; 作用 搭配new 创建类的&…

docker学习-docker常用其他命令整理

随便写写,后面有空再更新 镜像命令,容器命令已在之前略有更新,这次不写, 一、后台启动命令 # 命令 docker run -d 容器名 # 例子 docker run -d centos # 启动centos,使用后台方式启动 # 问题: 使用doc…

大数据手册(Spark)--Spark 简介

Spark 简介 Apache Spark 是一种用于大数据工作负载的分布式开源处理系统。它使用内存中缓存和优化的查询执行方式,可针对任何规模的数据进行快速分析查询。Apache Spark 提供了简明、一致的 Java、Scala、Python 和 R 应用程序编程接口 (API)。 Apache Spark 是专…

代码随想录第四十三天|最后一块石头的重量 II 、目标和

题目链接:. - 力扣(LeetCode) 代码如下: 题目链接:. - 力扣(LeetCode) 代码如下:

用户行为分析与内容创新:Kompas.ai的数据驱动策略

在数字化营销的今天,用户行为数据分析已成为内容创新和策略调整的核心。通过深入理解用户的行为模式和偏好,品牌能够创造出更具吸引力和相关性的内容,从而实现精准营销。本文将探讨用户行为数据分析在内容创新和策略调整中的价值,…

【Linux】进程间通信方式之管道

🤖个人主页:晚风相伴-CSDN博客 💖如果觉得内容对你有帮助的话,还请给博主一键三连(点赞💜、收藏🧡、关注💚)吧 🙏如果内容有误的话,还望指出&…