-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-18874][SQL] First phase: Deferring the correlated predicate pull up to Optimizer phase #16954
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-18874][SQL] First phase: Deferring the correlated predicate pull up to Optimizer phase #16954
Changes from all commits
fdb8d2a
5f62a2c
8a8a7af
f0d2e7f
55842fa
c677ed8
00c890e
27cb36a
19cdbb0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -108,6 +108,28 @@ object TypeCoercion { | |
| case _ => None | ||
| } | ||
|
|
||
| /** | ||
| * This function determines the target type of a comparison operator when one operand | ||
| * is a String and the other is not. It also handles when one op is a Date and the | ||
| * other is a Timestamp by making the target type to be String. | ||
| */ | ||
| val findCommonTypeForBinaryComparison: (DataType, DataType) => Option[DataType] = { | ||
| // We should cast all relative timestamp/date/string comparison into string comparisons | ||
| // This behaves as a user would expect because timestamp strings sort lexicographically. | ||
| // i.e. TimeStamp(2013-01-01 00:00 ...) < "2014" = true | ||
| case (StringType, DateType) => Some(StringType) | ||
| case (DateType, StringType) => Some(StringType) | ||
| case (StringType, TimestampType) => Some(StringType) | ||
| case (TimestampType, StringType) => Some(StringType) | ||
| case (TimestampType, DateType) => Some(StringType) | ||
|
||
| case (DateType, TimestampType) => Some(StringType) | ||
| case (StringType, NullType) => Some(StringType) | ||
| case (NullType, StringType) => Some(StringType) | ||
| case (l: StringType, r: AtomicType) if r != StringType => Some(r) | ||
| case (l: AtomicType, r: StringType) if (l != StringType) => Some(l) | ||
| case (l, r) => None | ||
| } | ||
|
|
||
| /** | ||
| * Case 2 type widening (see the classdoc comment above for TypeCoercion). | ||
| * | ||
|
|
@@ -305,6 +327,14 @@ object TypeCoercion { | |
| * Promotes strings that appear in arithmetic expressions. | ||
| */ | ||
| object PromoteStrings extends Rule[LogicalPlan] { | ||
| private def castExpr(expr: Expression, targetType: DataType): Expression = { | ||
| (expr.dataType, targetType) match { | ||
| case (NullType, dt) => Literal.create(null, targetType) | ||
| case (l, dt) if (l != dt) => Cast(expr, targetType) | ||
| case _ => expr | ||
| } | ||
| } | ||
|
|
||
| def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { | ||
| // Skip nodes who's children have not been resolved yet. | ||
| case e if !e.childrenResolved => e | ||
|
|
@@ -321,37 +351,10 @@ object TypeCoercion { | |
| case p @ Equality(left @ TimestampType(), right @ StringType()) => | ||
| p.makeCopy(Array(left, Cast(right, TimestampType))) | ||
|
|
||
| // We should cast all relative timestamp/date/string comparison into string comparisons | ||
| // This behaves as a user would expect because timestamp strings sort lexicographically. | ||
| // i.e. TimeStamp(2013-01-01 00:00 ...) < "2014" = true | ||
| case p @ BinaryComparison(left @ StringType(), right @ DateType()) => | ||
| p.makeCopy(Array(left, Cast(right, StringType))) | ||
| case p @ BinaryComparison(left @ DateType(), right @ StringType()) => | ||
| p.makeCopy(Array(Cast(left, StringType), right)) | ||
| case p @ BinaryComparison(left @ StringType(), right @ TimestampType()) => | ||
| p.makeCopy(Array(left, Cast(right, StringType))) | ||
| case p @ BinaryComparison(left @ TimestampType(), right @ StringType()) => | ||
| p.makeCopy(Array(Cast(left, StringType), right)) | ||
|
|
||
| // Comparisons between dates and timestamps. | ||
| case p @ BinaryComparison(left @ TimestampType(), right @ DateType()) => | ||
| p.makeCopy(Array(Cast(left, StringType), Cast(right, StringType))) | ||
| case p @ BinaryComparison(left @ DateType(), right @ TimestampType()) => | ||
| p.makeCopy(Array(Cast(left, StringType), Cast(right, StringType))) | ||
|
|
||
| // Checking NullType | ||
| case p @ BinaryComparison(left @ StringType(), right @ NullType()) => | ||
| p.makeCopy(Array(left, Literal.create(null, StringType))) | ||
| case p @ BinaryComparison(left @ NullType(), right @ StringType()) => | ||
| p.makeCopy(Array(Literal.create(null, StringType), right)) | ||
|
|
||
| // When compare string with atomic type, case string to that type. | ||
| case p @ BinaryComparison(left @ StringType(), right @ AtomicType()) | ||
| if right.dataType != StringType => | ||
| p.makeCopy(Array(Cast(left, right.dataType), right)) | ||
| case p @ BinaryComparison(left @ AtomicType(), right @ StringType()) | ||
| if left.dataType != StringType => | ||
| p.makeCopy(Array(left, Cast(right, left.dataType))) | ||
| case p @ BinaryComparison(left, right) | ||
| if findCommonTypeForBinaryComparison(left.dataType, right.dataType).isDefined => | ||
| val commonType = findCommonTypeForBinaryComparison(left.dataType, right.dataType).get | ||
| p.makeCopy(Array(castExpr(left, commonType), castExpr(right, commonType))) | ||
|
|
||
| case Sum(e @ StringType()) => Sum(Cast(e, DoubleType)) | ||
| case Average(e @ StringType()) => Average(Cast(e, DoubleType)) | ||
|
|
@@ -365,17 +368,72 @@ object TypeCoercion { | |
| } | ||
|
|
||
| /** | ||
| * Convert the value and in list expressions to the common operator type | ||
| * by looking at all the argument types and finding the closest one that | ||
| * all the arguments can be cast to. When no common operator type is found | ||
| * the original expression will be returned and an Analysis Exception will | ||
| * be raised at type checking phase. | ||
| * Handles type coercion for both IN expression with subquery and IN | ||
| * expressions without subquery. | ||
| * 1. In the first case, find the common type by comparing the left hand side (LHS) | ||
| * expression types against corresponding right hand side (RHS) expression derived | ||
| * from the subquery expression's plan output. Inject appropriate casts in the | ||
| * LHS and RHS side of IN expression. | ||
| * | ||
| * 2. In the second case, convert the value and in list expressions to the | ||
| * common operator type by looking at all the argument types and finding | ||
| * the closest one that all the arguments can be cast to. When no common | ||
| * operator type is found the original expression will be returned and an | ||
| * Analysis Exception will be raised at the type checking phase. | ||
| */ | ||
| object InConversion extends Rule[LogicalPlan] { | ||
| private def flattenExpr(expr: Expression): Seq[Expression] = { | ||
| expr match { | ||
| // Multi columns in IN clause is represented as a CreateNamedStruct. | ||
| // flatten the named struct to get the list of expressions. | ||
| case cns: CreateNamedStruct => cns.valExprs | ||
| case expr => Seq(expr) | ||
| } | ||
| } | ||
|
|
||
| def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { | ||
| // Skip nodes who's children have not been resolved yet. | ||
| case e if !e.childrenResolved => e | ||
|
|
||
| // Handle type casting required between value expression and subquery output | ||
| // in IN subquery. | ||
| case i @ In(a, Seq(ListQuery(sub, children, exprId))) | ||
| if !i.resolved && flattenExpr(a).length == sub.output.length => | ||
| // LHS is the value expression of IN subquery. | ||
| val lhs = flattenExpr(a) | ||
|
|
||
| // RHS is the subquery output. | ||
| val rhs = sub.output | ||
|
|
||
| val commonTypes = lhs.zip(rhs).flatMap { case (l, r) => | ||
| findCommonTypeForBinaryComparison(l.dataType, r.dataType) | ||
| .orElse(findTightestCommonType(l.dataType, r.dataType)) | ||
| } | ||
|
|
||
| // The number of columns/expressions must match between LHS and RHS of an | ||
| // IN subquery expression. | ||
| if (commonTypes.length == lhs.length) { | ||
| val castedRhs = rhs.zip(commonTypes).map { | ||
| case (e, dt) if e.dataType != dt => Alias(Cast(e, dt), e.name)() | ||
| case (e, _) => e | ||
| } | ||
| val castedLhs = lhs.zip(commonTypes).map { | ||
| case (e, dt) if e.dataType != dt => Cast(e, dt) | ||
| case (e, _) => e | ||
| } | ||
|
|
||
| // Before constructing the In expression, wrap the multi values in LHS | ||
| // in a CreatedNamedStruct. | ||
| val newLhs = castedLhs match { | ||
| case Seq(lhs) => lhs | ||
| case _ => CreateStruct(castedLhs) | ||
| } | ||
|
|
||
| In(newLhs, Seq(ListQuery(Project(castedRhs, sub), children, exprId))) | ||
| } else { | ||
| i | ||
| } | ||
|
|
||
| case i @ In(a, b) if b.exists(_.dataType != a.dataType) => | ||
| findWiderCommonType(i.children.map(_.dataType)) match { | ||
| case Some(finalDataType) => i.withNewChildren(i.children.map(Cast(_, finalDataType))) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now the de-duplication of sub plan happens in optimizer when we actually pull up the correlated predicates. Thus the project case is simplified.