Skip to content

Conversation

@kiszk
Copy link
Member

@kiszk kiszk commented Mar 1, 2017

What changes were proposed in this pull request?

This PR improves performance of operations with range() by changing Java code generated by Catalyst. This PR is inspired by the blog article.

This PR changes generated code in the following two points.

  1. Replace a while-loop with long instance variables a for-loop with int local varibles
  2. Suppress generation of shouldStop() method if this method is unnecessary (e.g. append() is not generated).

These points facilitates compiler optimizations in a JIT compiler by feeding the simplified Java code into the JIT compiler. The performance is improved by 7.6x.

Benchmark program:

val N = 1 << 29
val iters = 2
val benchmark = new Benchmark("range.count", N * iters)
benchmark.addCase(s"with this PR") { i =>
  var n = 0
  var len = 0
  while (n < iters) {
    len += sparkSession.range(N).selectExpr("count(id)").collect.length
    n += 1
  }
}
benchmark.run

Performance result without this PR

OpenJDK 64-Bit Server VM 1.8.0_111-8u111-b14-2ubuntu0.16.04.2-b14 on Linux 4.4.0-47-generic
Intel(R) Xeon(R) CPU E5-2667 v3 @ 3.20GHz
range.count:                             Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
w/o this PR                                   1349 / 1356        796.2           1.3       1.0X

Performance result with this PR

OpenJDK 64-Bit Server VM 1.8.0_111-8u111-b14-2ubuntu0.16.04.2-b14 on Linux 4.4.0-47-generic
Intel(R) Xeon(R) CPU E5-2667 v3 @ 3.20GHz
range.count:                             Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
with this PR                                   177 /  271       6065.3           0.2       1.0X

Here is a comparison between generated code w/o and with this PR. Only the method agg_doAggregateWithoutKey is changed.

Generated code without this PR

/* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 006 */   private Object[] references;
/* 007 */   private scala.collection.Iterator[] inputs;
/* 008 */   private boolean agg_initAgg;
/* 009 */   private boolean agg_bufIsNull;
/* 010 */   private long agg_bufValue;
/* 011 */   private org.apache.spark.sql.execution.metric.SQLMetric range_numOutputRows;
/* 012 */   private org.apache.spark.sql.execution.metric.SQLMetric range_numGeneratedRows;
/* 013 */   private boolean range_initRange;
/* 014 */   private long range_number;
/* 015 */   private TaskContext range_taskContext;
/* 016 */   private InputMetrics range_inputMetrics;
/* 017 */   private long range_batchEnd;
/* 018 */   private long range_numElementsTodo;
/* 019 */   private scala.collection.Iterator range_input;
/* 020 */   private UnsafeRow range_result;
/* 021 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder range_holder;
/* 022 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter range_rowWriter;
/* 023 */   private org.apache.spark.sql.execution.metric.SQLMetric agg_numOutputRows;
/* 024 */   private org.apache.spark.sql.execution.metric.SQLMetric agg_aggTime;
/* 025 */   private UnsafeRow agg_result;
/* 026 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder;
/* 027 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter;
/* 028 */
/* 029 */   public GeneratedIterator(Object[] references) {
/* 030 */     this.references = references;
/* 031 */   }
/* 032 */
/* 033 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 034 */     partitionIndex = index;
/* 035 */     this.inputs = inputs;
/* 036 */     agg_initAgg = false;
/* 037 */
/* 038 */     this.range_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[0];
/* 039 */     this.range_numGeneratedRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[1];
/* 040 */     range_initRange = false;
/* 041 */     range_number = 0L;
/* 042 */     range_taskContext = TaskContext.get();
/* 043 */     range_inputMetrics = range_taskContext.taskMetrics().inputMetrics();
/* 044 */     range_batchEnd = 0;
/* 045 */     range_numElementsTodo = 0L;
/* 046 */     range_input = inputs[0];
/* 047 */     range_result = new UnsafeRow(1);
/* 048 */     this.range_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(range_result, 0);
/* 049 */     this.range_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(range_holder, 1);
/* 050 */     this.agg_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[2];
/* 051 */     this.agg_aggTime = (org.apache.spark.sql.execution.metric.SQLMetric) references[3];
/* 052 */     agg_result = new UnsafeRow(1);
/* 053 */     this.agg_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result, 0);
/* 054 */     this.agg_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_holder, 1);
/* 055 */
/* 056 */   }
/* 057 */
/* 058 */   private void agg_doAggregateWithoutKey() throws java.io.IOException {
/* 059 */     // initialize aggregation buffer
/* 060 */     agg_bufIsNull = false;
/* 061 */     agg_bufValue = 0L;
/* 062 */
/* 063 */     // initialize Range
/* 064 */     if (!range_initRange) {
/* 065 */       range_initRange = true;
/* 066 */       initRange(partitionIndex);
/* 067 */     }
/* 068 */
/* 069 */     while (true) {
/* 070 */       while (range_number != range_batchEnd) {
/* 071 */         long range_value = range_number;
/* 072 */         range_number += 1L;
/* 073 */
/* 074 */         // do aggregate
/* 075 */         // common sub-expressions
/* 076 */
/* 077 */         // evaluate aggregate function
/* 078 */         boolean agg_isNull1 = false;
/* 079 */
/* 080 */         long agg_value1 = -1L;
/* 081 */         agg_value1 = agg_bufValue + 1L;
/* 082 */         // update aggregation buffer
/* 083 */         agg_bufIsNull = false;
/* 084 */         agg_bufValue = agg_value1;
/* 085 */
/* 086 */         if (shouldStop()) return;
/* 087 */       }
/* 088 */
/* 089 */       if (range_taskContext.isInterrupted()) {
/* 090 */         throw new TaskKilledException();
/* 091 */       }
/* 092 */
/* 093 */       long range_nextBatchTodo;
/* 094 */       if (range_numElementsTodo > 1000L) {
/* 095 */         range_nextBatchTodo = 1000L;
/* 096 */         range_numElementsTodo -= 1000L;
/* 097 */       } else {
/* 098 */         range_nextBatchTodo = range_numElementsTodo;
/* 099 */         range_numElementsTodo = 0;
/* 100 */         if (range_nextBatchTodo == 0) break;
/* 101 */       }
/* 102 */       range_numOutputRows.add(range_nextBatchTodo);
/* 103 */       range_inputMetrics.incRecordsRead(range_nextBatchTodo);
/* 104 */
/* 105 */       range_batchEnd += range_nextBatchTodo * 1L;
/* 106 */     }
/* 107 */
/* 108 */   }
/* 109 */
/* 110 */   private void initRange(int idx) {
/* 111 */     java.math.BigInteger index = java.math.BigInteger.valueOf(idx);
/* 112 */     java.math.BigInteger numSlice = java.math.BigInteger.valueOf(2L);
/* 113 */     java.math.BigInteger numElement = java.math.BigInteger.valueOf(10000L);
/* 114 */     java.math.BigInteger step = java.math.BigInteger.valueOf(1L);
/* 115 */     java.math.BigInteger start = java.math.BigInteger.valueOf(0L);
/* 117 */
/* 118 */     java.math.BigInteger st = index.multiply(numElement).divide(numSlice).multiply(step).add(start);
/* 119 */     if (st.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) {
/* 120 */       range_number = Long.MAX_VALUE;
/* 121 */     } else if (st.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
/* 122 */       range_number = Long.MIN_VALUE;
/* 123 */     } else {
/* 124 */       range_number = st.longValue();
/* 125 */     }
/* 126 */     range_batchEnd = range_number;
/* 127 */
/* 128 */     java.math.BigInteger end = index.add(java.math.BigInteger.ONE).multiply(numElement).divide(numSlice)
/* 129 */     .multiply(step).add(start);
/* 130 */     if (end.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) {
/* 131 */       partitionEnd = Long.MAX_VALUE;
/* 132 */     } else if (end.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
/* 133 */       partitionEnd = Long.MIN_VALUE;
/* 134 */     } else {
/* 135 */       partitionEnd = end.longValue();
/* 136 */     }
/* 137 */
/* 138 */     java.math.BigInteger startToEnd = java.math.BigInteger.valueOf(partitionEnd).subtract(
/* 139 */       java.math.BigInteger.valueOf(range_number));
/* 140 */     range_numElementsTodo  = startToEnd.divide(step).longValue();
/* 141 */     if (range_numElementsTodo < 0) {
/* 142 */       range_numElementsTodo = 0;
/* 143 */     } else if (startToEnd.remainder(step).compareTo(java.math.BigInteger.valueOf(0L)) != 0) {
/* 144 */       range_numElementsTodo++;
/* 145 */     }
/* 146 */   }
/* 147 */
/* 148 */   protected void processNext() throws java.io.IOException {
/* 149 */     while (!agg_initAgg) {
/* 150 */       agg_initAgg = true;
/* 151 */       long agg_beforeAgg = System.nanoTime();
/* 152 */       agg_doAggregateWithoutKey();
/* 153 */       agg_aggTime.add((System.nanoTime() - agg_beforeAgg) / 1000000);
/* 154 */
/* 155 */       // output the result
/* 156 */
/* 157 */       agg_numOutputRows.add(1);
/* 158 */       agg_rowWriter.zeroOutNullBytes();
/* 159 */
/* 160 */       if (agg_bufIsNull) {
/* 161 */         agg_rowWriter.setNullAt(0);
/* 162 */       } else {
/* 163 */         agg_rowWriter.write(0, agg_bufValue);
/* 164 */       }
/* 165 */       append(agg_result);
/* 166 */     }
/* 167 */   }
/* 168 */ }

