Skip to content

Commit 2ca9bb0

Browse files
mswit-databricksgatorsmile
authored andcommitted
[SPARK-23173][SQL] Avoid creating corrupt parquet files when loading data from JSON
## What changes were proposed in this pull request? The from_json() function accepts an additional parameter, where the user might specify the schema. The issue is that the specified schema might not be compatible with data. In particular, the JSON data might be missing data for fields declared as non-nullable in the schema. The from_json() function does not verify the data against such errors. When data with missing fields is sent to the parquet encoder, there is no verification either. The end results is a corrupt parquet file. To avoid corruptions, make sure that all fields in the user-specified schema are set to be nullable. Since this changes the behavior of a public function, we need to include it in release notes. The behavior can be reverted by setting `spark.sql.fromJsonForceNullableSchema=false` ## How was this patch tested? Added two new tests. Author: Michał Świtakowski <[email protected]> Closes #20694 from mswit-databricks/SPARK-23173.
1 parent 2c36736 commit 2ca9bb0

File tree

4 files changed

+70
-9
lines changed

4 files changed

+70
-9
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
3030
import org.apache.spark.sql.catalyst.json._
3131
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
3232
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, BadRecordException, FailFastMode, GenericArrayData, MapData}
33+
import org.apache.spark.sql.internal.SQLConf
3334
import org.apache.spark.sql.types._
3435
import org.apache.spark.unsafe.types.UTF8String
3536
import org.apache.spark.util.Utils
@@ -515,10 +516,15 @@ case class JsonToStructs(
515516
child: Expression,
516517
timeZoneId: Option[String] = None)
517518
extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes {
518-
override def nullable: Boolean = true
519519

520-
def this(schema: DataType, options: Map[String, String], child: Expression) =
521-
this(schema, options, child, None)
520+
val forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)
521+
522+
// The JSON input data might be missing certain fields. We force the nullability
523+
// of the user-provided schema to avoid data corruptions. In particular, the parquet-mr encoder
524+
// can generate incorrect files if values are missing in columns declared as non-nullable.
525+
val nullableSchema = if (forceNullableSchema) schema.asNullable else schema
526+
527+
override def nullable: Boolean = true
522528

523529
// Used in `FunctionRegistry`
524530
def this(child: Expression, schema: Expression) =
@@ -535,22 +541,22 @@ case class JsonToStructs(
535541
child = child,
536542
timeZoneId = None)
537543

538-
override def checkInputDataTypes(): TypeCheckResult = schema match {
544+
override def checkInputDataTypes(): TypeCheckResult = nullableSchema match {
539545
case _: StructType | ArrayType(_: StructType, _) =>
540546
super.checkInputDataTypes()
541547
case _ => TypeCheckResult.TypeCheckFailure(
542-
s"Input schema ${schema.simpleString} must be a struct or an array of structs.")
548+
s"Input schema ${nullableSchema.simpleString} must be a struct or an array of structs.")
543549
}
544550

