发起Shuffle Read的方法是这些RDD的compute方法。下面以ShuffledRDD为例,描述Shuffle Read过程。
0. 流程图
1. 入口函数
Shuffle Read操作的入口是ShuffledRDD.compute方法。
override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
.read()
.asInstanceOf[Iterator[(K, C)]]
}
(1)通过SparkEnv获取ShuffleManager对象,它两个实现HashShuffleManager和SortShuffleManager,这个两个实现的getReader方法都返回HashShuffleReader对象;
(2)调用HashShuffleReader的read方法。
(3)compute方法返回的是一个迭代器,只有在涉及action或固化操作时才会具体执行用户提供的操作。
1.1. HashShuffleReader.read
override def read(): Iterator[Product2[K, C]] = {
val ser = Serializer.getSerializer(dep.serializer)
val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser)
val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
if (dep.mapSideCombine) {
new InterruptibleIterator(context, dep.aggregator.get.combineCombinersByKey(iter, context))
} else {
new InterruptibleIterator(context, dep.aggregator.get.combineValuesByKey(iter, context))
}
} else {
require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
// Convert the Product2s to pairs since this is what downstream RDDs currently expect
iter.asInstanceOf[Iterator[Product2[K, C]]].map(pair => (pair._1, pair._2))
}
// Sort the output if there is a sort ordering defined.
dep.keyOrdering match {
case Some(keyOrd: Ordering[K]) =>
// Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled,
// the ExternalSorter won't spill to disk.
val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser))
sorter.insertAll(aggregatedIter)
context.taskMetrics.incMemoryBytesSpilled(sorter.memoryBytesSpilled)
context.taskMetrics.incDiskBytesSpilled(sorter.diskBytesSpilled)
sorter.iterator
case None =>
aggregatedIter
}
}
(1)BlockStoreShuffleFetcher是一个object,只有一个方法fetch,根据shuffleId和partition来获取对应的shuffle内容; fetch方法返回一个迭代器,遍历次迭代器就可以获取对应的数据记录;
(2)后面是依据不同的条件,构造不同的迭代器,比如是否要合并,排序等。
注:这里mapSideCombine的操作和Shuffle Write时调用的方法是不同的。
write时调用:combineValuesByKey;
read时调用:combineCombinersByKey。
2. BlockStoreShuffleFetcher
一个Shuffle Map Stage会将输出写到多个节点。由于多个ShuffleMapTask在同一节点执行,每个Task创建各自独立的Blocks,Blocks的数量取决于Reduce的数量(shuffle输出分区个数),因此一个reduce(一个分区)在一个节点上可能对应多个Block。
Map和Reduce关系示意图:
一个Reduce依赖所有的Map,每个Map都会输出一份数据到每一个Ruduce。可以理解为,有多少个Map,一个Reduce就对应多少个Block。
首先,需要通过调用MapOutputTracker.getServerStatuses获取reduce对应的Blocks所在的节点以及每个Block的大小。
def fetch[T](
shuffleId: Int,
reduceId: Int,
context: TaskContext,
serializer: Serializer)
: Iterator[T] =
{
logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
val blockManager = SparkEnv.get.blockManager
val startTime = System.currentTimeMillis
val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId)
......
}
调用MapOutputTracker的getServerStatuses方法。
2.1. MapOutputTracker. getServerStatuses
MapOutputTracker类定义了一个数据结构:
protected val mapStatuses: Map[Int, Array[MapStatus]]
mapStatuses在Driver和Executor有不同的行为:
(1)在Driver端,用于记录所有ShuffleMapTask的map输出结果;
(2)在Executor端,它只作为一个缓存,如果对应数据不存在,则会从Driver端获取。
下面描述缓存没有命中,而从Driver获取的情形。
def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = {
val statuses = mapStatuses.get(shuffleId).orNull
if (statuses == null) {
logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
var fetchedStatuses: Array[MapStatus] = null
fetching.synchronized {
// Someone else is fetching it; wait for them to be done
while (fetching.contains(shuffleId)) {
try {
fetching.wait()
} catch {
case e: InterruptedException =>
}
}
// Either while we waited the fetch happened successfully, or
// someone fetched it in between the get and the fetching.synchronized.
fetchedStatuses = mapStatuses.get(shuffleId).orNull
if (fetchedStatuses == null) {
// We have to do the fetch, get others to wait for us.
fetching += shuffleId
}
}
if (fetchedStatuses == null) {
// We won the race to fetch the output locs; do so
logInfo("Doing the fetch; tracker actor = " + trackerActor)
// This try-finally prevents hangs due to timeouts:
try {
val fetchedBytes =
askTracker(GetMapOutputStatuses(shuffleId)).asInstanceOf[Array[Byte]]
fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)
logInfo("Got the output locations")
mapStatuses.put(shuffleId, fetchedStatuses)
} finally {
fetching.synchronized {
fetching -= shuffleId
fetching.notifyAll()
}
}
}
if (fetchedStatuses != null) {
fetchedStatuses.synchronized {
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
}
} else {
logError("Missing all output locations for shuffle " + shuffleId)
throw new MetadataFetchFailedException(
shuffleId, reduceId, "Missing all output locations for shuffle " + shuffleId)
}
} else {
statuses.synchronized {
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses)
}
}
}
(1)fetching记录当前正在获取的ShuffleId,如果当前ShuffleId有线程正在获取则等待,如果没有其他线程在获取则将ShuffleId加入fetching队列;
(2)fetchedStatuses为null,则开始获取;
(3)调用askTracker方法,向MapOutputTrackerMasterActor发送GetMapOutputStatuses消息,askTracker返回序列化的MapStatus信息;
(4)将获取的MapStatus信息反序列化生成MapStatus对象数组;
(5)调用mapStatuses.put,将MapStatus对象存入mapStatuses缓存;
(6)调用MapOutputTracker.convertMapStatuses方法,将获取的的MapStatus转化为(BlockManagerId, BlockSize)二元组,一个BlockManagerId可能对应过个BlockSize。
2.1.1. MapOutputTrackerMasterActor处理GetMapOutputStatuses消息
case GetMapOutputStatuses(shuffleId: Int) =>
val hostPort = sender.path.address.hostPort
logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort)
val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId)
val serializedSize = mapOutputStatuses.size
if (serializedSize > maxAkkaFrameSize) {
val msg = s"Map output statuses were $serializedSize bytes which " +
s"exceeds spark.akka.frameSize ($maxAkkaFrameSize bytes)."
/* For SPARK-1244 we'll opt for just logging an error and then throwing an exception.
* Note that on exception the actor will just restart. A bigger refactoring (SPARK-1239)
* will ultimately remove this entire code path. */
val exception = new SparkException(msg)
logError(msg, exception)
throw exception
}
sender ! mapOutputStatuses
(1)调用MapOutputTrackerMaster.getSerializedMapOutputStatuses方法,获取ShuffleId对应的序列化好的MapStatus;
(2)返回序列化好的MapStatus信息。
2.1.2 MapOutputTrackerMaster.getSerializedMapOutputStatuses
def getSerializedMapOutputStatuses(shuffleId: Int): Array[Byte] = {
var statuses: Array[MapStatus] = null
var epochGotten: Long = -1
epochLock.synchronized {
if (epoch > cacheEpoch) {
cachedSerializedStatuses.clear()
cacheEpoch = epoch
}
cachedSerializedStatuses.get(shuffleId) match {
case Some(bytes) =>
return bytes
case None =>
statuses = mapStatuses.getOrElse(shuffleId, Array[MapStatus]())
epochGotten = epoch
}
}
// If we got here, we failed to find the serialized locations in the cache, so we pulled
// out a snapshot of the locations as "statuses"; let's serialize and return that
val bytes = MapOutputTracker.serializeMapStatuses(statuses)
logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length))
// Add them into the table only if the epoch hasn't changed while we were working
epochLock.synchronized {
if (epoch == epochGotten) {
cachedSerializedStatuses(shuffleId) = bytes
}
}
bytes
}
(1)判断缓存是否过期,如过期则清除;
(2)从缓存中读取数据,如果缓存中没有则从mapStatuses中读取,缓存中有则直接返回;
(3)将获取的MapStatus序列化并存入缓存。
2.1.3. MapOutputTracker.convertMapStatuses
private def convertMapStatuses(
shuffleId: Int,
reduceId: Int,
statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = {
assert (statuses != null)
statuses.map {
status =>
if (status == null) {
logError("Missing an output location for shuffle " + shuffleId)
throw new MetadataFetchFailedException(
shuffleId, reduceId, "Missing an output location for shuffle " + shuffleId)
} else {
(status.location, status.getSizeForBlock(reduceId))
}
}
}
(BlockManagerId, BlockSize)二元组,因此一个BlockManagerId可能对应多个BlockSize,也就是说一个BlockManagerId在数组中会出现多次。
注:BlockSize并不代表Block的实际大小。MapStatus有两个实现:CompressedMapStatus和HighlyCompressedMapStatus。
CompressedMapStatus存储的Block大小是经过压缩处理的,不能还原成原值;
2.2. 构建ShuffleBlockId映射
获取到Reudce对应的所有Block的位置及大小信息后,BlockStoreShuffleFetcher.fetch方法开始构建ShuffleBlockId映射。
val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId)
logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format(
shuffleId, reduceId, System.currentTimeMillis - startTime))
val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]]
for (((address, size), index) <- statuses.zipWithIndex) {
splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size))
}
val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = splitsByAddress.toSeq.map {
case (address, splits) =>
(address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2)))
}
(1)statuses的类型为Array[(BlockManagerId, Long)],其中BlockMangerId代表block所在的位置,Long表示Block的大小;
(2)for循环将statuses转换成[BlockManagerId,ArrayBuffer[(Int, Long)]]结构,它表示在BlockManagerId上,一个ruduce对应多个Block,其中Int表示statuses的下标索引,Long表示Block的大小;
(3)创建BlockManagerId与ShuffleBlockId的映射;由于statuses中的记录是按Map编号(即partition编号)从小到排列的(具体可参考DAGScheduler.handleTaskCompletion方法中调用Stage.addOutputLoc方法及MapOutputTracker.registerMapOutputs方法),其下标索引代表了partition编号,因此在这儿可以利用for循环保存的下标索引来创建ShuffleBlockId对象。
到此就完成了BlockManagerId到Seq[(BlockId, Long)]的映射;BlockId代表ShuffleBlockId,Long表示对应Block的大小。
2.3. 创建ShuffleBlockFetcherIterator对象
BlockStoreShuffleFetcher.fetch方法开始创建ShuffleBlockFetcherIterator对象。
val blockFetcherItr = new ShuffleBlockFetcherIterator(
context,
SparkEnv.get.blockManager.shuffleClient,
blockManager,
blocksByAddress,
serializer,
SparkEnv.get.conf.getLong("spark.reducer.maxMbInFlight", 48) * 1024 * 1024)
从类名可知,该对象是一个迭代器。在构造体中会调用自身的initialize方法。
2.3.1. ShuffleBlockFetcherIterator.initialize
private[this] def initialize(): Unit = {
// Add a task completion callback (called in both success case and failure case) to cleanup.
context.addTaskCompletionListener(_ => cleanup())
// Split local and remote blocks.
val remoteRequests = splitLocalRemoteBlocks()
// Add the remote requests into our queue in a random order
fetchRequests ++= Utils.randomize(remoteRequests)
// Send out initial requests for blocks, up to our maxBytesInFlight
while (fetchRequests.nonEmpty &&
(bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
sendRequest(fetchRequests.dequeue())
}
val numFetches = remoteRequests.size - fetchRequests.size
logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime))
// Get Local Blocks
fetchLocalBlocks()
logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime))
}
splitLocalRemoteBlocks方法,会将每个位于远端的[BlockerManagerId, Seq[(BlockId, Long)]]封装成多个FetchRequest对象,对象的数量根据Long值的和以及 maxBytesInFlight参数来控制;
(2)将 splitLocalRemoteBlocks返回的 FetchRequest数组随机化,并加入fetchRequests队列;
;
(4)调用fetchLocalBlocks方法,从本地读取Block。
2.3.2. ShuffleBlockFetcherIterator.sendRequest
private[this] def sendRequest(req: FetchRequest) {
logDebug("Sending request for %d blocks (%s) from %s".format(
req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))
bytesInFlight += req.size
// so we can look up the size of each blockID
val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap
val blockIds = req.blocks.map(_._1.toString)
val address = req.address
shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
new BlockFetchingListener {
override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {
// Only add the buffer to results queue if the iterator is not zombie,
// i.e. cleanup() has not been called yet.
if (!isZombie) {
// Increment the ref count because we need to pass this to a different thread.
// This needs to be released after use.
buf.retain()
results.put(new SuccessFetchResult(BlockId(blockId), sizeMap(blockId), buf))
shuffleMetrics.incRemoteBytesRead(buf.size)
shuffleMetrics.incRemoteBlocksFetched(1)
}
logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
}
override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
results.put(new FailureFetchResult(BlockId(blockId), e))
}
}
)
}
该方法负责读取Remote Block。通过ShuffleClient对象,具体实现是NettyBlockTransferService,通过fetchBlocks方法来读取Block;读取成功后, NettyBlockTransferService回调onBlockFetchSuccess方法,将结果封装成SuccessFetchResult对象,并压入results队列。
2.3.3. ShuffleBlockFetcherIterator.fetchLocalBlocks
private[this] def fetchLocalBlocks() {
val iter = localBlocks.iterator
while (iter.hasNext) {
val blockId = iter.next()
try {
val buf = blockManager.getBlockData(blockId)
shuffleMetrics.incLocalBlocksFetched(1)
shuffleMetrics.incLocalBytesRead(buf.size)
buf.retain()
results.put(new SuccessFetchResult(blockId, 0, buf))
} catch {
case e: Exception =>
// If we see an exception, stop immediately.
logError(s"Error occurred while fetching local blocks", e)
results.put(new FailureFetchResult(blockId, e))
return
}
}
}
该方法负责读取本地block,并将结构封装成SuccessFetchResult对象压入results队列。
2.4. 返回迭代器
当ShuffleBlockFetcherIterator构造完成后,会对该对象进行处理并封装进InterruptibleIterator对象返回。
val itr = blockFetcherItr.flatMap(unpackBlock)
val completionIter = CompletionIterator[T, Iterator[T]](itr, {
context.taskMetrics.updateShuffleReadMetrics()
})
new InterruptibleIterator[T](context, completionIter) {
val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
override def next(): T = {
readMetrics.incRecordsRead(1)
delegate.next()
}
}