spark支持深度学习批量推理

背景

在数据量较大的业务场景中,spark在数据处理、传统机器学习训练、
深度学习相关业务,能取得较明显的效率提升。
本篇围绕spark大数据背景下的推理,介绍一些优雅的使用方式。

spark适用场景

  1. 大数据量自定义方法处理、类sql处理
  2. 传统机器学习方法(k-means、xgboost、lr…)
  3. 分布式深度学习推理
    在这里插入图片描述

目前在10亿+数据量的推理场景中使用,需要用户自己实现批数据准备,基于RDD的方法完成模型推理输出。
业务使用中的问题:

  1. 模型文件重复导入加载
  2. 自定义批数据准备,脱离深度学习dataloader框架,操作略显麻烦,有性能和内存oom等问题。

实践

spark加速深度学习推理

spark加速深度学习推理,基本思路为:

  1. 开启不定量worker并行执行(cpu或gpu)推理任务
  2. 所有worker共享同一份模型参数
  3. 依赖spark pandas udf功能,方便并行处理 dataframe 数据
  4. 依赖深度学习框架,方便实现最优批数据划分
    下面以pytorch resnet 为实践demo

加载&&广播模型参数

广播模型参数,不仅能减少模型重复加载带来的流量和io,而且能加速推理前模型加载的速度。
driver广播模型参数:

# Load ResNet50 on driver node and broadcast its state.
model_state = models.resnet50(pretrained=True).state_dict()
bc_model_state = sc.broadcast(model_state)

worker读取模型参数:

def get_model_for_eval():"""Gets the broadcasted model."""model = models.resnet50(pretrained=True)model.load_state_dict(bc_model_state.value)model.eval()return model

实现基于dataframe的dataset

目前主流的深度学习框架,dataset的实现大多基于本地存储,在读取分布式存储的场景 需要用户自定义实现。
自定义实现有2个方法:

  1. 使用分布式存储的api接口读取文件内容
  2. dataset读取dataframe二进制文件内容

方法一迭代与使用的存储类型会保持同步,且每次使用前需要明确使用的分布式存储,虽然实现方法容易但是使用流程略显麻烦。
方法二不需要关心分布式存储类型,只要需要获取并解析spark dataframe列传入内容即可。

本文采用方法二实现dataset:

# 从二进制流中解析图片信息
def pil_loader(binary_file):# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)image_io = io.BytesIO(binary_file)img = Image.open(image_io)return img.convert('RGB')# Create a custom PyTorch dataset class.
class ImageDataset(Dataset):def __init__(self, data, transform=None):self.data = dataself.transform = transformdef __len__(self):return len(self.data)def __getitem__(self, index):image = pil_loader(self.data[index])if self.transform is not None:image = self.transform(image)return image

实现批量推理的pandas udf

Pandas udf是基于RDD的一个低门槛高性能的实现方法,pandas udf能自定义处理逻辑,以列的方式操作datafrme内容。
这是社区目前推荐的自定义处理方式。

# Define the function for model inference.
# PyArrow >= 1.0.0 must be installed;
@pandas_udf(ArrayType(FloatType()))
def predict_batch_udf(binaray_data: pd.Series) -> pd.Series:transform = transforms.Compose([transforms.Resize(224),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])images = ImageDataset(binaray_data, transform=transform)loader = torch.utils.data.DataLoader(images, batch_size=500, num_workers=8)model = get_model_for_eval()model.to(device)all_predictions = []with torch.no_grad():for batch in loader:predictions = list(model(batch.to(device)).cpu().numpy())for prediction in predictions:all_predictions.append(prediction)return pd.Series(all_predictions)
# 调用pandas udf
predictions_df = df. \select(col("filename"), predict_batch_udf(col("data")).alias("prediction"))

更多代码细节:
https://github.com/Crazybean-lwb/deeplearning-pyspark/blob/master/examples/pytorch-inference.py

模型仓加速推理

打通到模型仓mlflow功能:

  • 模型存储和版本管理
  • 便捷取用
  • 适用spark datarame更高阶的pandas udf实现

在这里插入图片描述

# Create the PySpark UDF
import mlflow.pyfunc
pyfunc_udf = mlflow.pyfunc.spark_udf(spark, model_uri=model_uri)# 调用pandas udf
df = spark_df.withColumn("prediction", pyfunc_udf(struct([...])))

参考信息:

  1. pytorch分布式批量推理
  2. tensorflow分布式批量推理
  3. 模型仓mlflow协助分布式批量推理

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/66962.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

1+X智慧安防系统实施与运维技能等级证产教融合基地建设方案

一、系统概述 1X智慧安防系统实施与运维技能等级证产教融合体系统融合了产业需求、教育培训和技能认证,通过课程培训、实训基地和实习实训等方式培养学员的技能水平,并通过技能认证来评估其能力,以满足智慧安防行业对人才的需求,并…

迈向无限可能, ATEN宏正领跑设备切换行业革命!

随着互联网在各个领域的广泛应用,线上办公这一不受时间和地点制约、不受发展空间限制的办公模式开始广受追捧,预示着经济的发展正朝着新潮与活跃的方向不断跃进。当然,在互联网时代的背景下,多线程、多设备的线上办公模式也催生了许多问题:多设备间无法进行高速传输、切换;为保…

能直接运营的发接任务平台小程序搭建开发演示

有个项目估计做过互联网的小伙伴都听说过——发接任务平台。 基本每年都有发接任务平台关站,但又有新的平台出来,往复循环,无比热闹。这在互联网圈不常见,互联网项目很多都是风头过去了就结束了,但发接任务年年似乎都…

电商项目part10 高并发缓存实战

