Skip to content

Commit 7d847b8

Browse files
committed
Fix code generation. Fix joins.
1 parent e4cc4b0 commit 7d847b8

File tree

6 files changed

+24
-27
lines changed

6 files changed

+24
-27
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2336,14 +2336,14 @@ object ResolveCreateNamedStruct extends Rule[LogicalPlan] {
23362336
*/
23372337
object SortMaps extends Rule[LogicalPlan] {
23382338
private def containsUnorderedMap(e: Expression): Boolean =
2339-
MapType.containsUnorderedMap(e.dataType)
2339+
e.resolved && MapType.containsUnorderedMap(e.dataType)
23402340

23412341
override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions {
2342-
case cmp @ BinaryComparison(left, right) if cmp.resolved && containsUnorderedMap(left) =>
2342+
case cmp @ BinaryComparison(left, right) if containsUnorderedMap(left) =>
23432343
cmp.withNewChildren(OrderMaps(left) :: right :: Nil)
2344-
case cmp @ BinaryComparison(left, right) if cmp.resolved && containsUnorderedMap(right) =>
2344+
case cmp @ BinaryComparison(left, right) if containsUnorderedMap(right) =>
23452345
cmp.withNewChildren(left :: OrderMaps(right) :: Nil)
2346-
case sort: SortOrder if sort.resolved && containsUnorderedMap(sort.child) =>
2346+
case sort: SortOrder if containsUnorderedMap(sort.child) =>
23472347
sort.copy(child = OrderMaps(sort.child))
23482348
} transform {
23492349
case a: Aggregate if a.resolved && a.groupingExpressions.exists(containsUnorderedMap) =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,7 @@ class CodegenContext {
484484
case dt: DataType if dt.isInstanceOf[AtomicType] => s"$c1.equals($c2)"
485485
case array: ArrayType => genComp(array, c1, c2) + " == 0"
486486
case struct: StructType => genComp(struct, c1, c2) + " == 0"
487+
case map: MapType if map.ordered => genComp(map, c1, c2) + " == 0"
487488
case udt: UserDefinedType[_] => genEqual(udt.sqlType, c1, c2)
488489
case _ =>
489490
throw new IllegalArgumentException(
@@ -567,9 +568,9 @@ class CodegenContext {
567568
ArrayData bValues = b.valueArray();
568569
int minLength = (lengthA > lengthB) ? lengthB : lengthA;
569570
for (int i = 0; i < minLength; i++) {
570-
${javaType(keyType)} keyA = ${getValue("aKeys", valueType, "i")};
571-
${javaType(keyType)} keyB = ${getValue("bKeys", valueType, "i")};
572-
int comp = ${genComp(valueType, "keyA", "keyB")};
571+
${javaType(keyType)} keyA = ${getValue("aKeys", keyType, "i")};
572+
${javaType(keyType)} keyB = ${getValue("bKeys", keyType, "i")};
573+
int comp = ${genComp(keyType, "keyA", "keyB")};
573574
if (comp != 0) {
574575
return comp;
575576
}
@@ -584,19 +585,13 @@ class CodegenContext {
584585
} else {
585586
${javaType(valueType)} valueA = ${getValue("aValues", valueType, "i")};
586587
${javaType(valueType)} valueB = ${getValue("bValues", valueType, "i")};
587-
int comp = ${genComp(valueType, "valueA", "valueB")};
588+
comp = ${genComp(valueType, "valueA", "valueB")};
588589
if (comp != 0) {
589590
return comp;
590591
}
591592
}
592593
}
593-
594-
if (lengthA < lengthB) {
595-
return -1;
596-
} else if (lengthA > lengthB) {
597-
return 1;
598-
}
599-
return 0;
594+
return lengthA - lengthB;
600595
}
601596
"""
602597
addNewFunction(compareFunc, funcCode)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -419,8 +419,8 @@ case class EqualTo(left: Expression, right: Expression)
419419
case TypeCheckResult.TypeCheckSuccess =>
420420
// Maps are only allowed when they are ordered.
421421
if (MapType.containsUnorderedMap(left.dataType)) {
422-
TypeCheckResult.TypeCheckFailure("Cannot use unordered map type in EqualTo, but " +
423-
s"the actual input type is ${left.dataType.catalogString}.")
422+
TypeCheckResult.TypeCheckFailure(
423+
s"Cannot use unordered map type in EqualTo: ${left.dataType.catalogString}.")
424424
} else {
425425
TypeCheckResult.TypeCheckSuccess
426426
}
@@ -452,8 +452,8 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
452452
EqualNullSafe
453453
// Maps are only allowed when they are ordered.
454454
if (MapType.containsUnorderedMap(left.dataType)) {
455-
TypeCheckResult.TypeCheckFailure("Cannot use unordered map type in EqualNullSafe, but " +
456-
s"the actual input type is ${left.dataType.catalogString}.")
455+
TypeCheckResult.TypeCheckFailure(
456+
s"Cannot use unordered map type in EqualNullSafe: ${left.dataType.catalogString}.")
457457
} else {
458458
TypeCheckResult.TypeCheckSuccess
459459
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,8 @@ class AnalysisErrorSuite extends AnalysisTest {
212212

213213
errorTest(
214214
"sorting by unsupported column types",
215-
mapRelation.orderBy('map.asc),
216-
"sort" :: "type" :: "map<int,int>" :: Nil)
215+
intervalRelation.orderBy('interval.asc),
216+
"sort" :: "type" :: "calendarinterval" :: Nil)
217217

218218
errorTest(
219219
"sorting by attributes are not from grouping expressions",

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
4040
val e = intercept[AnalysisException] {
4141
assertSuccess(expr)
4242
}
43-
assert(e.getMessage.contains(
44-
s"cannot resolve '${expr.sql}' due to data type mismatch:"))
43+
assert(e.getMessage.contains("cannot resolve "))
44+
assert(e.getMessage.contains("due to data type mismatch:"))
4545
assert(e.getMessage.contains(errorMessage))
4646
}
4747

@@ -51,8 +51,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
5151
}
5252

5353
def assertErrorForDifferingTypes(expr: Expression): Unit = {
54-
assertError(expr,
55-
s"differing types in '${expr.sql}'")
54+
assertError(expr, "differing types in")
5655
}
5756

5857
test("check types for unary arithmetic") {
@@ -99,6 +98,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
9998
assertSuccess(LessThanOrEqual('intField, 'stringField))
10099
assertSuccess(GreaterThan('intField, 'stringField))
101100
assertSuccess(GreaterThanOrEqual('intField, 'stringField))
101+
assertSuccess(EqualTo('mapField, 'mapField))
102+
assertSuccess(EqualNullSafe('mapField, 'mapField))
102103

103104
// We will transform EqualTo with numeric and boolean types to CaseKeyWhen
104105
assertSuccess(EqualTo('intField, 'booleanField))
@@ -111,8 +112,6 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
111112
assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField))
112113
assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField))
113114

114-
assertError(EqualTo('mapField, 'mapField), "Cannot use map type in EqualTo")
115-
assertError(EqualNullSafe('mapField, 'mapField), "Cannot use map type in EqualNullSafe")
116115
assertError(LessThan('mapField, 'mapField),
117116
s"requires ${TypeCollection.Ordered.simpleString} type")
118117
assertError(LessThanOrEqual('mapField, 'mapField),

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,4 +57,7 @@ object TestRelations {
5757

5858
val mapRelation = LocalRelation(
5959
AttributeReference("map", MapType(IntegerType, IntegerType))())
60+
61+
val intervalRelation = LocalRelation(
62+
AttributeReference("interval", CalendarIntervalType)())
6063
}

0 commit comments

Comments
 (0)