Generated code with this PR

/* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 006 */   private Object[] references;
/* 007 */   private scala.collection.Iterator[] inputs;
/* 008 */   private boolean agg_initAgg;
/* 009 */   private boolean agg_bufIsNull;
/* 010 */   private long agg_bufValue;
/* 011 */   private org.apache.spark.sql.execution.metric.SQLMetric range_numOutputRows;
/* 012 */   private org.apache.spark.sql.execution.metric.SQLMetric range_numGeneratedRows;
/* 013 */   private boolean range_initRange;
/* 014 */   private long range_number;
/* 015 */   private TaskContext range_taskContext;
/* 016 */   private InputMetrics range_inputMetrics;
/* 017 */   private long range_batchEnd;
/* 018 */   private long range_numElementsTodo;
/* 019 */   private scala.collection.Iterator range_input;
/* 020 */   private UnsafeRow range_result;
/* 021 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder range_holder;
/* 022 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter range_rowWriter;
/* 023 */   private org.apache.spark.sql.execution.metric.SQLMetric agg_numOutputRows;
/* 024 */   private org.apache.spark.sql.execution.metric.SQLMetric agg_aggTime;
/* 025 */   private UnsafeRow agg_result;
/* 026 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder;
/* 027 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter;
/* 028 */
/* 029 */   public GeneratedIterator(Object[] references) {
/* 030 */     this.references = references;
/* 031 */   }
/* 032 */
/* 033 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 034 */     partitionIndex = index;
/* 035 */     this.inputs = inputs;
/* 036 */     agg_initAgg = false;
/* 037 */
/* 038 */     this.range_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[0];
/* 039 */     this.range_numGeneratedRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[1];
/* 040 */     range_initRange = false;
/* 041 */     range_number = 0L;
/* 042 */     range_taskContext = TaskContext.get();
/* 043 */     range_inputMetrics = range_taskContext.taskMetrics().inputMetrics();
/* 044 */     range_batchEnd = 0;
/* 045 */     range_numElementsTodo = 0L;
/* 046 */     range_input = inputs[0];
/* 047 */     range_result = new UnsafeRow(1);
/* 048 */     this.range_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(range_result, 0);
/* 049 */     this.range_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(range_holder, 1);
/* 050 */     this.agg_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[2];
/* 051 */     this.agg_aggTime = (org.apache.spark.sql.execution.metric.SQLMetric) references[3];
/* 052 */     agg_result = new UnsafeRow(1);
/* 053 */     this.agg_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result, 0);
/* 054 */     this.agg_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_holder, 1);
/* 055 */
/* 056 */   }
/* 057 */
/* 058 */   private void agg_doAggregateWithoutKey() throws java.io.IOException {
/* 059 */     // initialize aggregation buffer
/* 060 */     agg_bufIsNull = false;
/* 061 */     agg_bufValue = 0L;
/* 062 */
/* 063 */     // initialize Range
/* 064 */     if (!range_initRange) {
/* 065 */       range_initRange = true;
/* 066 */       initRange(partitionIndex);
/* 067 */     }
/* 068 */
/* 069 */     while (true) {
/* 070 */       long range_range = range_batchEnd - range_number;
/* 071 */       if (range_range != 0L) {
/* 072 */         int range_localEnd = (int)(range_range / 1L);
/* 073 */         for (int range_localIdx = 0; range_localIdx < range_localEnd; range_localIdx++) {
/* 074 */           long range_value = ((long)range_localIdx * 1L) + range_number;
/* 075 */
/* 076 */           // do aggregate
/* 077 */           // common sub-expressions
/* 078 */
/* 079 */           // evaluate aggregate function
/* 080 */           boolean agg_isNull1 = false;
/* 081 */
/* 082 */           long agg_value1 = -1L;
/* 083 */           agg_value1 = agg_bufValue + 1L;
/* 084 */           // update aggregation buffer
/* 085 */           agg_bufIsNull = false;
/* 086 */           agg_bufValue = agg_value1;
/* 087 */
/* 088 */           // shouldStop check is eliminated
/* 089 */         }
/* 090 */         range_number = range_batchEnd;
/* 091 */       }
/* 092 */
/* 093 */       if (range_taskContext.isInterrupted()) {
/* 094 */         throw new TaskKilledException();
/* 095 */       }
/* 096 */
/* 097 */       long range_nextBatchTodo;
/* 098 */       if (range_numElementsTodo > 1000L) {
/* 099 */         range_nextBatchTodo = 1000L;
/* 100 */         range_numElementsTodo -= 1000L;
/* 101 */       } else {
/* 102 */         range_nextBatchTodo = range_numElementsTodo;
/* 103 */         range_numElementsTodo = 0;
/* 104 */         if (range_nextBatchTodo == 0) break;
/* 105 */       }
/* 106 */       range_numOutputRows.add(range_nextBatchTodo);
/* 107 */       range_inputMetrics.incRecordsRead(range_nextBatchTodo);
/* 108 */
/* 109 */       range_batchEnd += range_nextBatchTodo * 1L;
/* 110 */     }
/* 111 */
/* 112 */   }
/* 113 */
/* 114 */   private void initRange(int idx) {
/* 115 */     java.math.BigInteger index = java.math.BigInteger.valueOf(idx);
/* 116 */     java.math.BigInteger numSlice = java.math.BigInteger.valueOf(2L);
/* 117 */     java.math.BigInteger numElement = java.math.BigInteger.valueOf(10000L);
/* 118 */     java.math.BigInteger step = java.math.BigInteger.valueOf(1L);
/* 119 */     java.math.BigInteger start = java.math.BigInteger.valueOf(0L);
/* 120 */     long partitionEnd;
/* 121 */
/* 122 */     java.math.BigInteger st = index.multiply(numElement).divide(numSlice).multiply(step).add(start);
/* 123 */     if (st.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) {
/* 124 */       range_number = Long.MAX_VALUE;
/* 125 */     } else if (st.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
/* 126 */       range_number = Long.MIN_VALUE;
/* 127 */     } else {
/* 128 */       range_number = st.longValue();
/* 129 */     }
/* 130 */     range_batchEnd = range_number;
/* 131 */
/* 132 */     java.math.BigInteger end = index.add(java.math.BigInteger.ONE).multiply(numElement).divide(numSlice)
/* 133 */     .multiply(step).add(start);
/* 134 */     if (end.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) {
/* 135 */       partitionEnd = Long.MAX_VALUE;
/* 136 */     } else if (end.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
/* 137 */       partitionEnd = Long.MIN_VALUE;
/* 138 */     } else {
/* 139 */       partitionEnd = end.longValue();
/* 140 */     }
/* 141 */
/* 142 */     java.math.BigInteger startToEnd = java.math.BigInteger.valueOf(partitionEnd).subtract(
/* 143 */       java.math.BigInteger.valueOf(range_number));
/* 144 */     range_numElementsTodo  = startToEnd.divide(step).longValue();
/* 145 */     if (range_numElementsTodo < 0) {
/* 146 */       range_numElementsTodo = 0;
/* 147 */     } else if (startToEnd.remainder(step).compareTo(java.math.BigInteger.valueOf(0L)) != 0) {
/* 148 */       range_numElementsTodo++;
/* 149 */     }
/* 150 */   }
/* 151 */
/* 152 */   protected void processNext() throws java.io.IOException {
/* 153 */     while (!agg_initAgg) {
/* 154 */       agg_initAgg = true;
/* 155 */       long agg_beforeAgg = System.nanoTime();
/* 156 */       agg_doAggregateWithoutKey();
/* 157 */       agg_aggTime.add((System.nanoTime() - agg_beforeAgg) / 1000000);
/* 158 */
/* 159 */       // output the result
/* 160 */
/* 161 */       agg_numOutputRows.add(1);
/* 162 */       agg_rowWriter.zeroOutNullBytes();
/* 163 */
/* 164 */       if (agg_bufIsNull) {
/* 165 */         agg_rowWriter.setNullAt(0);
/* 166 */       } else {
/* 167 */         agg_rowWriter.write(0, agg_bufValue);
/* 168 */       }
/* 169 */       append(agg_result);
/* 170 */     }
/* 171 */   }
/* 172 */ }

