Parallelize your massive SHAP computations with MLlib and PySpark

https://medium.com/towards-data-science/parallelize-your-massive-shap-computations-with-mllib-and-pyspark-b00accc8667c

(能翻墙直接看原文)

A stepwise guide for efficiently explaining your models using SHAP.

Photo by Pietro Jeng on Unsplash

Introduction to MLlib

Apache Spark’s Machine Learning Library (MLlib) is designed primarily for scalability and speed by leveraging the Spark runtime for common distributed use cases in supervised learning like classification and regression, unsupervised learning like clustering and collaborative filtering and in other cases like dimensionality reduction. In this article, I cover how we can use SHAP to explain a Gradient Boosted Trees (GBT) model that has fit our data at scale.

What are Gradient Boosted Trees?

Before we understand what Gradient Boosted Trees are, we need to understand boosting. Boosting is an ensemble technique that sequentially combines a number of weak learners to achieve an overall strong learner. In case of Gradient Boosted Trees, each weak learner is a decision tree that sequentially minimizes the errors (MSE in case of regression and log loss in case of classification) generated by the previous decision tree in that sequence. To read about GBTs in more detail, please refer to this blog post.

Understanding our imports

from pyspark.sql import SparkSession
from pyspark import SparkContext, SparkConf
from pyspark.ml.classification import GBTClassificationModel
import shap
import pyspark.sql.functions as F
from pyspark.sql.types import *

The first two imports are for initializing a Spark session. It will be used for converting our pandas dataframe to a spark one. The third import is used to load our GBT model into memory which will be passed to our SHAP explainer to generate explanations. The SHAP explainer itself will be initialized using the SHAP package using the fourth import. The penultimate and last import is for performing SQL functions and using SQL types. These will be used in our User-Defined Function (UDF) which I shall describe later.

Converting our MLlib GBT feature vector to a Pandas dataframe

The SHAP Explainer takes a dataframe as input. However, training an MLlib GBT model requires data preprocessing. More specifically, the categorical variables in our data needs to be converted into numeric variables using either Category Indexing or One-Hot Encoding. To learn more about how to train a GBT model, refer to this article). The resulting “features” column is a SparseVector (to read more on it, check the “Preprocess Data” section in this example). It looks like something below:

SparseVector features column description — 1. default index value, 2. vector length, 3. list of indexes of the feature columns, 4. list of data values at the corresponding index at 3. [Image by author]

The “features” column shown above is for a single training instance. We need to transform this SparseVector for all our training instances. One way to do it is to iteratively process each row and append to our pandas dataframe that we will feed to our SHAP explainer (ouch!). There is a much faster way, which leverages the fact that we have all of our data loaded in memory (if not, we can load it in batches and perform the preprocessing for each in-memory batch). In Shikhar Dua’s words:

1. Create a list of dictionaries in which each dictionary corresponds to an input data row.

2. Create a data frame from this list.

So, based on the above method, we get something like this:

rows_list = []
for row in spark_df.rdd.collect(): dict1 = {} dict1.update({k:v for k,v in zip(spark_df.cols,row.features)})rows_list.append(dict1) 
pandas_df = pd.DataFrame(rows_list)

If rdd.collect() looks scary, it’s actually pretty simple to explain. Resilient Distributed Datasets (RDD) are fundamental Spark data structures that are an immutable distribution of objects. Each dataset in an RDD is further subdivided into logical partitions that can be computed in different worker nodes of our Spark cluster. So, all PySpark RDD collect() does is retrieve data from all the worker nodes to the driver node. As you might guess, this is a memory bottleneck, and if we are handling data larger than our driver node’s memory capacity, we need to increase the number of our RDD partitions and filter them by partition index. Read how to do that here.

Don’t take my word on the execution performance. Check out the stats.

Performance profiling for inserting rows to a pandas dataframe. [Source (Thanks to Mikhail_Sam and Peter Mortensen): here]

Here are the metrics from one of my Databricks notebook scheduled job runs:

Input size: 11.9 GiB (~12.78GB), Total time Across All Tasks: 20 min, Number of records: 165.16K

Summary Metrics for 125 Completed Tasks executed by the stage that run the above cell. [Image by author]

