文章目录
1.UDF
object Spark03 {
def main(args: Array[String]): Unit = {
val sparkConf = new SparkConf().setAppName("Sql").setMaster("local")
val spark = SparkSession.builder().config(sparkConf).getOrCreate()
val df = spark.read.json("data/input.json")
df.createOrReplaceTempView("user")
// TODO
// UDF
spark.udf.register("prefixName", (name: String) => {
"Name" + name
})
spark.sql("select name, prefixName(age) from user").show()
spark.close()
}
}
2.UDAF
1.弱类型函数实现
UserDefinedAggregateFunction 抽象类
object Spark04_UDAF {
def main(args: Array[String]): Unit = {
val sparkConf = new SparkConf().setMaster("local").setAppName("Sql")
val spark = SparkSession.builder().config(sparkConf).getOrCreate()
val df = spark.read.json("data/input.json")
df.createOrReplaceTempView("user")
spark.udf.register("avgAge", new MyAvgUDAF())
spark.sql("select avgAge(age) from user").show()
spark.close()
}
// 弱类型函数实现
class MyAvgUDAF extends UserDefinedAggregateFunction{
override def inputSchema: StructType = {
StructType(
Array(
StructField("age", LongType)
)
)
}
override def bufferSchema: StructType = {
StructType(
Array(
StructField("total", LongType),
StructField("count", LongType)
)
)
}
// 输出
override def dataType: DataType = LongType
// 函数稳定性
override def deterministic: Boolean = true
// 缓冲区初始化
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0, 0L)
buffer.update(1, 0L)
}
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer.update(0, buffer.getLong(0) + input.getLong(0))
buffer.update(1, buffer.getLong(1) + 1)
}
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1.update(0, buffer1.getLong(0) + buffer2.getLong(0))
buffer1.update(1, buffer1.getLong(1) + buffer2.getLong(1))
}
// 计算
override def evaluate(buffer: Row): Any = {
buffer.getLong(0)/buffer.getLong(1)
}
}
}
2.强类型函数实现
Aggregator 抽象类
object Spark05_UDAF {
def main(args: Array[String]): Unit = {
val sparkConf = new SparkConf().setMaster("local").setAppName("Sql")
val spark = SparkSession.builder().config(sparkConf).getOrCreate()
val df = spark.read.json("data/input.json")
df.createOrReplaceTempView("user")
// 强转弱
spark.udf.register("avgAge", functions.udaf(new MyAvgUDAF1))
spark.sql("select avgAge(age) from user").show()
spark.close()
}
case class Buff(var total:Long, var count: Long)
class MyAvgUDAF1 extends Aggregator[Long, Buff, Long]{
override def zero: Buff = {
Buff(0L, 0L)
}
override def reduce(b: Buff, a: Long): Buff = {
b.total = b.total + a
b.count = b.count + 1
b
}
override def merge(b1: Buff, b2: Buff): Buff = {
b1.total = b1.total + b2.total
b1.count = b1.count + b2.count
b1
}
override def finish(reduction: Buff): Long = {
reduction.total / reduction.count
}
// 缓冲区的编码
override def bufferEncoder: Encoder[Buff] = Encoders.product
// 输入的编码
override def outputEncoder: Encoder[Long] = Encoders.scalaLong
}
}
3.Mysql数据源
object Spark06 {
def main(args: Array[String]): Unit = {
// spark.read.format("json").load
// df.write.format("json").save()
// Mysql
val sparkConf = new SparkConf().setMaster("local").setAppName("Sql")
val spark = SparkSession.builder().config(sparkConf).getOrCreate()
import spark.implicits._
val df = spark.read.format("jdbc")
.option("url", "jdbc:mysql:///test")
.option("driver", "com.mysql.jdbc.Driver")
.option("user", "root")
.option("password", "root")
.option("dbtable", "user9")
.load()
df.show()
df.write.format("jdbc")
.option("url", "jdbc:mysql:///test")
.option("driver", "com.mysql.jdbc.Driver")
.option("user", "root")
.option("password", "root")
.option("dbtable", "user8")
.mode(SaveMode.Append)
.save()
spark.close()
}
}
4.Hive
maven
<dependency>
<groupId>mysql</groupId>
<artifactId>mysql-connector-java</artifactId>
<version>5.1.27</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-hive_2.12</artifactId>
<version>3.0.0</version>
</dependency>
object Spark07_Hive {
def main(args: Array[String]): Unit = {
val sparkConf = new SparkConf().setMaster("local").setAppName("Sql")
val spark = SparkSession.builder().enableHiveSupport().config(sparkConf).getOrCreate()
import spark.implicits._
spark.sql("show tables").show()
spark.close()
}
}