Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A prototype of ThreadContextElement on k/js and k/native #3325

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
77 changes: 77 additions & 0 deletions kotlinx-coroutines-core/common/src/ThreadContextElement.common.kt
@@ -0,0 +1,77 @@
/*
* Copyright 2016-2022 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.coroutines

import kotlin.coroutines.*

/**
* Defines elements in [CoroutineContext] that are installed into thread context
* every time the coroutine with this element in the context is resumed on a thread.
*
* Implementations of this interface define a type [S] of the thread-local state that they need to store on
* resume of a coroutine and restore later on suspend. The infrastructure provides the corresponding storage.
*
* Example usage looks like this:
*
* ```
* // Appends "name" of a coroutine to a current thread name when coroutine is executed
* class CoroutineName(val name: String) : ThreadContextElement<String> {
* // declare companion object for a key of this element in coroutine context
* companion object Key : CoroutineContext.Key<CoroutineName>
*
* // provide the key of the corresponding context element
* override val key: CoroutineContext.Key<CoroutineName>
* get() = Key
*
* // this is invoked before coroutine is resumed on current thread
* override fun updateThreadContext(context: CoroutineContext): String {
* val previousName = Thread.currentThread().name
* Thread.currentThread().name = "$previousName # $name"
* return previousName
* }
*
* // this is invoked after coroutine has suspended on current thread
* override fun restoreThreadContext(context: CoroutineContext, oldState: String) {
* Thread.currentThread().name = oldState
* }
* }
*
* // Usage
* launch(Dispatchers.Main + CoroutineName("Progress bar coroutine")) { ... }
* ```
*
* Every time this coroutine is resumed on a thread, UI thread name is updated to
* "UI thread original name # Progress bar coroutine" and the thread name is restored to the original one when
* this coroutine suspends.
*
* To use [ThreadLocal] variable within the coroutine use [ThreadLocal.asContextElement][asContextElement] function.
*/
public interface ThreadContextElement<S> : CoroutineContext.Element {
/**
* Updates context of the current thread.
* This function is invoked before the coroutine in the specified [context] is resumed in the current thread
* when the context of the coroutine this element.
* The result of this function is the old value of the thread-local state that will be passed to [restoreThreadContext].
* This method should handle its own exceptions and do not rethrow it. Thrown exceptions will leave coroutine which
* context is updated in an undefined state and may crash an application.
*
* @param context the coroutine context.
*/
public fun updateThreadContext(context: CoroutineContext): S

/**
* Restores context of the current thread.
* This function is invoked after the coroutine in the specified [context] is suspended in the current thread
* if [updateThreadContext] was previously invoked on resume of this coroutine.
* The value of [oldState] is the result of the previous invocation of [updateThreadContext] and it should
* be restored in the thread-local state by this function.
* This method should handle its own exceptions and do not rethrow it. Thrown exceptions will leave coroutine which
* context is updated in an undefined state and may crash an application.
*
* @param context the coroutine context.
* @param oldState the value returned by the previous invocation of [updateThreadContext].
*/
public fun restoreThreadContext(context: CoroutineContext, oldState: S)
}
Expand Up @@ -4,6 +4,97 @@

package kotlinx.coroutines.internal

import kotlinx.coroutines.*
import kotlin.coroutines.*
import kotlin.jvm.*

internal expect fun threadContextElements(context: CoroutineContext): Any

@JvmField
internal val NO_THREAD_ELEMENTS = Symbol("NO_THREAD_ELEMENTS")

// Used when there are >= 2 active elements in the context
@Suppress("UNCHECKED_CAST")
private class ThreadState(@JvmField val context: CoroutineContext, n: Int) {
private val values = arrayOfNulls<Any>(n)
private val elements = arrayOfNulls<ThreadContextElement<Any?>>(n)
private var i = 0

fun append(element: ThreadContextElement<*>, value: Any?) {
values[i] = value
elements[i++] = element as ThreadContextElement<Any?>
}

fun restore(context: CoroutineContext) {
for (i in elements.indices.reversed()) {
elements[i]!!.restoreThreadContext(context, values[i])
}
}
}

// Counts ThreadContextElements in the context
// Any? here is Int | ThreadContextElement (when count is one)
internal val countAll =
fun (countOrElement: Any?, element: CoroutineContext.Element): Any? {
if (element is ThreadContextElement<*>) {
val inCount = countOrElement as? Int ?: 1
return if (inCount == 0) element else inCount + 1
}
return countOrElement
}

// Find one (first) ThreadContextElement in the context, it is used when we know there is exactly one
private val findOne =
fun (found: ThreadContextElement<*>?, element: CoroutineContext.Element): ThreadContextElement<*>? {
if (found != null) return found
return element as? ThreadContextElement<*>
}

