Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -432,53 +432,65 @@ object DataSourceStrategy {
* @return a `Some[Filter]` if the input [[Expression]] is convertible, otherwise a `None`.
*/
protected[sql] def translateFilter(predicate: Expression): Option[Filter] = {
// Recursively try to find an attribute name from the top level that can be pushed down.
def attrName(e: Expression): Option[String] = e match {
// In Spark and many data sources such as parquet, dots are used as a column path delimiter;
// thus, we don't translate such expressions.
case a: Attribute if !a.name.contains(".") =>
Some(a.name)
case s: GetStructField if !s.childSchema(s.ordinal).name.contains(".") =>
attrName(s.child).map(_ + s".${s.childSchema(s.ordinal).name}")
case _ =>
None
}

predicate match {
case expressions.EqualTo(a: Attribute, Literal(v, t)) =>
Some(sources.EqualTo(a.name, convertToScala(v, t)))
case expressions.EqualTo(Literal(v, t), a: Attribute) =>
Some(sources.EqualTo(a.name, convertToScala(v, t)))

case expressions.EqualNullSafe(a: Attribute, Literal(v, t)) =>
Some(sources.EqualNullSafe(a.name, convertToScala(v, t)))
case expressions.EqualNullSafe(Literal(v, t), a: Attribute) =>
Some(sources.EqualNullSafe(a.name, convertToScala(v, t)))

case expressions.GreaterThan(a: Attribute, Literal(v, t)) =>
Some(sources.GreaterThan(a.name, convertToScala(v, t)))
case expressions.GreaterThan(Literal(v, t), a: Attribute) =>
Some(sources.LessThan(a.name, convertToScala(v, t)))

case expressions.LessThan(a: Attribute, Literal(v, t)) =>
Some(sources.LessThan(a.name, convertToScala(v, t)))
case expressions.LessThan(Literal(v, t), a: Attribute) =>
Some(sources.GreaterThan(a.name, convertToScala(v, t)))

case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, t)) =>
Some(sources.GreaterThanOrEqual(a.name, convertToScala(v, t)))
case expressions.GreaterThanOrEqual(Literal(v, t), a: Attribute) =>
Some(sources.LessThanOrEqual(a.name, convertToScala(v, t)))

case expressions.LessThanOrEqual(a: Attribute, Literal(v, t)) =>
Some(sources.LessThanOrEqual(a.name, convertToScala(v, t)))
case expressions.LessThanOrEqual(Literal(v, t), a: Attribute) =>
Some(sources.GreaterThanOrEqual(a.name, convertToScala(v, t)))

case expressions.InSet(a: Attribute, set) =>
val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType)
Some(sources.In(a.name, set.toArray.map(toScala)))
case expressions.EqualTo(e: Expression, Literal(v, t)) =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR will be a good performance improvement for Spark 2.5.0.

attrName(e).map(name => sources.EqualTo(name, convertToScala(v, t)))
case expressions.EqualTo(Literal(v, t), e: Expression) =>
attrName(e).map(name => sources.EqualTo(name, convertToScala(v, t)))

case expressions.EqualNullSafe(e: Expression, Literal(v, t)) =>
attrName(e).map(name => sources.EqualNullSafe(name, convertToScala(v, t)))
case expressions.EqualNullSafe(Literal(v, t), e: Expression) =>
attrName(e).map(name => sources.EqualNullSafe(name, convertToScala(v, t)))

case expressions.GreaterThan(e: Expression, Literal(v, t)) =>
attrName(e).map(name => sources.GreaterThan(name, convertToScala(v, t)))
case expressions.GreaterThan(Literal(v, t), e: Expression) =>
attrName(e).map(name => sources.LessThan(name, convertToScala(v, t)))

case expressions.LessThan(e: Expression, Literal(v, t)) =>
attrName(e).map(name => sources.LessThan(name, convertToScala(v, t)))
case expressions.LessThan(Literal(v, t), e: Expression) =>
attrName(e).map(name => sources.GreaterThan(name, convertToScala(v, t)))

case expressions.GreaterThanOrEqual(e: Expression, Literal(v, t)) =>
attrName(e).map(name => sources.GreaterThanOrEqual(name, convertToScala(v, t)))
case expressions.GreaterThanOrEqual(Literal(v, t), e: Expression) =>
attrName(e).map(name => sources.LessThanOrEqual(name, convertToScala(v, t)))

case expressions.LessThanOrEqual(e: Expression, Literal(v, t)) =>
attrName(e).map(name => sources.LessThanOrEqual(name, convertToScala(v, t)))
case expressions.LessThanOrEqual(Literal(v, t), e: Expression) =>
attrName(e).map(name => sources.GreaterThanOrEqual(name, convertToScala(v, t)))