A part of suppressing shouldStop() was originally developed by @inouehrs

How was this patch tested?

Add new tests into DataFrameRangeSuite

var shouldStopRequired: Boolean = false

def isShouldStopRequired: Boolean = {
if (shouldStopRequired) return true
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you just write shouldStopRequired || (this.parent != null && this.parent.isShouldStopRequired)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, it looks better. Done

@SparkQA
Copy link

SparkQA commented Mar 1, 2017

Test build #73689 has finished for PR 17122 at commit 47f405c.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Mar 1, 2017

Test build #73702 has finished for PR 17122 at commit 6913b45.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

val range = ctx.freshName("range")
// we need to place consume() before calling isShouldStopRequired
val body = consume(ctx, Seq(ev))
val shouldStop = if (isShouldStopRequired) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isShouldStopRequired complicates the logic. Is it necessary? How much improvement it brings?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that isShouldStopRequired is simple logic. It just checks whether shouldStopRequired or parents shouldStopRequired is true

There are two reasons why isShouldStopRequired is necessary.

  1. The improvement is largely degraded from 7.6x to 5.5x without isShouldStopRequired
  2. We may miss some opportunities to enable compiler optimizations since the size of loop body would be increased without isShouldStopRequired. This is because a JIT compiler has a threshold of loop body size to apply some loop optimizations such as loop unrolling.
OpenJDK 64-Bit Server VM 1.8.0_111-8u111-b14-2ubuntu0.16.04.2-b14 on Linux 4.4.0-47-generic
Intel(R) Xeon(R) CPU E5-2667 v3 @ 3.20GHz
cnt:                                     Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
cnt                                            247 /  289       4340.6           0.2       1.0X


// to track the existence of apply() call in the current produce-consume cycle
// if apply is not called (e.g. in aggregation), we can skip shoudStop in the inner-most loop
parent.shouldStopRequired = false
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this? The default value of shouldStopRequired is already false.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to ensure produce() starts with parent.shouldStopRequired = false. This is because I am afraid other produce-consume may set true into shouldStopRequired if we have more than one-produce-consume in one parent.
However, in most of cases, it would not happen. For the simplicity, I eliminated this.

| for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) {
| long $value = ((long)$localIdx * ${step}L) + $number;
| $body
| $shouldStop
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think under most of cases, we need shouldStop check. But currently shouldStopRequired is false by default, so you will consume many additional rows now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh. nvm. the outer-most WholeStageCodegenExec's shouldStopRequired is true.

override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
val stopEarly = ctx.freshName("stopEarly")
ctx.addMutableState("boolean", stopEarly, s"$stopEarly = false;")
shouldStopRequired = true // loop may break early even without append in loop body
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this Limit is the parent of an Aggregate, the final shouldStopRequired is true. But actually we can skip the check.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. This implementation depends on slightly old revision that means there is no stopEarly() method. Removed this line.

@viirya
Copy link
Member

viirya commented Mar 2, 2017

This optimization looks good and the improvement is great. I have some comments regarding the changes of shouldStopRequired.

@SparkQA
Copy link

SparkQA commented Mar 2, 2017

Test build #73730 has started for PR 17122 at commit 5ff8dca.

*/
final def produce(ctx: CodegenContext, parent: CodegenSupport): String = executeQuery {
this.parent = parent

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extra space.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch. done.

throw new UnsupportedOperationException
}

/* for optimization */
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deserve better comment.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. done.

}

