Skip to content

Commit

Permalink
Introduce a separate slot for stealing tasks into in CoroutineScheduler
Browse files Browse the repository at this point in the history
It solves two problems:

* Stealing into exclusively owned local queue does no longer require and CAS'es or atomic operations where they were previously not needed. It should save a few cycles on the stealing code path
* The overall timing perturbations should be slightly better now: previously it was possible for the stolen task to be immediately got stolen again from the stealer thread because it was actually published to owner's queue, but its submission time was never updated

Fixes #3416
  • Loading branch information
qwwdfsad committed Nov 22, 2022
1 parent 287d038 commit 20daaa7
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 31 deletions.
17 changes: 13 additions & 4 deletions kotlinx-coroutines-core/jvm/src/scheduling/CoroutineScheduler.kt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import kotlinx.coroutines.internal.*
import java.io.*
import java.util.concurrent.*
import java.util.concurrent.locks.*
import kotlin.jvm.internal.Ref.ObjectRef
import kotlin.math.*
import kotlin.random.*

Expand Down Expand Up @@ -598,6 +599,12 @@ internal class CoroutineScheduler(
@JvmField
val localQueue: WorkQueue = WorkQueue()

/**
* Slot that is used to steal tasks into to avoid re-adding them
* to the local queue. See [trySteal]
*/
private val stolenTask: ObjectRef<Task?> = ObjectRef()

/**
* Worker state. **Updated only by this worker thread**.
* By default, worker is in DORMANT state in the case when it was created, but all CPU tokens or tasks were taken.
Expand All @@ -617,7 +624,7 @@ internal class CoroutineScheduler(

/**
* It is set to the termination deadline when started doing [park] and it reset
* when there is a task. It servers as protection against spurious wakeups of parkNanos.
* when there is a task. It serves as protection against spurious wakeups of parkNanos.
*/
private var terminationDeadline = 0L

Expand Down Expand Up @@ -920,12 +927,14 @@ internal class CoroutineScheduler(
if (worker !== null && worker !== this) {
assert { localQueue.size == 0 }
val stealResult = if (blockingOnly) {
localQueue.tryStealBlockingFrom(victim = worker.localQueue)
localQueue.tryStealBlockingFrom(victim = worker.localQueue, stolenTask)
} else {
localQueue.tryStealFrom(victim = worker.localQueue)
localQueue.tryStealFrom(victim = worker.localQueue, stolenTask)
}
if (stealResult == TASK_STOLEN) {
return localQueue.poll()
val result = stolenTask.element
stolenTask.element = null
return result
} else if (stealResult > 0) {
minDelay = min(minDelay, stealResult)
}
Expand Down
22 changes: 11 additions & 11 deletions kotlinx-coroutines-core/jvm/src/scheduling/WorkQueue.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package kotlinx.coroutines.scheduling
import kotlinx.atomicfu.*
import kotlinx.coroutines.*
import java.util.concurrent.atomic.*
import kotlin.jvm.internal.Ref.ObjectRef

internal const val BUFFER_CAPACITY_BASE = 7
internal const val BUFFER_CAPACITY = 1 shl BUFFER_CAPACITY_BASE
Expand All @@ -31,7 +32,7 @@ internal const val NOTHING_TO_STEAL = -2L
* (scheduler workers without a CPU permit steal blocking tasks via this mechanism). Such property enforces us to use CAS in
* order to properly claim value from the buffer.
* Moreover, [Task] objects are reusable, so it may seem that this queue is prone to ABA problem.
* Indeed it formally has ABA-problem, but the whole processing logic is written in the way that such ABA is harmless.
* Indeed, it formally has ABA-problem, but the whole processing logic is written in the way that such ABA is harmless.
* I have discovered a truly marvelous proof of this, which this KDoc is too narrow to contain.
*/
internal class WorkQueue {
Expand Down Expand Up @@ -100,23 +101,22 @@ internal class WorkQueue {
}

/**
* Tries stealing from [victim] queue into this queue.
* Tries stealing from [victim] queue into the [stolenTaskRef] argument.
*
* Returns [NOTHING_TO_STEAL] if queue has nothing to steal, [TASK_STOLEN] if at least task was stolen
* or positive value of how many nanoseconds should pass until the head of this queue will be available to steal.
*/
fun tryStealFrom(victim: WorkQueue): Long {
fun tryStealFrom(victim: WorkQueue, stolenTaskRef: ObjectRef<Task?>): Long {
assert { bufferSize == 0 }
val task = victim.pollBuffer()
if (task != null) {
val notAdded = add(task)
assert { notAdded == null }
stolenTaskRef.element = task
return TASK_STOLEN
}
return tryStealLastScheduled(victim, blockingOnly = false)
return tryStealLastScheduled(victim, stolenTaskRef, blockingOnly = false)
}

fun tryStealBlockingFrom(victim: WorkQueue): Long {
fun tryStealBlockingFrom(victim: WorkQueue, stolenTaskRef: ObjectRef<Task?>): Long {
assert { bufferSize == 0 }
var start = victim.consumerIndex.value
val end = victim.producerIndex.value
Expand All @@ -128,13 +128,13 @@ internal class WorkQueue {
val value = buffer[index]
if (value != null && value.isBlocking && buffer.compareAndSet(index, value, null)) {
victim.blockingTasksInBuffer.decrementAndGet()
add(value)
stolenTaskRef.element = value
return TASK_STOLEN
} else {
++start
}
}
return tryStealLastScheduled(victim, blockingOnly = true)
return tryStealLastScheduled(victim, stolenTaskRef, blockingOnly = true)
}

fun offloadAllWorkTo(globalQueue: GlobalQueue) {
Expand All @@ -147,7 +147,7 @@ internal class WorkQueue {
/**
* Contract on return value is the same as for [tryStealFrom]
*/
private fun tryStealLastScheduled(victim: WorkQueue, blockingOnly: Boolean): Long {
private fun tryStealLastScheduled(victim: WorkQueue, stolenTaskRef: ObjectRef<Task?>, blockingOnly: Boolean): Long {
while (true) {
val lastScheduled = victim.lastScheduledTask.value ?: return NOTHING_TO_STEAL
if (blockingOnly && !lastScheduled.isBlocking) return NOTHING_TO_STEAL
Expand All @@ -164,7 +164,7 @@ internal class WorkQueue {
* and dispatched another one. In the latter case we should retry to avoid missing task.
*/
if (victim.lastScheduledTask.compareAndSet(lastScheduled, null)) {
add(lastScheduled)
stolenTaskRef.element = lastScheduled
return TASK_STOLEN
}
continue
Expand Down
19 changes: 11 additions & 8 deletions kotlinx-coroutines-core/jvm/test/scheduling/WorkQueueStressTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import org.junit.*
import org.junit.Test
import java.util.concurrent.*
import kotlin.concurrent.*
import kotlin.jvm.internal.*
import kotlin.test.*

class WorkQueueStressTest : TestBase() {
Expand Down Expand Up @@ -52,17 +53,18 @@ class WorkQueueStressTest : TestBase() {

for (i in 0 until stealersCount) {
threads += thread(name = "stealer $i") {
val ref = Ref.ObjectRef<Task?>()
val myQueue = WorkQueue()
startLatch.await()
while (!producerFinished || producerQueue.size != 0) {
stolenTasks[i].addAll(myQueue.drain().map { task(it) })
myQueue.tryStealFrom(victim = producerQueue)
stolenTasks[i].addAll(myQueue.drain(ref).map { task(it) })
myQueue.tryStealFrom(victim = producerQueue, ref)
}

// Drain last element which is not counted in buffer
stolenTasks[i].addAll(myQueue.drain().map { task(it) })
myQueue.tryStealFrom(producerQueue)
stolenTasks[i].addAll(myQueue.drain().map { task(it) })
stolenTasks[i].addAll(myQueue.drain(ref).map { task(it) })
myQueue.tryStealFrom(producerQueue, ref)
stolenTasks[i].addAll(myQueue.drain(ref).map { task(it) })
}
}

Expand All @@ -89,13 +91,14 @@ class WorkQueueStressTest : TestBase() {
val stolen = GlobalQueue()
threads += thread(name = "stealer") {
val myQueue = WorkQueue()
val ref = Ref.ObjectRef<Task?>()
startLatch.await()
while (stolen.size != offerIterations) {
if (myQueue.tryStealFrom(producerQueue) != NOTHING_TO_STEAL) {
stolen.addAll(myQueue.drain().map { task(it) })
if (myQueue.tryStealFrom(producerQueue, ref) != NOTHING_TO_STEAL) {
stolen.addAll(myQueue.drain(ref).map { task(it) })
}
}
stolen.addAll(myQueue.drain().map { task(it) })
stolen.addAll(myQueue.drain(ref).map { task(it) })
}

startLatch.countDown()
Expand Down
22 changes: 14 additions & 8 deletions kotlinx-coroutines-core/jvm/test/scheduling/WorkQueueTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package kotlinx.coroutines.scheduling
import kotlinx.coroutines.*
import org.junit.*
import org.junit.Test
import kotlin.jvm.internal.Ref.ObjectRef
import kotlin.test.*

class WorkQueueTest : TestBase() {
Expand All @@ -27,7 +28,7 @@ class WorkQueueTest : TestBase() {
fun testLastScheduledComesFirst() {
val queue = WorkQueue()
(1L..4L).forEach { queue.add(task(it)) }
assertEquals(listOf(4L, 1L, 2L, 3L), queue.drain())
assertEquals(listOf(4L, 1L, 2L, 3L), queue.drain(ObjectRef()))
}

@Test
Expand All @@ -38,9 +39,9 @@ class WorkQueueTest : TestBase() {
(0 until size).forEach { queue.add(task(it))?.let { t -> offload.addLast(t) } }

val expectedResult = listOf(129L) + (0L..126L).toList()
val actualResult = queue.drain()
val actualResult = queue.drain(ObjectRef())
assertEquals(expectedResult, actualResult)
assertEquals((0L until size).toSet().minus(expectedResult), offload.drain().toSet())
assertEquals((0L until size).toSet().minus(expectedResult.toSet()), offload.drain().toSet())
}

@Test
Expand All @@ -61,23 +62,28 @@ class WorkQueueTest : TestBase() {
timeSource.step(3)

val stealer = WorkQueue()
assertEquals(TASK_STOLEN, stealer.tryStealFrom(victim))
assertEquals(arrayListOf(1L), stealer.drain())
val ref = ObjectRef<Task?>()
assertEquals(TASK_STOLEN, stealer.tryStealFrom(victim, ref))
assertEquals(arrayListOf(1L), stealer.drain(ref))

assertEquals(TASK_STOLEN, stealer.tryStealFrom(victim))
assertEquals(arrayListOf(2L), stealer.drain())
assertEquals(TASK_STOLEN, stealer.tryStealFrom(victim, ref))
assertEquals(arrayListOf(2L), stealer.drain(ref))
}
}

internal fun task(n: Long) = TaskImpl(Runnable {}, n, NonBlockingContext)

internal fun WorkQueue.drain(): List<Long> {
internal fun WorkQueue.drain(ref: ObjectRef<Task?>): List<Long> {
var task: Task? = poll()
val result = arrayListOf<Long>()
while (task != null) {
result += task.submissionTime
task = poll()
}
if (ref.element != null) {
result += ref.element!!.submissionTime
ref.element = null
}
return result
}

Expand Down

0 comments on commit 20daaa7

Please sign in to comment.