Skip to content

Commit ca5e7f4

Browse files
shrink the commits
1 parent 77eeb10 commit ca5e7f4

File tree

26 files changed

+189
-147
lines changed

26 files changed

+189
-147
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.analysis
1919

2020
import org.apache.spark.util.collection.OpenHashSet
2121
import org.apache.spark.sql.AnalysisException
22-
import org.apache.spark.sql.catalyst.errors.TreeNodeException
2322
import org.apache.spark.sql.catalyst.expressions._
2423
import org.apache.spark.sql.catalyst.plans.logical._
2524
import org.apache.spark.sql.catalyst.rules._
@@ -59,6 +58,7 @@ class Analyzer(
5958
ResolveReferences ::
6059
ResolveGroupingAnalytics ::
6160
ResolveSortReferences ::
61+
ResolveGenerate ::
6262
ImplicitGenerate ::
6363
ResolveFunctions ::
6464
GlobalAggregates ::
@@ -473,10 +473,47 @@ class Analyzer(
473473
*/
474474
object ImplicitGenerate extends Rule[LogicalPlan] {
475475
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
476-
case Project(Seq(Alias(g: Generator, _)), child) =>
477-
Generate(g, join = false, outer = false, None, child)
476+
case Project(Seq(Alias(g: Generator, name)), child) =>
477+
Generate(g, join = false, outer = false, child, qualifier = None, name :: Nil, Nil)
478+
case Project(Seq(MultiAlias(g: Generator, names)), child) =>
479+
Generate(g, join = false, outer = false, child, qualifier = None, names, Nil)
478480
}
479481
}
482+
483+
object ResolveGenerate extends Rule[LogicalPlan] {
484+
// Construct the output attributes for the generator,
485+
// The output attribute names can be either specified or
486+
// auto generated.
487+
private def makeGeneratorOutput(
488+
generator: Generator,
489+
attributeNames: Seq[String],
490+
qualifier: Option[String]): Array[Attribute] = {
491+
val elementTypes = generator.elementTypes
492+
493+
val raw = if (attributeNames.size == elementTypes.size) {
494+
attributeNames.zip(elementTypes).map {
495+
case (n, (t, nullable)) => AttributeReference(n, t, nullable)()
496+
}
497+
} else {
498+
elementTypes.zipWithIndex.map {
499+
// keep the default column names as Hive does _c0, _c1, _cN
500+
case ((t, nullable), i) => AttributeReference(s"_c$i", t, nullable)()
501+
}
502+
}
503+
504+
qualifier.map(q => raw.map(_.withQualifiers(q :: Nil))).getOrElse(raw).toArray[Attribute]
505+
}
506+
507+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
508+
case p: Generate if !p.child.resolved || !p.generator.resolved => p
509+
case p: Generate if p.resolved == false =>
510+
// if the generator output names are not specified, we will use the default ones.
511+
val gOutput = makeGeneratorOutput(p.generator, p.attributeNames, p.qualifier)
512+
Generate(
513+
p.generator, p.join, p.outer, p.child, p.qualifier, gOutput.map(_.name), gOutput)
514+
}
515+
}
516+
480517
}
481518

482519
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,12 @@ trait CheckAnalysis {
3838
throw new AnalysisException(msg)
3939
}
4040

41+
def containsMultipleGenerators(exprs: Seq[Expression]): Boolean = {
42+
exprs.flatMap(_.collect {
43+
case e: Generator => true
44+
}).length >= 1
45+
}
46+
4147
def checkAnalysis(plan: LogicalPlan): Unit = {
4248
// We transform up and order the rules so as to catch the first possible failure instead
4349
// of the result of cascading resolution failures.
@@ -107,6 +113,12 @@ trait CheckAnalysis {
107113
failAnalysis(
108114
s"unresolved operator ${operator.simpleString}")
109115

116+
case p @ Project(exprs, _) if containsMultipleGenerators(exprs) =>
117+
failAnalysis(
118+
s"""Only a single table generating function is allowed in a SELECT clause, found:
119+
| ${exprs.map(_.prettyString).mkString(",")}""".stripMargin)
120+
121+
110122
case _ => // Analysis successful!
111123
}
112124
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -284,12 +284,13 @@ package object dsl {
284284
seed: Int = (math.random * 1000).toInt): LogicalPlan =
285285
Sample(fraction, withReplacement, seed, logicalPlan)
286286

287+
// TODO specify the output column names
287288
def generate(
288289
generator: Generator,
289290
join: Boolean = false,
290291
outer: Boolean = false,
291-
alias: Option[String] = None): LogicalPlan =
292-
Generate(generator, join, outer, None, logicalPlan)
292+
alias: Option[String] = None): Generate =
293+
Generate(generator, join, outer, logicalPlan, alias)
293294