缓存的数据一致性 只要使用到缓存,无论是本地内存做缓存还是使用 redis 做缓存,那么就会存在数据同步的问题。 先读缓存数据,缓存数据有,则立即返回结果;如果没有数据,则从数据库读数据,并且把…

spring 错误百科

一、使用Spring出错根源 1、隐式规则的存在 你可能忽略了 Sping Boot 中 SpringBootApplication 是有一个默认的扫描包范围的。这就是一个隐私规则。如果你原本不知道,那么犯错概率还是很高的。类似的案例这里不再赘述。 2、默认配置不合理 3、追求奇技淫巧 4、…

iOS系统修复软件 Fix My iPhone for Mac

Fix My iPhone for Mac是一款iOS系统恢复工具。修复您的iPhone卡在Apple徽标,黑屏,冻结屏幕,iTunes更新/还原错误和超过20个iOS 12升级失败。这个macOS桌面应用程序提供快速,即时的解决方案来修复您的iOS系统问题,而不…

记录一下自己对linux分区挂载的理解

一直狠模糊,分两个区,一个挂载/, 一个挂载/home 两者是什么关系 实测 先看挂载的内容 然后umount /home后创建一个新文件 再挂载回去 发现旧分区又回来了,说明路径只是个抽象的概念,分区挂载,互相之间数据是不影响…

layui数据表格实现表格中嵌套表格,并且可以折叠展开

效果: 思路: 1、最外层的表格先渲染,在done回调中向每个tr后面插入一个用来嵌套子级表格的tr。 tr的class和table的id需要用索引 i 关联 //向每一行tr后面追加显示子table的trlet trEles $(".layui-table-view[lay-idlist] tbody tr&…

iOS练手项目知识点汇总

基础理解篇 Objective-C是一种面向对象的编程语言,它支持元编程。元编程是指编写程序来生成或操纵其他程序的技术。 Objective-C中,元编程可以使用Objective-C的动态特性来实现。例如可以使用Objective-C的运行时函数来动态地创建类、添加属性和方法等等…

【Flutter】使用Android Studio 创建第一个flutter应用。

前言 首先下载好 flutter sdk和 Android Studio。 FlutterSDK下载 Android Studio官网 配置 我的是 windows。 where.exe flutter dart查看flutter安装环境。 如果没有,自己在环境变量的path添加下flutter安装路径。 在将 Path 变量更新后,打开一个…

Java应用CPU占用过高故障排除

一、背景 最近测试反馈测试环境接口偶现有访问超时,然后APP提示是网络失败,看了一下测试环境的应用完全没啥问题,一直以为是网络问题。 今天测试有反馈了,赶紧看了一下测试服务器,这次终于有症状了,CPU直…

ARM编程模型-指令流水线

流水线技术通过多个功能部件并行工作来缩短程序执行时间,提高处理器核的效率和吞吐率,从而成为微处理器设计中最为重要的技术之一。 1. 3级流水线 到ARM7为止的ARM处理器使用简单的3级流水线,它包括下列流水线级。 (1&#xff0…

第63步 深度学习图像识别:多分类建模误判病例分析(Tensorflow)

基于WIN10的64位系统演示 一、写在前面 上两期我们基于TensorFlow和Pytorch环境做了图像识别的多分类任务建模。这一期我们做误判病例分析,分两节介绍,分别基于TensorFlow和Pytorch环境的建模和分析。 本期以健康组、肺结核组、COVID-19组、细菌性&am…

Python安装与Pycharm配置

Python与Pycharm安装 用了一年的Python最近被一个问题难倒了,pip安装一直不能用,报错说被另一个程序使用。被逼到只能重新安装python了,正好记录一下这个过程,写这篇笔记。(突然想到可能是配Arcgis的python接口&#…

金融风控数据分析-信用评分卡建模(附数据集下载地址)

本文引用自: 金融风控:信用评分卡建模流程 - 知乎 (zhihu.com) 在原文的基础上加上了一部分自己的理解,转载在CSDN上作为保留记录。 本文涉及到的数据集可直接从天池上面下载: Give Me Some Credit给我一些荣誉_数据集-阿里云…

2023年新风口,Kwai快手海外公会入驻详细指南!

Kwai快手海外公会发展前景作为中国短视频市场的领头羊,Kwai快手在海外市场也取得了一定的成绩。但是,海外市场的发展前景还存在一些不确定因素,需要快手进一步探索和努力。首先,海外市场对于内容的申请找cmxyci要求与国内市场有所…

MonoDETR: Depth-guided Transformer for Monocular 3D Object Detection 论文解读

文章目录 1.Abstract2. Introduction3. Related workDETR base methods 4. Method4.1Feature ExtractionVisual Featuresdepth featuresforeground depth map 4. 2 Depth guided transformerVisual and depth encodersDepth-guided-decoderDepth positional encoding 4. 3 Dete…

【vue2第十一章】v-model的原理详解 与 如何使用v-model对父子组件的value绑定 和修饰符.sync

v-model的原理详解 v-model的本质就是一个语法糖,实际上就是 :value"msg" 与 input"msg $event.target.value" 的简写。 :value"msg" 从数据单向绑定到input框,当data数据中的msg内容一旦改变,而input框数据…

OpenCV(十八):图像直方图

目录 1.直方图统计 2.直方图均衡化 3.直方图匹配 1.直方图统计 直方图统计是一种用于分析图像或数据的统计方法,它通过统计每个数值或像素值的频率分布来了解数据的分布情况。 在OpenCV中,可以使用函数cv::calcHist()来计算图像的直方图。 calcHist(…

uni-app之android离线打包

一 AndroidStudio创建项目 1.1,上一节演示了uni-app云打包,下面演示怎样androidStudio离线打包。在AndroidStudio里面新建空项目 1.2,下载uni-app离线SDK,离线SDK主要用于App本地离线打包及扩展原生能力,SDK下载链接h…