Working with the SHAP Library

We are now ready to pass our preprocessed dataset to the SHAP TreeExplainer. Remember that SHAP is a local feature attribution method that explains individual predictions as an algebraic sum of the shapley values of the features of our model.

We use a TreeExplainer for the following reasons:

  1. Suitable: TreeExplainer is a class that computes SHAP values for tree-based models (Random Forest, XGBoost, LightGBM, GBT, etc).
  2. Exact: Instead of simulating missing features by random sampling, it makes use of the tree structure by simply ignoring decision paths that rely on the missing features. The TreeExplainer output is therefore deterministic and does not vary based on the background dataset.
  3. Efficient: Instead of iterating over each possible feature combination (or a subset thereof), all combinations are pushed through the tree simultaneously, using a more complex algorithm to keep track of each combination’s result — reducing complexity from O(TL2ᵐ) for all possible coalitions to the polynomial O(TLD²) (where is the number of features, is number of trees, is maximum number of leaves and is maximum tree depth).

The check_additivity = False flag runs a validation check to verify if the sum of SHAP values equals to the output of the model. However, this flag requires predictions to be run that are not supported by Spark, so it needs to be set to False as it is ignored anyway. Once we get the SHAP values, we convert it into a pandas dataframe from a Numpy array, so that it is easily interpretable.

One thing to note is that the dataset order is preserved when we convert a Spark dataframe to pandas, but the reverse is not true.

The points above lead us to the code snippet below:

gbt = GBTClassificationModel.load('your-model-path') 
explainer = shap.TreeExplainer(gbt)
shap_values = explainer(pandas_df, check_additivity = False)
shap_pandas_df = pd.DataFrame(shap_values.values, cols = pandas_df.columns)

An Introduction to Pyspark UDFs and when to use them

How PySpark UDFs distribute individual tasks to worker (executor) nodes [Source: here]

User-Defined Functions are complex custom functions that operate on a particular row of our dataset. These functions are generally used when the native Spark functions are not deemed sufficient to solve the problem. Spark functions are inherently faster than UDFs because it is natively a JVM structure whose methods are implemented by local calls to Java APIs. However, PySpark UDFs are Python implementations that requires data movement between the Python interpreter and the JVM (refer to Arrow 4 in the picture above). This inevitably introduces some processing delay.

If no processing delays can be tolerated, the best thing to do is create a Python wrapper to call the Scala UDF from PySpark itself. A great example is shown in this blog. However, using a PySpark UDF was sufficient for my use case, since it is easy to understand and code.

The code below explains the Python function to be executed on each worker/executor node. We just pick up the highest SHAP values (absolute values as we want to find the most impactful negative features as well) and append it to the respective pos_features and neg_features list and in turn append both these lists to a features list that is returned to the caller.

def shap_udf(row):dict = {} pos_features = [] neg_features = [] for feature in row.columns: dict[feature] = row[feature]     dict_importance = {key: value for key, value insorted(dict.items(), key=lambda item: __builtin__.abs(item[1]),   reverse = True)}     for k,v in dict_importance.items(): if __builtin__.abs(v) >= <your-threshold-shap-value>: if v > 0: pos_features.append((k,v)) else: neg_features.append((k,v)) features = [] features.append(pos_features[:5]) features.append(neg_features[:5])    return features

We then register our PySpark UDF with our Python function name (in my case, it is shap_udf) and specify the return type (mandatory in Python and Java) of the function in the parameters to F.udf(). There are two lists in the outer ArrayType(), one for positive features and the other for negative ones. Since each individual list comprises of at most 5 (feature-name, shap-value) StructType() pairs, it represents the inner ArrayType(). Below is the code:

udf_obj = F.udf(shap_udf, ArrayType(ArrayType(StructType([ StructField(‘Feature’, StringType()), 
StructField(‘Shap_Value’, FloatType()),
]))))

Now, we just create a new Spark dataframe with a column called ‘Shap_Importance’ that invokes our UDF for each row of the spark_shapdf dataframe. To split the positive and negative features, we create two columns in a new Spark dataframe called final_sparkdf. Our final code-snippet looks like below:

