From e37eb8a5316e12f81a30eb12aac0be8370034b58 Mon Sep 17 00:00:00 2001 From: chenzhx Date: Fri, 22 Apr 2022 21:23:29 +0800 Subject: [PATCH] [SPARK-38897][SQL]DS V2 supports push down string functions --- .../expressions/GeneralScalarExpression.java | 36 +++++++++++++++ .../expressions/datetime/CurrentDate.java | 43 ++++++++++++++++++ .../util/V2ExpressionSQLBuilder.java | 15 +++++++ .../catalyst/util/V2ExpressionBuilder.scala | 45 ++++++++++++++++++- .../org/apache/spark/sql/jdbc/H2Dialect.scala | 2 +- .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 28 ++++++++++++ 6 files changed, 167 insertions(+), 2 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/datetime/CurrentDate.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java index 58082d5ee09c..d525f5699cf0 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java @@ -148,6 +148,42 @@ *
  • Since version: 3.3.0
  • * * + *
  • Name: SUBSTRING + * + *
  • + *
  • Name: UPPER + * + *
  • + *
  • Name: LOWER + * + *
  • + *
  • Name: TRANSLATE + * + *
  • + *
  • Name: TRIM + * + *
  • + *
  • Name: OVERLAY + * + *
  • * * Note: SQL semantic conforms ANSI standard, so some expressions are not supported when ANSI off, * including: add, subtract, multiply, divide, remainder, pmod. diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/datetime/CurrentDate.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/datetime/CurrentDate.java new file mode 100644 index 000000000000..1bb6276f32dd --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/datetime/CurrentDate.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.datetime; + +import org.apache.spark.sql.connector.expressions.Expression; +import org.apache.spark.sql.types.DataType; +import scala.None; +import scala.Option; + +import java.io.Serializable; + +/** + * Represents a cast expression in the public logical expression API. + * + * @since 3.3.0 + */ +public class CurrentDate implements Expression, Serializable { + private String timeZoneId; + + CurrentDate(String timeZoneId) { + this.timeZoneId = timeZoneId; + } + + public String timeZoneId() { return timeZoneId; } + + @Override + public Expression[] children() { return EMPTY_EXPRESSION; } +} \ No newline at end of file diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java index c9dfa2003e3c..6d4a0a85b7d3 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java @@ -102,6 +102,12 @@ public String build(Expression expr) { case "FLOOR": case "CEIL": case "WIDTH_BUCKET": + case "SUBSTRING": + case "UPPER": + case "LOWER": + case "TRANSLATE": + case "TRIM": + case "OVERLAY": return visitSQLFunction(name, Arrays.stream(e.children()).map(c -> build(c)).toArray(String[]::new)); case "CASE_WHEN": { @@ -228,4 +234,13 @@ protected String visitSQLFunction(String funcName, String[] inputs) { protected String visitUnexpectedExpr(Expression expr) throws IllegalArgumentException { throw new IllegalArgumentException("Unexpected V2 expression: " + expr); } + + protected String visitLike(String name, String l, String r, char escape) throws IllegalArgumentException { + switch (escape) { + case '\\' : + return l + " " + name + " " + r; + default: + return l + " " + name + " " + r + " ESCAPE " + escape; + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index 487b809d48a0..b5a5933fb825 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.util -import org.apache.spark.sql.catalyst.expressions.{Abs, Add, And, BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen, Cast, Ceil, Coalesce, Contains, Divide, EndsWith, EqualTo, Exp, Expression, Floor, In, InSet, IsNotNull, IsNull, Literal, Log, Multiply, Not, Or, Pow, Predicate, Remainder, Sqrt, StartsWith, StringPredicate, Subtract, UnaryMinus, WidthBucket} +import org.apache.spark.sql.catalyst.expressions.{Abs, Add, And, BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen, Cast, Ceil, Coalesce, Contains, Divide, EndsWith, EqualTo, Exp, Expression, Floor, In, InSet, IsNotNull, IsNull, Literal, Log, Lower, Multiply, Not, Or, Overlay, Pow, Predicate, Remainder, Sqrt, StartsWith, StringPredicate, StringTranslate, StringTrim, Substring, Subtract, UnaryMinus, Upper, WidthBucket} import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, FieldReference, GeneralScalarExpression, LiteralValue} import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate} import org.apache.spark.sql.execution.datasources.PushableColumn @@ -200,6 +200,49 @@ class V2ExpressionBuilder( } else { None } + case Substring(str, pos, len) => + val s = generateExpression(str) + val p = generateExpression(pos) + val l = generateExpression(len) + if (s.isDefined && p.isDefined && l.isDefined) { + Some(new GeneralScalarExpression("SUBSTRING", Array[V2Expression](s.get, p.get, l.get))) + } else { + None + } + case Upper(child) => generateExpression(child) + .map(v => new GeneralScalarExpression("UPPER", Array[V2Expression](v))) + case Lower(child) => generateExpression(child) + .map(v => new GeneralScalarExpression("LOWER", Array[V2Expression](v))) + case StringTranslate(str, matching, replace) => + val s = generateExpression(str) + val m = generateExpression(matching) + val r = generateExpression(replace) + if (s.isDefined && m.isDefined && r.isDefined) { + Some(new GeneralScalarExpression("TRANSLATE", + Array[V2Expression](s.get, m.get, r.get))) + } else { + None + } + case StringTrim(str, trim) => + val s = generateExpression(str) + if (trim.isDefined) { + trim.flatMap(generateExpression(_)).map { t => + new GeneralScalarExpression("TRIM", Array[V2Expression](s.get, t)) + } + } else { + Some(new GeneralScalarExpression("TRIM", Array[V2Expression](s.get))) + } + case Overlay(input, replace, pos, len) => + val i = generateExpression(input) + val r = generateExpression(replace) + val p = generateExpression(pos) + val l = generateExpression(len) + if (i.isDefined && r.isDefined && p.isDefined && l.isDefined) { + Some(new GeneralScalarExpression("OVERLAY", + Array[V2Expression](i.get, r.get, p.get, l.get))) + } else { + None + } // TODO supports other expressions case _ => None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala index 0aa971c0d3ab..beca3e657e79 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala @@ -35,7 +35,7 @@ private object H2Dialect extends JdbcDialect { class H2SQLBuilder extends JDBCSQLBuilder { override def visitSQLFunction(funcName: String, inputs: Array[String]): String = { funcName match { - case "WIDTH_BUCKET" => + case "WIDTH_BUCKET" | "OVERLAY" => val functionInfo = super.visitSQLFunction(funcName, inputs) throw QueryCompilationErrors.noSuchFunctionError("H2", functionInfo) case _ => super.visitSQLFunction(funcName, inputs) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 5cfa2f465a2b..33532c6d1d83 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -626,6 +626,34 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel } } + test("scan with filter push-down with string functions") { + val df1 = sql("select * FROM h2.test.employee where " + + "substr(name, 2, 1) = 'e'" + + " AND upper(name) = 'JEN' AND lower(name) = 'jen' ") + checkFiltersRemoved(df1) + val expectedPlanFragment1 = + "PushedFilters: [NAME IS NOT NULL, (SUBSTRING(NAME, 2, 1)) = 'e', " + + "UPPER(NAME) = 'JEN', LOWER(NAME) = 'jen']" + checkPushedInfo(df1, expectedPlanFragment1) + checkAnswer(df1, Seq(Row(6, "jen", 12000, 1200, true))) + + val df2 = sql("select * FROM h2.test.employee where " + + "trim(name) = 'jen'" + + "AND translate(name, 'e', 1) = 'j1n'") + checkFiltersRemoved(df2) + val expectedPlanFragment2 = + "PushedFilters: [NAME IS NOT NULL, TRIM(NAME) = 'jen', " + + "(TRANSLATE(NAME, 'e', '1')) = 'j1n']" + checkPushedInfo(df2, expectedPlanFragment2) + checkAnswer(df2, Seq(Row(6, "jen", 12000, 1200, true))) + + // H2 does not support width_bucket + val df3 = sql("select * FROM h2.test.employee where(OVERLAY(NAME, '1', 2, -1)) = 'j1n'") + checkFiltersRemoved(df3, false) + checkPushedInfo(df3, "PushedFilters: [NAME IS NOT NULL]") + checkAnswer(df3, Seq(Row(6, "jen", 12000, 1200, true))) + } + test("scan with aggregate push-down: MAX AVG with filter and group by") { val df = sql("select MAX(SaLaRY), AVG(BONUS) FROM h2.test.employee where dept > 0" + " group by DePt")