发起Shuffle Read的方法是这些RDD的compute方法。下面以ShuffledRDD为例,描述Shuffle Read过程。

0. 流程图

1. 入口函数

Shuffle Read操作的入口是ShuffledRDD.compute方法。

override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
    val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
    SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
      .asInstanceOf[Iterator[(K, C)]]




1.1. HashShuffleReader.read

override def read(): Iterator[Product2[K, C]] = {
    val ser = Serializer.getSerializer(dep.serializer)
    val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser)
    val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
      if (dep.mapSideCombine) {
        new InterruptibleIterator(context, dep.aggregator.get.combineCombinersByKey(iter, context))
      } else {
        new InterruptibleIterator(context, dep.aggregator.get.combineValuesByKey(iter, context))
    } else {
      require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
      // Convert the Product2s to pairs since this is what downstream RDDs currently expect
      iter.asInstanceOf[Iterator[Product2[K, C]]].map(pair => (pair._1, pair._2))
    // Sort the output if there is a sort ordering defined.
    dep.keyOrdering match {
      case Some(keyOrd: Ordering[K]) =>
        // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled,
        // the ExternalSorter won't spill to disk.
        val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser))
      case None =>

(1)BlockStoreShuffleFetcher是一个object,只有一个方法fetch,根据shuffleId和partition来获取对应的shuffle内容; fetch方法返回一个迭代器,遍历次迭代器就可以获取对应的数据记录;


注:这里mapSideCombine的操作和Shuffle Write时调用的方法是不同的。



2. BlockStoreShuffleFetcher

一个Shuffle Map Stage会将输出写到多个节点。由于多个ShuffleMapTask在同一节点执行,每个Task创建各自独立的Blocks,Blocks的数量取决于Reduce的数量(shuffle输出分区个数),因此一个reduce(一个分区)在一个节点上可能对应多个Block。


def fetch[T](
      shuffleId: Int,
      reduceId: Int,
      context: TaskContext,
      serializer: Serializer)
    : Iterator[T] =
    logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
    val blockManager = SparkEnv.get.blockManager
    val startTime = System.currentTimeMillis
    val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId)


2.1. MapOutputTracker. getServerStatuses


protected val mapStatuses: Map[Int, Array[MapStatus]]