/* for optimization */
var shouldStopRequired: Boolean = false
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add protected.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, done

val localIdx = ctx.freshName("localIdx")
val localEnd = ctx.freshName("localEnd")
val range = ctx.freshName("range")
// we need to place consume() before calling isShouldStopRequired
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to describe the reason that consume() may modify shouldStopRequired.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, done.

/*
* for optimization to suppress shouldStop() in a loop of WholeStageCodegen
*/
// true: require to insert shouldStop() into a loop
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

??

Copy link
Member Author

@kiszk kiszk Mar 2, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

? I do not understand your comment. Do you suggest to remove line 213?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your comment style looks weird. Please put true... in the /*... */

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, the usual style is:

 /**
  * ....
  * 
  */

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated comments around here

@SparkQA
Copy link

SparkQA commented Mar 2, 2017

Test build #73734 has started for PR 17122 at commit 9528ccc.

*/
// true: require to insert shouldStop() into a loop
protected var shouldStopRequired: Boolean = false

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a simple comment.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

@SparkQA
Copy link

SparkQA commented Mar 2, 2017

Test build #73737 has started for PR 17122 at commit 6697928.

@viirya
Copy link
Member

viirya commented Mar 2, 2017

retest this please.

}

// set true if doConsume() inserts append() method that requires shouldStop() in the loop
protected var shouldStopRequired: Boolean = false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it is possible to do this without a mutable var? Code generation has way to much mutable state as it is.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hvanhovell We can do it technically. It looks simple. On the other hand, we have to maintain an immutable var carefully by investigating whether append() is required. In particular, we would add a new CodegenSupport-related class.
In contrast, the current approach is easy to maintain the var. When we would add a new CodegenSupport-related class, it is unnecessary to carefully investigate it.

