From 654b36d9616da25268ae8f93bb422f94ac5172c3 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Fri, 22 Jul 2022 11:33:56 +0800 Subject: [PATCH 1/2] [SPARK-39836][SQL] Simplify V2ExpressionBuilder by extract common method. --- .../catalyst/util/V2ExpressionBuilder.scala | 160 +++--------------- 1 file changed, 28 insertions(+), 132 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index 59cbcf4833482..0a78dfe9958b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -92,27 +92,9 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { generateExpression(child).map(v => new V2Cast(v, dataType)) case Abs(child, true) => generateExpression(child) .map(v => new GeneralScalarExpression("ABS", Array[V2Expression](v))) - case Coalesce(children) => - val childrenExpressions = children.flatMap(generateExpression(_)) - if (children.length == childrenExpressions.length) { - Some(new GeneralScalarExpression("COALESCE", childrenExpressions.toArray[V2Expression])) - } else { - None - } - case Greatest(children) => - val childrenExpressions = children.flatMap(generateExpression(_)) - if (children.length == childrenExpressions.length) { - Some(new GeneralScalarExpression("GREATEST", childrenExpressions.toArray[V2Expression])) - } else { - None - } - case Least(children) => - val childrenExpressions = children.flatMap(generateExpression(_)) - if (children.length == childrenExpressions.length) { - Some(new GeneralScalarExpression("LEAST", childrenExpressions.toArray[V2Expression])) - } else { - None - } + case Coalesce(children) => generateExpressionWithName("COALESCE", children) + case Greatest(children) => generateExpressionWithName("GREATEST", children) + case Least(children) => generateExpressionWithName("LEAST", children) case Rand(child, hideSeed) => if (hideSeed) { Some(new GeneralScalarExpression("RAND", Array.empty[V2Expression])) @@ -120,14 +102,7 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { generateExpression(child) .map(v => new GeneralScalarExpression("RAND", Array[V2Expression](v))) } - case log: Logarithm => - val l = generateExpression(log.left) - val r = generateExpression(log.right) - if (l.isDefined && r.isDefined) { - Some(new GeneralScalarExpression("LOG", Array[V2Expression](l.get, r.get))) - } else { - None - } + case log: Logarithm => generateExpressionWithName("LOG", log.children) case Log10(child) => generateExpression(child) .map(v => new GeneralScalarExpression("LOG10", Array[V2Expression](v))) case Log2(child) => generateExpression(child) @@ -136,28 +111,14 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { .map(v => new GeneralScalarExpression("LN", Array[V2Expression](v))) case Exp(child) => generateExpression(child) .map(v => new GeneralScalarExpression("EXP", Array[V2Expression](v))) - case Pow(left, right) => - val l = generateExpression(left) - val r = generateExpression(right) - if (l.isDefined && r.isDefined) { - Some(new GeneralScalarExpression("POWER", Array[V2Expression](l.get, r.get))) - } else { - None - } + case pow: Pow => generateExpressionWithName("POWER", pow.children) case Sqrt(child) => generateExpression(child) .map(v => new GeneralScalarExpression("SQRT", Array[V2Expression](v))) case Floor(child) => generateExpression(child) .map(v => new GeneralScalarExpression("FLOOR", Array[V2Expression](v))) case Ceil(child) => generateExpression(child) .map(v => new GeneralScalarExpression("CEIL", Array[V2Expression](v))) - case round: Round => - val l = generateExpression(round.left) - val r = generateExpression(round.right) - if (l.isDefined && r.isDefined) { - Some(new GeneralScalarExpression("ROUND", Array[V2Expression](l.get, r.get))) - } else { - None - } + case round: Round => generateExpressionWithName("ROUND", round.children) case Sin(child) => generateExpression(child) .map(v => new GeneralScalarExpression("SIN", Array[V2Expression](v))) case Sinh(child) => generateExpression(child) @@ -184,14 +145,7 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { .map(v => new GeneralScalarExpression("ATAN", Array[V2Expression](v))) case Atanh(child) => generateExpression(child) .map(v => new GeneralScalarExpression("ATANH", Array[V2Expression](v))) - case atan2: Atan2 => - val l = generateExpression(atan2.left) - val r = generateExpression(atan2.right) - if (l.isDefined && r.isDefined) { - Some(new GeneralScalarExpression("ATAN2", Array[V2Expression](l.get, r.get))) - } else { - None - } + case atan2: Atan2 => generateExpressionWithName("ATAN2", atan2.children) case Cbrt(child) => generateExpression(child) .map(v => new GeneralScalarExpression("CBRT", Array[V2Expression](v))) case ToDegrees(child) => generateExpression(child) @@ -200,14 +154,7 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { .map(v => new GeneralScalarExpression("RADIANS", Array[V2Expression](v))) case Signum(child) => generateExpression(child) .map(v => new GeneralScalarExpression("SIGN", Array[V2Expression](v))) - case wb: WidthBucket => - val childrenExpressions = wb.children.flatMap(generateExpression(_)) - if (childrenExpressions.length == wb.children.length) { - Some(new GeneralScalarExpression("WIDTH_BUCKET", - childrenExpressions.toArray[V2Expression])) - } else { - None - } + case wb: WidthBucket => generateExpressionWithName("WIDTH_BUCKET", wb.children) case and: And => // AND expects predicate val l = generateExpression(and.left, true) @@ -282,93 +229,32 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { } else { None } - case iff: If => - val childrenExpressions = iff.children.flatMap(generateExpression(_)) - if (iff.children.length == childrenExpressions.length) { - Some(new GeneralScalarExpression("CASE_WHEN", childrenExpressions.toArray[V2Expression])) - } else { - None - } + case iff: If => generateExpressionWithName("CASE_WHEN", iff.children) case substring: Substring => val children = if (substring.len == Literal(Integer.MAX_VALUE)) { Seq(substring.str, substring.pos) } else { substring.children } - val childrenExpressions = children.flatMap(generateExpression(_)) - if (childrenExpressions.length == children.length) { - Some(new GeneralScalarExpression("SUBSTRING", - childrenExpressions.toArray[V2Expression])) - } else { - None - } + generateExpressionWithName("SUBSTRING", children) case Upper(child) => generateExpression(child) .map(v => new GeneralScalarExpression("UPPER", Array[V2Expression](v))) case Lower(child) => generateExpression(child) .map(v => new GeneralScalarExpression("LOWER", Array[V2Expression](v))) - case translate: StringTranslate => - val childrenExpressions = translate.children.flatMap(generateExpression(_)) - if (childrenExpressions.length == translate.children.length) { - Some(new GeneralScalarExpression("TRANSLATE", - childrenExpressions.toArray[V2Expression])) - } else { - None - } - case trim: StringTrim => - val childrenExpressions = trim.children.flatMap(generateExpression(_)) - if (childrenExpressions.length == trim.children.length) { - Some(new GeneralScalarExpression("TRIM", childrenExpressions.toArray[V2Expression])) - } else { - None - } - case trim: StringTrimLeft => - val childrenExpressions = trim.children.flatMap(generateExpression(_)) - if (childrenExpressions.length == trim.children.length) { - Some(new GeneralScalarExpression("LTRIM", childrenExpressions.toArray[V2Expression])) - } else { - None - } - case trim: StringTrimRight => - val childrenExpressions = trim.children.flatMap(generateExpression(_)) - if (childrenExpressions.length == trim.children.length) { - Some(new GeneralScalarExpression("RTRIM", childrenExpressions.toArray[V2Expression])) - } else { - None - } + case translate: StringTranslate => generateExpressionWithName("TRANSLATE", translate.children) + case trim: StringTrim => generateExpressionWithName("TRIM", trim.children) + case trim: StringTrimLeft => generateExpressionWithName("LTRIM", trim.children) + case trim: StringTrimRight => generateExpressionWithName("RTRIM", trim.children) case overlay: Overlay => val children = if (overlay.len == Literal(-1)) { Seq(overlay.input, overlay.replace, overlay.pos) } else { overlay.children } - val childrenExpressions = children.flatMap(generateExpression(_)) - if (childrenExpressions.length == children.length) { - Some(new GeneralScalarExpression("OVERLAY", - childrenExpressions.toArray[V2Expression])) - } else { - None - } - case date: DateAdd => - val childrenExpressions = date.children.flatMap(generateExpression(_)) - if (childrenExpressions.length == date.children.length) { - Some(new GeneralScalarExpression("DATE_ADD", childrenExpressions.toArray[V2Expression])) - } else { - None - } - case date: DateDiff => - val childrenExpressions = date.children.flatMap(generateExpression(_)) - if (childrenExpressions.length == date.children.length) { - Some(new GeneralScalarExpression("DATE_DIFF", childrenExpressions.toArray[V2Expression])) - } else { - None - } - case date: TruncDate => - val childrenExpressions = date.children.flatMap(generateExpression(_)) - if (childrenExpressions.length == date.children.length) { - Some(new GeneralScalarExpression("TRUNC", childrenExpressions.toArray[V2Expression])) - } else { - None - } + generateExpressionWithName("OVERLAY", children) + case date: DateAdd => generateExpressionWithName("DATE_ADD", date.children) + case date: DateDiff => generateExpressionWithName("DATE_DIFF", date.children) + case date: TruncDate => generateExpressionWithName("TRUNC", date.children) case Second(child, _) => generateExpression(child).map(v => new V2Extract("SECOND", v)) case Minute(child, _) => @@ -429,6 +315,16 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { case _ => operatorName } } + + private def generateExpressionWithName( + v2ExpressionName: String, children: Seq[Expression]): Option[V2Expression] = { + val childrenExpressions = children.flatMap(generateExpression(_)) + if (childrenExpressions.length == children.length) { + Some(new GeneralScalarExpression(v2ExpressionName, childrenExpressions.toArray[V2Expression])) + } else { + None + } + } } object ColumnOrField { From a8899b93b07c03b5d0c24ced54623d3ee471593f Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Sat, 23 Jul 2022 19:47:25 +0800 Subject: [PATCH 2/2] Update code --- .../catalyst/util/V2ExpressionBuilder.scala | 90 +++++++------------ 1 file changed, 30 insertions(+), 60 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index 0a78dfe9958b1..70c85def45d99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -90,8 +90,7 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { } case Cast(child, dataType, _, true) => generateExpression(child).map(v => new V2Cast(v, dataType)) - case Abs(child, true) => generateExpression(child) - .map(v => new GeneralScalarExpression("ABS", Array[V2Expression](v))) + case Abs(child, true) => generateExpressionWithName("ABS", Seq(child)) case Coalesce(children) => generateExpressionWithName("COALESCE", children) case Greatest(children) => generateExpressionWithName("GREATEST", children) case Least(children) => generateExpressionWithName("LEAST", children) @@ -99,61 +98,36 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { if (hideSeed) { Some(new GeneralScalarExpression("RAND", Array.empty[V2Expression])) } else { - generateExpression(child) - .map(v => new GeneralScalarExpression("RAND", Array[V2Expression](v))) + generateExpressionWithName("RAND", Seq(child)) } case log: Logarithm => generateExpressionWithName("LOG", log.children) - case Log10(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("LOG10", Array[V2Expression](v))) - case Log2(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("LOG2", Array[V2Expression](v))) - case Log(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("LN", Array[V2Expression](v))) - case Exp(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("EXP", Array[V2Expression](v))) + case Log10(child) => generateExpressionWithName("LOG10", Seq(child)) + case Log2(child) => generateExpressionWithName("LOG2", Seq(child)) + case Log(child) => generateExpressionWithName("LN", Seq(child)) + case Exp(child) => generateExpressionWithName("EXP", Seq(child)) case pow: Pow => generateExpressionWithName("POWER", pow.children) - case Sqrt(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("SQRT", Array[V2Expression](v))) - case Floor(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("FLOOR", Array[V2Expression](v))) - case Ceil(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("CEIL", Array[V2Expression](v))) + case Sqrt(child) => generateExpressionWithName("SQRT", Seq(child)) + case Floor(child) => generateExpressionWithName("FLOOR", Seq(child)) + case Ceil(child) => generateExpressionWithName("CEIL", Seq(child)) case round: Round => generateExpressionWithName("ROUND", round.children) - case Sin(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("SIN", Array[V2Expression](v))) - case Sinh(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("SINH", Array[V2Expression](v))) - case Cos(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("COS", Array[V2Expression](v))) - case Cosh(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("COSH", Array[V2Expression](v))) - case Tan(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("TAN", Array[V2Expression](v))) - case Tanh(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("TANH", Array[V2Expression](v))) - case Cot(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("COT", Array[V2Expression](v))) - case Asin(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("ASIN", Array[V2Expression](v))) - case Asinh(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("ASINH", Array[V2Expression](v))) - case Acos(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("ACOS", Array[V2Expression](v))) - case Acosh(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("ACOSH", Array[V2Expression](v))) - case Atan(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("ATAN", Array[V2Expression](v))) - case Atanh(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("ATANH", Array[V2Expression](v))) + case Sin(child) => generateExpressionWithName("SIN", Seq(child)) + case Sinh(child) => generateExpressionWithName("SINH", Seq(child)) + case Cos(child) => generateExpressionWithName("COS", Seq(child)) + case Cosh(child) => generateExpressionWithName("COSH", Seq(child)) + case Tan(child) => generateExpressionWithName("TAN", Seq(child)) + case Tanh(child) => generateExpressionWithName("TANH", Seq(child)) + case Cot(child) => generateExpressionWithName("COT", Seq(child)) + case Asin(child) => generateExpressionWithName("ASIN", Seq(child)) + case Asinh(child) => generateExpressionWithName("ASINH", Seq(child)) + case Acos(child) => generateExpressionWithName("ACOS", Seq(child)) + case Acosh(child) => generateExpressionWithName("ACOSH", Seq(child)) + case Atan(child) => generateExpressionWithName("ATAN", Seq(child)) + case Atanh(child) => generateExpressionWithName("ATANH", Seq(child)) case atan2: Atan2 => generateExpressionWithName("ATAN2", atan2.children) - case Cbrt(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("CBRT", Array[V2Expression](v))) - case ToDegrees(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("DEGREES", Array[V2Expression](v))) - case ToRadians(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("RADIANS", Array[V2Expression](v))) - case Signum(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("SIGN", Array[V2Expression](v))) + case Cbrt(child) => generateExpressionWithName("CBRT", Seq(child)) + case ToDegrees(child) => generateExpressionWithName("DEGREES", Seq(child)) + case ToRadians(child) => generateExpressionWithName("RADIANS", Seq(child)) + case Signum(child) => generateExpressionWithName("SIGN", Seq(child)) case wb: WidthBucket => generateExpressionWithName("WIDTH_BUCKET", wb.children) case and: And => // AND expects predicate @@ -205,10 +179,8 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { assert(v.isInstanceOf[V2Predicate]) new V2Not(v.asInstanceOf[V2Predicate]) } - case UnaryMinus(child, true) => generateExpression(child) - .map(v => new GeneralScalarExpression("-", Array[V2Expression](v))) - case BitwiseNot(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("~", Array[V2Expression](v))) + case UnaryMinus(child, true) => generateExpressionWithName("-", Seq(child)) + case BitwiseNot(child) => generateExpressionWithName("~", Seq(child)) case CaseWhen(branches, elseValue) => val conditions = branches.map(_._1).flatMap(generateExpression(_, true)) val values = branches.map(_._2).flatMap(generateExpression(_, true)) @@ -237,10 +209,8 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { substring.children } generateExpressionWithName("SUBSTRING", children) - case Upper(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("UPPER", Array[V2Expression](v))) - case Lower(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("LOWER", Array[V2Expression](v))) + case Upper(child) => generateExpressionWithName("UPPER", Seq(child)) + case Lower(child) => generateExpressionWithName("LOWER", Seq(child)) case translate: StringTranslate => generateExpressionWithName("TRANSLATE", translate.children) case trim: StringTrim => generateExpressionWithName("TRIM", trim.children) case trim: StringTrimLeft => generateExpressionWithName("LTRIM", trim.children)