文章目录
UDF
UDF 接受一个参数返回一个结果
spark.udf.register("toUppperCaseUdf",(cloumn:String) => cloumn.toUpperCase)
spark.sql("select toUppperCaseUdf(name) from t_user")
UDAF
多进一出,比如系统函数sum
无泛型约束的UDAF
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructField, StructType}
object AverageUDAF extends UserDefinedAggregateFunction {
/**
* 通过inputSchema指定调用自定义函数传入的参数类型
*numInput 为类型名称,可以任意指定,
* StructField("numInput", DoubleType, nullable = true) :: Nil 等同
* List("numInput", DoubleType, nullable = true) :: Nil
* @return
*/
override def inputSchema: StructType = {
StructType(
StructField("numInput", DoubleType, nullable = true) :: Nil
)
}
/**
* 缓冲数据
* 对于求平均数而言,不断累加的是年龄总人数以及年龄总和
* @return
*/
override def bufferSchema: StructType = {
StructType(
StructField("buff1", DoubleType, nullable = true) :: StructField("buff2", LongType, nullable = true) :: Nil
)
}
/**
* 自定义UDAF函数返回的数据类型
* @return
*/
override def dataType: DataType = DoubleType
/**
* 判断UDAF函数与返回的函数类型是否一致
* @return
*/
override def deterministic: Boolean = ???
/**
* 初始化值
* @param buffer
* 等价 buffer(0) = 0.0
* buffer(1) = 0L
*/
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0,0.0)
buffer.update(1,0L)
}
/**
* 控制具体的聚合逻辑,在同一个分区中,每次只取一行数据,将原表中每一行参与运算列累加到聚合缓冲区
* @param buffer 缓冲中数据ROW
* @param input 表中的ROW,0代表存放累加的年龄,1代表当前参数累加的年龄
*/
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer.update(0,buffer.getDouble(0)+input.getDouble(0))
buffer.update(1,buffer.getLong(1)+1)
}
/**
* 每一个分区都有自己的缓冲区,通过merge将聚合缓冲区中数据合并到一个聚合缓冲区中
* @param buffer1
* @param buffer2
*/
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1.update(0,buffer1.getDouble(0)+buffer2.getDouble(0))
buffer1.update(1,buffer1.getLong(1)+buffer2.getLong(1))
}
/**
* 对最终聚合缓冲区中数据进行最后一次运算
* @param buffer
* @return
*/
override def evaluate(buffer: Row): Any = {
buffer.getDouble(0) / buffer.getLong(1)
}
}
import org.apache.spark.sql.SparkSession
object TestUDF {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.master("local[*]")
.appName("test")
.getOrCreate()
spark.udf.register("AverageUDAF",(cloumn:Any) => cloumn.toString.toDouble)
spark.udf.register("AverageUDAF",AverageUDAF)
spark.sql("select AverageUDAF(age) from t_user group by sex")
}
}
有泛型约束的UDAF
原理一致,但是调用该UDAF时允许添加泛型,保障函数更加安全.但是这种UDAF不可直接在SQL中被调用运算
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{Encoder, Encoders, Row}
/**
* 泛型IN,BUF,OUT 对应类型部分为Row,自定义样例类,数据类型
* IN: 在聚合前每一个待聚合数据的类型
* BUF: 每个分区的聚合缓冲区的类型,在本例中为自定义样例类Buffer
* OUT: 聚合后最终返回的结果类型
*/
object AverageFemaleUDAF extends Aggregator[Row, Buffer, Double] {
/**
* 初始化聚合缓冲区的初始值
* @return
*/
override def zero: Buffer = Buffer(0.0, 0L)
/**
* 用于聚合当前分区中每一行的值到聚合缓冲区中Buffer中,在buffer中,age属性用于累加年龄,count用于累加人数
* @param b 缓冲区
* @param a 表中数据
* @return
*/
override def reduce(b: Buffer, a: Row): Buffer = {
if (a.getString(2) == "Female") {
b.age += a.getInt(1)
b.count += 1
}
b
}
/**
* 合并多个聚合缓冲区中的值
* @param b1
* @param b2
* @return
*/
override def merge(b1: Buffer, b2: Buffer): Buffer = {
b1.age += b2.age
b1.count += b2.count
b1
}
/**
* 对于最终的聚合缓冲区中的数据进行最后一次运算,得到UDAF的最终结果
* @param reduction
* @return
*/
override def finish(reduction: Buffer): Double = reduction.age / reduction.count
/**
* 聚合缓冲区类型解码器
* @return
*/
override def bufferEncoder: Encoder[Buffer] = Encoders.product
/**
* 最终结果的数据类型解码器
* @return
*/
override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
/**
* 通过样例类充当聚合缓冲区
*
* @param age
* @param count
*/
case class Buffer(var age: Double, var count: Long)
调用
val femalAvg = AverageFemaleUDAF.toColumn.name("avg")
df.select(femalAvg)
UDTF
一进多出去,一行数据中某一列数据展开比如flatMap
import java.util
import org.apache.spark.sql.{Encoder, Encoders, Row, SparkSession}
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import scala.collection.mutable.ListBuffer
object TestUDTF {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.master("local[*]")
.appName("test")
.getOrCreate()
val schema = StructType(List(
StructField("movie", StringType, nullable = false),
StructField("category", StringType, nullable = false)
))
val rows = new util.ArrayList[Row]()
rows.add(Row("<八佰>", "战争,历史"))
rows.add(Row("<我是传奇>", "科幻,丧尸"))
val df1 = spark.createDataFrame(rows, schema)
df1.show()
implicit val flatMapEncoder: Encoder[(String, String)] = Encoders.kryo[(String, String)]
val tableArray = df1.flatMap(row => {
val tableArray = new ListBuffer[(String, String)]()
val categoryArray = row.getString(1).split(",")
for (c <- categoryArray) {
tableArray.append((row.getString(0), c))
}
tableArray
}).collect()
val df2 = spark.createDataFrame(tableArray).toDF("movie", "category")
df2.show()
spark.stop()
}
}
结果
±-----±-------+
| movie|category|
±-----±-------+
| <八佰>| 战争|
| <八佰>| 历史|
|<我是传奇>| 科幻|
|<我是传奇>| 丧尸|
±-----±-------+