new_sparkdf = spark_df.withColumn(‘Shap_Importance’, udf_obj(F.struct([spark_shapdf[x] for x in spark_shapdf.columns])))final_sparkdf = new_sparkdf.withColumn(‘Positive_Shap’, final_sparkdf.Shap_Importance[0]).withColumn(‘Negative_Shap’, new_sparkdf.Shap_Importance[1])

And finally, we have extracted all the important features of our GBT model per testing instance without the use of any explicit for loops! The consolidated code can be found in the below GitHub gist.

from pyspark.sql import SparkSession
from pyspark import SparkContext, SparkConf
from pyspark.ml.classification import GBTClassificationModel
import shap
import pyspark.sql.functions as  F
from pyspark.sql.types import *#convert the sparse feature vector that is passed to the MLlib GBT model into a pandas dataframe. 
#This 'pandas_df' will be passed to the Shap TreeExplainer.
rows_list = []
for row in spark_df.rdd.collect(): dict1 = {}dict1.update({k:v for k,v in zip(spark_df.cols,row.features)})rows_list.append(dict1)pandas_df = pd.DataFrame(rows_list)#Load the GBT model from the path you have saved it
gbt = GBTClassificationModel.load("<your path where the GBT model is loaded>") 
#make sure the application where your notebook runs has access to the storage path!explainer = shap.TreeExplainer(gbt)
#check_additivity requires predictions to be run that is not supported by spark [yet], so it needs to be set to False as it is ignored anyway.
shap_values = explainer(pandas_df, check_additivity = False)
shap_pandas_df = pd.DataFrame(shap_values.values, cols = pandas_df.columns)spark = SparkSession.builder.config(conf=SparkConf().set("spark.master", "local[*]")).getOrCreate()
spark_shapdf = spark.createDataFrame(shap_pandas_df)def shap_udf(row): #work on a single spark dataframe row, for all rows. This work is distributed among all the worker nodes of your Apache Spark cluster.dict = {}pos_features = []neg_features = []for feature in row.columns:dict[feature] = row[feature]dict_importance = {key: value for key, value in sorted(dict.items(), key=lambda item: __builtin__.abs(item[1]), reverse = True)}for k,v in dict_importance.items():if __builtin__.abs(v) >= <your-threshold-shap-value>:if v > 0:pos_features.append((k,v))else:neg_features.append((k,v))features = []#taking top 5 features from pos and neg features. We can increase this number.features.append(pos_features[:5])features.append(neg_features[:5])return featuresudf_obj = F.udf(shap_udf, ArrayType(ArrayType(StructType([StructField('Feature', StringType()),StructField('Shap_Value', FloatType()),
]))))new_sparkdf = spark_df.withColumn('Shap_Importance', udf_obj(F.struct([spark_shapdf[x] for x in spark_shapdf.columns])))
final_sparkdf = new_sparkdf.withColumn('Positive_Shap', final_sparkdf.Shap_Importance[0]).withColumn('Negative_Shap', new_sparkdf.Shap_Importance[1])

Get the most impactful Positive and Negative SHAP values from our fitted GBT Model

P.S. This is my first attempt at writing an article and if there are any factual or statistical inconsistencies, please reach out to me and I shall be more than happy to learn together with you! :)

References

[1] Soner Yıldırım, Gradient Boosted Decision Trees-Explained (2020), Towards Data Science

[2] Susan Li, Machine Learning with PySpark and MLlib — Solving a Binary Classification Problem (2018), Towards Data Science

[3] Stephen Offer, How to Train XGBoost With Spark (2020), Data Science and ML

[4] Use Apache Spark MLlib on Databricks (2021), Databricks

[5] Umberto Griffo, Don’t collect large RDDs (2020), Apache Spark — Best Practices and Tuning

[6] Nikhilesh Nukala, Yuhao Zhu, Guilherme Braccialli, Tom Goldenberg (2019), Spark UDF — Deep Insights in Performance, QuantumBlack

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/855527.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【调试记录】:windows查看一个文件夹下(多个子文件夹)每个子文件夹中文件的数量分别是多少

