Python - 深度学习系列38 重塑实体识别5-预测并行化改造

说明

在重塑实体识别4中梳理了数据流,然后我发现pipeline的串行效率太低了,所以做了并行化改造。里面还是有不少坑的,记录一下。

内容

1 pipeline

官方的pipeline看起来的确是比较好用的,主要是实现了比较好的数据预处理。因为在训练/使用过程中都要进行数据的令牌化与反令牌化,有些字符会被特殊处理,例如 '##A’等。
在这里插入图片描述
在使用过程中,我用200条新闻数据进行测试,用pipeline方法花了11分钟处理完毕,期间CUDA的使用率大约为10%。按此估算,即使用多接口并行的方式,那么一分钟最多处理2000条,一天最多处理0.14*2000~30万条数据。这个效率太低了。

2 并行化

最终的结论是不到30秒处理200条,显存只占用2.6G,理论上可以支持3个服务并行(以确保GPU的完全利用)。按最保守的估计,改造后的并行化应该可以提升3倍的效率,稍微激进一点,可以提升10倍的效率。这个之后可以进行测试。

一些主要的点如下

2.1 结果解析

结果可以分为:

  • 1 仅含解析出的实体列表,用逗号连接字符串表示。
  • 2 含实体及其起始位置的表示,这个用于标注反馈、二次增强处理。
  • 3 仅含BIO标签,主要用于和测试数据进行效果比对。

对应的相关函数,看起来有点繁杂,我自己都不太想看第二眼。

from datasets import ClassLabel
# 定义标签列表
label_list = ['B', 'I', 'O']
# 创建 ClassLabel 对象
class_label = ClassLabel(names=label_list)
def convert_entity_label_batch(x):x1 = xreturn class_label.int2str(x1)
# 定义函数将整数Tensor转换为字符串 | 反令牌函数,但是用不上;因为predict label列表的长度和 ss_padding相同
def tensor_to_string(tensor, tokenizer = None , skip_special_tokens = True):return tokenizer.decode(tensor.tolist(), skip_special_tokens=skip_special_tokens).replace(' ','')from datasets import ClassLabel
def detokenize(word_piece):"""将 WordPiece 令牌还原为原始句子。"""if word_piece.startswith('##'):x = word_piece[2:]else:x = word_piecereturn x
import re
def extract_bio_positions(bio_string):pattern = re.compile(r'B(I+)(O|$)')matches = pattern.finditer(bio_string)results = []for match in matches:start, end = match.span()results.append((start, end - 1))  # end-1 to include the last 'I'return results# 0.1ms
def parse_ent_pos_map_batch(some_dict = None):word_list = some_dict['token_words']label_list = [int(x) for x in list(some_dict['label_list'])]min_len = min(len(word_list),len(label_list))word_list = word_list[:min_len]label_list = label_list[:min_len]label_list1 =  list(map(convert_entity_label_batch,label_list))oriword_list1 = list(map(detokenize,word_list))ori_word_str =''.join(oriword_list1)# 补到等长label_str = ''for i in range(len(label_list1)):len_of_ori_word = len(oriword_list1[i])if len_of_ori_word == 1:tem_str = label_list1[i]else:if label_list1[i] in ['I','O']:tem_str = label_list1[i] * len_of_ori_wordelse:tem_str = 'B' + 'I' * (len_of_ori_word -1)        label_str += tem_strpos_list = extract_bio_positions(label_str)part_ent_list = [(ori_word_str[x[0]:x[1]] , *x) for x in pos_list]return part_ent_list# =============
def make_BIO_by_len(some_len):default_str = 'I' * some_lenstr_list = list(default_str)str_list[0] ='B'return str_list
def gen_BIO_list2(some_dict):the_content = some_dict['clean_data']ent_list =  some_dict['ent_tuple_list']content_list = list(the_content)tag_list = list('O'* len(content_list))for ent_info in ent_list:start = ent_info[1]end = ent_info[2]label_len = end-starttem_bio_list = make_BIO_by_len(label_len)tag_list[start:end] = tem_bio_listres_dict = {}res_dict['x'] = content_listres_dict['y'] = tag_listreturn res_dictdef trim_len(some_dict = None):padding_BIO = some_dict['padding_BIO']ss_len = some_dict['ss_len']return padding_BIO[:ss_len]

