Skip to content

Commit

Permalink
Introduce a separate slot for stealing tasks into in CoroutineSchedul…
Browse files Browse the repository at this point in the history
…er (#3537)

* Introduce a separate slot for stealing tasks into in CoroutineScheduler

It solves two problems:

* Stealing into exclusively owned local queue does no longer require CAS'es or atomic operations where they were previously not needed. It should save a few cycles on the stealing code path
* The overall timing perturbations should be slightly better now: previously it was possible for the stolen task to be immediately got stolen again from the stealer thread because it was actually published to the owner's queue, but its submission time was never updated (#3416)

* Move victim argument in WorkQueue into the receiver position to simplify the overall code structure
* Fix oversubscription in CoroutineScheduler (-> Dispatchers.Default) (#3418)

Previously, a worker thread unconditionally processed tasks from its own local queue, even if tasks were CPU-intensive, but CPU token was not acquired.

Fixes #3416
Fixes #3418
  • Loading branch information
qwwdfsad committed Jan 16, 2023
1 parent ebff885 commit 87d1af9
Show file tree
Hide file tree
Showing 5 changed files with 209 additions and 70 deletions.
45 changes: 26 additions & 19 deletions kotlinx-coroutines-core/jvm/src/scheduling/CoroutineScheduler.kt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import kotlinx.coroutines.internal.*
import java.io.*
import java.util.concurrent.*
import java.util.concurrent.locks.*
import kotlin.jvm.internal.Ref.ObjectRef
import kotlin.math.*
import kotlin.random.*

Expand Down Expand Up @@ -263,16 +264,16 @@ internal class CoroutineScheduler(
val workers = ResizableAtomicArray<Worker>(corePoolSize + 1)

/**
* Long describing state of workers in this pool.
* Currently includes created, CPU-acquired and blocking workers each occupying [BLOCKING_SHIFT] bits.
* The `Long` value describing the state of workers in this pool.
* Currently includes created, CPU-acquired, and blocking workers, each occupying [BLOCKING_SHIFT] bits.
*/
private val controlState = atomic(corePoolSize.toLong() shl CPU_PERMITS_SHIFT)
private val createdWorkers: Int inline get() = (controlState.value and CREATED_MASK).toInt()
private val availableCpuPermits: Int inline get() = availableCpuPermits(controlState.value)

private inline fun createdWorkers(state: Long): Int = (state and CREATED_MASK).toInt()
private inline fun blockingTasks(state: Long): Int = (state and BLOCKING_MASK shr BLOCKING_SHIFT).toInt()
public inline fun availableCpuPermits(state: Long): Int = (state and CPU_PERMITS_MASK shr CPU_PERMITS_SHIFT).toInt()
inline fun availableCpuPermits(state: Long): Int = (state and CPU_PERMITS_MASK shr CPU_PERMITS_SHIFT).toInt()

// Guarded by synchronization
private inline fun incrementCreatedWorkers(): Int = createdWorkers(controlState.incrementAndGet())
Expand Down Expand Up @@ -598,6 +599,12 @@ internal class CoroutineScheduler(
@JvmField
val localQueue: WorkQueue = WorkQueue()

/**
* Slot that is used to steal tasks into to avoid re-adding them
* to the local queue. See [trySteal]
*/
private val stolenTask: ObjectRef<Task?> = ObjectRef()

/**
* Worker state. **Updated only by this worker thread**.
* By default, worker is in DORMANT state in the case when it was created, but all CPU tokens or tasks were taken.
Expand All @@ -617,7 +624,7 @@ internal class CoroutineScheduler(

/**
* It is set to the termination deadline when started doing [park] and it reset
* when there is a task. It servers as protection against spurious wakeups of parkNanos.
* when there is a task. It serves as protection against spurious wakeups of parkNanos.
*/
private var terminationDeadline = 0L

Expand Down Expand Up @@ -719,7 +726,6 @@ internal class CoroutineScheduler(
parkedWorkersStackPush(this)
return
}
assert { localQueue.size == 0 }
workerCtl.value = PARKED // Update value once
/*
* inStack() prevents spurious wakeups, while workerCtl.value == PARKED
Expand Down Expand Up @@ -866,15 +872,16 @@ internal class CoroutineScheduler(
}
}

fun findTask(scanLocalQueue: Boolean): Task? {
if (tryAcquireCpuPermit()) return findAnyTask(scanLocalQueue)
// If we can't acquire a CPU permit -- attempt to find blocking task
val task = if (scanLocalQueue) {
localQueue.poll() ?: globalBlockingQueue.removeFirstOrNull()
} else {
globalBlockingQueue.removeFirstOrNull()
}
return task ?: trySteal(blockingOnly = true)
fun findTask(mayHaveLocalTasks: Boolean): Task? {
if (tryAcquireCpuPermit()) return findAnyTask(mayHaveLocalTasks)
/*
* If we can't acquire a CPU permit, attempt to find blocking task:
* * Check if our queue has one (maybe mixed in with CPU tasks)
* * Poll global and try steal
*/
return localQueue.pollBlocking()
?: globalBlockingQueue.removeFirstOrNull()
?: trySteal(blockingOnly = true)
}

private fun findAnyTask(scanLocalQueue: Boolean): Task? {
Expand Down Expand Up @@ -904,7 +911,6 @@ internal class CoroutineScheduler(
}

private fun trySteal(blockingOnly: Boolean): Task? {
assert { localQueue.size == 0 }
val created = createdWorkers
// 0 to await an initialization and 1 to avoid excess stealing on single-core machines
if (created < 2) {
Expand All @@ -918,14 +924,15 @@ internal class CoroutineScheduler(
if (currentIndex > created) currentIndex = 1
val worker = workers[currentIndex]
if (worker !== null && worker !== this) {
assert { localQueue.size == 0 }
val stealResult = if (blockingOnly) {
localQueue.tryStealBlockingFrom(victim = worker.localQueue)
worker.localQueue.tryStealBlocking(stolenTask)
} else {
localQueue.tryStealFrom(victim = worker.localQueue)
worker.localQueue.trySteal(stolenTask)
}
if (stealResult == TASK_STOLEN) {
return localQueue.poll()
val result = stolenTask.element
stolenTask.element = null
return result
} else if (stealResult > 0) {
minDelay = min(minDelay, stealResult)
}
Expand Down
89 changes: 56 additions & 33 deletions kotlinx-coroutines-core/jvm/src/scheduling/WorkQueue.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package kotlinx.coroutines.scheduling
import kotlinx.atomicfu.*
import kotlinx.coroutines.*
import java.util.concurrent.atomic.*
import kotlin.jvm.internal.Ref.ObjectRef

internal const val BUFFER_CAPACITY_BASE = 7
internal const val BUFFER_CAPACITY = 1 shl BUFFER_CAPACITY_BASE
Expand All @@ -31,7 +32,7 @@ internal const val NOTHING_TO_STEAL = -2L
* (scheduler workers without a CPU permit steal blocking tasks via this mechanism). Such property enforces us to use CAS in
* order to properly claim value from the buffer.
* Moreover, [Task] objects are reusable, so it may seem that this queue is prone to ABA problem.
* Indeed it formally has ABA-problem, but the whole processing logic is written in the way that such ABA is harmless.
* Indeed, it formally has ABA-problem, but the whole processing logic is written in the way that such ABA is harmless.
* I have discovered a truly marvelous proof of this, which this KDoc is too narrow to contain.
*/
internal class WorkQueue {
Expand All @@ -46,10 +47,12 @@ internal class WorkQueue {
* [T2] changeProducerIndex (3)
* [T3] changeConsumerIndex (4)
*
* Which can lead to resulting size bigger than actual size at any moment of time.
* This is in general harmless because steal will be blocked by timer
* Which can lead to resulting size being negative or bigger than actual size at any moment of time.
* This is in general harmless because steal will be blocked by timer.
* Negative sizes can be observed only when non-owner reads the size, which happens only
* for diagnostic toString().
*/
internal val bufferSize: Int get() = producerIndex.value - consumerIndex.value
private val bufferSize: Int get() = producerIndex.value - consumerIndex.value
internal val size: Int get() = if (lastScheduledTask.value != null) bufferSize + 1 else bufferSize
private val buffer: AtomicReferenceArray<Task?> = AtomicReferenceArray(BUFFER_CAPACITY)
private val lastScheduledTask = atomic<Task?>(null)
Expand Down Expand Up @@ -100,41 +103,61 @@ internal class WorkQueue {
}

/**
* Tries stealing from [victim] queue into this queue.
* Tries stealing from this queue into the [stolenTaskRef] argument.
*
* Returns [NOTHING_TO_STEAL] if queue has nothing to steal, [TASK_STOLEN] if at least task was stolen
* or positive value of how many nanoseconds should pass until the head of this queue will be available to steal.
*/
fun tryStealFrom(victim: WorkQueue): Long {
assert { bufferSize == 0 }
val task = victim.pollBuffer()
fun trySteal(stolenTaskRef: ObjectRef<Task?>): Long {
val task = pollBuffer()
if (task != null) {
val notAdded = add(task)
assert { notAdded == null }
stolenTaskRef.element = task
return TASK_STOLEN
}
return tryStealLastScheduled(victim, blockingOnly = false)
return tryStealLastScheduled(stolenTaskRef, blockingOnly = false)
}

fun tryStealBlockingFrom(victim: WorkQueue): Long {
assert { bufferSize == 0 }
var start = victim.consumerIndex.value
val end = victim.producerIndex.value
val buffer = victim.buffer

while (start != end) {
val index = start and MASK
if (victim.blockingTasksInBuffer.value == 0) break
val value = buffer[index]
if (value != null && value.isBlocking && buffer.compareAndSet(index, value, null)) {
victim.blockingTasksInBuffer.decrementAndGet()
add(value)
return TASK_STOLEN
} else {
++start
fun tryStealBlocking(stolenTaskRef: ObjectRef<Task?>): Long {
var start = consumerIndex.value
val end = producerIndex.value

while (start != end && blockingTasksInBuffer.value > 0) {
stolenTaskRef.element = tryExtractBlockingTask(start++) ?: continue
return TASK_STOLEN
}
return tryStealLastScheduled(stolenTaskRef, blockingOnly = true)
}

// Polls for blocking task, invoked only by the owner
fun pollBlocking(): Task? {
while (true) { // Poll the slot
val lastScheduled = lastScheduledTask.value ?: break
if (!lastScheduled.isBlocking) break
if (lastScheduledTask.compareAndSet(lastScheduled, null)) {
return lastScheduled
} // Failed -> someone else stole it
}

val start = consumerIndex.value
var end = producerIndex.value

while (start != end && blockingTasksInBuffer.value > 0) {
val task = tryExtractBlockingTask(--end)
if (task != null) {
return task
}
}
return tryStealLastScheduled(victim, blockingOnly = true)
return null
}

private fun tryExtractBlockingTask(index: Int): Task? {
val arrayIndex = index and MASK
val value = buffer[arrayIndex]
if (value != null && value.isBlocking && buffer.compareAndSet(arrayIndex, value, null)) {
blockingTasksInBuffer.decrementAndGet()
return value
}
return null
}

fun offloadAllWorkTo(globalQueue: GlobalQueue) {
Expand All @@ -145,11 +168,11 @@ internal class WorkQueue {
}

/**
* Contract on return value is the same as for [tryStealFrom]
* Contract on return value is the same as for [trySteal]
*/
private fun tryStealLastScheduled(victim: WorkQueue, blockingOnly: Boolean): Long {
private fun tryStealLastScheduled(stolenTaskRef: ObjectRef<Task?>, blockingOnly: Boolean): Long {
while (true) {
val lastScheduled = victim.lastScheduledTask.value ?: return NOTHING_TO_STEAL
val lastScheduled = lastScheduledTask.value ?: return NOTHING_TO_STEAL
if (blockingOnly && !lastScheduled.isBlocking) return NOTHING_TO_STEAL

// TODO time wraparound ?
Expand All @@ -163,8 +186,8 @@ internal class WorkQueue {
* If CAS has failed, either someone else had stolen this task or the owner executed this task
* and dispatched another one. In the latter case we should retry to avoid missing task.
*/
if (victim.lastScheduledTask.compareAndSet(lastScheduled, null)) {
add(lastScheduled)
if (lastScheduledTask.compareAndSet(lastScheduled, null)) {
stolenTaskRef.element = lastScheduled
return TASK_STOLEN
}
continue
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Copyright 2016-2022 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.coroutines.scheduling

import kotlinx.coroutines.*
import org.junit.Test
import java.util.concurrent.*
import java.util.concurrent.atomic.AtomicInteger

class CoroutineSchedulerOversubscriptionTest : TestBase() {

private val inDefault = AtomicInteger(0)

private fun CountDownLatch.runAndCheck() {
if (inDefault.incrementAndGet() > CORE_POOL_SIZE) {
error("Oversubscription detected")
}

await()
inDefault.decrementAndGet()
}

@Test
fun testOverSubscriptionDeterministic() = runTest {
val barrier = CountDownLatch(1)
val threadsOccupiedBarrier = CyclicBarrier(CORE_POOL_SIZE)
// All threads but one
repeat(CORE_POOL_SIZE - 1) {
launch(Dispatchers.Default) {
threadsOccupiedBarrier.await()
barrier.runAndCheck()
}
}
threadsOccupiedBarrier.await()
withContext(Dispatchers.Default) {
// Put a task in a local queue, it will be stolen
launch(Dispatchers.Default) {
barrier.runAndCheck()
}
// Put one more task to trick the local queue check
launch(Dispatchers.Default) {
barrier.runAndCheck()
}

withContext(Dispatchers.IO) {
try {
// Release the thread
delay(100)
} finally {
barrier.countDown()
}
}
}
}

@Test
fun testOverSubscriptionStress() = repeat(1000 * stressTestMultiplierSqrt) {
inDefault.set(0)
runTest {
val barrier = CountDownLatch(1)
val threadsOccupiedBarrier = CyclicBarrier(CORE_POOL_SIZE)
// All threads but one
repeat(CORE_POOL_SIZE - 1) {
launch(Dispatchers.Default) {
threadsOccupiedBarrier.await()
barrier.runAndCheck()
}
}
threadsOccupiedBarrier.await()
withContext(Dispatchers.Default) {
// Put a task in a local queue
launch(Dispatchers.Default) {
barrier.runAndCheck()
}
// Put one more task to trick the local queue check
launch(Dispatchers.Default) {
barrier.runAndCheck()
}

withContext(Dispatchers.IO) {
yield()
barrier.countDown()
}
}
}
}
}

0 comments on commit 87d1af9

Please sign in to comment.