Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CancellableQueueSynchronizer and ReadWriteMutex #2045

Open
wants to merge 9 commits into
base: develop
Choose a base branch
from
18 changes: 12 additions & 6 deletions benchmarks/src/jmh/kotlin/benchmarks/ChannelSinkBenchmark.kt
Expand Up @@ -24,6 +24,11 @@ open class ChannelSinkBenchmark {
private val unconfinedOneElement = Dispatchers.Unconfined + tl.asContextElement()
private val unconfinedTwoElements = Dispatchers.Unconfined + tl.asContextElement() + tl2.asContextElement()

private val elements = (0 until N).toList()

@Param("0", "1", "8", "32")
var channelCapacity = 0

@Benchmark
fun channelPipeline(): Int = runBlocking {
run(unconfined)
Expand All @@ -41,14 +46,14 @@ open class ChannelSinkBenchmark {

private suspend inline fun run(context: CoroutineContext): Int {
return Channel
.range(1, 10_000, context)
.filter(context) { it % 4 == 0 }
.fold(0) { a, b -> a + b }
.range(context) // should not allocate `Int`s!
.filter(context) { it % 4 == 0 } // should not allocate `Int`s!
.fold(0) { a, b -> if (a % 8 == 0) a else b } // should not allocate `Int`s!
}

private fun Channel.Factory.range(start: Int, count: Int, context: CoroutineContext) = GlobalScope.produce(context) {
for (i in start until (start + count))
send(i)
private fun Channel.Factory.range(context: CoroutineContext) = GlobalScope.produce(context, capacity = channelCapacity) {
for (i in 0 until N)
send(elements[i]) // should not allocate `Int`s!
}

// Migrated from deprecated operators, are good only for stressing channels
Expand All @@ -69,3 +74,4 @@ open class ChannelSinkBenchmark {
}
}

private const val N = 10_000
6 changes: 3 additions & 3 deletions benchmarks/src/jmh/kotlin/benchmarks/SemaphoreBenchmark.kt
Expand Up @@ -48,7 +48,7 @@ open class SemaphoreBenchmark {
val semaphore = Semaphore(_3_maxPermits)
val jobs = ArrayList<Job>(coroutines)
repeat(coroutines) {
jobs += GlobalScope.launch {
jobs += GlobalScope.launch(dispatcher) {
repeat(n) {
semaphore.withPermit {
doGeomDistrWork(WORK_INSIDE)
Expand All @@ -66,7 +66,7 @@ open class SemaphoreBenchmark {
val semaphore = Channel<Unit>(_3_maxPermits)
val jobs = ArrayList<Job>(coroutines)
repeat(coroutines) {
jobs += GlobalScope.launch {
jobs += GlobalScope.launch(dispatcher) {
repeat(n) {
semaphore.send(Unit) // acquire
doGeomDistrWork(WORK_INSIDE)
Expand All @@ -87,4 +87,4 @@ enum class SemaphoreBenchDispatcherCreator(val create: (parallelism: Int) -> Cor

private const val WORK_INSIDE = 50
private const val WORK_OUTSIDE = 50
private const val BATCH_SIZE = 100000
private const val BATCH_SIZE = 1000000
@@ -0,0 +1,43 @@
/*
* Copyright 2016-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package benchmarks

import kotlinx.coroutines.*
import kotlinx.coroutines.sync.*
import org.openjdk.jmh.annotations.*
import java.util.concurrent.TimeUnit
import kotlin.test.*

@Warmup(iterations = 5, time = 1)
@Measurement(iterations = 10, time = 1)
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@State(Scope.Benchmark)
@Fork(1)
open class SequentialSemaphoreAsMutexBenchmark {
val s = Semaphore(1)

@Benchmark
fun benchmark() : Unit = runBlocking {
val s = Semaphore(permits = 1, acquiredPermits = 1)
var step = 0
launch(Dispatchers.Unconfined) {
repeat(N) {
assertEquals(it * 2, step)
step++
s.acquire()
}
}
repeat(N) {
assertEquals(it * 2 + 1, step)
step++
s.release()
}
}
}

fun main() = SequentialSemaphoreAsMutexBenchmark().benchmark()

private val N = 1_000_000
16 changes: 15 additions & 1 deletion kotlinx-coroutines-core/api/kotlinx-coroutines-core.api
Expand Up @@ -51,7 +51,7 @@ public final class kotlinx/coroutines/CancellableContinuation$DefaultImpls {
public static synthetic fun tryResume$default (Lkotlinx/coroutines/CancellableContinuation;Ljava/lang/Object;Ljava/lang/Object;ILjava/lang/Object;)Ljava/lang/Object;
}

public class kotlinx/coroutines/CancellableContinuationImpl : kotlin/coroutines/jvm/internal/CoroutineStackFrame, kotlinx/coroutines/CancellableContinuation, kotlinx/coroutines/channels/Waiter {
public class kotlinx/coroutines/CancellableContinuationImpl : kotlin/coroutines/jvm/internal/CoroutineStackFrame, kotlinx/coroutines/CancellableContinuation, kotlinx/coroutines/Waiter {
public fun <init> (Lkotlin/coroutines/Continuation;I)V
public final fun callCancelHandler (Lkotlinx/coroutines/CancelHandler;Ljava/lang/Throwable;)V
public final fun callOnCancellation (Lkotlin/jvm/functions/Function1;Ljava/lang/Throwable;)V
Expand All @@ -64,6 +64,7 @@ public class kotlinx/coroutines/CancellableContinuationImpl : kotlin/coroutines/
public fun getStackTraceElement ()Ljava/lang/StackTraceElement;
public fun initCancellability ()V
public fun invokeOnCancellation (Lkotlin/jvm/functions/Function1;)V
public fun invokeOnCancellation (Lkotlinx/coroutines/internal/Segment;I)V
public fun isActive ()Z
public fun isCancelled ()Z
public fun isCompleted ()Z
Expand Down Expand Up @@ -1257,6 +1258,7 @@ public class kotlinx/coroutines/selects/SelectImplementation : kotlinx/coroutine
public fun invoke (Lkotlinx/coroutines/selects/SelectClause1;Lkotlin/jvm/functions/Function2;)V
public fun invoke (Lkotlinx/coroutines/selects/SelectClause2;Ljava/lang/Object;Lkotlin/jvm/functions/Function2;)V
public fun invoke (Lkotlinx/coroutines/selects/SelectClause2;Lkotlin/jvm/functions/Function2;)V
public fun invokeOnCancellation (Lkotlinx/coroutines/internal/Segment;I)V
public fun onTimeout (JLkotlin/jvm/functions/Function1;)V
public fun selectInRegistrationPhase (Ljava/lang/Object;)V
public fun trySelect (Ljava/lang/Object;Ljava/lang/Object;)Z
Expand Down Expand Up @@ -1327,6 +1329,18 @@ public final class kotlinx/coroutines/sync/MutexKt {
public static synthetic fun withLock$default (Lkotlinx/coroutines/sync/Mutex;Ljava/lang/Object;Lkotlin/jvm/functions/Function0;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object;
}

public abstract interface class kotlinx/coroutines/sync/ReadWriteMutex {
public abstract fun getWrite ()Lkotlinx/coroutines/sync/Mutex;
public abstract fun readLock (Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public abstract fun readUnlock ()V
}

public final class kotlinx/coroutines/sync/ReadWriteMutexKt {
public static final fun ReadWriteMutex ()Lkotlinx/coroutines/sync/ReadWriteMutex;
public static final fun read (Lkotlinx/coroutines/sync/ReadWriteMutex;Lkotlin/jvm/functions/Function0;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public static final fun write (Lkotlinx/coroutines/sync/ReadWriteMutex;Lkotlin/jvm/functions/Function0;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
}

public abstract interface class kotlinx/coroutines/sync/Semaphore {
public abstract fun acquire (Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public abstract fun getAvailablePermits ()I
Expand Down
8 changes: 4 additions & 4 deletions kotlinx-coroutines-core/build.gradle
Expand Up @@ -266,8 +266,8 @@ task jvmStressTest(type: Test, dependsOn: compileTestKotlinJvm) {
testLogging.showStandardStreams = true
systemProperty 'kotlinx.coroutines.scheduler.keep.alive.sec', '100000' // any unpark problem hangs test
// Adjust internal algorithmic parameters to increase the testing quality instead of performance.
systemProperty 'kotlinx.coroutines.semaphore.segmentSize', '1'
systemProperty 'kotlinx.coroutines.semaphore.maxSpinCycles', '10'
systemProperty 'kotlinx.coroutines.cqs.segmentSize', '1'
systemProperty 'kotlinx.coroutines.cqs.maxSpinCycles', '10'
systemProperty 'kotlinx.coroutines.bufferedChannel.segmentSize', '2'
systemProperty 'kotlinx.coroutines.bufferedChannel.expandBufferCompletionWaitIterations', '1'
}
Expand Down Expand Up @@ -302,8 +302,8 @@ static void configureJvmForLincheck(task, additional = false) {
'--add-exports', 'java.base/jdk.internal.util=ALL-UNNAMED'] // in the model checking mode
// Adjust internal algorithmic parameters to increase the testing quality instead of performance.
var segmentSize = additional ? '2' : '1'
task.systemProperty 'kotlinx.coroutines.semaphore.segmentSize', segmentSize
task.systemProperty 'kotlinx.coroutines.semaphore.maxSpinCycles', '1' // better for the model checking mode
task.systemProperty 'kotlinx.coroutines.cqs.segmentSize', segmentSize
task.systemProperty 'kotlinx.coroutines.cqs.maxSpinCycles', '1' // better for the model checking mode
task.systemProperty 'kotlinx.coroutines.bufferedChannel.segmentSize', segmentSize
task.systemProperty 'kotlinx.coroutines.bufferedChannel.expandBufferCompletionWaitIterations', '1'
}
Expand Down
105 changes: 82 additions & 23 deletions kotlinx-coroutines-core/common/src/CancellableContinuationImpl.kt
Expand Up @@ -5,7 +5,6 @@
package kotlinx.coroutines

import kotlinx.atomicfu.*
import kotlinx.coroutines.channels.Waiter
import kotlinx.coroutines.internal.*
import kotlin.coroutines.*
import kotlin.coroutines.intrinsics.*
Expand All @@ -15,6 +14,15 @@ private const val UNDECIDED = 0
private const val SUSPENDED = 1
private const val RESUMED = 2

private const val DECISION_SHIFT = 29
private const val INDEX_MASK = (1 shl DECISION_SHIFT) - 1
private const val NO_INDEX = INDEX_MASK

private inline val Int.decision get() = this shr DECISION_SHIFT
private inline val Int.index get() = this and INDEX_MASK
@Suppress("NOTHING_TO_INLINE")
private inline fun decisionAndIndex(decision: Int, index: Int) = (decision shl DECISION_SHIFT) + index

@JvmField
internal val RESUME_TOKEN = Symbol("RESUME_TOKEN")

Expand Down Expand Up @@ -44,7 +52,7 @@ internal open class CancellableContinuationImpl<in T>(
* less dependencies.
*/

/* decision state machine
/** decision state machine

+-----------+ trySuspend +-----------+
| UNDECIDED | -------------> | SUSPENDED |
Expand All @@ -56,9 +64,12 @@ internal open class CancellableContinuationImpl<in T>(
| RESUMED |
+-----------+

Note: both tryResume and trySuspend can be invoked at most once, first invocation wins
Note: both tryResume and trySuspend can be invoked at most once, first invocation wins.
If the cancellation handler is specified via a [Segment] instance and the index in it
(so [Segment.onCancellation] should be called), the [_decisionAndIndex] field may store
this index additionally to the "decision" value.
*/
private val _decision = atomic(UNDECIDED)
private val _decisionAndIndex = atomic(decisionAndIndex(UNDECIDED, NO_INDEX))

/*
=== Internal states ===
Expand Down Expand Up @@ -144,7 +155,7 @@ internal open class CancellableContinuationImpl<in T>(
detachChild()
return false
}
_decision.value = UNDECIDED
_decisionAndIndex.value = decisionAndIndex(UNDECIDED, NO_INDEX)
_state.value = Active
return true
}
Expand Down Expand Up @@ -194,10 +205,13 @@ internal open class CancellableContinuationImpl<in T>(
_state.loop { state ->
if (state !is NotCompleted) return false // false if already complete or cancelling
// Active -- update to final state
val update = CancelledContinuation(this, cause, handled = state is CancelHandler)
val update = CancelledContinuation(this, cause, handled = state is CancelHandler || state is Segment<*>)
if (!_state.compareAndSet(state, update)) return@loop // retry on cas failure
// Invoke cancel handler if it was present
(state as? CancelHandler)?.let { callCancelHandler(it, cause) }
when (state) {
is CancelHandler -> callCancelHandler(state, cause)
is Segment<*> -> callSegmentOnCancellation(state, cause)
}
// Complete state update
detachChildIfNonResuable()
dispatchResume(resumeMode) // no need for additional cancellation checks
Expand Down Expand Up @@ -234,6 +248,12 @@ internal open class CancellableContinuationImpl<in T>(
fun callCancelHandler(handler: CancelHandler, cause: Throwable?) =
callCancelHandlerSafely { handler.invoke(cause) }

private fun callSegmentOnCancellation(segment: Segment<*>, cause: Throwable?) {
val index = _decisionAndIndex.value.index
check(index != NO_INDEX) { "The index for Segment.onCancellation(..) is broken" }
callCancelHandlerSafely { segment.onCancellation(index, cause) }
}

fun callOnCancellation(onCancellation: (cause: Throwable) -> Unit, cause: Throwable) {
try {
onCancellation.invoke(cause)
Expand All @@ -253,19 +273,19 @@ internal open class CancellableContinuationImpl<in T>(
parent.getCancellationException()

private fun trySuspend(): Boolean {
_decision.loop { decision ->
when (decision) {
UNDECIDED -> if (this._decision.compareAndSet(UNDECIDED, SUSPENDED)) return true
_decisionAndIndex.loop { cur ->
when (cur.decision) {
UNDECIDED -> if (this._decisionAndIndex.compareAndSet(cur, decisionAndIndex(SUSPENDED, cur.index))) return true
RESUMED -> return false
else -> error("Already suspended")
}
}
}

private fun tryResume(): Boolean {
_decision.loop { decision ->
when (decision) {
UNDECIDED -> if (this._decision.compareAndSet(UNDECIDED, RESUMED)) return true
_decisionAndIndex.loop { cur ->
when (cur.decision) {
UNDECIDED -> if (this._decisionAndIndex.compareAndSet(cur, decisionAndIndex(RESUMED, cur.index))) return true
SUSPENDED -> return false
else -> error("Already resumed")
}
Expand All @@ -275,7 +295,7 @@ internal open class CancellableContinuationImpl<in T>(
@PublishedApi
internal fun getResult(): Any? {
val isReusable = isReusable()
// trySuspend may fail either if 'block' has resumed/cancelled a continuation
// trySuspend may fail either if 'block' has resumed/cancelled a continuation,
// or we got async cancellation from parent.
if (trySuspend()) {
/*
Expand Down Expand Up @@ -350,14 +370,44 @@ internal open class CancellableContinuationImpl<in T>(
override fun resume(value: T, onCancellation: ((cause: Throwable) -> Unit)?) =
resumeImpl(value, resumeMode, onCancellation)

/**
* An optimized version for the code below that does not allocate
* a cancellation handler object and efficiently stores the specified
* [segment] and [index] in this [CancellableContinuationImpl].
*
* The only difference is that `segment.onCancellation(..)` is never
* called if this continuation is already completed; thus,
* the semantics is similar to [BeforeResumeCancelHandler].
*
* ```
* invokeOnCancellation { cause ->
* segment.onCancellation(index, cause)
* }
* ```
*/
override fun invokeOnCancellation(segment: Segment<*>, index: Int) {
_decisionAndIndex.update {
check(it.index == NO_INDEX) {
"invokeOnCancellation should be called at most once"
}
decisionAndIndex(it.decision, index)
}
invokeOnCancellationImpl(segment)
}

public override fun invokeOnCancellation(handler: CompletionHandler) {
val cancelHandler = makeCancelHandler(handler)
invokeOnCancellationImpl(cancelHandler)
}

private fun invokeOnCancellationImpl(handler: Any) {
assert { handler is CancelHandler || handler is Segment<*> }
_state.loop { state ->
when (state) {
is Active -> {
if (_state.compareAndSet(state, cancelHandler)) return // quit on cas success
if (_state.compareAndSet(state, handler)) return // quit on cas success
}
is CancelHandler -> multipleHandlersError(handler, state)
is CancelHandler, is Segment<*> -> multipleHandlersError(handler, state)
is CompletedExceptionally -> {
/*
* Continuation was already cancelled or completed exceptionally.
Expand All @@ -371,7 +421,13 @@ internal open class CancellableContinuationImpl<in T>(
* because we play type tricks on Kotlin/JS and handler is not necessarily a function there
*/
if (state is CancelledContinuation) {
callCancelHandler(handler, (state as? CompletedExceptionally)?.cause)
val cause: Throwable? = (state as? CompletedExceptionally)?.cause
if (handler is CancelHandler) {
callCancelHandler(handler, cause)
} else {
val segment = handler as Segment<*>
callSegmentOnCancellation(segment, cause)
}
}
return
}
Expand All @@ -380,14 +436,16 @@ internal open class CancellableContinuationImpl<in T>(
* Continuation was already completed, and might already have cancel handler.
*/
if (state.cancelHandler != null) multipleHandlersError(handler, state)
// BeforeResumeCancelHandler does not need to be called on a completed continuation
if (cancelHandler is BeforeResumeCancelHandler) return
// BeforeResumeCancelHandler and Segment.invokeOnCancellation(..)
// do NOT need to be called on completed continuation.
if (handler is BeforeResumeCancelHandler || handler is Segment<*>) return
handler as CancelHandler
if (state.cancelled) {
// Was already cancelled while being dispatched -- invoke the handler directly
callCancelHandler(handler, state.cancelCause)
return
}
val update = state.copy(cancelHandler = cancelHandler)
val update = state.copy(cancelHandler = handler)
if (_state.compareAndSet(state, update)) return // quit on cas success
}
else -> {
Expand All @@ -396,15 +454,16 @@ internal open class CancellableContinuationImpl<in T>(
* Change its state to CompletedContinuation, unless we have BeforeResumeCancelHandler which
* does not need to be called in this case.
*/
if (cancelHandler is BeforeResumeCancelHandler) return
val update = CompletedContinuation(state, cancelHandler = cancelHandler)
if (handler is BeforeResumeCancelHandler || handler is Segment<*>) return
handler as CancelHandler
val update = CompletedContinuation(state, cancelHandler = handler)
if (_state.compareAndSet(state, update)) return // quit on cas success
}
}
}
}

private fun multipleHandlersError(handler: CompletionHandler, state: Any?) {
private fun multipleHandlersError(handler: Any, state: Any?) {
error("It's prohibited to register multiple handlers, tried to register $handler, already has $state")
}

Expand Down