2.2 批量预测

看起来同样很繁杂,但是不得不细看。首先,数据会按照几个长度 20,50,198分为三部分处理,batch_predict每次仅处理一个批次。在这里,将数据转为定长的令牌长度,然后转入CUDA进行批量预测。结果再按照实体-位置 tuple, 实体列表和BIO三种方式进行解析。

from functools import partial
import transformers 
import torch 
from transformers import AutoModelForMaskedLM, AutoTokenizer,AutoModelForTokenClassification
from functools import partial
# some_batch 是原文经过padding的数据,['ss_hash','ss','ss_len', 'ss_padding'], 其中ss_padding的长度是固定的
# 模型文件和令牌文件都放在model_path之下,model比较大,避免重载;而tokenize会有padding过程,必须重载
# 模型先载入cuda
def batch_predict(some_batch, ss_padding_len = None, model = None, model_path = None):# 因为tokenize会在令牌的前后加上分隔令牌,所以+2if ss_padding_len is None:ss_padding_len = some_batch['ss_padding'].apply(len).max()print('ss_padding_len is %s ' % ss_padding_len)max_len = ss_padding_len+2tokenizer = AutoTokenizer.from_pretrained(model_path)tencoder = partial(tokenizer.encode,truncation=True, max_length=max_len, is_split_into_words=True, return_tensors="pt",  padding='max_length')some_batch['ss_padding_token'] = some_batch['ss_padding'].apply(list).apply(tencoder)# 构成矩阵minput = torch.cat(list(some_batch['ss_padding_token'].values))# 将数据搬到GPU中处理再返回with torch.no_grad():input_cuda = minput.to(device)outputs_cuda = model(input_cuda).logitspredictions = torch.argmax(outputs_cuda, dim=2)predictions_list = list(predictions.to('cpu').numpy())predict_list1 = []for predictions in predictions_list:tem_pred_tag = [int(x) for x in predictions[1:-1]]predict_list1.append(tem_pred_tag)some_batch['label_list'] = predict_list1_s = cols2s(some_df =some_batch, cols= ['ss_padding','label_list'], cols_key_mapping= ['token_words', 'label_list'])_s1 = _s.apply(parse_ent_pos_map_batch)some_batch['ent_tuple_list'] = list(_s1)some_batch['ent_list'] = some_batch['ent_tuple_list'].apply(lambda x: ','.join([a[0] for a in x ]))_s = cols2s(some_batch, cols= ['ss_padding', 'ent_tuple_list'], cols_key_mapping= ['clean_data', 'ent_tuple_list'])s1 = _s.apply(gen_BIO_list2)ent_tuple_res_df1 = pd.DataFrame(s1.to_list())some_batch['padding_BIO'] = list(ent_tuple_res_df1['y'].apply(lambda x: ''.join(x)))_s00 = cols2s(some_batch, cols = ['ss_len', 'padding_BIO'], cols_key_mapping=['ss_len', 'padding_BIO'])some_batch['BIO'] = list(_s00.apply(trim_len))return some_batch    

3 迭代器

在推送数据处理时,可以采用迭代器来控制不同的批次数据

