Skip to content

Commit d95ba40

Browse files
committed
[SPARK-24924][SQL] Add mapping for built-in Avro data source
1 parent 0c83f71 commit d95ba40

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed}
3333
import org.apache.commons.io.FileUtils
3434

3535
import org.apache.spark.sql._
36+
import org.apache.spark.sql.execution.datasources.DataSource
3637
import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
3738
import org.apache.spark.sql.types._
3839

@@ -51,6 +52,13 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
5152
checkAnswer(newEntries, originalEntries)
5253
}
5354

55+
test("resolve avro data source") {
56+
Seq("avro", "com.databricks.spark.avro").foreach { provider =>
57+
assert(DataSource.lookupDataSource(provider, spark.sessionState.conf) ===
58+
classOf[org.apache.spark.sql.avro.AvroFileFormat])
59+
}
60+
}
61+
5462
test("reading from multiple paths") {
5563
val df = spark.read.format("avro").load(episodesAvro, episodesAvro)
5664
assert(df.count == 16)
@@ -456,7 +464,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
456464
// get the same values back.
457465
withTempPath { tempDir =>
458466
val name = "AvroTest"
459-
val namespace = "com.databricks.spark.avro"
467+
val namespace = "org.apache.spark.avro"
460468
val parameters = Map("recordName" -> name, "recordNamespace" -> namespace)
461469

462470
val avroDir = tempDir + "/namedAvro"

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,7 @@ object DataSource extends Logging {
571571
val nativeOrc = classOf[OrcFileFormat].getCanonicalName
572572
val socket = classOf[TextSocketSourceProvider].getCanonicalName
573573
val rate = classOf[RateStreamProvider].getCanonicalName
574+
val avro = "org.apache.spark.sql.avro.AvroFileFormat"
574575

575576
Map(
576577
"org.apache.spark.sql.jdbc" -> jdbc,
@@ -592,6 +593,7 @@ object DataSource extends Logging {
592593
"org.apache.spark.ml.source.libsvm.DefaultSource" -> libsvm,
593594
"org.apache.spark.ml.source.libsvm" -> libsvm,
594595
"com.databricks.spark.csv" -> csv,
596+
"com.databricks.spark.avro" -> avro,
595597
"org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket,
596598
"org.apache.spark.sql.execution.streaming.RateSourceProvider" -> rate
597599
)
@@ -635,12 +637,6 @@ object DataSource extends Logging {
635637
"Hive built-in ORC data source must be used with Hive support enabled. " +
636638
"Please use the native ORC data source by setting 'spark.sql.orc.impl' to " +
637639
"'native'")
638-
} else if (provider1.toLowerCase(Locale.ROOT) == "avro" ||
639-
provider1 == "com.databricks.spark.avro") {
640-
throw new AnalysisException(
641-
s"Failed to find data source: ${provider1.toLowerCase(Locale.ROOT)}. " +
642-
"Please find an Avro package at " +
643-
"http://spark.apache.org/third-party-projects.html")
644640
} else {
645641
throw new ClassNotFoundException(
646642
s"Failed to find data source: $provider1. Please find packages at " +

0 commit comments

Comments
 (0)