294295
def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan =
295296
InsertIntoTable(

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala

Lines changed: 7 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -42,47 +42,27 @@ abstract class Generator extends Expression {
4242

4343
override type EvaluatedType = TraversableOnce[Row]
4444

45-
override lazy val dataType =
46-
ArrayType(StructType(output.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))))
45+
override def dataType: DataType = ???
4746

4847
override def nullable: Boolean = false
4948

5049
/**
51-
* Should be overridden by specific generators. Called only once for each instance to ensure
52-
* that rule application does not change the output schema of a generator.
50+
* The output element data types in structure of Seq[(DataType, Nullable)]
5351
*/
54-
protected def makeOutput(): Seq[Attribute]
55-
56-
private var _output: Seq[Attribute] = null
57-
58-
def output: Seq[Attribute] = {
59-
if (_output == null) {
60-
_output = makeOutput()
61-
}
62-
_output
63-
}
52+
def elementTypes: Seq[(DataType, Boolean)]
6453

6554
/** Should be implemented by child classes to perform specific Generators. */
6655
override def eval(input: Row): TraversableOnce[Row]
67-
68-
/** Overridden `makeCopy` also copies the attributes that are produced by this generator. */
69-
override def makeCopy(newArgs: Array[AnyRef]): this.type = {
70-
val copy = super.makeCopy(newArgs)
71-
copy._output = _output
72-
copy
73-
}
7456
}
7557

