基于Scala开发Spark ML的ALS推荐模型实战

推荐系统,广泛应用到电商,营销行业。本文通过Scala,开发Spark ML的ALS算法训练推荐模型,用于电影评分预测推荐。

算法简介

ALS算法是Spark ML中实现协同过滤的矩阵分解方法。

ALS,即交替最小二乘法(Alternating Least Squares),是协同过滤技术中的一种经典算法。它通过对用户和物品的潜在特征进行建模,来预测用户对未知物品的评分或偏好。具体介绍如下:

  1. 矩阵分解模型:在推荐系统中,我们通常有一个用户-物品的评分矩阵,其中行表示用户,列表示物品,矩阵中的值代表用户对物品的评分。然而,这个矩阵通常是非常稀疏的,因为用户只给少数物品评分。ALS算法就是在这样的不完整评分矩阵上操作,通过矩阵分解来补全缺失值,进而产生推荐。
  2. 算法原理:ALS算法的核心思想是通过迭代过程更新用户和物品的潜在因子向量。在每次迭代中,一个评分被建模为用户潜在特征向量和物品潜在特征向量的点积,加上一个偏差项。通过最小化实际评分和预测评分之间的差异来不断优化这些潜在特征向量。
  3. Spark ML实现:在Spark ML库中,ALS算法被用于处理大规模的数据集,并提供了多种参数以适应不同的数据特性和需求。例如,可以设置潜在因子的数量、正则化参数、迭代次数等。此外,Spark ML的ALS还支持隐式反馈数据的变体,这对于无法获取明确评分的数据非常有用。

总的来说,ALS是一种强大的推荐系统算法,尤其适用于处理大规模稀疏数据集。通过合理地选择和调整参数,可以在保持高效计算的同时获得良好的推荐质量。

代码实战

pom.xml文件更新,加入相关依赖

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"><modelVersion>4.0.0</modelVersion><groupId>org.example</groupId><artifactId>sparkGNU2023</artifactId><version>1.0-SNAPSHOT</version><properties><maven.compiler.source>8</maven.compiler.source><maven.compiler.target>8</maven.compiler.target><project.build.sourceEncoding>UTF-8</project.build.sourceEncoding><scala.version>2.13</scala.version><spark.version>3.4.1</spark.version><log4j.version>1.2.17</log4j.version><slf4j.version>1.7.22</slf4j.version></properties><dependencies><!--日志相关依赖--><dependency><groupId>org.slf4j</groupId><artifactId>jcl-over-slf4j</artifactId><version>${slf4j.version}</version></dependency><dependency><groupId>org.slf4j</groupId><artifactId>slf4j-api</artifactId><version>${slf4j.version}</version></dependency><dependency><groupId>org.slf4j</groupId><artifactId>slf4j-log4j12</artifactId><version>${slf4j.version}</version></dependency><dependency><groupId>log4j</groupId><artifactId>log4j</artifactId><version>${log4j.version}</version></dependency><dependency><groupId>com.thoughtworks.paranamer</groupId><artifactId>paranamer</artifactId><version>2.8</version></dependency><dependency><groupId>org.apache.spark</groupId><artifactId>spark-core_2.13</artifactId><version>3.4.1</version></dependency><dependency><groupId>org.apache.spark</groupId><artifactId>spark-sql_2.13</artifactId><version>${spark.version}</version></dependency><dependency><groupId>org.apache.spark</groupId><artifactId>spark-streaming_2.13</artifactId><version>${spark.version}</version></dependency><dependency><groupId>org.apache.spark</groupId><artifactId>spark-hive_2.13</artifactId><version>${spark.version}</version></dependency><dependency><groupId>org.apache.spark</groupId><artifactId>spark-streaming-kafka-0-10_2.13</artifactId><version>3.4.1</version></dependency><dependency><groupId>org.apache.spark</groupId><artifactId>spark-mllib_2.13</artifactId><version>${spark.version}</version></dependency><dependency><groupId>org.apache.spark</groupId><artifactId>spark-streaming-kafka-0-8_2.11</artifactId><version>2.4.8</version></dependency><dependency><groupId>mysql</groupId><artifactId>mysql-connector-java</artifactId><version>8.0.30</version></dependency><dependency><groupId>org.apache.flume.flume-ng-clients</groupId><artifactId>flume-ng-log4jappender</artifactId><version>1.11.0</version></dependency><!--        flume 拦截器相关依赖--><dependency><groupId>org.apache.flume</groupId><artifactId>flume-ng-core</artifactId><version>1.9.0</version><scope>provided</scope></dependency><dependency><groupId>com.alibaba</groupId><artifactId>fastjson</artifactId><version>1.2.62</version></dependency></dependencies><build><plugins><plugin><groupId>org.apache.maven.plugins</groupId><artifactId>maven-compiler-plugin</artifactId><version>3.8.1</version><configuration><source>1.8</source><target>1.8</target></configuration></plugin><plugin><groupId>org.apache.maven.plugins</groupId><artifactId>maven-assembly-plugin</artifactId><version>3.6.0</version><configuration><descriptorRefs><descriptorRef>jar-with-dependencies</descriptorRef></descriptorRefs></configuration><executions><execution><id>make-assembly</id><phase>package</phase><goals><goal>single</goal></goals></execution></executions></plugin></plugins></build></project>

