Spark代码可读性与性能优化——示例六(GroupBy、ReduceByKey)

1. 普通常见优化示例

1.1 错误示例 groupByKey

import org.apache.spark.{SparkConf, SparkContext}

object GroupNormal {

  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setAppName("GroupNormal")
    val sc = new SparkContext(conf)

    // 数据可能有几亿条,此处只做模拟示例
    val dataRDD = sc.parallelize(List(
      ("hello", 2),
      ("java", 7),
      ("where", 1),
      ("rust", 2),
      // 中间还有很多数据,不做展示
      ("scala", 1),
      ("java", 1),
      ("black", 9)
    ))

    // 做一个词频统计
    val result = dataRDD.groupByKey()
      .mapValues(_.sum)
      .sortBy(_._2, false)

    result.take(10).foreach(println)

    sc.stop()
  }

}

1.2 正确示例 reduceByKey

// 修改此部分groupByKey代码为reduceByKey
    val result = dataRDD
      .reduceByKey(_ + _)
      .sortBy(_._2, false)

    result.take(10).foreach(println)

2. 高级优化

2.0. 需求:统计历年全国高考生中数学成绩前100名

2.1 数据示例

id

chinese

math

english

year

3412312

121

115

134

2018

5231211

103

131

114

2010

……

……

……

……

……

2342354

134

105

124

2014

共计约2亿条数据
数据存于Hive中,表名tb_student_score,id值(唯一)代表学生,chinese代表语文,math代表数学,english代表英语

2.2 存在问题的代码示例

import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession

/**
  * 数据分组错误示例
  *
  * @author ALion
  * @version 2019/5/15 22:33
  */
object GroupDemo {

  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setAppName("GroupDemo")
    val spark = SparkSession.builder()
      .config(conf)
      .enableHiveSupport()
      .getOrCreate()

    // 获取原始数据
    val studentDF = spark.sql(
      """
        |SELECT * 
        |FROM tb_student_score
        |WHERE id IS NOT NULL AND math IS NOT NULL AND year IS NOT NULL
      """.stripMargin)
    
    // 开始进行分析
    val resultRDD = studentDF.rdd
      .map(row => {
        val id = row.getLong(row.fieldIndex("id"))
        val math = row.getInt(row.fieldIndex("math"))
        val year = row.getInt(row.fieldIndex("year"))

        (year, (id, math))
      })
      .groupByKey() // 按年分组
      .mapValues(_.toSeq.sortWith(_._2 > _._2).take(100)) // 根据math对每个人进行降序排序,最后获取前100的人
    
    // 触发Action,展示部分统计结果
    resultRDD.take(10).foreach(println)

    spark.stop()
  }

}

首先,可以肯定的是代码逻辑毫无问题,能够满足业务需求。
其次,这部分代码又存在很大的性能问题:
spark.sql(“SELECT * FROM tb_student_score”)这种形势读取表中数据较慢,有更快的方式
groupByKey处,发生shuffle,大量数据被分到对应的年份的节点中,然后每个节点使用单线程在各年对应的所有数据中对学生进行排序,最后获取前100名
groupByKey处的shuffle可能发生数据倾斜,可能存在部分年份的数据不全或参考人数较少,而部分年份数据较多
另外,直接使用SQL的方案已附在文章末尾

2.3 如何解决代码中的问题?

首先,读取表可以采用DataFrame的API,指定Schema,能够加速表的读取

val tbSchema = StructType(Array(
	  StructField("id", LongType, true),
	  StructField("chinese", IntegerType, true),
	  StructField("math", IntegerType, true),
	  StructField("english", IntegerType, true),
	  StructField("year", IntegerType, true)
	))
	
	// 获取原始数据
	val studentDF = spark.read.schema(tbSchema).table("tb_student_score")
	      .where("id IS NOT NULL AND math IS NOT NULL AND year IS NOT NULL")

其次,关于groupBy发生shuffle的问题以及排序的问题。似乎数据如果不按年份分组,针对每年所有的分数统一排序,就没有其他办法。因为待排序的数据不在一起好像就不能完整的排序啊?那还怎么谈取前100名啊?
其实不然,想想我们是不是可以先在每个数据分块本地排序一次获取前100名,最后将所有的前100汇总,进行一次总的排序获取总的前100名?这样的话,充分利用了每个分块的并行计算,提前做了部分排序,当数据shuffle的时候每个分块数据就只有100条,最后汇总进行一次排序的数据量就非常小了!其实这就是归并排序的思想,感兴趣的朋友可以搜索‘归并排序’看看。
优化后的示例代码如下:

