Skip to content

Commit 7b6e1d5

Browse files
viiryadongjoon-hyun
authored andcommitted
[SPARK-25557][SQL] Nested column predicate pushdown for ORC
### What changes were proposed in this pull request? We added nested column predicate pushdown for Parquet in #27728. This patch extends the feature support to ORC. ### Why are the changes needed? Extending the feature to ORC for feature parity. Better performance for handling nested predicate pushdown. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit tests. Closes #28761 from viirya/SPARK-25557. Authored-by: Liang-Chi Hsieh <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 6c3d0a4 commit 7b6e1d5

File tree

11 files changed

+460
-310
lines changed

11 files changed

+460
-310
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2108,9 +2108,9 @@ object SQLConf {
21082108
.doc("A comma-separated list of data source short names or fully qualified data source " +
21092109
"implementation class names for which Spark tries to push down predicates for nested " +
21102110
"columns and/or names containing `dots` to data sources. This configuration is only " +
2111-
"effective with file-based data source in DSv1. Currently, Parquet implements " +
2112-
"both optimizations while ORC only supports predicates for names containing `dots`. The " +
2113-
"other data sources don't support this feature yet. So the default value is 'parquet,orc'.")
2111+
"effective with file-based data sources in DSv1. Currently, Parquet and ORC implement " +
2112+
"both optimizations. The other data sources don't support this feature yet. So the " +
2113+
"default value is 'parquet,orc'.")
21142114
.version("3.0.0")
21152115
.stringConf
21162116
.createWithDefault("parquet,orc")

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -668,6 +668,8 @@ abstract class PushableColumnBase {
668668
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper
669669
def helper(e: Expression): Option[Seq[String]] = e match {
670670
case a: Attribute =>
671+
// Attribute that contains dot "." in name is supported only when
672+
// nested predicate pushdown is enabled.
671673
if (nestedPredicatePushdownEnabled || !a.name.contains(".")) {
672674
Some(Seq(a.name))
673675
} else {

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFiltersBase.scala

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,11 @@
1717

1818
package org.apache.spark.sql.execution.datasources.orc
1919

20+
import java.util.Locale
21+
22+
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
2023
import org.apache.spark.sql.sources.{And, Filter}
21-
import org.apache.spark.sql.types.{AtomicType, BinaryType, DataType}
24+
import org.apache.spark.sql.types.{AtomicType, BinaryType, DataType, StructField, StructType}
2225

2326
/**
2427
* Methods that can be shared when upgrading the built-in Hive.
@@ -37,12 +40,45 @@ trait OrcFiltersBase {
3740
}
3841

3942
/**
40-
* Return true if this is a searchable type in ORC.
41-
* Both CharType and VarcharType are cleaned at AstBuilder.
43+
* This method returns a map which contains ORC field name and data type. Each key
44+
* represents a column; `dots` are used as separators for nested columns. If any part
45+
* of the names contains `dots`, it is quoted to avoid confusion. See
46+
* `org.apache.spark.sql.connector.catalog.quoted` for implementation details.
47+
*
48+
* BinaryType, UserDefinedType, ArrayType and MapType are ignored.
4249
*/
43-
protected[sql] def isSearchableType(dataType: DataType) = dataType match {
44-
case BinaryType => false
45-
case _: AtomicType => true
46-
case _ => false
50+
protected[sql] def getSearchableTypeMap(
51+
schema: StructType,
52+
caseSensitive: Boolean): Map[String, DataType] = {
53+
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper
54+
55+
def getPrimitiveFields(
56+
fields: Seq[StructField],
57+
parentFieldNames: Seq[String] = Seq.empty): Seq[(String, DataType)] = {
58+
fields.flatMap { f =>
59+
f.dataType match {
60+
case st: StructType =>
61+
getPrimitiveFields(st.fields, parentFieldNames :+ f.name)
62+
case BinaryType => None
63+
case _: AtomicType =>
64+
Some(((parentFieldNames :+ f.name).quoted, f.dataType))
65+
case _ => None
66+
}
67+
}
68+
}
69+
70+
val primitiveFields = getPrimitiveFields(schema.fields)
71+
if (caseSensitive) {
72+
primitiveFields.toMap
73+
} else {
74+
// Don't consider ambiguity here, i.e. more than one field are matched in case insensitive
75+
// mode, just skip pushdown for these fields, they will trigger Exception when reading,
76+
// See: SPARK-25175.
77+
val dedupPrimitiveFields = primitiveFields
78+
.groupBy(_._1.toLowerCase(Locale.ROOT))
79+
.filter(_._2.size == 1)
80+
.mapValues(_.head._2)
81+
CaseInsensitiveMap(dedupPrimitiveFields)
82+
}
4783
}
4884
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters}
2727
import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
2828
import org.apache.spark.sql.execution.datasources.orc.OrcFilters
2929
import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder
30+
import org.apache.spark.sql.internal.SQLConf
3031
import org.apache.spark.sql.sources.Filter
3132
import org.apache.spark.sql.types.StructType
3233
import org.apache.spark.sql.util.CaseInsensitiveStringMap
@@ -60,10 +61,8 @@ case class OrcScanBuilder(
6061
// changed `hadoopConf` in executors.
6162
OrcInputFormat.setSearchArgument(hadoopConf, f, schema.fieldNames)
6263
}
63-
val dataTypeMap = schema.map(f => quoteIfNeeded(f.name) -> f.dataType).toMap
64-
// TODO (SPARK-25557): ORC doesn't support nested predicate pushdown, so they are removed.
65-
val newFilters = filters.filter(!_.containsNestedColumn)
66-
_pushedFilters = OrcFilters.convertibleFilters(schema, dataTypeMap, newFilters).toArray
64+
val dataTypeMap = OrcFilters.getSearchableTypeMap(schema, SQLConf.get.caseSensitiveAnalysis)
65+
_pushedFilters = OrcFilters.convertibleFilters(schema, dataTypeMap, filters).toArray
6766
}
6867
filters
6968
}

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileBasedDataSourceTest.scala

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@ import java.io.File
2222
import scala.reflect.ClassTag
2323
import scala.reflect.runtime.universe.TypeTag
2424

25-
import org.apache.spark.sql.{DataFrame, SaveMode}
25+
import org.apache.spark.sql.{DataFrame, Row, SaveMode}
26+
import org.apache.spark.sql.functions.struct
2627
import org.apache.spark.sql.test.SQLTestUtils
28+
import org.apache.spark.sql.types.StructType
2729

2830
/**
2931
* A helper trait that provides convenient facilities for file-based data source testing.
@@ -103,4 +105,40 @@ private[sql] trait FileBasedDataSourceTest extends SQLTestUtils {
103105
df: DataFrame, path: File): Unit = {
104106
df.write.mode(SaveMode.Overwrite).format(dataSourceName).save(path.getCanonicalPath)
105107
}
108+
109+
/**
110+
* Takes single level `inputDF` dataframe to generate multi-level nested
111+
* dataframes as new test data. It tests both non-nested and nested dataframes
112+
* which are written and read back with specified datasource.
113+
*/
114+
protected def withNestedDataFrame(inputDF: DataFrame): Seq[(DataFrame, String, Any => Any)] = {
115+
assert(inputDF.schema.fields.length == 1)
116+
assert(!inputDF.schema.fields.head.dataType.isInstanceOf[StructType])
117+
val df = inputDF.toDF("temp")
118+
Seq(
119+
(
120+
df.withColumnRenamed("temp", "a"),
121+
"a", // zero nesting
122+
(x: Any) => x),
123+
(
124+
df.withColumn("a", struct(df("temp") as "b")).drop("temp"),
125+
"a.b", // one level nesting
126+
(x: Any) => Row(x)),
127+
(
128+
df.withColumn("a", struct(struct(df("temp") as "c") as "b")).drop("temp"),
129+
"a.b.c", // two level nesting
130+
(x: Any) => Row(Row(x))
131+
),
132+
(
133+
df.withColumnRenamed("temp", "a.b"),
134+
"`a.b`", // zero nesting with column name containing `dots`
135+
(x: Any) => x
136+
),
137+
(
138+
df.withColumn("a.b", struct(df("temp") as "c.d") ).drop("temp"),
139+
"`a.b`.`c.d`", // one level nesting with column names containing `dots`
140+
(x: Any) => Row(x)
141+
)
142+
)
143+
}
106144
}

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,4 +143,26 @@ abstract class OrcTest extends QueryTest with FileBasedDataSourceTest with Befor
143143
FileUtils.copyURLToFile(url, file)
144144
spark.read.orc(file.getAbsolutePath)
145145
}
146+
147+
/**
148+
* Takes a sequence of products `data` to generate multi-level nested
149+
* dataframes as new test data. It tests both non-nested and nested dataframes
150+
* which are written and read back with Orc datasource.
151+
*
152+
* This is different from [[withOrcDataFrame]] which does not
153+
* test nested cases.
154+
*/
155+
protected def withNestedOrcDataFrame[T <: Product: ClassTag: TypeTag](data: Seq[T])
156+
(runTest: (DataFrame, String, Any => Any) => Unit): Unit =
157+
withNestedOrcDataFrame(spark.createDataFrame(data))(runTest)
158+
159+
protected def withNestedOrcDataFrame(inputDF: DataFrame)
160+
(runTest: (DataFrame, String, Any => Any) => Unit): Unit = {
161+
withNestedDataFrame(inputDF).foreach { case (newDF, colName, resultFun) =>
162+
withTempPath { file =>
163+
newDF.write.format(dataSourceName).save(file.getCanonicalPath)
164+
readFile(file.getCanonicalPath, true) { df => runTest(df, colName, resultFun) }
165+
}
166+
}
167+
}
146168
}

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -122,34 +122,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
122122

