spark写出分布式的训练算法_利用 Spark 和 scikit-learn 将你的模型训练加快 100 倍...

在 Ibotta,我们训练了许多机器学习模型。这些模型为我们的推荐系统、搜索引擎、定价优化引擎、数据质量等提供动力。它们在与我们的移动应用程序交互时为数百万用户做出预测。

当我们使用 Spark 进行数据处理时,我们首选的机器学习框架是 scikit-learn。随着计算机变得越来越便宜,机器学习解决方案的上市时间变得越来越关键,我们探索了加快模型训练的各种方法。其中一个解决方案是将 Spark 和 scikit-learn 中的元素组合到我们自己的混合解决方案中。

sk-dist 的介绍

我们很高兴地宣布我们的开源项目 sk-dist 的启动。该项目的目标是为使用 Spark 分发 scikit 学习元估计器提供一个通用框架。元估计器的例子有决策树集合(随机林和额外随机树)、超参数调解器(网格搜索和随机搜索)和多分类技术(一对多和多对一)。

我们的主要动机是填补传统机器学习模型空间的空白。在神经网络和深度学习的空间之外,我们发现我们的训练模型的大部分计算时间并没有花在训练单个数据集的单个模型上。相反,大部分时间都花在使用元估计器在数据集上训练模型的多次迭代上。

例子

让我们谈谈手写数字数据集。在这里,我们对手写数字的图像进行了适当的编码、分类。我们可以很快在一台机器上训练 1797 条记录的支持向量机,花费的时间不到一秒钟。但超参数调整需要在训练数据的不同子集上进行大量的训练。

如下图所示,我们已经构建了一个总计需要 1050 个训练的参数网格。在拥有 100 多个核的 Spark 上使用 sk dist 只需 3.4 秒。这项工作的总时间是 7.2 分钟,意思是在没有并行化的单机上训练要花这么长时间。import timefrom sklearn import datasets, svm

from skdist.distribute.search import DistGridSearchCV

from pyspark.sql import SparkSession # instantiate spark session

spark = (

SparkSession

.builder

.getOrCreate()

)

sc = spark.sparkContext # the digits dataset

digits = datasets.load_digits()

X = digits["data"]

y = digits["target"] # create a classifier: a support vector classifier

classifier = svm.SVC()

param_grid = {

"C": [0.01, 0.01, 0.1, 1.0, 10.0, 20.0, 50.0],

"gamma": ["scale", "auto", 0.001, 0.01, 0.1],

"kernel": ["rbf", "poly", "sigmoid"]

}

scoring = "f1_weighted"

cv = 10# hyperparameter optimization

start = time.time()

model = DistGridSearchCV(

classifier, param_grid,

sc=sc, cv=cv, scoring=scoring,

verbose=True

)

model.fit(X,y)

print("Train time: {0}".format(time.time() - start))

print("Best score: {0}".format(model.best_score_))------------------------------

Spark context found; running with spark

Fitting 10 folds for each of 105 candidates, totalling 1050 fits

Train time: 3.380601406097412

Best score: 0.981450024203508

这个例子演示了一个常见的场景,在这个场景中,将数据拟合到内存中并训练单个分类器是很简单的,但是适合超参数优化所需的匹配数量会迅速增加。下面是一个运行网格搜索问题的例子,和上面的 sk dist 示例类似:

带sk-dist的网格搜索

对于 ibotta 传统机器学习的实际应用,我们经常发现自己处于类似这样的情况中:中小型数据(10k 到 1M 的记录)和许多简单分类器迭代以适应超参数调整、集成和多分类解决方案。

现有解决方案

传统的机器学习元估计器训练方法已经存在。第一个是最简单的:scikit-learn 使用 joblib 内置的元估计器并行化。这与 sk-dist 的操作非常相似,但是它有一个主要的限制:性能受限于任何机器的资源。即使与理论上拥有数百个内核的单机相比,Spark 仍然具有一些优势,如执行器的微调内存规范、容错,以及成本控制选项,如对工作节点使用 spot 实例。

另一个现有的解决方案是 Spark ML,它是 Spark 的一个本地机器学习库,支持许多与 scikit-learn 相同的算法来解决分类和回归问题。它还具有诸如树集合和网格搜索之类的元估计器,以及对多分类问题的支持。

分布在不同的维度上

如上所示,Spark ML 将针对分布在多个执行器上的数据来训练单个模型。当数据量很大,以至于无法存入一台机器上的内存时,这种方法可以很好地工作。然而,当数据量很小时,在单台机器上这可能会比 scikit-learn 的学习效果差。此外,例如,当训练一个随机森林时,Spark ML 按顺序训练每个决策树。此项工作的时间将与决策树的数量成线性比例,和分配给该任务的资源无关。

对于网格搜索,Spark ML 实现了一个并行参数,该参数将并行地训练各个模型。然而,每个单独的模型仍在对分布在执行器之间的数据进行训练。这项任务的总并行度只是纯粹按照模型维度来的,而不是数据分布的维度。

