Skip to content

Commit b16ddb8

Browse files
committed
Add distinct & sorting. Improve aggregation.
1 parent edec2d8 commit b16ddb8

File tree

2 files changed

+27
-22
lines changed

2 files changed

+27
-22
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2348,12 +2348,36 @@ object SortMaps extends Rule[LogicalPlan] {
23482348
cmp.withNewChildren(SortMap(left) :: right :: Nil)
23492349
case cmp @ BinaryComparison(left, right) if cmp.resolved && hasUnorderedMap(right) =>
23502350
cmp.withNewChildren(left :: SortMap(right) :: Nil)
2351+
case sort: SortOrder if sort.resolved && hasUnorderedMap(sort.child) =>
2352+
sort.copy(child = SortMap(sort.child))
23512353
} transform {
23522354
case a: Aggregate if a.resolved && a.groupingExpressions.exists(hasUnorderedMap) =>
2353-
a.transformExpressionsUp {
2355+
// Modify the top level grouping expressions
2356+
val replacements = a.groupingExpressions.collect {
23542357
case a: Attribute if hasUnorderedMap(a) =>
2358+
a -> Alias(SortMap(a), a.name)(exprId = a.exprId, qualifier = a.qualifier)
2359+
case e if hasUnorderedMap(e) =>
2360+
e -> SortMap(e)
2361+
}
2362+
2363+
// Tranform the expression tree.
2364+
a.transformExpressionsUp {
2365+
case e =>
2366+
// TODO create an expression map!
2367+
replacements
2368+
.find(_._1.semanticEquals(e))
2369+
.map(_._2)
2370+
.getOrElse(e)
2371+
}
2372+
2373+
case Distinct(child) if child.resolved && child.output.exists(hasUnorderedMap) =>
2374+
val projectList = child.output.map { a =>
2375+
if (hasUnorderedMap(a)) {
23552376
Alias(SortMap(a), a.name)(exprId = a.exprId, qualifier = a.qualifier)
2356-
case e if hasUnorderedMap(e) => SortMap(e)
2377+
} else {
2378+
a
2379+
}
23572380
}
2381+
Distinct(Project(projectList, child))
23582382
}
23592383
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -439,12 +439,7 @@ class AnalysisErrorSuite extends AnalysisTest {
439439
checkDataType(dataType, shouldSuccess = true)
440440
}
441441

442-
val unsupportedDataTypes = Seq(
443-
MapType(StringType, LongType),
444-
new StructType()
445-
.add("f1", FloatType, nullable = true)
446-
.add("f2", MapType(StringType, LongType), nullable = true),
447-
new UngroupableUDT())
442+
val unsupportedDataTypes = Seq(new UngroupableUDT())
448443
unsupportedDataTypes.foreach { dataType =>
449444
checkDataType(dataType, shouldSuccess = false)
450445
}
@@ -479,20 +474,6 @@ class AnalysisErrorSuite extends AnalysisTest {
479474
AttributeReference("c", BinaryType)(exprId = ExprId(4)))))
480475

481476
assertAnalysisError(plan, "binary type expression `a` cannot be used in join conditions" :: Nil)
482-
483-
val plan2 =
484-
Join(
485-
LocalRelation(
486-
AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)),
487-
AttributeReference("b", IntegerType)(exprId = ExprId(1))),
488-
LocalRelation(
489-
AttributeReference("c", MapType(IntegerType, StringType))(exprId = ExprId(4)),
490-
AttributeReference("d", IntegerType)(exprId = ExprId(3))),
491-
Cross,
492-
Some(EqualTo(AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)),
493-
AttributeReference("c", MapType(IntegerType, StringType))(exprId = ExprId(4)))))
494-
495-
assertAnalysisError(plan2, "map type expression `a` cannot be used in join conditions" :: Nil)
496477
}
497478

498479
test("PredicateSubQuery is used outside of a filter") {

0 commit comments

Comments
 (0)