// 开始进行分析
    val resultRDD = studentDF.rdd
      .mapPartitions {
        // 自己实现时,如果为了性能更好,不建议这样的函数式写法
        // 这里只是为了方便看
        _.map { row =>
          val id = row.getLong(row.fieldIndex("id"))
          val math = row.getInt(row.fieldIndex("math"))
          val year = row.getInt(row.fieldIndex("year"))

          (year, (id, math))
        }.toArray
          .groupBy(_._1) // 先在每个分块前,获取历年的数学前100名,减少后续groupBy的shuffle数据量
          .mapValues(_.map(_._2).sortWith(_._2 > _._2).take(100))
          .toIterator
      }.groupByKey() // 最后获取所有分块的前100名,再次排序,计算总的前100名
       .mapValues(_.flatten.toSeq.sortWith(_._2 > _._2).take(100))
    
    // 触发Action,展示部分统计结果
    resultRDD.take(10).foreach(println)

上述代码,已经完成功能实现。那么,这样的代码是否是最好的呢?答案是否定的。因为当前的排序是针对每个分块(Partition)的,一个Executor上有多个分块,每个分块有前100条数据需要shuffle,显然如果一个Executor一共只有100条数据需要shuffle才是最理想的!如果我们能有办法同时操纵每个Executor上的所有数据,获取前100条数据,那该多好啊!

我们想要的排序流程示意图如下:

groupby spark groupby spark性能_groupby spark

然而,Spark并没提供一个类似mapPartition的可以对Executor上所有分块统一操作的算子(不然的话,我们就可以像mapPartion那样统计每Executor的前100名了)。不过我们有一个算子reduceByKey,它会在每个节点合并数据后再shuffle到一个节点进行最后的合并,这种行为似乎与我们需要的逻辑类似,不过好像又有那么一点不一样。

你可能会说reduceByKey是合并,而我们的需求是排序啊!!!是的,这看上去似乎有点矛盾。

事实上,这样是行得通的:

首先,让我们假想有这样一个集合类型A(内部是可排序的,并且只能拥有前100的数据,多余的会被删除)
接着,把每个元素(id,math)转换成含有一个元素的集合A
最后,使用reduceByKey,将每个集合依次相加合并!!!没错!就是合并!这样最后一个集合就是包含前100名的集合了。
这样一个集合类型A,似乎在Scala、Java中不存在,不过有一个TreeSet能保证内部有序,我们可以在数据合并后手动提取前100,这样就可以了(另外,你也可以自己实现这样一个集合:3)

第一步,先将id和math转为一个对象,并为这个对象实现equals、hashCode、compareTo方法,保证后续在TreeSet中的排序不会出问题。另外,再实现一个toString方法,方便我们查看打印效果!:)

Person.class 代码 (因为Java比较易懂、易写这几个方法,这里优先采用Java的形式,后面会附上Scala对应的实现类)

public class Person implements Comparable<Person>, Serializable {

    private long id;

    private int math;

    public Person(long id, int math) {
        this.id = id;
        this.math = math;
    }

    @Override
    public int compareTo(Person person) {
        int result = person.math - this.math; // 降序
        if (result == 0) {
            result = person.id - this.id > 0 ? 1 : -1;
        }
        return result;
    }

    @Override
    public boolean equals(Object o) {
        if (this == o) return true;
        if (o == null || getClass() != o.getClass()) return false;

        Person person = (Person) o;

        return id == person.id;
    }

    @Override
    public int hashCode() {
        return (int) (id ^ (id >>> 32));
    }

    @Override
    public String toString() {
        return "Person{" +
                "id='" + id + '\'' +
                ", math=" + math +
                '}';
    }

}

TreeSet 使用示例

import scala.collection.immutable.TreeSet

object Demo {

  def main(args: Array[String]): Unit = {
    val set = TreeSet[Person](
      new Person(1231232L, 108),
      new Person(3214124L, 116),
      new Person(1321313L, 121),
      new Person(6435235L, 125)
    )

    // 获取前3名
    for (elem <- set.take(3)) {
      println(s"--> elem = $elem")
    }
  }

}

第二步,将原先的id、math封装为TreeSet

studentDF.rdd
      .map(row => {
        val id = row.getLong(row.fieldIndex("id"))
        val math = row.getInt(row.fieldIndex("math"))
        val year = row.getInt(row.fieldIndex("year"))

        (year, TreeSet(new Person(id, math)))
      })

最后,使用reduceByKey合并所有数据,得到前100名的结果

val resultRDD = studentDF.rdd
      .map(row => {
        val id = row.getLong(row.fieldIndex("id"))
        val math = row.getInt(row.fieldIndex("math"))
        val year = row.getInt(row.fieldIndex("year"))

        (year, TreeSet(new Person(id, math)))
      })
      .reduceByKey((set1, set2) => set1 ++ set2 take 100)  // 依次合并2个Set,并只保留前100

    resultRDD.take(10).foreach(println)

