From 4339b81b87e1d597932ecfe2e8d708a617e0e09a Mon Sep 17 00:00:00 2001 From: angerszhu Date: Thu, 18 Jun 2020 22:25:42 +0800 Subject: [PATCH 1/5] save --- .../expressions/complexTypeExtractors.scala | 61 +++++++++++++++++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 23 +++++++ 2 files changed, 84 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 9c600c9d39cf7..69e5885db405e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -59,6 +59,23 @@ object ExtractValue { GetArrayStructFields(child, fields(ordinal).copy(name = fieldName), ordinal, fields.length, containsNull) + case (ExtractNestedArray(StructType(fields), containsNull, containsNullSeq), + NonNullLiteral(v, StringType)) => + child match { + case ExtractGetArrayStructField(_, num) if num == containsNullSeq.size => + val fieldName = v.toString + val ordinal = findField(fields, fieldName, resolver) + val row = (0 until num).foldRight(child) { (_, e) => + GetArrayItem(e, Literal(0)) + } + val innerArray = GetArrayStructFields(row, fields(ordinal).copy(name = fieldName), + ordinal, fields.length, containsNull) + containsNullSeq.foldRight(innerArray: Expression) { (_, expr) => + new CreateArray(Seq(expr)) + } + case _ => GetArrayItem(child, extraction) + } + case (_: ArrayType, _) => GetArrayItem(child, extraction) case (MapType(kt, _, _), _) => GetMapValue(child, extraction) @@ -95,6 +112,50 @@ object ExtractValue { trait ExtractValue extends Expression +object ExtractNestedArray { + + type ReturnType = Option[(DataType, Boolean, Seq[Boolean])] + + def unapply(dataType: DataType): ReturnType = { + extractArrayType(dataType) + } + + def extractArrayType(dataType: DataType): ReturnType = { + dataType match { + case ArrayType(dt, containsNull) => + extractArrayType(dt) match { + case Some((d, cn, seq)) => Some(d, cn, containsNull +: seq) + case None => Some(dt, containsNull, Seq.empty[Boolean]) + } + case _ => None + } + } +} + +/** + * Extract GetArrayStructField from Expression + */ +object ExtractGetArrayStructField { + + type ReturnType = Option[(Expression, Int)] + + def unapply(expr: Expression): ReturnType = { + extractArrayStruct(expr) + } + + def extractArrayStruct(expr: Expression): ReturnType = { + expr match { + case gas @ GetArrayStructFields(child, _, _, _, _) => + extractArrayStruct(child) match { + case Some((e, deep)) => Some(e, deep + 1) + case None => Some(child, 1) + } + case _ => None + } + } +} + + /** * Returns the value of fields in the Struct `child`. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index a219b91627b2b..ddce32b5f460a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -22,7 +22,11 @@ import java.net.{MalformedURLException, URL} import java.sql.{Date, Timestamp} import java.util.concurrent.atomic.AtomicBoolean +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.{AccumulatorSuite, SparkException} + import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, Partial} @@ -3521,6 +3525,25 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark |""".stripMargin), Row(1)) } } + + test("SPARK-32002: Support Extract valve from nested Array(a:Struct(Array(b:Struct)))") { + withTempView("rows") { + val df = spark.read + .json(Seq( + """{"a": [{"b": [{"c": [1,2]}]}]}""", + """{"a": [{"b": [{"c": [1]}, {"c": [2]}]}]}""", + """{"a":[{}]}""").toDS()) + df.createOrReplaceTempView("nest") + + checkAnswer(sql( + """ + |SELECT a.b.c FROM nest + """.stripMargin), + Row(ArrayBuffer(ArrayBuffer(ArrayBuffer(1, 2)))) :: + Row(ArrayBuffer(ArrayBuffer(ArrayBuffer(1), ArrayBuffer(2)))) :: + Row(ArrayBuffer(null)) :: Nil) + } + } } case class Foo(bar: Option[String]) From da9a3d514c845a74ec8fd7fa00a945bdeb99092a Mon Sep 17 00:00:00 2001 From: angerszhu Date: Thu, 18 Jun 2020 22:27:56 +0800 Subject: [PATCH 2/5] Update SQLQuerySuite.scala --- .../src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index ddce32b5f460a..a0a8d0055e36a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -3526,7 +3526,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } } - test("SPARK-32002: Support Extract valve from nested Array(a:Struct(Array(b:Struct)))") { + test("SPARK-32002: Support Extract valve from nested ArrayStruct") { withTempView("rows") { val df = spark.read .json(Seq( From b8f4a698d3269cc4344d75dfd193c09bfae67c36 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Thu, 18 Jun 2020 23:09:27 +0800 Subject: [PATCH 3/5] Update SQLQuerySuite.scala --- .../src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index a0a8d0055e36a..40184200d3063 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -22,11 +22,9 @@ import java.net.{MalformedURLException, URL} import java.sql.{Date, Timestamp} import java.util.concurrent.atomic.AtomicBoolean -import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.{AccumulatorSuite, SparkException} - import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, Partial} From 400ad7d04d55af4a25d2b5c92061de3f0533eecf Mon Sep 17 00:00:00 2001 From: angerszhu Date: Fri, 19 Jun 2020 11:05:02 +0800 Subject: [PATCH 4/5] Update SelectedFieldSuite.scala --- .../spark/sql/catalyst/expressions/SelectedFieldSuite.scala | 6 ------ 1 file changed, 6 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SelectedFieldSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SelectedFieldSuite.scala index 3c826e812b5cc..084259c92a409 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SelectedFieldSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SelectedFieldSuite.scala @@ -378,12 +378,6 @@ class SelectedFieldSuite extends AnalysisTest { StructField("subfield1", IntegerType, nullable = false) :: Nil)) :: Nil))) } - testSelect(arrayWithMultipleFields, "col7.field3.subfield1") { - StructField("col7", ArrayType(StructType( - StructField("field3", ArrayType(StructType( - StructField("subfield1", IntegerType, nullable = false) :: Nil))) :: Nil))) - } - // Array with a nested int array // |-- col1: string (nullable = false) // |-- col8: array (nullable = true) From b6e92c025a4f3746165a3ba660fe80cca63bb91f Mon Sep 17 00:00:00 2001 From: AngersZhuuuu Date: Sat, 20 Jun 2020 14:43:53 +0800 Subject: [PATCH 5/5] Update sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala Co-authored-by: Ruslan Dautkhanov --- .../src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 40184200d3063..5c62880f57b61 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -3524,7 +3524,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } } - test("SPARK-32002: Support Extract valve from nested ArrayStruct") { + test("SPARK-32002: Support Extract value from nested ArrayStruct") { withTempView("rows") { val df = spark.read .json(Seq(