// Updates state for ThreadContextElements in the context using the given ThreadState
private val updateState =
fun (state: ThreadState, element: CoroutineContext.Element): ThreadState {
if (element is ThreadContextElement<*>) {
state.append(element, element.updateThreadContext(state.context))
}
return state
}

// countOrElement is pre-cached in dispatched continuation
// returns NO_THREAD_ELEMENTS if the contest does not have any ThreadContextElements
internal fun updateThreadContext(context: CoroutineContext, countOrElement: Any?): Any? {
@Suppress("NAME_SHADOWING")
val countOrElement = countOrElement ?: threadContextElements(context)
@Suppress("IMPLICIT_BOXING_IN_IDENTITY_EQUALS")
return when {
countOrElement === 0 -> NO_THREAD_ELEMENTS // very fast path when there are no active ThreadContextElements
// ^^^ identity comparison for speed, we know zero always has the same identity
countOrElement is Int -> {
// slow path for multiple active ThreadContextElements, allocates ThreadState for multiple old values
context.fold(ThreadState(context, countOrElement), updateState)
}
else -> {
// fast path for one ThreadContextElement (no allocations, no additional context scan)
@Suppress("UNCHECKED_CAST")
val element = countOrElement as ThreadContextElement<Any?>
element.updateThreadContext(context)
}
}
}

internal fun restoreThreadContext(context: CoroutineContext, oldState: Any?) {
when {
oldState === NO_THREAD_ELEMENTS -> return // very fast path when there are no ThreadContextElements
oldState is ThreadState -> {
// slow path with multiple stored ThreadContextElements
oldState.restore(context)
}
else -> {
// fast path for one ThreadContextElement, but need to find it
@Suppress("UNCHECKED_CAST")
val element = context.fold(null, findOne) as ThreadContextElement<Any?>
element.restoreThreadContext(context, oldState)
}
}
}
//internal expect fun threadContextElements(context: CoroutineContext): Any

internal fun threadContextElements(context: CoroutineContext): Any = context.fold(0, countAll)!!
127 changes: 127 additions & 0 deletions kotlinx-coroutines-core/common/test/ThreadContextElementTest.common.kt
@@ -0,0 +1,127 @@
/*
* Copyright 2016-2022 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.coroutines

import kotlin.coroutines.*
import kotlin.test.*

class ThreadContextElementCommonTest : TestBase() {

interface TestThreadContextElement : ThreadContextElement<Int> {
companion object Key : CoroutineContext.Key<TestThreadContextElement>
}

@Test
fun updatesAndRestores() = runTest {
expect(1)
var update = 0
var restore = 0
val threadContextElement = object : TestThreadContextElement {
override fun updateThreadContext(context: CoroutineContext): Int {
update++
return 0
}

override fun restoreThreadContext(context: CoroutineContext, oldState: Int) {
restore++
}

override val key: CoroutineContext.Key<*>
get() = TestThreadContextElement.Key
}
launch(Dispatchers.Unconfined + threadContextElement) {
assertEquals(1, update)
assertEquals(0, restore)
}
assertEquals(1, update)
assertEquals(1, restore)
finish(2)
}

class TestThreadContextIntElement(
val update: () -> Int,
val restore: (Int) -> Unit
) : TestThreadContextElement {
override val key: CoroutineContext.Key<*>
get() = TestThreadContextElement.Key

override fun updateThreadContext(context: CoroutineContext): Int {
return update()
}

override fun restoreThreadContext(context: CoroutineContext, oldState: Int) {
restore(oldState)
}
}

@Test
fun twoCoroutinesUpdateAndRestore() = runTest {
expect(1)
var state = 0

var updateA = 0
var restoreA = 0
var updateB = 0
var restoreB = 0

val lock = Job()
println("Launch A")
val jobA = launch(Dispatchers.Unconfined + TestThreadContextIntElement(
update = {
updateA++
state = 10; 0
},
restore = {
restoreA++
state = it
}
)) {
println("A started")
assertEquals(1, updateA)
assertEquals(10, state)
println("A lock reached")
lock.join()
assertEquals(1, restoreA)
assertEquals(1, updateB)
assertEquals(1, restoreB)
assertEquals(2, updateA)
println("A resumed")
assertEquals(10, state)
println("A completes")
}

println("Launch B")
launch(Dispatchers.Unconfined + TestThreadContextIntElement(
update = {
updateB++
state = 20; 0
},
restore = {
restoreB++
state = it
}
)) {
println("B started")
assertEquals(1, updateB)
assertEquals(20, state)
println("B lock complete")
lock.complete()
println("B wait join A")
jobA.join()
assertEquals(2, updateB)
assertEquals(1, restoreB)
assertEquals(2, updateA)
assertEquals(2, restoreA)
println("B resumed")
assertEquals(20, state)
println("B completes")
}

println("All complete")
assertEquals(0, state)

finish(2)
}
}