AVG是求平均值,所以输出类型是Double类型
1)创建弱类型聚合函数类extends UserDefinedAggregateFunction
class MyAgeFunction extends UserDefinedAggregateFunction { override def inputSchema: StructType = { new StructType ( ) . add ( "age" , LongType) } override def bufferSchema: StructType = { new StructType ( ) . add ( "sum" , LongType) . add ( "conut" , LongType) } override def dataType: DataType = DoubleTypeoverride def deterministic: Boolean = true override def initialize ( buffer: MutableAggregationBuffer) : Unit = { buffer ( 0 ) = 0 Lbuffer ( 1 ) = 0 L} override def update ( buffer: MutableAggregationBuffer, input: Row) : Unit = { buffer ( 0 ) = buffer. getLong ( 0 ) + input. getLong ( 0 ) buffer ( 1 ) = buffer. getLong ( 1 ) + 1 } override def merge ( buffer1: MutableAggregationBuffer, buffer2: Row) : Unit = { buffer1 ( 0 ) = buffer1. getLong ( 0 ) + buffer2. getLong ( 0 ) buffer1 ( 1 ) = buffer1. getLong ( 1 ) + buffer2. getLong ( 1 ) } override def evaluate ( buffer: Row) : Any = { buffer. getLong ( 0 ) . toDouble / buffer. getLong ( 1 ) }
}
聚合函数使用
def main ( args: Array[ String] ) : Unit = { val conf = new SparkConf ( ) . setAppName ( "Spark01_Custom" ) . setMaster ( "local[*]" ) val spark = SparkSession. builder ( ) . config ( conf) . getOrCreate ( ) val rdd1 = spark. sparkContext. makeRDD ( List ( ( "chun" , 21 ) , ( "chun1" , 23 ) , ( "chun3" , 22 ) ) ) import spark. implicits. _val frame = rdd1. toDF ( "name" , "age" ) frame. createGlobalTempView ( "people" ) val udaf = new MyAgeFunction spark. udf. register ( "avgAge" , udaf) spark. sql ( "select avgAge(age) from global_temp.people" ) . show}
2)创建强类型聚合函数AVG(extends Aggregator[输入类型,缓冲区类型,输出类型])
class MyAgeClassFuction extends Aggregator [ UserBean, AvgBuffer, Double] { override def zero: AvgBuffer = AvgBuffer ( 0 , 0 ) override def reduce ( b: AvgBuffer, a: UserBean) : AvgBuffer = { b. sum += a. ageb. count += 1 b} override def merge ( b1: AvgBuffer, b2: AvgBuffer) : AvgBuffer = { b1. sum = b1. sum + b2. sumb1. count = b1. count + b2. countb1} override def finish ( reduction: AvgBuffer) : Double = { reduction. sum / reduction. count} override def bufferEncoder: Encoder[ AvgBuffer] = Encoders. productoverride def outputEncoder: Encoder[ Double] = Encoders. scalaDouble
}
case class UserBean ( name : String, age : Int)
case class AvgBuffer ( var sum : Int, var count : Int)
使用
def main ( args: Array[ String] ) : Unit = { val conf = new SparkConf ( ) . setAppName ( "Spark02_Custom2" ) . setMaster ( "local[*]" ) val spark = SparkSession. builder ( ) . config ( conf) . getOrCreate ( ) val rdd = spark. sparkContext. makeRDD ( List ( ( "chun1" , 23 ) , ( "chun2" , 24 ) , ( "chun3" , 25 ) ) ) import spark. implicits. _rdd. toDF ( "name" , "age" ) val udaf = new MyAgeClassFuction val avgColumn = udaf. toColumn. name ( "avgAge" ) val userRDD = rdd. map { case ( name, age) = > { UserBean ( name, age) } } val ds = userRDD. toDSval rdd1 = ds. rddds. show ( ) rdd1. foreach ( println) spark. stop ( ) }
可以看到强类型聚合函数输出的结果每一行都是UserBean类型的,是样例类类型,并不像弱类型一样是row