diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index 59cbcf4833482..70c85def45d99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -90,124 +90,45 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { } case Cast(child, dataType, _, true) => generateExpression(child).map(v => new V2Cast(v, dataType)) - case Abs(child, true) => generateExpression(child) - .map(v => new GeneralScalarExpression("ABS", Array[V2Expression](v))) - case Coalesce(children) => - val childrenExpressions = children.flatMap(generateExpression(_)) - if (children.length == childrenExpressions.length) { - Some(new GeneralScalarExpression("COALESCE", childrenExpressions.toArray[V2Expression])) - } else { - None - } - case Greatest(children) => - val childrenExpressions = children.flatMap(generateExpression(_)) - if (children.length == childrenExpressions.length) { - Some(new GeneralScalarExpression("GREATEST", childrenExpressions.toArray[V2Expression])) - } else { - None - } - case Least(children) => - val childrenExpressions = children.flatMap(generateExpression(_)) - if (children.length == childrenExpressions.length) { - Some(new GeneralScalarExpression("LEAST", childrenExpressions.toArray[V2Expression])) - } else { - None - } + case Abs(child, true) => generateExpressionWithName("ABS", Seq(child)) + case Coalesce(children) => generateExpressionWithName("COALESCE", children) + case Greatest(children) => generateExpressionWithName("GREATEST", children) + case Least(children) => generateExpressionWithName("LEAST", children) case Rand(child, hideSeed) => if (hideSeed) { Some(new GeneralScalarExpression("RAND", Array.empty[V2Expression])) } else { - generateExpression(child) - .map(v => new GeneralScalarExpression("RAND", Array[V2Expression](v))) - } - case log: Logarithm => - val l = generateExpression(log.left) - val r = generateExpression(log.right) - if (l.isDefined && r.isDefined) { - Some(new GeneralScalarExpression("LOG", Array[V2Expression](l.get, r.get))) - } else { - None - } - case Log10(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("LOG10", Array[V2Expression](v))) - case Log2(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("LOG2", Array[V2Expression](v))) - case Log(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("LN", Array[V2Expression](v))) - case Exp(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("EXP", Array[V2Expression](v))) - case Pow(left, right) => - val l = generateExpression(left) - val r = generateExpression(right) - if (l.isDefined && r.isDefined) { - Some(new GeneralScalarExpression("POWER", Array[V2Expression](l.get, r.get))) - } else { - None - } - case Sqrt(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("SQRT", Array[V2Expression](v))) - case Floor(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("FLOOR", Array[V2Expression](v))) - case Ceil(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("CEIL", Array[V2Expression](v))) - case round: Round => - val l = generateExpression(round.left) - val r = generateExpression(round.right) - if (l.isDefined && r.isDefined) { - Some(new GeneralScalarExpression("ROUND", Array[V2Expression](l.get, r.get))) - } else { - None - } - case Sin(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("SIN", Array[V2Expression](v))) - case Sinh(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("SINH", Array[V2Expression](v))) - case Cos(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("COS", Array[V2Expression](v))) - case Cosh(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("COSH", Array[V2Expression](v))) - case Tan(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("TAN", Array[V2Expression](v))) - case Tanh(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("TANH", Array[V2Expression](v))) - case Cot(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("COT", Array[V2Expression](v))) - case Asin(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("ASIN", Array[V2Expression](v))) - case Asinh(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("ASINH", Array[V2Expression](v))) - case Acos(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("ACOS", Array[V2Expression](v))) - case Acosh(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("ACOSH", Array[V2Expression](v))) - case Atan(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("ATAN", Array[V2Expression](v))) - case Atanh(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("ATANH", Array[V2Expression](v))) - case atan2: Atan2 => - val l = generateExpression(atan2.left) - val r = generateExpression(atan2.right) - if (l.isDefined && r.isDefined) { - Some(new GeneralScalarExpression("ATAN2", Array[V2Expression](l.get, r.get))) - } else { - None - } - case Cbrt(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("CBRT", Array[V2Expression](v))) - case ToDegrees(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("DEGREES", Array[V2Expression](v))) - case ToRadians(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("RADIANS", Array[V2Expression](v))) - case Signum(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("SIGN", Array[V2Expression](v))) - case wb: WidthBucket => - val childrenExpressions = wb.children.flatMap(generateExpression(_)) - if (childrenExpressions.length == wb.children.length) { - Some(new GeneralScalarExpression("WIDTH_BUCKET", - childrenExpressions.toArray[V2Expression])) - } else { - None - } + generateExpressionWithName("RAND", Seq(child)) + } + case log: Logarithm => generateExpressionWithName("LOG", log.children) + case Log10(child) => generateExpressionWithName("LOG10", Seq(child)) + case Log2(child) => generateExpressionWithName("LOG2", Seq(child)) + case Log(child) => generateExpressionWithName("LN", Seq(child)) + case Exp(child) => generateExpressionWithName("EXP", Seq(child)) + case pow: Pow => generateExpressionWithName("POWER", pow.children) + case Sqrt(child) => generateExpressionWithName("SQRT", Seq(child)) + case Floor(child) => generateExpressionWithName("FLOOR", Seq(child)) + case Ceil(child) => generateExpressionWithName("CEIL", Seq(child)) + case round: Round => generateExpressionWithName("ROUND", round.children) + case Sin(child) => generateExpressionWithName("SIN", Seq(child)) + case Sinh(child) => generateExpressionWithName("SINH", Seq(child)) + case Cos(child) => generateExpressionWithName("COS", Seq(child)) + case Cosh(child) => generateExpressionWithName("COSH", Seq(child)) + case Tan(child) => generateExpressionWithName("TAN", Seq(child)) + case Tanh(child) => generateExpressionWithName("TANH", Seq(child)) + case Cot(child) => generateExpressionWithName("COT", Seq(child)) + case Asin(child) => generateExpressionWithName("ASIN", Seq(child)) + case Asinh(child) => generateExpressionWithName("ASINH", Seq(child)) + case Acos(child) => generateExpressionWithName("ACOS", Seq(child)) + case Acosh(child) => generateExpressionWithName("ACOSH", Seq(child)) + case Atan(child) => generateExpressionWithName("ATAN", Seq(child)) + case Atanh(child) => generateExpressionWithName("ATANH", Seq(child)) + case atan2: Atan2 => generateExpressionWithName("ATAN2", atan2.children) + case Cbrt(child) => generateExpressionWithName("CBRT", Seq(child)) + case ToDegrees(child) => generateExpressionWithName("DEGREES", Seq(child)) + case ToRadians(child) => generateExpressionWithName("RADIANS", Seq(child)) + case Signum(child) => generateExpressionWithName("SIGN", Seq(child)) + case wb: WidthBucket => generateExpressionWithName("WIDTH_BUCKET", wb.children) case and: And => // AND expects predicate val l = generateExpression(and.left, true) @@ -258,10 +179,8 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { assert(v.isInstanceOf[V2Predicate]) new V2Not(v.asInstanceOf[V2Predicate]) } - case UnaryMinus(child, true) => generateExpression(child) - .map(v => new GeneralScalarExpression("-", Array[V2Expression](v))) - case BitwiseNot(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("~", Array[V2Expression](v))) + case UnaryMinus(child, true) => generateExpressionWithName("-", Seq(child)) + case BitwiseNot(child) => generateExpressionWithName("~", Seq(child)) case CaseWhen(branches, elseValue) => val conditions = branches.map(_._1).flatMap(generateExpression(_, true)) val values = branches.map(_._2).flatMap(generateExpression(_, true)) @@ -282,93 +201,30 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { } else { None } - case iff: If => - val childrenExpressions = iff.children.flatMap(generateExpression(_)) - if (iff.children.length == childrenExpressions.length) { - Some(new GeneralScalarExpression("CASE_WHEN", childrenExpressions.toArray[V2Expression])) - } else { - None - } + case iff: If => generateExpressionWithName("CASE_WHEN", iff.children) case substring: Substring => val children = if (substring.len == Literal(Integer.MAX_VALUE)) { Seq(substring.str, substring.pos) } else { substring.children } - val childrenExpressions = children.flatMap(generateExpression(_)) - if (childrenExpressions.length == children.length) { - Some(new GeneralScalarExpression("SUBSTRING", - childrenExpressions.toArray[V2Expression])) - } else { - None - } - case Upper(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("UPPER", Array[V2Expression](v))) - case Lower(child) => generateExpression(child) - .map(v => new GeneralScalarExpression("LOWER", Array[V2Expression](v))) - case translate: StringTranslate => - val childrenExpressions = translate.children.flatMap(generateExpression(_)) - if (childrenExpressions.length == translate.children.length) { - Some(new GeneralScalarExpression("TRANSLATE", - childrenExpressions.toArray[V2Expression])) - } else { - None - } - case trim: StringTrim => - val childrenExpressions = trim.children.flatMap(generateExpression(_)) - if (childrenExpressions.length == trim.children.length) { - Some(new GeneralScalarExpression("TRIM", childrenExpressions.toArray[V2Expression])) - } else { - None - } - case trim: StringTrimLeft => - val childrenExpressions = trim.children.flatMap(generateExpression(_)) - if (childrenExpressions.length == trim.children.length) { - Some(new GeneralScalarExpression("LTRIM", childrenExpressions.toArray[V2Expression])) - } else { - None - } - case trim: StringTrimRight => - val childrenExpressions = trim.children.flatMap(generateExpression(_)) - if (childrenExpressions.length == trim.children.length) { - Some(new GeneralScalarExpression("RTRIM", childrenExpressions.toArray[V2Expression])) - } else { - None - } + generateExpressionWithName("SUBSTRING", children) + case Upper(child) => generateExpressionWithName("UPPER", Seq(child)) + case Lower(child) => generateExpressionWithName("LOWER", Seq(child)) + case translate: StringTranslate => generateExpressionWithName("TRANSLATE", translate.children) + case trim: StringTrim => generateExpressionWithName("TRIM", trim.children) + case trim: StringTrimLeft => generateExpressionWithName("LTRIM", trim.children) + case trim: StringTrimRight => generateExpressionWithName("RTRIM", trim.children) case overlay: Overlay => val children = if (overlay.len == Literal(-1)) { Seq(overlay.input, overlay.replace, overlay.pos) } else { overlay.children } - val childrenExpressions = children.flatMap(generateExpression(_)) - if (childrenExpressions.length == children.length) { - Some(new GeneralScalarExpression("OVERLAY", - childrenExpressions.toArray[V2Expression])) - } else { - None - } - case date: DateAdd => - val childrenExpressions = date.children.flatMap(generateExpression(_)) - if (childrenExpressions.length == date.children.length) { - Some(new GeneralScalarExpression("DATE_ADD", childrenExpressions.toArray[V2Expression])) - } else { - None - } - case date: DateDiff => - val childrenExpressions = date.children.flatMap(generateExpression(_)) - if (childrenExpressions.length == date.children.length) { - Some(new GeneralScalarExpression("DATE_DIFF", childrenExpressions.toArray[V2Expression])) - } else { - None - } - case date: TruncDate => - val childrenExpressions = date.children.flatMap(generateExpression(_)) - if (childrenExpressions.length == date.children.length) { - Some(new GeneralScalarExpression("TRUNC", childrenExpressions.toArray[V2Expression])) - } else { - None - } + generateExpressionWithName("OVERLAY", children) + case date: DateAdd => generateExpressionWithName("DATE_ADD", date.children) + case date: DateDiff => generateExpressionWithName("DATE_DIFF", date.children) + case date: TruncDate => generateExpressionWithName("TRUNC", date.children) case Second(child, _) => generateExpression(child).map(v => new V2Extract("SECOND", v)) case Minute(child, _) => @@ -429,6 +285,16 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { case _ => operatorName } } + + private def generateExpressionWithName( + v2ExpressionName: String, children: Seq[Expression]): Option[V2Expression] = { + val childrenExpressions = children.flatMap(generateExpression(_)) + if (childrenExpressions.length == children.length) { + Some(new GeneralScalarExpression(v2ExpressionName, childrenExpressions.toArray[V2Expression])) + } else { + None + } + } } object ColumnOrField {