# 迭代器切分
import pandas as pd
class DataFrameBatchIterator:def __init__(self, dataframe, batch_size):self.dataframe = dataframeself.batch_size = batch_size# 【我增加的】self.fail_batch_list = []def __iter__(self):num_rows = len(self.dataframe)num_batches = (num_rows - 1) // self.batch_size + 1for i in range(num_batches):start_idx = i * self.batch_sizeend_idx = (i + 1) * self.batch_sizebatch_data = self.dataframe.iloc[start_idx:end_idx]yield batch_data# 【我增加的】def clear_fail(self):self.fail_batch_list = []# 【我增加的】def get_some_batch(self, batch_idx):return self.dataframe.iloc[self.batch_size * batch_idx: self.batch_size * (batch_idx + 1)]# 【我增加的】记录失败的批次def rec_fail_batch_idx(self, batch_idx):self.fail_batch_list.append(batch_idx)
# 创建一个示例 DataFrame
data = {'Name': ['John', 'Jane', 'Mike', 'Alice', 'Bob'],'Age': [25, 30, 35, 28, 32],'City': ['New York', 'Paris', 'London', 'Tokyo', 'Sydney']}
df = pd.DataFrame(data)
# 创建 DataFrame 迭代器
batch_iterator = DataFrameBatchIterator(df, batch_size=2)
import tqdm
# 使用迭代器逐批次处理数据
for i,batch in tqdm.tqdm(enumerate(batch_iterator)):try:# 在这里可以对当前批次的数据进行相应的操作# 例如进行数据清洗、特征处理、模型训练等# 示例:打印当前批次的数据
#         raise Exception(e) print(batch)except:print('>>> %s Fail' % i)batch_iterator.rec_fail_batch_idx(i)

以下是实际的调度

# 假设处理长度为1万的句子
# 20 * 2000 ~ 4w
# 50 * 800 ~  4w
# 200 * 200 ~ 4w
import warnings 
warnings.filterwarnings('ignore')
batch_slice_para = {20:2000, 50:800, 200:200}
batch_len_list = sorted(list(batch_slice_para.keys()))
batch_len_list.insert(0,0)batch_df_list = []
for i in range(len(batch_len_list)):if i >0:sel = (ss_df['ss_len'] >=batch_len_list[i-1]) & (ss_df['ss_len'] < batch_len_list[i])if sel.sum():padding_len = batch_len_list[i]padding_batch = batch_slice_para[padding_len]tem_df= ss_df[sel]# tem_df['ss_padding'] = tem_df['ss'].apply(lambda x: x.ljust(padding_len,'a'))tem_df['ss_padding'] = tem_df['ss']tem_df_iterator = DataFrameBatchIterator(tem_df, padding_batch)batch_df_list.append(tem_df_iterator)else:batch_df_list.append(None)

对每个批次执行处理,载入模型

label_list = ['B','I','O']
model_checkpoint = 'model03'
model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=len(label_list))device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Device: %s' % device)
# 自动切换设备
if model.device.type != device:model.to(device)print('>>> 检测到模型设备与当前指定不一致,切换 %s' % device )
else:print('>>> 模型设备一致,不切换 %s' % device)

分批次预测(主要是确保显存不溢出)

batch_res_list = []
for some_iter in batch_df_list:for some_batch in some_iter:batch_res = batch_predict(some_batch, model = model, model_path = 'model03')batch_res_list.append(batch_res)

结果合并

batch_res_df = pd.concat(batch_res_list, ignore_index= True)
mdf = pd.merge(input_df , batch_res_df[['ss_hash', 'ent_list']],how='left', on ='ss_hash')

在这里插入图片描述

4 总结

一个在理论上证明可以显著提升效率的点在于,模型进行实体识别时先切分了短句,然后按短句进行了去重:相同短句的实体结果一定是相同的。

实验中,200条新闻产生了约5万个短句,去重后只剩下约3.5万。所以即使在这一步也是有提升的。当然,这种方式同样也可以被用于pipeline。

还有就是在处理填充时,并不按照最大长度统一填充。而是按照句子长度的统计特性分为了短、中、长三种方式。从统计上看,约70%的短句长度是在20个字符以内的,真正超过50个字符的短句(中间无分隔符),即使从语法上来看也是比较奇怪的。
这样在填充数据时浪费就比pipeline要小,同样显存可以装下更多的数据。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/22679.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Solidwokrs钣金拆图之移动面命令使用技巧

Solidwokrs钣金拆图之移动面命令使用技巧 Chapter1 Solidwokrs钣金拆图之移动面命令使用技巧Chapter2 solidworks如何删除外部参考 Chapter1 Solidwokrs钣金拆图之移动面命令使用技巧 原文链接&#xff1a;https://www.sohu.com/a/441562400_728492 今天给大家介绍一个SolidW…

