Skip to content

Commit 5ff56c5

Browse files
committed
Merge pull request #5 from phaller/wip-match
Support await inside match expressions
2 parents e5e7b9a + 11fe7a0 commit 5ff56c5

File tree

2 files changed

+152
-23
lines changed

2 files changed

+152
-23
lines changed

src/main/scala/scala/async/ExprBuilder.scala

Lines changed: 79 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
131131
s"AsyncState #$state, next = $nextState"
132132
}
133133

134-
class AsyncStateWithIf(stats: List[c.Tree], state: Int)
134+
class AsyncStateWithoutAwait(stats: List[c.Tree], state: Int)
135135
extends AsyncState(stats, state, 0) {
136136
// nextState unused, since encoded in then and else branches
137137

@@ -387,17 +387,50 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
387387
this += If(cond,
388388
Block(mkStateTree(thenState), Apply(Ident("resume"), List())),
389389
Block(mkStateTree(elseState), Apply(Ident("resume"), List())))
390-
new AsyncStateWithIf(stats.toList, state) {
390+
new AsyncStateWithoutAwait(stats.toList, state) {
391391
override val varDefs = self.varDefs.toList
392392
}
393393
}
394-
394+
395+
/**
396+
* Build `AsyncState` ending with a match expression.
397+
*
398+
* The cases of the match simply resume at the state of their corresponding right-hand side.
399+
*
400+
* @param scrutTree tree of the scrutinee
401+
* @param cases list of case definitions
402+
* @param stateFirstCase state of the right-hand side of the first case
403+
* @param perCaseBudget maximum number of states per case
404+
* @return an `AsyncState` representing the match expression
405+
*/
406+
def resultWithMatch(scrutTree: c.Tree, cases: List[CaseDef], stateFirstCase: Int, perCasebudget: Int): AsyncState = {
407+
// 1. build list of changed cases
408+
val newCases = for ((cas, num) <- cases.zipWithIndex) yield cas match {
409+
case CaseDef(pat, guard, rhs) => CaseDef(pat, guard, Block(mkStateTree(num * perCasebudget + stateFirstCase), Apply(Ident("resume"), List())))
410+
}
411+
// 2. insert changed match tree at the end of the current state
412+
this += Match(c.resetAllAttrs(scrutTree.duplicate), newCases)
413+
new AsyncStateWithoutAwait(stats.toList, state) {
414+
override val varDefs = self.varDefs.toList
415+
}
416+
}
417+
395418
override def toString: String = {
396419
val statsBeforeAwait = stats.mkString("\n")
397420
s"ASYNC STATE:\n$statsBeforeAwait \nawaitable: $awaitable \nresult name: $resultName"
398421
}
399422
}
400423