This is a trade-off between simplicity and maintainability. What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We spend quite a bit of time debugging issues caused by poorly managed mutable vars in code generation. So I'd rather avoid it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would +1 for this. That is part of reason why I said it complicates the logic in previous comment.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I will rewrite this using immutable var.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated to use an immutable variable.

@SparkQA
Copy link

SparkQA commented Mar 2, 2017

Test build #73745 has finished for PR 17122 at commit 6697928.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Mar 2, 2017

Test build #73781 has finished for PR 17122 at commit 8704083.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

* isShouldStopRequired: require to insert shouldStop() into the loop if true
*/
def isShouldStopRequired: Boolean = {
return shouldStopRequired && !(this.parent != null && !this.parent.isShouldStopRequired)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this better to understand?

def isShouldStopRequired: Boolean = {
  assert(this.parent != null)
  shouldStopRequired && this.parent.isShouldStopRequired
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your suggestion. However, it caused an assertion failure at "SPARK-7150 range api" in DataFrameRangeSuite.

In the failure case, isShouldStopRequired is called in the class hierarchy by parent.
RangeExec -> FilterExec -> WholeStageCodegenExec

return shouldStopRequired && !(this.parent != null && !this.parent.isShouldStopRequired)
}

// set false if doConsume() does not insert append() that requires shouldStop() in the loop
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion for the comment:

/**
 * Set to false if this plan consumes all rows produced by children but doesn't output row to buffer
 * by calling `append()`, so the children don't require `shouldStop()` in the loop of producing rows.
 */

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look good, done

throw new UnsupportedOperationException
}

/**
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion for the comment:

/**
 * For optimization to suppress `shouldStop()` in a loop of WholeStageCodegen.
 * Returning true means we need to insert `shouldStop()` into the loop producing rows, if any.
 */

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, done

@viirya
Copy link
Member

viirya commented Mar 3, 2017

Have few suggestions about the code comments. Otherwise LGTM.

@SparkQA
Copy link

SparkQA commented Mar 3, 2017

Test build #73813 has finished for PR 17122 at commit 7f095c0.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

* to buffer by calling append(), so the children don't require shouldStop()
* in the loop of producing rows.
*/
protected def shouldStopRequired: Boolean = true
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be set to true for all blocking operators? Sort for instance?

Copy link
Member Author

@kiszk kiszk Mar 4, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for pointing out this. I overlooked Sort.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious about the performance improvement on Sort.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I cannot see performance improvement on Sort. I think there are two reasons for this result.

One is that the loop body is too large. At the inner-most loop, a insertRow method is called. I think that the size of this method is too large to facilitate loop optimizations.

The other is that the hotspot method is not here. I guess that the hotspot method may be sort() at line 154.

Here is the generated code.