def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = {
    val statuses = mapStatuses.get(shuffleId).orNull
    if (statuses == null) {
      logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
      var fetchedStatuses: Array[MapStatus] = null
      fetching.synchronized {
        // Someone else is fetching it; wait for them to be done
        while (fetching.contains(shuffleId)) {
          try {
          } catch {
            case e: InterruptedException =>
        // Either while we waited the fetch happened successfully, or
        // someone fetched it in between the get and the fetching.synchronized.
        fetchedStatuses = mapStatuses.get(shuffleId).orNull
        if (fetchedStatuses == null) {
          // We have to do the fetch, get others to wait for us.
          fetching += shuffleId
      if (fetchedStatuses == null) {
        // We won the race to fetch the output locs; do so
        logInfo("Doing the fetch; tracker actor = " + trackerActor)
        // This try-finally prevents hangs due to timeouts:
        try {
          val fetchedBytes =
          fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)
          logInfo("Got the output locations")
          mapStatuses.put(shuffleId, fetchedStatuses)
        } finally {
          fetching.synchronized {
            fetching -= shuffleId
      if (fetchedStatuses != null) {
        fetchedStatuses.synchronized {
          return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
      } else {
        logError("Missing all output locations for shuffle " + shuffleId)
        throw new MetadataFetchFailedException(
          shuffleId, reduceId, "Missing all output locations for shuffle " + shuffleId)
    } else {
      statuses.synchronized {
        return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses)






(6)调用MapOutputTracker.convertMapStatuses方法,将获取的的MapStatus转化为(BlockManagerId, BlockSize)二元组,一个BlockManagerId可能对应过个BlockSize。

2.1.1. MapOutputTrackerMasterActor处理GetMapOutputStatuses消息

case GetMapOutputStatuses(shuffleId: Int) =>
      val hostPort = sender.path.address.hostPort
      logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort)
      val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId)
      val serializedSize = mapOutputStatuses.size
      if (serializedSize > maxAkkaFrameSize) {
        val msg = s"Map output statuses were $serializedSize bytes which " +
          s"exceeds spark.akka.frameSize ($maxAkkaFrameSize bytes)."
        /* For SPARK-1244 we'll opt for just logging an error and then throwing an exception.
         * Note that on exception the actor will just restart. A bigger refactoring (SPARK-1239)
         * will ultimately remove this entire code path. */
        val exception = new SparkException(msg)
        logError(msg, exception)
        throw exception
      sender ! mapOutputStatuses



2.1.2 MapOutputTrackerMaster.getSerializedMapOutputStatuses

def getSerializedMapOutputStatuses(shuffleId: Int): Array[Byte] = {
    var statuses: Array[MapStatus] = null
    var epochGotten: Long = -1
    epochLock.synchronized {
      if (epoch > cacheEpoch) {
        cacheEpoch = epoch
      cachedSerializedStatuses.get(shuffleId) match {
        case Some(bytes) =>
          return bytes
        case None =>
          statuses = mapStatuses.getOrElse(shuffleId, Array[MapStatus]())
          epochGotten = epoch
    // If we got here, we failed to find the serialized locations in the cache, so we pulled
    // out a snapshot of the locations as "statuses"; let's serialize and return that
    val bytes = MapOutputTracker.serializeMapStatuses(statuses)
    logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length))
    // Add them into the table only if the epoch hasn't changed while we were working
    epochLock.synchronized {
      if (epoch == epochGotten) {
        cachedSerializedStatuses(shuffleId) = bytes




2.1.3. MapOutputTracker.convertMapStatuses

private def convertMapStatuses(
      shuffleId: Int,
      reduceId: Int,
      statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = {
    assert (statuses != null)
    statuses.map {
      status =>
        if (status == null) {
          logError("Missing an output location for shuffle " + shuffleId)
          throw new MetadataFetchFailedException(
            shuffleId, reduceId, "Missing an output location for shuffle " + shuffleId)
        } else {
          (status.location, status.getSizeForBlock(reduceId))

(BlockManagerId, BlockSize)二元组,因此一个BlockManagerId可能对应多个BlockSize,也就是说一个BlockManagerId在数组中会出现多次。



2.2. 构建ShuffleBlockId映射


val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId)
    logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format(
      shuffleId, reduceId, System.currentTimeMillis - startTime))
    val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]]
    for (((address, size), index) <- statuses.zipWithIndex) {
      splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size))
    val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = splitsByAddress.toSeq.map {
      case (address, splits) =>
        (address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2)))

(1)statuses的类型为Array[(BlockManagerId, Long)],其中BlockMangerId代表block所在的位置,Long表示Block的大小;

(2)for循环将statuses转换成[BlockManagerId,ArrayBuffer[(Int, Long)]]结构,它表示在BlockManagerId上,一个ruduce对应多个Block,其中Int表示statuses的下标索引,Long表示Block的大小;


到此就完成了BlockManagerId到Seq[(BlockId, Long)]的映射;BlockId代表ShuffleBlockId,Long表示对应Block的大小。

2.3. 创建ShuffleBlockFetcherIterator对象


val blockFetcherItr = new ShuffleBlockFetcherIterator(
      SparkEnv.get.conf.getLong("spark.reducer.maxMbInFlight", 48) * 1024 * 1024)


2.3.1. ShuffleBlockFetcherIterator.initialize

private[this] def initialize(): Unit = {
    // Add a task completion callback (called in both success case and failure case) to cleanup.
    context.addTaskCompletionListener(_ => cleanup())
    // Split local and remote blocks.
    val remoteRequests = splitLocalRemoteBlocks()
    // Add the remote requests into our queue in a random order
    fetchRequests ++= Utils.randomize(remoteRequests)
    // Send out initial requests for blocks, up to our maxBytesInFlight
    while (fetchRequests.nonEmpty &&
      (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
    val numFetches = remoteRequests.size - fetchRequests.size
    logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime))
    // Get Local Blocks
    logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime))

splitLocalRemoteBlocks方法,会将每个位于远端的[BlockerManagerId, Seq[(BlockId, Long)]]封装成多个FetchRequest对象,对象的数量根据Long值的和以及 maxBytesInFlight参数来控制;

(2)将 splitLocalRemoteBlocks返回的 FetchRequest数组随机化,并加入fetchRequests队列;


2.3.2. ShuffleBlockFetcherIterator.sendRequest

private[this] def sendRequest(req: FetchRequest) {
    logDebug("Sending request for %d blocks (%s) from %s".format(
      req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))
    bytesInFlight += req.size
    // so we can look up the size of each blockID
    val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap
    val blockIds = req.blocks.map(_._1.toString)
    val address = req.address
    shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
      new BlockFetchingListener {
        override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {
          // Only add the buffer to results queue if the iterator is not zombie,
          // i.e. cleanup() has not been called yet.
          if (!isZombie) {
            // Increment the ref count because we need to pass this to a different thread.
            // This needs to be released after use.
            results.put(new SuccessFetchResult(BlockId(blockId), sizeMap(blockId), buf))
          logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
        override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
          logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
          results.put(new FailureFetchResult(BlockId(blockId), e))

该方法负责读取Remote Block。通过ShuffleClient对象,具体实现是NettyBlockTransferService,通过fetchBlocks方法来读取Block;读取成功后, NettyBlockTransferService回调onBlockFetchSuccess方法,将结果封装成SuccessFetchResult对象,并压入results队列。

2.3.3. ShuffleBlockFetcherIterator.fetchLocalBlocks

private[this] def fetchLocalBlocks() {
    val iter = localBlocks.iterator
    while (iter.hasNext) {
      val blockId = iter.next()
      try {
        val buf = blockManager.getBlockData(blockId)
        results.put(new SuccessFetchResult(blockId, 0, buf))
      } catch {
        case e: Exception =>
          // If we see an exception, stop immediately.
          logError(s"Error occurred while fetching local blocks", e)
          results.put(new FailureFetchResult(blockId, e))


2.4. 返回迭代器


val itr = blockFetcherItr.flatMap(unpackBlock)
    val completionIter = CompletionIterator[T, Iterator[T]](itr, {
    new InterruptibleIterator[T](context, completionIter) {
      val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
      override def next(): T = {