为什么讲解Sorted-Based shuffle?2方面的原因:
一,可能有些朋友看到Sorted-Based Shuffle的时候,会有一个误解,认为Spark基于Sorted-Based Shuffle 它产出的结果是有序的。
二,Sorted-Based Shuffle要排序,涉及到一个排序算法。这部分内容同学们可选学。

Sorted-Based Shuffle 的核心是借助于 ExternalSorter 把每个 ShuffleMapTask 的输出,排序到一个文件中 (FileSegmentGroup),为了区分下一个阶段 Reducer Task 不同的内容,它还需要有一个索引文件 (Index) 来告诉下游 Stage 的并行任务,那一部份是属于你的。

Shuffle Map Task 在ExternalSorter 溢出到磁盘的时候,产生一组 File (File Group是hashShuffle中的概念,理解为一个file文件池,这里为区分,使用File的概念,FileSegment根据PartionID排序)和 一个索引文件,File 里的 FileSegement 会进行排序,在 Reducer 端有4个Reducer Task,下游的 Task 可以很容易跟据索引 (index) 定位到这个 Fie 中的哪部份 FileSegement 是属于下游的,它相当于一个指针,下游的 Task 要向 Driver 去碓定文件在那里,然后到了这个 File 文件所在的地方,实际上会跟 BlockManager 进行沟通,BlockManager 首先会读一个 Index 文件,根据它的命名则进行解析,比如说下一个阶段的第一个 Task,一般就是抓取第一个 Segment,这是一个指针定位的过程。

再次强调 Sort-Based Shuffle 最大的意义是减少临时文件的输出数量,且只会产生两个文件:一个是包含不同内容划分成不同 FileSegment 构成的单一文件 File,另外一个是索引文件 Index。

 Sort-Based Shuffle  Mapper端的 Sort and Spill 的过程 (ApependOnlyMap时不进行排序,Spill 到磁盘的时候再进行排序的)

现在我们从源码的角度去看看到底Sorted-Based Shuffle这个排序实际上是在干什么的。


// Let the user specify short names for shuffle managers
val shortShuffleMgrNames = Map(
  "sort" -> classOf[org.apache.spark.shuffle.sort.SortShuffleManager].getName,
  "tungsten-sort" -> classOf[org.apache.spark.shuffle.sort.SortShuffleManager].getName)
val shuffleMgrName = conf.get("spark.shuffle.manager", "sort")
val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName)
val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass)


