Skip to content

Commit

Permalink
Polymorphic AsyncAwait implementation
Browse files Browse the repository at this point in the history
This preliminary work adds an async/await implementation based off
the now built-in mechanism in the Scala 2 compiler.

The grittiest details of the implementation are borrowed from :

* scala/scala#8816
* https://github.com/retronym/monad-ui/tree/master/src/main/scala/monadui
* https://github.com/scala/scala-async

Due to the reliance on Dispatcher#unsafeRunSync, the implementation
currently only works on JVM.

Error propagation / cancellation seems to behave as it should.

NB : it is worth noting that despite it compiling, using this with
OptionT/EitherT/IorT is currently unsafe, for two reasons:

* what seems to be a bug in the MonadCancel instance tied to those
Error-able types: See https://gitter.im/typelevel/cats-effect-dev?at=60818cf4ae90f3684098c042
* The fact that calling `unsafeRunSync` is `F[A] => A`, which obviously
doesn't work for types that have an error channel that isn't accounted
for by the CE typeclasses.
  • Loading branch information
Baccata committed Apr 22, 2021
1 parent df6861d commit dd7aa87
Show file tree
Hide file tree
Showing 3 changed files with 310 additions and 9 deletions.
23 changes: 14 additions & 9 deletions build.sbt
Expand Up @@ -222,13 +222,12 @@ lazy val kernel = crossProject(JSPlatform, JVMPlatform)
libraryDependencies += "org.specs2" %%% "specs2-core" % Specs2Version % Test)
.settings(dottyLibrarySettings)
.settings(libraryDependencies += "org.typelevel" %%% "cats-core" % CatsVersion)
.jsSettings(
Compile / doc / sources := {
if (isDotty.value)
Seq()
else
(Compile / doc / sources).value
})
.jsSettings(Compile / doc / sources := {
if (isDotty.value)
Seq()
else
(Compile / doc / sources).value
})

/**
* Reference implementations (including a pure ConcurrentBracket), generic ScalaCheck
Expand Down Expand Up @@ -303,7 +302,8 @@ lazy val tests = crossProject(JSPlatform, JVMPlatform)
name := "cats-effect-tests",
libraryDependencies ++= Seq(
"org.typelevel" %%% "discipline-specs2" % DisciplineVersion % Test,
"org.typelevel" %%% "cats-kernel-laws" % CatsVersion % Test)
"org.typelevel" %%% "cats-kernel-laws" % CatsVersion % Test),
scalacOptions ++= List("-Xasync")
)
.jvmSettings(
Test / fork := true,
Expand All @@ -329,7 +329,12 @@ lazy val std = crossProject(JSPlatform, JVMPlatform)
else
"org.specs2" %%% "specs2-scalacheck" % Specs2Version % Test
},
libraryDependencies += "org.scalacheck" %%% "scalacheck" % ScalaCheckVersion % Test
libraryDependencies += "org.scalacheck" %%% "scalacheck" % ScalaCheckVersion % Test,
libraryDependencies ++= {
if (!isDotty.value)
Seq("org.scala-lang" % "scala-reflect" % scalaVersion.value % "provided")
else Seq()
}
)

/**
Expand Down
175 changes: 175 additions & 0 deletions std/jvm/src/main/scala-2/AsyncAwait.scala
@@ -0,0 +1,175 @@
/*
* Copyright 2020-2021 Typelevel
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package cats.effect.std

import scala.annotation.compileTimeOnly
import scala.reflect.macros.whitebox

import cats.effect.std.Dispatcher
import cats.effect.kernel.Outcome
import cats.effect.kernel.Sync
import cats.effect.kernel.Async
import cats.effect.kernel.syntax.all._

class AsyncAwaitDsl[F[_]](implicit F: Async[F]) {

/**
* Type member used by the macro expansion to recover what `F` is without typetags
*/
type _AsyncContext[A] = F[A]

