forked from typelevel/cats-effect
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Polymorphic AsyncAwait implementation
This preliminary work adds an async/await implementation based off the now built-in mechanism in the Scala 2 compiler. The grittiest details of the implementation are borrowed from : * scala/scala#8816 * https://github.com/retronym/monad-ui/tree/master/src/main/scala/monadui * https://github.com/scala/scala-async Due to the reliance on Dispatcher#unsafeRunSync, the implementation currently only works on JVM. Error propagation / cancellation seems to behave as it should. NB : it is worth noting that despite it compiling, using this with OptionT/EitherT/IorT is currently unsafe, for two reasons: * what seems to be a bug in the MonadCancel instance tied to those Error-able types: See https://gitter.im/typelevel/cats-effect-dev?at=60818cf4ae90f3684098c042 * The fact that calling `unsafeRunSync` is `F[A] => A`, which obviously doesn't work for types that have an error channel that isn't accounted for by the CE typeclasses.
- Loading branch information
Showing
3 changed files
with
310 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
/* | ||
* Copyright 2020-2021 Typelevel | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package cats.effect.std | ||
|
||
import scala.annotation.compileTimeOnly | ||
import scala.reflect.macros.whitebox | ||
|
||
import cats.effect.std.Dispatcher | ||
import cats.effect.kernel.Outcome | ||
import cats.effect.kernel.Sync | ||
import cats.effect.kernel.Async | ||
import cats.effect.kernel.syntax.all._ | ||
|
||
class AsyncAwaitDsl[F[_]](implicit F: Async[F]) { | ||
|
||
/** | ||
* Type member used by the macro expansion to recover what `F` is without typetags | ||
*/ | ||
type _AsyncContext[A] = F[A] | ||
|
||
/** | ||
* Value member used by the macro expansion to recover the Async instance associated to the block. | ||
*/ | ||
implicit val _AsyncInstance: Async[F] = F | ||
|
||
/** | ||
* Non-blocking await the on result of `awaitable`. This may only be used directly within an enclosing `async` block. | ||
* | ||
* Internally, this will register the remainder of the code in enclosing `async` block as a callback | ||
* in the `onComplete` handler of `awaitable`, and will *not* block a thread. | ||
*/ | ||
@compileTimeOnly("[async] `await` must be enclosed in an `async` block") | ||
def await[T](awaitable: F[T]): T = | ||
??? // No implementation here, as calls to this are translated to `onComplete` by the macro. | ||
|
||
/** | ||
* Run the block of code `body` asynchronously. `body` may contain calls to `await` when the results of | ||
* a `Future` are needed; this is translated into non-blocking code. | ||
*/ | ||
def async[T](body: => T): F[T] = macro AsyncAwaitDsl.asyncImpl[F, T] | ||
|
||
} | ||
|
||
object AsyncAwaitDsl { | ||
|
||
type Callback = Either[Throwable, AnyRef] => Unit | ||
|
||
def asyncImpl[F[_], T]( | ||
c: whitebox.Context | ||
)(body: c.Tree): c.Tree = { | ||
import c.universe._ | ||
if (!c.compilerSettings.contains("-Xasync")) { | ||
c.abort( | ||
c.macroApplication.pos, | ||
"The async requires the compiler option -Xasync (supported only by Scala 2.12.12+ / 2.13.3+)" | ||
) | ||
} else | ||
try { | ||
val awaitSym = typeOf[AsyncAwaitDsl[Any]].decl(TermName("await")) | ||
def mark(t: DefDef): Tree = { | ||
import language.reflectiveCalls | ||
c.internal | ||
.asInstanceOf[{ | ||
def markForAsyncTransform( | ||
owner: Symbol, | ||
method: DefDef, | ||
awaitSymbol: Symbol, | ||
config: Map[String, AnyRef] | ||
): DefDef | ||
}] | ||
.markForAsyncTransform( | ||
c.internal.enclosingOwner, | ||
t, | ||
awaitSym, | ||
Map.empty | ||
) | ||
} | ||
val name = TypeName("stateMachine$async") | ||
// format: off | ||
q""" | ||
final class $name(dispatcher: _root_.cats.effect.std.Dispatcher[${c.prefix}._AsyncContext], callback: _root_.cats.effect.std.AsyncAwaitDsl.Callback) extends _root_.cats.effect.std.AsyncAwaitStateMachine(dispatcher, callback) { | ||
${mark(q"""override def apply(tr$$async: _root_.cats.effect.kernel.Outcome[${c.prefix}._AsyncContext, _root_.scala.Throwable, _root_.scala.AnyRef]): _root_.scala.Unit = ${body}""")} | ||
} | ||
${c.prefix}._AsyncInstance.recoverWith { | ||
_root_.cats.effect.std.Dispatcher[${c.prefix}._AsyncContext].use { dispatcher => | ||
${c.prefix}._AsyncInstance.async_[_root_.scala.AnyRef](cb => new $name(dispatcher, cb).start()) | ||
} | ||
}{ | ||
case _root_.cats.effect.std.AsyncAwaitDsl.CancelBridge => | ||
${c.prefix}._AsyncInstance.map(${c.prefix}._AsyncInstance.canceled)(_ => null.asInstanceOf[AnyRef]) | ||
}.asInstanceOf[${c.macroApplication.tpe}] | ||
""" | ||
} catch { | ||
case e: ReflectiveOperationException => | ||
c.abort( | ||
c.macroApplication.pos, | ||
"-Xasync is provided as a Scala compiler option, but the async macro is unable to call c.internal.markForAsyncTransform. " + e.getClass.getName + " " + e.getMessage | ||
) | ||
} | ||
} | ||
|
||
// A marker exception to communicate cancellation through the async runtime. | ||
object CancelBridge extends Throwable with scala.util.control.NoStackTrace | ||
} | ||
|
||
abstract class AsyncAwaitStateMachine[F[_]]( | ||
dispatcher: Dispatcher[F], | ||
callback: AsyncAwaitDsl.Callback | ||
)(implicit F: Sync[F]) extends Function1[Outcome[F, Throwable, AnyRef], Unit] { | ||
|
||
// FSM translated method | ||
//def apply(v1: Outcome[IO, Throwable, AnyRef]): Unit = ??? | ||
|
||
private[this] var state$async: Int = 0 | ||
|
||
/** Retrieve the current value of the state variable */ | ||
protected def state: Int = state$async | ||
|
||
/** Assign `i` to the state variable */ | ||
protected def state_=(s: Int): Unit = state$async = s | ||
|
||
protected def completeFailure(t: Throwable): Unit = | ||
callback(Left(t)) | ||
|
||
protected def completeSuccess(value: AnyRef): Unit = { | ||
callback(Right(value)) | ||
} | ||
|
||
protected def onComplete(f: F[AnyRef]): Unit = { | ||
dispatcher.unsafeRunAndForget(f.guaranteeCase(outcome => F.delay(this(outcome)))) | ||
} | ||
|
||
protected def getCompleted(f: F[AnyRef]): Outcome[F, Throwable, AnyRef] = { | ||
val _ = f | ||
null | ||
} | ||
|
||
protected def tryGet(tr: Outcome[F, Throwable, AnyRef]): AnyRef = | ||
tr match { | ||
case Outcome.Succeeded(value) => | ||
// TODO discuss how to propagate "errors"" from other | ||
// error channels than the Async's, such as None | ||
// in OptionT. Maybe some ad-hoc polymorphic construct | ||
// with a custom path-dependent "bridge" exception type... | ||
// ... or something | ||
dispatcher.unsafeRunSync(value) | ||
case Outcome.Errored(e) => | ||
callback(Left(e)) | ||
this // sentinel value to indicate the dispatch loop should exit. | ||
case Outcome.Canceled() => | ||
callback(Left(AsyncAwaitDsl.CancelBridge)) | ||
this | ||
} | ||
|
||
def start(): Unit = { | ||
// Required to kickstart the async state machine. | ||
// `def apply` does not consult its argument when `state == 0`. | ||
apply(null) | ||
} | ||
|
||
} |
121 changes: 121 additions & 0 deletions
121
tests/jvm/src/test/scala-2/cats/effect/std/AsyncAwaitSpec.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
/* | ||
* Copyright 2020-2021 Typelevel | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package cats.effect | ||
package std | ||
|
||
import scala.concurrent.duration._ | ||
import cats.syntax.all._ | ||
import cats.data.Kleisli | ||
|
||
class AsyncAwaitSpec extends BaseSpec { | ||
|
||
"IOAsyncAwait" should { | ||
object IOAsyncAwait extends cats.effect.std.AsyncAwaitDsl[IO] | ||
import IOAsyncAwait.{await => ioAwait, _} | ||
|
||
"work on success" in real { | ||
|
||
val io = IO.sleep(100.millis) >> IO.pure(1) | ||
|
||
val program = async(ioAwait(io) + ioAwait(io)) | ||
|
||
program.flatMap { res => | ||
IO { | ||
res must beEqualTo(2) | ||
} | ||
} | ||
} | ||
|
||
"propagate errors outward" in real { | ||
|
||
case object Boom extends Throwable | ||
val io = IO.raiseError[Int](Boom) | ||
|
||
val program = async(ioAwait(io)) | ||
|
||
program.attempt.flatMap { res => | ||
IO { | ||
res must beEqualTo(Left(Boom)) | ||
} | ||
} | ||
} | ||
|
||
"propagate canceled outcomes outward" in real { | ||
|
||
val io = IO.canceled | ||
|
||
val program = async(ioAwait(io)) | ||
|
||
program.start.flatMap(_.join).flatMap { res => | ||
IO { | ||
res must beEqualTo(Outcome.canceled[IO, Throwable, Unit]) | ||
} | ||
} | ||
} | ||
|
||
"be cancellable" in real { | ||
|
||
val program = for { | ||
ref <- Ref[IO].of(0) | ||
_ <- async { ioAwait(IO.sleep(100.millis) *> ref.update(_ + 1)) } | ||
.start | ||
.flatMap(_.cancel) | ||
_ <- IO.sleep(200.millis) | ||
result <- ref.get | ||
} yield { | ||
result | ||
} | ||
|
||
program.flatMap { res => | ||
IO { | ||
res must beEqualTo(0) | ||
} | ||
} | ||
|
||
} | ||
|
||
"suspend side effects" in real { | ||
var x = 0 | ||
val program = async(x += 1) | ||
|
||
for { | ||
before <- IO(x must beEqualTo(0)) | ||
_ <- program | ||
after <- IO(x must beEqualTo(1)) | ||
} yield before && after | ||
} | ||
} | ||
|
||
"KleisliAsyncAwait" should { | ||
type F[A] = Kleisli[IO, Int, A] | ||
object KleisliAsyncAwait extends cats.effect.std.AsyncAwaitDsl[F] | ||
import KleisliAsyncAwait.{await => kAwait, _} | ||
|
||
"work on successes" in real { | ||
val io = Temporal[F].sleep(100.millis) >> Kleisli(x => IO.pure(x + 1)) | ||
|
||
val program = async(kAwait(io) + kAwait(io)) | ||
|
||
program.run(0).flatMap { res => | ||
IO { | ||
res must beEqualTo(2) | ||
} | ||
} | ||
} | ||
} | ||
|
||
} |