Skip to content

Commit 1cd1919

Browse files
Fix
1 parent b14993e commit 1cd1919

File tree

4 files changed

+70
-8
lines changed

4 files changed

+70
-8
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
3030
import org.apache.spark.sql.catalyst.json._
3131
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
3232
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, BadRecordException, FailFastMode, GenericArrayData, MapData}
33+
import org.apache.spark.sql.internal.SQLConf
3334
import org.apache.spark.sql.types._
3435
import org.apache.spark.unsafe.types.UTF8String
3536
import org.apache.spark.util.Utils
@@ -515,10 +516,15 @@ case class JsonToStructs(
515516
child: Expression,
516517
timeZoneId: Option[String] = None)
517518
extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes {
518-
override def nullable: Boolean = true
519519

520-
def this(schema: DataType, options: Map[String, String], child: Expression) =
521-
this(schema, options, child, None)
520+
val forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)
521+
522+
// The JSON input data might be missing certain fields. We force the nullability
523+
// of the user-provided schema to avoid data corruptions. In particular, the parquet-mr encoder
524+
// can generate incorrect files if values are missing in columns declared as non-nullable.
525+
val nullableSchema = if (forceNullableSchema) schema.asNullable else schema
526+
527+
override def nullable: Boolean = true
522528

523529
// Used in `FunctionRegistry`
524530
def this(child: Expression, schema: Expression) =
@@ -535,22 +541,22 @@ case class JsonToStructs(
535541
child = child,
536542
timeZoneId = None)
537543

538-
override def checkInputDataTypes(): TypeCheckResult = schema match {
544+
override def checkInputDataTypes(): TypeCheckResult = nullableSchema match {
539545
case _: StructType | ArrayType(_: StructType, _) =>
540546
super.checkInputDataTypes()
541547
case _ => TypeCheckResult.TypeCheckFailure(
542-
s"Input schema ${schema.simpleString} must be a struct or an array of structs.")
548+
s"Input schema ${nullableSchema.simpleString} must be a struct or an array of structs.")
543549
}
544550

545551
@transient
546-
lazy val rowSchema = schema match {
552+
lazy val rowSchema = nullableSchema match {
547553
case st: StructType => st
548554
case ArrayType(st: StructType, _) => st
549555
}
550556

551557
// This converts parsed rows to the desired output by the given schema.
552558
@transient
553-
lazy val converter = schema match {
559+
lazy val converter = nullableSchema match {
554560
case _: StructType =>
555561
(rows: Seq[InternalRow]) => if (rows.length == 1) rows.head else null
556562
case ArrayType(_: StructType, _) =>
@@ -563,7 +569,7 @@ case class JsonToStructs(
563569
rowSchema,
564570
new JSONOptions(options + ("mode" -> FailFastMode.name), timeZoneId.get))
565571

566-
override def dataType: DataType = schema
572+
override def dataType: DataType = nullableSchema
567573

568574
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
569575
copy(timeZoneId = Option(timeZoneId))

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,14 @@ object SQLConf {
493493
.stringConf
494494
.createWithDefault("_corrupt_record")
495495

496+
val FROM_JSON_FORCE_NULLABLE_SCHEMA = buildConf("spark.sql.fromJsonForceNullableSchema")
497+
.internal()
498+
.doc("When true, force the output schema of the from_json() function to be nullable " +
499+
"(including all the fields). Otherwise, the schema might not be compatible with" +
500+
"actual data, which leads to curruptions.")
501+
.booleanConf
502+
.createWithDefault(true)
503+
496504
val BROADCAST_TIMEOUT = buildConf("spark.sql.broadcastTimeout")
497505
.doc("Timeout in seconds for the broadcast wait time in broadcast joins.")
498506
.timeConf(TimeUnit.SECONDS)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import org.apache.spark.SparkFunSuite
2323
import org.apache.spark.sql.catalyst.InternalRow
2424
import org.apache.spark.sql.catalyst.errors.TreeNodeException
2525
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeTestUtils, DateTimeUtils, GenericArrayData, PermissiveMode}
26+
import org.apache.spark.sql.internal.SQLConf
2627
import org.apache.spark.sql.types._
2728
import org.apache.spark.unsafe.types.UTF8String
2829

@@ -680,4 +681,31 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
680681
)
681682
}
682683
}
684+
685+
test("from_json missing fields") {
686+
val conf = SQLConf.get
687+
for (forceJsonNullableSchema <- Seq(false, true)) {
688+
conf.setConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA, forceJsonNullableSchema)
689+
val input =
690+
"""{
691+
| "a": 1,
692+
| "c": "foo"
693+
|}
694+
|"""
695+
.stripMargin
696+
val jsonSchema = new StructType()
697+
.add("a", LongType, nullable = false)
698+
.add("b", StringType, nullable = false)
699+
.add("c", StringType, nullable = false)
700+
val output = InternalRow(1L, null, UTF8String.fromString("foo"))
701+
checkEvaluation(
702+
JsonToStructs(jsonSchema, Map.empty, Literal.create(input, StringType), gmtId),
703+
output
704+
)
705+
val schema = JsonToStructs(jsonSchema, Map.empty, Literal.create(input, StringType), gmtId)
706+
.dataType
707+
val schemaToCompare = if (forceJsonNullableSchema) jsonSchema.asNullable else jsonSchema
708+
assert(schemaToCompare == schema);
709+
}
710+
}
683711
}

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection}
4343
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow}
4444
import org.apache.spark.sql.catalyst.util.DateTimeUtils
4545
import org.apache.spark.sql.execution.datasources.SQLHadoopMapReduceCommitProtocol
46+
import org.apache.spark.sql.functions._
4647
import org.apache.spark.sql.internal.SQLConf
4748
import org.apache.spark.sql.test.SharedSQLContext
4849
import org.apache.spark.sql.types._
@@ -780,6 +781,25 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
780781
assert(option.compressionCodecClassName == "UNCOMPRESSED")
781782
}
782783
}
784+
785+
test("SPARK-23173 Writing a file with data converted from JSON with and incorrect user schema") {
786+
withTempPath { file =>
787+
val jsonData =
788+
"""{
789+
| "a": 1,
790+
| "c": "foo"
791+
|}
792+
|"""
793+
.stripMargin
794+
val jsonSchema = new StructType()
795+
.add("a", LongType, nullable = false)
796+
.add("b", StringType, nullable = false)
797+
.add("c", StringType, nullable = false)
798+
spark.range(1).select(from_json(lit(jsonData), jsonSchema) as "input")
799+
.write.parquet(file.getAbsolutePath)
800+
checkAnswer(spark.read.parquet(file.getAbsolutePath), Seq(Row(Row(1, null, "foo"))))
801+
}
802+
}
783803
}
784804

785805
class JobCommitFailureParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext)

0 commit comments

Comments
 (0)