diff --git a/kotlinx-coroutines-core/common/src/channels/Produce.kt b/kotlinx-coroutines-core/common/src/channels/Produce.kt index 3342fb6ec9..b7eefa78bc 100644 --- a/kotlinx-coroutines-core/common/src/channels/Produce.kt +++ b/kotlinx-coroutines-core/common/src/channels/Produce.kt @@ -137,7 +137,7 @@ internal fun CoroutineScope.produce( return coroutine } -internal open class ProducerCoroutine( +private class ProducerCoroutine( parentContext: CoroutineContext, channel: Channel ) : ChannelCoroutine(parentContext, channel, true, active = true), ProducerScope { override val isActive: Boolean diff --git a/kotlinx-coroutines-core/common/src/flow/internal/FlowCoroutine.kt b/kotlinx-coroutines-core/common/src/flow/internal/FlowCoroutine.kt index b395525620..9a81eefa2d 100644 --- a/kotlinx-coroutines-core/common/src/flow/internal/FlowCoroutine.kt +++ b/kotlinx-coroutines-core/common/src/flow/internal/FlowCoroutine.kt @@ -51,33 +51,11 @@ internal fun scopedFlow(@BuilderInference block: suspend CoroutineScope.(Flo flowScope { block(this@flow) } } -internal fun CoroutineScope.flowProduce( - context: CoroutineContext, - capacity: Int = 0, - @BuilderInference block: suspend ProducerScope.() -> Unit -): ReceiveChannel { - val channel = Channel(capacity) - val newContext = newCoroutineContext(context) - val coroutine = FlowProduceCoroutine(newContext, channel) - coroutine.start(CoroutineStart.ATOMIC, coroutine, block) - return coroutine -} - private class FlowCoroutine( context: CoroutineContext, uCont: Continuation ) : ScopeCoroutine(context, uCont) { - public override fun childCancelled(cause: Throwable): Boolean { - if (cause is ChildCancelledException) return true - return cancelImpl(cause) - } -} - -private class FlowProduceCoroutine( - parentContext: CoroutineContext, - channel: Channel -) : ProducerCoroutine(parentContext, channel) { - public override fun childCancelled(cause: Throwable): Boolean { + override fun childCancelled(cause: Throwable): Boolean { if (cause is ChildCancelledException) return true return cancelImpl(cause) } diff --git a/kotlinx-coroutines-core/common/src/flow/internal/Merge.kt b/kotlinx-coroutines-core/common/src/flow/internal/Merge.kt index 9eca8aa0c2..c18adba3b7 100644 --- a/kotlinx-coroutines-core/common/src/flow/internal/Merge.kt +++ b/kotlinx-coroutines-core/common/src/flow/internal/Merge.kt @@ -22,7 +22,7 @@ internal class ChannelFlowTransformLatest( override suspend fun flowCollect(collector: FlowCollector) { assert { collector is SendingCollector } // So cancellation behaviour is not leaking into the downstream - flowScope { + coroutineScope { var previousFlow: Job? = null flow.collect { value -> previousFlow?.apply { @@ -49,7 +49,7 @@ internal class ChannelFlowMerge( ChannelFlowMerge(flow, concurrency, context, capacity, onBufferOverflow) override fun produceImpl(scope: CoroutineScope): ReceiveChannel { - return scope.flowProduce(context, capacity, block = collectToFun) + return scope.produce(context, capacity, block = collectToFun) } override suspend fun collectTo(scope: ProducerScope) { @@ -87,7 +87,7 @@ internal class ChannelLimitedFlowMerge( ChannelLimitedFlowMerge(flows, context, capacity, onBufferOverflow) override fun produceImpl(scope: CoroutineScope): ReceiveChannel { - return scope.flowProduce(context, capacity, block = collectToFun) + return scope.produce(context, capacity, block = collectToFun) } override suspend fun collectTo(scope: ProducerScope) { diff --git a/kotlinx-coroutines-core/common/src/flow/operators/Merge.kt b/kotlinx-coroutines-core/common/src/flow/operators/Merge.kt index 432160f340..84005a4a31 100644 --- a/kotlinx-coroutines-core/common/src/flow/operators/Merge.kt +++ b/kotlinx-coroutines-core/common/src/flow/operators/Merge.kt @@ -61,7 +61,7 @@ public fun Flow.flatMapConcat(transform: suspend (value: T) -> Flow * its concurrent merging so that only one properly configured channel is used for execution of merging logic. * * @param concurrency controls the number of in-flight flows, at most [concurrency] flows are collected - * at the same time. By default it is equal to [DEFAULT_CONCURRENCY]. + * at the same time. By default, it is equal to [DEFAULT_CONCURRENCY]. */ @FlowPreview public fun Flow.flatMapMerge( diff --git a/kotlinx-coroutines-core/common/test/flow/operators/FlatMapMergeFastPathTest.kt b/kotlinx-coroutines-core/common/test/flow/operators/FlatMapMergeFastPathTest.kt index a92189c45c..f810221848 100644 --- a/kotlinx-coroutines-core/common/test/flow/operators/FlatMapMergeFastPathTest.kt +++ b/kotlinx-coroutines-core/common/test/flow/operators/FlatMapMergeFastPathTest.kt @@ -39,19 +39,14 @@ class FlatMapMergeFastPathTest : FlatMapMergeBaseTest() { @Test fun testCancellationExceptionDownstream() = runTest { - val flow = flow { - emit(1) - hang { expect(2) } - }.flatMapMerge { + val flow = flowOf(1, 2, 3).flatMapMerge { flow { emit(it) - expect(1) throw CancellationException("") } }.buffer(64) - assertFailsWith(flow) - finish(3) + assertEquals(listOf(1, 2, 3), flow.toList()) } @Test diff --git a/kotlinx-coroutines-core/common/test/flow/operators/FlatMapMergeTest.kt b/kotlinx-coroutines-core/common/test/flow/operators/FlatMapMergeTest.kt index 7470289ece..c2ce346d9b 100644 --- a/kotlinx-coroutines-core/common/test/flow/operators/FlatMapMergeTest.kt +++ b/kotlinx-coroutines-core/common/test/flow/operators/FlatMapMergeTest.kt @@ -69,19 +69,14 @@ class FlatMapMergeTest : FlatMapMergeBaseTest() { @Test fun testCancellationExceptionDownstream() = runTest { - val flow = flow { - emit(1) - hang { expect(2) } - }.flatMapMerge { + val flow = flowOf(1, 2, 3).flatMapMerge { flow { emit(it) - expect(1) throw CancellationException("") } } - assertFailsWith(flow) - finish(3) + assertEquals(listOf(1, 2, 3), flow.toList()) } @Test diff --git a/kotlinx-coroutines-core/common/test/flow/operators/FlattenConcatTest.kt b/kotlinx-coroutines-core/common/test/flow/operators/FlattenConcatTest.kt index 084af5b9bb..4ec7cc3cd1 100644 --- a/kotlinx-coroutines-core/common/test/flow/operators/FlattenConcatTest.kt +++ b/kotlinx-coroutines-core/common/test/flow/operators/FlattenConcatTest.kt @@ -36,4 +36,17 @@ class FlattenConcatTest : FlatMapBaseTest() { consumer.cancelAndJoin() finish(2) } + + @Test + fun testCancellation() = runTest { + val flow = flow { + repeat(5) { + emit(flow { + if (it == 2) throw CancellationException("") + emit(1) + }) + } + } + assertFailsWith(flow.flattenConcat()) + } } diff --git a/kotlinx-coroutines-core/common/test/flow/operators/MergeTest.kt b/kotlinx-coroutines-core/common/test/flow/operators/MergeTest.kt index 1248188554..f084798487 100644 --- a/kotlinx-coroutines-core/common/test/flow/operators/MergeTest.kt +++ b/kotlinx-coroutines-core/common/test/flow/operators/MergeTest.kt @@ -45,6 +45,64 @@ abstract class MergeTest : TestBase() { assertEquals(listOf("source"), result) } + @Test + fun testOneSourceCancelled() = runTest { + val flow = flow { + expect(1) + emit(1) + expect(2) + yield() + throw CancellationException("") + } + + val otherFlow = flow { + repeat(5) { + emit(1) + yield() + } + + expect(3) + } + + val result = listOf(flow, otherFlow).merge().toList() + assertEquals(MutableList(6) { 1 }, result) + finish(4) + } + + @Test + fun testOneSourceCancelledNonFused() = runTest { + val flow = flow { + expect(1) + emit(1) + expect(2) + yield() + throw CancellationException("") + } + + val otherFlow = flow { + repeat(5) { + emit(1) + yield() + } + + expect(3) + } + + val result = listOf(flow, otherFlow).nonFuseableMerge().toList() + assertEquals(MutableList(6) { 1 }, result) + finish(4) + } + + private fun Iterable>.nonFuseableMerge(): Flow { + return channelFlow { + forEach { flow -> + launch { + flow.collect { send(it) } + } + } + } + } + @Test fun testIsolatedContext() = runTest { val flow = flow {