@@ -19,14 +19,15 @@ package org.apache.spark.sql.execution
1919
2020import  java .sql .{Timestamp , Date }
2121
22- import  org .apache . spark . serializer . Serializer 
23- import   org . apache . spark .{ SparkEnv ,  SparkConf ,  ShuffleDependency ,  SparkContext } 
22+ import  org .scalatest .{ FunSuite ,  BeforeAndAfterAll } 
23+ 
2424import  org .apache .spark .rdd .ShuffledRDD 
25+ import  org .apache .spark .serializer .Serializer 
26+ import  org .apache .spark .ShuffleDependency 
2527import  org .apache .spark .sql .types ._ 
2628import  org .apache .spark .sql .Row 
27- import  org .scalatest .{FunSuite , BeforeAndAfterAll }
28- 
29- import  org .apache .spark .sql .{MyDenseVectorUDT , SQLContext , QueryTest }
29+ import  org .apache .spark .sql .test .TestSQLContext ._ 
30+ import  org .apache .spark .sql .{MyDenseVectorUDT , QueryTest }
3031
3132class  SparkSqlSerializer2DataTypeSuite  extends  FunSuite  {
3233  //  Make sure that we will not use serializer2 for unsupported data types.
@@ -67,18 +68,17 @@ class SparkSqlSerializer2DataTypeSuite extends FunSuite {
6768}
6869
6970abstract  class  SparkSqlSerializer2Suite  extends  QueryTest  with  BeforeAndAfterAll  {
70- 
71-   @ transient var  sparkContext :  SparkContext  =  _
72-   @ transient var  sqlContext :  SQLContext  =  _
73-   //  We may have an existing SparkEnv (e.g. the one used by TestSQLContext).
74-   @ transient val  existingSparkEnv  =  SparkEnv .get
7571  var  allColumns :  String  =  _
7672  val  serializerClass :  Class [Serializer ] = 
7773    classOf [SparkSqlSerializer2 ].asInstanceOf [Class [Serializer ]]
74+   var  numShufflePartitions :  Int  =  _
75+   var  useSerializer2 :  Boolean  =  _
7876
7977  override  def  beforeAll ():  Unit  =  {
80-     sqlContext.sql(" set spark.sql.shuffle.partitions=5"  )
81-     sqlContext.sql(" set spark.sql.useSerializer2=true"  )
78+     numShufflePartitions =  conf.numShufflePartitions
79+     useSerializer2 =  conf.useSqlSerializer2
80+ 
81+     sql(" set spark.sql.useSerializer2=true"  )
8282
8383    val  supportedTypes  = 
8484      Seq (StringType , BinaryType , NullType , BooleanType ,
@@ -112,18 +112,15 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll
112112          new  Timestamp (i))
113113      }
114114
115-     sqlContext. createDataFrame(rdd, schema).registerTempTable(" shuffle"  )
115+     createDataFrame(rdd, schema).registerTempTable(" shuffle"  )
116116
117117    super .beforeAll()
118118  }
119119
120120  override  def  afterAll ():  Unit  =  {
121-     sqlContext.dropTempTable(" shuffle"  )
122-     sparkContext.stop()
123-     sqlContext =  null 
124-     sparkContext =  null 
125-     //  Set the existing SparkEnv back.
126-     SparkEnv .set(existingSparkEnv)
121+     dropTempTable(" shuffle"  )
122+     sql(s " set spark.sql.shuffle.partitions= $numShufflePartitions" )
123+     sql(s " set spark.sql.useSerializer2= $useSerializer2" )
127124    super .afterAll()
128125  }
129126
@@ -144,64 +141,40 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll
144141  }
145142
146143  test(" key schema and value schema are not nulls"  ) {
147-     val  df  =  sqlContext. sql(s " SELECT DISTINCT  ${allColumns} FROM shuffle " )
144+     val  df  =  sql(s " SELECT DISTINCT  ${allColumns} FROM shuffle " )
148145    checkSerializer(df.queryExecution.executedPlan, serializerClass)
149146    checkAnswer(
150147      df,
151-       sqlContext. table(" shuffle"  ).collect())
148+       table(" shuffle"  ).collect())
152149  }
153150
154151  test(" value schema is null"  ) {
155-     val  df  =  sqlContext. sql(s " SELECT col0 FROM shuffle ORDER BY col0 " )
152+     val  df  =  sql(s " SELECT col0 FROM shuffle ORDER BY col0 " )
156153    checkSerializer(df.queryExecution.executedPlan, serializerClass)
157154    assert(
158155      df.map(r =>  r.getString(0 )).collect().toSeq === 
159-       sqlContext.table(" shuffle"  ).select(" col0"  ).map(r =>  r.getString(0 )).collect().sorted.toSeq)
156+       table(" shuffle"  ).select(" col0"  ).map(r =>  r.getString(0 )).collect().sorted.toSeq)
157+   }
158+ }
159+ 
160+ /**  Tests SparkSqlSerializer2 with sort based shuffle without sort merge. */ 
161+ class  SparkSqlSerializer2SortShuffleSuite  extends  SparkSqlSerializer2Suite  {
162+   override  def  beforeAll ():  Unit  =  {
163+     super .beforeAll()
164+     //  Sort merge will not be triggered.
165+     sql(" set spark.sql.shuffle.partitions = 200"  )
160166  }
161167
162168  test(" key schema is null"  ) {
163169    val  aggregations  =  allColumns.split(" ,"  ).map(c =>  s " COUNT( $c) " ).mkString(" ,"  )
164-     val  df  =  sqlContext. sql(s " SELECT  $aggregations FROM shuffle " )
170+     val  df  =  sql(s " SELECT  $aggregations FROM shuffle " )
165171    checkSerializer(df.queryExecution.executedPlan, serializerClass)
166172    checkAnswer(
167173      df,
168174      Row (1000 , 1000 , 0 , 1000 , 1000 , 1000 , 1000 , 1000 , 1000 , 1000 , 1000 , 1000 , 1000 , 1000 ))
169175  }
170176}
171177
172- /**  Tests SparkSqlSerializer2 with hash based shuffle. */ 
173- class  SparkSqlSerializer2HashShuffleSuite  extends  SparkSqlSerializer2Suite  {
174-   override  def  beforeAll ():  Unit  =  {
175-     val  sparkConf  = 
176-       new  SparkConf ()
177-         .set(" spark.driver.allowMultipleContexts"  , " true"  )
178-         .set(" spark.sql.testkey"  , " true"  )
179-         .set(" spark.shuffle.manager"  , " hash"  )
180- 
181-     sparkContext =  new  SparkContext (" local[2]"  , " Serializer2SQLContext"  , sparkConf)
182-     sqlContext =  new  SQLContext (sparkContext)
183-     super .beforeAll()
184-   }
185- }
186- 
187- /**  Tests SparkSqlSerializer2 with sort based shuffle without sort merge. */ 
188- class  SparkSqlSerializer2SortShuffleSuite  extends  SparkSqlSerializer2Suite  {
189-   override  def  beforeAll ():  Unit  =  {
190-     //  Since spark.sql.shuffle.partition is 5, we will not do sort merge when
191-     //  spark.shuffle.sort.bypassMergeThreshold is also 5.
192-     val  sparkConf  = 
193-       new  SparkConf ()
194-         .set(" spark.driver.allowMultipleContexts"  , " true"  )
195-         .set(" spark.sql.testkey"  , " true"  )
196-         .set(" spark.shuffle.manager"  , " sort"  )
197-         .set(" spark.shuffle.sort.bypassMergeThreshold"  , " 5"  )
198- 
199-     sparkContext =  new  SparkContext (" local[2]"  , " Serializer2SQLContext"  , sparkConf)
200-     sqlContext =  new  SQLContext (sparkContext)
201-     super .beforeAll()
202-   }
203- }
204- 
205178/**  For now, we will use SparkSqlSerializer for sort based shuffle with sort merge. */ 
206179class  SparkSqlSerializer2SortMergeShuffleSuite  extends  SparkSqlSerializer2Suite  {
207180
@@ -210,15 +183,8 @@ class SparkSqlSerializer2SortMergeShuffleSuite extends SparkSqlSerializer2Suite
210183    classOf [SparkSqlSerializer ].asInstanceOf [Class [Serializer ]]
211184
212185  override  def  beforeAll ():  Unit  =  {
213-     val  sparkConf  = 
214-       new  SparkConf ()
215-         .set(" spark.driver.allowMultipleContexts"  , " true"  )
216-         .set(" spark.sql.testkey"  , " true"  )
217-         .set(" spark.shuffle.manager"  , " sort"  )
218-         .set(" spark.shuffle.sort.bypassMergeThreshold"  , " 0"  ) //  Always do sort merge.
219- 
220-     sparkContext =  new  SparkContext (" local[2]"  , " Serializer2SQLContext"  , sparkConf)
221-     sqlContext =  new  SQLContext (sparkContext)
222186    super .beforeAll()
187+     //  To trigger the sort merge.
188+     sql(" set spark.sql.shuffle.partitions = 201"  )
223189  }
224190}
0 commit comments