Skip to content

Commit 6d22666

Browse files
Extending ParquetFilters
1 parent 93e8192 commit 6d22666

File tree

4 files changed

+230
-76
lines changed

4 files changed

+230
-76
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -174,12 +174,6 @@ case class GreaterThanOrEqual(left: Expression, right: Expression) extends Binar
174174
override def eval(input: Row): Any = c2(input, left, right, _.gteq(_, _))
175175
}
176176

177-
// A simple filter condition on a single column
178-
/*case class ColumnFilterPredicate(val comparison: BinaryComparison) extends BinaryComparison {
179-
override def eval(input: Row): Any = comparison.eval(input)
180-
181-
} */
182-
183177
case class If(predicate: Expression, trueValue: Expression, falseValue: Expression)
184178
extends Expression {
185179

sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala

Lines changed: 164 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -17,42 +17,125 @@
1717

1818
package org.apache.spark.sql.parquet
1919

20+
import org.apache.hadoop.conf.Configuration
21+
2022
import parquet.filter._
23+
import parquet.filter.ColumnPredicates._
2124
import parquet.column.ColumnReader
22-
import org.apache.spark.sql.catalyst.expressions._
23-
import org.apache.spark.sql.catalyst.expressions.Equals
24-
import org.apache.spark.sql.execution.SparkSqlSerializer
25-
import org.apache.hadoop.conf.Configuration
26-
import org.apache.spark.sql.catalyst.types.{IntegerType, BooleanType, NativeType}
27-
import scala.reflect.runtime.universe.{typeTag, TypeTag}
28-
import scala.reflect.ClassTag
25+
2926
import com.google.common.io.BaseEncoding
30-
import parquet.filter
31-
import parquet.filter.ColumnPredicates.BooleanPredicateFunction
3227

33-
// Implicits
34-
import collection.JavaConversions._
28+
import org.apache.spark.sql.catalyst.types._
29+
import org.apache.spark.sql.catalyst.expressions._
30+
import org.apache.spark.sql.execution.SparkSqlSerializer
3531

3632
object ParquetFilters {
3733
val PARQUET_FILTER_DATA = "org.apache.spark.sql.parquet.row.filter"
3834

3935
def createFilter(filterExpressions: Seq[Expression]): UnboundRecordFilter = {
4036
def createEqualityFilter(name: String, literal: Literal) = literal.dataType match {
41-
case BooleanType => new ComparisonFilter(name, literal.value.asInstanceOf[Boolean])
42-
case IntegerType => new ComparisonFilter(name, _ == literal.value.asInstanceOf[Int])
37+
case BooleanType =>
38+
ComparisonFilter.createBooleanFilter(name, literal.value.asInstanceOf[Boolean])
39+
case IntegerType =>
40+
ComparisonFilter.createIntFilter(name, (x: Int) => x == literal.value.asInstanceOf[Int])
41+
case LongType =>
42+
ComparisonFilter.createLongFilter(name, (x: Long) => x == literal.value.asInstanceOf[Long])
43+
case DoubleType =>
44+
ComparisonFilter.createDoubleFilter(
45+
name,
46+
(x: Double) => x == literal.value.asInstanceOf[Double])
47+
case FloatType =>
48+
ComparisonFilter.createFloatFilter(
49+
name,
50+
(x: Float) => x == literal.value.asInstanceOf[Float])
51+
case StringType =>
52+
ComparisonFilter.createStringFilter(name, literal.value.asInstanceOf[String])
4353
}
44-
45-
val filters: Seq[UnboundRecordFilter] = filterExpressions.map {
46-
case Equals(left: Literal, right: NamedExpression) => {
47-
val name: String = right.name
48-
createEqualityFilter(name, left)
49-
}
50-
case Equals(left: NamedExpression, right: Literal) => {
51-
val name: String = left.name
52-
createEqualityFilter(name, right)
53-
}
54+
def createLessThanFilter(name: String, literal: Literal) = literal.dataType match {
55+
case IntegerType =>
56+
ComparisonFilter.createIntFilter(name, (x: Int) => x < literal.value.asInstanceOf[Int])
57+
case LongType =>
58+
ComparisonFilter.createLongFilter(name, (x: Long) => x < literal.value.asInstanceOf[Long])
59+
case DoubleType =>
60+
ComparisonFilter.createDoubleFilter(
61+
name,
62+
(x: Double) => x < literal.value.asInstanceOf[Double])
63+
case FloatType =>
64+
ComparisonFilter.createFloatFilter(
65+
name,
66+
(x: Float) => x < literal.value.asInstanceOf[Float])
5467
}
55-
68+
def createLessThanOrEqualFilter(name: String, literal: Literal) = literal.dataType match {
69+
case IntegerType =>
70+
ComparisonFilter.createIntFilter(name, (x: Int) => x <= literal.value.asInstanceOf[Int])
71+
case LongType =>
72+
ComparisonFilter.createLongFilter(name, (x: Long) => x <= literal.value.asInstanceOf[Long])
73+
case DoubleType =>
74+
ComparisonFilter.createDoubleFilter(
75+
name,
76+
(x: Double) => x <= literal.value.asInstanceOf[Double])
77+
case FloatType =>
78+
ComparisonFilter.createFloatFilter(
79+
name,
80+
(x: Float) => x <= literal.value.asInstanceOf[Float])
81+
}
82+
// TODO: combine these two types somehow?
83+
def createGreaterThanFilter(name: String, literal: Literal) = literal.dataType match {
84+
case IntegerType =>
85+
ComparisonFilter.createIntFilter(name, (x: Int) => x > literal.value.asInstanceOf[Int])
86+
case LongType =>
87+
ComparisonFilter.createLongFilter(name, (x: Long) => x > literal.value.asInstanceOf[Long])
88+
case DoubleType =>
89+
ComparisonFilter.createDoubleFilter(
90+
name,
91+
(x: Double) => x > literal.value.asInstanceOf[Double])
92+
case FloatType =>
93+
ComparisonFilter.createFloatFilter(
94+
name,
95+
(x: Float) => x > literal.value.asInstanceOf[Float])
96+
}
97+
def createGreaterThanOrEqualFilter(name: String, literal: Literal) = literal.dataType match {
98+
case IntegerType =>
99+
ComparisonFilter.createIntFilter(name, (x: Int) => x >= literal.value.asInstanceOf[Int])
100+
case LongType =>
101+
ComparisonFilter.createLongFilter(name, (x: Long) => x >= literal.value.asInstanceOf[Long])
102+
case DoubleType =>
103+
ComparisonFilter.createDoubleFilter(
104+
name,
105+
(x: Double) => x >= literal.value.asInstanceOf[Double])
106+
case FloatType =>
107+
ComparisonFilter.createFloatFilter(
108+
name,
109+
(x: Float) => x >= literal.value.asInstanceOf[Float])
110+
}
111+
// TODO: can we actually rely on the predicate being normalized as in expression < literal?
112+
// That would simplify this pattern matching
113+
// TODO: we currently only filter on non-nullable (Parquet REQUIRED) attributes until
114+
// https://github.com/Parquet/parquet-mr/issues/371
115+
// has been resolved
116+
val filters: Seq[UnboundRecordFilter] = filterExpressions.collect {
117+
case Equals(left: Literal, right: NamedExpression) if !right.nullable =>
118+
createEqualityFilter(right.name, left)
119+
case Equals(left: NamedExpression, right: Literal) if !left.nullable =>
120+
createEqualityFilter(left.name, right)
121+
case LessThan(left: Literal, right: NamedExpression) if !right.nullable =>
122+
createLessThanFilter(right.name, left)
123+
case LessThan(left: NamedExpression, right: Literal) if !left.nullable =>
124+
createLessThanFilter(left.name, right)
125+
case LessThanOrEqual(left: Literal, right: NamedExpression) if !right.nullable =>
126+
createLessThanOrEqualFilter(right.name, left)
127+
case LessThanOrEqual(left: NamedExpression, right: Literal) if !left.nullable =>
128+
createLessThanOrEqualFilter(left.name, right)
129+
case GreaterThan(left: Literal, right: NamedExpression) if !right.nullable =>
130+
createGreaterThanFilter(right.name, left)
131+
case GreaterThan(left: NamedExpression, right: Literal) if !left.nullable =>
132+
createGreaterThanFilter(left.name, right)
133+
case GreaterThanOrEqual(left: Literal, right: NamedExpression) if !right.nullable =>
134+
createGreaterThanOrEqualFilter(right.name, left)
135+
case GreaterThanOrEqual(left: NamedExpression, right: Literal) if !left.nullable =>
136+
createGreaterThanOrEqualFilter(left.name, right)
137+
}
138+
// TODO: How about disjunctions? (Or-ed)
56139
if (filters.length > 0) filters.reduce(AndRecordFilter.and) else null
57140
}
58141

@@ -83,47 +166,72 @@ class ComparisonFilter(
83166
private val columnName: String,
84167
private var filter: UnboundRecordFilter)
85168
extends UnboundRecordFilter {
86-
def this(columnName: String, value: Boolean) =
87-
this(
169+
override def bind(readers: java.lang.Iterable[ColumnReader]): RecordFilter = {
170+
filter.bind(readers)
171+
}
172+
}
173+
174+
object ComparisonFilter {
175+
def createBooleanFilter(columnName: String, value: Boolean): UnboundRecordFilter =
176+
new ComparisonFilter(
88177
columnName,
89178
ColumnRecordFilter.column(
90179
columnName,
91180
ColumnPredicates.applyFunctionToBoolean(
92-
new ColumnPredicates.BooleanPredicateFunction {
181+
new BooleanPredicateFunction {
93182
def functionToApply(input: Boolean): Boolean = input == value
94-
})))
95-
def this(columnName: String, func: Int => Boolean) =
96-
this(
183+
}
184+
)))
185+
def createStringFilter(columnName: String, value: String): UnboundRecordFilter =
186+
new ComparisonFilter(
187+
columnName,
188+
ColumnRecordFilter.column(
189+
columnName,
190+
ColumnPredicates.applyFunctionToString (
191+
new ColumnPredicates.PredicateFunction[String] {
192+
def functionToApply(input: String): Boolean = input == value
193+
}
194+
)))
195+
def createIntFilter(columnName: String, func: Int => Boolean): UnboundRecordFilter =
196+
new ComparisonFilter(
97197
columnName,
98198
ColumnRecordFilter.column(
99199
columnName,
100200
ColumnPredicates.applyFunctionToInteger(
101-
new ColumnPredicates.IntegerPredicateFunction {
102-
def functionToApply(input: Int) = if (input != null) func(input) else false
103-
})))
104-
override def bind(readers: java.lang.Iterable[ColumnReader]): RecordFilter = {
105-
filter.bind(readers)
106-
}
201+
new IntegerPredicateFunction {
202+
def functionToApply(input: Int) = func(input)
203+
}
204+
)))
205+
def createLongFilter(columnName: String, func: Long => Boolean): UnboundRecordFilter =
206+
new ComparisonFilter(
207+
columnName,
208+
ColumnRecordFilter.column(
209+
columnName,
210+
ColumnPredicates.applyFunctionToLong(
211+
new LongPredicateFunction {
212+
def functionToApply(input: Long) = func(input)
213+
}
214+
)))
215+
def createDoubleFilter(columnName: String, func: Double => Boolean): UnboundRecordFilter =
216+
new ComparisonFilter(
217+
columnName,
218+
ColumnRecordFilter.column(
219+
columnName,
220+
ColumnPredicates.applyFunctionToDouble(
221+
new DoublePredicateFunction {
222+
def functionToApply(input: Double) = func(input)
223+
}
224+
)))
225+
def createFloatFilter(columnName: String, func: Float => Boolean): UnboundRecordFilter =
226+
new ComparisonFilter(
227+
columnName,
228+
ColumnRecordFilter.column(
229+
columnName,
230+
ColumnPredicates.applyFunctionToFloat(
231+
new FloatPredicateFunction {
232+
def functionToApply(input: Float) = func(input)
233+
}
234+
)))
107235
}
108236

