- 用户自定义的
UDF
- 定义:
UDF(User-Defined-Function)
,也就是最基本的函数,它提供了SQL
中对字段转换的功能,不涉及聚合操作。例如将日期类型转换成字符串类型,格式化字段。 - 用法
object UDFTest {
case class Person(name: String, age: Int)
def main(args: Array[String]): Unit = {
//常见SparkSession
val sparkSession: SparkSession = SparkSession.builder().appName("DataFrameTest").master("local[2]").getOrCreate()
//根据文件获取RDD
val personRDD: RDD[String] = sparkSession.sparkContext.textFile("C:\\Users\\39402\\Desktop\\person.txt")
/**
* 注册一个udf函数,
* toString:为自定义函数的引用名,
* (str: String) => str + "我是UDF自定义函数":这个是自定义的函数体,它是一个匿名函数
*/
sparkSession.udf.register("toString", (str: String) => str + "我是UDF自定义函数")
import sparkSession.implicits._
//引入隐式转换
//利用反射将RDD转换成DataFrame
val personDF: DataFrame = personRDD.map(_.split(",")).map(line => Person(line(0), line(1).toInt)).toDF()
//将DataFrame注册成一张表
personDF.createOrReplaceTempView("person")
//利用Spark的SQL来查询数据,其中toString就是我们自定义的UDF函数
sparkSession.sql("select toString(name),age from person").show()
}
}
- 用户自定义的
UDAF
- 定义:
UDAF
函数是用户自定义的聚合函数,为Spark SQL
提供对数据集的聚合功能,类似于max()、min()、count()
等功能,只不过自定义的功能是根据具体的业务功能来确定的。因为DataFrame是弱类型的,DataSet是强类型,所以自定义的UDAF
也提供了两种实现,一个是弱类型的一个是强类型的。 - 弱类型用法,需要继承
UserDefindAggregateFunction
,实现它的方法
package com.lyz.sql.udf
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
object MyCustomUDAF extends UserDefinedAggregateFunction {
//:: Nil 作用就是为StructField常见Array集合,并放入进去
def inputSchema: StructType = StructType(StructField("age", IntegerType) :: Nil)
//缓存字段类型,也就是每个分区的共享变量
def bufferSchema: StructType = StructType(StructField("sum", IntegerType) :: StructField("count", IntegerType) :: Nil)
//UDF输出数据类型
def dataType: DataType = IntegerType
//输入类型和输出类型是否一致
def deterministic: Boolean = true
//初始化分区中的共享变量
def initialize(buffer: MutableAggregationBuffer): Unit = {
//初始化每个分区上的年龄总和为0
buffer(0) = 0
//初始化每个分区上的人数为0
buffer(1) = 0
}
//每个分区中每一条记录,聚合的时候需要调用该方法
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
//将新输入进来的数据一之前合并的结果做聚合操作,
//buffer(0)就是上边定义的年龄总和sum,也就是每个分区上的年龄总和
buffer(0) = buffer.getInt(0) + input.getInt(0)
//buffer(1)就是上边定义的人的个数count,也就是每个分区上的人个数
buffer(1) = buffer.getInt(1) + 1
}
//对分区结果进行合并
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
// buffer1(0)就是所有分区的年龄总和
//buffer1.getInt(0) + buffer2.getInt(0):就是将没分区上的年龄相加
//下标为0的就是年龄总和
buffer1(0) = buffer1.getInt(0) + buffer2.getInt(0)
//buffer(1)就是所有分区的人个数
//buffer1.getInt(1) + buffer2.getInt(1):就是将每个分区人个数聚合在一起,
//下标为1就是人的个数
buffer1(1) = buffer1.getInt(1) + buffer2.getInt(1)
}
//最终结算结果
def evaluate(buffer: Row): Any = {
buffer.getInt(0) / buffer.getInt(1)
}
}
package com.lyz.sql.udf
import com.lyz.sql.dataframe.DataFrameTest.Person
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SparkSession}
object MyCustomUDAFMain {
def main(args: Array[String]): Unit = {
val sparkSession: SparkSession = SparkSession.builder().appName("DataFrameTest").master("local[2]").getOrCreate()
//根据文件获取RDD
val personRDD: RDD[String] = sparkSession.sparkContext.textFile("C:\\Users\\39402\\Desktop\\person.txt")
import sparkSession.implicits._
//引入隐式转换
//利用反射将RDD转换成DataFrame
val personDF: DataFrame = personRDD.map(_.split(",")).map(line => Person(line(0), line(1).toInt)).toDF()
sparkSession.udf.register("myCustomUDAF", MyCustomUDAF)
personDF.createOrReplaceTempView("person")
/**
* 输出结果为:15
*/
sparkSession.sql("select myCustomUDAF(age) from person").show()
}
}
- 强类型用法,需要继承
Aggregate
,实现它的方法。既然是强类型,那么其中肯定涉及到对象的存在
package com.lyz.sql.udf
import org.apache.spark.sql.{Encoder, Encoders}
import org.apache.spark.sql.expressions.Aggregator
//输入
case class Person(name: String, age: Int)
//缓存变量,也就是逻辑介质,
case class Avg(sum: Int, count: Int)
object MyCutomUDAFStrong extends Aggregator[Person, Avg, Int] {
//初始化缓存变量
def zero: Avg = Avg(0, 0)
/**
* 每个分区计算各自的结果
*
* @param b :聚合后的缓存变量
* @param a :新输入的数据
* @return b:聚合后的缓存变量
*/
def reduce(b: Avg, a: Person): Avg = {
b.sum += a.age
b.count += 1
b
}
//合并每个分区的结果
def merge(b1: Avg, b2: Avg): Avg = {
b1.sum += b2.sum
b1.count += b2.count
b1
}
//最后完成平均值的计算
def finish(reduction: Avg): Int = {
reduction.sum / reduction.count
}
//Encoders.product:是对scala元组和case类型转换的编码器
def bufferEncoder: Encoder[Avg] = Encoders.product
//设定输出值的编码器
def outputEncoder: Encoder[Int] = Encoders.scalaInt
}
package com.lyz.sql.udf
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Dataset, SparkSession, TypedColumn}
object MyCustomStrongMain {
def main(args: Array[String]): Unit = {
val sparkSession: SparkSession = SparkSession.builder().appName("DataFrameTest").master("local[2]").getOrCreate()
//根据文件获取RDD
val personRDD: RDD[String] = sparkSession.sparkContext.textFile("C:\\Users\\39402\\Desktop\\person.txt")
import sparkSession.implicits._ //引入隐式转换
//里用RDD生成Dataset
val personDS: Dataset[Person] = personRDD.map(_.split(",")).map(line => Person(line(0), line(1).toInt)).toDS()
//将这个函数转成TypedColumn,并且提供一个别名
val avgAge: TypedColumn[Person, Int] = MyCustomUDAFStrong.toColumn.name("ageAvg")
personDS.select(avgAge).show()
}
}