最后,我们希望将我们的训练分布在与 Spark ML 不同的维度上。当使用中小型数据时,将数据拟合到内存中不是问题。对于随机森林的例子,我们希望将训练数据完整地广播给每个执行器,在每个执行者身上拟合一个独立的决策树,并将这些拟合的决策树带回给驱动器,以集合成一个随机森林。这个维度比串行分布数据和训练决策树快几个数量级。

特征

考虑到这些现有解决方案在我们的问题空间中的局限性,我们内部决定开发 sk-dist。归根结底,我们希望发布的是模型,而不是数据。

虽然 sk-dist 主要关注元估计器的分布式训练,但它也包括很多其它模块,如 Spark 的 scikit-learn 模型的分布式预测模块等。分布式训练——使用 Spark 进行分布式元估计训练,支持以下算法:带网格搜索和随机搜索的超参数优化、带随机林的树集合、额外树和随机树嵌入,以及一对一和一对多的多分类策略。

分布预测——具有 Spark 数据帧的拟合 scikit-learn 估计器的预测方法。这使得带有 scikit-learn 的大规模分布式预测可以在没有 Spark 的情况下进行。

特征编码——分布特征编码使用被称为编码器的灵活特征变换器来完成。不管有没有 Spark,它都可以起作用。它将推断数据类型,自动应用默认的特征变换器作为标准特征编码技术的最佳实现。它还可以作为一个完全可定制的功能联合,如编码器,它的附加优势是与 Spark 匹配的分布式 transformer。

用例

以下是判断 sk-dist 是否适合解决你的机器学习问题的一些准则:传统的机器学习方法,如广义线性模型、随机梯度下降、最近邻、决策树和朴素贝叶斯等,都能很好地应用于 sk-dist,这些方法都可以在 scikit-learn 中实现,并且可以直接应用于 sk-dist 元估计。

中小型数据、大数据不能很好地在 sk-dist 中起作用。记住,分布式训练的维度是沿着模型的轴,而不是数据。数据不仅需要放在每个执行器的内存中,而且要小到可以传播。根据 Spark 配置,最大传播大小可能会受到限制。

Spark 定向和访问——sk-dist 的核心功能需要运行 Spark。对于个人或小型数据科学团队来说,这并不总是可行的。

这里一个重要的注意事项是,虽然神经网络和深度学习在技术上可以用于 sk-dist,但这些技术需要大量的训练数据,有时需要专门的基础设施才能有效。深度学习不是 sk-dist 的最佳用例,因为它违反了上面的(1)和(2)。

开始

要开始使用 sk-dist,请查看安装指南。代码库还包含一个示例库,用于说明 sk-dist 的一些用例。欢迎所有人提交问题并为项目做出贡献。

雷锋网(公众号:雷锋网)雷锋网雷锋网

雷锋网版权文章,未经授权禁止转载。详情见转载须知。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/282112.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

理解LinkedHashMap

1. LinkedHashMap概述:LinkedHashMap是HashMap的一个子类,它保留插入的顺序,如果需要输出的顺序和输入时的相同,那么就选用LinkedHashMap。LinkedHashMap是Map接口的哈希表和链接列表实现,具有可预知的迭代顺序。此实现…

MySQL - 锁

一、什么是锁 锁是数据库系统区别于文件系统的一个关键特性。锁机制用于管理对共享资源的并发访问。 二、MySQL 不同存储引擎支持的锁机制 存储引擎支持的锁类型Myisam表锁Innodb行锁、表锁Memory表锁BDB页锁、表锁表锁:直接锁住的是一个表,开销小&…

数据库时区那些事儿 - MySQL的时区处理

原文地址 当JVM时区和数据库时区不一致的时候,会发生什么?这个问题也许你从来没有注意过,但是当把Java程序容器化的时候,问题就浮现出来了,因为目前几乎所有的Docker Image的时区都是UTC。本文探究了MySQL及其JDBC驱动…

java_函数的重载

函数的重载(Overload)概念:在同一个类中,允许存在一个以上的同名函数,只要他们的参数个数或者参数类型不同即可。函数功能一样,仅仅是参与运算的未知内同不同时,可以定义多函数,却使…

全新升级的AOP框架Dora.Interception[2]: 基于约定的拦截器定义方式

Dora.Interception(github地址,觉得不错不妨给一颗星)有别于其他AOP框架的最大的一个特点就是采用针对“约定”的拦截器定义方式。如果我们为拦截器定义了一个接口或者基类,那么拦截方法将失去任意注册依赖服务的灵活性。除此之外…

redis watch使用场景_redis不得不会的事务玩法

我们都知道redis追求的是简单,快速,高效,在这种情况下也就拒绝了支持window平台,学sqlserver的时候,我们知道事务还算是个比较复杂的东西,所以这吊毛要是照搬到redis中去,理所当然redis就不是那…

加快Android Studio的编译速度

