Skip to content

Commit d4950e6

Browse files
chenghao-intelJoshRosen
authored andcommitted
[SPARK-9735][SQL] Respect the user specified schema than the infer partition schema for HadoopFsRelation
To enable the unit test of `hadoopFsRelationSuite.Partition column type casting`. It previously threw exception like below, as we treat the auto infer partition schema with higher priority than the user specified one. ``` java.lang.ClassCastException: java.lang.Integer cannot be cast to org.apache.spark.unsafe.types.UTF8String at org.apache.spark.sql.catalyst.expressions.BaseGenericInternalRow$class.getUTF8String(rows.scala:45) at org.apache.spark.sql.catalyst.expressions.GenericInternalRow.getUTF8String(rows.scala:220) at org.apache.spark.sql.catalyst.expressions.JoinedRow.getUTF8String(JoinedRow.scala:102) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(generated.java:62) at org.apache.spark.sql.execution.datasources.DataSourceStrategy$$anonfun$17$$anonfun$apply$9.apply(DataSourceStrategy.scala:212) at org.apache.spark.sql.execution.datasources.DataSourceStrategy$$anonfun$17$$anonfun$apply$9.apply(DataSourceStrategy.scala:212) at scala.collection.Iterator$$anon$11.next(Iterator.scala:328) at scala.collection.Iterator$$anon$11.next(Iterator.scala:328) at scala.collection.Iterator$class.foreach(Iterator.scala:727) at scala.collection.AbstractIterator.foreach(Iterator.scala:1157) at scala.collection.generic.Growable$class.$plus$plus$eq(Growable.scala:48) at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:103) at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:47) at scala.collection.TraversableOnce$class.to(TraversableOnce.scala:273) at scala.collection.AbstractIterator.to(Iterator.scala:1157) at scala.collection.TraversableOnce$class.toBuffer(TraversableOnce.scala:265) at scala.collection.AbstractIterator.toBuffer(Iterator.scala:1157) at scala.collection.TraversableOnce$class.toArray(TraversableOnce.scala:252) at scala.collection.AbstractIterator.toArray(Iterator.scala:1157) at org.apache.spark.rdd.RDD$$anonfun$collect$1$$anonfun$12.apply(RDD.scala:903) at org.apache.spark.rdd.RDD$$anonfun$collect$1$$anonfun$12.apply(RDD.scala:903) at org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1846) at org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1846) at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:66) at org.apache.spark.scheduler.Task.run(Task.scala:88) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:214) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:615) at java.lang.Thread.run(Thread.java:745) 07:44:01.344 ERROR org.apache.spark.executor.Executor: Exception in task 14.0 in stage 3.0 (TID 206) java.lang.ClassCastException: java.lang.Integer cannot be cast to org.apache.spark.unsafe.types.UTF8String at org.apache.spark.sql.catalyst.expressions.BaseGenericInternalRow$class.getUTF8String(rows.scala:45) at org.apache.spark.sql.catalyst.expressions.GenericInternalRow.getUTF8String(rows.scala:220) at org.apache.spark.sql.catalyst.expressions.JoinedRow.getUTF8String(JoinedRow.scala:102) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(generated.java:62) at org.apache.spark.sql.execution.datasources.DataSourceStrategy$$anonfun$17$$anonfun$apply$9.apply(DataSourceStrategy.scala:212) at org.apache.spark.sql.execution.datasources.DataSourceStrategy$$anonfun$17$$anonfun$apply$9.apply(DataSourceStrategy.scala:212) at scala.collection.Iterator$$anon$11.next(Iterator.scala:328) at scala.collection.Iterator$$anon$11.next(Iterator.scala:328) at scala.collection.Iterator$class.foreach(Iterator.scala:727) at scala.collection.AbstractIterator.foreach(Iterator.scala:1157) at scala.collection.generic.Growable$class.$plus$plus$eq(Growable.scala:48) at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:103) at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:47) at scala.collection.TraversableOnce$class.to(TraversableOnce.scala:273) at scala.collection.AbstractIterator.to(Iterator.scala:1157) at scala.collection.TraversableOnce$class.toBuffer(TraversableOnce.scala:265) at scala.collection.AbstractIterator.toBuffer(Iterator.scala:1157) at scala.collection.TraversableOnce$class.toArray(TraversableOnce.scala:252) at scala.collection.AbstractIterator.toArray(Iterator.scala:1157) at org.apache.spark.rdd.RDD$$anonfun$collect$1$$anonfun$12.apply(RDD.scala:903) at org.apache.spark.rdd.RDD$$anonfun$collect$1$$anonfun$12.apply(RDD.scala:903) at org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1846) at org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1846) at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:66) at org.apache.spark.scheduler.Task.run(Task.scala:88) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:214) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:615) at java.lang.Thread.run(Thread.java:745) ``` Author: Cheng Hao <[email protected]> Closes #8026 from chenghao-intel/partition_discovery.
1 parent 3535b91 commit d4950e6