/**
* Value member used by the macro expansion to recover the Async instance associated to the block.
*/
implicit val _AsyncInstance: Async[F] = F

/**
* Non-blocking await the on result of `awaitable`. This may only be used directly within an enclosing `async` block.
*
* Internally, this will register the remainder of the code in enclosing `async` block as a callback
* in the `onComplete` handler of `awaitable`, and will *not* block a thread.
*/
@compileTimeOnly("[async] `await` must be enclosed in an `async` block")
def await[T](awaitable: F[T]): T =
??? // No implementation here, as calls to this are translated to `onComplete` by the macro.

/**
* Run the block of code `body` asynchronously. `body` may contain calls to `await` when the results of
* a `Future` are needed; this is translated into non-blocking code.
*/
def async[T](body: => T): F[T] = macro AsyncAwaitDsl.asyncImpl[F, T]

}

object AsyncAwaitDsl {

type Callback = Either[Throwable, AnyRef] => Unit

def asyncImpl[F[_], T](
c: whitebox.Context
)(body: c.Tree): c.Tree = {
import c.universe._
if (!c.compilerSettings.contains("-Xasync")) {
c.abort(
c.macroApplication.pos,
"The async requires the compiler option -Xasync (supported only by Scala 2.12.12+ / 2.13.3+)"
)
} else
try {
val awaitSym = typeOf[AsyncAwaitDsl[Any]].decl(TermName("await"))
def mark(t: DefDef): Tree = {
import language.reflectiveCalls
c.internal
.asInstanceOf[{
def markForAsyncTransform(
owner: Symbol,
method: DefDef,
awaitSymbol: Symbol,
config: Map[String, AnyRef]
): DefDef
}]
.markForAsyncTransform(
c.internal.enclosingOwner,
t,
awaitSym,
Map.empty
)
}
val name = TypeName("stateMachine$async")
// format: off
q"""
final class $name(dispatcher: _root_.cats.effect.std.Dispatcher[${c.prefix}._AsyncContext], callback: _root_.cats.effect.std.AsyncAwaitDsl.Callback) extends _root_.cats.effect.std.AsyncAwaitStateMachine(dispatcher, callback) {
${mark(q"""override def apply(tr$$async: _root_.cats.effect.kernel.Outcome[${c.prefix}._AsyncContext, _root_.scala.Throwable, _root_.scala.AnyRef]): _root_.scala.Unit = ${body}""")}
}
${c.prefix}._AsyncInstance.recoverWith {
_root_.cats.effect.std.Dispatcher[${c.prefix}._AsyncContext].use { dispatcher =>
${c.prefix}._AsyncInstance.async_[_root_.scala.AnyRef](cb => new $name(dispatcher, cb).start())
}
}{
case _root_.cats.effect.std.AsyncAwaitDsl.CancelBridge =>
${c.prefix}._AsyncInstance.map(${c.prefix}._AsyncInstance.canceled)(_ => null.asInstanceOf[AnyRef])
}.asInstanceOf[${c.macroApplication.tpe}]
"""
} catch {
case e: ReflectiveOperationException =>
c.abort(
c.macroApplication.pos,
"-Xasync is provided as a Scala compiler option, but the async macro is unable to call c.internal.markForAsyncTransform. " + e.getClass.getName + " " + e.getMessage
)
}
}

// A marker exception to communicate cancellation through the async runtime.
object CancelBridge extends Throwable with scala.util.control.NoStackTrace
}