override def runTask(context: TaskContext): MapStatus = {
  // Deserialize the RDD using the broadcast variable.
  val threadMXBean = ManagementFactory.getThreadMXBean
  val deserializeStartTime = System.currentTimeMillis()
  val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
  } else 0L
  val ser = SparkEnv.get.closureSerializer.newInstance()
  val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
    ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
  _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime
  _executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
    threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
  } else 0L

  var writer: ShuffleWriter[Any, Any] = null
  try {
    val manager = SparkEnv.get.shuffleManager
    writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
    writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
    writer.stop(success = true).get
  } catch {
    case e: Exception =>
      try {
        if (writer != null) {
          writer.stop(success = false)
      } catch {
        case e: Exception =>
          log.debug("Could not stop writer", e)
      throw e

manager = SparkEnv.get.shuffleManager是从SparkEnv中通过反射的获取的shuffleManager,就是SortShuffleManager。那 manager.getWriter是SortShuffleManager的getWriter

/** Get a writer for a given partition. Called on executors by map tasks. */
override def getWriter[K, V](
    handle: ShuffleHandle,
    mapId: Int,
    context: TaskContext): ShuffleWriter[K, V] = {
    handle.shuffleId, handle.asInstanceOf[BaseShuffleHandle[_, _, _]].numMaps)
  val env = SparkEnv.get
  handle match {
    case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] =>
      new UnsafeShuffleWriter(
    case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] =>
      new BypassMergeSortShuffleWriter(
    case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
      new SortShuffleWriter(shuffleBlockResolver, other, mapId, context)

SortShuffleManager getWriter Handle提供的三种方式

  • unsafeShuffleHandle : tungsten深度优化的方式
  • bypassMergeSortHandle:Sorted-Shuffle在一定程度上可以退化为hashShuffle的方式
  • BaseShuffleHandle:是SortShuffleWriter

再回到之前ShuffleMapTask中,获取shuflemanager getWriter之后,要write写数据。

var writer: ShuffleWriter[Any, Any] = null
try {
  val manager = SparkEnv.get.shuffleManager
  writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
  writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
  writer.stop(success = true).get



/** Write a bunch of records to this task's output */
override def write(records: Iterator[Product2[K, V]]): Unit = {
  sorter = if (dep.mapSideCombine) {
    require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
    new ExternalSorter[K, V, C](
      context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
  } else {
    // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't
    // care whether the keys get sorted in each partition; that will be done on the reduce side
    // if the operation being run is sortByKey.
    new ExternalSorter[K, V, V](
      context, aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)

  // Don't bother including the time to open the merged output file in the shuffle write time,
  // because it just opens a single file, so is typically too fast to measure accurately
  // (see SPARK-3570).
  val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
  val tmp = Utils.tempFileWith(output)
  try {
    val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
    val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
    shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
    mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
  } finally {
    if (tmp.exists() && !tmp.delete()) {
      logError(s"Error while deleting temp file ${tmp.getAbsolutePath}")


// Data structures to store in-memory objects before we spill. Depending on whether we have an
// Aggregator set, we either put objects into an AppendOnlyMap where we combine them, or we
// store them in an array buffer.
@volatile private var map = new PartitionedAppendOnlyMap[K, C]
@volatile private var buffer = new PartitionedPairBuffer[K, C]

1,在map端进行combine:PartitionedAppendOnlyMap 是map类型的数据结构,map是key-value ,在本地进行聚合,在本地key值不变,Value不断进行更新;PartitionedAppendOnlyMap 底层还是一个数组,基于数组实现map的原因是更节省空间,效率更高。那么直接基于数组怎么实现map:把数组的标记 0 1 2 3 4 .。。。把偶数设置为map的key值,把奇数设置为map的value值。


def insertAll(records: Iterator[Product2[K, V]]): Unit = {
  // TODO: stop combining if we find that the reduction factor isn't high
  val shouldCombine = aggregator.isDefined

  if (shouldCombine) {
    // Combine values in-memory first using our AppendOnlyMap
    val mergeValue = aggregator.get.mergeValue
    val createCombiner = aggregator.get.createCombiner
    var kv: Product2[K, V] = null
    val update = (hadValue: Boolean, oldValue: C) => {
      if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
    while (records.hasNext) {
      kv = records.next()
      map.changeValue((getPartition(kv._1), kv._1), update)
      maybeSpillCollection(usingMap = true)
  } else {
    // Stick values into our buffer
    while (records.hasNext) {
      val kv = records.next()
      buffer.insert(getPartition(kv._1), kv._1, kv._2.asInstanceOf[C])
      maybeSpillCollection(usingMap = false)



 根据dep.shuffleId, mapId获取输出文件output
写数据 根据dep.shuffleId, mapId, partitionLengths, tmp,tmp是中间临时文件写入文件和更新索引。

val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
val tmp = Utils.tempFileWith(output)
try {
  val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
  val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
  shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
  mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)


def writePartitionedFile(
    blockId: BlockId,
    outputFile: File): Array[Long] = {

  // Track location of each range in the output file
  val lengths = new Array[Long](numPartitions)
  val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize,

  if (spills.isEmpty) {
    // Case where we only have in-memory data
    val collection = if (aggregator.isDefined) map else buffer
    val it = collection.destructiveSortedWritablePartitionedIterator(comparator)
    while (it.hasNext) {
      val partitionId = it.nextPartition()
      while (it.hasNext && it.nextPartition() == partitionId) {
      val segment = writer.commitAndGet()
      lengths(partitionId) = segment.length
  } else {
    // We must perform merge-sort; get an iterator by partition and write everything directly.
    for ((id, elements) <- this.partitionedIterator) {
      if (elements.hasNext) {
        for (elem <- elements) {
          writer.write(elem._1, elem._2)
        val segment = writer.commitAndGet()
        lengths(id) = segment.length



大家看一下里面有没有排序的事情?可能没有看见,里面有一句很关键的代码:val it = collection.destructiveSortedWritablePartitionedIterator(comparator),生成一个it WritablePartitionedIterator写数据


private[spark] trait WritablePartitionedPairCollection[K, V] {
   * Insert a key-value pair with a partition into the collection
  def insert(partition: Int, key: K, value: V): Unit

   * Iterate through the data in order of partition ID and then the given comparator. This may
   * destroy the underlying collection.
  def partitionedDestructiveSortedIterator(keyComparator: Option[Comparator[K]])
    : Iterator[((Int, K), V)]

这个地方看到了排序:以partition ID进行排序,实现快速的写,方便的读操作;关键的一点对KEY进行操作。

 * Implementation of WritablePartitionedPairCollection that wraps a map in which the keys are tuples
 * of (partition ID, K)
private[spark] class PartitionedAppendOnlyMap[K, V]
  extends SizeTrackingAppendOnlyMap[(Int, K), V] with WritablePartitionedPairCollection[K, V] {

  def partitionedDestructiveSortedIterator(keyComparator: Option[Comparator[K]])
    : Iterator[((Int, K), V)] = {
    val comparator = keyComparator.map(partitionKeyComparator).getOrElse(partitionComparator)

  def insert(partition: Int, key: K, value: V): Unit = {
    update((partition, key), value)


 * Return an iterator of the map in sorted order. This provides a way to sort the map without
 * using additional memory, at the expense of destroying the validity of the map.
def destructiveSortedIterator(keyComparator: Comparator[K]): Iterator[(K, V)] = {
  destroyed = true
  // Pack KV pairs into the front of the underlying array
  var keyIndex, newIndex = 0
  while (keyIndex < capacity) {
    if (data(2 * keyIndex) != null) {
      data(2 * newIndex) = data(2 * keyIndex)
      data(2 * newIndex + 1) = data(2 * keyIndex + 1)
      newIndex += 1
    keyIndex += 1
  assert(curSize == newIndex + (if (haveNullValue) 1 else 0))

  new Sorter(new KVArraySortDataFormat[K, AnyRef]).sort(data, 0, newIndex, keyComparator)

  new Iterator[(K, V)] {
    var i = 0
    var nullValueReady = haveNullValue
    def hasNext: Boolean = (i < newIndex || nullValueReady)
    def next(): (K, V) = {
      if (nullValueReady) {
        nullValueReady = false
        (null.asInstanceOf[K], nullValue)
      } else {
        val item = (data(2 * i).asInstanceOf[K], data(2 * i + 1).asInstanceOf[V])
        i += 1

里面的关键的地方有一个new Sorter

class Sorter[K, Buffer](private val s: SortDataFormat[K, Buffer]) {

  private val timSort = new TimSort(s)

   * Sorts the input buffer within range [lo, hi).
  def sort(a: Buffer, lo: Int, hi: Int, c: Comparator[_ >: K]): Unit = {
    timSort.sort(a, lo, hi, c)
