一:程序

1.需求

  实现一个求平均值的UDAF。

  这里保留Double格式化,在完成求平均值后与系统的AVG进行对比,观察正确性。

 

2.SparkSQLUDFDemo程序

 1 package com.scala.it
 2 
 3 import org.apache.spark.sql.hive.HiveContext
 4 import org.apache.spark.{SparkConf, SparkContext}
 5 
 6 import scala.math.BigDecimal.RoundingMode
 7 
 8 object SparkSQLUDFDemo {
 9   def main(args: Array[String]): Unit = {
10     val conf = new SparkConf()
11       .setMaster("local[*]")
12       .setAppName("udf")
13     val sc = SparkContext.getOrCreate(conf)
14     val sqlContext = new HiveContext(sc)
15 
16     // ==================================
17     // 写一个Double数据格式化的自定义函数(给定保留多少位小数部分)
18     sqlContext.udf.register(
19       "doubleValueFormat", // 自定义函数名称
20       (value: Double, scale: Int) => {
21         // 自定义函数处理的代码块
22         BigDecimal.valueOf(value).setScale(scale, RoundingMode.HALF_DOWN).doubleValue()
23       })
24 
25     // 自定义UDAF
26     sqlContext.udf.register("selfAvg", AvgUDAF)
27 
28     sqlContext.sql(
29       """
30         |SELECT
31         |  deptno,
32         |  doubleValueFormat(AVG(sal), 2) AS avg_sal,
33         |  doubleValueFormat(selfAvg(sal), 2) AS self_avg_sal
34         |FROM hadoop09.emp
35         |GROUP BY deptno
36       """.stripMargin).show()
37 
38   }
39 }

 

3.AvgUDAF程序

 1 package com.scala.it
 2 
 3 import org.apache.spark.sql.Row
 4 import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
 5 import org.apache.spark.sql.types._
 6 
 7 
 8 object AvgUDAF extends UserDefinedAggregateFunction{
 9   override def inputSchema: StructType = {
10     // 给定UDAF的输出参数类型
11     StructType(
12       StructField("sal", DoubleType) :: Nil
13     )
14   }
15 
16   override def bufferSchema: StructType = {
17     // 在计算过程中会涉及到的缓存数据类型
18     StructType(
19       StructField("total_sal", DoubleType) ::
20         StructField("count_sal", LongType) :: Nil
21     )
22   }
23 
24   override def dataType: DataType = {
25     // 给定该UDAF返回的数据类型
26     DoubleType
27   }
28 
29   override def deterministic: Boolean = {
30     // 主要用于是否支持近似查找,如果为false:表示支持多次查询允许结果不一样,为true表示结果必须一样
31     true
32   }
33 
34   override def initialize(buffer: MutableAggregationBuffer): Unit = {
35     // 初始化 ===> 初始化缓存数据
36     buffer.update(0, 0.0) // 初始化total_sal
37     buffer.update(1, 0L) // 初始化count_sal
38   }
39 
40   override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
41     // 根据输入的数据input,更新缓存buffer的内容
42     // 获取输入的sal数据
43     val inputSal = input.getDouble(0)
44 
45     // 获取缓存中的数据
46     val totalSal = buffer.getDouble(0)
47     val countSal = buffer.getLong(1)
48 
49     // 更新缓存数据
50     buffer.update(0, totalSal + inputSal)
51     buffer.update(1, countSal + 1L)
52   }
53 
54   override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
55     // 当两个分区的数据需要进行合并的时候,该方法会被调用
56     // 功能:将buffer2中的数据合并到buffer1中
57     // 获取缓存区数据
58     val buf1Total = buffer1.getDouble(0)
59     val buf1Count = buffer1.getLong(1)
60 
61     val buf2Total = buffer2.getDouble(0)
62     val buf2Count = buffer2.getLong(1)
63 
64     // 更新缓存区
65     buffer1.update(0, buf1Total + buf2Total)
66     buffer1.update(1, buf1Count + buf2Count)
67   }
68 
69   override def evaluate(buffer: Row): Any = {
70     // 求返回值
71     buffer.getDouble(0) / buffer.getLong(1)
72   }
73 }

 

4.效果

  048 SparkSQL自定义UDAF函数_sql

 

二:知识点

1.udf注册

  048 SparkSQL自定义UDAF函数_spark_02

 

2.解释上面的update

  重要的是两个参数的意思,不然程序有些看不懂。

  所以,程序的意思是,第一位存储总数,第二位存储个数。

  048 SparkSQL自定义UDAF函数_数据_03

 

3.还要解释一个StructType的生成

  在以前的程序中,是使用Array来生成的。如:

    048 SparkSQL自定义UDAF函数_sql_04

  在上面的程序中,不是这种方式,使用集合的方式。

    048 SparkSQL自定义UDAF函数_spark_05