diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index b5cf8c9515bf..9a7d4783222c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -432,53 +432,65 @@ object DataSourceStrategy { * @return a `Some[Filter]` if the input [[Expression]] is convertible, otherwise a `None`. */ protected[sql] def translateFilter(predicate: Expression): Option[Filter] = { + // Recursively try to find an attribute name from the top level that can be pushed down. + def attrName(e: Expression): Option[String] = e match { + // In Spark and many data sources such as parquet, dots are used as a column path delimiter; + // thus, we don't translate such expressions. + case a: Attribute if !a.name.contains(".") => + Some(a.name) + case s: GetStructField if !s.childSchema(s.ordinal).name.contains(".") => + attrName(s.child).map(_ + s".${s.childSchema(s.ordinal).name}") + case _ => + None + } + predicate match { - case expressions.EqualTo(a: Attribute, Literal(v, t)) => - Some(sources.EqualTo(a.name, convertToScala(v, t))) - case expressions.EqualTo(Literal(v, t), a: Attribute) => - Some(sources.EqualTo(a.name, convertToScala(v, t))) - - case expressions.EqualNullSafe(a: Attribute, Literal(v, t)) => - Some(sources.EqualNullSafe(a.name, convertToScala(v, t))) - case expressions.EqualNullSafe(Literal(v, t), a: Attribute) => - Some(sources.EqualNullSafe(a.name, convertToScala(v, t))) - - case expressions.GreaterThan(a: Attribute, Literal(v, t)) => - Some(sources.GreaterThan(a.name, convertToScala(v, t))) - case expressions.GreaterThan(Literal(v, t), a: Attribute) => - Some(sources.LessThan(a.name, convertToScala(v, t))) - - case expressions.LessThan(a: Attribute, Literal(v, t)) => - Some(sources.LessThan(a.name, convertToScala(v, t))) - case expressions.LessThan(Literal(v, t), a: Attribute) => - Some(sources.GreaterThan(a.name, convertToScala(v, t))) - - case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, t)) => - Some(sources.GreaterThanOrEqual(a.name, convertToScala(v, t))) - case expressions.GreaterThanOrEqual(Literal(v, t), a: Attribute) => - Some(sources.LessThanOrEqual(a.name, convertToScala(v, t))) - - case expressions.LessThanOrEqual(a: Attribute, Literal(v, t)) => - Some(sources.LessThanOrEqual(a.name, convertToScala(v, t))) - case expressions.LessThanOrEqual(Literal(v, t), a: Attribute) => - Some(sources.GreaterThanOrEqual(a.name, convertToScala(v, t))) - - case expressions.InSet(a: Attribute, set) => - val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType) - Some(sources.In(a.name, set.toArray.map(toScala))) + case expressions.EqualTo(e: Expression, Literal(v, t)) => + attrName(e).map(name => sources.EqualTo(name, convertToScala(v, t))) + case expressions.EqualTo(Literal(v, t), e: Expression) => + attrName(e).map(name => sources.EqualTo(name, convertToScala(v, t))) + + case expressions.EqualNullSafe(e: Expression, Literal(v, t)) => + attrName(e).map(name => sources.EqualNullSafe(name, convertToScala(v, t))) + case expressions.EqualNullSafe(Literal(v, t), e: Expression) => + attrName(e).map(name => sources.EqualNullSafe(name, convertToScala(v, t))) + + case expressions.GreaterThan(e: Expression, Literal(v, t)) => + attrName(e).map(name => sources.GreaterThan(name, convertToScala(v, t))) + case expressions.GreaterThan(Literal(v, t), e: Expression) => + attrName(e).map(name => sources.LessThan(name, convertToScala(v, t))) + + case expressions.LessThan(e: Expression, Literal(v, t)) => + attrName(e).map(name => sources.LessThan(name, convertToScala(v, t))) + case expressions.LessThan(Literal(v, t), e: Expression) => + attrName(e).map(name => sources.GreaterThan(name, convertToScala(v, t))) + + case expressions.GreaterThanOrEqual(e: Expression, Literal(v, t)) => + attrName(e).map(name => sources.GreaterThanOrEqual(name, convertToScala(v, t))) + case expressions.GreaterThanOrEqual(Literal(v, t), e: Expression) => + attrName(e).map(name => sources.LessThanOrEqual(name, convertToScala(v, t))) + + case expressions.LessThanOrEqual(e: Expression, Literal(v, t)) => + attrName(e).map(name => sources.LessThanOrEqual(name, convertToScala(v, t))) + case expressions.LessThanOrEqual(Literal(v, t), e: Expression) => + attrName(e).map(name => sources.GreaterThanOrEqual(name, convertToScala(v, t))) + + case expressions.InSet(e: Expression, set) => + val toScala = CatalystTypeConverters.createToScalaConverter(e.dataType) + attrName(e).map(name => sources.In(name, set.toArray.map(toScala))) // Because we only convert In to InSet in Optimizer when there are more than certain // items. So it is possible we still get an In expression here that needs to be pushed // down. - case expressions.In(a: Attribute, list) if !list.exists(!_.isInstanceOf[Literal]) => + case expressions.In(e: Expression, list) if list.forall(_.isInstanceOf[Literal]) => val hSet = list.map(e => e.eval(EmptyRow)) - val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType) - Some(sources.In(a.name, hSet.toArray.map(toScala))) + val toScala = CatalystTypeConverters.createToScalaConverter(e.dataType) + attrName(e).map(name => sources.In(name, hSet.toArray.map(toScala))) - case expressions.IsNull(a: Attribute) => - Some(sources.IsNull(a.name)) - case expressions.IsNotNull(a: Attribute) => - Some(sources.IsNotNull(a.name)) + case expressions.IsNull(e: Expression) => + attrName(e).map(name => sources.IsNull(name)) + case expressions.IsNotNull(e: Expression) => + attrName(e).map(name => sources.IsNotNull(name)) case expressions.And(left, right) => // See SPARK-12218 for detailed discussion @@ -504,14 +516,14 @@ object DataSourceStrategy { case expressions.Not(child) => translateFilter(child).map(sources.Not) - case expressions.StartsWith(a: Attribute, Literal(v: UTF8String, StringType)) => - Some(sources.StringStartsWith(a.name, v.toString)) + case expressions.StartsWith(e: Expression, Literal(v: UTF8String, StringType)) => + attrName(e).map(name => sources.StringStartsWith(name, v.toString)) - case expressions.EndsWith(a: Attribute, Literal(v: UTF8String, StringType)) => - Some(sources.StringEndsWith(a.name, v.toString)) + case expressions.EndsWith(e: Expression, Literal(v: UTF8String, StringType)) => + attrName(e).map(name => sources.StringEndsWith(name, v.toString)) - case expressions.Contains(a: Attribute, Literal(v: UTF8String, StringType)) => - Some(sources.StringContains(a.name, v.toString)) + case expressions.Contains(e: Expression, Literal(v: UTF8String, StringType)) => + attrName(e).map(name => sources.StringContains(name, v.toString)) case _ => None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 21ab9c78e53d..0e57c7de9222 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -27,7 +27,7 @@ import scala.collection.JavaConverters.asScalaBufferConverter import org.apache.parquet.filter2.predicate._ import org.apache.parquet.filter2.predicate.FilterApi._ import org.apache.parquet.io.api.Binary -import org.apache.parquet.schema.{DecimalMetadata, MessageType, OriginalType, PrimitiveComparator} +import org.apache.parquet.schema._ import org.apache.parquet.schema.OriginalType._ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ @@ -54,7 +54,7 @@ private[parquet] class ParquetFilters( * @param fieldName field name in parquet file * @param fieldType field type related info in parquet file */ - private case class ParquetField( + private case class ParquetPrimitiveField ( fieldName: String, fieldType: ParquetSchemaType) @@ -364,16 +364,46 @@ private[parquet] class ParquetFilters( /** * Returns a map, which contains parquet field name and data type, if predicate push down applies. */ - private def getFieldMap(dataType: MessageType): Map[String, ParquetField] = { - // Here we don't flatten the fields in the nested schema but just look up through - // root fields. Currently, accessing to nested fields does not push down filters - // and it does not support to create filters for them. - val primitiveFields = - dataType.getFields.asScala.filter(_.isPrimitive).map(_.asPrimitiveType()).map { f => - f.getName -> ParquetField(f.getName, - ParquetSchemaType(f.getOriginalType, - f.getPrimitiveTypeName, f.getTypeLength, f.getDecimalMetadata)) + private def getFieldMap(dataType: MessageType): Map[String, ParquetPrimitiveField] = { + def canPushDownField(field: Type): Boolean = { + if (field.getName.contains(".")) { + // Parquet does not allow dots in the column name because dots are used as a column path + // delimiter. Since Parquet 1.8.2 (PARQUET-389), Parquet accepts the filter predicates + // with missing columns. The incorrect results could be got from Parquet when we push down + // filters for the column having dots in the names. Thus, we do not push down such filters. + // See SPARK-20364. + false + } else { + field match { + case _: PrimitiveType => true + // Parquet only supports push-down for primitive types; as a result, Map and List types + // are filtered out. FYI, when g is a `Struct`, `g.getOriginalType` is `null`. + // When g is a `Map`, `g.getOriginalType` is `MAP`. + // When g is a `List`, `g.getOriginalType` is `LIST`. + case g: GroupType if g.getOriginalType == null => true + case _ => false + } + } + } + + def getFieldMapHelper( + fields: Seq[Type], + baseName: Option[String] = None): Seq[(String, ParquetPrimitiveField)] = { + fields.filter(canPushDownField).flatMap { field => + val name = baseName.map(_ + "." + field.getName).getOrElse(field.getName) + field match { + case p: PrimitiveType => + val primitiveField = ParquetPrimitiveField(fieldName = name, + fieldType = ParquetSchemaType(p.getOriginalType, + p.getPrimitiveTypeName, p.getTypeLength, p.getDecimalMetadata)) + Some((name, primitiveField)) + case g: GroupType => + getFieldMapHelper(g.getFields.asScala, Some(name)) + } } + } + + val primitiveFields = getFieldMapHelper(dataType.getFields.asScala) if (caseSensitive) { primitiveFields.toMap } else { @@ -393,8 +423,7 @@ private[parquet] class ParquetFilters( * Converts data sources filters to Parquet filter predicates. */ def createFilter(schema: MessageType, predicate: sources.Filter): Option[FilterPredicate] = { - val nameToParquetField = getFieldMap(schema) - createFilterHelper(nameToParquetField, predicate, canPartialPushDownConjuncts = true) + createFilterHelper(schema, predicate, canPartialPushDownConjuncts = true) } /** @@ -407,9 +436,11 @@ private[parquet] class ParquetFilters( * @return the Parquet-native filter predicates that are eligible for pushdown. */ private def createFilterHelper( - nameToParquetField: Map[String, ParquetField], + schema: MessageType, predicate: sources.Filter, canPartialPushDownConjuncts: Boolean): Option[FilterPredicate] = { + val nameToParquetField = getFieldMap(schema) + // Decimal type must make sure that filter value's scale matched the file. // If doesn't matched, which would cause data corruption. def isDecimalMatched(value: Any, decimalMeta: DecimalMetadata): Boolean = value match { @@ -442,13 +473,8 @@ private[parquet] class ParquetFilters( }) } - // Parquet does not allow dots in the column name because dots are used as a column path - // delimiter. Since Parquet 1.8.2 (PARQUET-389), Parquet accepts the filter predicates - // with missing columns. The incorrect results could be got from Parquet when we push down - // filters for the column having dots in the names. Thus, we do not push down such filters. - // See SPARK-20364. def canMakeFilterOn(name: String, value: Any): Boolean = { - nameToParquetField.contains(name) && !name.contains(".") && valueCanMakeFilterOn(name, value) + nameToParquetField.contains(name) && valueCanMakeFilterOn(name, value) } // NOTE: @@ -515,9 +541,9 @@ private[parquet] class ParquetFilters( // AND before hitting NOT or OR conditions, and in this case, the unsupported predicate // can be safely removed. val lhsFilterOption = - createFilterHelper(nameToParquetField, lhs, canPartialPushDownConjuncts) + createFilterHelper(schema, lhs, canPartialPushDownConjuncts) val rhsFilterOption = - createFilterHelper(nameToParquetField, rhs, canPartialPushDownConjuncts) + createFilterHelper(schema, rhs, canPartialPushDownConjuncts) (lhsFilterOption, rhsFilterOption) match { case (Some(lhsFilter), Some(rhsFilter)) => Some(FilterApi.and(lhsFilter, rhsFilter)) @@ -529,13 +555,13 @@ private[parquet] class ParquetFilters( case sources.Or(lhs, rhs) => for { lhsFilter <- - createFilterHelper(nameToParquetField, lhs, canPartialPushDownConjuncts = false) + createFilterHelper(schema, lhs, canPartialPushDownConjuncts = false) rhsFilter <- - createFilterHelper(nameToParquetField, rhs, canPartialPushDownConjuncts = false) + createFilterHelper(schema, rhs, canPartialPushDownConjuncts = false) } yield FilterApi.or(lhsFilter, rhsFilter) case sources.Not(pred) => - createFilterHelper(nameToParquetField, pred, canPartialPushDownConjuncts = false) + createFilterHelper(schema, pred, canPartialPushDownConjuncts = false) .map(FilterApi.not) case sources.In(name, values) if canMakeFilterOn(name, values.head) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala index f20aded169e4..e64e47874903 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala @@ -22,76 +22,178 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.sources import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} class DataSourceStrategySuite extends PlanTest with SharedSQLContext { test("translate simple expression") { - val attrInt = 'cint.int - val attrStr = 'cstr.string + val fields = StructField("cint", IntegerType, nullable = true) :: + StructField("cstr", StringType, nullable = true) :: Nil - testTranslateFilter(EqualTo(attrInt, 1), Some(sources.EqualTo("cint", 1))) - testTranslateFilter(EqualTo(1, attrInt), Some(sources.EqualTo("cint", 1))) + val attrNested1 = 'a.struct(StructType(fields)) + val attrNested2 = 'b.struct(StructType( + StructField("c", StructType(fields), nullable = true) :: Nil)) + + val attrIntNested1 = GetStructField(attrNested1, 0, None) + val attrStrNested1 = GetStructField(attrNested1, 1, None) + + val attrIntNested2 = GetStructField(GetStructField(attrNested2, 0, None), 0, None) + val attrStrNested2 = GetStructField(GetStructField(attrNested2, 0, None), 1, None) + + Seq(('cint.int, 'cstr.string, "cint", "cstr"), // no nesting + (attrIntNested1, attrStrNested1, "a.cint", "a.cstr"), // one level nesting + (attrIntNested2, attrStrNested2, "b.c.cint", "b.c.cstr") // two level nesting + ).foreach { case (attrInt, attrStr, attrIntString, attrStrString) => + testTranslateSimpleExpression( + attrInt, attrStr, attrIntString, attrStrString, isResultNone = false) + } + } + + test("translate complex expression") { + val fields = StructField("cint", IntegerType, nullable = true) :: Nil + + val attrNested1 = 'a.struct(StructType(fields)) + val attrNested2 = 'b.struct(StructType( + StructField("c", StructType(fields), nullable = true) :: Nil)) + + val attrIntNested1 = GetStructField(attrNested1, 0, None) + val attrIntNested2 = GetStructField(GetStructField(attrNested2, 0, None), 0, None) + + StructField("cint", IntegerType, nullable = true) + + Seq(('cint.int, "cint"), // no nesting + (attrIntNested1, "a.cint"), // one level nesting + (attrIntNested2, "b.c.cint") // two level nesting + ).foreach { case (attrInt, attrIntString) => + testTranslateComplexExpression(attrInt, attrIntString, isResultNone = false) + } + } + + test("column name containing dot can not be pushed down") { + val fieldsWithoutDot = StructField("cint", IntegerType, nullable = true) :: + StructField("cstr", StringType, nullable = true) :: Nil + + val fieldsWithDot = StructField("column.cint", IntegerType, nullable = true) :: + StructField("column.cstr", StringType, nullable = true) :: Nil + + val attrNested1 = 'a.struct(StructType(fieldsWithDot)) + val attrIntNested1 = GetStructField(attrNested1, 0, None) + val attrStrNested1 = GetStructField(attrNested1, 1, None) + + val attrNested2 = 'b.struct(StructType( + StructField("c", StructType(fieldsWithDot), nullable = true) :: Nil)) + val attrIntNested2 = GetStructField(GetStructField(attrNested2, 0, None), 0, None) + val attrStrNested2 = GetStructField(GetStructField(attrNested2, 0, None), 1, None) + + val attrNestedWithDotInTopLevel = Symbol("column.a").struct(StructType(fieldsWithoutDot)) + val attrIntNested1WithDotInTopLevel = GetStructField(attrNestedWithDotInTopLevel, 0, None) + val attrStrNested1WithDotInTopLevel = GetStructField(attrNestedWithDotInTopLevel, 1, None) + + Seq((Symbol("column.cint").int, Symbol("column.cstr").string), // no nesting + (attrIntNested1, attrStrNested1), // one level nesting + (attrIntNested1WithDotInTopLevel, attrStrNested1WithDotInTopLevel), // one level nesting + (attrIntNested2, attrStrNested2) // two level nesting + ).foreach { case (attrInt, attrStr) => + testTranslateSimpleExpression( + attrInt, attrStr, "", "", isResultNone = true) + testTranslateComplexExpression(attrInt, "", isResultNone = true) + } + } + + // `isResultNone` is used when testing invalid input expression + // containing dots which translates into None + private def testTranslateSimpleExpression( + attrInt: Expression, attrStr: Expression, + attrIntString: String, attrStrString: String, isResultNone: Boolean): Unit = { + + def result(result: sources.Filter): Option[sources.Filter] = { + if (isResultNone) { + None + } else { + Some(result) + } + } + + testTranslateFilter(EqualTo(attrInt, 1), result(sources.EqualTo(attrIntString, 1))) + testTranslateFilter(EqualTo(1, attrInt), result(sources.EqualTo(attrIntString, 1))) testTranslateFilter(EqualNullSafe(attrStr, Literal(null)), - Some(sources.EqualNullSafe("cstr", null))) + result(sources.EqualNullSafe(attrStrString, null))) testTranslateFilter(EqualNullSafe(Literal(null), attrStr), - Some(sources.EqualNullSafe("cstr", null))) + result(sources.EqualNullSafe(attrStrString, null))) - testTranslateFilter(GreaterThan(attrInt, 1), Some(sources.GreaterThan("cint", 1))) - testTranslateFilter(GreaterThan(1, attrInt), Some(sources.LessThan("cint", 1))) + testTranslateFilter(GreaterThan(attrInt, 1), result(sources.GreaterThan(attrIntString, 1))) + testTranslateFilter(GreaterThan(1, attrInt), result(sources.LessThan(attrIntString, 1))) - testTranslateFilter(LessThan(attrInt, 1), Some(sources.LessThan("cint", 1))) - testTranslateFilter(LessThan(1, attrInt), Some(sources.GreaterThan("cint", 1))) + testTranslateFilter(LessThan(attrInt, 1), result(sources.LessThan(attrIntString, 1))) + testTranslateFilter(LessThan(1, attrInt), result(sources.GreaterThan(attrIntString, 1))) - testTranslateFilter(GreaterThanOrEqual(attrInt, 1), Some(sources.GreaterThanOrEqual("cint", 1))) - testTranslateFilter(GreaterThanOrEqual(1, attrInt), Some(sources.LessThanOrEqual("cint", 1))) + testTranslateFilter(GreaterThanOrEqual(attrInt, 1), + result(sources.GreaterThanOrEqual(attrIntString, 1))) + testTranslateFilter(GreaterThanOrEqual(1, attrInt), + result(sources.LessThanOrEqual(attrIntString, 1))) - testTranslateFilter(LessThanOrEqual(attrInt, 1), Some(sources.LessThanOrEqual("cint", 1))) - testTranslateFilter(LessThanOrEqual(1, attrInt), Some(sources.GreaterThanOrEqual("cint", 1))) + testTranslateFilter(LessThanOrEqual(attrInt, 1), + result(sources.LessThanOrEqual(attrIntString, 1))) + testTranslateFilter(LessThanOrEqual(1, attrInt), + result(sources.GreaterThanOrEqual(attrIntString, 1))) - testTranslateFilter(InSet(attrInt, Set(1, 2, 3)), Some(sources.In("cint", Array(1, 2, 3)))) + testTranslateFilter(InSet(attrInt, Set(1, 2, 3)), + result(sources.In(attrIntString, Array(1, 2, 3)))) - testTranslateFilter(In(attrInt, Seq(1, 2, 3)), Some(sources.In("cint", Array(1, 2, 3)))) + testTranslateFilter(In(attrInt, Seq(1, 2, 3)), + result(sources.In(attrIntString, Array(1, 2, 3)))) - testTranslateFilter(IsNull(attrInt), Some(sources.IsNull("cint"))) - testTranslateFilter(IsNotNull(attrInt), Some(sources.IsNotNull("cint"))) + testTranslateFilter(IsNull(attrInt), result(sources.IsNull(attrIntString))) + testTranslateFilter(IsNotNull(attrInt), result(sources.IsNotNull(attrIntString))) - // cint > 1 AND cint < 10 + // attrInt > 1 AND attrInt < 10 testTranslateFilter(And( GreaterThan(attrInt, 1), LessThan(attrInt, 10)), - Some(sources.And( - sources.GreaterThan("cint", 1), - sources.LessThan("cint", 10)))) + result(sources.And( + sources.GreaterThan(attrIntString, 1), + sources.LessThan(attrIntString, 10)))) - // cint >= 8 OR cint <= 2 + // attrInt >= 8 OR attrInt <= 2 testTranslateFilter(Or( GreaterThanOrEqual(attrInt, 8), LessThanOrEqual(attrInt, 2)), - Some(sources.Or( - sources.GreaterThanOrEqual("cint", 8), - sources.LessThanOrEqual("cint", 2)))) + result(sources.Or( + sources.GreaterThanOrEqual(attrIntString, 8), + sources.LessThanOrEqual(attrIntString, 2)))) testTranslateFilter(Not(GreaterThanOrEqual(attrInt, 8)), - Some(sources.Not(sources.GreaterThanOrEqual("cint", 8)))) + result(sources.Not(sources.GreaterThanOrEqual(attrIntString, 8)))) - testTranslateFilter(StartsWith(attrStr, "a"), Some(sources.StringStartsWith("cstr", "a"))) + testTranslateFilter(StartsWith(attrStr, "a"), + result(sources.StringStartsWith(attrStrString, "a"))) - testTranslateFilter(EndsWith(attrStr, "a"), Some(sources.StringEndsWith("cstr", "a"))) + testTranslateFilter(EndsWith(attrStr, "a"), result(sources.StringEndsWith(attrStrString, "a"))) - testTranslateFilter(Contains(attrStr, "a"), Some(sources.StringContains("cstr", "a"))) + testTranslateFilter(Contains(attrStr, "a"), result(sources.StringContains(attrStrString, "a"))) } - test("translate complex expression") { - val attrInt = 'cint.int + // `isResultNone` is used when testing invalid input expression + // containing dots which translates into None + private def testTranslateComplexExpression( + attrInt: Expression, attrIntString: String, isResultNone: Boolean): Unit = { + + def result(result: sources.Filter): Option[sources.Filter] = { + if (isResultNone) { + None + } else { + Some(result) + } + } - // ABS(cint) - 2 <= 1 + // ABS(attrInt) - 2 <= 1 testTranslateFilter(LessThanOrEqual( // Expressions are not supported // Functions such as 'Abs' are not supported Subtract(Abs(attrInt), 2), 1), None) - // (cin1 > 1 AND cint < 10) OR (cint > 50 AND cint > 100) + // (attrInt > 1 AND attrInt < 10) OR (attrInt > 50 AND attrInt > 100) testTranslateFilter(Or( And( GreaterThan(attrInt, 1), @@ -100,16 +202,16 @@ class DataSourceStrategySuite extends PlanTest with SharedSQLContext { And( GreaterThan(attrInt, 50), LessThan(attrInt, 100))), - Some(sources.Or( + result(sources.Or( sources.And( - sources.GreaterThan("cint", 1), - sources.LessThan("cint", 10)), + sources.GreaterThan(attrIntString, 1), + sources.LessThan(attrIntString, 10)), sources.And( - sources.GreaterThan("cint", 50), - sources.LessThan("cint", 100))))) + sources.GreaterThan(attrIntString, 50), + sources.LessThan(attrIntString, 100))))) // SPARK-22548 Incorrect nested AND expression pushed down to JDBC data source - // (cint > 1 AND ABS(cint) < 10) OR (cint < 50 AND cint > 100) + // (attrInt > 1 AND ABS(attrInt) < 10) OR (attrInt < 50 AND attrInt > 100) testTranslateFilter(Or( And( GreaterThan(attrInt, 1), @@ -120,7 +222,7 @@ class DataSourceStrategySuite extends PlanTest with SharedSQLContext { GreaterThan(attrInt, 50), LessThan(attrInt, 100))), None) - // NOT ((cint <= 1 OR ABS(cint) >= 10) AND (cint <= 50 OR cint >= 100)) + // NOT ((attrInt <= 1 OR ABS(attrInt) >= 10) AND (attrInt <= 50 OR attrInt >= 100)) testTranslateFilter(Not(And( Or( LessThanOrEqual(attrInt, 1), @@ -131,7 +233,7 @@ class DataSourceStrategySuite extends PlanTest with SharedSQLContext { LessThanOrEqual(attrInt, 50), GreaterThanOrEqual(attrInt, 100)))), None) - // (cint = 1 OR cint = 10) OR (cint > 0 OR cint < -10) + // (attrInt = 1 OR attrInt = 10) OR (attrInt > 0 OR attrInt < -10) testTranslateFilter(Or( Or( EqualTo(attrInt, 1), @@ -140,15 +242,15 @@ class DataSourceStrategySuite extends PlanTest with SharedSQLContext { Or( GreaterThan(attrInt, 0), LessThan(attrInt, -10))), - Some(sources.Or( + result(sources.Or( sources.Or( - sources.EqualTo("cint", 1), - sources.EqualTo("cint", 10)), + sources.EqualTo(attrIntString, 1), + sources.EqualTo(attrIntString, 10)), sources.Or( - sources.GreaterThan("cint", 0), - sources.LessThan("cint", -10))))) + sources.GreaterThan(attrIntString, 0), + sources.LessThan(attrIntString, -10))))) - // (cint = 1 OR ABS(cint) = 10) OR (cint > 0 OR cint < -10) + // (attrInt = 1 OR ABS(attrInt) = 10) OR (attrInt > 0 OR attrInt < -10) testTranslateFilter(Or( Or( EqualTo(attrInt, 1), @@ -162,7 +264,7 @@ class DataSourceStrategySuite extends PlanTest with SharedSQLContext { // In end-to-end testing, conjunctive predicate should has been split // before reaching DataSourceStrategy.translateFilter. // This is for UT purpose to test each [[case]]. - // (cint > 1 AND cint < 10) AND (cint = 6 AND cint IS NOT NULL) + // (attrInt > 1 AND attrInt < 10) AND (attrInt = 6 AND attrInt IS NOT NULL) testTranslateFilter(And( And( GreaterThan(attrInt, 1), @@ -171,15 +273,15 @@ class DataSourceStrategySuite extends PlanTest with SharedSQLContext { And( EqualTo(attrInt, 6), IsNotNull(attrInt))), - Some(sources.And( + result(sources.And( sources.And( - sources.GreaterThan("cint", 1), - sources.LessThan("cint", 10)), + sources.GreaterThan(attrIntString, 1), + sources.LessThan(attrIntString, 10)), sources.And( - sources.EqualTo("cint", 6), - sources.IsNotNull("cint"))))) + sources.EqualTo(attrIntString, 6), + sources.IsNotNull(attrIntString))))) - // (cint > 1 AND cint < 10) AND (ABS(cint) = 6 AND cint IS NOT NULL) + // (attrInt > 1 AND attrInt < 10) AND (ABS(attrInt) = 6 AND attrInt IS NOT NULL) testTranslateFilter(And( And( GreaterThan(attrInt, 1), @@ -190,7 +292,7 @@ class DataSourceStrategySuite extends PlanTest with SharedSQLContext { EqualTo(Abs(attrInt), 6), IsNotNull(attrInt))), None) - // (cint > 1 OR cint < 10) AND (cint = 6 OR cint IS NOT NULL) + // (attrInt > 1 OR attrInt < 10) AND (attrInt = 6 OR attrInt IS NOT NULL) testTranslateFilter(And( Or( GreaterThan(attrInt, 1), @@ -199,15 +301,15 @@ class DataSourceStrategySuite extends PlanTest with SharedSQLContext { Or( EqualTo(attrInt, 6), IsNotNull(attrInt))), - Some(sources.And( + result(sources.And( sources.Or( - sources.GreaterThan("cint", 1), - sources.LessThan("cint", 10)), + sources.GreaterThan(attrIntString, 1), + sources.LessThan(attrIntString, 10)), sources.Or( - sources.EqualTo("cint", 6), - sources.IsNotNull("cint"))))) + sources.EqualTo(attrIntString, 6), + sources.IsNotNull(attrIntString))))) - // (cint > 1 OR cint < 10) AND (cint = 6 OR cint IS NOT NULL) + // (attrInt > 1 OR attrInt < 10) AND (attrInt = 6 OR attrInt IS NOT NULL) testTranslateFilter(And( Or( GreaterThan(attrInt, 1), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 9cfc943cd2b3..239c43e7e10b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -26,7 +26,7 @@ import org.apache.parquet.filter2.predicate.FilterApi._ import org.apache.parquet.filter2.predicate.Operators.{Column => _, _} import org.apache.spark.SparkException -import org.apache.spark.sql._ +import org.apache.spark.sql.{Column, _} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation @@ -83,8 +83,6 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex filterClass: Class[_ <: FilterPredicate], checker: (DataFrame, Seq[Row]) => Unit, expected: Seq[Row]): Unit = { - val output = predicate.collect { case a: Attribute => a }.distinct - withSQLConf( SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true", SQLConf.PARQUET_FILTER_PUSHDOWN_DATE_ENABLED.key -> "true", @@ -93,7 +91,9 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex SQLConf.PARQUET_FILTER_PUSHDOWN_STRING_STARTSWITH_ENABLED.key -> "true", SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { val query = df - .select(output.map(e => Column(e)): _*) + // The following select will flatten the nested data structure, + // so comment it out for now until we find a better approach. + // .select(output.map(e => Column(e)): _*) .where(Column(predicate)) var maybeRelation: Option[HadoopFsRelation] = None @@ -150,7 +150,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex checkBinaryFilterPredicate(predicate, filterClass, Seq(Row(expected)))(df) } - private def testTimestampPushdown(data: Seq[Timestamp]): Unit = { + private def testTimestampPushDown(data: Seq[Timestamp]): Unit = { assert(data.size === 4) val ts1 = data.head val ts2 = data(1) @@ -215,15 +215,44 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } + case class N1[T](a: Option[T]) + + case class N2[T](b: Option[T]) + test("filter pushdown - boolean") { - withParquetDataFrame((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit df => + val data0 = (true :: false :: Nil).map(b => Tuple1.apply(Option(b))) + val data1 = data0.map(x => N1(Some(x))) + val data2 = data1.map(x => N2(Some(x))) + + // zero nesting + withParquetDataFrame(data0) { implicit df => checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], Seq(Row(true), Row(false))) - checkFilterPredicate('_1 === true, classOf[Eq[_]], true) checkFilterPredicate('_1 <=> true, classOf[Eq[_]], true) checkFilterPredicate('_1 =!= true, classOf[NotEq[_]], false) } + + // one level nesting + withParquetDataFrame(data1) { implicit df => + val col = Symbol("a._1") + checkFilterPredicate(col.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate(col.isNotNull, classOf[NotEq[_]], Seq(Row(Row(true)), Row(Row(false)))) + checkFilterPredicate(col === true, classOf[Eq[_]], Seq(Row(Row(true)))) + checkFilterPredicate(col <=> true, classOf[Eq[_]], Seq(Row(Row(true)))) + checkFilterPredicate(col =!= true, classOf[NotEq[_]], Seq(Row(Row(false)))) + } + + // two level nesting + withParquetDataFrame(data2) { implicit df => + val col = Symbol("b.a._1") + checkFilterPredicate(col.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate(col.isNotNull, classOf[NotEq[_]], + Seq(Row(Row(Row(true))), Row(Row(Row(false))))) + checkFilterPredicate(col === true, classOf[Eq[_]], Seq(Row(Row(Row(true))))) + checkFilterPredicate(col <=> true, classOf[Eq[_]], Seq(Row(Row(Row(true))))) + checkFilterPredicate(col =!= true, classOf[NotEq[_]], Seq(Row(Row(Row(false))))) + } } test("filter pushdown - tinyint") { @@ -498,7 +527,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex Timestamp.valueOf("2018-06-17 08:28:53.123")) withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> ParquetOutputTimestampType.TIMESTAMP_MILLIS.toString) { - testTimestampPushdown(millisData) + testTimestampPushDown(millisData) } // spark.sql.parquet.outputTimestampType = TIMESTAMP_MICROS @@ -508,7 +537,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex Timestamp.valueOf("2018-06-17 08:28:53.123456")) withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> ParquetOutputTimestampType.TIMESTAMP_MICROS.toString) { - testTimestampPushdown(microsData) + testTimestampPushDown(microsData) } // spark.sql.parquet.outputTimestampType = INT96 doesn't support pushdown