diff --git a/README.md b/README.md index d16bbddc..e6842894 100644 --- a/README.md +++ b/README.md @@ -268,7 +268,7 @@ and use that as a temp location for this data. jdbcdriver No Determined by the JDBC URL's subprotocol - The class name of the JDBC driver to load before JDBC operations. This class must be on the classpath. In most cases, it should not be necessary to specify this option, as the appropriate driver classname should automatically be determined by the JDBC URL's subprotocol. + The class name of the JDBC driver to use. This class must be on the classpath. In most cases, it should not be necessary to specify this option, as the appropriate driver classname should automatically be determined by the JDBC URL's subprotocol. diststyle diff --git a/project/SparkRedshiftBuild.scala b/project/SparkRedshiftBuild.scala index 01bc9fba..ef8480a9 100644 --- a/project/SparkRedshiftBuild.scala +++ b/project/SparkRedshiftBuild.scala @@ -86,6 +86,9 @@ object SparkRedshiftBuild extends Build { // For testing, we use an Amazon driver, which is available from // http://docs.aws.amazon.com/redshift/latest/mgmt/configure-jdbc-connection.html "com.amazon.redshift" % "jdbc4" % "1.1.7.1007" % "test" from "https://s3.amazonaws.com/redshift-downloads/drivers/RedshiftJDBC4-1.1.7.1007.jar", + // Although support for the postgres driver is lower priority than support for Amazon's + // official Redshift driver, we still run basic tests with it. + "postgresql" % "postgresql" % "8.3-606.jdbc4" % "test", "com.google.guava" % "guava" % "14.0.1" % "test", "org.scalatest" %% "scalatest" % "2.2.1" % "test", "org.mockito" % "mockito-core" % "1.10.19" % "test" diff --git a/src/it/scala/com/databricks/spark/redshift/IntegrationSuiteBase.scala b/src/it/scala/com/databricks/spark/redshift/IntegrationSuiteBase.scala index 03254c00..9e8eacc8 100644 --- a/src/it/scala/com/databricks/spark/redshift/IntegrationSuiteBase.scala +++ b/src/it/scala/com/databricks/spark/redshift/IntegrationSuiteBase.scala @@ -60,7 +60,7 @@ trait IntegrationSuiteBase protected val AWS_S3_SCRATCH_SPACE: String = loadConfigFromEnv("AWS_S3_SCRATCH_SPACE") require(AWS_S3_SCRATCH_SPACE.contains("s3n"), "must use s3n:// URL") - protected val jdbcUrl: String = { + protected def jdbcUrl: String = { s"$AWS_REDSHIFT_JDBC_URL?user=$AWS_REDSHIFT_USER&password=$AWS_REDSHIFT_PASSWORD" } diff --git a/src/it/scala/com/databricks/spark/redshift/PostgresDriverIntegrationSuite.scala b/src/it/scala/com/databricks/spark/redshift/PostgresDriverIntegrationSuite.scala new file mode 100644 index 00000000..c27ddef0 --- /dev/null +++ b/src/it/scala/com/databricks/spark/redshift/PostgresDriverIntegrationSuite.scala @@ -0,0 +1,45 @@ +/* + * Copyright 2015 Databricks + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.databricks.spark.redshift + +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} + +/** + * Basic integration tests with the Postgres JDBC driver. + */ +class PostgresDriverIntegrationSuite extends IntegrationSuiteBase { + + override def jdbcUrl: String = { + super.jdbcUrl.replace("jdbc:redshift", "jdbc:postgresql") + } + + test("postgresql driver takes precedence for jdbc:postgresql:// URIs") { + val conn = DefaultJDBCWrapper.getConnector(None, jdbcUrl) + try { + assert(conn.getClass.getName === "org.postgresql.jdbc4.Jdbc4Connection") + } finally { + conn.close() + } + } + + test("roundtrip save and load") { + val df = sqlContext.createDataFrame(sc.parallelize(Seq(Row(1)), 1), + StructType(StructField("foo", IntegerType) :: Nil)) + testRoundtripSaveAndLoad(s"save_with_one_empty_partition_$randomSuffix", df) + } +} diff --git a/src/main/scala/com/databricks/spark/redshift/RedshiftJDBCWrapper.scala b/src/main/scala/com/databricks/spark/redshift/RedshiftJDBCWrapper.scala index 99c07ca4..c0dc9dad 100644 --- a/src/main/scala/com/databricks/spark/redshift/RedshiftJDBCWrapper.scala +++ b/src/main/scala/com/databricks/spark/redshift/RedshiftJDBCWrapper.scala @@ -17,12 +17,12 @@ package com.databricks.spark.redshift -import java.net.URI import java.sql.{ResultSet, PreparedStatement, Connection, Driver, DriverManager, ResultSetMetaData, SQLException} import java.util.Properties import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.{ThreadFactory, Executors} +import scala.collection.JavaConverters._ import scala.concurrent.{Await, ExecutionContext, Future} import scala.concurrent.duration.Duration import scala.util.Try @@ -53,27 +53,28 @@ private[redshift] class JDBCWrapper { } /** - * Given a JDBC subprotocol, returns the appropriate driver class so that it can be registered - * with Spark. If the user has explicitly specified a driver class in their configuration then - * that class will be used. Otherwise, we will attempt to load the correct driver class based on + * Given a JDBC subprotocol, returns the name of the appropriate driver class to use. + * + * If the user has explicitly specified a driver class in their configuration then that class will + * be used. Otherwise, we will attempt to load the correct driver class based on * the JDBC subprotocol. * - * @param jdbcSubprotocol 'redshift' or 'postgres' + * @param jdbcSubprotocol 'redshift' or 'postgresql' * @param userProvidedDriverClass an optional user-provided explicit driver class name * @return the driver class */ private def getDriverClass( jdbcSubprotocol: String, - userProvidedDriverClass: Option[String]): Class[Driver] = { - userProvidedDriverClass.map(Utils.classForName).getOrElse { + userProvidedDriverClass: Option[String]): String = { + userProvidedDriverClass.getOrElse { jdbcSubprotocol match { case "redshift" => try { - Utils.classForName("com.amazon.redshift.jdbc41.Driver") + Utils.classForName("com.amazon.redshift.jdbc41.Driver").getName } catch { case _: ClassNotFoundException => try { - Utils.classForName("com.amazon.redshift.jdbc4.Driver") + Utils.classForName("com.amazon.redshift.jdbc4.Driver").getName } catch { case e: ClassNotFoundException => throw new ClassNotFoundException( @@ -81,12 +82,16 @@ private[redshift] class JDBCWrapper { "instructions on downloading and configuring the official Amazon driver.", e) } } - case "postgres" => Utils.classForName("org.postgresql.Driver") + case "postgresql" => "org.postgresql.Driver" case other => throw new IllegalArgumentException(s"Unsupported JDBC protocol: '$other'") } - }.asInstanceOf[Class[Driver]] + } } + /** + * Reflectively calls Spark's `DriverRegistry.register()`, which handles corner-cases related to + * using JDBC drivers that are not accessible from the bootstrap classloader. + */ private def registerDriver(driverClass: String): Unit = { // DriverRegistry.register() is one of the few pieces of private Spark functionality which // we need to rely on. This class was relocated in Spark 1.5.0, so we need to use reflection @@ -194,9 +199,19 @@ private[redshift] class JDBCWrapper { */ def getConnector(userProvidedDriverClass: Option[String], url: String): Connection = { val subprotocol = url.stripPrefix("jdbc:").split(":")(0) - val driverClass: Class[Driver] = getDriverClass(subprotocol, userProvidedDriverClass) - registerDriver(driverClass.getCanonicalName) - DriverManager.getConnection(url, new Properties()) + val driverClass: String = getDriverClass(subprotocol, userProvidedDriverClass) + registerDriver(driverClass) + // Note that we purposely don't call DriverManager.getConnection() here: we want to ensure + // that an explicitly-specified user-provided driver class can take precedence over the default + // class, but DriverManager.getConnection() might return a according to a different precedence. + // At the same time, we don't want to create a driver-per-connection, so we use the + // DriverManager's driver instances to handle that singleton logic for us. + val driver: Driver = DriverManager.getDrivers.asScala.collectFirst { + case d if d.getClass.getCanonicalName == driverClass => d + }.getOrElse { + throw new IllegalArgumentException(s"Did not find registered driver with class $driverClass") + } + driver.connect(url, new Properties()) } /**