109-
/*class EqualityFilter(
110-
private val columnName: String,
111-
private var filter: UnboundRecordFilter)
112-
extends UnboundRecordFilter {
113-
def this(columnName: String, value: Boolean) =
114-
this(columnName, ColumnRecordFilter.column(columnName, ColumnPredicates.equalTo(value)))
115-
def this(columnName: String, value: Int) =
116-
this(columnName, ColumnRecordFilter.column(columnName, ColumnPredicates.equalTo(value)))
117-
def this(columnName: String, value: Long) =
118-
this(columnName, ColumnRecordFilter.column(columnName, ColumnPredicates.equalTo(value)))
119-
def this(columnName: String, value: Double) =
120-
this(columnName, ColumnRecordFilter.column(columnName, ColumnPredicates.equalTo(value)))
121-
def this(columnName: String, value: Float) =
122-
this(columnName, ColumnRecordFilter.column(columnName, ColumnPredicates.equalTo(value)))
123-
def this(columnName: String, value: String) =
124-
this(columnName, ColumnRecordFilter.column(columnName, ColumnPredicates.equalTo(value)))
125-
override def bind(readers: java.lang.Iterable[ColumnReader]): RecordFilter = {
126-
filter.bind(readers)
127-
}
128-
}*/
129237

sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala

Lines changed: 53 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,29 @@ import parquet.schema.{MessageType, MessageTypeParser}
2727

2828
import org.apache.spark.sql.catalyst.expressions.GenericRow
2929
import org.apache.spark.util.Utils
30+
import parquet.hadoop.metadata.CompressionCodecName
31+
import parquet.hadoop.api.WriteSupport
32+
import parquet.example.data.{GroupWriter, Group}
33+
import parquet.io.api.RecordConsumer
34+
import parquet.hadoop.api.WriteSupport.WriteContext
35+
import parquet.example.data.simple.SimpleGroup
36+
37+
// Write support class for nested groups:
38+
// ParquetWriter initializes GroupWriteSupport with an empty configuration
39+
// (it is after all not intended to be used in this way?)
40+
// and members are private so we need to make our own
41+
private class TestGroupWriteSupport(schema: MessageType) extends WriteSupport[Group] {
42+
var groupWriter: GroupWriter = null
43+
override def prepareForWrite(recordConsumer: RecordConsumer): Unit = {
44+
groupWriter = new GroupWriter(recordConsumer, schema)
45+
}
46+
override def init(configuration: Configuration): WriteContext = {
47+
new WriteContext(schema, new java.util.HashMap[String, String]())
48+
}
49+
override def write(record: Group) {
50+
groupWriter.write(record)
51+
}
52+
}
3053