545551
@transient
546-
lazy val rowSchema = schema match {
552+
lazy val rowSchema = nullableSchema match {
547553
case st: StructType => st
548554
case ArrayType(st: StructType, _) => st
549555
}
550556

551557
// This converts parsed rows to the desired output by the given schema.
552558
@transient
553-
lazy val converter = schema match {
559+
lazy val converter = nullableSchema match {
554560
case _: StructType =>
555561
(rows: Seq[InternalRow]) => if (rows.length == 1) rows.head else null
556562
case ArrayType(_: StructType, _) =>
@@ -563,7 +569,7 @@ case class JsonToStructs(
563569
rowSchema,
564570
new JSONOptions(options + ("mode" -> FailFastMode.name), timeZoneId.get))
565571

566-
override def dataType: DataType = schema
572+
override def dataType: DataType = nullableSchema
567573

568574
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
569575
copy(timeZoneId = Option(timeZoneId))

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,14 @@ object SQLConf {
493493
.stringConf
494494
.createWithDefault("_corrupt_record")
495495

496+
val FROM_JSON_FORCE_NULLABLE_SCHEMA = buildConf("spark.sql.fromJsonForceNullableSchema")
497+
.internal()
498+
.doc("When true, force the output schema of the from_json() function to be nullable " +
499+
"(including all the fields). Otherwise, the schema might not be compatible with" +
500+
"actual data, which leads to curruptions.")
501+
.booleanConf
502+
.createWithDefault(true)
503+
496504
val BROADCAST_TIMEOUT = buildConf("spark.sql.broadcastTimeout")
497505
.doc("Timeout in seconds for the broadcast wait time in broadcast joins.")
498506
.timeConf(TimeUnit.SECONDS)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,13 @@ import java.util.Calendar
2222
import org.apache.spark.SparkFunSuite
2323
import org.apache.spark.sql.catalyst.InternalRow
2424
import org.apache.spark.sql.catalyst.errors.TreeNodeException
25+
import org.apache.spark.sql.catalyst.plans.PlanTestBase
2526
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeTestUtils, DateTimeUtils, GenericArrayData, PermissiveMode}
27+
import org.apache.spark.sql.internal.SQLConf
2628
import org.apache.spark.sql.types._
2729
import org.apache.spark.unsafe.types.UTF8String
2830

29-
class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
31+
class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with PlanTestBase {
3032
val json =
3133
"""
3234
|{"store":{"fruit":[{"weight":8,"type":"apple"},{"weight":9,"type":"pear"}],
@@ -680,4 +682,30 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
680682
)
681683
}
682684
}
685+
686+
test("from_json missing fields") {
687+
for (forceJsonNullableSchema <- Seq(false, true)) {
688+
withSQLConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA.key -> forceJsonNullableSchema.toString) {
689+
val input =
690+
"""{
691+
| "a": 1,
692+
| "c": "foo"
693+
|}
694+
|""".stripMargin
695+
val jsonSchema = new StructType()
696+
.add("a", LongType, nullable = false)
697+
.add("b", StringType, nullable = false)
698+
.add("c", StringType, nullable = false)
699+
val output = InternalRow(1L, null, UTF8String.fromString("foo"))
700+
checkEvaluation(
701+
JsonToStructs(jsonSchema, Map.empty, Literal.create(input, StringType), gmtId),
702+
output
703+
)
704+
val schema = JsonToStructs(jsonSchema, Map.empty, Literal.create(input, StringType), gmtId)
705+
.dataType
706+
val schemaToCompare = if (forceJsonNullableSchema) jsonSchema.asNullable else jsonSchema
707+
assert(schemaToCompare == schema)
708+
}
709+
}
710+
}
683711
}

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection}
4343
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow}
4444
import org.apache.spark.sql.catalyst.util.DateTimeUtils
4545
import org.apache.spark.sql.execution.datasources.SQLHadoopMapReduceCommitProtocol
46+
import org.apache.spark.sql.functions._
4647
import org.apache.spark.sql.internal.SQLConf
4748
import org.apache.spark.sql.test.SharedSQLContext
4849
import org.apache.spark.sql.types._
@@ -780,6 +781,24 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
780781
assert(option.compressionCodecClassName == "UNCOMPRESSED")
781782
}
782783
}
784+
785+
test("SPARK-23173 Writing a file with data converted from JSON with and incorrect user schema") {
786+
withTempPath { file =>
787+
val jsonData =
788+
"""{
789+
| "a": 1,
790+
| "c": "foo"
791+
|}
792+
|""".stripMargin
793+
val jsonSchema = new StructType()
794+
.add("a", LongType, nullable = false)
795+
.add("b", StringType, nullable = false)
796+
.add("c", StringType, nullable = false)
797+
spark.range(1).select(from_json(lit(jsonData), jsonSchema) as "input")
798+
.write.parquet(file.getAbsolutePath)
799+
checkAnswer(spark.read.parquet(file.getAbsolutePath), Seq(Row(Row(1, null, "foo"))))
800+
}
801+
}
783802
}
784803

785804
class JobCommitFailureParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext)

0 commit comments

Comments
 (0)