Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes recover inconsistency with raise DSL on types other than Either #3052

Merged
merged 11 commits into from
Jun 1, 2023
Merged
6 changes: 6 additions & 0 deletions arrow-libs/core/arrow-core/api/arrow-core.api
Original file line number Diff line number Diff line change
Expand Up @@ -3283,9 +3283,12 @@ public final class arrow/core/raise/IorRaise : arrow/core/raise/Raise {
public final fun bindAllIor (Ljava/util/Map;)Ljava/util/Map;
public final fun bindAllIor (Ljava/util/Set;)Ljava/util/Set;
public fun catch (Larrow/core/continuations/Effect;Lkotlin/jvm/functions/Function3;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public final fun combine (Ljava/lang/Object;)Ljava/lang/Object;
public final fun getCombineError ()Lkotlin/jvm/functions/Function2;
public fun invoke (Lkotlin/jvm/functions/Function1;)Ljava/lang/Object;
public fun invoke (Lkotlin/jvm/functions/Function2;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public fun raise (Ljava/lang/Object;)Ljava/lang/Void;
public final fun recover (Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;)Ljava/lang/Object;
public fun shift (Ljava/lang/Object;)Ljava/lang/Object;
}

Expand Down Expand Up @@ -3313,6 +3316,7 @@ public final class arrow/core/raise/NullableRaise : arrow/core/raise/Raise {
public fun invoke (Lkotlin/jvm/functions/Function2;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public synthetic fun raise (Ljava/lang/Object;)Ljava/lang/Void;
public fun raise (Ljava/lang/Void;)Ljava/lang/Void;
public final fun recover (Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function0;)Ljava/lang/Object;
public synthetic fun shift (Ljava/lang/Object;)Ljava/lang/Object;
public fun shift (Ljava/lang/Void;)Ljava/lang/Object;
}
Expand Down Expand Up @@ -3342,6 +3346,7 @@ public final class arrow/core/raise/OptionRaise : arrow/core/raise/Raise {
public fun invoke (Lkotlin/jvm/functions/Function2;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public fun raise (Larrow/core/None;)Ljava/lang/Void;
public synthetic fun raise (Ljava/lang/Object;)Ljava/lang/Void;
public final fun recover (Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function0;)Ljava/lang/Object;
public fun shift (Larrow/core/None;)Ljava/lang/Object;
public synthetic fun shift (Ljava/lang/Object;)Ljava/lang/Object;
}
Expand Down Expand Up @@ -3506,6 +3511,7 @@ public final class arrow/core/raise/ResultRaise : arrow/core/raise/Raise {
public fun invoke (Lkotlin/jvm/functions/Function2;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public synthetic fun raise (Ljava/lang/Object;)Ljava/lang/Void;
public fun raise (Ljava/lang/Throwable;)Ljava/lang/Void;
public final fun recover (Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;)Ljava/lang/Object;
public synthetic fun shift (Ljava/lang/Object;)Ljava/lang/Object;
public fun shift (Ljava/lang/Throwable;)Ljava/lang/Object;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package arrow.core.raise
import arrow.atomic.Atomic
import arrow.atomic.updateAndGet
import arrow.core.Either
import arrow.core.EmptyValue.combine
import arrow.core.Ior
import arrow.core.NonEmptyList
import arrow.core.NonEmptySet
Expand Down Expand Up @@ -72,6 +73,15 @@ public class NullableRaise(private val raise: Raise<Null>) : Raise<Null> by rais
contract { returns() implies (value != null) }
return ensureNotNull(value) { null }
}

@RaiseDSL
public inline fun <A> recover(
@BuilderInference block: NullableRaise.() -> A,
recover: () -> A,
): A = when (val nullable = nullable(block)) {
null -> recover()
else -> nullable
}
}

public class ResultRaise(private val raise: Raise<Throwable>) : Raise<Throwable> by raise {
Expand All @@ -96,6 +106,15 @@ public class ResultRaise(private val raise: Raise<Throwable>) : Raise<Throwable>
@JvmName("bindAllResult")
public fun <A> NonEmptySet<Result<A>>.bindAll(): NonEmptySet<A> =
map { it.bind() }

@RaiseDSL
public inline fun <A> recover(
@BuilderInference block: ResultRaise.() -> A,
recover: (Throwable) -> A,
): A = result(block).fold(
onSuccess = { it },
onFailure = { recover(it) }
)
}

public class OptionRaise(private val raise: Raise<None>) : Raise<None> by raise {
Expand Down Expand Up @@ -129,10 +148,19 @@ public class OptionRaise(private val raise: Raise<None>) : Raise<None> by raise
contract { returns() implies (value != null) }
return ensureNotNull(value) { None }
}

@RaiseDSL
public inline fun <A> recover(
@BuilderInference block: OptionRaise.() -> A,
recover: () -> A,
): A = when (val option = option(block)) {
is None -> recover()
is Some<A> -> option.value
}
}

public class IorRaise<Error> @PublishedApi internal constructor(
private val combineError: (Error, Error) -> Error,
@PublishedApi internal val combineError: (Error, Error) -> Error,
private val state: Atomic<Option<Error>>,
private val raise: Raise<Error>,
) : Raise<Error> {
Expand Down Expand Up @@ -170,8 +198,23 @@ public class IorRaise<Error> @PublishedApi internal constructor(
public fun <K, V> Map<K, Ior<Error, V>>.bindAll(): Map<K, V> =
mapValues { (_, v) -> v.bind() }

private fun combine(other: Error): Error =
@PublishedApi
internal fun combine(other: Error): Error =
state.updateAndGet { prev ->
Some(prev.map { combineError(it, other) }.getOrElse { other })
}.getOrElse { other }

@RaiseDSL
public inline fun <A> recover(
@BuilderInference block: IorRaise<Error>.() -> A,
recover: (error: Error) -> A,
): A = when (val ior = ior(combineError, block)) {
is Ior.Both -> {
combine(ior.leftValue)
ior.rightValue
}

is Ior.Left -> recover(ior.value)
is Ior.Right -> ior.value
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,11 @@ package arrow.core.raise
import arrow.core.Either
import arrow.core.Ior
import arrow.core.test.nonEmptyList
import arrow.typeclasses.Semigroup
import io.kotest.assertions.throwables.shouldThrow
import io.kotest.core.spec.style.StringSpec
import io.kotest.matchers.shouldBe
import io.kotest.property.Arb
import io.kotest.property.arbitrary.filter
import io.kotest.property.arbitrary.int
import io.kotest.property.arbitrary.list
import io.kotest.property.arbitrary.string
import io.kotest.property.checkAll
import kotlinx.coroutines.async
import kotlinx.coroutines.awaitAll
Expand Down Expand Up @@ -79,4 +75,13 @@ class IorSpec : StringSpec({
}
}.message shouldBe "Boom!"
}

"Recover works as expected" {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This highlights the problem in the issue. This test fails without the change above.

ior(String::plus) {
val one = recover({ Ior.Left("Hello").bind() }) { 1 }
val two = Ior.Right(2).bind()
val three = Ior.Both(", World", 3).bind()
one + two + three
} shouldBe Ior.Both(", World", 6)
}
})
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,12 @@ class NullableSpec : StringSpec({
either.bind() + 3
} shouldBe 7
}

"Recover works as expected" {
nullable {
val one: Int = recover({ null.bind<Int>() }) { 1 }
val two = 2.bind()
one + two
} shouldBe 3
}
})
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package arrow.core.raise

import arrow.core.None
import arrow.core.Some
import arrow.core.some
import arrow.core.toOption
import io.kotest.core.spec.style.StringSpec
Expand Down Expand Up @@ -39,4 +40,12 @@ class OptionSpec : StringSpec({
throw IllegalStateException("This should not be executed")
} shouldBe None
}

"Recover works as expected" {
option {
val one: Int = recover({ None.bind<Int>() }) { 1 }
val two = Some(2).bind()
one + two
} shouldBe Some(3)
}
})
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,12 @@ class ResultSpec : StringSpec({
"Result - raise" {
result { raise(boom) } shouldBe Result.failure(boom)
}

"Recover works as expected" {
result {
val one: Int = recover({ Result.failure<Int>(boom).bind() }) { 1 }
val two = Result.success(2).bind()
one + two
} shouldBe Result.success(3)
}
})