3154
private[sql] object ParquetTestData {
3255

@@ -75,26 +98,42 @@ private[sql] object ParquetTestData {
7598
val configuration: Configuration = ContextUtil.getConfiguration(job)
7699
val schema: MessageType = MessageTypeParser.parseMessageType(testSchema)
77100

78-
val writeSupport = new RowWriteSupport()
79-
writeSupport.setSchema(schema, configuration)
80-
val writer = new ParquetWriter(path, writeSupport)
101+
//val writeSupport = new MutableRowWriteSupport()
102+
//writeSupport.setSchema(schema, configuration)
103+
//val writer = new ParquetWriter(path, writeSupport)
104+
val writeSupport = new TestGroupWriteSupport(schema)
105+
//val writer = //new ParquetWriter[Group](path, writeSupport)
106+
val writer = new ParquetWriter[Group](path, writeSupport)
107+
81108
for(i <- 0 until 15) {
82-
val data = new Array[Any](6)
109+
val record = new SimpleGroup(schema)
110+
//val data = new Array[Any](6)
83111
if (i % 3 == 0) {
84-
data.update(0, true)
112+
//data.update(0, true)
113+
record.add(0, true)
85114
} else {
86-
data.update(0, false)
115+
//data.update(0, false)
116+
record.add(0, false)
87117
}
88-
//if (i % 5 == 0) {
118+
if (i % 5 == 0) {
119+
record.add(1, 5)
89120
// data.update(1, 5)
90-
//} else {
91-
data.update(1, null) // optional
121+
} else {
122+
if (i % 5 == 1) record.add(1, 4)
123+
}
124+
//else {
125+
// data.update(1, null) // optional
92126
//}
93-
data.update(2, "abc")
94-
data.update(3, i.toLong << 33)
95-
data.update(4, 2.5F)
96-
data.update(5, 4.5D)
97-
writer.write(new GenericRow(data.toArray))
127+
//data.update(2, "abc")
128+
record.add(2, "abc")
129+
//data.update(3, i.toLong << 33)
130+
record.add(3, i.toLong << 33)
131+
//data.update(4, 2.5F)
132+
record.add(4, 2.5F)
133+
//data.update(5, 4.5D)
134+
record.add(5, 4.5D)
135+
//writer.write(new GenericRow(data.toArray))
136+
writer.write(record)
98137
}
99138
writer.close()
100139
}

sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,20 @@ class ParquetQuerySuite extends QueryTest with FunSuite with BeforeAndAfterAll {
241241

242242
test("SELECT WHERE") {
243243
val result = sql("SELECT * FROM testsource WHERE myint = 5").collect()
244+
/*test("SELECT WHERE") {
245+
//val result = parquetFile("/home/andre/input.adam").registerAsTable("adamtable")
246+
//sql("SELECT * FROM adamtable WHERE mapq = 0").collect()
247+
//assert(result != null)
248+
//val result = sql("SELECT * FROM testsource WHERE myint = 5").collect()
249+
// TODO: ADD larger case SchemaRDD with filtering on REQUIRED field!
250+
implicit def anyToRow(value: Any): Row = value.asInstanceOf[Row]
251+
TestSQLContext
252+
.parquetFile(ParquetTestData.testNestedDir1.toString)
253+
.toSchemaRDD.registerAsTable("xtmptable")
254+
val result = sql("SELECT * FROM xtmptable WHERE owner = \"Julien Le Dem\"").collect()
255+
>>>>>>> Extending ParquetFilters
244256
assert(result != null)
257+
}*/
245258
}
246259
}
247260

0 commit comments

Comments
 (0)