@@ -24,12 +24,12 @@ import java.util.{Calendar, GregorianCalendar, Properties}
2424import org .h2 .jdbc .JdbcSQLException
2525import org .scalatest .{BeforeAndAfter , PrivateMethodTester }
2626
27- import org .apache .spark .{ SparkException , SparkFunSuite }
27+ import org .apache .spark .SparkFunSuite
2828import org .apache .spark .sql .{DataFrame , Row }
2929import org .apache .spark .sql .execution .DataSourceScanExec
3030import org .apache .spark .sql .execution .command .ExplainCommand
3131import org .apache .spark .sql .execution .datasources .LogicalRelation
32- import org .apache .spark .sql .execution .datasources .jdbc .{JDBCOptions , JDBCRDD , JdbcUtils }
32+ import org .apache .spark .sql .execution .datasources .jdbc .{JDBCOptions , JDBCRDD , JDBCRelation , JdbcUtils }
3333import org .apache .spark .sql .sources ._
3434import org .apache .spark .sql .test .SharedSQLContext
3535import org .apache .spark .sql .types ._
@@ -209,6 +209,16 @@ class JDBCSuite extends SparkFunSuite
209209 conn.close()
210210 }
211211
212+ // Check whether the tables are fetched in the expected degree of parallelism
213+ def checkNumPartitions (df : DataFrame , expectedNumPartitions : Int ): Unit = {
214+ val jdbcRelations = df.queryExecution.analyzed.collect {
215+ case LogicalRelation (r : JDBCRelation , _, _) => r
216+ }
217+ assert(jdbcRelations.length == 1 )
218+ assert(jdbcRelations.head.parts.length == expectedNumPartitions,
219+ s " Expecting a JDBCRelation with $expectedNumPartitions partitions, but got:` $jdbcRelations` " )
220+ }
221+
212222 test(" SELECT *" ) {
213223 assert(sql(" SELECT * FROM foobar" ).collect().size === 3 )
214224 }
@@ -313,13 +323,23 @@ class JDBCSuite extends SparkFunSuite
313323 }
314324
315325 test(" SELECT * partitioned" ) {
316- assert(sql(" SELECT * FROM parts" ).collect().size == 3 )
326+ val df = sql(" SELECT * FROM parts" )
327+ checkNumPartitions(df, expectedNumPartitions = 3 )
328+ assert(df.collect().length == 3 )
317329 }
318330
319331 test(" SELECT WHERE (simple predicates) partitioned" ) {
320- assert(sql(" SELECT * FROM parts WHERE THEID < 1" ).collect().size === 0 )
321- assert(sql(" SELECT * FROM parts WHERE THEID != 2" ).collect().size === 2 )
322- assert(sql(" SELECT THEID FROM parts WHERE THEID = 1" ).collect().size === 1 )
332+ val df1 = sql(" SELECT * FROM parts WHERE THEID < 1" )
333+ checkNumPartitions(df1, expectedNumPartitions = 3 )
334+ assert(df1.collect().length === 0 )
335+
336+ val df2 = sql(" SELECT * FROM parts WHERE THEID != 2" )
337+ checkNumPartitions(df2, expectedNumPartitions = 3 )
338+ assert(df2.collect().length === 2 )
339+
340+ val df3 = sql(" SELECT THEID FROM parts WHERE THEID = 1" )
341+ checkNumPartitions(df3, expectedNumPartitions = 3 )
342+ assert(df3.collect().length === 1 )
323343 }
324344
325345 test(" SELECT second field partitioned" ) {
@@ -370,24 +390,27 @@ class JDBCSuite extends SparkFunSuite
370390 }
371391
372392 test(" Partitioning via JDBCPartitioningInfo API" ) {
373- assert(
374- spark.read.jdbc(urlWithUserAndPass, " TEST.PEOPLE " , " THEID " , 0 , 4 , 3 , new Properties () )
375- .collect().length === 3 )
393+ val df = spark.read.jdbc(urlWithUserAndPass, " TEST.PEOPLE " , " THEID " , 0 , 4 , 3 , new Properties ())
394+ checkNumPartitions(df, expectedNumPartitions = 3 )
395+ assert(df .collect().length === 3 )
376396 }
377397
378398 test(" Partitioning via list-of-where-clauses API" ) {
379399 val parts = Array [String ](" THEID < 2" , " THEID >= 2" )
380- assert(spark.read.jdbc(urlWithUserAndPass, " TEST.PEOPLE" , parts, new Properties ())
381- .collect().length === 3 )
400+ val df = spark.read.jdbc(urlWithUserAndPass, " TEST.PEOPLE" , parts, new Properties ())
401+ checkNumPartitions(df, expectedNumPartitions = 2 )
402+ assert(df.collect().length === 3 )
382403 }
383404
384405 test(" Partitioning on column that might have null values." ) {
385- assert(
386- spark.read.jdbc(urlWithUserAndPass, " TEST.EMP" , " theid" , 0 , 4 , 3 , new Properties ())
387- .collect().length === 4 )
388- assert(
389- spark.read.jdbc(urlWithUserAndPass, " TEST.EMP" , " THEID" , 0 , 4 , 3 , new Properties ())
390- .collect().length === 4 )
406+ val df = spark.read.jdbc(urlWithUserAndPass, " TEST.EMP" , " theid" , 0 , 4 , 3 , new Properties ())
407+ checkNumPartitions(df, expectedNumPartitions = 3 )
408+ assert(df.collect().length === 4 )
409+
410+ val df2 = spark.read.jdbc(urlWithUserAndPass, " TEST.EMP" , " THEID" , 0 , 4 , 3 , new Properties ())
411+ checkNumPartitions(df2, expectedNumPartitions = 3 )
412+ assert(df2.collect().length === 4 )
413+
391414 // partitioning on a nullable quoted column
392415 assert(
393416 spark.read.jdbc(urlWithUserAndPass, " TEST.EMP" , """ "Dept"""" , 0 , 4 , 3 , new Properties ())
@@ -404,6 +427,7 @@ class JDBCSuite extends SparkFunSuite
404427 numPartitions = 0 ,
405428 connectionProperties = new Properties ()
406429 )
430+ checkNumPartitions(res, expectedNumPartitions = 1 )
407431 assert(res.count() === 8 )
408432 }
409433
@@ -417,6 +441,7 @@ class JDBCSuite extends SparkFunSuite
417441 numPartitions = 10 ,
418442 connectionProperties = new Properties ()
419443 )
444+ checkNumPartitions(res, expectedNumPartitions = 4 )
420445 assert(res.count() === 8 )
421446 }
422447
@@ -430,6 +455,7 @@ class JDBCSuite extends SparkFunSuite
430455 numPartitions = 4 ,
431456 connectionProperties = new Properties ()
432457 )
458+ checkNumPartitions(res, expectedNumPartitions = 1 )
433459 assert(res.count() === 8 )
434460 }
435461
@@ -450,7 +476,9 @@ class JDBCSuite extends SparkFunSuite
450476 }
451477
452478 test(" SELECT * on partitioned table with a nullable partition column" ) {
453- assert(sql(" SELECT * FROM nullparts" ).collect().size == 4 )
479+ val df = sql(" SELECT * FROM nullparts" )
480+ checkNumPartitions(df, expectedNumPartitions = 3 )
481+ assert(df.collect().length == 4 )
454482 }
455483
456484 test(" H2 integral types" ) {
@@ -722,7 +750,8 @@ class JDBCSuite extends SparkFunSuite
722750 }
723751 // test the JdbcRelation toString output
724752 df.queryExecution.analyzed.collect {
725- case r : LogicalRelation => assert(r.relation.toString == " JDBCRelation(TEST.PEOPLE)" )
753+ case r : LogicalRelation =>
754+ assert(r.relation.toString == " JDBCRelation(TEST.PEOPLE) [numPartitions=3]" )
726755 }
727756 }
728757
0 commit comments