查看文件夹小所有文件的数量&#xff0c;包括子文件中的文件的数量 Get-ChildItem -Path "C:\path\to\your\directory" -Recurse | Measure-Object查看一个文件夹下多个子文件每个子文件夹中文件的数量分别是多少 $targetFolder "C:\path\to\your\directory&…

jquey+mybatis-plus实现简单分页功能

这篇文章介绍一下怎么通过JQuery结合mybatis-plus的分页插件实现原生HTML页面的分页效果&#xff0c;没有使用任何前端框架&#xff0c;主要是对前端知识的应用。 创建Springboot项目 Intellij IDEA中创建一个Springboot项目&#xff0c;项目名为pager。 添加必须的依赖包 修…

【Linux】—MySQL安装

文章目录 前言一、下载官方MySQL包二、下载完成后&#xff0c;通过xftp6上传到Linux服务器上三、解压MySQL安装包四、在安装目录下执行rpm安装&#xff0c;请按顺序依次执行。五、配置MySQL六、启动MySQL数据库七、退出&#xff0c;重新登录数据库 前言 本文主要介绍在Linux环境…

创建Docker容器与外部机通信(独立IP的方式)

需求&#xff1a;希望外部可以直接通过不同IP地址访问宿主机上的Docker容器&#xff0c;而不需要端口映射&#xff08;同一个IP不同的端口与外部通讯&#xff09;&#xff0c;这通常涉及到在宿主机的网络层面进行更高级的配置&#xff0c;比如使用IP伪装&#xff08;IP masquer…

团队协同渗透测试报告输入输出平台部署

目录 简介 文章来源 部署环境 文件下载 开始安装 系统初始化 免责声明 结语 简介 因应监管部需求&#xff0c;国内访问Docker源pull镜像开始变得复杂且困难起来了&#xff0c;大佬github给的在线/离线安装脚本跑了很久也无法拉取到镜像&#xff0c;所以将以前的镜像打…

类的继承性(Java)

本篇学习面向对象语言的第二特性——继承性。 1 .为什么需要继承 我们来举个例子&#xff1a;我们知道动物有很多种&#xff0c;是一个比较大的概念。在动物的种类中&#xff0c;我们熟悉的有猫(Cat)、狗(Dog)等动物&#xff0c;它们都有动物的一般特征&#xff08;比如能够吃…

【YOLOv9改进[注意力]】在YOLOv9中使用注意力CascadedGroupAttention(2023)的实践 + 含全部代码和详细修改方式

本文将进行在YOLOv9中使用注意力CascadedGroupAttention的实践,助力YOLOv9目标检测效果的实践,文中含全部代码、详细修改方式。助您轻松理解改进的方法。 改进前和改进后的参数对比: 目录 一 CascadedGroupAttention 二 在YOLOv9中使用注意力CascadedGroupAttention的实…

Rcmp: Reconstructing RDMA-Based Memory Disaggregation via CXL——论文阅读

TACO 2024 Paper CXL论文阅读笔记整理 背景 RDMA&#xff1a;RDMA是一系列协议&#xff0c;允许一台机器通过网络直接访问远程机器中的数据。RDMA协议通常固定在RDMA NIC&#xff08;RNIC&#xff09;上&#xff0c;具有高带宽&#xff08;>10 GB/s&#xff09;和微秒级延…

云计算 | (八)基本云架构

文章目录 📚负载分布架构🐇负载分布架构🐇单机系统🐇应用、数据库分离🐇应用服务集群🐇微服务🐇负载均衡分类🐇Nginx🐇负载均衡算法⭐️轮询法⭐️随机法⭐️源地址哈希法⭐️加权轮询法⭐️加权随机法⭐️键值范围法📚资源池架构📚动态可扩展架构📚弹…

Vatee万腾平台:智能科技的领航者

随着科技的飞速发展&#xff0c;数字化转型已成为企业、行业乃至整个社会不可逆转的趋势。在这个变革的浪潮中&#xff0c;Vatee万腾平台凭借其卓越的技术实力、前瞻的战略眼光和卓越的服务品质&#xff0c;成为了智能科技的领航者。 Vatee万腾平台致力于为企业提供全方位的数字…

java:spring actuator扩展原有info endpoint的功能

