Skip to content

Commit 0b3ec99

Browse files
dongjoon-hyundaspalrahul
authored andcommitted
[SPARK-25534][SQL] Make SQLHelper trait
## What changes were proposed in this pull request? Currently, Spark has 7 `withTempPath` and 6 `withSQLConf` functions. This PR aims to remove duplicated and inconsistent code and reduce them to the following meaningful implementations. **withTempPath** - `SQLHelper.withTempPath`: The one which was used in `SQLTestUtils`. **withSQLConf** - `SQLHelper.withSQLConf`: The one which was used in `PlanTest`. - `ExecutorSideSQLConfSuite.withSQLConf`: The one which doesn't throw `AnalysisException` on StaticConf changes. - `SQLTestUtils.withSQLConf`: The one which overrides intentionally to change the active session. ```scala protected override def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { SparkSession.setActiveSession(spark) super.withSQLConf(pairs: _*)(f) } ``` ## How was this patch tested? Pass the Jenkins with the existing tests. Closes apache#22548 from dongjoon-hyun/SPARK-25534. Authored-by: Dongjoon Hyun <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent bafbd32 commit 0b3ec99

File tree

9 files changed

+81
-132
lines changed

9 files changed

+81
-132
lines changed

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ import org.scalatest.Suite
2222
import org.scalatest.Tag
2323

