目录
2.4 Action操作
2.5 Transformation 操作
与RDD类似的操作
存储相关
select相关
where相关
groupBy相关
orderBy相关
join相关
集合相关
空值处理
窗口函数
内建函数
2.6 SQL语句
2.7 输入与输出
Parquet文件:
json文件:
CSV文件:
2.8 UDF & UDAF
1、UDF
2、UDAF
2.9 访问Hive
第3节 Spark SQL原理
2.1 SparkSQL中的join
1、Broadcast Hash Join
2、Shuffle Hash Join
3、Shuffle Sort Merge Join
4、Cartesian product join(了解)
5、Broadcast nested loop join(了解)
2.2 SQL解析过程
2.4 Action操作
与RDD类似的操作
- show、collect、collectAsList、head、first、count、take、takeAsList、reduce
与结构相关
- printSchema、explain、columns、dtypes、col
EMPNO,ENAME,JOB,MGR,HIREDATE,SAL,COMM,DEPTNO
7369,SMITH,CLERK,7902,2001-01-02 22:12:13,800,,20
7499,ALLEN,SALESMAN,7698,2002-01-02 22:12:13,1600,300,30
7521,WARD,SALESMAN,7698,2003-01-02 22:12:13,1250,500,30
7566,JONES,MANAGER,7839,2004-01-02 22:12:13,2975,,20
7654,MARTIN,SALESMAN,7698,2005-01-02 22:12:13,1250,1400,30
7698,BLAKE,MANAGER,7839,2005-04-02 22:12:13,2850,,30
7782,CLARK,MANAGER,7839,2006-03-02 22:12:13,2450,,10
7788,SCOTT,ANALYST,7566,2007-03-02 22:12:13,3000,,20
7839,KING,PRESIDENT,,2006-03-02 22:12:13,5000,,10
7844,TURNER,SALESMAN,7698,2009-07-02 22:12:13,1500,0,30
7876,ADAMS,CLERK,7788,2010-05-02 22:12:13,1100,,20
7900,JAMES,CLERK,7698,2011-06-02 22:12:13,950,,30
7902,FORD,ANALYST,7566,2011-07-02 22:12:13,3000,,20
7934,MILLER,CLERK,7782,2012-11-02 22:12:13,1300,,10
package com.ch.sparksql
import org.apache.spark.sql.{DataFrame, SparkSession}
object ActionDemo {
def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder()
.appName("Demo1")
.master("local[*]")
.getOrCreate()
val sc = spark.sparkContext
sc.setLogLevel("warn")
val df: DataFrame = spark.read
.option("header", "true")
.option("inferschema", "true")
.csv("data/emp.dat")
df.printSchema()
//root
// |-- EMPNO: integer (nullable = true)
// |-- ENAME: string (nullable = true)
// |-- JOB: string (nullable = true)
// |-- MGR: integer (nullable = true)
// |-- HIREDATE: timestamp (nullable = true)
// |-- SAL: integer (nullable = true)
// |-- COMM: integer (nullable = true)
// |-- DEPTNO: integer (nullable = true)
df.toJSON.show(false)
//+-----------------------------------------------------------------------------------------------------------------------------------------+
//|value |
//+-----------------------------------------------------------------------------------------------------------------------------------------+
//|{"EMPNO":7369,"ENAME":"SMITH","JOB":"CLERK","MGR":7902,"HIREDATE":"2001-01-02T22:12:13.000+08:00","SAL":800,"DEPTNO":20} |
// ......
// 显示第一行数据, 非 字段名
val h0 = df.head()
println(h0) // [7369,SMITH,CLERK,7902,2001-01-02 22:12:13.0,800,null,20]
println("==================================================")
val df1 = df
println(df1.count) // 14
println("==================================================")
// 缺省显示20行
df1.union(df1).show()
//.....
//| 7698| BLAKE| MANAGER|7839|2005-04-02 22:12:13|2850|null| 30|
//+-----+------+---------+----+-------------------+----+----+------+
//only showing top 20 rows
println("==================================================")
// 显示2行
df1.show(2)
//+-----+-----+--------+----+-------------------+----+----+------+
//|EMPNO|ENAME| JOB| MGR| HIREDATE| SAL|COMM|DEPTNO|
//+-----+-----+--------+----+-------------------+----+----+------+
//| 7369|SMITH| CLERK|7902|2001-01-02 22:12:13| 800|null| 20|
//| 7499|ALLEN|SALESMAN|7698|2002-01-02 22:12:13|1600| 300| 30|
//+-----+-----+--------+----+-------------------+----+----+------+
//only showing top 2 rows
println("==================================================")
// 不截断字符
df1.toJSON.show(false)
println("==================================================")
// 显示10行,不截断字符
df1.toJSON.show(10, false)
//+-----------------------------------------------------------------------------------------------------------------------------------------+
//|value |
//+-----------------------------------------------------------------------------------------------------------------------------------------+
//|{"EMPNO":7369,"ENAME":"SMITH","JOB":"CLERK","MGR":7902,"HIREDATE":"2001-01-02T22:12:13.000+08:00","SAL":800,"DEPTNO":20} |
// ....
println("==================================================")
spark.catalog.listFunctions.show(10000, false)
//+---------------------------+--------+-----------+-------------------------------------------------------------------------+-----------+
//|name |database|description|className |isTemporary|
//+---------------------------+--------+-----------+-------------------------------------------------------------------------+-----------+
//|! |null |null |org.apache.spark.sql.catalyst.expressions.Not |true |
// .......
println("==================================================")
// collect返回的是数组, Array[org.apache.spark.sql.Row]
val c1 = df1.collect()
// collectAsList返回的是List, List[org.apache.spark.sql.Row]
val c2 = df1.collectAsList()
// 返回 org.apache.spark.sql.Row
val h1 = df1.head()
val f1 = df1.first()
// 返回 Array[org.apache.spark.sql.Row],长度为3
val h2 = df1.head(3)
val f2 = df1.take(3)
// 返回 List[org.apache.spark.sql.Row],长度为2
val t2 = df1.takeAsList(2)
spark.close()
}
}
// 结构属性
df1.columns // 查看列名
df1.dtypes // 查看列名和类型
df1.explain() // 参看执行计划
df1.col("name") // 获取某个列
df1.printSchema // 常用
2.5 Transformation 操作
select * from tab where ... group by ... having... order by...
RDD类似的操作
持久化/缓存与checkpoint
select
where
group by / 聚合
order by
join
集合操作
空值操作(函数)
函数
与RDD类似的操作
map、filter、flatMap、mapPartitions、sample、 randomSplit、 limit、distinct、dropDuplicates、describe
df1.map(row=>row.getAs[Int](0)).show
// randomSplit(与RDD类似,将DF、DS按给定参数分成多份) 随机返回其中 %50 60 70的数据
val df2 = df1.randomSplit(Array(0.5, 0.6, 0.7))
df2(0).count
df2(1).count
df2(2).count
// 取10行数据生成新的DataSet
val df2 = df1.limit(10)
// distinct,去重
val df2 = df1.union(df1)
df2.distinct.count
// dropDuplicates,按列值去重
df2.dropDuplicates.show
df2.dropDuplicates("mgr", "deptno").show
df2.dropDuplicates("mgr").show
df2.dropDuplicates("deptno").show
// 返回全部列的统计(count、mean、stddev、min、max)
ds1.describe().show
// 返回指定列的统计
ds1.describe("sal").show
ds1.describe("sal", "comm").show
存储相关
cacheTable、persist、checkpoint、unpersist、cache
备注:Dataset 默认的存储级别是 MEMORY_AND_DISK
import org.apache.spark.storage.StorageLevel
spark.sparkContext.setCheckpointDir("hdfs://linux121:9000/checkpoint")
df1.show()
df1.checkpoint()
df1.cache()
df1.persist(StorageLevel.MEMORY_ONLY)
df1.count()
df1.unpersist(true)
df1.createOrReplaceTempView("t1")
spark.catalog.cacheTable("t1")
spark.catalog.uncacheTable("t1")
select相关
列的多种表示、select、selectExpr
// 列的多种表示方法。使用""、$""、'、col()、ds("")
// 注意:不要混用;必要时使用spark.implicitis._;并非每个表示在所有的地方都有效
// 建议使用 $ 形式, 可以遍历
df1.select($"ename", $"hiredate", $"sal").show
df1.select("ename", "hiredate", "sal").show
df1.select('ename, 'hiredate, 'sal).show
df1.select(col("ename"), col("hiredate"), col("sal")).show
df1.select(df1("ename"), df1("hiredate"), df1("sal")).show
// 下面的写法无效,其他列的表示法有效
df1.select("ename", "hiredate", "sal"+100).show
df1.select("ename", "hiredate", "sal+100").show
// 这样写才符合语法
df1.select($"ename", $"hiredate", $"sal"+100).show
df1.select('ename, 'hiredate, 'sal+100).show
// 可使用expr表达式(expr里面只能使用引号)
df1.select(expr("comm+100"), expr("sal+100"), expr("ename")).show
df1.selectExpr("ename as name").show
// 平方
df1.selectExpr("power(sal, 2)", "sal").show
// 四舍五入
df1.selectExpr("round(sal, -3) as newsal", "sal", "ename").show
// drop、withColumn、 withColumnRenamed、casting
// drop 删除一个或多个列,得到新的DF
df1.drop("mgr")
df1.drop("empno", "mgr")
// withColumn,修改列值
val df2 = df1.withColumn("sal", $"sal"+1000)
df2.show
// withColumnRenamed,更改列名
df1.withColumnRenamed("sal", "newsal")
// 备注:drop、withColumn、withColumnRenamed返回的是DF
// cast,类型转换
df1.selectExpr("cast(empno as string)").printSchema
import org.apache.spark.sql.types._
df1.select('empno.cast(StringType)).printSchema
where相关
where == filter
// where操作, 用 "=" 也可以
df1.filter("sal>1000").show
df1.filter("sal>1000 and job=='MANAGER'").show
// filter操作
df1.where("sal>1000").show
df1.where("sal>1000 and job=='MANAGER'").show
groupBy相关
groupBy、agg、max、min、avg、sum、count(后面5个为内置函数)
// 在idea中使用, 记得导包 org.apache.spark.sql.function._
// groupBy、max、min、mean、sum、count(与df1.count不同)
df1.groupBy("Job").sum("sal").show
df1.groupBy("Job").max("sal").show
df1.groupBy("Job").min("sal").show
df1.groupBy("Job").avg("sal").show
df1.groupBy("Job").count.show
// 类似having子句
df1.groupBy("Job").avg("sal").where("avg(sal) > 2000").show
df1.groupBy("Job").avg("sal").where($"avg(sal)" > 2000).show
// agg
df1.groupBy("Job").agg("sal"->"max", "sal"->"min", "sal"->"avg", "sal"->"sum", "sal"->"count").show
df1.groupBy("deptno").agg("sal"->"max", "sal"->"min", "sal"->"avg", "sal"->"sum", "sal"->"count").show
// 这种方式更好理解
df1.groupBy("Job").agg(max("sal"), min("sal"), avg("sal"), sum("sal"), count("sal")).show
// 给列取别名
df1.groupBy("Job").agg(max("sal"), min("sal"), avg("sal"), sum("sal"),
count("sal")).withColumnRenamed("min(sal)", "min1").show
// 给列取别名,最简便
df1.groupBy("Job").agg(max("sal").as("max1"), min("sal").as("min2"), avg("sal").as("avg3"),
sum("sal").as("sum4"), count("sal").as("count5")).show
orderBy相关
orderBy == sort
// orderBy 建议用 $符方式
df1.orderBy("sal").show
df1.orderBy($"sal").show
df1.orderBy($"sal".asc).show
// 降序
df1.orderBy(-$"sal").show
df1.orderBy('sal).show
df1.orderBy(col("sal")).show
df1.orderBy(df1("sal")).show
df1.orderBy($"sal".desc).show
df1.orderBy(-'sal).show
df1.orderBy(-'deptno, -'sal).show
// sort,以下语句等价
df1.sort("sal").show
df1.sort($"sal").show
df1.sort($"sal".asc).show
df1.sort('sal).show
df1.sort(col("sal")).show
df1.sort(df1("sal")).show
df1.sort($"sal".desc).show
df1.sort(-'sal).show
df1.sort(-'deptno, -'sal).show
join相关
// 1、笛卡尔积, 不建议永
df1.crossJoin(df1).count
// 2、等值连接(单字段)(连接字段empno,仅显示了一次)
df1.join(df1, "empno").count
// 3、等值连接(多字段)(连接字段empno、ename,仅显示了一次)
df1.join(df1, Seq("empno", "ename")).show
// 定义第一个数据集
case class StudentAge(sno: Int, name: String, age: Int)
val lst = List(StudentAge(1,"Alice", 18), StudentAge(2,"Andy", 19),
StudentAge(3,"Bob",17), StudentAge(4,"Justin", 21),
StudentAge(5,"Cindy", 20))
val ds1 = spark.createDataset(lst)
ds1.show()
// 定义第二个数据集
case class StudentHeight(sname: String, height: Int)
val rdd = sc.makeRDD(List(StudentHeight("Alice", 160),
StudentHeight("Andy", 159),
StudentHeight("Bob", 170),
StudentHeight("Cindy", 165),
StudentHeight("Rose", 160)))
val ds2 = rdd.toDS
// 备注:不能使用双引号,而且这里是 ===
ds1.join(ds2, $"name"===$"sname").show
ds1.join(ds2, 'name==='sname).show
ds1.join(ds2, ds1("name")===ds2("sname")).show
ds1.join(ds2, ds1("sname")===ds2("sname"), "inner").show
// 多种连接方式 , 啥也不写, 默认内连接
ds1.join(ds2, $"name"===$"sname").show
ds1.join(ds2, $"name"===$"sname", "inner").show
// 左外
ds1.join(ds2, $"name"===$"sname", "left").show
ds1.join(ds2, $"name"===$"sname", "left_outer").show
// 右外
ds1.join(ds2, $"name"===$"sname", "right").show
ds1.join(ds2, $"name"===$"sname", "right_outer").show
// 全外连接
ds1.join(ds2, $"name"===$"sname", "outer").show
ds1.join(ds2, $"name"===$"sname", "full").show
ds1.join(ds2, $"name"===$"sname", "full_outer").show
备注:DS在join操作之后变成了DF
集合相关
union==unionAll(过期)、intersect、except
// union、unionAll(过期)、intersect、except。集合的交、并、差
val ds3 = ds1.select("name")
val ds4 = ds2.select("sname")
// union 求并集,不去重
ds3.union(ds4).show
// unionAll、union 等价;unionAll过期方法,不建议使用
ds3.unionAll(ds4).show
// intersect 求交
ds3.intersect(ds4).show
// except 求差
ds3.except(ds4).show
空值处理
// NaN (Not a Number), 经过非法运算得到的结果
math.sqrt(-1.0)
math.sqrt(-1.0).isNaN()
df1.show
// 删除所有列中, 有空值和NaN的那些行
df1.na.drop.show
// 删除某列有 空值和NaN 的那些行
df1.na.drop(Array("mgr")).show
// 对全部列填充;对指定单列填充;对指定多列填充
df1.na.fill(1000).show
df1.na.fill(1000, Array("comm")).show
df1.na.fill(Map("mgr"->2000, "comm"->1000)).show
// 对指定的值进行替换
df1.na.replace("comm" :: "deptno" :: Nil, Map(0 -> 100, 10 -> 100)).show
// 查询空值列或非空值列。isNull、isNotNull为内置函数
df1.filter("comm is null").show
df1.where("comm is null").show
df1.filter($"comm".isNull).show
df1.where($"comm".isNull).show
df1.filter(col("comm").isNull).show
df1.filter("comm is not null").show
df1.filter(col("comm").isNotNull).show
窗口函数
一般情况下窗口函数不用 DSL 处理,直接用SQL更方便
参考源码Window.scala、WindowSpec.scala(主要)
import org.apache.spark.sql.expressions.Window
val w1 = Window.partitionBy("cookieid").orderBy("createtime")
val w2 = Window.partitionBy("cookieid").orderBy("pv")
val w3 = w1.rowsBetween(Window.unboundedPreceding, Window.currentRow)
val w4 = w1.rowsBetween(-1, 1)
// 聚组函数【用分析函数的数据集】
df.select($"cookieid", $"pv", sum("pv").over(w1).alias("pv1")).show
df.select($"cookieid", $"pv", sum("pv").over(w3).alias("pv1")).show
df.select($"cookieid", $"pv", sum("pv").over(w4).as("pv1")).show
// 排名
df.select($"cookieid", $"pv", rank().over(w2).alias("rank")).show
df.select($"cookieid", $"pv", dense_rank().over(w2).alias("denserank")).show
df.select($"cookieid", $"pv", row_number().over(w2).alias("rownumber")).show
// lag、lead
df.select($"cookieid", $"pv", lag("pv", 2).over(w2).alias("rownumber")).show
df.select($"cookieid", $"pv", lag("pv", -2).over(w2).alias("rownumber")).show
内建函数
http://spark.apache.org/docs/latest/api/sql/index.html
2.6 SQL语句
总体而言:SparkSQL与HQL兼容;与HQL相比,SparkSQL更简洁。
createTempView、createOrReplaceTempView、spark.sql("SQL")
package com.ch.sparksql
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Dataset, SparkSession}
case class Info(id: String, tags: String)
object SQLDemo {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.appName(this.getClass.getCanonicalName)
.master("local[*]")
.getOrCreate()
spark.sparkContext.setLogLevel("warn")
// 导入隐式转换包
import spark.implicits._
// 准备数据
val arr = Array("1 1,2,3", "2 2,3", "3 1,2")
// 转成 RDD
val rdd: RDD[Info] = spark.sparkContext.makeRDD(arr)
.map { line =>
val fields: Array[String] = line.split("\\s+")
// 可以通过 case class 转成 DateSet
Info(fields(0), fields(1))
}
val ds: Dataset[Info] = spark.createDataset(rdd)
ds.createOrReplaceTempView("t1")
ds.show
// +---+-----+
// | id| tags|
// +---+-----+
// | 1|1,2,3|
// | 2| 2,3|
// | 3| 1,2|
// +---+-----+
// 用SQL处理 - HQL
spark.sql(
"""
|select id, tag
| from t1
| lateral view explode(split(tags, ",")) t2 as tag
|""".stripMargin
).show
// SparkSQL, 可以简化成下面的样子
// explode(split(tags, ",")) tag
spark.sql(
"""
|select id, explode(split(tags, ",")) tag
| from t1
|""".stripMargin
).show
spark.close()
}
}
2.7 输入与输出
SparkSQL内建支持的数据源包括:Parquet、JSON、CSV、Avro、Images、BinaryFiles(Spark 3.0)。其中Parquet是默认的数据源。
// 内部使用
DataFrameReader.format(args).option("key", "value").schema(args).load()
// 开发API
SparkSession.read
val df1 = spark.read.format("parquet").load("data/users.parquet")
// Use Parquet; you can omit format("parquet") if you wish as it's the default
val df2 = spark.read.load("data/users.parquet")
// Use CSV
val df3 = spark.read.format("csv")
.option("inferSchema", "true")
.option("header", "true")
.load("data/people1.csv")
// Use JSON
val df4 = spark.read.format("json")
.load("data/emp.json")
// 内部使用
DataFrameWriter.format(args)
.option(args)
.bucketBy(args)
.partitionBy(args)
.save(path)
// 开发API
DataFrame.write
Parquet文件:
spark.sql(
"""
|CREATE OR REPLACE TEMPORARY VIEW users
|USING parquet
|OPTIONS (path "data/users.parquet")
|""".stripMargin
)
spark.sql("select * from users").show
df.write.format("parquet")
.mode("overwrite")
.option("compression", "snappy")
.save("data/parquet")
json文件:
val fileJson = "data/emp.json"
val df6 = spark.read.format("json").load(fileJson)
spark.sql(
"""
|CREATE OR REPLACE TEMPORARY VIEW emp
| USING json
| options(path "data/emp.json")
|""".stripMargin)
spark.sql("SELECT * FROM emp").show()
spark.sql("SELECT * FROM emp").write
.format("json")
.mode("overwrite")
.save("data/json"
CSV文件:
// CSV
val fileCSV = "data/people1.csv"
val df = spark.read.format("csv")
.option("header", "true")
.option("inferschema", "true")
.load(fileCSV)
spark.sql(
"""
|CREATE OR REPLACE TEMPORARY VIEW people
| USING csv
|options(path "data/people1.csv",
| header "true",
| inferschema "true")
|""".stripMargin)
spark.sql("select * from people")
.write
.format("csv")
.mode("overwrite")
.save("data/csv")
JDBC:
val jdbcDF = spark
.read
.format("jdbc")
.option("url", "jdbc:mysql://linux123:3306/ebiz?useSSL=false")
//&useUnicode=true
.option("driver", "com.mysql.jdbc.Driver")
.option("dbtable", "lagou_product_info")
.option("user", "hive")
.option("password", "12345678")
.load()
jdbcDF.show()
jdbcDF.write
.format("jdbc")
// 下面跟上字符集, 可以避免乱码
.option("url", "jdbc:mysql://linux123:3306/ebiz?useSSL=false&characterEncoding=utf8")
.option("user", "hive")
.option("password", "12345678")
.option("driver", "com.mysql.jdbc.Driver")
.option("dbtable", "lagou_product_info_back")
.mode("append") // (SaveMode.Append)
.save
备注:如果有中文, 注意表的字符集,否则会有乱码
- SaveMode.ErrorIfExists(默认)。若表存在,则会直接报异 常,数据不能存入数据库
- SaveMode.Append。若表存在,则追加在该表中;若该表不存在,则会先创建表,再插入数据
- SaveMode.Overwrite。先将已有的表及其数据全都删除,再重新创建该表,最后插入新的数据
- SaveMode.Ignore。若表不存在,则创建表并存入数据;若表存在,直接跳过数据的存储,不会报错
-- 创建表
create table lagou_product_info_back as
select * from lagou_product_info;
-- 检查表的字符集
show create table lagou_product_info_back;
show create table lagou_product_info;
-- 修改表的字符集
alter table lagou_product_info_back convert to character set utf8;
2.8 UDF & UDAF
1、UDF
UDF(User Defined Function),自定义函数。函数的输入、输出都是一条数据记录,类似于Spark SQL中普通的数学或字符串函数。实现上看就是普通的Scala函数;
UDAF(User Defined Aggregation Funcation),用户自定义聚合函数。函数本身作用于数据集合,能够在聚合操作的基础上进行自定义操作(多条数据输入,一条数据输出);类似于在group by之后使用的sum、avg等函数;
用Scala编写的UDF与普通的Scala函数几乎没有任何区别,唯一需要多执行的一个步骤是要在SQLContext注册它。
def len(bookTitle: String):Int = bookTitle.length
spark.udf.register("len", len _)
val booksWithLongTitle = spark.sql("select title, author from books where len(title) > 10")
编写的UDF可以放到SQL语句的fields部分,也可以作为where、groupBy或者having子句的一部分。
也可以在使用UDF时,传入常量而非表的列名。稍稍修改一下前面的函数,让长度10作为函数的参数传入:
def lengthLongerThan(bookTitle: String, length: Int): Boolean = bookTitle.length > length
spark.udf.register("longLength", lengthLongerThan _)
val booksWithLongTitle = spark.sql("select title, author from books where longLength(title, 10)")
若使用DataFrame的API,则以字符串的形式将UDF传入:
val booksWithLongTitle = dataFrame.filter("longLength(title, 10)")
DataFrame的API也可以接收Column对象,可以用$符号来包裹一个字符串表示一个Column。$是定义在SQLImplicits 对象中的一个隐式转换。此时,UDF的定义也不相同,不能直接定义Scala函数,而是要用定义在org.apache.spark.sql.functions中的 udf 方法来接收一个函数。这种方式无需register:
import org.apache.spark.sql.functions._
val longLength = udf((bookTitle: String, length: Int) => bookTitle.length > length)
import spark.implicits._
val booksWithLongTitle = dataFrame.filter(longLength($"title", lit(10)))
完整示例:
package cn.lagou.sparksql
import org.apache.spark.sql.{Row, SparkSession}
class UDF {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.appName(this.getClass.getCanonicalName)
.master("local[*]")
.getOrCreate()
spark.sparkContext.setLogLevel("WARN")
val data = List(("scala", "author1"), ("spark", "author2"),
("hadoop", "author3"), ("hive", "author4"),
("strom", "author5"), ("kafka", "author6"))
val df = spark.createDataFrame(data).toDF("title", "author")
df.createTempView("books")
// 定义函数并注册
def len1(bookTitle: String): Int = bookTitle.length
spark.udf.register("len1", len1 _)
// UDF可以在select语句、where语句等多处使用
spark.sql("select title, author, len1(title) from books").show
spark.sql("select title, author from books where len1(title)>5").show
// UDF可以在DataFrame、Dataset的API中使用
import spark.implicits._
df.filter("len1(title)>5").show
// 不能通过编译
// df.filter(len1($"title")>5).show
// 能通过编译,但不能执行
// df.select("len1(title)").show
// 不能通过编译
// df.select(len1($"title")).show
// 如果要在DSL语法中使用$符号包裹字符串表示一个Column,需要用udf方法来接收函数。这种函数无需注册
import org.apache.spark.sql.functions._
// val len2 = udf((bookTitle: String) => bookTitle.length)
// val a:(String) => Int = (bookTitle: String) => bookTitle.length
// val len2 = udf(a)
val len2 = udf(len1 _)
df.filter(len2($"title") > 5).show
df.select($"title", $"author", len2($"title")).show
// 不使用UDF 实现同样的功能
df.map { case Row(title: String, author: String) => (title, author, title.length) }.show
spark.stop()
}
}
2、UDAF
数据如下:
id, name, sales, discount, state, saleDate
1, "Widget Co", 1000.00, 0.00, "AZ", "2019-01-01"
2, "Acme Widgets", 2000.00, 500.00, "CA", "2019-02-01"
3, "Widgetry", 1000.00, 200.00, "CA", "2020-01-11"
4, "Widgets R Us", 2000.00, 0.0, "CA", "2020-02-19"
5, "Ye Olde Widgete", 3000.00, 0.0, "MA", "2020-02-28"
最后要得到的结果为:
(2020年的合计值 – 2019年的合计值) / 2019年的合计值
(6000 - 3000) / 3000 = 1
执行以下SQL得到最终的结果:
select userFunc(sales, saleDate) from table1;
即计算逻辑在userFunc中实现
定义初值
分区内合并
分区间合并
计算最终结果
普通的UDF不支持数据的聚合运算。如当要对销售数据执行年度同比计算,就需要对当年和上一年的销量分别求和,然后再利用公式进行计算。此时需要使用UDAF,Spark为所有的UDAF定义了一个父类UserDefinedAggregateFunction 。要继承这个类,需要实现父类的几个抽象方法:
- inputSchema用于定义与DataFrame列有关的输入样式
- bufferSchema用于定义存储聚合运算时产生的中间数据结果的Schema
- dataType标明了UDAF函数的返回值类型
- deterministic是一个布尔值,用以标记针对给定的一组输入,UDAF是否总是生成相同的结果
- initialize对聚合运算中间结果的初始化
- update函数的第一个参数为bufferSchema中两个Field的索引,默认以0开始;UDAF的核心计算都发生在update函数中;update函数的第二个参数input: Row对应的并非DataFrame的行,而是被inputSchema投影了的行
- merge函数负责合并两个聚合运算的buffer,再将其存储到MutableAggregationBuffer中
- evaluate函数完成对聚合Buffer值的运算,得到最终的结果
UDAF--类型不安全
package com.ch.sparksql
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, StringType, StructType}
import org.apache.spark.sql.{Row, SparkSession}
class TypeUnsafeUDAF extends UserDefinedAggregateFunction {
// 定义输入数据的类型, 名字可以随便写, 属于形参列表
override def inputSchema: StructType =
new StructType().add("sales1", DoubleType).add("saleDate1", StringType)
// 定义数据缓存的类型, 即计算数据存放的位置
override def bufferSchema: StructType =
new StructType().add("year2019", DoubleType).add("year2020", DoubleType)
// 定义最终返回结果的类型
override def dataType: DataType = DoubleType
// 对于相同的结果是否有相同的输出 90%以上的场合都是true
override def deterministic: Boolean = true
// 数据缓存的初始化,
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0, 0.0) // 第一个位置, 初值为 0.0
buffer.update(1, 0.0) // 第二个位置, 初值为 0.0
}
// 分区内数据合并, 因为传入参数是 input, 不需要返回值, 因为直接更新了 buffer
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
// 销售量、上面的 input 位置为 0 的 sales1
val sales = input.getAs[Double](0)
// 销售日期(year) 上面的 input 位置为 1 的 saleDate, 取字符串的前 4 个
val saleYear = input.getAs[String](1).take(4)
// 根据 saleYear 的值, 加到不同的 buffer 中
saleYear match{
case "2019" => buffer(0) = buffer.getAs[Double](0) + sales
case "2020" => buffer(1) = buffer.getAs[Double](1) + sales
case _ => println("Error!")
}
}
// 分区间数据合并, 因为传入数据是 buffer
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
// 两个分区的 buffer1 buffer2 的相对位置数据 分别相加
buffer1(0) = buffer1.getAs[Double](0) + buffer2.getAs[Double](0)
buffer1(1) = buffer1.getAs[Double](1) + buffer2.getAs[Double](1)
}
// 计算最终的结果, 另外修改返回值为 Double 类型
override def evaluate(buffer: Row): Double = {
// (2020年的合计值 – 2019年的合计值) / 2019年的合计值
// 因为浮点型数据的特点, 写成下面的格式
// 这里的 buffer 即为上面的 buffer1 和 buffer2 合并后的结果
if (math.abs(buffer.getAs[Double](0)) < 0.000000001) 0.0
else (buffer.getAs[Double](1) - buffer.getAs[Double](0)) / buffer.getAs[Double](0)
}
}
object TypeUnsafeUDAFTest{
def main(args: Array[String]): Unit = {
Logger.getLogger("org").setLevel(Level.WARN)
val spark = SparkSession.builder()
.appName(s"${this.getClass.getCanonicalName}")
.master("local[*]")
.getOrCreate()
val sales = Seq(
(1, "Widget Co", 1000.00, 0.00, "AZ", "2019-01-02"),
(2, "Acme Widgets", 1000.00, 500.00, "CA", "2019-02-01"),
(3, "Widgetry", 1000.00, 200.00, "CA", "2020-01-11"),
(4, "Widgets R Us", 2000.00, 0.0, "CA", "2020-02-19"),
(5, "Ye Olde Widgete", 3000.00, 0.0, "MA", "2020-02-28"))
val salesDF = spark.createDataFrame(sales)
.toDF("id", "name", "sales", "discount", "state", "saleDate")
salesDF.createTempView("sales")
// 注册自定义函数
val userFunc = new TypeUnsafeUDAF
spark.udf.register("userFunc", userFunc)
// 执行 SQL
spark.sql("select userFunc(sales, saleDate) as rate from sales").show()
spark.stop()
}
}
UDAF--类型安全
package com.ch.sparksql
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{Encoder, Encoders, SparkSession, TypedColumn}
// 定义两个样例类用来存放数据
case class Sales(id: Int, name1: String, sales: Double, discount: Double, name2: String, stime: String)
case class SalesBuffer(var sales2019: Double, var sales2020: Double)
// 直接在下面把输入数据说明清楚, 则下面的各方法参数会自动生成 [输入, ]
class TypeSafeUDAF extends Aggregator[Sales, SalesBuffer, Double]{
// 定义初值
override def zero: SalesBuffer = SalesBuffer(0.0, 0.0)
// 分区内的数据合并, 数据类型需要自己定义
override def reduce(buffer: SalesBuffer, input: Sales): SalesBuffer = {
val sales: Double = input.sales
val year = input.stime.take(4) // 取出前4位字符
year match {
case "2019" => buffer.sales2019 += sales
case "2020" => buffer.sales2020 += sales
case _ => println("ERROR")
}
buffer // 把分区内的合并数据返回出去
}
// 两个buffer, 分区间的数据合并
override def merge(b1: SalesBuffer, b2: SalesBuffer): SalesBuffer = {
SalesBuffer(b1.sales2019 + b2.sales2019, b1.sales2020 + b2.sales2020)
}
// 计算最终结果
override def finish(reduction: SalesBuffer): Double = {
if (math.abs(reduction.sales2019) < 0.0000000001) 0.0
else (reduction.sales2020 - reduction.sales2019) / reduction.sales2019
}
// 定义编码器, case class → product
override def bufferEncoder: Encoder[SalesBuffer] = Encoders.product
override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
object TypeSafeUDAFTest{
def main(args: Array[String]): Unit = {
Logger.getLogger("org").setLevel(Level.WARN)
val spark = SparkSession.builder()
.appName(s"${this.getClass.getCanonicalName}")
.master("local[*]")
.getOrCreate()
val sales = Seq(
Sales(1, "Widget Co", 1000.00, 0.00, "AZ", "2019-01-02"),
Sales(2, "Acme Widgets", 2000.00, 500.00, "CA", "2019-02-01"),
Sales(3, "Widgetry", 1000.00, 200.00, "CA", "2020-01-11"),
Sales(4, "Widgets R Us", 2000.00, 0.0, "CA", "2020-02-19"),
Sales(5, "Ye Olde Widgete", 3000.00, 0.0, "MA", "2020-02-28"))
import spark.implicits._
val ds = spark.createDataset(sales)
ds.show
val rate: TypedColumn[Sales, Double] = new TypeSafeUDAF().toColumn.name("rate")
ds.select(rate).show
spark.stop()
}
}
2.9 访问Hive
在 pom 文件中增加依赖:
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-hive_2.12</artifactId>
<version>${spark.version}</version>
</dependency>
在 resources中增加hive-site.xml文件,在文件中增加内容:
<configuration>
<property>
<name>hive.metastore.uris</name>
<value>thrift://linux123:9083</value>
</property>
</configuration>
备注:最好使用 metastore service 连接Hive;使用直连 metastore 的方式时,SparkSQL程序会修改 Hive 的版本信息;
默认Spark使用 Hive 1.2.1进行编译,包含对应的serde, udf, udaf等。
package com.ch.sparksql
import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}
object AccessHive {
def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder()
.appName("Demo1")
.master("local[*]")
.enableHiveSupport()
// Spark使用与Hive相同的约定写parquet数据
.config("spark.sql.parquet.writeLegacyFormat", "true")
.getOrCreate()
val sc = spark.sparkContext
sc.setLogLevel("warn")
spark.sql("show databases").show
spark.sql("select * from ods.ods_trade_product_info").show
val df: DataFrame = spark.table("ods.ods_trade_product_info")
df.show()
// 保存到另行创建的表
df.write.mode(SaveMode.Append).saveAsTable("ods.ods_trade_product_info_backup")
spark.table("ods.ods_trade_product_info_backup").show
spark.close()
}
}
第3节 Spark SQL原理
2.1 SparkSQL中的join
数据分析中将两个数据集进行 Join 操作是很常见的场景。在 Spark 的物理计划阶段,Spark 的 Join Selection 类会根据 Join hints 策略、Join 表的大小、 Join 是等值 Join 还是不等值以及参与 Join 的 key 是否可以排序等条件来选择最终的 Join 策略,最后 Spark 会利用选择好的 Join 策略执行最终的计算。当前 Spark 一共支持五种 Join 策略:
- Broadcast hash join (BHJ) → 系统自动在map端join
- Shuffle hash join(SHJ)
- Shuffle sort merge join (SMJ)
- Shuffle-and-replicate nested loop join,又称笛卡尔积(Cartesian product join)
- Broadcast nested loop join (BNLJ)
其中 BHJ 和 SMJ 这两种 Join 策略是我们运行 Spark 作业最常见的。JoinSelection 会先根据 Join 的 Key 为等值 Join来选择 Broadcast hash join、Shuffle hash join 以及 Shuffle sort merge join 中的一个;如果 Join 的 Key 为不等值Join 或者没有指定 Join 条件,则会选择 Broadcast nested loop join 或 Shuffle-and-replicate nested loop join。
不同的 Join 策略在执行上效率差别很大,了解每种 Join 策略的执行过程和适用条件是很有必要的。
1、Broadcast Hash Join
Broadcast Hash Join 的实现是将小表的数据广播到 Spark 所有的 Executor 端,这个广播过程和我们自己去广播数据没什么区别:
- 利用 collect 算子将小表的数据从 Executor 端拉到 Driver 端
- 在 Driver 端调用 sparkContext.broadcast 广播到所有 Executor 端
- 在 Executor 端使用广播的数据与大表进行 Join 操作(实际上是执行map操作)
这种 Join 策略避免了 Shuffle 操作。一般而言,Broadcast Hash Join 会比其他 Join 策略执行的要快。
使用这种 Join 策略必须满足以下条件:
- 小表的数据必须很小,可以通过 spark.sql.autoBroadcastJoinThreshold 参数来配置,默认是 10MB
- 如果内存比较大,可以将阈值适当加大
- 将 spark.sql.autoBroadcastJoinThreshold 参数设置为 -1,可以关闭这种连接方式
- 只能用于等值 Join,不要求参与 Join 的 keys 可排序
2、Shuffle Hash Join
当表中的数据比较大,又不适合使用广播,这个时候就可以考虑使用 Shuffle Hash Join。
Shuffle Hash Join 同样是在大表和小表进行 Join 的时候选择的一种策略。它的计算思想是:把大表和小表按照相同的分区算法和分区数进行分区(根据参与 Join 的 keys 进行分区),这样就保证了 hash 值一样的数据都分发到同一个分区中,然后在同一个 Executor 中两张表 hash 值一样的分区就可以在本地进行 hash Join 了。在进行 Join 之前,还会对小表的分区构建 Hash Map。Shuffle hash join 利用了分治思想,把大问题拆解成小问题去解决。
要启用 Shuffle Hash Join 必须满足以下条件:
- 仅支持等值 Join,不要求参与 Join 的 Keys 可排序
- spark.sql.join.preferSortMergeJoin 参数必须设置为 false,参数是从 Spark 2.0.0 版本引入的,默认值为true,也就是默认情况下选择 Sort Merge Join
- 小表的大小(plan.stats.sizeInBytes)必须小于 spark.sql.autoBroadcastJoinThreshold * spark.sql.shuffle.partitions(默认值200)
- 而且小表大小(stats.sizeInBytes)的三倍必须小于等于大表的大小(stats.sizeInBytes),也就是a.stats.sizeInBytes * 3 < = b.stats.sizeInBytes
3、Shuffle Sort Merge Join
前面两种 Join 策略对表的大小都有条件的,如果参与 Join 的表都很大,这时候就得考虑用 Shuffle Sort Merge Join了。
Shuffle Sort Merge Join 的实现思想:
- 将两张表按照 join key 进行shuffle,保证join key值相同的记录会被分在相应的分区
- 对每个分区内的数据进行排序
- 排序后再对相应的分区内的记录进行连接
无论分区有多大,Sort Merge Join都不用把一侧的数据全部加载到内存中,而是即用即丢;因为两个序列都有序。从头遍历,碰到key相同的就输出,如果不同,左边小就继续取左边,反之取右边。从而大大提高了大数据量下sql join的稳定性。
要启用 Shuffle Sort Merge Join 必须满足以下条件:
- 仅支持等值 Join,并且要求参与 Join 的 Keys 可排序
4、Cartesian product join(了解)
如果 Spark 中两张参与 Join 的表没指定连接条件,那么会产生 Cartesian product join,这个 Join 得到的结果其实就是两张表行数的乘积。
5、Broadcast nested loop join(了解)
可以把 Broadcast nested loop join 的执行看做下面的计算:
for record_1 in relation_1:
for record_2 in relation_2:
# join condition is executed
可以看出 Broadcast nested loop join 在某些情况会对某张表重复扫描多次,效率非常低下。从名字可以看出,这种join 会根据相关条件对小表进行广播,以减少表的扫描次数。
Broadcast nested loop join 支持等值和不等值 Join,支持所有的 Join 类型。
2.2 SQL解析过程
Spark SQL 可以说是 Spark 中的精华部分。原来基于 RDD 构建大数据计算任务,重心在向 DataSet 转移,原来基于RDD 写的代码也在迁移。使用 Spark SQL 编码好处是非常大的,尤其是在性能方面,有很大提升。Spark SQL 中各种内嵌的性能优化比写 RDD 遵守各种最佳实践更靠谱的,尤其对新手来说。如先 filter 操作再 map 操作,SparkSQL 中会自动进行谓词下推;Spark SQL中会自动使用 broadcast join 来广播小表,把 shuffle join 转化为 map join等等。
Spark SQL对SQL语句的处理和关系型数据库类似,即词法/语法解析、绑定、优化、执行。Spark SQL会先将SQL语句解析成一棵树,然后使用规则(Rule)对Tree进行绑定、优化等处理过程。Spark SQL由Core、Catalyst、Hive、Hive-ThriftServer四部分构成:
- Core: 负责处理数据的输入和输出,如获取数据,查询结果输出成DataFrame等
- Catalyst: 负责处理整个查询过程,包括解析、绑定、优化等 (核心)
- Hive: 负责对Hive数据进行处理
- Hive-ThriftServer: 主要用于对Hive的访问
Spark SQL的代码复杂度是问题的本质复杂度带来的,Spark SQL中的 Catalyst 框架大部分逻辑是在一个 Tree 类型的数据结构上做各种折腾,基于 Scala 来实现还是很优雅的,Scala 的偏函数和强大的 Case 正则匹配,让整个代码看起来非常优雅。
SparkSession 是编写 Spark 应用代码的入口,启动一个 spark-shell 会提供给你一个创建 SparkSession, 这个对象是整个 Spark 应用的起始点。以下是 SparkSession 的一些重要的变量和方法:
package com.ch.sparksql
import org.apache.spark.sql.{DataFrame, SparkSession}
object Plan {
def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder()
.appName("Demo1")
.master("local[*]")
.getOrCreate()
spark.sparkContext.setLogLevel("warn")
import spark.implicits._
Seq((0, "zhansan", 10),
(1, "lisi", 11),
(2, "wangwu", 12)).toDF("id", "name", "age").createOrReplaceTempView("stu")
Seq((0, "chinese", 80), (0, "math", 100), (0, "english", 98),
(1, "chinese", 86), (1, "math", 97), (1, "english", 90),
(2, "chinese", 90), (2, "math", 94), (2, "english", 88)
).toDF("id", "subject", "score").createOrReplaceTempView("score")
val df: DataFrame = spark.sql(
"""
|select sum(v), name
| from (select stu.id, 100 + 10 + score.score as v, name
| from stu join score
| where stu.id = score.id and stu.age >= 11) tmp
|group by name
|""".stripMargin)
println(df.queryExecution)
// df.show()
val df1: DataFrame = spark.sql(
"""
|select sum(v), name
| from (select stu.id, 100 + 10 + score.score as v, name
| from stu join score on stu.id = score.id where stu.age >= 11) tmp
|group by name
|""".stripMargin)
println(df1.queryExecution)
val df2: DataFrame = spark.sql(
"""
|select sum(v), name
| from (select stu.id, 100 + 10 + score.score as v, name
| from stu join score on stu.id = score.id where stu.age >= 11) tmp
|group by name
|""".stripMargin)
println(df2.queryExecution)
// df1.show()
// 打印执行计划
// println(df.queryExecution)
spark.close()
}
}
queryExecution 就是整个执行计划的执行引擎,里面有执行过程中各个中间过程变量,整个执行流程如下:
上面例子中的 SQL 语句经过 Parser 解析后就会变成一个抽象语法树,对应解析后的逻辑计划 AST 为:
== Parsed Logical Plan ==
'Aggregate ['name], [unresolvedalias('sum('v), None), 'name]
+- 'SubqueryAlias `tmp`
+- 'Project ['stu.id, ((100 + 10) + 'score.score) AS v#26, 'name]
+- 'Filter (('stu.id = 'score.id) && ('stu.age >= 11))
+- 'Join Inner
:- 'UnresolvedRelation `stu`
+- 'UnresolvedRelation `score`
备注:在执行计划中 Project/Projection 代表的意思是投影
选, 投, 连, 三种最基本的操作
其中过滤条件变为了 Filter 节点,这个节点是 UnaryNode(一元节点) 类型, 只有一个孩子。两个表中的数据变为了UnresolvedRelation 节点,节点类型为 LeafNode ,即叶子节点, JOIN 操作为节点, 这个是一个 BinaryNode 节点,有两个孩子。
以上节点都是 LogicalPlan 类型的, 可以理解为进行各种操作的 Operator, SparkSQL 对各种操作定义了各种Operator。
这些 operator 组成的抽象语法树就是整个 Catatyst 优化的基础,Catatyst 优化器会在这个树上面进行各种折腾,把树上面的节点挪来挪去来进行优化。
经过 Parser 有了抽象语法树,但是并不知道 score,sum 这些东西是啥,所以就需要 analyer 来定位。
analyzer 会把 AST 上所有 Unresolved 的东西都转变为 resolved 状态,SparkSQL 有很多resolve 规则:
- ResolverRelations。解析表(列)的基本类型等信息
- ResolveFuncions。解析出来函数的基本信息
- ResolveReferences。解析引用,通常是解析列名
== Analyzed Logical Plan ==
sum(v): bigint, name: string
Aggregate [name#8], [sum(cast(v#26 as bigint)) AS sum(v)#28L, name#8]
+- SubqueryAlias `tmp`
+- Project [id#7, ((100 + 10) + score#22) AS v#26, name#8]
+- Filter ((id#7 = id#20) && (age#9 >= 11))
+- Join Inner
:- SubqueryAlias `stu`
: +- Project [_1#3 AS id#7, _2#4 AS name#8, _3#5 AS age#9]
: +- LocalRelation [_1#3, _2#4, _3#5]
+- SubqueryAlias `score`
+- Project [_1#16 AS id#20, _2#17 AS subject#21, _3#18 AS score#22]
+- LocalRelation [_1#16, _2#17, _3#18]
下面要进行逻辑优化了,常见的逻辑优化有:
== Optimized Logical Plan ==
Aggregate [name#8], [sum(cast(v#26 as bigint)) AS sum(v)#28L, name#8]
+- Project [(110 + score#22) AS v#26, name#8]
+- Join Inner, (id#7 = id#20)
:- LocalRelation [id#7, name#8]
+- LocalRelation [id#20, score#22]
这里用到的优化有:谓词下推(Push Down Predicate)、常量折叠(Constant Folding)、字段裁剪(Columning Pruning)
做完逻辑优化,还需要先转换为物理执行计划,将逻辑上可行的执行计划变为 Spark 可以真正执行的计划:
SparkSQL 把逻辑节点转换为了相应的物理节点, 比如 Join 算子,Spark 根据不同场景为该算子制定了不同的算法策略。 (代码生成)
== Physical Plan ==
*(2) HashAggregate(keys=[name#8], functions=[sum(cast(v#26 as bigint))], output=[sum(v)#28L, name#8])
+- Exchange hashpartitioning(name#8, 200)
+- *(1) HashAggregate(keys=[name#8], functions=[partial_sum(cast(v#26 as bigint))], output=[name#8, sum#38L])
+- *(1) Project [(110 + score#22) AS v#26, name#8]
+- *(1) BroadcastHashJoin [id#7], [id#20], Inner, BuildLeft
:- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)))
: +- LocalTableScan [id#7, name#8]
+- LocalTableScan [id#20, score#22]
数据在一个一个的 plan 中流转,然后每个 plan 里面表达式都会对数据进行处理,就相当于经过了一个个小函数的调用处理,这里面有大量的函数调用开销。是不是可以把这些小函数内联一下,当成一个大函数,WholeStageCodegen 就是干这事的。可以看到最终执行计划每个节点前面有个 * 号,说明整段代码生成被启用,Project、BroadcastHashJoin、HashAggregate 这一段都启用了整段代码生成,级联为了大函数。