一:程序
1.需求
实现一个求平均值的UDAF。
这里保留Double格式化,在完成求平均值后与系统的AVG进行对比,观察正确性。
2.SparkSQLUDFDemo程序
1 package com.scala.it 2 3 import org.apache.spark.sql.hive.HiveContext 4 import org.apache.spark.{SparkConf, SparkContext} 5 6 import scala.math.BigDecimal.RoundingMode 7 8 object SparkSQLUDFDemo { 9 def main(args: Array[String]): Unit = { 10 val conf = new SparkConf() 11 .setMaster("local[*]") 12 .setAppName("udf") 13 val sc = SparkContext.getOrCreate(conf) 14 val sqlContext = new HiveContext(sc) 15 16 // ================================== 17 // 写一个Double数据格式化的自定义函数(给定保留多少位小数部分) 18 sqlContext.udf.register( 19 "doubleValueFormat", // 自定义函数名称 20 (value: Double, scale: Int) => { 21 // 自定义函数处理的代码块 22 BigDecimal.valueOf(value).setScale(scale, RoundingMode.HALF_DOWN).doubleValue() 23 }) 24 25 // 自定义UDAF 26 sqlContext.udf.register("selfAvg", AvgUDAF) 27 28 sqlContext.sql( 29 """ 30 |SELECT 31 | deptno, 32 | doubleValueFormat(AVG(sal), 2) AS avg_sal, 33 | doubleValueFormat(selfAvg(sal), 2) AS self_avg_sal 34 |FROM hadoop09.emp 35 |GROUP BY deptno 36 """.stripMargin).show() 37 38 } 39 }
3.AvgUDAF程序
1 package com.scala.it 2 3 import org.apache.spark.sql.Row 4 import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} 5 import org.apache.spark.sql.types._ 6 7 8 object AvgUDAF extends UserDefinedAggregateFunction{ 9 override def inputSchema: StructType = { 10 // 给定UDAF的输出参数类型 11 StructType( 12 StructField("sal", DoubleType) :: Nil 13 ) 14 } 15 16 override def bufferSchema: StructType = { 17 // 在计算过程中会涉及到的缓存数据类型 18 StructType( 19 StructField("total_sal", DoubleType) :: 20 StructField("count_sal", LongType) :: Nil 21 ) 22 } 23 24 override def dataType: DataType = { 25 // 给定该UDAF返回的数据类型 26 DoubleType 27 } 28 29 override def deterministic: Boolean = { 30 // 主要用于是否支持近似查找,如果为false:表示支持多次查询允许结果不一样,为true表示结果必须一样 31 true 32 } 33 34 override def initialize(buffer: MutableAggregationBuffer): Unit = { 35 // 初始化 ===> 初始化缓存数据 36 buffer.update(0, 0.0) // 初始化total_sal 37 buffer.update(1, 0L) // 初始化count_sal 38 } 39 40 override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { 41 // 根据输入的数据input,更新缓存buffer的内容 42 // 获取输入的sal数据 43 val inputSal = input.getDouble(0) 44 45 // 获取缓存中的数据 46 val totalSal = buffer.getDouble(0) 47 val countSal = buffer.getLong(1) 48 49 // 更新缓存数据 50 buffer.update(0, totalSal + inputSal) 51 buffer.update(1, countSal + 1L) 52 } 53 54 override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { 55 // 当两个分区的数据需要进行合并的时候,该方法会被调用 56 // 功能:将buffer2中的数据合并到buffer1中 57 // 获取缓存区数据 58 val buf1Total = buffer1.getDouble(0) 59 val buf1Count = buffer1.getLong(1) 60 61 val buf2Total = buffer2.getDouble(0) 62 val buf2Count = buffer2.getLong(1) 63 64 // 更新缓存区 65 buffer1.update(0, buf1Total + buf2Total) 66 buffer1.update(1, buf1Count + buf2Count) 67 } 68 69 override def evaluate(buffer: Row): Any = { 70 // 求返回值 71 buffer.getDouble(0) / buffer.getLong(1) 72 } 73 }
4.效果
二:知识点
1.udf注册
2.解释上面的update
重要的是两个参数的意思,不然程序有些看不懂。
所以,程序的意思是,第一位存储总数,第二位存储个数。
3.还要解释一个StructType的生成
在以前的程序中,是使用Array来生成的。如:
在上面的程序中,不是这种方式,使用集合的方式。