424+
/**
425+
* An `AsyncBlockBuilder` builds a `ListBuffer[AsyncState]` based on the expressions of a `Block(stats, expr)` (see `Async.asyncImpl`).
426+
*
427+
* @param stats a list of expressions
428+
* @param expr the last expression of the block
429+
* @param startState the start state
430+
* @param endState the state to continue with
431+
* @param budget the maximum number of states in this block
432+
* @param toRename a `Map` for renaming the given key symbols to the mangled value names
433+
*/
401434
class AsyncBlockBuilder(stats: List[c.Tree], expr: c.Tree, startState: Int, endState: Int,
402435
budget: Int, private var toRename: Map[c.Symbol, c.Name]) {
403436
val asyncStates = ListBuffer[builder.AsyncState]()
@@ -413,7 +446,15 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
413446
case Apply(fun, _) if fun.symbol == awaitMethod => true
414447
case _ => false
415448
}) throw new FallbackToCpsException
416-
449+
450+
def builderForBranch(tree: c.Tree, state: Int, nextState: Int, budget: Int, nameMap: Map[c.Symbol, c.Name]): AsyncBlockBuilder = {
451+
val (branchStats, branchExpr) = tree match {
452+
case Block(s, e) => (s, e)
453+
case _ => (List(tree), Literal(Constant(())))
454+
}
455+
new AsyncBlockBuilder(branchStats, branchExpr, state, nextState, budget, nameMap)
456+
}
457+
417458
// populate asyncStates
418459
for (stat <- stats) stat match {
419460
// the val name = await(..) pattern
@@ -450,29 +491,44 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
450491
asyncStates +=
451492
// the two Int arguments are the start state of the then branch and the else branch, respectively
452493
stateBuilder.resultWithIf(cond, currState + 1, currState + thenBudget)
453-
454-
val thenBuilder = thenp match {
455-
case Block(thenStats, thenExpr) =>
456-
new AsyncBlockBuilder(thenStats, thenExpr, currState + 1, currState + ifBudget, thenBudget, toRename)
457-
case _ =>
458-
new AsyncBlockBuilder(List(thenp), Literal(Constant(())), currState + 1, currState + ifBudget, thenBudget, toRename)
494+
495+
List((thenp, currState + 1, thenBudget), (elsep, currState + thenBudget, elseBudget)) foreach { case (tree, state, branchBudget) =>
496+
val builder = builderForBranch(tree, state, currState + ifBudget, branchBudget, toRename)
497+
asyncStates ++= builder.asyncStates
498+
toRename ++= builder.toRename
459499
}
460-
asyncStates ++= thenBuilder.asyncStates
461-
toRename ++= thenBuilder.toRename
462-
463-
val elseBuilder = elsep match {
464-
case Block(elseStats, elseExpr) =>
465-
new AsyncBlockBuilder(elseStats, elseExpr, currState + thenBudget, currState + ifBudget, elseBudget, toRename)
466-
case _ =>
467-
new AsyncBlockBuilder(List(elsep), Literal(Constant(())), currState + thenBudget, currState + ifBudget, elseBudget, toRename)
468-
}
469-
asyncStates ++= elseBuilder.asyncStates
470-
toRename ++= elseBuilder.toRename
471-
500+
472501
// create new state builder for state `currState + ifBudget`
473502
currState = currState + ifBudget
474503
stateBuilder = new builder.AsyncStateBuilder(currState, toRename)
475-
504+
505+
case Match(scrutinee, cases) =>
506+
vprintln("transforming match expr: " + stat)
507+
checkForUnsupportedAwait(scrutinee)
508+
509+
val matchBudget: Int = remainingBudget / 2
510+
remainingBudget -= matchBudget //TODO test if budget > 0
511+
// state that we continue with after match: currState + matchBudget
512+
513+
val perCaseBudget: Int = matchBudget / cases.size
514+
asyncStates +=
515+
// the two Int arguments are the start state of the first case and the per-case state budget, respectively
516+
stateBuilder.resultWithMatch(scrutinee, cases, currState + 1, perCaseBudget)
517+
518+
for ((cas, num) <- cases.zipWithIndex) {
519+
val (casStats, casExpr) = cas match {
520+
case CaseDef(_, _, Block(s, e)) => (s, e)
521+
case CaseDef(_, _, rhs) => (List(rhs), Literal(Constant(())))
522+
}
523+
val builder = new AsyncBlockBuilder(casStats, casExpr, currState + (num * perCaseBudget) + 1, currState + matchBudget, perCaseBudget, toRename)
524+
asyncStates ++= builder.asyncStates
525+
toRename ++= builder.toRename
526+
}
527+
528+
// create new state builder for state `currState + matchBudget`
529+
currState = currState + matchBudget
530+
stateBuilder = new builder.AsyncStateBuilder(currState, toRename)
531+
476532
case _ =>
477533
checkForUnsupportedAwait(stat)
478534
stateBuilder += stat
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
/**
2+
* Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
3+
*/
4+
5+
package scala.async
6+
package run
7+
package match0
8+
9+
import language.{reflectiveCalls, postfixOps}
10+
import scala.concurrent.{Future, ExecutionContext, future, Await}
11+
import scala.concurrent.duration._
12+
import scala.async.Async.{async, await}
13+
import org.junit.runner.RunWith
14+
import org.junit.runners.JUnit4
15+
import org.junit.Test
16+
17+
18+
class TestMatchClass {
19+
20+
import ExecutionContext.Implicits.global
21+
22+
def m1(x: Int): Future[Int] = future {
23+
Thread.sleep(1000)
24+
x + 2
25+
}
26+
27+
def m2(y: Int): Future[Int] = async {
28+
val f = m1(y)
29+
var z = 0
30+
y match {
31+
case 10 =>
32+
val x1 = await(f)
33+
z = x1 + 2
34+
case 20 =>
35+
val x2 = await(f)
36+
z = x2 - 2
37+
}
38+
z
39+
}
40+
41+
def m3(y: Int): Future[Int] = async {
42+
val f = m1(y)
43+
var z = 0
44+
y match {
45+
case 0 =>
46+
val x2 = await(f)
47+
z = x2 - 2
48+
case 1 =>
49+
val x1 = await(f)
50+
z = x1 + 2
51+
}
52+
z
53+
}
54+
}
55+
56+
57+
@RunWith(classOf[JUnit4])
58+
class MatchSpec {
59+
60+
@Test def `support await in a simple match expression`() {
61+
val o = new TestMatchClass
62+
val fut = o.m2(10) // matches first case
63+
val res = Await.result(fut, 2 seconds)
64+
res mustBe (14)
65+
}
66+
67+
@Test def `support await in a simple match expression 2`() {
68+
val o = new TestMatchClass
69+
val fut = o.m3(1) // matches second case
70+
val res = Await.result(fut, 2 seconds)
71+
res mustBe (5)
72+
}
73+
}

0 commit comments

Comments
 (0)