case expressions.InSet(e: Expression, set) =>
val toScala = CatalystTypeConverters.createToScalaConverter(e.dataType)
attrName(e).map(name => sources.In(name, set.toArray.map(toScala)))

// Because we only convert In to InSet in Optimizer when there are more than certain
// items. So it is possible we still get an In expression here that needs to be pushed
// down.
case expressions.In(a: Attribute, list) if !list.exists(!_.isInstanceOf[Literal]) =>
case expressions.In(e: Expression, list) if list.forall(_.isInstanceOf[Literal]) =>
val hSet = list.map(e => e.eval(EmptyRow))
val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType)
Some(sources.In(a.name, hSet.toArray.map(toScala)))
val toScala = CatalystTypeConverters.createToScalaConverter(e.dataType)
attrName(e).map(name => sources.In(name, hSet.toArray.map(toScala)))

case expressions.IsNull(a: Attribute) =>
Some(sources.IsNull(a.name))
case expressions.IsNotNull(a: Attribute) =>
Some(sources.IsNotNull(a.name))
case expressions.IsNull(e: Expression) =>
attrName(e).map(name => sources.IsNull(name))
case expressions.IsNotNull(e: Expression) =>
attrName(e).map(name => sources.IsNotNull(name))

case expressions.And(left, right) =>
// See SPARK-12218 for detailed discussion
Expand All @@ -504,14 +516,14 @@ object DataSourceStrategy {
case expressions.Not(child) =>
translateFilter(child).map(sources.Not)

case expressions.StartsWith(a: Attribute, Literal(v: UTF8String, StringType)) =>
Some(sources.StringStartsWith(a.name, v.toString))
case expressions.StartsWith(e: Expression, Literal(v: UTF8String, StringType)) =>
attrName(e).map(name => sources.StringStartsWith(name, v.toString))

case expressions.EndsWith(a: Attribute, Literal(v: UTF8String, StringType)) =>
Some(sources.StringEndsWith(a.name, v.toString))
case expressions.EndsWith(e: Expression, Literal(v: UTF8String, StringType)) =>
attrName(e).map(name => sources.StringEndsWith(name, v.toString))

case expressions.Contains(a: Attribute, Literal(v: UTF8String, StringType)) =>
Some(sources.StringContains(a.name, v.toString))
case expressions.Contains(e: Expression, Literal(v: UTF8String, StringType)) =>
attrName(e).map(name => sources.StringContains(name, v.toString))

case _ => None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import scala.collection.JavaConverters.asScalaBufferConverter
import org.apache.parquet.filter2.predicate._
import org.apache.parquet.filter2.predicate.FilterApi._
import org.apache.parquet.io.api.Binary
import org.apache.parquet.schema.{DecimalMetadata, MessageType, OriginalType, PrimitiveComparator}
import org.apache.parquet.schema._
import org.apache.parquet.schema.OriginalType._
import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName
import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._
Expand All @@ -54,7 +54,7 @@ private[parquet] class ParquetFilters(
* @param fieldName field name in parquet file
* @param fieldType field type related info in parquet file
*/
private case class ParquetField(
private case class ParquetPrimitiveField (
fieldName: String,
fieldType: ParquetSchemaType)

Expand Down Expand Up @@ -364,16 +364,46 @@ private[parquet] class ParquetFilters(
/**
* Returns a map, which contains parquet field name and data type, if predicate push down applies.
*/
private def getFieldMap(dataType: MessageType): Map[String, ParquetField] = {
// Here we don't flatten the fields in the nested schema but just look up through
// root fields. Currently, accessing to nested fields does not push down filters
// and it does not support to create filters for them.
val primitiveFields =
dataType.getFields.asScala.filter(_.isPrimitive).map(_.asPrimitiveType()).map { f =>
f.getName -> ParquetField(f.getName,
ParquetSchemaType(f.getOriginalType,
f.getPrimitiveTypeName, f.getTypeLength, f.getDecimalMetadata))
private def getFieldMap(dataType: MessageType): Map[String, ParquetPrimitiveField] = {
def canPushDownField(field: Type): Boolean = {
if (field.getName.contains(".")) {
// Parquet does not allow dots in the column name because dots are used as a column path
// delimiter. Since Parquet 1.8.2 (PARQUET-389), Parquet accepts the filter predicates
// with missing columns. The incorrect results could be got from Parquet when we push down
// filters for the column having dots in the names. Thus, we do not push down such filters.
// See SPARK-20364.
false
} else {
field match {
case _: PrimitiveType => true
// Parquet only supports push-down for primitive types; as a result, Map and List types
// are filtered out. FYI, when g is a `Struct`, `g.getOriginalType` is `null`.
// When g is a `Map`, `g.getOriginalType` is `MAP`.
// When g is a `List`, `g.getOriginalType` is `LIST`.
case g: GroupType if g.getOriginalType == null => true
case _ => false
}
}
}

def getFieldMapHelper(
fields: Seq[Type],
baseName: Option[String] = None): Seq[(String, ParquetPrimitiveField)] = {
fields.filter(canPushDownField).flatMap { field =>
val name = baseName.map(_ + "." + field.getName).getOrElse(field.getName)
field match {
case p: PrimitiveType =>
val primitiveField = ParquetPrimitiveField(fieldName = name,
fieldType = ParquetSchemaType(p.getOriginalType,
p.getPrimitiveTypeName, p.getTypeLength, p.getDecimalMetadata))
Some((name, primitiveField))
case g: GroupType =>
getFieldMapHelper(g.getFields.asScala, Some(name))
}
}
}

val primitiveFields = getFieldMapHelper(dataType.getFields.asScala)
if (caseSensitive) {
primitiveFields.toMap
} else {
Expand All @@ -393,8 +423,7 @@ private[parquet] class ParquetFilters(
* Converts data sources filters to Parquet filter predicates.
*/
def createFilter(schema: MessageType, predicate: sources.Filter): Option[FilterPredicate] = {
val nameToParquetField = getFieldMap(schema)
createFilterHelper(nameToParquetField, predicate, canPartialPushDownConjuncts = true)
createFilterHelper(schema, predicate, canPartialPushDownConjuncts = true)
}

/**
Expand All @@ -407,9 +436,11 @@ private[parquet] class ParquetFilters(
* @return the Parquet-native filter predicates that are eligible for pushdown.
*/
private def createFilterHelper(
nameToParquetField: Map[String, ParquetField],
schema: MessageType,
predicate: sources.Filter,
canPartialPushDownConjuncts: Boolean): Option[FilterPredicate] = {
val nameToParquetField = getFieldMap(schema)

// Decimal type must make sure that filter value's scale matched the file.
// If doesn't matched, which would cause data corruption.
def isDecimalMatched(value: Any, decimalMeta: DecimalMetadata): Boolean = value match {
Expand Down Expand Up @@ -442,13 +473,8 @@ private[parquet] class ParquetFilters(
})
}

// Parquet does not allow dots in the column name because dots are used as a column path
// delimiter. Since Parquet 1.8.2 (PARQUET-389), Parquet accepts the filter predicates
// with missing columns. The incorrect results could be got from Parquet when we push down
// filters for the column having dots in the names. Thus, we do not push down such filters.
// See SPARK-20364.
def canMakeFilterOn(name: String, value: Any): Boolean = {
nameToParquetField.contains(name) && !name.contains(".") && valueCanMakeFilterOn(name, value)
nameToParquetField.contains(name) && valueCanMakeFilterOn(name, value)
}

// NOTE:
Expand Down Expand Up @@ -515,9 +541,9 @@ private[parquet] class ParquetFilters(
// AND before hitting NOT or OR conditions, and in this case, the unsupported predicate
// can be safely removed.
val lhsFilterOption =
createFilterHelper(nameToParquetField, lhs, canPartialPushDownConjuncts)
createFilterHelper(schema, lhs, canPartialPushDownConjuncts)
val rhsFilterOption =
createFilterHelper(nameToParquetField, rhs, canPartialPushDownConjuncts)
createFilterHelper(schema, rhs, canPartialPushDownConjuncts)

(lhsFilterOption, rhsFilterOption) match {
case (Some(lhsFilter), Some(rhsFilter)) => Some(FilterApi.and(lhsFilter, rhsFilter))
Expand All @@ -529,13 +555,13 @@ private[parquet] class ParquetFilters(
case sources.Or(lhs, rhs) =>
for {
lhsFilter <-
createFilterHelper(nameToParquetField, lhs, canPartialPushDownConjuncts = false)
createFilterHelper(schema, lhs, canPartialPushDownConjuncts = false)
rhsFilter <-
createFilterHelper(nameToParquetField, rhs, canPartialPushDownConjuncts = false)
createFilterHelper(schema, rhs, canPartialPushDownConjuncts = false)
} yield FilterApi.or(lhsFilter, rhsFilter)

case sources.Not(pred) =>
createFilterHelper(nameToParquetField, pred, canPartialPushDownConjuncts = false)
createFilterHelper(schema, pred, canPartialPushDownConjuncts = false)
.map(FilterApi.not)

case sources.In(name, values) if canMakeFilterOn(name, values.head)
Expand Down
Loading