UDAF的使用(弱类型 基于DataFrame)
用户自定义UDAF聚合函数需要实现以下两个步骤:
1、弱类型聚合函数
继承UserDefinedAggregateFunction
2、注册为函数:ss.udf.register(“avgCus”, new CusAvgFun)
package SparkSQL
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
/**
* 自定义UDAF函数
* 聚合函数 对多行数据生效 进来多行 输出一行
*/
object TestUDAFFunction {
def main(args: Array[String]): Unit = {
val ss = SparkSession.builder().master("local").appName("UDAF Function").getOrCreate()
val sc = ss.sparkContext
import ss.implicits._
var data: DataFrame = sc.parallelize(Array(("zs", 15), ("ls", 20), ("ww", 18), ("ml", 25), ("zq", 30))).toDF("name", "ageAndheigth")
//注册聚合函数
ss.udf.register("avgCus", new CusAvgFun)
data.createGlobalTempView("student")
//聚合函数的使用
ss.sql("select avgCus(ageAndheigth) as valuecu from global_temp.student").show()
}
}
/**
* 用户自定义聚合函数
*/
class CusAvgFun extends UserDefinedAggregateFunction {
//输入数据的类型 类似于创建dataframe时候指定列的数据类型 在整个底层计算中应该是以row传递数据
override def inputSchema: StructType = StructType(StructField("data", LongType) :: Nil)
//缓冲区中值的数据类型 这里就是你在计算过程中所需要的数据的类型 如果求平均数在这里就是两个中间值 封装成Row传值
override def bufferSchema: StructType = StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil)
//输出数据的类型
override def dataType: DataType = DoubleType
//如果此函数是确定性的,即给定相同的输入,返回true,始终返回相同的输出。
override def deterministic: Boolean = true
//初始化给定的聚合缓冲区,即聚合缓冲区的零值。约定应该是在两个初始缓冲区上应用合并函数只应返回初始缓冲区本身。
//这里说一下 这里是中间值的数据 因为刚刚这中间值一共设置了两个 所以这里要按照顺序更新两个值
//分别是sum值和buffer值
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0, 0l)
buffer.update(1, 0l)
}
//更新值 buffer代表的是缓冲区的值 input代表的是新输入的数据的值
//其中input代表的是新输入的数据 封装成Row,其中Row中有一个值 这个值就是刚定义的输入数据的类型 long类型的data
//buffer中有两个值 分别是定义的缓冲区的值 一个sum 一个count
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer.update(0, buffer.getLong(0) + input.getLong(0))
buffer.update(1, buffer.getLong(1) + 1l)
}
//合并两个聚合缓冲区并将更新的缓冲区值存储回“buffer1”。当我们将两个部分聚合的数据合并在一起时,会调用此方法。
//会将一个缓冲区的数据拉取到另一个缓冲区完成合并。所以在第二个参数buffer2就是另外一个缓冲区的数据只不过封装成了Row
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1.update(0, buffer1.getLong(0) + buffer2.getLong(0))
buffer1.update(1, buffer1.getLong(1) + buffer2.getLong(1))
}
//根据给定的聚合缓冲区计算此[[UserDefinedAggregateFunction]]的最终结果。
//这里的是指最终缓冲区的数据 有两个值 分区是刚定义的缓冲区的数据类型 sum和count 按顺序的
override def evaluate(buffer: Row): Any = {
(buffer.getLong(0) / buffer.getLong(1)).toDouble
}
}