IO进程线程(六)进程

文章目录 一、进程状态&#xff08;二&#xff09;进程状态切换实例1. 实例1 二、进程的创建&#xff08;一&#xff09;原理&#xff08;二&#xff09;fork函数--创建进程1. 定义2. 不关注返回值3. 关注返回值 &#xff08;三&#xff09; 父子进程的执行顺序&#xff08;四&…

【Redis数据库百万字详解】数据持久化

文章目录 一、持久化1.1、什么是持久化1.2、持久化方式1.3、RDB优缺点1.4、AOF优缺点 二、RDB持久化触发机制2.1、手动触发2.2、自动触发 三、RDB持久化配置3.1、配置文件3.2、配置查询/设置3.3、禁用持久化3.4、RDB文件恢复 四、RDB持久化案例4.1、手动持久化4.2、自动持久化案…

2024第26届大湾区国际电机博览会暨发展论坛

2024第二十六届大湾区国际电机博览会 暨发展论坛 2024第26届大湾区国际电机博览会暨发展论坛 The 26th Greater Bay Area International Motor Expo and Development Forum 时间&#xff1a;2024年12月4-6日 地址&#xff1a;深圳国际会展中心&#xff08;宝安新馆&#x…

安全生产新篇章:可燃气体报警器检验周期的国家标准解读

随着工业化进程的加快&#xff0c;安全生产成为了重中之重。 可燃气体报警器作为预防火灾和爆炸事故的重要设备&#xff0c;其准确性和可靠性直接关系到企业的生产安全和员工的生命财产安全。 因此&#xff0c;国家对可燃气体报警器的检验周期有着明确的规定&#xff0c;以确…

美洽工作台3.0,全新发布!

美洽工作台3.0&#xff0c;全新发布 想要效率翻倍&#xff0c;就要一步到位&#xff01; 工作台 3.0&#xff0c;为效率而生 1. 更丰富的外观选择&#xff0c;让界面焕然一新&#xff0c;新增导航主题色选择&#xff0c;深色 Dark、浅色 Light 随意切换 2. 自定义你的专属导…

Python 识别图片形式pdf的尝试(未解决)

想识别出pdf页面右下角某处的编号。pdf是图片形式页面。查了下方法&#xff0c;有源码是先将页面提取成jpg&#xff0c;再用pytesseract提取图片文件中的内容。 直接用图片来识别。纯数字的图片&#xff0c;如条形码&#xff0c;可识别。带中文的不可以&#xff0c;很乱。 识别…

吴恩达深度学习笔记:机器学习(ML)策略(1)(ML strategy(1))1.3-1.4

目录 第三门课 结构化机器学习项目&#xff08;Structuring Machine Learning Projects&#xff09;第一周 机器学习&#xff08;ML&#xff09;策略&#xff08;1&#xff09;&#xff08;ML strategy&#xff08;1&#xff09;&#xff09;1.3 单一数字评估指标&#xff08;S…

Linux|如何安装 Java

引言 Java是最受欢迎的编程语言之一&#xff0c;JVM&#xff08;Java的虚拟机&#xff09;是运行Java应用程序的运行时环境。这两个平台是许多流行软件所需的&#xff0c;包括Tomcat&#xff0c;Jetty&#xff0c;Cassandra&#xff0c;Glassfish和Jenkins。 本教程[1]将指导您…

2024年应用经济学、管理科学与社会国际学术会议(ICAEMSS 2024)

2024年应用经济学、管理科学与社会国际学术会议&#xff08;ICAEMSS 2024&#xff09; 会议简介 2024年应用经济学、管理科学与社会国际学术会议将聚焦应用经济学和管理科学的前沿问题&#xff0c;深入探讨社会变革中的经济管理与科学应用。参会者将分享最新研究成果&#xf…

短剧小程序App系统源码:打造个性化追剧体验