abstract class AsyncAwaitStateMachine[F[_]](
dispatcher: Dispatcher[F],
callback: AsyncAwaitDsl.Callback
)(implicit F: Sync[F]) extends Function1[Outcome[F, Throwable, AnyRef], Unit] {

// FSM translated method
//def apply(v1: Outcome[IO, Throwable, AnyRef]): Unit = ???

private[this] var state$async: Int = 0

/** Retrieve the current value of the state variable */
protected def state: Int = state$async

/** Assign `i` to the state variable */
protected def state_=(s: Int): Unit = state$async = s

protected def completeFailure(t: Throwable): Unit =
callback(Left(t))

protected def completeSuccess(value: AnyRef): Unit = {
callback(Right(value))
}

protected def onComplete(f: F[AnyRef]): Unit = {
dispatcher.unsafeRunAndForget(f.guaranteeCase(outcome => F.delay(this(outcome))))
}

protected def getCompleted(f: F[AnyRef]): Outcome[F, Throwable, AnyRef] = {
val _ = f
null
}

protected def tryGet(tr: Outcome[F, Throwable, AnyRef]): AnyRef =
tr match {
case Outcome.Succeeded(value) =>
// TODO discuss how to propagate "errors"" from other
// error channels than the Async's, such as None
// in OptionT. Maybe some ad-hoc polymorphic construct
// with a custom path-dependent "bridge" exception type...
// ... or something
dispatcher.unsafeRunSync(value)
case Outcome.Errored(e) =>
callback(Left(e))
this // sentinel value to indicate the dispatch loop should exit.
case Outcome.Canceled() =>
callback(Left(AsyncAwaitDsl.CancelBridge))
this
}

def start(): Unit = {
// Required to kickstart the async state machine.
// `def apply` does not consult its argument when `state == 0`.
apply(null)
}

}
121 changes: 121 additions & 0 deletions tests/jvm/src/test/scala-2/cats/effect/std/AsyncAwaitSpec.scala
@@ -0,0 +1,121 @@
/*
* Copyright 2020-2021 Typelevel
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package cats.effect
package std

import scala.concurrent.duration._
import cats.syntax.all._
import cats.data.Kleisli

class AsyncAwaitSpec extends BaseSpec {

"IOAsyncAwait" should {
object IOAsyncAwait extends cats.effect.std.AsyncAwaitDsl[IO]
import IOAsyncAwait.{await => ioAwait, _}

"work on success" in real {

val io = IO.sleep(100.millis) >> IO.pure(1)

val program = async(ioAwait(io) + ioAwait(io))

program.flatMap { res =>
IO {
res must beEqualTo(2)
}
}
}

"propagate errors outward" in real {

case object Boom extends Throwable
val io = IO.raiseError[Int](Boom)

val program = async(ioAwait(io))

program.attempt.flatMap { res =>
IO {
res must beEqualTo(Left(Boom))
}
}
}

"propagate canceled outcomes outward" in real {

val io = IO.canceled

val program = async(ioAwait(io))

program.start.flatMap(_.join).flatMap { res =>
IO {
res must beEqualTo(Outcome.canceled[IO, Throwable, Unit])
}
}
}

"be cancellable" in real {

val program = for {
ref <- Ref[IO].of(0)
_ <- async { ioAwait(IO.sleep(100.millis) *> ref.update(_ + 1)) }
.start
.flatMap(_.cancel)
_ <- IO.sleep(200.millis)
result <- ref.get
} yield {
result
}

program.flatMap { res =>
IO {
res must beEqualTo(0)
}
}

}

"suspend side effects" in real {
var x = 0
val program = async(x += 1)

for {
before <- IO(x must beEqualTo(0))
_ <- program
after <- IO(x must beEqualTo(1))
} yield before && after
}
}

"KleisliAsyncAwait" should {
type F[A] = Kleisli[IO, Int, A]
object KleisliAsyncAwait extends cats.effect.std.AsyncAwaitDsl[F]
import KleisliAsyncAwait.{await => kAwait, _}

"work on successes" in real {
val io = Temporal[F].sleep(100.millis) >> Kleisli(x => IO.pure(x + 1))

val program = async(kAwait(io) + kAwait(io))

program.run(0).flatMap { res =>
IO {
res must beEqualTo(2)
}
}
}
}

}

0 comments on commit dd7aa87

Please sign in to comment.