# 项目代码资源&#xff1a; 可能还在审核中&#xff0c;请等待。。。 https://download.csdn.net/download/chenhz2284/89437506 # 项目代码 【pom.xml】 <dependencies><dependency><groupId>org.springframework.boot</groupId><artifactId&…

客户端输入网址后发生的全过程解析(协议交互、缓存、渲染)

目录 1. 输入 URL 并按下回车键2. DNS 解析3. TCP 连接4. 发送 HTTP 请求5. 服务器处理请求6. 发送 HTTP 响应7. 浏览器接收响应8. 渲染网页9. 执行脚本10. 处理其他资源11. TLS/SSL 加密&#xff08;如果使用 HTTPS&#xff09;握手过程 12. 协议协商和优化 总结 1. 输入 URL …

Ubuntu下FastDDS的源码编译和简单测试

FastDDS是eprosima公司开发的DDS&#xff08;Data Distribution Service&#xff09;库&#xff0c;使用的语言是C&#xff0c;自称是"The Most Complete Open Source DDS Middleware"&#xff0c;其官网是https://eprosima.com/&#xff0c;FastDDS源码在https://gi…

【TB作品】MSP430G2553,单片机,口袋板,流量积算仪设计

题9 流量积算仪设计 某型流量计精度为0.1%, 满刻度值为4L/s&#xff0c;流量计输出为4—20 mA。 设计基于MSP430及VFC32的流量积算仪。 具体要求 (1) 积算仪满刻度10000 L&#xff0c;精度0.1 L; 计满10000 L&#xff0c;自动归零并通过串口&#xff08;RS232&#xff09;向上位…

【Cloudscapes V2】Blender商城10周年免费领取礼物超逼真的Vdb云和爆炸合集烟雾体积云字体符号轨迹火焰粒子

6月19号的限时免费领取插件挺牛的&#xff0c;可以在blender里渲染体积云、爆炸特效、火焰、烟雾等效果&#xff0c;非常逼真。 Blender商城10周年免费领取礼物&#xff1a;https://blendermarket.com/birthday Cloudscapes V2 - 超逼真的 Vdb 云和爆炸合集 CloudScapes 是 …

优思学院|怎么选择精益生产培训才不会被坑?

在选择精益生产培训公司时&#xff0c;我们需要从多个角度去思考。企业若只是盲目地跟风&#xff0c;这样的做法无异于缘木求鱼。精益生产的核心在于发现和消除那些不增值的活动&#xff0c;从而提升产品的质量和生产效率&#xff0c;但要知道的是&#xff0c;发现和改进的人就…

揭秘与应对:一打开移动硬盘就提示格式化的深度解析

在日常的数据存储与交换中&#xff0c;移动硬盘因其便携性和大容量而备受青睐。然而&#xff0c;有时我们可能会遇到一种令人困扰的现象&#xff1a;当试图打开移动硬盘时&#xff0c;系统会弹出一个警告窗口&#xff0c;提示“磁盘未被格式化&#xff0c;是否现在格式化&#…

Llama 3 大型语言模型到底是如何炼成的?

Meta 在今年 4 月开源了 Llama 3 大型语言模型&#xff0c;这是 Meta&#xff0c;也是整个行业迄今为止功能最强大的开源 LLM。 那么 Meta 是如何训练 Llama 3 大型语言模型的&#xff0c;又在训练过程中遇到了什么问题&#xff0c;提出了什么新的解决方案呢&#xff1f;近日&…

计算机行业的现状与未来之2024

年年都说编程好&#xff0c;编程工资涨不了。 人家骑车送外卖&#xff0c;月入两万好不好。 一、计算机专业的背景与现状 在过去几十年里&#xff0c;计算机科学相关专业一直是高考考生的热门选择。无论是计算机科学与技术、软件工程&#xff0c;还是人工智能与大数据&#xff…

防止员工离职导致数据泄露,员工离职后把文件带出公司

中科数安的电脑文件资料透明加密防泄密系统确实能够在一定程度上防止员工离职导致的数据泄露。以下是具体的分析&#xff1a; www.weaem.com 访问控制与权限管理&#xff1a;系统实施了严格的权限管理制度&#xff0c;对核心文件和数据资源进行细致的访问权限划分。这意味着&am…