7658
/**
7759
* A generator that produces its output using the provided lambda function.
7860
*/
7961
case class UserDefinedGenerator(
80-
schema: Seq[Attribute],
62+
elementTypes: Seq[(DataType, Boolean)],
8163
function: Row => TraversableOnce[Row],
8264
children: Seq[Expression])
83-
extends Generator{
84-
85-
override protected def makeOutput(): Seq[Attribute] = schema
65+
extends Generator {
8666

8767
override def eval(input: Row): TraversableOnce[Row] = {
8868
val inputRow = new InterpretedProjection(children)
@@ -95,30 +75,18 @@ case class UserDefinedGenerator(
9575
/**
9676
* Given an input array produces a sequence of rows for each value in the array.
9777
*/
98-
case class Explode(attributeNames: Seq[String], child: Expression)
78+
case class Explode(child: Expression)
9979
extends Generator with trees.UnaryNode[Expression] {
10080

10181
override lazy val resolved =
10282
child.resolved &&
10383
(child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType])
10484

105-
private lazy val elementTypes = child.dataType match {
85+
override def elementTypes: Seq[(DataType, Boolean)] = child.dataType match {
10686
case ArrayType(et, containsNull) => (et, containsNull) :: Nil
10787
case MapType(kt, vt, valueContainsNull) => (kt, false) :: (vt, valueContainsNull) :: Nil
10888
}
10989

110-
// TODO: Move this pattern into Generator.
111-
protected def makeOutput() =
112-
if (attributeNames.size == elementTypes.size) {
113-
attributeNames.zip(elementTypes).map {
114-
case (n, (t, nullable)) => AttributeReference(n, t, nullable)()
115-
}
116-
} else {
117-
elementTypes.zipWithIndex.map {
118-
case ((t, nullable), i) => AttributeReference(s"c_$i", t, nullable)()
119-
}
120-
}
121-
12290
override def eval(input: Row): TraversableOnce[Row] = {
12391
child.dataType match {
12492
case ArrayType(_, _) =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ case class Alias(child: Expression, name: String)(
112112
extends NamedExpression with trees.UnaryNode[Expression] {
113113

114114
override type EvaluatedType = Any
115+
// Alias(Generator, xx) need to be transformed into Generate(generator, ...)
116+
override lazy val resolved = childrenResolved && !child.isInstanceOf[Generator]
115117

116118
override def eval(input: Row): Any = child.eval(input)
117119

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -477,16 +477,16 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] {
477477
object PushPredicateThroughGenerate extends Rule[LogicalPlan] with PredicateHelper {
478478

479479
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
480-
case filter @ Filter(condition,
481-
generate @ Generate(generator, join, outer, alias, grandChild)) =>
480+
case filter @ Filter(condition, g: Generate) =>
482481
// Predicates that reference attributes produced by the `Generate` operator cannot
483482
// be pushed below the operator.
484483
val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition {
485-
conjunct => conjunct.references subsetOf grandChild.outputSet
484+
conjunct => conjunct.references subsetOf g.child.outputSet
486485
}
487486
if (pushDown.nonEmpty) {
488487
val pushDownPredicate = pushDown.reduce(And)
489-
val withPushdown = generate.copy(child = Filter(pushDownPredicate, grandChild))
488+
val withPushdown = Generate(g.generator, join = g.join, outer = g.outer,
489+
Filter(pushDownPredicate, g.child), g.qualifier, g.attributeNames, g.gOutput)
490490
stayUp.reduceOption(And).map(Filter(_, withPushdown)).getOrElse(withPushdown)
491491
} else {
492492
filter

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -40,34 +40,41 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend
4040
* output of each into a new stream of rows. This operation is similar to a `flatMap` in functional
4141
* programming with one important additional feature, which allows the input rows to be joined with
4242
* their output.
43+
* @param generator the generator expression
4344
* @param join when true, each output row is implicitly joined with the input tuple that produced
4445
* it.
4546
* @param outer when true, each input row will be output at least once, even if the output of the
4647
* given `generator` is empty. `outer` has no effect when `join` is false.
47-
* @param alias when set, this string is applied to the schema of the output of the transformation
48-
* as a qualifier.
48+
* @param child Children logical plan node
49+
* @param qualifier Qualifier for the attributes of generator(UDTF)
50+
* @param attributeNames the column names for the generator(UDTF), will be _c0, _c1 .. _cN if
51+
* leave as default (empty)
52+
* @param gOutput The output of Generator.
4953
*/
5054
case class Generate(
5155
generator: Generator,
5256
join: Boolean,
5357
outer: Boolean,
54-
alias: Option[String],
55-
child: LogicalPlan)
58+
child: LogicalPlan,
59+
qualifier: Option[String] = None,
60+
attributeNames: Seq[String] = Nil,
61+
gOutput: Seq[Attribute] = Nil)
5662
extends UnaryNode {
5763

58-
protected def generatorOutput: Seq[Attribute] = {
59-
val output = alias
60-
.map(a => generator.output.map(_.withQualifiers(a :: Nil)))
61-
.getOrElse(generator.output)
62-
if (join && outer) {
63-
output.map(_.withNullability(true))
64-
} else {
65-
output
66-
}
64+
override lazy val resolved: Boolean = {
65+
generator.resolved &&
66+
childrenResolved &&
67+
attributeNames.length > 0 &&
68+
gOutput.map(_.name) == attributeNames
6769
}
6870

69-
override def output: Seq[Attribute] =
70-
if (join) child.output ++ generatorOutput else generatorOutput
71+
// we don't want the gOutput to be taken as part of the expressions
72+
// as that will cause exceptions like unresolved attributes etc.
73+
override def expressions: Seq[Expression] = generator :: Nil
74+
75+
def output: Seq[Attribute] = {
76+
if (join) child.output ++ gOutput else gOutput
77+
}
7178
}
7279

7380
case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
9292

9393
assert(!Project(Seq(UnresolvedAttribute("a")), testRelation).resolved)
9494

95-
val explode = Explode(Nil, AttributeReference("a", IntegerType, nullable = true)())
95+
val explode = Explode(AttributeReference("a", IntegerType, nullable = true)())
9696
assert(!Project(Seq(Alias(explode, "explode")()), testRelation).resolved)
9797

9898
assert(!Project(Seq(Alias(Count(Literal(1)), "count")()), testRelation).resolved)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -454,21 +454,21 @@ class FilterPushdownSuite extends PlanTest {
454454
test("generate: predicate referenced no generated column") {
455455
val originalQuery = {
456456
testRelationWithArrayType
457-
.generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr"))
457+
.generate(Explode('c_arr), true, false, Some("arr"))
458458
.where(('b >= 5) && ('a > 6))
459459
}
460460
val optimized = Optimize(originalQuery.analyze)
461461
val correctAnswer = {
462462
testRelationWithArrayType
463463
.where(('b >= 5) && ('a > 6))
464-
.generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr")).analyze
464+
.generate(Explode('c_arr), true, false, Some("arr")).analyze
465465
}
466466

467467
comparePlans(optimized, correctAnswer)
468468
}
469469

470470
test("generate: part of conjuncts referenced generated column") {
471-
val generator = Explode(Seq("c"), 'c_arr)
471+
val generator = Explode('c_arr)
472472
val originalQuery = {
473473
testRelationWithArrayType
474474
.generate(generator, true, false, Some("arr"))
@@ -499,7 +499,7 @@ class FilterPushdownSuite extends PlanTest {
499499
test("generate: all conjuncts referenced generated column") {
500500
val originalQuery = {
501501
testRelationWithArrayType
502-
.generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr"))
502+
.generate(Explode('c_arr), true, false, Some("arr"))
503503
.where(('c > 6) || ('b > 5)).analyze
504504
}
505505
val optimized = Optimize(originalQuery)

sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -711,12 +711,15 @@ class DataFrame private[sql](
711711
*/
712712
def explode[A <: Product : TypeTag](input: Column*)(f: Row => TraversableOnce[A]): DataFrame = {
713713
val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
714-
val attributes = schema.toAttributes
714+
715+
val elementTypes = schema.toAttributes.map { attr => (attr.dataType, attr.nullable) }
716+
val names = schema.toAttributes.map(_.name)
717+
715718
val rowFunction =
716719
f.andThen(_.map(CatalystTypeConverters.convertToCatalyst(_, schema).asInstanceOf[Row]))
717-
val generator = UserDefinedGenerator(attributes, rowFunction, input.map(_.expr))
720+
val generator = UserDefinedGenerator(elementTypes, rowFunction, input.map(_.expr))
718721

719-
Generate(generator, join = true, outer = false, None, logicalPlan)
722+
Generate(generator, join = true, outer = false, logicalPlan, qualifier = None, names, Nil)
720723
}
721724

722725
/**
@@ -733,12 +736,16 @@ class DataFrame private[sql](
733736
: DataFrame = {
734737
val dataType = ScalaReflection.schemaFor[B].dataType
735738
val attributes = AttributeReference(outputColumn, dataType)() :: Nil
739+
// TODO handle the metadata?
740+
val elementTypes = attributes.map { attr => (attr.dataType, attr.nullable) }
741+
val names = attributes.map(_.name)
742+
736743
def rowFunction(row: Row): TraversableOnce[Row] = {
737744
f(row(0).asInstanceOf[A]).map(o => Row(CatalystTypeConverters.convertToCatalyst(o, dataType)))
738745
}
739-
val generator = UserDefinedGenerator(attributes, rowFunction, apply(inputColumn).expr :: Nil)
746+
val generator = UserDefinedGenerator(elementTypes, rowFunction, apply(inputColumn).expr :: Nil)
740747

741-
Generate(generator, join = true, outer = false, None, logicalPlan)
748+
Generate(generator, join = true, outer = false, logicalPlan, qualifier = None, names, Nil)
742749
}
743750

744751
/////////////////////////////////////////////////////////////////////////////

0 commit comments

Comments
 (0)