Skip to content

Commit 47f405c

Browse files
committed
initial commit
1 parent 4ba9c6c commit 47f405c

File tree

4 files changed

+50
-7
lines changed

4 files changed

+50
-7
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ trait CodegenSupport extends SparkPlan {
7777
*/
7878
final def produce(ctx: CodegenContext, parent: CodegenSupport): String = executeQuery {
7979
this.parent = parent
80+
81+
// to track the existence of apply() call in the current produce-consume cycle
82+
// if apply is not called (e.g. in aggregation), we can skip shoudStop in the inner-most loop
83+
parent.shouldStopRequired = false
8084
ctx.freshNamePrefix = variablePrefix
8185
s"""
8286
|${ctx.registerComment(s"PRODUCE: ${this.simpleString}")}
@@ -206,6 +210,15 @@ trait CodegenSupport extends SparkPlan {
206210
def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
207211
throw new UnsupportedOperationException
208212
}
213+
214+
/* for optimization */
215+
var shouldStopRequired: Boolean = false
216+
217+
def isShouldStopRequired: Boolean = {
218+
if (shouldStopRequired) return true
219+
if (this.parent != null) return this.parent.isShouldStopRequired
220+
false
221+
}
209222
}
210223

211224

@@ -418,6 +431,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
418431
} else {
419432
""
420433
}
434+
shouldStopRequired = true
421435
s"""
422436
|${row.code}
423437
|append(${row.value}$doCopy);

sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -387,8 +387,8 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
387387
// How many values should be generated in the next batch.
388388
val nextBatchTodo = ctx.freshName("nextBatchTodo")
389389

390-
// The default size of a batch.
391-
val batchSize = 1000L
390+
// The default size of a batch, which must be positive integer
391+
val batchSize = 1000
392392

393393
ctx.addNewFunction("initRange",
394394
s"""
@@ -434,6 +434,17 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
434434
val input = ctx.freshName("input")
435435
// Right now, Range is only used when there is one upstream.
436436
ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];")
437+
438+
val localIdx = ctx.freshName("localIdx")
439+
val localEnd = ctx.freshName("localEnd")
440+
val range = ctx.freshName("range")
441+
// we need to place consume() before calling isShouldStopRequired
442+
val body = consume(ctx, Seq(ev))
443+
val shouldStop = if (isShouldStopRequired) {
444+
s"if (shouldStop()) { $number = $value + ${step}L; return; }"
445+
} else {
446+
"// shouldStop check is eliminated"
447+
}
437448
s"""
438449
| // initialize Range
439450
| if (!$initTerm) {
@@ -442,11 +453,15 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
442453
| }
443454
|
444455
| while (true) {
445-
| while ($number != $batchEnd) {
446-
| long $value = $number;
447-
| $number += ${step}L;
448-
| ${consume(ctx, Seq(ev))}
449-
| if (shouldStop()) return;
456+
| long $range = $batchEnd - $number;
457+
| if ($range != 0L) {
458+
| int $localEnd = (int)($range / ${step}L);
459+
| for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) {
460+
| long $value = ((long)$localIdx * ${step}L) + $number;
461+
| $body
462+
| $shouldStop
463+
| }
464+
| $number = $batchEnd;
450465
| }
451466
|
452467
| if ($taskContext.isInterrupted()) {

sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport {
6969
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
7070
val stopEarly = ctx.freshName("stopEarly")
7171
ctx.addMutableState("boolean", stopEarly, s"$stopEarly = false;")
72+
shouldStopRequired = true // loop may break early even without append in loop body
7273

7374
ctx.addNewFunction("stopEarly", s"""
7475
@Override

sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,19 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall
8989
val n = 9L * 1000 * 1000 * 1000 * 1000 * 1000 * 1000
9090
val res13 = spark.range(-n, n, n / 9).select("id")
9191
assert(res13.count == 18)
92+
93+
// range with non aggregation operation
94+
val res14 = spark.range(0, 100, 2).toDF.filter("50 <= id")
95+
res14.collect
96+
assert(res14.count == 25)
97+
98+
val res15 = spark.range(100, -100, -2).toDF.filter("id <= 0")
99+
res15.collect
100+
assert(res15.count == 50)
101+
102+
val res16 = spark.range(-1500, 1500, 3).toDF.filter("0 <= id")
103+
res16.collect
104+
assert(res16.count == 500)
92105
}
93106

94107
test("Range with randomized parameters") {

0 commit comments

Comments
 (0)