Skip to content

Commit

Permalink
[SPARK-41497][CORE] Fixing accumulator undercount in the case of the …
Browse files Browse the repository at this point in the history
…retry task with rdd cache

### What changes were proposed in this pull request?
As described in [SPARK-41497](https://issues.apache.org/jira/browse/SPARK-41497), when a task with rdd cache failed after caching the block data successfully, the retry task will load from the cache. While since the first task attempt failed, so the registered accumulators won't get updated.

The general idea to fix the issue in this PR is to add a visibility status for RDDBlocks, a RDDBlock will be visible only when one of the tasks generating the RDDBlock succeed to guarantee that accumulators have been updated.
Making below changes to do this:
1. In `BlockManagerMasterEndpoint`, adding `visibleRDDBlocks` to help record the RDDBlocks which are visible, and `tidToRddBlockIds` to help to track the RDDBlocks generated in each taskId so that we can update the visibility status based on task status;
2. In `BlockInfoManager`, adding `visibleRDDBlocks` to track the visible RDDBlocks in the block manager, once a RDDBlock is visible, master will ask BlockManagers having the block to update the visibility status;
3. When do `RDD` getOrCompute, re-compute the partition to update accumulators if the cached RDDBlock is not visible event if the cached data exists, and report the taskId and RDDBlock relationship to `BlockManagerMasterEndpoint`;
4. When a task finished successfully, ask `BlockManagerMasterEndpoint` to update the blocks to be visible, and broadcast the visibility status to `BlockManagers` having the cached data.

### Why are the changes needed?
Bug fix.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Adding new UT.

Closes #39459 from ivoson/SPARK-41497.

Authored-by: Tengfei Huang <tengfei.h@gmail.com>
Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
  • Loading branch information
ivoson authored and Mridul Muralidharan committed Mar 3, 2023
1 parent 2e6ce42 commit fd50043
Show file tree
Hide file tree
Showing 12 changed files with 489 additions and 26 deletions.
11 changes: 11 additions & 0 deletions core/src/main/scala/org/apache/spark/internal/config/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2468,4 +2468,15 @@ package object config {
.version("3.4.0")
.booleanConf
.createWithDefault(false)

private[spark] val RDD_CACHE_VISIBILITY_TRACKING_ENABLED =
ConfigBuilder("spark.rdd.cache.visibilityTracking.enabled")
.internal()
.doc("Set to be true to enabled RDD cache block's visibility status. Once it's enabled," +
" a RDD cache block can be used only when it's marked as visible. And a RDD block will be" +
" marked as visible only when one of the tasks generating the cache block finished" +
" successfully. This is relevant in context of consistent accumulator status.")
.version("3.4.0")
.booleanConf
.createWithDefault(false)
}
10 changes: 6 additions & 4 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -372,10 +372,12 @@ abstract class RDD[T: ClassTag](
val blockId = RDDBlockId(id, partition.index)
var readCachedBlock = true
// This method is called on executors, so we need call SparkEnv.get instead of sc.env.
SparkEnv.get.blockManager.getOrElseUpdate(blockId, storageLevel, elementClassTag, () => {
readCachedBlock = false
computeOrReadCheckpoint(partition, context)
}) match {
SparkEnv.get.blockManager.getOrElseUpdateRDDBlock(
context.taskAttemptId(), blockId, storageLevel, elementClassTag, () => {
readCachedBlock = false
computeOrReadCheckpoint(partition, context)
}
) match {
// Block hit.
case Left(blockResult) =>
if (readCachedBlock) {
Expand Down
11 changes: 11 additions & 0 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import org.apache.spark.errors.SparkCoreErrors
import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config
import org.apache.spark.internal.config.RDD_CACHE_VISIBILITY_TRACKING_ENABLED
import org.apache.spark.internal.config.Tests.TEST_NO_STAGE_RETRY
import org.apache.spark.network.shuffle.{BlockStoreClient, MergeFinalizerListener}
import org.apache.spark.network.shuffle.protocol.MergeStatuses
Expand Down Expand Up @@ -303,6 +304,10 @@ private[spark] class DAGScheduler(
private val shuffleSendFinalizeRpcExecutor: ExecutorService =
ThreadUtils.newDaemonFixedThreadPool(shuffleFinalizeRpcThreads, "shuffle-merge-finalize-rpc")

/** Whether rdd cache visibility tracking is enabled. */
private val trackingCacheVisibility: Boolean =
sc.getConf.get(RDD_CACHE_VISIBILITY_TRACKING_ENABLED)

/**
* Called by the TaskSetManager to report task's starting.
*/
Expand Down Expand Up @@ -1787,6 +1792,12 @@ private[spark] class DAGScheduler(
case _: ExceptionFailure | _: TaskKilled => updateAccumulators(event)
case _ =>
}
if (trackingCacheVisibility) {
// Update rdd blocks' visibility status.
blockManagerMaster.updateRDDBlockVisibility(
event.taskInfo.taskId, visible = event.reason == Success)
}

postTaskEnd(event)

event.reason match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ private[storage] object BlockInfo {
*
* This class is thread-safe.
*/
private[storage] class BlockInfoManager extends Logging {
private[storage] class BlockInfoManager(trackingCacheVisibility: Boolean = false) extends Logging {

private type TaskAttemptId = Long

Expand All @@ -150,6 +150,12 @@ private[storage] class BlockInfoManager extends Logging {
*/
private[this] val blockInfoWrappers = new ConcurrentHashMap[BlockId, BlockInfoWrapper]

/**
* Record invisible rdd blocks stored in the block manager, entries will be removed when blocks
* are marked as visible or blocks are removed by [[removeBlock()]].
*/
private[this] val invisibleRDDBlocks = new mutable.HashSet[RDDBlockId]

/**
* Stripe used to control multi-threaded access to block information.
*
Expand Down Expand Up @@ -180,6 +186,32 @@ private[storage] class BlockInfoManager extends Logging {

// ----------------------------------------------------------------------------------------------

// Exposed for test only.
private[storage] def containsInvisibleRDDBlock(blockId: RDDBlockId): Boolean = {
invisibleRDDBlocks.synchronized {
invisibleRDDBlocks.contains(blockId)
}
}

private[spark] def isRDDBlockVisible(blockId: RDDBlockId): Boolean = {
if (trackingCacheVisibility) {
invisibleRDDBlocks.synchronized {
blockInfoWrappers.containsKey(blockId) && !invisibleRDDBlocks.contains(blockId)
}
} else {
// Always be visible if the feature flag is disabled.
true
}
}

private[spark] def tryMarkBlockAsVisible(blockId: RDDBlockId): Unit = {
if (trackingCacheVisibility) {
invisibleRDDBlocks.synchronized {
invisibleRDDBlocks.remove(blockId)
}
}
}

/**
* Called at the start of a task in order to register that task with this [[BlockInfoManager]].
* This must be called prior to calling any other BlockInfoManager methods from that task.
Expand Down Expand Up @@ -399,7 +431,19 @@ private[storage] class BlockInfoManager extends Logging {
try {
val wrapper = new BlockInfoWrapper(newBlockInfo, lock)
while (true) {
val previous = blockInfoWrappers.putIfAbsent(blockId, wrapper)
val previous = if (trackingCacheVisibility) {
invisibleRDDBlocks.synchronized {
val res = blockInfoWrappers.putIfAbsent(blockId, wrapper)
if (res == null) {
// Added to invisible blocks if it doesn't exist before.
blockId.asRDDId.foreach(invisibleRDDBlocks.add)
}
res
}
} else {
blockInfoWrappers.putIfAbsent(blockId, wrapper)
}

if (previous == null) {
// New block lock it for writing.
val result = lockForWriting(blockId, blocking = false)
Expand Down Expand Up @@ -502,7 +546,10 @@ private[storage] class BlockInfoManager extends Logging {
throw new IllegalStateException(
s"Task $taskAttemptId called remove() on block $blockId without a write lock")
} else {
blockInfoWrappers.remove(blockId)
invisibleRDDBlocks.synchronized {
blockInfoWrappers.remove(blockId)
blockId.asRDDId.foreach(invisibleRDDBlocks.remove)
}
info.readerCount = 0
info.writerTask = BlockInfo.NO_WRITER
writeLocksByTask.get(taskAttemptId).remove(blockId)
Expand All @@ -525,6 +572,9 @@ private[storage] class BlockInfoManager extends Logging {
blockInfoWrappers.clear()
readLocksByTask.clear()
writeLocksByTask.clear()
invisibleRDDBlocks.synchronized {
invisibleRDDBlocks.clear()
}
}

}
92 changes: 80 additions & 12 deletions core/src/main/scala/org/apache/spark/storage/BlockManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ import org.apache.spark.errors.SparkCoreErrors
import org.apache.spark.executor.DataReadMethod
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config
import org.apache.spark.internal.config.Network
import org.apache.spark.internal.config.{Network, RDD_CACHE_VISIBILITY_TRACKING_ENABLED}
import org.apache.spark.memory.{MemoryManager, MemoryMode}
import org.apache.spark.metrics.source.Source
import org.apache.spark.network._
Expand Down Expand Up @@ -199,8 +199,11 @@ private[spark] class BlockManager(
new DiskBlockManager(conf, deleteFilesOnStop = deleteFilesOnStop, isDriver = isDriver)
}

/** Whether rdd cache visibility tracking is enabled. */
private val trackingCacheVisibility: Boolean = conf.get(RDD_CACHE_VISIBILITY_TRACKING_ENABLED)

// Visible for testing
private[storage] val blockInfoManager = new BlockInfoManager
private[storage] val blockInfoManager = new BlockInfoManager(trackingCacheVisibility)

private val futureExecutionContext = ExecutionContext.fromExecutorService(
ThreadUtils.newDaemonCachedThreadPool("block-manager-future", 128))
Expand Down Expand Up @@ -1323,31 +1326,74 @@ private[spark] class BlockManager(
blockInfoManager.releaseAllLocksForTask(taskAttemptId)
}

/**
* Retrieve the given rdd block if it exists and is visible, otherwise call the provided
* `makeIterator` method to compute the block, persist it, and return its values.
*
* @return either a BlockResult if the block was successfully cached, or an iterator if the block
* could not be cached.
*/
def getOrElseUpdateRDDBlock[T](
taskId: Long,
blockId: RDDBlockId,
level: StorageLevel,
classTag: ClassTag[T],
makeIterator: () => Iterator[T]): Either[BlockResult, Iterator[T]] = {
val isCacheVisible = isRDDBlockVisible(blockId)
val res = getOrElseUpdate(blockId, level, classTag, makeIterator, isCacheVisible)
if (res.isLeft && !isCacheVisible) {
// Block exists but not visible, report taskId -> blockId info to master.
master.updateRDDBlockTaskInfo(blockId, taskId)
}

res
}

/**
* Retrieve the given block if it exists, otherwise call the provided `makeIterator` method
* to compute the block, persist it, and return its values.
*
* @return either a BlockResult if the block was successfully cached, or an iterator if the block
* could not be cached.
*/
def getOrElseUpdate[T](
private def getOrElseUpdate[T](
blockId: BlockId,
level: StorageLevel,
classTag: ClassTag[T],
makeIterator: () => Iterator[T]): Either[BlockResult, Iterator[T]] = {
// Attempt to read the block from local or remote storage. If it's present, then we don't need
// to go through the local-get-or-put path.
get[T](blockId)(classTag) match {
case Some(block) =>
return Left(block)
case _ =>
// Need to compute the block.
makeIterator: () => Iterator[T],
isCacheVisible: Boolean): Either[BlockResult, Iterator[T]] = {
// Track whether the data is computed or not, force to do the computation later if need to.
// The reason we push the force computing later is that once the executor is decommissioned we
// will have a better chance to replicate the cache block because of the `checkShouldStore`
// validation when putting a new block.
var computed: Boolean = false
val iterator = () => {
computed = true
makeIterator()
}
if (isCacheVisible) {
// Attempt to read the block from local or remote storage. If it's present, then we don't need
// to go through the local-get-or-put path.
get[T](blockId)(classTag) match {
case Some(block) =>
return Left(block)
case _ =>
// Need to compute the block.
}
}

// TODO: need a better way to handle blocks with indeterminate/unordered results, replicas
// for same blockId could be different. And the reported accumulators could be not matching
// the cached results.
// Initially we hold no locks on this block.
doPutIterator(blockId, makeIterator, level, classTag, keepReadLock = true) match {
doPutIterator(blockId, iterator, level, classTag, keepReadLock = true) match {
case None =>
// doPut() didn't hand work back to us, so the block already existed or was successfully
// stored. Therefore, we now hold a read lock on the block.
if (!isCacheVisible && !computed) {
// Force compute to report accumulator updates.
Utils.getIteratorSize(makeIterator())
}
val blockResult = getLocalValues(blockId).getOrElse {
// Since we held a read lock between the doPut() and get() calls, the block should not
// have been evicted, so get() not returning the block indicates some internal error.
Expand Down Expand Up @@ -1422,6 +1468,28 @@ private[spark] class BlockManager(
blockStoreUpdater.save()
}

// Check whether a rdd block is visible or not.
private[spark] def isRDDBlockVisible(blockId: RDDBlockId): Boolean = {
// Cached blocks are always visible if the feature flag is disabled.
if (!trackingCacheVisibility) {
return true
}

// If the rdd block visibility information not available in the block manager,
// asking master for the information.
if (blockInfoManager.isRDDBlockVisible(blockId)) {
return true
}

if(master.isRDDBlockVisible(blockId)) {
// Cache the visibility status if block exists.
blockInfoManager.tryMarkBlockAsVisible(blockId)
true
} else {
false
}
}

/**
* Helper method used to abstract common code from [[BlockStoreUpdater.save()]]
* and [[doPutIterator()]].
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,19 @@ class BlockManagerMaster(
res
}

def updateRDDBlockTaskInfo(blockId: RDDBlockId, taskId: Long): Unit = {
driverEndpoint.askSync[Unit](UpdateRDDBlockTaskInfo(blockId, taskId))
}

def updateRDDBlockVisibility(taskId: Long, visible: Boolean): Unit = {
driverEndpoint.ask[Unit](UpdateRDDBlockVisibility(taskId, visible))
}

/** Check whether a block is visible */
def isRDDBlockVisible(blockId: RDDBlockId): Boolean = {
driverEndpoint.askSync[Boolean](GetRDDBlockVisibility(blockId))
}

/** Get locations of the blockId from the driver */
def getLocations(blockId: BlockId): Seq[BlockManagerId] = {
driverEndpoint.askSync[Seq[BlockManagerId]](GetLocations(blockId))
Expand Down

0 comments on commit fd50043

Please sign in to comment.