训练ALS模型

基于scala训练ALS模型

package base.charpter10import breeze.linalg.sum
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.recommendation.ALS
import org.apache.spark.sql.functions.{col, count, explode, when}
import org.apache.spark.sql.{DataFrame, SparkSession}/*** @projectName sparkGNU2023  * @package base.charpter10  * @className base.charpter10.MovieRecommender  * @description ${description}  * @author pblh123* @date 2024/3/29 15:18* @version 1.0**/object MovieRecommender {def main(args: Array[String]): Unit = {// 创建Spark会话val spark = SparkSession.builder().appName("MovieRecommender").master("local[*]").getOrCreate()import spark.implicits._// 假设我们有一个用户-物品评分数据集,格式为(userId, itemId, rating)/*** UserID,MovieID,Rating,Timestamp*  1,1193,5,978300760*  1,661,3,978302109*/// 指定CSV文件的路径,以及解析选项val csvFilePath = "data/ratings.csv"val csvOptions = Map("header" -> "true", // 是否有列名头"inferSchema" -> "true", // 是否自动推断数据类型"encoding" -> "UTF-8", // 如果有特定的编码格式,例如对于包含中文的CSV文件:)// 读取CSV文件并创建DataFrameval ratingsDF = spark.read.format("csv").options(csvOptions).load(csvFilePath)// 显示DataFrame的前几行以验证数据是否正确加载println("查看原始据数据样例:")ratingsDF.show(5)val ratings: DataFrame = ratingsDF.select("UserID", "MovieID", "Rating").withColumnRenamed("UserID", "userId").withColumnRenamed("MovieID", "itemId").withColumnRenamed("Rating", "rating")// 将数据集分割为训练集和测试集val Array(training, test) = ratings.randomSplit(Array(0.8, 0.2))println("查看训练集数据")training.show(5)println("查看测试集数据")test.show(5)// 设置ALS参数// 创建一个ALS实例并配置参数val als = new ALS().setMaxIter(10) // 设置最大迭代次数为5,10,本地测试时,设置过大,会报错.setRegParam(0.01) // 设置正则化参数为0.01.setUserCol("userId") // 设置用户列名为"userId".setItemCol("itemId") // 设置物品列名为"itemId".setRatingCol("rating") // 设置评分列名为"rating"/*** ALS(Alternating Least Squares)是一种基于矩阵分解的协同过滤算法,用于处理用户和物品之间的评分数据。各参数说明如下:*  setMaxIter: 设置最大迭代次数,决定模型训练的精细程度。迭代次数越多,模型通常越精确,但训练时间也可能更长。*  setRegParam: 设置正则化参数,用于控制模型的复杂度和过拟合程度。较小的正则化参数值可能导致模型过复杂,容易过拟合;较大的值则可能导致模型过于简单,欠拟合。*  setUserCol, setItemCol, setRatingCol: 分别设置用户ID列、物品ID列和评分列的名称。这些列名根据实际的数据结构来确定,用于告诉ALS算法在哪些列中查找用户、物品和评分信息。*/// 训练ALS模型println("开始训练模型")val model = als.fit(training)// 对测试集进行预测val predictions = model.transform(test)predictions.show()predictions.filter($"rating".isNotNull && $"prediction".isNotNull).count() // 确认有非空的评分和预测值// 评估模型val evaluator = new RegressionEvaluator().setMetricName("rmse").setLabelCol("rating").setPredictionCol("prediction")val rmse = evaluator.evaluate(predictions)println(s"Root-mean-square error = $rmse")// 为用户生成推荐// 该函数是基于一个模型(model)为所有用户推荐项目的函数。它将为每个用户推荐5个项目/*** +------+--------------------------------------------------------------------------------------------+*  |userId|recommendations[{itemid,pred_rating},{itemid,pred_rating},...]                                                                             |*  +------+--------------------------------------------------------------------------------------------+*  |12    |[{1864, 9.721167}, {2964, 8.815781}, {3867, 8.480173}, {1539, 7.8904114}, {563, 7.8829007}] |*  |22    |[{2964, 6.090676}, {3215, 5.6165895}, {1534, 5.4731245}, {718, 5.462125}, {2632, 5.4482727}]|*/val userRecs = model.recommendForAllUsers(5)userRecs.show(5,false)println("保存预测结果")
//    userRecs.write.mode("overwrite").parquet("models/recomALSmodel") // 保存为parquet格式,一般用于集群中
// userRecs是一个DataFrame,其中"recommendations"列是数组类型val explodedUserRecs = userRecs.withColumn("recommendations", explode($"recommendations")).select($"userId", $"recommendations.itemId".as("itemId"), $"recommendations.rating".as("PredRating"))explodedUserRecs.write.mode("overwrite").format("csv").save("predictRes/recomALS")  // PC 调试使用// 保存模型到指定路径val modelPath = "models/recomALSmodel"model.write.overwrite().save(modelPath)println(s"Model saved to $modelPath")// 停止Spark会话spark.stop()/*当程序试图停止Spark会话时,可能会触发清理临时文件的操作,从而导致出现NoSuchFileException异常。通常情况下,这不是代码逻辑的问题,而是Spark内部在清理资源时可能出现的问题。可以尝试重启Spark环境或者适当增大Spark的临时目录空间来避免此类问题。*/}}

运行代码,效果图如下

TodoList:目前RMSE计算出问题,原数据清洗没有做,模型参数还可以调整。后期调整更新后,再发一篇文章。

使用训练的模型预测新数据

scala开发应用模型demo代码

package base.charpter10import org.apache.spark.ml.recommendation.ALSModel
import org.apache.spark.sql.SparkSession/*** @projectName sparkGNU2023  * @package base.charpter10  * @className base.charpter10.RecommendationModelLoadDemo  * @description ${description}  * @author pblh123* @date 2024/3/29 15:36* @version 1.0**/object RecommendationModelLoadDemo {def main(args: Array[String]): Unit = {// 创建Spark会话val spark = SparkSession.builder().master("local[*]").appName("RecommendationModelUsageDemo").getOrCreate()import spark.implicits._// 加载之前保存的ALS模型val modelPath = "models/recomALSmodel"val loadedModel: ALSModel = ALSModel.load(modelPath)// 假设我们有一些新的用户-物品对,我们想要预测它们的评分val userItemPairs = Seq((1, 4), // 用户1对物品4的评分预测(2, 2) // 用户2对物品2的评分预测).toDF("userId", "itemId")// 使用模型进行评分预测val predictions = loadedModel.transform(userItemPairs)predictions.show()// 现在,假设我们想要为用户1生成前N个推荐物品val numRecommendations = 5 // 为用户推荐的物品数量val userRecs = loadedModel.recommendForAllUsers(numRecommendations)userRecs.show(5,false)// 停止Spark会话spark.stop()}}

运行效果如下

评估效果说明:目前的预测评分不合理,是因为模型没有经过精挑,优化,预测的记过会依据预测评分高低排序,选取得分高的前5个结果返回。后期模型调优后,结果就正常了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/788004.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

4.2 JavaWeb Day05分层解耦

三层架构功能 controller层接收请求&#xff0c;响应数据&#xff0c;层内调用了service层的方法&#xff0c;service层仅负责业务逻辑处理&#xff0c;其中要获取数据&#xff0c;就要去调用dao层&#xff0c;由dao层进行数据访问操作去查询数据&#xff08;进行增删改查&…

new mars3d.layer.HeatLayer({实现动态修改热力图半径

1.使用热力图插件的时候&#xff0c;实现动态修改热力图效果半径 2.直接修改是不可以的&#xff0c;因为这个是热力图本身的参数。 因此我们需要拿到这个热力图对象之后&#xff0c;参考api文档&#xff0c;对整个 heatLayer.heatStyle进行传参修改。 heatStyle地址&#x…

HarmonyOS 应用开发之featureAbility接口切换particleAbility接口切换

featureAbility接口切换 FA模型接口Stage模型接口对应d.ts文件Stage模型对应接口getWant(callback: AsyncCallback<Want>): void; getWant(): Promise<Want>;ohos.app.ability.UIAbility.d.tslaunchWant: Want;startAbility(parameter: StartAbilityParameter, c…

【MySQL系列】使用 ALTER TABLE 语句修改表结构的方法

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

智慧驿站式的“智慧公厕”,给城市新基建带来新变化

随着智慧城市建设的推进&#xff0c;智慧驿站作为一种多功能城市部件&#xff0c;正逐渐在城市中崭露头角。这些智慧驿站集合了智慧公厕的管理功能&#xff0c;为城市的新基建带来了全新的变革。本文以智慧驿站智慧公厕源头实力厂家广州中期科技有限公司&#xff0c;大量精品案…

FreeRTOS--3

1.总结任务调度算法之间的区别&#xff0c;重新实现一遍任务调度算法的代码。 一&#xff0c;抢占式调度&#xff1a;高优先级的任务可以打断低优先级任务的执行。 抢占式调度适用于任务优先级不同的任务。使用默认的任务去创建一个优先级比他高的任务&#xff0c;观察抢占式调…

HTML基础:脚本 script 标签

你好&#xff0c;我是云桃桃。 1枚程序媛&#xff0c;大专生&#xff0c;2年时间从1800到月入过万&#xff0c;工作5年买房。 分享成长心得。 255篇原创内容-公众号 后台回复“前端工具”可获取开发工具&#xff0c;持续更新中 后台回复“前端基础题”可得到前端基础100题汇…

CSS(二)---【常见属性、复合属性使用】

零.前言 本篇文章主要阐述CSS常见属性、复合属性&#xff0c;更多前置知识请见作者其它文章&#xff1a; CSS(一)---【CSS简介、导入方式、八种选择器、优先级】-CSDN博客 1.CSS属性 CSS的属性有上百个&#xff0c;但是我们并不需要全部学习&#xff0c;只要我们学习一部分…

MuJoCo 入门教程(一)

系列文章目录 前言 一、简介 MuJoCo 是多关节接触动力学&#xff08;Multi-Joint dynamics with Contact&#xff09;的缩写。它是一个通用物理引擎&#xff0c;旨在促进机器人、生物力学、图形和动画、机器学习以及其他需要快速、准确地仿真铰接结构与环境交互的领域的研究和开…

数据结构(初阶)第二节:顺序表

从本文正式进入对数据结构的讲解&#xff0c;开始前友友们要有C语言的基础&#xff0c;熟练掌握动态内存管理、结构体、指针等章节&#xff0c;方便后续的学习。 顺序表&#xff08;Sequence List&#xff09; 线性表的概念&#xff1a;线性表&#xff08;linear list&#xff…

手写简易操作系统(二十二)--时间管理

前情提要 上一节我们实现了硬盘的驱动&#xff0c;本来这一节打算实现文件系统的&#xff0c;但是文件系统中有个时间属性&#xff0c;所以这里我们先实现操作系统的时间管理。 一、除法 我们的系统是一个32位的系统&#xff0c;在编译一些除法的时候编译的时候没问题&#…

保持ssh断开后,程序不会停止执行

保持ssh断开后&#xff0c;程序不会停止执行 一、前言 笔者做远程部署搞了一阵子&#xff0c;快结项时发现一旦我关闭了ssh连接窗口&#xff0c;远程服务器会自动杀掉我在ssh连接状态下运行的程序。 这怎么行&#xff0c;岂不是想要它一直运行还得要一台电脑一直打开ssh连接咯…

曲线降采样之道格拉斯-普克算法Douglas–Peucker

曲线降采样之道格拉斯-普克算法Douglas–Peucker 该算法的目的是&#xff0c;给定一条由线段构成的曲线&#xff0c;找到一条点数较少的相似曲线&#xff0c;来近似描述原始的曲线&#xff0c;达到降低时间、空间复杂度和平滑曲线的目的。 附赠自动驾驶学习资料和量产经验&…

【C语言】“vid”Microsoft Visual Studio安装及应用(检验内存泄露)

文章目录 前言安装包获取配置VLD完成 前言 我们在写代码时往往容易存在内存泄漏的情况&#xff0c;所以存在这样一个名为VLD的工具用来检验内存泄漏&#xff0c;现在我来教大家安装一下 安装包获取 vld下载网址&#xff1a;https://github.com/KindDragon/vld/releases/tag/…

YOLOv8结合SCI低光照图像增强算法!让夜晚目标无处遁形!【含端到端推理脚本】

这里的"SCI"代表的并不是论文等级,而是论文采用的方法 — “自校准光照学习” ~ 左侧为SCI模型增强后图片的检测效果,右侧为原始v8n检测效果 这篇文章的主要内容是通过使用SCI模型和YOLOv8进行算法联调,最终实现了如上所示的效果:在增强图像可见度的同时,对图像…

【中文视觉语言模型+本地部署 】23.08 阿里Qwen-VL:能对图片理解、定位物体、读取文字的视觉语言模型 (推理最低12G显存+)

项目主页&#xff1a;https://github.com/QwenLM/Qwen-VL 通义前问网页在线使用——&#xff08;文本问答&#xff0c;图片理解&#xff0c;文档解析&#xff09;&#xff1a;https://tongyi.aliyun.com/qianwen/ 论文v3. : 一个全能的视觉语言模型 23.10 Qwen-VL: A Versatile…

读取信息boot.bin和xclbin命令

bootgen读Boot.bin命令 johnjohn-virtual-machine:~/project_zynq/kv260_image_ubuntu22.04$ bootgen -read BOOT-k26-starter-kit-202305_2022.2.bin xclbinutil读xclbin命令 johnjohn-virtual-machine:~/project_zynq/kv260_image_ubuntu22.04$ xclbinutil -i kv260-smartca…

Linux权限提升总结

几个信息收集的项目推荐 运行这几个项目就会在目标主机上收集一些敏感信息供我们参考和使用 一个综合探针&#xff1a;traitor 一个自动化提权&#xff1a;BeRoot(gtfo3bins&lolbas) 使用python2运行beroot.py就可以运行程序&#xff0c;然后就可以收集到系统中的大量信…

mysql锁表问题

问题描述 偶尔应用日志会打印锁表超时回滚 org.springframework.dao.CannotAcquireLockException: ### Error updating database. Cause: com.mysql.cj.jdbc.exceptions.MySQLTransactionRollbackException: Lock wait timeout exceeded; try restarting transactionmysql锁…