Spark代码可读性与性能优化——示例六(GroupBy、ReduceByKey)
1. 普通常见优化示例
1.1 错误示例 groupByKey
import org.apache.spark.{SparkConf, SparkContext}
object GroupNormal {
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName("GroupNormal")
val sc = new SparkContext(conf)
// 数据可能有几亿条,此处只做模拟示例
val dataRDD = sc.parallelize(List(
("hello", 2),
("java", 7),
("where", 1),
("rust", 2),
// 中间还有很多数据,不做展示
("scala", 1),
("java", 1),
("black", 9)
))
// 做一个词频统计
val result = dataRDD.groupByKey()
.mapValues(_.sum)
.sortBy(_._2, false)
result.take(10).foreach(println)
sc.stop()
}
}
1.2 正确示例 reduceByKey
// 修改此部分groupByKey代码为reduceByKey
val result = dataRDD
.reduceByKey(_ + _)
.sortBy(_._2, false)
result.take(10).foreach(println)
2. 高级优化
2.0. 需求:统计历年全国高考生中数学成绩前100名
2.1 数据示例
id | chinese | math | english | year |
3412312 | 121 | 115 | 134 | 2018 |
5231211 | 103 | 131 | 114 | 2010 |
…… | …… | …… | …… | …… |
2342354 | 134 | 105 | 124 | 2014 |
共计约2亿条数据
数据存于Hive中,表名tb_student_score,id值(唯一)代表学生,chinese代表语文,math代表数学,english代表英语
2.2 存在问题的代码示例
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
/**
* 数据分组错误示例
*
* @author ALion
* @version 2019/5/15 22:33
*/
object GroupDemo {
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName("GroupDemo")
val spark = SparkSession.builder()
.config(conf)
.enableHiveSupport()
.getOrCreate()
// 获取原始数据
val studentDF = spark.sql(
"""
|SELECT *
|FROM tb_student_score
|WHERE id IS NOT NULL AND math IS NOT NULL AND year IS NOT NULL
""".stripMargin)
// 开始进行分析
val resultRDD = studentDF.rdd
.map(row => {
val id = row.getLong(row.fieldIndex("id"))
val math = row.getInt(row.fieldIndex("math"))
val year = row.getInt(row.fieldIndex("year"))
(year, (id, math))
})
.groupByKey() // 按年分组
.mapValues(_.toSeq.sortWith(_._2 > _._2).take(100)) // 根据math对每个人进行降序排序,最后获取前100的人
// 触发Action,展示部分统计结果
resultRDD.take(10).foreach(println)
spark.stop()
}
}
首先,可以肯定的是代码逻辑毫无问题,能够满足业务需求。
其次,这部分代码又存在很大的性能问题:
spark.sql(“SELECT * FROM tb_student_score”)这种形势读取表中数据较慢,有更快的方式
groupByKey处,发生shuffle,大量数据被分到对应的年份的节点中,然后每个节点使用单线程在各年对应的所有数据中对学生进行排序,最后获取前100名
groupByKey处的shuffle可能发生数据倾斜,可能存在部分年份的数据不全或参考人数较少,而部分年份数据较多
另外,直接使用SQL的方案已附在文章末尾
2.3 如何解决代码中的问题?
首先,读取表可以采用DataFrame的API,指定Schema,能够加速表的读取
val tbSchema = StructType(Array(
StructField("id", LongType, true),
StructField("chinese", IntegerType, true),
StructField("math", IntegerType, true),
StructField("english", IntegerType, true),
StructField("year", IntegerType, true)
))
// 获取原始数据
val studentDF = spark.read.schema(tbSchema).table("tb_student_score")
.where("id IS NOT NULL AND math IS NOT NULL AND year IS NOT NULL")
其次,关于groupBy发生shuffle的问题以及排序的问题。似乎数据如果不按年份分组,针对每年所有的分数统一排序,就没有其他办法。因为待排序的数据不在一起好像就不能完整的排序啊?那还怎么谈取前100名啊?
其实不然,想想我们是不是可以先在每个数据分块本地排序一次获取前100名,最后将所有的前100汇总,进行一次总的排序获取总的前100名?这样的话,充分利用了每个分块的并行计算,提前做了部分排序,当数据shuffle的时候每个分块数据就只有100条,最后汇总进行一次排序的数据量就非常小了!其实这就是归并排序的思想,感兴趣的朋友可以搜索‘归并排序’看看。
优化后的示例代码如下:
// 开始进行分析
val resultRDD = studentDF.rdd
.mapPartitions {
// 自己实现时,如果为了性能更好,不建议这样的函数式写法
// 这里只是为了方便看
_.map { row =>
val id = row.getLong(row.fieldIndex("id"))
val math = row.getInt(row.fieldIndex("math"))
val year = row.getInt(row.fieldIndex("year"))
(year, (id, math))
}.toArray
.groupBy(_._1) // 先在每个分块前,获取历年的数学前100名,减少后续groupBy的shuffle数据量
.mapValues(_.map(_._2).sortWith(_._2 > _._2).take(100))
.toIterator
}.groupByKey() // 最后获取所有分块的前100名,再次排序,计算总的前100名
.mapValues(_.flatten.toSeq.sortWith(_._2 > _._2).take(100))
// 触发Action,展示部分统计结果
resultRDD.take(10).foreach(println)
上述代码,已经完成功能实现。那么,这样的代码是否是最好的呢?答案是否定的。因为当前的排序是针对每个分块(Partition)的,一个Executor上有多个分块,每个分块有前100条数据需要shuffle,显然如果一个Executor一共只有100条数据需要shuffle才是最理想的!如果我们能有办法同时操纵每个Executor上的所有数据,获取前100条数据,那该多好啊!
我们想要的排序流程示意图如下:
然而,Spark并没提供一个类似mapPartition的可以对Executor上所有分块统一操作的算子(不然的话,我们就可以像mapPartion那样统计每Executor的前100名了)。不过我们有一个算子reduceByKey,它会在每个节点合并数据后再shuffle到一个节点进行最后的合并,这种行为似乎与我们需要的逻辑类似,不过好像又有那么一点不一样。
你可能会说reduceByKey是合并,而我们的需求是排序啊!!!是的,这看上去似乎有点矛盾。
事实上,这样是行得通的:
首先,让我们假想有这样一个集合类型A(内部是可排序的,并且只能拥有前100的数据,多余的会被删除)
接着,把每个元素(id,math)转换成含有一个元素的集合A
最后,使用reduceByKey,将每个集合依次相加合并!!!没错!就是合并!这样最后一个集合就是包含前100名的集合了。
这样一个集合类型A,似乎在Scala、Java中不存在,不过有一个TreeSet能保证内部有序,我们可以在数据合并后手动提取前100,这样就可以了(另外,你也可以自己实现这样一个集合:3)
第一步,先将id和math转为一个对象,并为这个对象实现equals、hashCode、compareTo方法,保证后续在TreeSet中的排序不会出问题。另外,再实现一个toString方法,方便我们查看打印效果!:)
Person.class 代码 (因为Java比较易懂、易写这几个方法,这里优先采用Java的形式,后面会附上Scala对应的实现类)
public class Person implements Comparable<Person>, Serializable {
private long id;
private int math;
public Person(long id, int math) {
this.id = id;
this.math = math;
}
@Override
public int compareTo(Person person) {
int result = person.math - this.math; // 降序
if (result == 0) {
result = person.id - this.id > 0 ? 1 : -1;
}
return result;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Person person = (Person) o;
return id == person.id;
}
@Override
public int hashCode() {
return (int) (id ^ (id >>> 32));
}
@Override
public String toString() {
return "Person{" +
"id='" + id + '\'' +
", math=" + math +
'}';
}
}
TreeSet 使用示例
import scala.collection.immutable.TreeSet
object Demo {
def main(args: Array[String]): Unit = {
val set = TreeSet[Person](
new Person(1231232L, 108),
new Person(3214124L, 116),
new Person(1321313L, 121),
new Person(6435235L, 125)
)
// 获取前3名
for (elem <- set.take(3)) {
println(s"--> elem = $elem")
}
}
}
第二步,将原先的id、math封装为TreeSet
studentDF.rdd
.map(row => {
val id = row.getLong(row.fieldIndex("id"))
val math = row.getInt(row.fieldIndex("math"))
val year = row.getInt(row.fieldIndex("year"))
(year, TreeSet(new Person(id, math)))
})
最后,使用reduceByKey合并所有数据,得到前100名的结果
val resultRDD = studentDF.rdd
.map(row => {
val id = row.getLong(row.fieldIndex("id"))
val math = row.getInt(row.fieldIndex("math"))
val year = row.getInt(row.fieldIndex("year"))
(year, TreeSet(new Person(id, math)))
})
.reduceByKey((set1, set2) => set1 ++ set2 take 100) // 依次合并2个Set,并只保留前100
resultRDD.take(10).foreach(println)
Nice!!! 这样,我们就同时解决了排序问题和数据倾斜问题!
进一步优化(aggregateByKey)
细心的朋友应该已经发现了,reduceByKey之前的map为每条的数据都生成了一个TreeSet,这样会大大增加内存消耗。
其实,我们只想要每个节点放一个可变的TreeSet(并且还能一直只存前100)。这样内存消耗就会更小!
那么我们该如何做呢?设计一个MyTreeSet,采用aggregateByKey复用同一个Set,简略的示例如下:
MyTreeSet(简易实现,针对mutable.TreeSet封装)
import scala.collection.mutable
class MyTreeSet[A](firstNum: Int, elem: Seq[A])(implicit val ord: Ordering[A]) {
val set: mutable.TreeSet[A] = mutable.TreeSet[A](elem: _*)
def +=(elem: A): MyTreeSet[A] = {
this add elem
this
}
def add(elem: A): Unit = {
set.add(elem)
// 删除排在最后的多余元素
check10Size()
}
def ++=(that: MyTreeSet[A]) : MyTreeSet[A] = {
that.set.foreach(e => this add e)
this
}
def check10Size(): Unit = {
// 如果超过了firstNum个,就删除
if (set.size > firstNum) {
set -= set.last
}
}
override def toString: String = set.toString
}
object MyTreeSet {
def apply[A](elem: A*)(implicit ord: Ordering[A]): MyTreeSet[A] = new MyTreeSet[A](10, elem) // 默认保留前10
def apply[A](firstNum: Int, elem: A*)(implicit ord: Ordering[A]): MyTreeSet[A] = new MyTreeSet[A](firstNum, elem)
}
Spark部分代码
val resultRDD = studentDF.rdd
.map(row => {
val id = row.getLong(row.fieldIndex("id"))
val math = row.getInt(row.fieldIndex("math"))
val year = row.getInt(row.fieldIndex("year"))
(year, new Person(id, math))
}).aggregateByKey(MyTreeSet[Person](100)) (
(set, v) => set += v,
(set1, set2) => set1 ++= set2
)
2.4 最终代码,以及其他附件代码
最终代码
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.{IntegerType, LongType, StructField, StructType}
import scala.collection.immutable.TreeSet
object GroupDemo3 {
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName("GroupDemo")
val spark = SparkSession.builder()
.config(conf)
.enableHiveSupport()
.getOrCreate()
val tbSchema = StructType(Array(
StructField("id", LongType, true),
StructField("chinese", IntegerType, true),
StructField("math", IntegerType, true),
StructField("english", IntegerType, true),
StructField("year", IntegerType, true)
))
// 获取原始数据
val studentDF = spark.read.schema(tbSchema).table("tb_student_score")
.where("id IS NOT NULL AND math IS NOT NULL AND year IS NOT NULL")
// 开始进行分析
val resultRDD = studentDF.rdd
.map(row => {
val id = row.getLong(row.fieldIndex("id"))
val math = row.getInt(row.fieldIndex("math"))
val year = row.getInt(row.fieldIndex("year"))
(year, new Person(id, math))
}).aggregateByKey(MyTreeSet[Person](100)) (
(set, v) => set += v,
(set1, set2) => set1 ++= set2
) // 依次合并2个Set,并只保留前100
// 触发Action,展示部分统计结果
resultRDD.take(10).foreach(println)
spark.stop()
}
}
Person的Scala实现
class PersonScala(val id: Long, val math: Int) extends Ordered[PersonScala] with Serializable {
override def compare(that: PersonScala): Int = {
var result = that.math - this.math // 降序
if (result == 0)
result = if (that.id - this.id > 0) 1 else -1
result
}
override def equals(obj: Any): Boolean = {
obj match {
case person: PersonScala => this.id == person.id
case _ => false
}
}
override def hashCode(): Int = (id ^ (id >>> 32)).toInt
override def toString: String = "Person{" + "id=" + id + ", math=" + math + '}'
}
object PersonScala {
def apply(id: Long, math: Int): PersonScala = new PersonScala(id, math)
}
示例——使用SQL获取历年数学的前100名(简单,但性能一般,且存在数据倾斜的可能)
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName("GroupDemo")
val spark = SparkSession.builder()
.config(conf)
.enableHiveSupport()
.getOrCreate()
// 使用sql分析
val resultDF = spark.sql(
"""
|SELECT year,id,math
|FROM (
| SELECT year,id,math,ROW_NUMBER() OVER (PARTITION BY year ORDER BY math DESC) rank
| FROM tb_student_score
|) g
|WHERE g.rank <= 100
""".stripMargin)
// 触发Action,展示部分统计结果
resultDF.show()
spark.stop()
}