随着数字媒体的迅猛发展&#xff0c;短剧作为一种新兴的娱乐形式&#xff0c;越来越受到广大观众的喜爱。为了满足用户对短剧内容的个性化需求&#xff0c;短剧小程序App系统应运而生。本文将深入探讨短剧App源码的核心功能&#xff0c;以及如何通过多语言支持和国际支付等技术…

超声波洗眼镜机是智商税吗?四款不能错过的超声波清洗机实力种草

在日常生活中&#xff0c;眼镜成为了我们不可或缺的伙伴&#xff0c;无论是阅读书籍、工作还是享受自然风光&#xff0c;清晰的视野总是至关重要。然而&#xff0c;眼镜上不可避免地会沾染灰尘、油脂甚至细菌&#xff0c;影响我们的视觉体验。传统的眼镜清洗方法虽然简单&#…

雷池WAF《动态防护》功能体验

一、雷池简介&#xff08;官方&#xff09; 自 2016 年起&#xff0c;长亭就开源了雷池的语义分析算法自动机引擎&#xff0c;随后又陆续开源了雷池相关风控插件和引擎通信协议。雷池的商业版本自发布以来&#xff0c;得到了各大咨询机构和众多顶级企业的认可。然而&#xff0…

MT3050 区间最小值

思路&#xff1a; 使用ST表 ST模板可参考MT3024 maxmin 代码&#xff1a; 1.暴力9/10&#xff1a; #include <bits/stdc.h> using namespace std; const int N 1e5 10; int n, m; int a[N]; int main() {ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);cin …

2024精选热门骨传导耳机推荐,你不会还不挑选吧?

骨传导耳机作为最近两年来才兴起的耳机品类&#xff0c;在街头的出现频率并不是很高&#xff0c;很多人对骨传导耳机不够了解甚至没听说过。骨传导耳机不入耳的设计&#xff0c;安全性、舒适性和稳定性都更高&#xff0c;既然有这么多的优势&#xff0c;那就为大家挑选几款高性…

16. 最接近的三数之和 - 力扣

1. 题目 给你一个长度为 n 的整数数组 nums 和 一个目标值 target。请你从 nums 中选出三个整数&#xff0c;使它们的和与 target 最接近。 返回这三个数的和。 假定每组输入只存在恰好一个解。 2. 示例 3.分析 做这道题目前&#xff0c;先做这道&#xff1a;三数之和 &#x…

手动操作Telnet不嫌累?要不试一下我自制的自动执行指令Telnet工具吧!

网管小贾 / sysadm.cc 昨天发生了一件事&#xff0c;我现在仍记忆犹新。 一大早我就被秘书喊进了胡总的办公室…… 一进门&#xff0c;只见我们部门的赖经理也在。 我打完招呼&#xff0c;胡总就问我&#xff0c;最近调到我们部门实习的小王表现如何。 我偷偷瞥了一眼赖经理…

【已有项目版】uniapp项目发版pda -- Android Studio

必备资料清单&#xff1a; 构建完成的app项目 在HBuilderX开发的uniapp项目 .keystore文件 文章目录 1. 安装Android Studio&#xff1a;https://developer.android.google.cn/studio?hlzh-cn2. 安装Android 离线SDK&#xff1a;https://nativesupport.dcloud.net.cn/AppDocs…

短期业绩波动较大被券商不予评级,金种子酒背靠华润如何发力?

《港湾商业观察》施子夫 王璐 虽然一季度成功实现了扭亏为盈&#xff0c;但从近些年年报来看&#xff0c;金种子酒&#xff08;600199.SH&#xff09;的业绩压力依然不容小觑。白酒主业萎靡不振时&#xff0c;金种子酒开始了剥离非主营业务。 这些措施能否有利于主业向好&am…

jmeter的infludb+grafana实时监控平台

目的&#xff1a;可以实时查看到jmeter拷机信息 框架&#xff1a;将 Jmeter 的数据导入 InfluxDB &#xff0c;再用 Grafana 从 InfluxDB 中获取数据并以特定的模板进行展示 性能监控平台部署实践 一、influxDB 官网&#xff1a;https://www.influxdata.com/downloads/ wget h…