UDF 函数
UDF 是我们用户可以自定义的函数,我们通过SparkSession对象来调用 udf 的 register(name:String,func(A1,A2,A3...)) 方法来注册一个我们自定义的函数。其中,name 是我们自定义的函数名称,func 是我们自定义的函数,它可以有很多个参数。
通过 UDF 函数,我们可以针对某一列数据或者某单元格数据进行针对的处理。
案例 1
定义一个函数,给 Andy 的 name 字段的值前 + "Name: "。
def main(args: Array[String]): Unit = {val conf = new SparkConf()conf.setAppName("spark sql udf").setMaster("local[*]")val spark = SparkSession.builder().config(conf).getOrCreate()import spark.implicits._val df = spark.read.json("data/sql/people.json")df.createOrReplaceTempView("people")spark.udf.register("prefixName",(name:String)=>{if (name.equals("Andy"))"Name: " + nameelsename})spark.sql("select prefixName(name) as name,age,sex from people").show()spark.stop()}
这里我们定义了一个自定义的 UDF 函数:prefixName,它会判断name字段的值是否为 "Andy",如果是,就会在她的值前+"Name: "。
运行结果:
+----------+---+---+
| name|age|sex|
+----------+---+---+
| Michael| 30| 男|
|Name: Andy| 19| 女|
| Justin| 19| 男|
|Bernadette| 20| 女|
| Gretchen| 23| 女|
| David| 27| 男|
| Joseph| 33| 女|
| Trish| 27| 女|
| Alex| 33| 女|
| Ben| 25| 男|
+----------+---+---+
UDAF 函数
强类型的DataSet和弱类型的DataFrame都提供了相关聚合函数,如count、countDistinct、avg、max、min。
UDAF 也就是我们用户的自定义聚合函数。聚合函数就比如 avg、sum这种函数,需要先把所有数据放到一起(缓冲区),再进行统一处理的一个函数。
实现 UDAF 函数需要有我们自定义的聚合函数的类(主要任务就是计算),我们可以继承 UserDefinedAggregateFunction,并实现里面的八种方法,来实现弱类型的聚合函数。(Spark3.0之后就不推荐使用了,更加推荐强类型的聚合函数)
我们可以继承Aggregator来实现强类型的聚合函数。
案例1 - 平均年龄
case 类可以直接构建对象,不需要new,因为样例类可以自动生成它的伴生对象和apply方法。
弱类型实现
import org.apache.spark.SparkConf
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StructField, StructType}/*** 弱类型*/
object UDAFTest01 {def main(args: Array[String]): Unit = {val conf = new SparkConf()conf.setAppName("spark sql udaf").setMaster("local[*]")val spark = SparkSession.builder().config(conf).getOrCreate()import spark.implicits._val df = spark.read.json("data/sql/people.json")df.createOrReplaceTempView("people")spark.udf.register("avgAge",new MyAvgUDAF())spark.sql("select avgAge(age) from people").show()spark.stop()}
}
class MyAvgUDAF extends UserDefinedAggregateFunction{// 输入数据的结构 INoverride def inputSchema: StructType = {StructType(Array(StructField("age",LongType)))}// 缓冲区数据的结构 BUFFERoverride def bufferSchema: StructType = {StructType(Array(StructField("total",LongType),StructField("count",LongType)))}// 函数计算结果的数据类型 OUToverride def dataType: DataType = LongType// 函数的稳定性 (传入相同的参数结果是否相同)override def deterministic: Boolean = true// 缓冲区初始化override def initialize(buffer: MutableAggregationBuffer): Unit = {//这两种写法都一样
// buffer(0) = 0L
// buffer(1) = 0L//第二种方法buffer.update(0,0L) //total 给缓冲区的第0个数据结构-total-初始化赋值0Lbuffer.update(1,0L) //count 给缓冲区的第1个数据结构-count-初始化赋值0L}// 数据过来之后 如何更新缓冲区override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {// 第一个参数代表缓冲区的第i个数据结构 0代表total 1代表count// 第二个参数是对第一个参数的数据结构进行重新赋值// buffer.getLong(0)是取出缓冲区第0个值-也就是total的值,给它+上输入的值中的第0个值(因为我们输入结构只有一个就是age:Long)buffer.update(0,buffer.getLong(0)+input.getLong(0))buffer.update(1,buffer.getLong(1)+1) //count 每次数据过来+1}// 多个缓冲区数据合并override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {buffer1.update(0,buffer1.getLong(0)+buffer2.getLong(0))buffer1.update(1,buffer1.getLong(1)+buffer2.getLong(1))}// 计算结果操作override def evaluate(buffer: Row): Any = {buffer.getLong(0)/buffer.getLong(1)}
}
运行结果:
+-----------+
|avgage(age)|
+-----------+
| 25|
+-----------+
强类型实现
import org.apache.spark.SparkConf
import org.apache.spark.sql.{Encoder, Encoders, Row, SparkSession, functions}
import org.apache.spark.sql.expressions.Aggregator/*** 强类型*/
object UDAFTest02 {def main(args: Array[String]): Unit = {val conf = new SparkConf()conf.setAppName("spark sql udaf").setMaster("local[*]")val spark = SparkSession.builder().config(conf).getOrCreate()import spark.implicits._val df = spark.read.json("data/sql/people.json")df.createOrReplaceTempView("people")spark.udf.register("avgAge",functions.udaf(new MyAvg_UDAF()))spark.sql("select avgAge(age) from people").show()spark.stop()}
}/*** 自定义聚合函数类:* 1.继承org.apache.spark.sql.expressions.Aggregator,定义泛型:* IN : 输入数据类型 Long* BUF : 缓冲区数据类型* OUT : 输出数据类型 Long* 2.重写方法*/
//样例类中的参数默认是 val 所以这里必须指定为var
case class Buff(var total: Long,var count: Long)
class MyAvg_UDAF extends Aggregator[Long,Buff,Long]{// zero: Buff zero代表这个方法是用来初始值(0值)// Buff是我们的case类 也就是说明这里是用来给 缓冲区进行初始化override def zero: Buff = {Buff(0L,0L)}// 根据输入数据更新缓冲区 要求返回-Buffoverride def reduce(buff: Buff, in: Long): Buff = {buff.total += inbuff.count += 1buff}// 合并缓冲区 同样返回buff1override def merge(buff1: Buff, buff2: Buff): Buff = {buff1.total += buff2.totalbuff1.count += buff2.countbuff1}// 计算结果override def finish(buff: Buff): Long = {buff.total/buff.count}// 网络传输需要序列化 缓冲区的编码操作 -编码override def bufferEncoder: Encoder[Buff] = Encoders.product// 输出的编码操作 -解码override def outputEncoder: Encoder[Long] = Encoders.scalaLong
}
运行结果:
+-----------+
|avgage(age)|
+-----------+
| 25|
+-----------+
早期UDAF强类型聚合函数
SQL:结构化数据查询 & DSL:面向对象查询(有对象有方法,与类型相关,所以通过DSL语句结合起来使用)
早期的UDAF强类型聚合函数使用DSL操作。
定义一个case类对应数据类型,然后通过as[对象]方法将DataFrame转为DataSet类型,然后将我们的UDAF聚合类转为列对象。
import org.apache.spark.SparkConf
import org.apache.spark.sql.{Dataset, Encoder, Encoders, Row, SparkSession, TypedColumn, functions}
import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StructField, StructType}/*** 早期的UDAF强类型聚合函数使用DSL操作*/
object UDAFTest03 {def main(args: Array[String]): Unit = {val conf = new SparkConf()conf.setAppName("spark sql udaf").setMaster("local[*]")val spark = SparkSession.builder().config(conf).getOrCreate()import spark.implicits._val df = spark.read.json("data/sql/people.json")val ds: Dataset[User] = df.as[User]// 将UDAF强类型聚合函数转为查询的类对象val udafCol: TypedColumn[User, Long] = new OldAvg_UDAF().toColumnds.select(udafCol).show()spark.stop()}
}/*** 自定义聚合函数类:* 1.继承org.apache.spark.sql.expressions.Aggregator,定义泛型:* IN : 输入数据类型 User* BUF : 缓冲区数据类型* OUT : 输出数据类型 Long* 2.重写方法*/
//样例类中的参数默认是 val 所以这里必须指定为var
case class User(name: String,age: Long,sex: String)
case class Buff(var total: Long,var count: Long)
class OldAvg_UDAF extends Aggregator[User,Buff,Long]{// zero: Buff zero代表这个方法是用来初始值(0值)// Buff是我们的case类 也就是说明这里是用来给 缓冲区进行初始化override def zero: Buff = {Buff(0L,0L)}// 根据输入数据更新缓冲区 要求返回-Buffoverride def reduce(buff: Buff, in: User): Buff = {buff.total += in.agebuff.count += 1buff}// 合并缓冲区 同样返回buff1override def merge(buff1: Buff, buff2: Buff): Buff = {buff1.total += buff2.totalbuff1.count += buff2.countbuff1}// 计算结果override def finish(buff: Buff): Long = {buff.total/buff.count}// 网络传输需要序列化 缓冲区的编码操作 -编码override def bufferEncoder: Encoder[Buff] = Encoders.product// 输出的编码操作 -解码override def outputEncoder: Encoder[Long] = Encoders.scalaLong
}
运行结果:
+------------------------------------------+
|OldAvg_UDAF(com.study.spark.core.sql.User)|
+------------------------------------------+
| 25|
+------------------------------------------+