Skip to content

Commit

Permalink
MultiplexFlow
Browse files Browse the repository at this point in the history
This is useful when the source allows only a single subscription, but
the data is needed by multiple users.

Fixing Kotlin#4103.
  • Loading branch information
odedniv committed Apr 17, 2024
1 parent f8d1821 commit 3362f89
Show file tree
Hide file tree
Showing 2 changed files with 479 additions and 0 deletions.
186 changes: 186 additions & 0 deletions kotlinx-coroutines-core/common/src/flow/operators/MultiplexFlow.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
package kotlinx.coroutines.flow

import kotlinx.coroutines.*
import kotlinx.coroutines.channels.*
import kotlin.coroutines.*
import kotlinx.coroutines.sync.*

/**
* Constructs a [MultiplexFlow].
*
* Behavior:
* * [getAll] is called every time the total keys collected by flows returned by [MultiplexFlow.get] changes (when collection is started or stopped).
* * [getAll] is called with the total keys of all collected [MultiplexFlow.get] flows.
* * [MultiplexFlow.get] calls share the data between them, such that [getAll] is not invoked when all the keys provided to [MultiplexFlow.get] are already collected by another [MultiplexFlow.get] caller.
* If [replay] is 0, this rule does not apply and [getAll] is re-invoked for every change in collections.
* * Errors in calls to [getAll] trigger a rollback to the previous keys, and collections of all [MultiplexFlow.get] with one of the new keys will throw that error.
* * Follow-up [getAll] error, or an error after the [getAll] collection already succeeded, will clear all subscriptions and cause all [MultiplexFlow.get] collections to throw that error.
* * If the flow returned by [getAll] finishes, all current collections of [MultiplexFlow.get] finish as well, and follow-up collections will re-invoke [getAll].
*/
public fun <K, V> MultiplexFlow(
scope: CoroutineScope,
replay: Int = 1,
extraBufferCapacity: Int = 0,
onBufferOverflow: BufferOverflow = BufferOverflow.SUSPEND,
getAll: suspend (keys: Set<K>) -> Flow<Map<K, V>>,
): MultiplexFlow<K, V> {
return MultiplexFlow(
Multiplexer(
getAll, replay = replay, extraBufferCapacity = extraBufferCapacity, onBufferOverflow = onBufferOverflow
).launchIn(scope)
)
}

/**
* Allows multiplexing multiple subscriptions to a single [Flow].
*
* This is useful when the source allows only a single subscription, but the data is needed by multiple users.
*/
public class MultiplexFlow<K, V> internal constructor(private val multiplexer: Multiplexer<K, V>) {
/** Returns a [Flow] that emits [V] for the requested [K]s, based on the map provided by `getAll`. */
public operator fun get(vararg keys: K): Flow<V> = flow {
val subscriptions = multiplexer.incrementUsage(*keys)
try {
subscriptions.filterKeys { it in keys }.values.map { it.data }.merge().collectWhile {
when (it) {
is Multiplexer.Value -> emit(it.value)
is Multiplexer.Error -> throw it.error
is Multiplexer.Finish -> return@collectWhile false
}
true
}
} finally {
multiplexer.decrementUsage(*keys)
}
}
}

