Skip to content

Commit d1975a1

Browse files
beliefercloud-fan
authored andcommitted
[SPARK-39836][SQL] Simplify V2ExpressionBuilder by extract common method
### What changes were proposed in this pull request? Currently, `V2ExpressionBuilder` have a lot of similar code, we can extract them as one common method. We can simplify the implement with the common method. ### Why are the changes needed? Simplify `V2ExpressionBuilder` by extract common method. ### Does this PR introduce _any_ user-facing change? 'No'. Just update inner implementation. ### How was this patch tested? N/A Closes #37249 from beliefer/SPARK-39836. Authored-by: Jiaan Geng <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 7358253 commit d1975a1

File tree

1 file changed

+59
-193
lines changed

1 file changed

+59
-193
lines changed

sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala

Lines changed: 59 additions & 193 deletions
Original file line numberDiff line numberDiff line change
@@ -90,124 +90,45 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) {
9090
}
9191
case Cast(child, dataType, _, true) =>
9292
generateExpression(child).map(v => new V2Cast(v, dataType))
93-
case Abs(child, true) => generateExpression(child)
94-
.map(v => new GeneralScalarExpression("ABS", Array[V2Expression](v)))
95-
case Coalesce(children) =>
96-
val childrenExpressions = children.flatMap(generateExpression(_))
97-
if (children.length == childrenExpressions.length) {
98-
Some(new GeneralScalarExpression("COALESCE", childrenExpressions.toArray[V2Expression]))
99-
} else {
100-
None
101-
}
102-
case Greatest(children) =>
103-
val childrenExpressions = children.flatMap(generateExpression(_))
104-
if (children.length == childrenExpressions.length) {
105-
Some(new GeneralScalarExpression("GREATEST", childrenExpressions.toArray[V2Expression]))
106-
} else {
107-
None
108-
}
109-
case Least(children) =>
110-
val childrenExpressions = children.flatMap(generateExpression(_))
111-
if (children.length == childrenExpressions.length) {
112-
Some(new GeneralScalarExpression("LEAST", childrenExpressions.toArray[V2Expression]))
113-
} else {
114-
None
115-
}
93+
case Abs(child, true) => generateExpressionWithName("ABS", Seq(child))
94+
case Coalesce(children) => generateExpressionWithName("COALESCE", children)
95+
case Greatest(children) => generateExpressionWithName("GREATEST", children)
96+
case Least(children) => generateExpressionWithName("LEAST", children)
11697
case Rand(child, hideSeed) =>
11798
if (hideSeed) {
11899
Some(new GeneralScalarExpression("RAND", Array.empty[V2Expression]))
119100
} else {
120-
generateExpression(child)
121-
.map(v => new GeneralScalarExpression("RAND", Array[V2Expression](v)))
122-
}
123-
case log: Logarithm =>
124-
val l = generateExpression(log.left)
125-
val r = generateExpression(log.right)
126-
if (l.isDefined && r.isDefined) {
127-
Some(new GeneralScalarExpression("LOG", Array[V2Expression](l.get, r.get)))
128-
} else {
129-
None
130-
}
131-
case Log10(child) => generateExpression(child)
132-
.map(v => new GeneralScalarExpression("LOG10", Array[V2Expression](v)))
133-
case Log2(child) => generateExpression(child)
134-
.map(v => new GeneralScalarExpression("LOG2", Array[V2Expression](v)))
135-
case Log(child) => generateExpression(child)
136-
.map(v => new GeneralScalarExpression("LN", Array[V2Expression](v)))
137-
case Exp(child) => generateExpression(child)
138-
.map(v => new GeneralScalarExpression("EXP", Array[V2Expression](v)))
139-
case Pow(left, right) =>
140-
val l = generateExpression(left)
141-
val r = generateExpression(right)
142-
if (l.isDefined && r.isDefined) {
143-
Some(new GeneralScalarExpression("POWER", Array[V2Expression](l.get, r.get)))
144-
} else {
145-
None
146-
}
147-
case Sqrt(child) => generateExpression(child)
148-
.map(v => new GeneralScalarExpression("SQRT", Array[V2Expression](v)))
149-
case Floor(child) => generateExpression(child)
150-
.map(v => new GeneralScalarExpression("FLOOR", Array[V2Expression](v)))
151-
case Ceil(child) => generateExpression(child)
152-
.map(v => new GeneralScalarExpression("CEIL", Array[V2Expression](v)))
153-
case round: Round =>
154-
val l = generateExpression(round.left)
155-
val r = generateExpression(round.right)
156-
if (l.isDefined && r.isDefined) {
157-
Some(new GeneralScalarExpression("ROUND", Array[V2Expression](l.get, r.get)))
158-
} else {
159-
None
160-
}
161-
case Sin(child) => generateExpression(child)
162-
.map(v => new GeneralScalarExpression("SIN", Array[V2Expression](v)))
163-
case Sinh(child) => generateExpression(child)
164-
.map(v => new GeneralScalarExpression("SINH", Array[V2Expression](v)))
165-
case Cos(child) => generateExpression(child)
166-
.map(v => new GeneralScalarExpression("COS", Array[V2Expression](v)))
167-
case Cosh(child) => generateExpression(child)
168-
.map(v => new GeneralScalarExpression("COSH", Array[V2Expression](v)))
169-
case Tan(child) => generateExpression(child)
170-
.map(v => new GeneralScalarExpression("TAN", Array[V2Expression](v)))
171-
case Tanh(child) => generateExpression(child)
172-
.map(v => new GeneralScalarExpression("TANH", Array[V2Expression](v)))
173-
case Cot(child) => generateExpression(child)
174-
.map(v => new GeneralScalarExpression("COT", Array[V2Expression](v)))
175-
case Asin(child) => generateExpression(child)
176-
.map(v => new GeneralScalarExpression("ASIN", Array[V2Expression](v)))
177-
case Asinh(child) => generateExpression(child)
178-
.map(v => new GeneralScalarExpression("ASINH", Array[V2Expression](v)))
179-
case Acos(child) => generateExpression(child)
180-
.map(v => new GeneralScalarExpression("ACOS", Array[V2Expression](v)))
181-
case Acosh(child) => generateExpression(child)
182-
.map(v => new GeneralScalarExpression("ACOSH", Array[V2Expression](v)))
183-
case Atan(child) => generateExpression(child)
184-
.map(v => new GeneralScalarExpression("ATAN", Array[V2Expression](v)))
185-
case Atanh(child) => generateExpression(child)
186-
.map(v => new GeneralScalarExpression("ATANH", Array[V2Expression](v)))
187-
case atan2: Atan2 =>
188-
val l = generateExpression(atan2.left)
189-
val r = generateExpression(atan2.right)
190-
if (l.isDefined && r.isDefined) {
191-
Some(new GeneralScalarExpression("ATAN2", Array[V2Expression](l.get, r.get)))
192-
} else {
193-
None
194-
}
195-
case Cbrt(child) => generateExpression(child)
196-
.map(v => new GeneralScalarExpression("CBRT", Array[V2Expression](v)))
197-
case ToDegrees(child) => generateExpression(child)
198-
.map(v => new GeneralScalarExpression("DEGREES", Array[V2Expression](v)))
199-
case ToRadians(child) => generateExpression(child)
200-
.map(v => new GeneralScalarExpression("RADIANS", Array[V2Expression](v)))
201-
case Signum(child) => generateExpression(child)
202-
.map(v => new GeneralScalarExpression("SIGN", Array[V2Expression](v)))
203-
case wb: WidthBucket =>
204-
val childrenExpressions = wb.children.flatMap(generateExpression(_))
205-
if (childrenExpressions.length == wb.children.length) {
206-
Some(new GeneralScalarExpression("WIDTH_BUCKET",
207-
childrenExpressions.toArray[V2Expression]))
208-
} else {
209-
None
210-
}
101+
generateExpressionWithName("RAND", Seq(child))
102+
}
103+
case log: Logarithm => generateExpressionWithName("LOG", log.children)
104+
case Log10(child) => generateExpressionWithName("LOG10", Seq(child))
105+
case Log2(child) => generateExpressionWithName("LOG2", Seq(child))
106+
case Log(child) => generateExpressionWithName("LN", Seq(child))
107+
case Exp(child) => generateExpressionWithName("EXP", Seq(child))
108+
case pow: Pow => generateExpressionWithName("POWER", pow.children)
109+
case Sqrt(child) => generateExpressionWithName("SQRT", Seq(child))
110+
case Floor(child) => generateExpressionWithName("FLOOR", Seq(child))
111+
case Ceil(child) => generateExpressionWithName("CEIL", Seq(child))
112+
case round: Round => generateExpressionWithName("ROUND", round.children)
113+
case Sin(child) => generateExpressionWithName("SIN", Seq(child))
114+
case Sinh(child) => generateExpressionWithName("SINH", Seq(child))
115+
case Cos(child) => generateExpressionWithName("COS", Seq(child))
116+
case Cosh(child) => generateExpressionWithName("COSH", Seq(child))
117+
case Tan(child) => generateExpressionWithName("TAN", Seq(child))
118+
case Tanh(child) => generateExpressionWithName("TANH", Seq(child))
119+
case Cot(child) => generateExpressionWithName("COT", Seq(child))
120+
case Asin(child) => generateExpressionWithName("ASIN", Seq(child))
121+
case Asinh(child) => generateExpressionWithName("ASINH", Seq(child))
122+
case Acos(child) => generateExpressionWithName("ACOS", Seq(child))
123+
case Acosh(child) => generateExpressionWithName("ACOSH", Seq(child))
124+
case Atan(child) => generateExpressionWithName("ATAN", Seq(child))
125+
case Atanh(child) => generateExpressionWithName("ATANH", Seq(child))
126+
case atan2: Atan2 => generateExpressionWithName("ATAN2", atan2.children)
127+
case Cbrt(child) => generateExpressionWithName("CBRT", Seq(child))
128+
case ToDegrees(child) => generateExpressionWithName("DEGREES", Seq(child))
129+
case ToRadians(child) => generateExpressionWithName("RADIANS", Seq(child))
130+
case Signum(child) => generateExpressionWithName("SIGN", Seq(child))
131+
case wb: WidthBucket => generateExpressionWithName("WIDTH_BUCKET", wb.children)
211132
case and: And =>
212133
// AND expects predicate
213134
val l = generateExpression(and.left, true)
@@ -258,10 +179,8 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) {
258179
assert(v.isInstanceOf[V2Predicate])
259180
new V2Not(v.asInstanceOf[V2Predicate])
260181
}
261-
case UnaryMinus(child, true) => generateExpression(child)
262-
.map(v => new GeneralScalarExpression("-", Array[V2Expression](v)))
263-
case BitwiseNot(child) => generateExpression(child)
264-
.map(v => new GeneralScalarExpression("~", Array[V2Expression](v)))
182+
case UnaryMinus(child, true) => generateExpressionWithName("-", Seq(child))
183+
case BitwiseNot(child) => generateExpressionWithName("~", Seq(child))
265184
case CaseWhen(branches, elseValue) =>
266185
val conditions = branches.map(_._1).flatMap(generateExpression(_, true))
267186
val values = branches.map(_._2).flatMap(generateExpression(_, true))
@@ -282,93 +201,30 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) {
282201
} else {
283202
None
284203
}
285-
case iff: If =>
286-
val childrenExpressions = iff.children.flatMap(generateExpression(_))
287-
if (iff.children.length == childrenExpressions.length) {
288-
Some(new GeneralScalarExpression("CASE_WHEN", childrenExpressions.toArray[V2Expression]))
289-
} else {
290-
None
291-
}
204+
case iff: If => generateExpressionWithName("CASE_WHEN", iff.children)
292205
case substring: Substring =>
293206
val children = if (substring.len == Literal(Integer.MAX_VALUE)) {
294207
Seq(substring.str, substring.pos)
295208
} else {
296209
substring.children
297210
}
298-
val childrenExpressions = children.flatMap(generateExpression(_))
299-
if (childrenExpressions.length == children.length) {
300-
Some(new GeneralScalarExpression("SUBSTRING",
301-
childrenExpressions.toArray[V2Expression]))
302-
} else {
303-
None
304-
}
305-
case Upper(child) => generateExpression(child)
306-
.map(v => new GeneralScalarExpression("UPPER", Array[V2Expression](v)))
307-
case Lower(child) => generateExpression(child)
308-
.map(v => new GeneralScalarExpression("LOWER", Array[V2Expression](v)))
309-
case translate: StringTranslate =>
310-
val childrenExpressions = translate.children.flatMap(generateExpression(_))
311-
if (childrenExpressions.length == translate.children.length) {
312-
Some(new GeneralScalarExpression("TRANSLATE",
313-
childrenExpressions.toArray[V2Expression]))
314-
} else {
315-
None
316-
}
317-
case trim: StringTrim =>
318-
val childrenExpressions = trim.children.flatMap(generateExpression(_))
319-
if (childrenExpressions.length == trim.children.length) {
320-
Some(new GeneralScalarExpression("TRIM", childrenExpressions.toArray[V2Expression]))
321-
} else {
322-
None
323-
}
324-
case trim: StringTrimLeft =>
325-
val childrenExpressions = trim.children.flatMap(generateExpression(_))
326-
if (childrenExpressions.length == trim.children.length) {
327-
Some(new GeneralScalarExpression("LTRIM", childrenExpressions.toArray[V2Expression]))
328-
} else {
329-
None
330-
}
331-
case trim: StringTrimRight =>
332-
val childrenExpressions = trim.children.flatMap(generateExpression(_))
333-
if (childrenExpressions.length == trim.children.length) {
334-
Some(new GeneralScalarExpression("RTRIM", childrenExpressions.toArray[V2Expression]))
335-
} else {
336-
None
337-
}
211+
generateExpressionWithName("SUBSTRING", children)
212+
case Upper(child) => generateExpressionWithName("UPPER", Seq(child))
213+
case Lower(child) => generateExpressionWithName("LOWER", Seq(child))
214+
case translate: StringTranslate => generateExpressionWithName("TRANSLATE", translate.children)
215+
case trim: StringTrim => generateExpressionWithName("TRIM", trim.children)
216+
case trim: StringTrimLeft => generateExpressionWithName("LTRIM", trim.children)
217+
case trim: StringTrimRight => generateExpressionWithName("RTRIM", trim.children)
338218
case overlay: Overlay =>
339219
val children = if (overlay.len == Literal(-1)) {
340220
Seq(overlay.input, overlay.replace, overlay.pos)
341221
} else {
342222
overlay.children
343223
}
344-
val childrenExpressions = children.flatMap(generateExpression(_))
345-
if (childrenExpressions.length == children.length) {
346-
Some(new GeneralScalarExpression("OVERLAY",
347-
childrenExpressions.toArray[V2Expression]))
348-
} else {
349-
None
350-
}
351-
case date: DateAdd =>
352-
val childrenExpressions = date.children.flatMap(generateExpression(_))
353-
if (childrenExpressions.length == date.children.length) {
354-
Some(new GeneralScalarExpression("DATE_ADD", childrenExpressions.toArray[V2Expression]))
355-
} else {
356-
None
357-
}
358-
case date: DateDiff =>
359-
val childrenExpressions = date.children.flatMap(generateExpression(_))
360-
if (childrenExpressions.length == date.children.length) {
361-
Some(new GeneralScalarExpression("DATE_DIFF", childrenExpressions.toArray[V2Expression]))
362-
} else {
363-
None
364-
}
365-
case date: TruncDate =>
366-
val childrenExpressions = date.children.flatMap(generateExpression(_))
367-
if (childrenExpressions.length == date.children.length) {
368-
Some(new GeneralScalarExpression("TRUNC", childrenExpressions.toArray[V2Expression]))
369-
} else {
370-
None
371-
}
224+
generateExpressionWithName("OVERLAY", children)
225+
case date: DateAdd => generateExpressionWithName("DATE_ADD", date.children)
226+
case date: DateDiff => generateExpressionWithName("DATE_DIFF", date.children)
227+
case date: TruncDate => generateExpressionWithName("TRUNC", date.children)
372228
case Second(child, _) =>
373229
generateExpression(child).map(v => new V2Extract("SECOND", v))
374230
case Minute(child, _) =>
@@ -429,6 +285,16 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) {
429285
case _ => operatorName
430286
}
431287
}
288+
289+
private def generateExpressionWithName(
290+
v2ExpressionName: String, children: Seq[Expression]): Option[V2Expression] = {
291+
val childrenExpressions = children.flatMap(generateExpression(_))
292+
if (childrenExpressions.length == children.length) {
293+
Some(new GeneralScalarExpression(v2ExpressionName, childrenExpressions.toArray[V2Expression]))
294+
} else {
295+
None
296+
}
297+
}
432298
}
433299

434300
object ColumnOrField {

0 commit comments

Comments
 (0)