从Eclipse切换到Android Studio后,感觉Android Studio的build速度比Eclipse慢很多,以下几个方法可以提高Android Studio的编译速度使用Gradle 2.4Gradle 2.4对执行性能有很大的优化,但Android Studio现在默认使用的是Gradle 2.2,所以我们需要…

开发中 MySQL 规范

一、建表规范 1、数据库名、表名、字段名必须使用小写字母或数字,并且禁止以数字开头 示例:goods_category、agent_operate_201812_log 2、数据库名、表名、字段名要做到见名识意 示例:goods_category,不能 gc 3、配置表建议以 …

PaddleOCR在 Linux下的webAPI部署方案

很多小伙伴在使用OCR时都希望能采用API的方式调用,这样就可以跨端跨平台了。本文将介绍一种基于python的PaddleOCR识别WebAPI部署方案。喜欢的可以关注公众号,获取更多内容。一、 Linux环境下部署1.环境要求操作系统:CenterOS7;主…

影响程序员生涯的三个错误观念,你千万不要犯!

程序员在社会上,到底是怎样一个生活群体?是否能找到自己方向?其实,路一直都在那里,只是你看不到而已! 当初的你,可能一直被一些技术牵着鼻子走,并不是自己在做着自己想做的&#xff…

心电图计算心率公式_心电图到底能反应啥问题,看过之后你也能当“医生”

只要是经历过健康体检的健康人,或者做过手术的患者,基本都做过心电图检查。都说久病成医,所以有些人对血、尿常规等各项检查的结果都门清儿得很,最起码看一眼也能说出个大概齐。偏偏心电图这种常做的检查,不但老病号如…

获取正在运行的服务

手机上安装的App,在后台运行着很多不同功能的服务,最常见的例如消息推送相关的服务。如何查看这些服务?如何判断某个服务是否正在运行?如何停止某一个服务呢?请看下面的方法: package com.example.servicel…

openstack的vnc启动ssl

1、制作ssl证书# cd /etc/pki/tls/certs [rootwww certs]# make vnc.key Enter pass phrase:# 输入密码 Verifying - Enter pass phrase:#确认# 从private key 中删除密码# openssl rsa -in vnc.key -out vnc.key # make vnc.csr Country Name (2 letter code) [XX]:CN# 国家 S…

开发composer包

一、初始化&#xff08;生成composer.json文件&#xff09; composer init#输入你要创建的composer包项目命名空间 Package name (<vendor>/<name>) [root/tiny-laravel]: #haveyb/tiny-laravel #输入composer包的描述 Description []:#this is a tiny laravel h…

Linux本地yum源配置以及使用yum源安装gcc编译环境

本文档是图文安装本地yum源的教程&#xff0c;以安装gcc编译环境为例。 适用范围&#xff1a;所有的cetos,红帽,fedroa版本 适用人群&#xff1a;有一点linux基础的小白 范例系统版本&#xff1a;CentOS Linux release 7.3.1611 (Core) 范例环境&#xff1a;vmware 虚拟机 安装…

word如何设置上标形式_如何在word中设置特殊页码

获取更多业界资讯和深度好文● 点击蓝字关注我们 ●在日常工作中&#xff0c;我们编辑的word文档经常需要设置页码&#xff0c;但有时文档的第一页是封面&#xff0c;第二页才是正文&#xff0c;或者第二页是目录&#xff0c;第三页才是正文&#xff0c;如下图所示&#xff0c;…

[cf797c]Minimal string(贪心+模拟)

题意&#xff1a; 给出了字符串s的内容&#xff0c;字符串t&#xff0c;u初始默认为空&#xff0c;允许做两种操作&#xff1a; 1、把s字符串第一个字符转移到t字符串最后 2、把t字符串最后一个字符转移到u字符串最后 最后要求s、t字符串都为空&#xff0c;问u字符串字典序最小…

发布composer包到 Packagist,并设置自动同步(从github到Packagist)

一、发布composer包 1、将我们写好的项目包发布到github上 这一步不赘述&#xff0c;应该都会。 但是需要注意的是&#xff0c;我们一定要为我们的项目包打上tag之后再提交&#xff0c;否则 我们composer require时可能会报错 Could not find a version of package。 # 设置…

教你在CorelDRAW中导入位图

在CorelDRAW软件中不能直接打开位图图像&#xff0c;在实际操作中&#xff0c;用户需要使用导入位图图像的方法进行操作。导入位图图像时&#xff0c;可以导入整幅图像&#xff0c;也可以在导入的过程中对图像进行裁剪&#xff0c;或重新取样图像&#xff0c;导入整幅位图图像时…

.NET 6 中将 ASP.NET Core 注册成 Windows Service

前言使用 Visual Studio 中的 Worker Service项目模板:我们很容易创建出 Windows Service&#xff1a;IHost host Host.CreateDefaultBuilder(args).UseWindowsService().ConfigureServices(services >{services.AddHostedService<Worker>();}).Build();await host.R…