/** Internal implementation that multiplexes the data to [MultiplexFlow]. */
internal class Multiplexer<K, V>(
private val getAll: suspend (keys: Set<K>) -> Flow<Map<K, V>>,
private val replay: Int,
private val extraBufferCapacity: Int,
private val onBufferOverflow: BufferOverflow,
) {
/** Current collected flows in [MultiplexFlow.get]. */
internal val subscriptions = MutableStateFlow(mapOf<K, DataAndUsers<V>>())

/** Last [subscriptions] keys, to know what changed. */
private var lastUsedKeys = setOf<K>()

/** Last [getAll] flow processor, so we can replace it with another. */
private var lastFlowsProcessor: Job? = null

/** Must only be called exactly once. */
internal fun launchIn(scope: CoroutineScope): Multiplexer<K, V> = also {
scope.launch {
try {
subscriptions.collect { current ->
val usedKeys = current.usedKeys()
if (replay > 0 && usedKeys == lastUsedKeys) return@collect
lastFlowsProcessor?.cancel()
for ((data, _) in current.values) data.resetReplayCache()
if (usedKeys.isEmpty()) {
lastUsedKeys = usedKeys
return@collect
}
val flow = tryGetAll(usedKeys) ?: return@collect
lastUsedKeys = usedKeys
// Getting succeeded, processing the flow.
lastFlowsProcessor = launch { processFlow(usedKeys, flow) }
}
} finally {
lastFlowsProcessor?.cancel()
for ((data, _) in subscriptions.value.values) data.emit(Finish())
}
}
}

internal suspend fun incrementUsage(vararg keys: K) = subscriptions.updateAndGet { previous ->
previous + keys.associateWith {
(previous[it] ?: DataAndUsers(mutableSharedFlow())) + currentCoroutineContext()
}
}

private fun mutableSharedFlow() = MutableSharedFlow<Emitted<V>>(
replay = replay,
extraBufferCapacity = extraBufferCapacity,
onBufferOverflow = onBufferOverflow,
)

internal suspend fun decrementUsage(vararg keys: K) {
subscriptions.update { previous ->
previous + keys.associateWith { previous[it]!! - currentCoroutineContext() }
}
}

/** Tries [getAll], rolling back and returning `null` on failure. */
private suspend fun tryGetAll(keys: Set<K>): Flow<Map<K, V>>? = try {
getAll(keys)
} catch (e: CancellationException) {
throw e
} catch (e: Throwable) {
// Failed to get, rolling back.
rollbackSubscriptions(current = keys, target = lastUsedKeys, cause = e)
lastUsedKeys = if (lastUsedKeys.isEmpty()) {
keys // Forcing a change to clear the subscription.
} else {
setOf() // Prevent infinite retries.
}
null
}

/** Processes the flow returned by [getAll], updating [DataAndUsers.data] of each entry. */
private suspend fun processFlow(keys: Set<K>, flow: Flow<Map<K, V>>) {
try {
flow.collect { allValues: Map<K, V> ->
for ((key, value) in allValues) {
if (key !in keys) continue // Ignoring keys that weren't subscribed.
subscriptions.value[key]!!.data.emit(Value(value))
}
}
} catch (e: CancellationException) {
throw e
} catch (e: Throwable) {
// Failed to collect, cancelling everything.
rollbackSubscriptions(current = keys, target = setOf(), cause = e)
return
}
// Collection finished normally, emitting Finish.
for ((data, _) in subscriptions.value.values) data.emit(Finish())
}

/** Rollbacks to [target] by removing the extras from [subscriptions] and setting the [DataAndUsers.data] of the removed keys to the error provided in the [cause]. */
private suspend fun rollbackSubscriptions(current: Set<K>, target: Set<K>, cause: Throwable) {
val toRemove = current - target
// Clearing the users of the keys to remove and emitting an error.
subscriptions.update { previous ->
previous.mapValues { (key: K, dataAndUsers: DataAndUsers<V>) ->
if (key in toRemove) {
dataAndUsers.copy(users = setOf()).also { it.data.emit(Error(cause)) }
} else {
dataAndUsers
}
}
}
}

private fun Map<K, DataAndUsers<V>>.usedKeys(): Set<K> = filterValues { it.users.isNotEmpty() }.keys

internal data class DataAndUsers<V>(
val data: MutableSharedFlow<Emitted<V>>,
val users: Set<CoroutineContext> = setOf(),
) {
operator fun plus(user: CoroutineContext) = copy(data = data, users = users + user)

operator fun minus(user: CoroutineContext) = copy(data = data, users = users - user)
}

internal sealed interface Emitted<V>

internal data class Value<V>(val value: V) : Emitted<V>

internal data class Error<V>(val error: Throwable) : Emitted<V>

internal class Finish<V> : Emitted<V>
}

0 comments on commit 3362f89

Please sign in to comment.