Skip to content

Commit 170568d

Browse files
committed
fix.
1 parent 4c673c6 commit 170568d

File tree

3 files changed

+69
-38
lines changed

3 files changed

+69
-38
lines changed

sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,11 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
159159
* @since 1.4.0
160160
*/
161161
def jdbc(url: String, table: String, properties: Properties): DataFrame = {
162-
jdbc(url, table, JDBCRelation.columnPartition(null), properties)
162+
// properties should override settings in extraOptions.
163+
this.extraOptions = this.extraOptions ++ properties.asScala
164+
// explicit url and dbtable should override all
165+
this.extraOptions += (JDBCOptions.JDBC_URL -> url, JDBCOptions.JDBC_TABLE_NAME -> table)
166+
format("jdbc").load()
163167
}
164168

165169
/**
@@ -177,7 +181,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
177181
* @param upperBound the maximum value of `columnName` used to decide partition stride.
178182
* @param numPartitions the number of partitions. This, along with `lowerBound` (inclusive),
179183
* `upperBound` (exclusive), form partition strides for generated WHERE
180-
* clause expressions used to split the column `columnName` evenly.
184+
* clause expressions used to split the column `columnName` evenly. When
185+
* the input is less than 1, the number is set to 1.
181186
* @param connectionProperties JDBC database connection arguments, a list of arbitrary string
182187
* tag/value. Normally at least a "user" and "password" property
183188
* should be included. "fetchsize" can be used to control the
@@ -192,9 +197,13 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
192197
upperBound: Long,
193198
numPartitions: Int,
194199
connectionProperties: Properties): DataFrame = {
195-
val partitioning = JDBCPartitioningInfo(columnName, lowerBound, upperBound, numPartitions)
196-
val parts = JDBCRelation.columnPartition(partitioning)
197-
jdbc(url, table, parts, connectionProperties)
200+
// columnName, lowerBound, upperBound and numPartitions override settings in extraOptions.
201+
this.extraOptions ++= Map(
202+
JDBCOptions.JDBC_PARTITION_COLUMN -> columnName,
203+
JDBCOptions.JDBC_LOWER_BOUND -> lowerBound.toString,
204+
JDBCOptions.JDBC_UPPER_BOUND -> upperBound.toString,
205+
JDBCOptions.JDBC_NUM_PARTITIONS -> numPartitions.toString)
206+
jdbc(url, table, connectionProperties)
198207
}
199208

200209
/**
@@ -220,22 +229,14 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
220229
table: String,
221230
predicates: Array[String],
222231
connectionProperties: Properties): DataFrame = {
232+
// connectionProperties should override settings in extraOptions.
233+
val params = extraOptions.toMap ++ connectionProperties.asScala.toMap
234+
val options = new JDBCOptions(url, table, params)
223235
val parts: Array[Partition] = predicates.zipWithIndex.map { case (part, i) =>
224236
JDBCPartition(part, i) : Partition
225237
}
226-
jdbc(url, table, parts, connectionProperties)
227-
}
228-
229-
private def jdbc(
230-
url: String,
231-
table: String,
232-
parts: Array[Partition],
233-
connectionProperties: Properties): DataFrame = {
234-
// connectionProperties should override settings in extraOptions.
235-
this.extraOptions = this.extraOptions ++ connectionProperties.asScala
236-
// explicit url and dbtable should override all
237-
this.extraOptions += ("url" -> url, "dbtable" -> table)
238-
format("jdbc").load()
238+
val relation = JDBCRelation(parts, options)(sparkSession)
239+
sparkSession.baseRelationToDataFrame(relation)
239240
}
240241

241242
/**

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,8 @@ private[sql] case class JDBCRelation(
137137
}
138138

139139
override def toString: String = {
140+
val partitioningInfo = if (parts.nonEmpty) s" [numPartitions=${parts.length}]" else ""
140141
// credentials should not be included in the plan output, table information is sufficient.
141-
s"JDBCRelation(${jdbcOptions.table})"
142+
s"JDBCRelation(${jdbcOptions.table})" + partitioningInfo
142143
}
143144
}

sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala

Lines changed: 48 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@ import java.util.{Calendar, GregorianCalendar, Properties}
2424
import org.h2.jdbc.JdbcSQLException
2525
import org.scalatest.{BeforeAndAfter, PrivateMethodTester}
2626

27-
import org.apache.spark.{SparkException, SparkFunSuite}
27+
import org.apache.spark.SparkFunSuite
2828
import org.apache.spark.sql.{DataFrame, Row}
2929
import org.apache.spark.sql.execution.DataSourceScanExec
3030
import org.apache.spark.sql.execution.command.ExplainCommand
3131
import org.apache.spark.sql.execution.datasources.LogicalRelation
32-
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JdbcUtils}
32+
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation, JdbcUtils}
3333
import org.apache.spark.sql.sources._
3434
import org.apache.spark.sql.test.SharedSQLContext
3535
import org.apache.spark.sql.types._
@@ -209,6 +209,16 @@ class JDBCSuite extends SparkFunSuite
209209
conn.close()
210210
}
211211

212+
// Check whether the tables are fetched in the expected degree of parallelism
213+
def checkNumPartitions(df: DataFrame, expectedNumPartitions: Int): Unit = {
214+
val jdbcRelations = df.queryExecution.analyzed.collect {
215+
case LogicalRelation(r: JDBCRelation, _, _) => r
216+
}
217+
assert(jdbcRelations.length == 1)
218+
assert(jdbcRelations.head.parts.length == expectedNumPartitions,
219+
s"Expecting a JDBCRelation with $expectedNumPartitions partitions, but got:`$jdbcRelations`")
220+
}
221+
212222
test("SELECT *") {
213223
assert(sql("SELECT * FROM foobar").collect().size === 3)
214224
}
@@ -313,13 +323,23 @@ class JDBCSuite extends SparkFunSuite
313323
}
314324

315325
test("SELECT * partitioned") {
316-
assert(sql("SELECT * FROM parts").collect().size == 3)
326+
val df = sql("SELECT * FROM parts")
327+
checkNumPartitions(df, expectedNumPartitions = 3)
328+
assert(df.collect().length == 3)
317329
}
318330

319331
test("SELECT WHERE (simple predicates) partitioned") {
320-
assert(sql("SELECT * FROM parts WHERE THEID < 1").collect().size === 0)
321-
assert(sql("SELECT * FROM parts WHERE THEID != 2").collect().size === 2)
322-
assert(sql("SELECT THEID FROM parts WHERE THEID = 1").collect().size === 1)
332+
val df1 = sql("SELECT * FROM parts WHERE THEID < 1")
333+
checkNumPartitions(df1, expectedNumPartitions = 3)
334+
assert(df1.collect().length === 0)
335+
336+
val df2 = sql("SELECT * FROM parts WHERE THEID != 2")
337+
checkNumPartitions(df2, expectedNumPartitions = 3)
338+
assert(df2.collect().length === 2)
339+
340+
val df3 = sql("SELECT THEID FROM parts WHERE THEID = 1")
341+
checkNumPartitions(df3, expectedNumPartitions = 3)
342+
assert(df3.collect().length === 1)
323343
}
324344

325345
test("SELECT second field partitioned") {
@@ -370,24 +390,27 @@ class JDBCSuite extends SparkFunSuite
370390
}
371391

372392
test("Partitioning via JDBCPartitioningInfo API") {
373-
assert(
374-
spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties())
375-
.collect().length === 3)
393+
val df = spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties())
394+
checkNumPartitions(df, expectedNumPartitions = 3)
395+
assert(df.collect().length === 3)
376396
}
377397

378398
test("Partitioning via list-of-where-clauses API") {
379399
val parts = Array[String]("THEID < 2", "THEID >= 2")
380-
assert(spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties())
381-
.collect().length === 3)
400+
val df = spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties())
401+
checkNumPartitions(df, expectedNumPartitions = 2)
402+
assert(df.collect().length === 3)
382403
}
383404

384405
test("Partitioning on column that might have null values.") {
385-
assert(
386-
spark.read.jdbc(urlWithUserAndPass, "TEST.EMP", "theid", 0, 4, 3, new Properties())
387-
.collect().length === 4)
388-
assert(
389-
spark.read.jdbc(urlWithUserAndPass, "TEST.EMP", "THEID", 0, 4, 3, new Properties())
390-
.collect().length === 4)
406+
val df = spark.read.jdbc(urlWithUserAndPass, "TEST.EMP", "theid", 0, 4, 3, new Properties())
407+
checkNumPartitions(df, expectedNumPartitions = 3)
408+
assert(df.collect().length === 4)
409+
410+
val df2 = spark.read.jdbc(urlWithUserAndPass, "TEST.EMP", "THEID", 0, 4, 3, new Properties())
411+
checkNumPartitions(df2, expectedNumPartitions = 3)
412+
assert(df2.collect().length === 4)
413+
391414
// partitioning on a nullable quoted column
392415
assert(
393416
spark.read.jdbc(urlWithUserAndPass, "TEST.EMP", """"Dept"""", 0, 4, 3, new Properties())
@@ -404,6 +427,7 @@ class JDBCSuite extends SparkFunSuite
404427
numPartitions = 0,
405428
connectionProperties = new Properties()
406429
)
430+
checkNumPartitions(res, expectedNumPartitions = 1)
407431
assert(res.count() === 8)
408432
}
409433

@@ -417,6 +441,7 @@ class JDBCSuite extends SparkFunSuite
417441
numPartitions = 10,
418442
connectionProperties = new Properties()
419443
)
444+
checkNumPartitions(res, expectedNumPartitions = 4)
420445
assert(res.count() === 8)
421446
}
422447

@@ -430,6 +455,7 @@ class JDBCSuite extends SparkFunSuite
430455
numPartitions = 4,
431456
connectionProperties = new Properties()
432457
)
458+
checkNumPartitions(res, expectedNumPartitions = 1)
433459
assert(res.count() === 8)
434460
}
435461

@@ -450,7 +476,9 @@ class JDBCSuite extends SparkFunSuite
450476
}
451477

452478
test("SELECT * on partitioned table with a nullable partition column") {
453-
assert(sql("SELECT * FROM nullparts").collect().size == 4)
479+
val df = sql("SELECT * FROM nullparts")
480+
checkNumPartitions(df, expectedNumPartitions = 3)
481+
assert(df.collect().length == 4)
454482
}
455483

456484
test("H2 integral types") {
@@ -722,7 +750,8 @@ class JDBCSuite extends SparkFunSuite
722750
}
723751
// test the JdbcRelation toString output
724752
df.queryExecution.analyzed.collect {
725-
case r: LogicalRelation => assert(r.relation.toString == "JDBCRelation(TEST.PEOPLE)")
753+
case r: LogicalRelation =>
754+
assert(r.relation.toString == "JDBCRelation(TEST.PEOPLE) [numPartitions=3]")
726755
}
727756
}
728757

0 commit comments

Comments
 (0)