/* 114 */     while (true) {
/* 115 */       long range_range = range_batchEnd - range_number;
/* 116 */       if (range_range != 0L) {
/* 117 */         int range_localEnd = (int)(range_range / -1L);
/* 118 */         for (int range_localIdx = 0; range_localIdx < range_localEnd; range_localIdx++) {
/* 119 */           long range_value = ((long)range_localIdx * -1L) + range_number;
/* 120 */
/* 121 */           range_rowWriter.write(0, range_value);
/* 122 */           sort_sorter.insertRow((UnsafeRow)range_result);
/* 123 */
/* 124 */           // shouldStop check is eliminated
/* 125 */         }
/* 126 */         range_number = range_batchEnd;
/* 127 */       }
/* 128 */
/* 129 */       if (range_taskContext.isInterrupted()) {
/* 130 */         throw new TaskKilledException();
/* 131 */       }
/* 132 */
/* 133 */       long range_nextBatchTodo;
/* 134 */       if (range_numElementsTodo > 1000L) {
/* 135 */         range_nextBatchTodo = 1000L;
/* 136 */         range_numElementsTodo -= 1000L;
/* 137 */       } else {
/* 138 */         range_nextBatchTodo = range_numElementsTodo;
/* 139 */         range_numElementsTodo = 0;
/* 140 */         if (range_nextBatchTodo == 0) break;
/* 141 */       }
/* 142 */       range_numOutputRows.add(range_nextBatchTodo);
/* 143 */       range_inputMetrics.incRecordsRead(range_nextBatchTodo);
/* 144 */
/* 145 */       range_batchEnd += range_nextBatchTodo * -1L;
/* 146 */     }
/* 147 */
/* 148 */   }
/* 149 */
/* 150 */   protected void processNext() throws java.io.IOException {
/* 151 */     if (sort_needToSort) {
/* 152 */       long sort_spillSizeBefore = sort_metrics.memoryBytesSpilled();
/* 153 */       sort_addToSorter();
/* 154 */       sort_sortedIter = sort_sorter.sort();
/* 155 */       sort_sortTime.add(sort_sorter.getSortTimeNanos() / 1000000);
/* 156 */       sort_peakMemory.add(sort_sorter.getPeakMemoryUsage());
/* 157 */       sort_spillSize.add(sort_metrics.memoryBytesSpilled() - sort_spillSizeBefore);
/* 158 */       sort_metrics.incPeakExecutionMemory(sort_sorter.getPeakMemoryUsage());
/* 159 */       sort_needToSort = false;
/* 160 */     }
/* 161 */
/* 162 */     while (sort_sortedIter.hasNext()) {
/* 163 */       UnsafeRow sort_outputRow = (UnsafeRow)sort_sortedIter.next();
/* 164 */
/* 165 */       append(sort_outputRow);
/* 166 */
/* 167 */       if (shouldStop()) return;
/* 168 */     }
/* 169 */   }

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would sort be improved? Sort is very expensive operation, so it dominates the runtime of the job. The only possible improvement here is that you could avoid sorting with range (assuming that we do not overflow).

Copy link
Member Author

@kiszk kiszk Mar 6, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with you. I just confirmed that this PR is not effective for Sort in general for answering an question.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only reason I mentioned sort is that there is no use in stopping early and that it would not be correct to do so. I was not really expecting any improvement.

@SparkQA
Copy link

SparkQA commented Mar 4, 2017

Test build #73898 has finished for PR 17122 at commit e6740a5.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Mar 4, 2017

Test build #73904 has finished for PR 17122 at commit c2f7939.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

* Returning true means we need to insert shouldStop() into the loop producing rows, if any.
*/
def isShouldStopRequired: Boolean = {
return shouldStopRequired && !(this.parent != null && !this.parent.isShouldStopRequired)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think shouldStopRequired && (this.parent == null || this.parent.isShouldStopRequired) is better to understand.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, done

@viirya
Copy link
Member

viirya commented Mar 7, 2017

LGTM, except the comment for isShouldStopRequired condition.

@SparkQA
Copy link

SparkQA commented Mar 7, 2017

Test build #74094 has finished for PR 17122 at commit 49d02e4.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@viirya
Copy link
Member

viirya commented Mar 9, 2017

LGTM cc @hvanhovell

@hvanhovell
Copy link
Contributor

LGTM merging to master

@asfgit asfgit closed this in fcb68e0 Mar 10, 2017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants