Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ and use that as a temp location for this data.
<td><tt>jdbcdriver</tt></td>
<td>No</td>
<td>Determined by the JDBC URL's subprotocol</td>
<td>The class name of the JDBC driver to load before JDBC operations. This class must be on the classpath. In most cases, it should not be necessary to specify this option, as the appropriate driver classname should automatically be determined by the JDBC URL's subprotocol.</td>
<td>The class name of the JDBC driver to use. This class must be on the classpath. In most cases, it should not be necessary to specify this option, as the appropriate driver classname should automatically be determined by the JDBC URL's subprotocol.</td>
</tr>
<tr>
<td><tt>diststyle</tt></td>
Expand Down
3 changes: 3 additions & 0 deletions project/SparkRedshiftBuild.scala
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ object SparkRedshiftBuild extends Build {
// For testing, we use an Amazon driver, which is available from
// http://docs.aws.amazon.com/redshift/latest/mgmt/configure-jdbc-connection.html
"com.amazon.redshift" % "jdbc4" % "1.1.7.1007" % "test" from "https://s3.amazonaws.com/redshift-downloads/drivers/RedshiftJDBC4-1.1.7.1007.jar",
// Although support for the postgres driver is lower priority than support for Amazon's
// official Redshift driver, we still run basic tests with it.
"postgresql" % "postgresql" % "8.3-606.jdbc4" % "test",
"com.google.guava" % "guava" % "14.0.1" % "test",
"org.scalatest" %% "scalatest" % "2.2.1" % "test",
"org.mockito" % "mockito-core" % "1.10.19" % "test"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ trait IntegrationSuiteBase
protected val AWS_S3_SCRATCH_SPACE: String = loadConfigFromEnv("AWS_S3_SCRATCH_SPACE")
require(AWS_S3_SCRATCH_SPACE.contains("s3n"), "must use s3n:// URL")

protected val jdbcUrl: String = {
protected def jdbcUrl: String = {
s"$AWS_REDSHIFT_JDBC_URL?user=$AWS_REDSHIFT_USER&password=$AWS_REDSHIFT_PASSWORD"
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright 2015 Databricks
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.databricks.spark.redshift

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

/**
* Basic integration tests with the Postgres JDBC driver.
*/
class PostgresDriverIntegrationSuite extends IntegrationSuiteBase {

override def jdbcUrl: String = {
super.jdbcUrl.replace("jdbc:redshift", "jdbc:postgresql")
}

test("postgresql driver takes precedence for jdbc:postgresql:// URIs") {
val conn = DefaultJDBCWrapper.getConnector(None, jdbcUrl)
try {
assert(conn.getClass.getName === "org.postgresql.jdbc4.Jdbc4Connection")
} finally {
conn.close()
}
}

test("roundtrip save and load") {
val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row(1)), 1),
StructType(StructField("foo", IntegerType) :: Nil))
testRoundtripSaveAndLoad(s"save_with_one_empty_partition_$randomSuffix", df)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@

package com.databricks.spark.redshift

import java.net.URI
import java.sql.{ResultSet, PreparedStatement, Connection, Driver, DriverManager, ResultSetMetaData, SQLException}
import java.util.Properties
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.{ThreadFactory, Executors}

import scala.collection.JavaConverters._
import scala.concurrent.{Await, ExecutionContext, Future}
import scala.concurrent.duration.Duration
import scala.util.Try
Expand Down Expand Up @@ -53,40 +53,45 @@ private[redshift] class JDBCWrapper {
}

/**
* Given a JDBC subprotocol, returns the appropriate driver class so that it can be registered
* with Spark. If the user has explicitly specified a driver class in their configuration then
* that class will be used. Otherwise, we will attempt to load the correct driver class based on
* Given a JDBC subprotocol, returns the name of the appropriate driver class to use.
*
* If the user has explicitly specified a driver class in their configuration then that class will
* be used. Otherwise, we will attempt to load the correct driver class based on
* the JDBC subprotocol.
*
* @param jdbcSubprotocol 'redshift' or 'postgres'
* @param jdbcSubprotocol 'redshift' or 'postgresql'
* @param userProvidedDriverClass an optional user-provided explicit driver class name
* @return the driver class
*/
private def getDriverClass(
jdbcSubprotocol: String,
userProvidedDriverClass: Option[String]): Class[Driver] = {
userProvidedDriverClass.map(Utils.classForName).getOrElse {
userProvidedDriverClass: Option[String]): String = {
userProvidedDriverClass.getOrElse {
jdbcSubprotocol match {
case "redshift" =>
try {
Utils.classForName("com.amazon.redshift.jdbc41.Driver")
Utils.classForName("com.amazon.redshift.jdbc41.Driver").getName
} catch {
case _: ClassNotFoundException =>
try {
Utils.classForName("com.amazon.redshift.jdbc4.Driver")
Utils.classForName("com.amazon.redshift.jdbc4.Driver").getName
} catch {
case e: ClassNotFoundException =>
throw new ClassNotFoundException(
"Could not load an Amazon Redshift JDBC driver; see the README for " +
"instructions on downloading and configuring the official Amazon driver.", e)
}
}
case "postgres" => Utils.classForName("org.postgresql.Driver")
case "postgresql" => "org.postgresql.Driver"
case other => throw new IllegalArgumentException(s"Unsupported JDBC protocol: '$other'")
}
}.asInstanceOf[Class[Driver]]
}
}

/**
* Reflectively calls Spark's `DriverRegistry.register()`, which handles corner-cases related to
* using JDBC drivers that are not accessible from the bootstrap classloader.
*/
private def registerDriver(driverClass: String): Unit = {
// DriverRegistry.register() is one of the few pieces of private Spark functionality which
// we need to rely on. This class was relocated in Spark 1.5.0, so we need to use reflection
Expand Down Expand Up @@ -194,9 +199,19 @@ private[redshift] class JDBCWrapper {
*/
def getConnector(userProvidedDriverClass: Option[String], url: String): Connection = {
val subprotocol = url.stripPrefix("jdbc:").split(":")(0)
val driverClass: Class[Driver] = getDriverClass(subprotocol, userProvidedDriverClass)
registerDriver(driverClass.getCanonicalName)
DriverManager.getConnection(url, new Properties())
val driverClass: String = getDriverClass(subprotocol, userProvidedDriverClass)
registerDriver(driverClass)
// Note that we purposely don't call DriverManager.getConnection() here: we want to ensure
// that an explicitly-specified user-provided driver class can take precedence over the default
// class, but DriverManager.getConnection() might return a according to a different precedence.
// At the same time, we don't want to create a driver-per-connection, so we use the
// DriverManager's driver instances to handle that singleton logic for us.
val driver: Driver = DriverManager.getDrivers.asScala.collectFirst {
case d if d.getClass.getCanonicalName == driverClass => d
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just realized that there's a bit of a bug here: if d is an instance of DriverWrapper then we need to compare against d.wrapped.getClass.getCanonicalName. I'll open a followup PR to fix this.

}.getOrElse {
throw new IllegalArgumentException(s"Did not find registered driver with class $driverClass")
}
driver.connect(url, new Properties())
}

/**
Expand Down