2424
import org.apache.spark.SparkFunSuite
25-
import org.apache.spark.sql.AnalysisException
2625
import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
2726
import org.apache.spark.sql.catalyst.expressions._
2827
import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode
@@ -57,7 +56,7 @@ trait CodegenInterpretedPlanTest extends PlanTest {
5756
* Provides helper methods for comparing plans, but without the overhead of
5857
* mandating a FunSuite.
5958
*/
60-
trait PlanTestBase extends PredicateHelper { self: Suite =>
59+
trait PlanTestBase extends PredicateHelper with SQLHelper { self: Suite =>
6160

6261
// TODO(gatorsmile): remove this from PlanTest and all the analyzer rules
6362
protected def conf = SQLConf.get
@@ -174,32 +173,4 @@ trait PlanTestBase extends PredicateHelper { self: Suite =>
174173
plan1 == plan2
175174
}
176175
}
177-
178-
/**
179-
* Sets all SQL configurations specified in `pairs`, calls `f`, and then restores all SQL
180-
* configurations.
181-
*/
182-
protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = {
183-
val conf = SQLConf.get
184-
val (keys, values) = pairs.unzip
185-
val currentValues = keys.map { key =>
186-
if (conf.contains(key)) {
187-
Some(conf.getConfString(key))
188-
} else {
189-
None
190-
}
191-
}
192-
(keys, values).zipped.foreach { (k, v) =>
193-
if (SQLConf.staticConfKeys.contains(k)) {
194-
throw new AnalysisException(s"Cannot modify the value of a static config: $k")
195-
}
196-
conf.setConfString(k, v)
197-
}
198-
try f finally {
199-
keys.zip(currentValues).foreach {
200-
case (key, Some(value)) => conf.setConfString(key, value)
201-
case (key, None) => conf.unsetConf(key)
202-
}
203-
}
204-
}
205176
}
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.sql.catalyst.plans
18+
19+
import java.io.File
20+
21+
import org.apache.spark.sql.AnalysisException
22+
import org.apache.spark.sql.internal.SQLConf
23+
import org.apache.spark.util.Utils
24+
25+
trait SQLHelper {
26+
27+
/**
28+
* Sets all SQL configurations specified in `pairs`, calls `f`, and then restores all SQL
29+
* configurations.
30+
*/
31+
protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = {
32+
val conf = SQLConf.get
33+
val (keys, values) = pairs.unzip
34+
val currentValues = keys.map { key =>
35+
if (conf.contains(key)) {
36+
Some(conf.getConfString(key))
37+
} else {
38+
None
39+
}
40+
}
41+
(keys, values).zipped.foreach { (k, v) =>
42+
if (SQLConf.staticConfKeys.contains(k)) {
43+
throw new AnalysisException(s"Cannot modify the value of a static config: $k")
44+
}
45+
conf.setConfString(k, v)
46+
}
47+
try f finally {
48+
keys.zip(currentValues).foreach {
49+
case (key, Some(value)) => conf.setConfString(key, value)
50+
case (key, None) => conf.unsetConf(key)
51+
}
52+
}
53+
}
54+
55+
/**
56+
* Generates a temporary path without creating the actual file/directory, then pass it to `f`. If
57+
* a file/directory is created there by `f`, it will be delete after `f` returns.
58+
*/
59+
protected def withTempPath(f: File => Unit): Unit = {
60+
val path = Utils.createTempDir()
61+
path.delete()
62+
try f(path) finally Utils.deleteRecursively(path)
63+
}
64+
}

sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,25 +19,25 @@ package org.apache.spark.sql.execution.benchmark
1919
import java.io.File
2020

2121
import scala.collection.JavaConverters._
22-
import scala.util.{Random, Try}
22+
import scala.util.Random
2323

2424
import org.apache.spark.SparkConf
2525
import org.apache.spark.benchmark.Benchmark
2626
import org.apache.spark.sql.{DataFrame, DataFrameWriter, Row, SparkSession}
2727
import org.apache.spark.sql.catalyst.InternalRow
28+
import org.apache.spark.sql.catalyst.plans.SQLHelper
2829
import org.apache.spark.sql.execution.datasources.parquet.{SpecificParquetRecordReaderBase, VectorizedParquetRecordReader}
2930
import org.apache.spark.sql.internal.SQLConf
3031
import org.apache.spark.sql.types._
3132
import org.apache.spark.sql.vectorized.ColumnVector
32-
import org.apache.spark.util.Utils
3333

3434

3535
/**
3636
* Benchmark to measure data source read performance.
3737
* To run this:
3838
* spark-submit --class <this class> <spark sql test jar>
3939
*/
40-
object DataSourceReadBenchmark {
40+
object DataSourceReadBenchmark extends SQLHelper {
4141
val conf = new SparkConf()
4242
.setAppName("DataSourceReadBenchmark")
4343
// Since `spark.master` always exists, overrides this value
@@ -54,27 +54,10 @@ object DataSourceReadBenchmark {
5454
spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true")
5555
spark.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true")
5656

57-
def withTempPath(f: File => Unit): Unit = {
58-
val path = Utils.createTempDir()
59-
path.delete()
60-
try f(path) finally Utils.deleteRecursively(path)
61-
}
62-
6357
def withTempTable(tableNames: String*)(f: => Unit): Unit = {
6458
try f finally tableNames.foreach(spark.catalog.dropTempView)
6559
}
6660

67-
def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = {
68-
val (keys, values) = pairs.unzip
69-
val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption)
70-
(keys, values).zipped.foreach(spark.conf.set)
71-
try f finally {
72-
keys.zip(currentValues).foreach {
73-
case (key, Some(value)) => spark.conf.set(key, value)
74-
case (key, None) => spark.conf.unset(key)
75-
}
76-
}
77-
}
7861
private def prepareTable(dir: File, df: DataFrame, partition: Option[String] = None): Unit = {
7962
val testDf = if (partition.isDefined) {
8063
df.write.partitionBy(partition.get)

sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/FilterPushdownBenchmark.scala

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,16 @@ package org.apache.spark.sql.execution.benchmark
1919

2020
import java.io.File
2121

22-
import scala.util.{Random, Try}
22+
import scala.util.Random
2323

2424
import org.apache.spark.SparkConf
2525
import org.apache.spark.benchmark.{Benchmark, BenchmarkBase}
2626
import org.apache.spark.sql.{DataFrame, SparkSession}
27+
import org.apache.spark.sql.catalyst.plans.SQLHelper
2728
import org.apache.spark.sql.functions.monotonically_increasing_id
2829
import org.apache.spark.sql.internal.SQLConf
2930
import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType
3031
import org.apache.spark.sql.types.{ByteType, Decimal, DecimalType, TimestampType}
31-
import org.apache.spark.util.Utils
3232

3333
/**
3434
* Benchmark to measure read performance with Filter pushdown.
@@ -40,7 +40,7 @@ import org.apache.spark.util.Utils
4040
* Results will be written to "benchmarks/FilterPushdownBenchmark-results.txt".
4141
* }}}
4242
*/
43-
object FilterPushdownBenchmark extends BenchmarkBase {
43+
object FilterPushdownBenchmark extends BenchmarkBase with SQLHelper {
4444

4545
private val conf = new SparkConf()
4646
.setAppName(this.getClass.getSimpleName)
@@ -60,28 +60,10 @@ object FilterPushdownBenchmark extends BenchmarkBase {
6060

6161
private val spark = SparkSession.builder().config(conf).getOrCreate()
6262

63-
def withTempPath(f: File => Unit): Unit = {
64-
val path = Utils.createTempDir()
65-
path.delete()
66-
try f(path) finally Utils.deleteRecursively(path)
67-
}
68-
6963
def withTempTable(tableNames: String*)(f: => Unit): Unit = {
7064
try f finally tableNames.foreach(spark.catalog.dropTempView)
7165
}
7266

73-
def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = {
74-
val (keys, values) = pairs.unzip
75-
val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption)
76-
(keys, values).zipped.foreach(spark.conf.set)
77-
try f finally {
78-
keys.zip(currentValues).foreach {
79-
case (key, Some(value)) => spark.conf.set(key, value)
80-
case (key, None) => spark.conf.unset(key)
81-
}
82-
}
83-
}
84-
8567
private def prepareTable(
8668
dir: File, numRows: Int, width: Int, useStringForValue: Boolean): Unit = {
8769
import spark.implicits._

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,19 @@
1616
*/
1717
package org.apache.spark.sql.execution.datasources.csv
1818

19-
import java.io.File
20-
2119
import org.apache.spark.SparkConf
2220
import org.apache.spark.benchmark.Benchmark
2321
import org.apache.spark.sql.{Column, Row, SparkSession}
22+
import org.apache.spark.sql.catalyst.plans.SQLHelper
2423
import org.apache.spark.sql.functions.lit
2524
import org.apache.spark.sql.types._
26-
import org.apache.spark.util.Utils
2725

2826
/**
2927
* Benchmark to measure CSV read/write performance.
3028
* To run this:
3129
* spark-submit --class <this class> --jars <spark sql test jar>
3230
*/
33-
object CSVBenchmarks {
31+
object CSVBenchmarks extends SQLHelper {
3432
val conf = new SparkConf()
3533

3634
val spark = SparkSession.builder
@@ -40,12 +38,6 @@ object CSVBenchmarks {
4038
.getOrCreate()
4139
import spark.implicits._
4240

43-
def withTempPath(f: File => Unit): Unit = {
44-
val path = Utils.createTempDir()
45-
path.delete()
46-
try f(path) finally Utils.deleteRecursively(path)
47-
}
48-
4941
def quotedValuesBenchmark(rowsNum: Int, numIters: Int): Unit = {
5042
val benchmark = new Benchmark(s"Parsing quoted values", rowsNum)
5143

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmarks.scala

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,16 @@ import java.io.File
2121
import org.apache.spark.SparkConf
2222
import org.apache.spark.benchmark.Benchmark
2323
import org.apache.spark.sql.{Row, SparkSession}
24+
import org.apache.spark.sql.catalyst.plans.SQLHelper
2425
import org.apache.spark.sql.functions.lit
2526
import org.apache.spark.sql.types._
26-
import org.apache.spark.util.Utils
2727

2828
/**
2929
* The benchmarks aims to measure performance of JSON parsing when encoding is set and isn't.
3030
* To run this:
3131
* spark-submit --class <this class> --jars <spark sql test jar>
3232
*/
33-
object JSONBenchmarks {
33+
object JSONBenchmarks extends SQLHelper {
3434
val conf = new SparkConf()
3535

3636
val spark = SparkSession.builder
@@ -40,13 +40,6 @@ object JSONBenchmarks {
4040
.getOrCreate()
4141
import spark.implicits._
4242

43-
def withTempPath(f: File => Unit): Unit = {
44-
val path = Utils.createTempDir()
45-
path.delete()
46-
try f(path) finally Utils.deleteRecursively(path)
47-
}
48-
49-
5043
def schemaInferring(rowsNum: Int): Unit = {
5144
val benchmark = new Benchmark("JSON schema inferring", rowsNum)
5245

sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManagerSuite.scala

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,12 @@ import org.apache.hadoop.conf.Configuration
2525
import org.apache.hadoop.fs._
2626

2727
import org.apache.spark.SparkFunSuite
28+
import org.apache.spark.sql.catalyst.plans.SQLHelper
2829
import org.apache.spark.sql.catalyst.util.quietly
2930
import org.apache.spark.sql.internal.SQLConf
3031
import org.apache.spark.sql.test.SharedSparkSession
31-
import org.apache.spark.util.Utils
3232

33-
abstract class CheckpointFileManagerTests extends SparkFunSuite {
33+
abstract class CheckpointFileManagerTests extends SparkFunSuite with SQLHelper {
3434

3535
def createManager(path: Path): CheckpointFileManager
3636

@@ -88,12 +88,6 @@ abstract class CheckpointFileManagerTests extends SparkFunSuite {
8888
fm.delete(path) // should not throw exception
8989
}
9090
}
91-
92-
protected def withTempPath(f: File => Unit): Unit = {
93-
val path = Utils.createTempDir()
94-
path.delete()
95-
try f(path) finally Utils.deleteRecursively(path)
96-
}
9791
}
9892

9993
class CheckpointFileManagerSuite extends SparkFunSuite with SharedSparkSession {

sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ import org.apache.spark.sql.catalyst.plans.PlanTestBase
4040
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
4141
import org.apache.spark.sql.catalyst.util._
4242
import org.apache.spark.sql.execution.FilterExec
43-
import org.apache.spark.sql.internal.SQLConf
4443
import org.apache.spark.util.UninterruptibleThread
4544
import org.apache.spark.util.Utils
4645

@@ -167,18 +166,6 @@ private[sql] trait SQLTestUtilsBase
167166
super.withSQLConf(pairs: _*)(f)
168167
}
169168

170-
/**
171-
* Generates a temporary path without creating the actual file/directory, then pass it to `f`. If
172-
* a file/directory is created there by `f`, it will be delete after `f` returns.
173-
*
174-
* @todo Probably this method should be moved to a more general place
175-
*/
176-
protected def withTempPath(f: File => Unit): Unit = {
177-
val path = Utils.createTempDir()
178-
path.delete()
179-
try f(path) finally Utils.deleteRecursively(path)
180-
}
181-
182169
/**
183170
* Copy file in jar's resource to a temp file, then pass it to `f`.
184171
* This function is used to make `f` can use the path of temp file(e.g. file:/), instead of

0 commit comments

Comments
 (0)