@@ -131,7 +131,7 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
131131 s " AsyncState # $state, next = $nextState"
132132 }
133133
134- class AsyncStateWithIf (stats : List [c.Tree ], state : Int )
134+ class AsyncStateWithoutAwait (stats : List [c.Tree ], state : Int )
135135 extends AsyncState (stats, state, 0 ) {
136136 // nextState unused, since encoded in then and else branches
137137
@@ -387,17 +387,50 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
387387 this += If (cond,
388388 Block (mkStateTree(thenState), Apply (Ident (" resume" ), List ())),
389389 Block (mkStateTree(elseState), Apply (Ident (" resume" ), List ())))
390- new AsyncStateWithIf (stats.toList, state) {
390+ new AsyncStateWithoutAwait (stats.toList, state) {
391391 override val varDefs = self.varDefs.toList
392392 }
393393 }
394-
394+
395+ /**
396+ * Build `AsyncState` ending with a match expression.
397+ *
398+ * The cases of the match simply resume at the state of their corresponding right-hand side.
399+ *
400+ * @param scrutTree tree of the scrutinee
401+ * @param cases list of case definitions
402+ * @param stateFirstCase state of the right-hand side of the first case
403+ * @param perCaseBudget maximum number of states per case
404+ * @return an `AsyncState` representing the match expression
405+ */
406+ def resultWithMatch (scrutTree : c.Tree , cases : List [CaseDef ], stateFirstCase : Int , perCasebudget : Int ): AsyncState = {
407+ // 1. build list of changed cases
408+ val newCases = for ((cas, num) <- cases.zipWithIndex) yield cas match {
409+ case CaseDef (pat, guard, rhs) => CaseDef (pat, guard, Block (mkStateTree(num * perCasebudget + stateFirstCase), Apply (Ident (" resume" ), List ())))
410+ }
411+ // 2. insert changed match tree at the end of the current state
412+ this += Match (c.resetAllAttrs(scrutTree.duplicate), newCases)
413+ new AsyncStateWithoutAwait (stats.toList, state) {
414+ override val varDefs = self.varDefs.toList
415+ }
416+ }
417+
395418 override def toString : String = {
396419 val statsBeforeAwait = stats.mkString(" \n " )
397420 s " ASYNC STATE: \n $statsBeforeAwait \n awaitable: $awaitable \n result name: $resultName"
398421 }
399422 }
400423
424+ /**
425+ * An `AsyncBlockBuilder` builds a `ListBuffer[AsyncState]` based on the expressions of a `Block(stats, expr)` (see `Async.asyncImpl`).
426+ *
427+ * @param stats a list of expressions
428+ * @param expr the last expression of the block
429+ * @param startState the start state
430+ * @param endState the state to continue with
431+ * @param budget the maximum number of states in this block
432+ * @param toRename a `Map` for renaming the given key symbols to the mangled value names
433+ */
401434 class AsyncBlockBuilder (stats : List [c.Tree ], expr : c.Tree , startState : Int , endState : Int ,
402435 budget : Int , private var toRename : Map [c.Symbol , c.Name ]) {
403436 val asyncStates = ListBuffer [builder.AsyncState ]()
@@ -413,7 +446,15 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
413446 case Apply (fun, _) if fun.symbol == awaitMethod => true
414447 case _ => false
415448 }) throw new FallbackToCpsException
416-
449+
450+ def builderForBranch (tree : c.Tree , state : Int , nextState : Int , budget : Int , nameMap : Map [c.Symbol , c.Name ]): AsyncBlockBuilder = {
451+ val (branchStats, branchExpr) = tree match {
452+ case Block (s, e) => (s, e)
453+ case _ => (List (tree), Literal (Constant (())))
454+ }
455+ new AsyncBlockBuilder (branchStats, branchExpr, state, nextState, budget, nameMap)
456+ }
457+
417458 // populate asyncStates
418459 for (stat <- stats) stat match {
419460 // the val name = await(..) pattern
@@ -450,29 +491,44 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
450491 asyncStates +=
451492 // the two Int arguments are the start state of the then branch and the else branch, respectively
452493 stateBuilder.resultWithIf(cond, currState + 1 , currState + thenBudget)
453-
454- val thenBuilder = thenp match {
455- case Block (thenStats, thenExpr) =>
456- new AsyncBlockBuilder (thenStats, thenExpr, currState + 1 , currState + ifBudget, thenBudget, toRename)
457- case _ =>
458- new AsyncBlockBuilder (List (thenp), Literal (Constant (())), currState + 1 , currState + ifBudget, thenBudget, toRename)
494+
495+ List ((thenp, currState + 1 , thenBudget), (elsep, currState + thenBudget, elseBudget)) foreach { case (tree, state, branchBudget) =>
496+ val builder = builderForBranch(tree, state, currState + ifBudget, branchBudget, toRename)
497+ asyncStates ++= builder.asyncStates
498+ toRename ++= builder.toRename
459499 }
460- asyncStates ++= thenBuilder.asyncStates
461- toRename ++= thenBuilder.toRename
462-
463- val elseBuilder = elsep match {
464- case Block (elseStats, elseExpr) =>
465- new AsyncBlockBuilder (elseStats, elseExpr, currState + thenBudget, currState + ifBudget, elseBudget, toRename)
466- case _ =>
467- new AsyncBlockBuilder (List (elsep), Literal (Constant (())), currState + thenBudget, currState + ifBudget, elseBudget, toRename)
468- }
469- asyncStates ++= elseBuilder.asyncStates
470- toRename ++= elseBuilder.toRename
471-
500+
472501 // create new state builder for state `currState + ifBudget`
473502 currState = currState + ifBudget
474503 stateBuilder = new builder.AsyncStateBuilder (currState, toRename)
475-
504+
505+ case Match (scrutinee, cases) =>
506+ vprintln(" transforming match expr: " + stat)
507+ checkForUnsupportedAwait(scrutinee)
508+
509+ val matchBudget : Int = remainingBudget / 2
510+ remainingBudget -= matchBudget // TODO test if budget > 0
511+ // state that we continue with after match: currState + matchBudget
512+
513+ val perCaseBudget : Int = matchBudget / cases.size
514+ asyncStates +=
515+ // the two Int arguments are the start state of the first case and the per-case state budget, respectively
516+ stateBuilder.resultWithMatch(scrutinee, cases, currState + 1 , perCaseBudget)
517+
518+ for ((cas, num) <- cases.zipWithIndex) {
519+ val (casStats, casExpr) = cas match {
520+ case CaseDef (_, _, Block (s, e)) => (s, e)
521+ case CaseDef (_, _, rhs) => (List (rhs), Literal (Constant (())))
522+ }
523+ val builder = new AsyncBlockBuilder (casStats, casExpr, currState + (num * perCaseBudget) + 1 , currState + matchBudget, perCaseBudget, toRename)
524+ asyncStates ++= builder.asyncStates
525+ toRename ++= builder.toRename
526+ }
527+
528+ // create new state builder for state `currState + matchBudget`
529+ currState = currState + matchBudget
530+ stateBuilder = new builder.AsyncStateBuilder (currState, toRename)
531+
476532 case _ =>
477533 checkForUnsupportedAwait(stat)
478534 stateBuilder += stat
0 commit comments