File tree

2 files changed

+55
-16
lines changed

2 files changed

+55
-16
lines changed

sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions._
3333
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection
3434
import org.apache.spark.sql.execution.{FileRelation, RDDConversions}
3535
import org.apache.spark.sql.execution.datasources.{PartitioningUtils, PartitionSpec, Partition}
36-
import org.apache.spark.sql.types.StructType
36+
import org.apache.spark.sql.types.{StringType, StructType}
3737
import org.apache.spark.sql._
3838
import org.apache.spark.util.SerializableConfiguration
3939

@@ -544,11 +544,32 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio
544544
}
545545

546546
private def discoverPartitions(): PartitionSpec = {
547-
val typeInference = sqlContext.conf.partitionColumnTypeInferenceEnabled()
548547
// We use leaf dirs containing data files to discover the schema.
549548
val leafDirs = fileStatusCache.leafDirToChildrenFiles.keys.toSeq
550-
PartitioningUtils.parsePartitions(leafDirs, PartitioningUtils.DEFAULT_PARTITION_NAME,
551-
typeInference)
549+
userDefinedPartitionColumns match {
550+
case Some(userProvidedSchema) if userProvidedSchema.nonEmpty =>
551+
val spec = PartitioningUtils.parsePartitions(
552+
leafDirs, PartitioningUtils.DEFAULT_PARTITION_NAME, typeInference = false)
553+
554+
// Without auto inference, all of value in the `row` should be null or in StringType,
555+
// we need to cast into the data type that user specified.
556+
def castPartitionValuesToUserSchema(row: InternalRow) = {
557+
InternalRow((0 until row.numFields).map { i =>
558+
Cast(
559+
Literal.create(row.getString(i), StringType),
560+
userProvidedSchema.fields(i).dataType).eval()
561+
}: _*)
562+
}
563+
564+
PartitionSpec(userProvidedSchema, spec.partitions.map { part =>
565+
part.copy(values = castPartitionValuesToUserSchema(part.values))
566+
})
567+
568+
case _ =>
569+
// user did not provide a partitioning schema
570+
PartitioningUtils.parsePartitions(leafDirs, PartitioningUtils.DEFAULT_PARTITION_NAME,
571+
typeInference = sqlContext.conf.partitionColumnTypeInferenceEnabled())
572+
}
552573
}
553574

554575
/**

sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -510,21 +510,39 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
510510
}
511511
}
512512

513-
// HadoopFsRelation.discoverPartitions() called by refresh(), which will ignore
514-
// the given partition data type.
515-
ignore("Partition column type casting") {
513+
test("SPARK-9735 Partition column type casting") {
516514
withTempPath { file =>
517-
val input = partitionedTestDF.select('a, 'b, 'p1.cast(StringType).as('ps), 'p2)
518-
519-
input
520-
.write
521-
.format(dataSourceName)
522-
.mode(SaveMode.Overwrite)
523-
.partitionBy("ps", "p2")
524-
.saveAsTable("t")
515+
val df = (for {
516+
i <- 1 to 3
517+
p2 <- Seq("foo", "bar")
518+
} yield (i, s"val_$i", 1.0d, p2, 123, 123.123f)).toDF("a", "b", "p1", "p2", "p3", "f")
519+
520+
val input = df.select(
521+
'a,
522+
'b,
523+
'p1.cast(StringType).as('ps1),
524+
'p2,
525+
'p3.cast(FloatType).as('pf1),
526+
'f)
525527

526528
withTempTable("t") {
527-
checkAnswer(sqlContext.table("t"), input.collect())
529+
input
530+
.write
531+
.format(dataSourceName)
532+
.mode(SaveMode.Overwrite)
533+
.partitionBy("ps1", "p2", "pf1", "f")
534+
.saveAsTable("t")
535+
536+
input
537+
.write
538+
.format(dataSourceName)
539+
.mode(SaveMode.Append)
540+
.partitionBy("ps1", "p2", "pf1", "f")
541+
.saveAsTable("t")
542+
543+
val realData = input.collect()
544+
545+
checkAnswer(sqlContext.table("t"), realData ++ realData)
528546
}
529547
}
530548
}

0 commit comments

Comments
 (0)