123123
private def withNestedParquetDataFrame(inputDF: DataFrame)
124124
(runTest: (DataFrame, String, Any => Any) => Unit): Unit = {
125-
assert(inputDF.schema.fields.length == 1)
126-
assert(!inputDF.schema.fields.head.dataType.isInstanceOf[StructType])
127-
val df = inputDF.toDF("temp")
128-
Seq(
129-
(
130-
df.withColumnRenamed("temp", "a"),
131-
"a", // zero nesting
132-
(x: Any) => x),
133-
(
134-
df.withColumn("a", struct(df("temp") as "b")).drop("temp"),
135-
"a.b", // one level nesting
136-
(x: Any) => Row(x)),
137-
(
138-
df.withColumn("a", struct(struct(df("temp") as "c") as "b")).drop("temp"),
139-
"a.b.c", // two level nesting
140-
(x: Any) => Row(Row(x))
141-
),
142-
(
143-
df.withColumnRenamed("temp", "a.b"),
144-
"`a.b`", // zero nesting with column name containing `dots`
145-
(x: Any) => x
146-
),
147-
(
148-
df.withColumn("a.b", struct(df("temp") as "c.d") ).drop("temp"),
149-
"`a.b`.`c.d`", // one level nesting with column names containing `dots`
150-
(x: Any) => Row(x)
151-
)
152-
).foreach { case (newDF, colName, resultFun) =>
125+
withNestedDataFrame(inputDF).foreach { case (newDF, colName, resultFun) =>
153126
withTempPath { file =>
154127
newDF.write.format(dataSourceName).save(file.getCanonicalPath)
155128
readParquetFile(file.getCanonicalPath) { df => runTest(df, colName, resultFun) }

sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import org.apache.orc.storage.serde2.io.HiveDecimalWritable
2727

2828
import org.apache.spark.SparkException
2929
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localDateToDays, toJavaDate, toJavaTimestamp}
30-
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.quoteIfNeeded
30+
import org.apache.spark.sql.internal.SQLConf
3131
import org.apache.spark.sql.sources.Filter
3232
import org.apache.spark.sql.types._
3333

@@ -68,11 +68,9 @@ private[sql] object OrcFilters extends OrcFiltersBase {
6868
* Create ORC filter as a SearchArgument instance.
6969
*/
7070
def createFilter(schema: StructType, filters: Seq[Filter]): Option[SearchArgument] = {
71-
val dataTypeMap = schema.map(f => quoteIfNeeded(f.name) -> f.dataType).toMap
71+
val dataTypeMap = OrcFilters.getSearchableTypeMap(schema, SQLConf.get.caseSensitiveAnalysis)
7272
// Combines all convertible filters using `And` to produce a single conjunction
73-
// TODO (SPARK-25557): ORC doesn't support nested predicate pushdown, so they are removed.
74-
val newFilters = filters.filter(!_.containsNestedColumn)
75-
val conjunctionOptional = buildTree(convertibleFilters(schema, dataTypeMap, newFilters))
73+
val conjunctionOptional = buildTree(convertibleFilters(schema, dataTypeMap, filters))
7674
conjunctionOptional.map { conjunction =>
7775
// Then tries to build a single ORC `SearchArgument` for the conjunction predicate.
7876
// The input predicate is fully convertible. There should not be any empty result in the
@@ -228,40 +226,38 @@ private[sql] object OrcFilters extends OrcFiltersBase {
228226
// NOTE: For all case branches dealing with leaf predicates below, the additional `startAnd()`
229227
// call is mandatory. ORC `SearchArgument` builder requires that all leaf predicates must be
230228
// wrapped by a "parent" predicate (`And`, `Or`, or `Not`).
231-
// Since ORC 1.5.0 (ORC-323), we need to quote for column names with `.` characters
232-
// in order to distinguish predicate pushdown for nested columns.
233229
expression match {
234-
case EqualTo(name, value) if isSearchableType(dataTypeMap(name)) =>
230+
case EqualTo(name, value) if dataTypeMap.contains(name) =>
235231
val castedValue = castLiteralValue(value, dataTypeMap(name))
236232
Some(builder.startAnd().equals(name, getType(name), castedValue).end())
237233

238-
case EqualNullSafe(name, value) if isSearchableType(dataTypeMap(name)) =>
234+
case EqualNullSafe(name, value) if dataTypeMap.contains(name) =>
239235
val castedValue = castLiteralValue(value, dataTypeMap(name))
240236
Some(builder.startAnd().nullSafeEquals(name, getType(name), castedValue).end())
241237

242-
case LessThan(name, value) if isSearchableType(dataTypeMap(name)) =>
238+
case LessThan(name, value) if dataTypeMap.contains(name) =>
243239
val castedValue = castLiteralValue(value, dataTypeMap(name))
244240
Some(builder.startAnd().lessThan(name, getType(name), castedValue).end())
245241

246-
case LessThanOrEqual(name, value) if isSearchableType(dataTypeMap(name)) =>
242+
case LessThanOrEqual(name, value) if dataTypeMap.contains(name) =>
247243
val castedValue = castLiteralValue(value, dataTypeMap(name))
248244
Some(builder.startAnd().lessThanEquals(name, getType(name), castedValue).end())
249245

250-
case GreaterThan(name, value) if isSearchableType(dataTypeMap(name)) =>
246+
case GreaterThan(name, value) if dataTypeMap.contains(name) =>
251247
val castedValue = castLiteralValue(value, dataTypeMap(name))
252248
Some(builder.startNot().lessThanEquals(name, getType(name), castedValue).end())
253249

254-
case GreaterThanOrEqual(name, value) if isSearchableType(dataTypeMap(name)) =>
250+
case GreaterThanOrEqual(name, value) if dataTypeMap.contains(name) =>
255251
val castedValue = castLiteralValue(value, dataTypeMap(name))
256252
Some(builder.startNot().lessThan(name, getType(name), castedValue).end())
257253

258-
case IsNull(name) if isSearchableType(dataTypeMap(name)) =>
254+
case IsNull(name) if dataTypeMap.contains(name) =>
259255
Some(builder.startAnd().isNull(name, getType(name)).end())
260256

261-
case IsNotNull(name) if isSearchableType(dataTypeMap(name)) =>
257+
case IsNotNull(name) if dataTypeMap.contains(name) =>
262258
Some(builder.startNot().isNull(name, getType(name)).end())
263259

264-
case In(name, values) if isSearchableType(dataTypeMap(name)) =>
260+
case In(name, values) if dataTypeMap.contains(name) =>
265261
val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(name)))
266262
Some(builder.startAnd().in(name, getType(name),
267263
castedValues.map(_.asInstanceOf[AnyRef]): _*).end())

0 commit comments

Comments
 (0)