Nice!!! 这样,我们就同时解决了排序问题和数据倾斜问题!
进一步优化(aggregateByKey)
细心的朋友应该已经发现了,reduceByKey之前的map为每条的数据都生成了一个TreeSet,这样会大大增加内存消耗。
其实,我们只想要每个节点放一个可变的TreeSet(并且还能一直只存前100)。这样内存消耗就会更小!
那么我们该如何做呢?设计一个MyTreeSet,采用aggregateByKey复用同一个Set,简略的示例如下:
MyTreeSet(简易实现,针对mutable.TreeSet封装)

import scala.collection.mutable

class MyTreeSet[A](firstNum: Int, elem: Seq[A])(implicit val ord: Ordering[A]) {

  val set: mutable.TreeSet[A] = mutable.TreeSet[A](elem: _*)

  def +=(elem: A): MyTreeSet[A] = {
    this add elem

    this
  }

  def add(elem: A): Unit = {
    set.add(elem)

    // 删除排在最后的多余元素
    check10Size()
  }

  def ++=(that: MyTreeSet[A]) : MyTreeSet[A] = {
    that.set.foreach(e => this add e)

    this
  }

  def check10Size(): Unit = {
    // 如果超过了firstNum个,就删除
    if (set.size > firstNum) {
      set -= set.last
    }
  }

  override def toString: String = set.toString
}

object MyTreeSet {

  def apply[A](elem: A*)(implicit ord: Ordering[A]): MyTreeSet[A] = new MyTreeSet[A](10, elem) // 默认保留前10
  
  def apply[A](firstNum: Int, elem: A*)(implicit ord: Ordering[A]): MyTreeSet[A] = new MyTreeSet[A](firstNum, elem)
  
}

Spark部分代码

val resultRDD = studentDF.rdd
      .map(row => {
        val id = row.getLong(row.fieldIndex("id"))
        val math = row.getInt(row.fieldIndex("math"))
        val year = row.getInt(row.fieldIndex("year"))

        (year, new Person(id, math))
      }).aggregateByKey(MyTreeSet[Person](100)) (
          (set, v) => set += v,
          (set1, set2) => set1 ++= set2
      )

2.4 最终代码,以及其他附件代码
最终代码

import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.{IntegerType, LongType, StructField, StructType}

import scala.collection.immutable.TreeSet

object GroupDemo3 {

  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setAppName("GroupDemo")
    val spark = SparkSession.builder()
      .config(conf)
      .enableHiveSupport()
      .getOrCreate()

    val tbSchema = StructType(Array(
      StructField("id", LongType, true),
      StructField("chinese", IntegerType, true),
      StructField("math", IntegerType, true),
      StructField("english", IntegerType, true),
      StructField("year", IntegerType, true)
    ))

    // 获取原始数据
    val studentDF = spark.read.schema(tbSchema).table("tb_student_score")
      .where("id IS NOT NULL AND math IS NOT NULL AND year IS NOT NULL")

    // 开始进行分析
    val resultRDD = studentDF.rdd
      .map(row => {
        val id = row.getLong(row.fieldIndex("id"))
        val math = row.getInt(row.fieldIndex("math"))
        val year = row.getInt(row.fieldIndex("year"))

        (year, new Person(id, math))
      }).aggregateByKey(MyTreeSet[Person](100)) (
          (set, v) => set += v,
          (set1, set2) => set1 ++= set2
      ) // 依次合并2个Set,并只保留前100

    // 触发Action,展示部分统计结果
    resultRDD.take(10).foreach(println)

    spark.stop()
  }

}

Person的Scala实现

class PersonScala(val id: Long, val math: Int) extends Ordered[PersonScala]  with Serializable {

  override def compare(that: PersonScala): Int = {
    var result = that.math - this.math // 降序
    if (result == 0)
      result = if (that.id - this.id > 0) 1 else -1
    result
  }

  override def equals(obj: Any): Boolean = {
    obj match {
      case person: PersonScala => this.id == person.id
      case _ => false
    }
  }

  override def hashCode(): Int = (id ^ (id >>> 32)).toInt

  override def toString: String = "Person{" + "id=" + id + ", math=" + math + '}'

}

object PersonScala {

  def apply(id: Long, math: Int): PersonScala = new PersonScala(id, math)

}

示例——使用SQL获取历年数学的前100名(简单,但性能一般,且存在数据倾斜的可能)

def main(args: Array[String]): Unit = {
   val conf = new SparkConf().setAppName("GroupDemo")
   val spark = SparkSession.builder()
     .config(conf)
     .enableHiveSupport()
     .getOrCreate()
   
   // 使用sql分析
   val resultDF = spark.sql(
     """
       |SELECT year,id,math
       |FROM (
       | SELECT year,id,math,ROW_NUMBER() OVER (PARTITION BY year ORDER BY math DESC) rank 
       | FROM tb_student_score
       |) g
       |WHERE g.rank <= 100
     """.stripMargin)
   
   // 触发Action,展示部分统计结果
   resultDF.show()

   spark.stop()
 }