Skip to content

Commit 2992a0e

Browse files
Bogdan Raducanuhvanhovell
authored andcommitted
[SPARK-13721][SQL] Support outer generators in DataFrame API
## What changes were proposed in this pull request? Added outer_explode, outer_posexplode, outer_inline functions and expressions. Some bug fixing in GenerateExec.scala for CollectionGenerator. Previously it was not correctly handling the case of outer with empty collections, only with nulls. ## How was this patch tested? New tests added to GeneratorFunctionSuite Author: Bogdan Raducanu <[email protected]> Closes #16608 from bogdanrdc/SPARK-13721.
1 parent 83dff87 commit 2992a0e

File tree

9 files changed

+150
-16
lines changed

9 files changed

+150
-16
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1619,11 +1619,18 @@ class Analyzer(
16191619
case _ => expr
16201620
}
16211621

1622-
/** Extracts a [[Generator]] expression and any names assigned by aliases to their output. */
16231622
private object AliasedGenerator {
1624-
def unapply(e: Expression): Option[(Generator, Seq[String])] = e match {
1625-
case Alias(g: Generator, name) if g.resolved => Some((g, name :: Nil))
1626-
case MultiAlias(g: Generator, names) if g.resolved => Some(g, names)
1623+
/**
1624+
* Extracts a [[Generator]] expression, any names assigned by aliases to the outputs
1625+
* and the outer flag. The outer flag is used when joining the generator output.
1626+
* @param e the [[Expression]]
1627+
* @return (the [[Generator]], seq of output names, outer flag)
1628+
*/
1629+
def unapply(e: Expression): Option[(Generator, Seq[String], Boolean)] = e match {
1630+
case Alias(GeneratorOuter(g: Generator), name) if g.resolved => Some((g, name :: Nil, true))
1631+
case MultiAlias(GeneratorOuter(g: Generator), names) if g.resolved => Some(g, names, true)
1632+
case Alias(g: Generator, name) if g.resolved => Some((g, name :: Nil, false))
1633+
case MultiAlias(g: Generator, names) if g.resolved => Some(g, names, false)
16271634
case _ => None
16281635
}
16291636
}
@@ -1644,7 +1651,8 @@ class Analyzer(
16441651
var resolvedGenerator: Generate = null
16451652

16461653
val newProjectList = projectList.flatMap {
1647-
case AliasedGenerator(generator, names) if generator.childrenResolved =>
1654+
1655+
case AliasedGenerator(generator, names, outer) if generator.childrenResolved =>
16481656
// It's a sanity check, this should not happen as the previous case will throw
16491657
// exception earlier.
16501658
assert(resolvedGenerator == null, "More than one generator found in SELECT.")
@@ -1653,7 +1661,7 @@ class Analyzer(
16531661
Generate(
16541662
generator,
16551663
join = projectList.size > 1, // Only join if there are other expressions in SELECT.
1656-
outer = false,
1664+
outer = outer,
16571665
qualifier = None,
16581666
generatorOutput = ResolveGenerate.makeGeneratorOutput(generator, names),
16591667
child)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,11 @@ object FunctionRegistry {
163163
expression[Abs]("abs"),
164164
expression[Coalesce]("coalesce"),
165165
expression[Explode]("explode"),
166+
expressionGeneratorOuter[Explode]("explode_outer"),
166167
expression[Greatest]("greatest"),
167168
expression[If]("if"),
168169
expression[Inline]("inline"),
170+
expressionGeneratorOuter[Inline]("inline_outer"),
169171
expression[IsNaN]("isnan"),
170172
expression[IfNull]("ifnull"),
171173
expression[IsNull]("isnull"),
@@ -176,6 +178,7 @@ object FunctionRegistry {
176178
expression[Nvl]("nvl"),
177179
expression[Nvl2]("nvl2"),
178180
expression[PosExplode]("posexplode"),
181+
expressionGeneratorOuter[PosExplode]("posexplode_outer"),
179182
expression[Rand]("rand"),
180183
expression[Randn]("randn"),
181184
expression[Stack]("stack"),
@@ -508,4 +511,13 @@ object FunctionRegistry {
508511
new ExpressionInfo(clazz.getCanonicalName, name)
509512
}
510513
}
514+
515+
private def expressionGeneratorOuter[T <: Generator : ClassTag](name: String)
516+
: (String, (ExpressionInfo, FunctionBuilder)) = {
517+
val (_, (info, generatorBuilder)) = expression[T](name)
518+
val outerBuilder = (args: Seq[Expression]) => {
519+
GeneratorOuter(generatorBuilder(args).asInstanceOf[Generator])
520+
}
521+
(name, (info, outerBuilder))
522+
}
511523
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,15 @@ case class Stack(children: Seq[Expression]) extends Generator {
204204
}
205205
}
206206

207+
case class GeneratorOuter(child: Generator) extends UnaryExpression with Generator {
208+
final override def eval(input: InternalRow = null): TraversableOnce[InternalRow] =
209+
throw new UnsupportedOperationException(s"Cannot evaluate expression: $this")
210+
211+
final override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
212+
throw new UnsupportedOperationException(s"Cannot evaluate expression: $this")
213+
214+
override def elementSchema: StructType = child.elementSchema
215+
}
207216
/**
208217
* A base class for [[Explode]] and [[PosExplode]].
209218
*/

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,17 @@ case class Generate(
101101

102102
override def producedAttributes: AttributeSet = AttributeSet(generatorOutput)
103103

104-
def qualifiedGeneratorOutput: Seq[Attribute] = qualifier.map { q =>
105-
// prepend the new qualifier to the existed one
106-
generatorOutput.map(a => a.withQualifier(Some(q)))
107-
}.getOrElse(generatorOutput)
104+
def qualifiedGeneratorOutput: Seq[Attribute] = {
105+
val qualifiedOutput = qualifier.map { q =>
106+
// prepend the new qualifier to the existed one
107+
generatorOutput.map(a => a.withQualifier(Some(q)))
108+
}.getOrElse(generatorOutput)
109+
val nullableOutput = qualifiedOutput.map {
110+
// if outer, make all attributes nullable, otherwise keep existing nullability
111+
a => a.withNullability(outer || a.nullable)
112+
}
113+
nullableOutput
114+
}
108115

109116
def output: Seq[Attribute] = {
110117
if (join) child.output ++ qualifiedGeneratorOutput else qualifiedGeneratorOutput

sql/core/src/main/scala/org/apache/spark/sql/Column.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -166,10 +166,7 @@ class Column(val expr: Expression) extends Logging {
166166

167167
// Leave an unaliased generator with an empty list of names since the analyzer will generate
168168
// the correct defaults after the nested expression's type has been resolved.
169-
case explode: Explode => MultiAlias(explode, Nil)
170-
case explode: PosExplode => MultiAlias(explode, Nil)
171-
172-
case jt: JsonTuple => MultiAlias(jt, Nil)
169+
case g: Generator => MultiAlias(g, Nil)
173170

174171
case func: UnresolvedFunction => UnresolvedAlias(func, Some(Column.generateAlias))
175172

sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,11 +162,15 @@ case class GenerateExec(
162162
val index = ctx.freshName("index")
163163

164164
// Add a check if the generate outer flag is true.
165-
val checks = optionalCode(outer, data.isNull)
165+
val checks = optionalCode(outer, s"($index == -1)")
166166

167167
// Add position
168168
val position = if (e.position) {
169-
Seq(ExprCode("", "false", index))
169+
if (outer) {
170+
Seq(ExprCode("", s"$index == -1", index))
171+
} else {
172+
Seq(ExprCode("", "false", index))
173+
}
170174
} else {
171175
Seq.empty
172176
}

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2870,6 +2870,15 @@ object functions {
28702870
*/
28712871
def explode(e: Column): Column = withExpr { Explode(e.expr) }
28722872

2873+
/**
2874+
* Creates a new row for each element in the given array or map column.
2875+
* Unlike explode, if the array/map is null or empty then null is produced.
2876+
*
2877+
* @group collection_funcs
2878+
* @since 2.2.0
2879+
*/
2880+
def explode_outer(e: Column): Column = withExpr { GeneratorOuter(Explode(e.expr)) }
2881+
28732882
/**
28742883
* Creates a new row for each element with position in the given array or map column.
28752884
*
@@ -2878,6 +2887,15 @@ object functions {
28782887
*/
28792888
def posexplode(e: Column): Column = withExpr { PosExplode(e.expr) }
28802889

2890+
/**
2891+
* Creates a new row for each element with position in the given array or map column.
2892+
* Unlike posexplode, if the array/map is null or empty then the row (null, null) is produced.
2893+
*
2894+
* @group collection_funcs
2895+
* @since 2.2.0
2896+
*/
2897+
def posexplode_outer(e: Column): Column = withExpr { GeneratorOuter(PosExplode(e.expr)) }
2898+
28812899
/**
28822900
* Extracts json object from a json string based on json path specified, and returns json string
28832901
* of the extracted json object. It will return null if the input json string is invalid.

sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,13 +87,27 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
8787
Row(1) :: Row(2) :: Row(3) :: Nil)
8888
}
8989

90+
test("single explode_outer") {
91+
val df = Seq((1, Seq(1, 2, 3)), (2, Seq())).toDF("a", "intList")
92+
checkAnswer(
93+
df.select(explode_outer('intList)),
94+
Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil)
95+
}
96+
9097
test("single posexplode") {
9198
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
9299
checkAnswer(
93100
df.select(posexplode('intList)),
94101
Row(0, 1) :: Row(1, 2) :: Row(2, 3) :: Nil)
95102
}
96103

104+
test("single posexplode_outer") {
105+
val df = Seq((1, Seq(1, 2, 3)), (2, Seq())).toDF("a", "intList")
106+
checkAnswer(
107+
df.select(posexplode_outer('intList)),
108+
Row(0, 1) :: Row(1, 2) :: Row(2, 3) :: Row(null, null) :: Nil)
109+
}
110+
97111
test("explode and other columns") {
98112
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
99113

@@ -110,6 +124,26 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
110124
Row(1, Seq(1, 2, 3), 3) :: Nil)
111125
}
112126

127+
test("explode_outer and other columns") {
128+
val df = Seq((1, Seq(1, 2, 3)), (2, Seq())).toDF("a", "intList")
129+
130+
checkAnswer(
131+
df.select($"a", explode_outer('intList)),
132+
Row(1, 1) ::
133+
Row(1, 2) ::
134+
Row(1, 3) ::
135+
Row(2, null) ::
136+
Nil)
137+
138+
checkAnswer(
139+
df.select($"*", explode_outer('intList)),
140+
Row(1, Seq(1, 2, 3), 1) ::
141+
Row(1, Seq(1, 2, 3), 2) ::
142+
Row(1, Seq(1, 2, 3), 3) ::
143+
Row(2, Seq(), null) ::
144+
Nil)
145+
}
146+
113147
test("aliased explode") {
114148
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
115149

@@ -122,6 +156,18 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
122156
Row(6) :: Nil)
123157
}
124158

159+
test("aliased explode_outer") {
160+
val df = Seq((1, Seq(1, 2, 3)), (2, Seq())).toDF("a", "intList")
161+
162+
checkAnswer(
163+
df.select(explode_outer('intList).as('int)).select('int),
164+
Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil)
165+
166+
checkAnswer(
167+
df.select(explode('intList).as('int)).select(sum('int)),
168+
Row(6) :: Nil)
169+
}
170+
125171
test("explode on map") {
126172
val df = Seq((1, Map("a" -> "b"))).toDF("a", "map")
127173

@@ -130,6 +176,15 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
130176
Row("a", "b"))
131177
}
132178

179+
test("explode_outer on map") {
180+
val df = Seq((1, Map("a" -> "b")), (2, Map[String, String]()),
181+
(3, Map("c" -> "d"))).toDF("a", "map")
182+
183+
checkAnswer(
184+
df.select(explode_outer('map)),
185+
Row("a", "b") :: Row(null, null) :: Row("c", "d") :: Nil)
186+
}
187+
133188
test("explode on map with aliases") {
134189
val df = Seq((1, Map("a" -> "b"))).toDF("a", "map")
135190

@@ -138,6 +193,14 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
138193
Row("a", "b"))
139194
}
140195

196+
test("explode_outer on map with aliases") {
197+
val df = Seq((3, None), (1, Some(Map("a" -> "b")))).toDF("a", "map")
198+
199+
checkAnswer(
200+
df.select(explode_outer('map).as("key1" :: "value1" :: Nil)).select("key1", "value1"),
201+
Row("a", "b") :: Row(null, null) :: Nil)
202+
}
203+
141204
test("self join explode") {
142205
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
143206
val exploded = df.select(explode('intList).as('i))
@@ -207,6 +270,19 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
207270
Row(1) :: Row(2) :: Nil)
208271
}
209272

273+
test("inline_outer") {
274+
val df = Seq((1, "2"), (3, "4"), (5, "6")).toDF("col1", "col2")
275+
val df2 = df.select(when('col1 === 1, null).otherwise(array(struct('col1, 'col2))).as("col1"))
276+
checkAnswer(
277+
df2.selectExpr("inline(col1)"),
278+
Row(3, "4") :: Row(5, "6") :: Nil
279+
)
280+
checkAnswer(
281+
df2.selectExpr("inline_outer(col1)"),
282+
Row(null, null) :: Row(3, "4") :: Row(5, "6") :: Nil
283+
)
284+
}
285+
210286
test("SPARK-14986: Outer lateral view with empty generate expression") {
211287
checkAnswer(
212288
sql("select nil from values 1 lateral view outer explode(array()) n as nil"),

sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils {
9393
checkSqlGeneration("SELECT array(1,2,3)")
9494
checkSqlGeneration("SELECT coalesce(null, 1, 2)")
9595
checkSqlGeneration("SELECT explode(array(1,2,3))")
96+
checkSqlGeneration("SELECT explode_outer(array())")
9697
checkSqlGeneration("SELECT greatest(1,null,3)")
9798
checkSqlGeneration("SELECT if(1==2, 'yes', 'no')")
9899
checkSqlGeneration("SELECT isnan(15), isnan('invalid')")
@@ -102,6 +103,8 @@ class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils {
102103
checkSqlGeneration("SELECT map(1, 'a', 2, 'b')")
103104
checkSqlGeneration("SELECT named_struct('c1',1,'c2',2,'c3',3)")
104105
checkSqlGeneration("SELECT nanvl(a, 5), nanvl(b, 10), nanvl(d, c) from t2")
106+
checkSqlGeneration("SELECT posexplode_outer(array())")
107+
checkSqlGeneration("SELECT inline_outer(array(struct('a', 1)))")
105108
checkSqlGeneration("SELECT rand(1)")
106109
checkSqlGeneration("SELECT randn(3)")
107110
checkSqlGeneration("SELECT struct(1,2,3)")

0 commit comments

Comments
 (0)