diff --git a/assembly/pom.xml b/assembly/pom.xml index 4e2b773e7d2f3..8269c6985df34 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -209,6 +209,16 @@ + + hbase + + + org.apache.spark + spark-hbase_${scala.binary.version} + ${project.version} + + + spark-ganglia-lgpl diff --git a/bin/compute-classpath.cmd b/bin/compute-classpath.cmd index a4c099fb45b14..23f5b52479452 100644 --- a/bin/compute-classpath.cmd +++ b/bin/compute-classpath.cmd @@ -81,6 +81,7 @@ set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%tools\target\scala-%SCALA_VERSION%\clas set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\catalyst\target\scala-%SCALA_VERSION%\classes set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\core\target\scala-%SCALA_VERSION%\classes set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\hive\target\scala-%SCALA_VERSION%\classes +set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\hbase\target\scala-%SCALA_VERSION%\classes set SPARK_TEST_CLASSES=%FWDIR%core\target\scala-%SCALA_VERSION%\test-classes set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%repl\target\scala-%SCALA_VERSION%\test-classes @@ -91,6 +92,7 @@ set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%streaming\target\scala-%SCALA set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\catalyst\target\scala-%SCALA_VERSION%\test-classes set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\core\target\scala-%SCALA_VERSION%\test-classes set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\hive\target\scala-%SCALA_VERSION%\test-classes +set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\hbase\target\scala-%SCALA_VERSION%\test-classes if "x%SPARK_TESTING%"=="x1" ( rem Add test clases to path - note, add SPARK_CLASSES and SPARK_TEST_CLASSES before CLASSPATH diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh index 298641f2684de..e2144722881c6 100755 --- a/bin/compute-classpath.sh +++ b/bin/compute-classpath.sh @@ -59,6 +59,7 @@ if [ -n "$SPARK_PREPEND_CLASSES" ]; then CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SPARK_SCALA_VERSION/classes" CLASSPATH="$CLASSPATH:$FWDIR/sql/hive-thriftserver/target/scala-$SPARK_SCALA_VERSION/classes" CLASSPATH="$CLASSPATH:$FWDIR/yarn/stable/target/scala-$SPARK_SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/sql/hbase/target/scala-$SCALA_VERSION/classes" fi # Use spark-assembly jar from either RELEASE or assembly directory @@ -130,6 +131,7 @@ if [[ $SPARK_TESTING == 1 ]]; then CLASSPATH="$CLASSPATH:$FWDIR/sql/catalyst/target/scala-$SPARK_SCALA_VERSION/test-classes" CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SPARK_SCALA_VERSION/test-classes" CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SPARK_SCALA_VERSION/test-classes" + CLASSPATH="$CLASSPATH:$FWDIR/sql/hbase/target/scala-$SCALA_VERSION/classes" fi # Add hadoop conf dir if given -- otherwise FileSystem.*, etc fail ! diff --git a/bin/hbase-sql b/bin/hbase-sql new file mode 100755 index 0000000000000..4ea11a4faaf12 --- /dev/null +++ b/bin/hbase-sql @@ -0,0 +1,55 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# +# Shell script for starting the Spark SQL for HBase CLI + +# Enter posix mode for bash +set -o posix + +CLASS="org.apache.spark.sql.hbase.HBaseSQLCLIDriver" + +# Figure out where Spark is installed +FWDIR="$(cd "`dirname "$0"`"/..; pwd)" + +function usage { + echo "Usage: ./bin/hbase-sql [options] [cli option]" + pattern="usage" + pattern+="\|Spark assembly has been built with hbase" + pattern+="\|NOTE: SPARK_PREPEND_CLASSES is set" + pattern+="\|Spark Command: " + pattern+="\|--help" + pattern+="\|=======" + + "$FWDIR"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 + echo + echo "CLI options:" + "$FWDIR"/bin/spark-class $CLASS --help 2>&1 | grep -v "$pattern" 1>&2 +} + +if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then + usage + exit 0 +fi + +source "$FWDIR"/bin/utils.sh +SUBMIT_USAGE_FUNCTION=usage +gatherSparkSubmitOpts "$@" + +exec "$FWDIR"/bin/spark-submit --class $CLASS "${SUBMISSION_OPTS[@]}" spark-internal "${APPLICATION_OPTS[@]}" diff --git a/examples/pom.xml b/examples/pom.xml index 8713230e1e8ed..ea7c34dae1dde 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -101,140 +101,140 @@ org.eclipse.jetty jetty-server - - org.apache.hbase - hbase-testing-util - ${hbase.version} - - - - org.apache.hbase - hbase-annotations - - - org.jruby - jruby-complete - - - - - org.apache.hbase - hbase-protocol - ${hbase.version} - - - org.apache.hbase - hbase-common - ${hbase.version} - - - - org.apache.hbase - hbase-annotations - - - - - org.apache.hbase - hbase-client - ${hbase.version} - - - - org.apache.hbase - hbase-annotations - - + + org.apache.hbase + hbase-testing-util + ${hbase.version} + + + + org.apache.hbase + hbase-annotations + + + org.jruby + jruby-complete + + + + + org.apache.hbase + hbase-protocol + ${hbase.version} + + + org.apache.hbase + hbase-common + ${hbase.version} + + + + org.apache.hbase + hbase-annotations + + + + + org.apache.hbase + hbase-client + ${hbase.version} + + + + org.apache.hbase + hbase-annotations + + io.netty netty - - - - - org.apache.hbase - hbase-server - ${hbase.version} - - - org.apache.hadoop - hadoop-core - - - org.apache.hadoop - hadoop-client - - - org.apache.hadoop - hadoop-mapreduce-client-jobclient - - - org.apache.hadoop - hadoop-mapreduce-client-core - - - org.apache.hadoop - hadoop-auth - - - - org.apache.hbase - hbase-annotations - - - org.apache.hadoop - hadoop-annotations - - - org.apache.hadoop - hadoop-hdfs - - - org.apache.hbase - hbase-hadoop1-compat - - - org.apache.commons - commons-math - - - com.sun.jersey - jersey-core - - - org.slf4j - slf4j-api - - - com.sun.jersey - jersey-server - - - com.sun.jersey - jersey-core - - - com.sun.jersey - jersey-json - - - - commons-io - commons-io - - - - - org.apache.hbase - hbase-hadoop-compat - ${hbase.version} - - - org.apache.hbase - hbase-hadoop-compat - ${hbase.version} - test-jar - test - + + + + + org.apache.hbase + hbase-server + ${hbase.version} + + + org.apache.hadoop + hadoop-core + + + org.apache.hadoop + hadoop-client + + + org.apache.hadoop + hadoop-mapreduce-client-jobclient + + + org.apache.hadoop + hadoop-mapreduce-client-core + + + org.apache.hadoop + hadoop-auth + + + + org.apache.hbase + hbase-annotations + + + org.apache.hadoop + hadoop-annotations + + + org.apache.hadoop + hadoop-hdfs + + + org.apache.hbase + hbase-hadoop1-compat + + + org.apache.commons + commons-math + + + com.sun.jersey + jersey-core + + + org.slf4j + slf4j-api + + + com.sun.jersey + jersey-server + + + com.sun.jersey + jersey-core + + + com.sun.jersey + jersey-json + + + + commons-io + commons-io + + + + + org.apache.hbase + hbase-hadoop-compat + ${hbase.version} + + + org.apache.hbase + hbase-hadoop-compat + ${hbase.version} + test-jar + test + org.apache.commons commons-math3 @@ -416,7 +416,7 @@ - scala-2.10 diff --git a/pom.xml b/pom.xml index b7df53d3e5eb1..901616f257995 100644 --- a/pom.xml +++ b/pom.xml @@ -156,7 +156,7 @@ central Maven Repository - https://repo1.maven.org/maven2 + http://repo1.maven.org/maven2 true @@ -167,7 +167,7 @@ apache-repo Apache Repository - https://repository.apache.org/content/repositories/releases + http://repository.apache.org/content/repositories/releases true @@ -178,7 +178,7 @@ jboss-repo JBoss Repository - https://repository.jboss.org/nexus/content/repositories/releases + http://repository.jboss.org/nexus/content/repositories/releases true @@ -189,7 +189,7 @@ mqtt-repo MQTT Repository - https://repo.eclipse.org/content/repositories/paho-releases + http://repo.eclipse.org/content/repositories/paho-releases true @@ -200,7 +200,7 @@ cloudera-repo Cloudera Repository - https://repository.cloudera.com/artifactory/cloudera-repos + http://repository.cloudera.com/artifactory/cloudera-repos true @@ -222,7 +222,7 @@ spring-releases Spring Release Repository - https://repo.spring.io/libs-release + http://repo.spring.io/libs-release true @@ -234,7 +234,7 @@ spark-staging-1038 Spark 1.2.0 Staging (1038) - https://repository.apache.org/content/repositories/orgapachespark-1038/ + http://repository.apache.org/content/repositories/orgapachespark-1038/ true @@ -246,7 +246,7 @@ central - https://repo1.maven.org/maven2 + http://repo1.maven.org/maven2 true @@ -936,6 +936,7 @@ ${java.version} ${java.version} + true UTF-8 1024m true @@ -1415,7 +1416,6 @@ 10.10.1.1 - scala-2.10 @@ -1431,7 +1431,6 @@ external/kafka - scala-2.11 diff --git a/python/pyspark/hbase.py b/python/pyspark/hbase.py new file mode 100644 index 0000000000000..e68a5f4e74f2a --- /dev/null +++ b/python/pyspark/hbase.py @@ -0,0 +1,84 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from py4j.protocol import Py4JError +import traceback + +from pyspark.sql import * + +__all__ = [ + "StringType", "BinaryType", "BooleanType", "TimestampType", "DecimalType", + "DoubleType", "FloatType", "ByteType", "IntegerType", "LongType", + "ShortType", "ArrayType", "MapType", "StructField", "StructType", + "SQLContext", "HBaseSQLContext", "SchemaRDD", "Row", "_ssql_ctx", "_get_object_id"] + + +class HBaseSQLContext(SQLContext): + """A variant of Spark SQL that integrates with data stored in Hive. + + Configuration for Hive is read from hive-site.xml on the classpath. + It supports running both SQL and HiveQL commands. + """ + + def __init__(self, sparkContext, hbaseContext=None): + """Create a new HiveContext. + + @param sparkContext: The SparkContext to wrap. + @param hbaseContext: An optional JVM Scala HBaseSQLContext. If set, we do not instatiate a new + HBaseSQLContext in the JVM, instead we make all calls to this object. + """ + SQLContext.__init__(self, sparkContext) + + if hbaseContext: + self._scala_HBaseSQLContext = hbaseContext + else: + self._scala_HBaseSQLContext = None + print("HbaseContext is %s" % self._scala_HBaseSQLContext) + + @property + def _ssql_ctx(self): + # try: + if self._scala_HBaseSQLContext is None: + # if not hasattr(self, '_scala_HBaseSQLContext'): + print ("loading hbase context ..") + self._scala_HBaseSQLContext = self._get_hbase_ctx() + self._scala_SQLContext = self._scala_HBaseSQLContext + else: + print("We already have hbase context") + + print vars(self) + return self._scala_HBaseSQLContext + # except Py4JError as e: + # import sys + # traceback.print_stack(file=sys.stdout) + # print ("Nice error .. %s " %e) + # print(e) + # raise Exception("" + # "HbaseSQLContext not found: You must build Spark with Hbase.", e) + + def _get_hbase_ctx(self): + print("sc=%s conf=%s" %(self._jsc.sc(), self._jsc.sc().configuration)) + return self._jvm.HBaseSQLContext(self._jsc.sc()) + + + class HBaseSchemaRDD(SchemaRDD): + def createTable(self, tableName, overwrite=False): + """Inserts the contents of this SchemaRDD into the specified table. + + Optionally overwriting any existing data. + """ + self._jschema_rdd.createTable(tableName, overwrite) + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index a2bcd73b6074f..559ef3e1596bc 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -107,13 +107,16 @@ class SqlParser extends AbstractSparkSQLParser { protected val WHERE = Keyword("WHERE") // Use reflection to find the reserved words defined in this class. + /* TODO: It will cause the null exception for the subClass of SqlParser. + * Temporary solution: Add one more filter to restrain the class must be SqlParser + */ protected val reservedWords = this .getClass .getMethods .filter(_.getReturnType == classOf[Keyword]) - .map(_.invoke(this).asInstanceOf[Keyword].str) - + .filter(_.toString.contains("org.apache.spark.sql.catalyst.SqlParser.".toCharArray)) + .map{_.invoke(this).asInstanceOf[Keyword].str} override val lexical = new SqlLexical(reservedWords) protected def assignAliases(exprs: Seq[Expression]): Seq[NamedExpression] = { diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 3bd283fd20156..2d4e8f2493590 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -83,6 +83,140 @@ scalacheck_${scala.binary.version} test + + org.apache.hbase + hbase-testing-util + 0.98.5-hadoop2 + + + + org.apache.hbase + hbase-annotations + + + org.jruby + jruby-complete + + + + + org.apache.hbase + hbase-protocol + 0.98.5-hadoop2 + + + org.apache.hbase + hbase-common + 0.98.5-hadoop2 + + + + org.apache.hbase + hbase-annotations + + + + + org.apache.hbase + hbase-client + 0.98.5-hadoop2 + + + + org.apache.hbase + hbase-annotations + + + io.netty + netty + + + + + org.apache.hbase + hbase-server + 0.98.5-hadoop2 + + + org.apache.hadoop + hadoop-core + + + org.apache.hadoop + hadoop-client + + + org.apache.hadoop + hadoop-mapreduce-client-jobclient + + + org.apache.hadoop + hadoop-mapreduce-client-core + + + org.apache.hadoop + hadoop-auth + + + + org.apache.hbase + hbase-annotations + + + org.apache.hadoop + hadoop-annotations + + + org.apache.hadoop + hadoop-hdfs + + + org.apache.hbase + hbase-hadoop1-compat + + + org.apache.commons + commons-math + + + com.sun.jersey + jersey-core + + + org.slf4j + slf4j-api + + + com.sun.jersey + jersey-server + + + com.sun.jersey + jersey-core + + + com.sun.jersey + jersey-json + + + + commons-io + commons-io + + + + + org.apache.hbase + hbase-hadoop-compat + 0.98.5-hadoop2 + + + org.apache.hbase + hbase-hadoop-compat + 0.98.5-hadoop2 + test-jar + test + target/scala-${scala.binary.version}/classes diff --git a/sql/core/src/main/scala/org/apache/spark/sql/hbase/BytesUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/hbase/BytesUtils.scala new file mode 100644 index 0000000000000..190c772ae5174 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/hbase/BytesUtils.scala @@ -0,0 +1,141 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +package org.apache.spark.sql.hbase + +import org.apache.hadoop.hbase.util.Bytes + +class BytesUtils { + lazy val booleanArray: HBaseRawType = new HBaseRawType(Bytes.SIZEOF_BOOLEAN) + lazy val byteArray: HBaseRawType = new HBaseRawType(Bytes.SIZEOF_BYTE) + lazy val charArray: HBaseRawType = new HBaseRawType(Bytes.SIZEOF_CHAR) + lazy val doubleArray: HBaseRawType = new HBaseRawType(Bytes.SIZEOF_DOUBLE) + lazy val floatArray: HBaseRawType = new HBaseRawType(Bytes.SIZEOF_FLOAT) + lazy val intArray: HBaseRawType = new HBaseRawType(Bytes.SIZEOF_INT) + lazy val longArray: HBaseRawType = new HBaseRawType(Bytes.SIZEOF_LONG) + lazy val shortArray: HBaseRawType = new HBaseRawType(Bytes.SIZEOF_SHORT) + + def toBytes(input: String): HBaseRawType = { + Bytes.toBytes(input) + } + + def toString(input: HBaseRawType): String = { + Bytes.toString(input) + } + + def toBytes(input: Byte): HBaseRawType = { + // Flip sign bit so that Byte is binary comparable + byteArray(0) = (input ^ 0x80).asInstanceOf[Byte] + byteArray + } + + def toByte(input: HBaseRawType): Byte = { + // Flip sign bit back + val v: Int = input(0) ^ 0x80 + v.asInstanceOf[Byte] + } + + def toBytes(input: Boolean): HBaseRawType = { + booleanArray(0) = 0.asInstanceOf[Byte] + if (input) { + booleanArray(0) = (-1).asInstanceOf[Byte] + } + booleanArray + } + + def toBoolean(input: HBaseRawType): Boolean = { + input(0) != 0 + } + + def toBytes(input: Double): HBaseRawType = { + var l: Long = java.lang.Double.doubleToLongBits(input) + l = (l ^ ((l >> java.lang.Long.SIZE - 1) | java.lang.Long.MIN_VALUE)) + 1 + Bytes.putLong(longArray, 0, l) + longArray + } + + def toDouble(input: HBaseRawType): Double = { + var l: Long = Bytes.toLong(input) + l = l - 1 + l ^= (~l >> java.lang.Long.SIZE - 1) | java.lang.Long.MIN_VALUE + java.lang.Double.longBitsToDouble(l) + } + + def toBytes(input: Short): HBaseRawType = { + shortArray(0) = ((input >> 8) ^ 0x80).asInstanceOf[Byte] + shortArray(1) = input.asInstanceOf[Byte] + shortArray + } + + def toShort(input: HBaseRawType): Short = { + // flip sign bit back + var v: Int = input(0) ^ 0x80 + v = (v << 8) + (input(1) & 0xff) + v.asInstanceOf[Short] + } + + def toBytes(input: Float): HBaseRawType = { + var i: Int = java.lang.Float.floatToIntBits(input) + i = (i ^ ((i >> Integer.SIZE - 1) | Integer.MIN_VALUE)) + 1 + toBytes(i) + } + + def toFloat(input: HBaseRawType): Float = { + var i = toInt(input) + i = i - 1 + i ^= (~i >> Integer.SIZE - 1) | Integer.MIN_VALUE + java.lang.Float.intBitsToFloat(i) + } + + def toBytes(input: Int): HBaseRawType = { + // Flip sign bit so that INTEGER is binary comparable + intArray(0) = ((input >> 24) ^ 0x80).asInstanceOf[Byte] + intArray(1) = (input >> 16).asInstanceOf[Byte] + intArray(2) = (input >> 8).asInstanceOf[Byte] + intArray(3) = input.asInstanceOf[Byte] + intArray + } + + def toInt(input: HBaseRawType): Int = { + // Flip sign bit back + var v: Int = input(0) ^ 0x80 + for (i <- 1 to Bytes.SIZEOF_INT - 1) { + v = (v << 8) + (input(i) & 0xff) + } + v + } + + def toBytes(input: Long): HBaseRawType = { + longArray(0) = ((input >> 56) ^ 0x80).asInstanceOf[Byte] + longArray(1) = (input >> 48).asInstanceOf[Byte] + longArray(2) = (input >> 40).asInstanceOf[Byte] + longArray(3) = (input >> 32).asInstanceOf[Byte] + longArray(4) = (input >> 24).asInstanceOf[Byte] + longArray(5) = (input >> 16).asInstanceOf[Byte] + longArray(6) = (input >> 8).asInstanceOf[Byte] + longArray(7) = input.asInstanceOf[Byte] + longArray + } + + def toLong(input: HBaseRawType): Long = { + // Flip sign bit back + var v: Long = input(0) ^ 0x80 + for (i <- 1 to Bytes.SIZEOF_LONG - 1) { + v = (v << 8) + (input(i) & 0xff) + } + v + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/hbase/DataTypeUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/hbase/DataTypeUtils.scala new file mode 100755 index 0000000000000..8976a4fd48f38 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/hbase/DataTypeUtils.scala @@ -0,0 +1,109 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +package org.apache.spark.sql.hbase + +import org.apache.hadoop.hbase.filter.BinaryComparator +import org.apache.hadoop.hbase.util.Bytes +import org.apache.spark.sql.catalyst.expressions.{Literal, MutableRow, Row} +import org.apache.spark.sql.catalyst.types._ + +/** + * Data Type conversion utilities + * + */ +object DataTypeUtils { + // TODO: more data types support? + def bytesToData (src: HBaseRawType, + dt: DataType, + bu: BytesUtils): Any = { + dt match { + case StringType => bu.toString(src) + case IntegerType => bu.toInt(src) + case BooleanType => bu.toBoolean(src) + case ByteType => src(0) + case DoubleType => bu.toDouble(src) + case FloatType => bu.toFloat(src) + case LongType => bu.toLong(src) + case ShortType => bu.toShort(src) + case _ => throw new Exception("Unsupported HBase SQL Data Type") + } + } + + def setRowColumnFromHBaseRawType(row: MutableRow, + index: Int, + src: HBaseRawType, + dt: DataType, + bu: BytesUtils): Unit = { + dt match { + case StringType => row.setString(index, bu.toString(src)) + case IntegerType => row.setInt(index, bu.toInt(src)) + case BooleanType => row.setBoolean(index, bu.toBoolean(src)) + case ByteType => row.setByte(index, bu.toByte(src)) + case DoubleType => row.setDouble(index, bu.toDouble(src)) + case FloatType => row.setFloat(index, bu.toFloat(src)) + case LongType => row.setLong(index, bu.toLong(src)) + case ShortType => row.setShort(index, bu.toShort(src)) + case _ => throw new Exception("Unsupported HBase SQL Data Type") + } + } + + def getRowColumnFromHBaseRawType(row: Row, + index: Int, + dt: DataType, + bu: BytesUtils): HBaseRawType = { + dt match { + case StringType => bu.toBytes(row.getString(index)) + case IntegerType => bu.toBytes(row.getInt(index)) + case BooleanType => bu.toBytes(row.getBoolean(index)) + case ByteType => bu.toBytes(row.getByte(index)) + case DoubleType => bu.toBytes(row.getDouble(index)) + case FloatType => bu.toBytes(row.getFloat(index)) + case LongType => bu.toBytes(row.getLong(index)) + case ShortType => bu.toBytes(row.getShort(index)) + case _ => throw new Exception("Unsupported HBase SQL Data Type") + } + } + + def getComparator(expression: Literal): BinaryComparator = { + expression.dataType match { + case DoubleType => { + new BinaryComparator(Bytes.toBytes(expression.value.asInstanceOf[Double])) + } + case FloatType => { + new BinaryComparator(Bytes.toBytes(expression.value.asInstanceOf[Float])) + } + case IntegerType => { + new BinaryComparator(Bytes.toBytes(expression.value.asInstanceOf[Int])) + } + case LongType => { + new BinaryComparator(Bytes.toBytes(expression.value.asInstanceOf[Long])) + } + case ShortType => { + new BinaryComparator(Bytes.toBytes(expression.value.asInstanceOf[Short])) + } + case StringType => { + new BinaryComparator(Bytes.toBytes(expression.value.asInstanceOf[String])) + } + case BooleanType => { + new BinaryComparator(Bytes.toBytes(expression.value.asInstanceOf[Boolean])) + } + case _ => { + throw new Exception("Cannot convert the data type using BinaryComparator") + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/hbase/HBaseKVHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/hbase/HBaseKVHelper.scala new file mode 100644 index 0000000000000..8af45a813a8a9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/hbase/HBaseKVHelper.scala @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hbase + +import org.apache.hadoop.hbase.util.Bytes +import org.apache.spark.sql.catalyst.types._ + +import scala.collection.mutable +import scala.collection.mutable.{ArrayBuffer, ListBuffer} + +object HBaseKVHelper { + private val delimiter: Byte = 0 + + /** + * create row key based on key columns information + * @param buffer an input buffer + * @param rawKeyColumns sequence of byte array and data type representing the key columns + * @return array of bytes + */ + def encodingRawKeyColumns(buffer: ListBuffer[Byte], + rawKeyColumns: Seq[(HBaseRawType, DataType)]): HBaseRawType = { + var listBuffer = buffer + listBuffer.clear() + for (rawKeyColumn <- rawKeyColumns) { + listBuffer = listBuffer ++ rawKeyColumn._1 + if (rawKeyColumn._2 == StringType) { + listBuffer += delimiter + } + } + listBuffer.toArray + } + + /** + * get the sequence of key columns from the byte array + * @param buffer an input buffer + * @param rowKey array of bytes + * @param keyColumns the sequence of key columns + * @return sequence of byte array + */ + def decodingRawKeyColumns(buffer: ListBuffer[HBaseRawType], + rowKey: HBaseRawType, keyColumns: Seq[KeyColumn]): Seq[HBaseRawType] = { + var listBuffer = buffer + listBuffer.clear() + var arrayBuffer = ArrayBuffer[Byte]() + var index = 0 + for (keyColumn <- keyColumns) { + arrayBuffer.clear() + val dataType = keyColumn.dataType + if (dataType == StringType) { + while (index < rowKey.length && rowKey(index) != delimiter) { + arrayBuffer += rowKey(index) + index = index + 1 + } + index = index + 1 + } + else { + val length = NativeType.defaultSizeOf(dataType.asInstanceOf[NativeType]) + for (i <- 0 to (length - 1)) { + arrayBuffer += rowKey(index) + index = index + 1 + } + } + listBuffer += arrayBuffer.toArray + } + listBuffer.toSeq + } + + /** + * Takes a record, translate it into HBase row key column and value by matching with metadata + * @param values record that as a sequence of string + * @param columns metadata that contains KeyColumn and NonKeyColumn + * @param keyBytes output paramater, array of (key column and its type); + * @param valueBytes array of (column family, column qualifier, value) + */ + def string2KV(values: Seq[String], + columns: Seq[AbstractColumn], + keyBytes: ListBuffer[(Array[Byte], DataType)], + valueBytes: ListBuffer[(Array[Byte], Array[Byte], Array[Byte])]) = { + assert(values.length == columns.length, + s"values length ${values.length} not equals lolumns length ${columns.length}") + keyBytes.clear() + valueBytes.clear() + val map = mutable.HashMap[Int, (Array[Byte], DataType)]() + var index = 0 + for (i <- 0 until values.length) { + val value = values(i) + val column = columns(i) + val bytes = string2Bytes(value, column.dataType, new BytesUtils) + if (column.isKeyColum()) { + map(column.asInstanceOf[KeyColumn].order) = ((bytes, column.dataType)) + index = index + 1 + } else { + val realCol = column.asInstanceOf[NonKeyColumn] + valueBytes += ((Bytes.toBytes(realCol.family), Bytes.toBytes(realCol.qualifier), bytes)) + } + } + + (0 until index).foreach(k => keyBytes += map.get(k).get) + } + + private def string2Bytes(v: String, dataType: DataType, bu: BytesUtils): Array[Byte] = { + dataType match { + // todo: handle some complex types + case BooleanType => bu.toBytes(v.toBoolean) + case ByteType => bu.toBytes(v) + case DoubleType => bu.toBytes(v.toDouble) + case FloatType => bu.toBytes((v.toFloat)) + case IntegerType => bu.toBytes(v.toInt) + case LongType => bu.toBytes(v.toLong) + case ShortType => bu.toBytes(v.toShort) + case StringType => bu.toBytes(v) + } + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/hbase/HBaseMetadata.scala b/sql/core/src/main/scala/org/apache/spark/sql/hbase/HBaseMetadata.scala new file mode 100644 index 0000000000000..40d9c56c1436e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/hbase/HBaseMetadata.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hbase + +import java.io._ +import org.apache.spark.sql.catalyst.types.DataType + +import scala.Some + +import org.apache.hadoop.hbase.{HColumnDescriptor, TableName, HTableDescriptor, HBaseConfiguration} +import org.apache.hadoop.hbase.client._ +import org.apache.hadoop.hbase.util.Bytes + +import org.apache.spark.sql.hbase._ +import org.apache.spark.Logging +import org.apache.spark.sql.hbase.HBaseRelation +import org.apache.spark.sql.hbase.NonKeyColumn + +/** + * Column represent the sql column + * sqlName the name of the column + * dataType the data type of the column + */ +sealed abstract class AbstractColumn { + val sqlName: String + val dataType: DataType + + def isKeyColum(): Boolean = false + + override def toString: String = { + s"$sqlName , $dataType.typeName" + } +} + +case class KeyColumn(val sqlName: String, val dataType: DataType, val order: Int) + extends AbstractColumn { + override def isKeyColum() = true +} + +case class NonKeyColumn( + val sqlName: String, + val dataType: DataType, + val family: String, + val qualifier: String) extends AbstractColumn { + @transient lazy val familyRaw = Bytes.toBytes(family) + @transient lazy val qualifierRaw = Bytes.toBytes(qualifier) + + override def toString = { + s"$sqlName , $dataType.typeName , $family:$qualifier" + } +} + +private[hbase] class HBaseMetadata extends Logging with Serializable { + + lazy val configuration = HBaseConfiguration.create() + + lazy val admin = new HBaseAdmin(configuration) + + logDebug(s"HBaseAdmin.configuration zkPort=" + + s"${admin.getConfiguration.get("hbase.zookeeper.property.clientPort")}") + + private def createHBaseUserTable(tableName: String, allColumns: Seq[AbstractColumn]) { + val tableDescriptor = new HTableDescriptor(TableName.valueOf(tableName)) + allColumns.map(x => + if (x.isInstanceOf[NonKeyColumn]) { + val nonKeyColumn = x.asInstanceOf[NonKeyColumn] + tableDescriptor.addFamily(new HColumnDescriptor(nonKeyColumn.family)) + }) + + admin.createTable(tableDescriptor, null); + } + + def createTable( + tableName: String, + hbaseTableName: String, + allColumns: Seq[AbstractColumn]) = { + // create a new hbase table for the user if not exist + if (!checkHBaseTableExists(hbaseTableName)) { + createHBaseUserTable(hbaseTableName, allColumns) + } + // check hbase table contain all the families + val nonKeyColumns = allColumns.filter(_.isInstanceOf[NonKeyColumn]) + nonKeyColumns.foreach { + case NonKeyColumn(_, _, family, _) => + if (!checkFamilyExists(hbaseTableName, family)) { + throw new Exception(s"The HBase table doesn't contain the Column Family: $family") + } + } + + HBaseRelation(tableName, "", hbaseTableName, allColumns, Some(configuration)) + } + + private[hbase] def checkHBaseTableExists(hbaseTableName: String): Boolean = { + admin.tableExists(hbaseTableName) + } + + private[hbase] def checkFamilyExists(hbaseTableName: String, family: String): Boolean = { + val tableDescriptor = admin.getTableDescriptor(TableName.valueOf(hbaseTableName)) + tableDescriptor.hasFamily(Bytes.toBytes(family)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/hbase/HBasePartition.scala b/sql/core/src/main/scala/org/apache/spark/sql/hbase/HBasePartition.scala new file mode 100755 index 0000000000000..770dd99f63d4b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/hbase/HBasePartition.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.hbase + +import org.apache.spark.Partition +import org.apache.spark.sql.catalyst.expressions.Expression + +private[hbase] class HBasePartition( + val idx: Int, val mappedIndex: Int, + val lowerBound: Option[HBaseRawType] = None, + val upperBound: Option[HBaseRawType] = None, + val server: Option[String] = None, + val filterPred: Option[Expression] = None) extends Partition with IndexMappable { + + override def index: Int = idx + + override def hashCode(): Int = idx +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/hbase/HBaseRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/hbase/HBaseRelation.scala new file mode 100755 index 0000000000000..df493612d5069 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/hbase/HBaseRelation.scala @@ -0,0 +1,666 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.hbase + +import java.util.ArrayList + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hbase.HBaseConfiguration +import org.apache.hadoop.hbase.client.{Get, HTable, Put, Result, Scan} +import org.apache.hadoop.hbase.filter._ +import org.apache.hadoop.hbase.util.Bytes +import org.apache.spark.Partition +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.LeafNode +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.hbase.PartialPredicateOperations._ + +import scala.collection.JavaConverters._ +import scala.collection.mutable.{ArrayBuffer, ListBuffer} + + +private[hbase] case class HBaseRelation( + tableName: String, + hbaseNamespace: String, + hbaseTableName: String, + allColumns: Seq[AbstractColumn], + @transient optConfiguration: Option[Configuration] = None) + extends LeafNode { + + @transient lazy val keyColumns = allColumns.filter(_.isInstanceOf[KeyColumn]) + .asInstanceOf[Seq[KeyColumn]].sortBy(_.order) + + @transient lazy val nonKeyColumns = allColumns.filter(_.isInstanceOf[NonKeyColumn]) + .asInstanceOf[Seq[NonKeyColumn]] + + @transient lazy val partitionKeys: Seq[AttributeReference] = keyColumns.map(col => + AttributeReference(col.sqlName, col.dataType, nullable = false)()) + + @transient lazy val columnMap = allColumns.map { + case key: KeyColumn => (key.sqlName, key.order) + case nonKey: NonKeyColumn => (nonKey.sqlName, nonKey) + }.toMap + + def configuration() = optConfiguration.getOrElse(HBaseConfiguration.create) + + // todo: scwf,remove this later + logDebug(s"HBaseRelation config has zkPort=" + + s"${configuration.get("hbase.zookeeper.property.clientPort")}") + + @transient lazy val htable: HTable = new HTable(configuration, hbaseTableName) + + // todo: scwf, why non key columns + lazy val attributes = nonKeyColumns.map(col => + AttributeReference(col.sqlName, col.dataType, nullable = true)()) + + // lazy val colFamilies = nonKeyColumns.map(_.family).distinct + // lazy val applyFilters = false + + def isNonKey(attr: AttributeReference): Boolean = { + attributes.exists(_.exprId == attr.exprId) + } + + def keyIndex(attr: AttributeReference): Int = { + // -1 if nonexistent + partitionKeys.indexWhere(_.exprId == attr.exprId) + } + def closeHTable() = htable.close + + val output: Seq[Attribute] = { + allColumns.map { + case column => + (partitionKeys union attributes).find(_.name == column.sqlName).get + } + } + + lazy val partitions: Seq[HBasePartition] = { + val regionLocations = htable.getRegionLocations.asScala.toSeq + log.info(s"Number of HBase regions for " + + s"table ${htable.getName.getNameAsString}: ${regionLocations.size}") + regionLocations.zipWithIndex.map { + case p => + new HBasePartition( + p._2, p._2, + Some(p._1._1.getStartKey), + Some(p._1._1.getEndKey), + Some(p._1._2.getHostname)) + } + } + + private def generateRange(partition: HBasePartition, pred: Expression, + index: Int): + (PartitionRange[_]) = { + def getData(dt: NativeType, + buffer: ListBuffer[HBaseRawType], + bound: Option[HBaseRawType]): Option[Any] = { + if (Bytes.toStringBinary(bound.get) == "") None + else { + val bytesUtils = new BytesUtils + Some(DataTypeUtils.bytesToData( + HBaseKVHelper.decodingRawKeyColumns(buffer, bound.get, keyColumns)(index), + dt, bytesUtils).asInstanceOf[dt.JvmType]) + } + } + + val dt = keyColumns(index).dataType.asInstanceOf[NativeType] + val isLastKeyIndex = index == (keyColumns.size - 1) + val buffer = ListBuffer[HBaseRawType]() + val start = getData(dt, buffer, partition.lowerBound) + val end = getData(dt, buffer, partition.upperBound) + val startInclusive = !start.isEmpty + val endInclusive = !end.isEmpty && !isLastKeyIndex + new PartitionRange(start, startInclusive, end, endInclusive, partition.index, dt, pred) + } + + private def prePruneRanges(ranges: Seq[PartitionRange[_]], keyIndex: Int) + : (Seq[PartitionRange[_]], Seq[PartitionRange[_]]) = { + require(keyIndex < keyColumns.size, "key index out of range") + if (ranges.isEmpty) { + (ranges, Nil) + } else if (keyIndex == 0) { + (Nil, ranges) + } else { + // the first portion is of those ranges of equal start and end values of the + // previous dimensions so they can be subject to further checks on the next dimension + val (p1, p2) = ranges.partition(p => p.start == p.end) + (p2, p1.map(p => generateRange(partitions(p.id), p.pred, keyIndex))) + } + } + + def getPrunedPartitions(partitionPred: Option[Expression] = None): Option[Seq[HBasePartition]] = { + def getPrunedRanges(pred: Expression): Seq[PartitionRange[_]] = { + val predRefs = pred.references.toSeq + val boundPruningPred = BindReferences.bindReference(pred, predRefs) + val keyIndexToPredIndex = (for { + (keyColumn, keyIndex) <- keyColumns.zipWithIndex + (predRef, predIndex) <- predRefs.zipWithIndex + if (keyColumn.sqlName == predRef.name) + } yield (keyIndex, predIndex)).toMap + + val row = new GenericMutableRow(predRefs.size) + var notPrunedRanges = partitions.map(generateRange(_, null, 0)) + var prunedRanges: Seq[PartitionRange[_]] = Nil + + for (keyIndex <- 0 until keyColumns.size; if (!notPrunedRanges.isEmpty)) { + val (passedRanges, toBePrunedRanges) = prePruneRanges(notPrunedRanges, keyIndex) + prunedRanges = prunedRanges ++ passedRanges + notPrunedRanges = + if (keyIndexToPredIndex.contains(keyIndex)) { + toBePrunedRanges.filter( + range => { + val predIndex = keyIndexToPredIndex(keyIndex) + row.update(predIndex, range) + val partialEvalResult = boundPruningPred.partialEval(row) + // MAYBE is represented by a null + (partialEvalResult == null) || partialEvalResult.asInstanceOf[Boolean] + } + ) + } else toBePrunedRanges + } + prunedRanges ++ notPrunedRanges + } + + partitionPred match { + case None => Some(partitions) + case Some(pred) => if (pred.references.intersect(AttributeSet(partitionKeys)).isEmpty) { + Some(partitions) + } else { + val prunedRanges: Seq[PartitionRange[_]] = getPrunedRanges(pred) + println("prunedRanges: " + prunedRanges.length) + var idx: Int = -1 + val result = Some(prunedRanges.map(p => { + val par = partitions(p.id) + idx = idx + 1 + if (p.pred == null) { + new HBasePartition(idx, par.mappedIndex, par.lowerBound, par.upperBound, par.server) + } else { + new HBasePartition(idx, par.mappedIndex, par.lowerBound, par.upperBound, + par.server, Some(p.pred)) + } + })) + result.foreach(println) + result + } + } + } + + def getPrunedPartitions2(partitionPred: Option[Expression] = None) + : Option[Seq[HBasePartition]] = { + def getPrunedRanges(pred: Expression): Seq[PartitionRange[_]] = { + val predRefs = pred.references.toSeq + val boundPruningPred = BindReferences.bindReference(pred, predRefs) + val keyIndexToPredIndex = (for { + (keyColumn, keyIndex) <- keyColumns.zipWithIndex + (predRef, predIndex) <- predRefs.zipWithIndex + if (keyColumn.sqlName == predRef.name) + } yield (keyIndex, predIndex)).toMap + + val row = new GenericMutableRow(predRefs.size) + var notPrunedRanges = partitions.map(generateRange(_, boundPruningPred, 0)) + var prunedRanges: Seq[PartitionRange[_]] = Nil + + for (keyIndex <- 0 until keyColumns.size; if (!notPrunedRanges.isEmpty)) { + val (passedRanges, toBePrunedRanges) = prePruneRanges(notPrunedRanges, keyIndex) + prunedRanges = prunedRanges ++ passedRanges + notPrunedRanges = + if (keyIndexToPredIndex.contains(keyIndex)) { + toBePrunedRanges.filter( + range => { + val predIndex = keyIndexToPredIndex(keyIndex) + row.update(predIndex, range) + val partialEvalResult = range.pred.partialReduce(row) + range.pred = if (partialEvalResult.isInstanceOf[Expression]) { + // progressively fine tune the constraining predicate + partialEvalResult.asInstanceOf[Expression] + } else { + null + } + // MAYBE is represented by a to-be-qualified-with expression + partialEvalResult.isInstanceOf[Expression] || + partialEvalResult.asInstanceOf[Boolean] + } + ) + } else toBePrunedRanges + } + prunedRanges ++ notPrunedRanges + } + + partitionPred match { + case None => Some(partitions) + case Some(pred) => if (pred.references.intersect(AttributeSet(partitionKeys)).isEmpty) { + // the predicate does not apply to the partitions at all; just push down the filtering + Some(partitions.map(p=>new HBasePartition(p.idx, p.mappedIndex, p.lowerBound, + p.upperBound, p.server, Some(pred)))) + } else { + val prunedRanges: Seq[PartitionRange[_]] = getPrunedRanges(pred) + println("prunedRanges: " + prunedRanges.length) + var idx: Int = -1 + val result = Some(prunedRanges.map(p => { + val par = partitions(p.id) + idx = idx + 1 + // pruned partitions have the same "physical" partition index, but different + // "canonical" index + if (p.pred == null) { + new HBasePartition(idx, par.mappedIndex, par.lowerBound, + par.upperBound, par.server, None) + } else { + new HBasePartition(idx, par.mappedIndex, par.lowerBound, + par.upperBound, par.server, Some(p.pred)) + } + })) + // TODO: remove/modify the following debug info + // result.foreach(println) + result + } + } + } + + /** + * Return the start keys of all of the regions in this table, + * as a list of SparkImmutableBytesWritable. + */ + def getRegionStartKeys() = { + val byteKeys: Array[Array[Byte]] = htable.getStartKeys + val ret = ArrayBuffer[ImmutableBytesWritableWrapper]() + for (byteKey <- byteKeys) { + ret += new ImmutableBytesWritableWrapper(byteKey) + } + ret + } + + def buildFilter( + projList: Seq[NamedExpression], + rowKeyPredicate: Option[Expression], + valuePredicate: Option[Expression]) = { + val distinctProjList = projList.distinct + if (distinctProjList.size == allColumns.size) { + Option(new FilterList(new ArrayList[Filter])) + } else { + val filtersList: List[Filter] = nonKeyColumns.filter { + case nkc => distinctProjList.exists(nkc == _.name) + }.map { + case NonKeyColumn(_, _, family, qualifier) => { + val columnFilters = new ArrayList[Filter] + columnFilters.add( + new FamilyFilter( + CompareFilter.CompareOp.EQUAL, + new BinaryComparator(Bytes.toBytes(family)) + )) + columnFilters.add( + new QualifierFilter( + CompareFilter.CompareOp.EQUAL, + new BinaryComparator(Bytes.toBytes(qualifier)) + )) + new FilterList(FilterList.Operator.MUST_PASS_ALL, columnFilters) + } + }.toList + + Option(new FilterList(FilterList.Operator.MUST_PASS_ONE, filtersList.asJava)) + } + } + + def buildFilter2( + projList: Seq[NamedExpression], + pred: Option[Expression]): (Option[FilterList], Option[Expression]) = { + var distinctProjList = projList.distinct + if (pred.isDefined) { + distinctProjList = distinctProjList.filterNot(_.references.subsetOf(pred.get.references)) + } + val projFilterList = if (distinctProjList.size == allColumns.size) { + None + } else { + val filtersList: List[Filter] = nonKeyColumns.filter { + case nkc => distinctProjList.exists(nkc == _.name) + }.map { + case NonKeyColumn(_, _, family, qualifier) => { + val columnFilters = new ArrayList[Filter] + columnFilters.add( + new FamilyFilter( + CompareFilter.CompareOp.EQUAL, + new BinaryComparator(Bytes.toBytes(family)) + )) + columnFilters.add( + new QualifierFilter( + CompareFilter.CompareOp.EQUAL, + new BinaryComparator(Bytes.toBytes(qualifier)) + )) + new FilterList(FilterList.Operator.MUST_PASS_ALL, columnFilters) + } + }.toList + + Option(new FilterList(FilterList.Operator.MUST_PASS_ONE, filtersList.asJava)) + } + + if (pred.isDefined) { + val predExp: Expression = pred.get + // build pred pushdown filters: + // 1. push any NOT through AND/OR + val notPushedPred = NOTPusher(predExp) + // 2. classify the transformed predicate into pushdownable and non-pushdownable predicates + val classfier = new ScanPredClassfier(this, 0) // Right now only on primary key dimension + val (pushdownFilterPred, otherPred) = classfier(notPushedPred) + // 3. build a FilterList mirroring the pushdownable predicate + val predPushdownFilterList = buildFilterListFromPred(pushdownFilterPred) + // 4. merge the above FilterList with the one from the projection + (predPushdownFilterList, otherPred) + } else { + (projFilterList, None) + } + } + + private def buildFilterListFromPred(pred: Option[Expression]): Option[FilterList] = { + var result: Option[FilterList] = None + val filters = new ArrayList[Filter] + if (pred.isDefined) { + val expression = pred.get + expression match { + case And(left, right) => { + if (left != null) { + val leftFilterList = buildFilterListFromPred(Some(left)) + if (leftFilterList.isDefined) { + filters.add(leftFilterList.get) + } + } + if (right != null) { + val rightFilterList = buildFilterListFromPred(Some(right)) + if (rightFilterList.isDefined) { + filters.add(rightFilterList.get) + } + } + result = Option(new FilterList(FilterList.Operator.MUST_PASS_ALL, filters)) + } + case Or(left, right) => { + if (left != null) { + val leftFilterList = buildFilterListFromPred(Some(left)) + if (leftFilterList.isDefined) { + filters.add(leftFilterList.get) + } + } + if (right != null) { + val rightFilterList = buildFilterListFromPred(Some(right)) + if (rightFilterList.isDefined) { + filters.add(rightFilterList.get) + } + } + result = Option(new FilterList(FilterList.Operator.MUST_PASS_ONE, filters)) + } + case GreaterThan(left: AttributeReference, right: Literal) => { + val keyColumn = keyColumns.find((p: KeyColumn) => p.sqlName.equals(left.name)) + val nonKeyColumn = nonKeyColumns.find((p: NonKeyColumn) => p.sqlName.equals(left.name)) + if (keyColumn.isDefined) { + val filter = new RowFilter(CompareFilter.CompareOp.GREATER, + new BinaryComparator(Bytes.toBytes(right.value.toString))) + result = Option(new FilterList(filter)) + } else if (nonKeyColumn.isDefined) { + val column = nonKeyColumn.get + val filter = new SingleColumnValueFilter(Bytes.toBytes(column.family), + Bytes.toBytes(column.qualifier), + CompareFilter.CompareOp.GREATER, + DataTypeUtils.getComparator(right)) + result = Option(new FilterList(filter)) + } + } + case GreaterThanOrEqual(left: AttributeReference, right: Literal) => { + val keyColumn = keyColumns.find((p: KeyColumn) => p.sqlName.equals(left.name)) + val nonKeyColumn = nonKeyColumns.find((p: NonKeyColumn) => p.sqlName.equals(left.name)) + if (keyColumn.isDefined) { + val filter = new RowFilter(CompareFilter.CompareOp.GREATER_OR_EQUAL, + new BinaryComparator(Bytes.toBytes(right.value.toString))) + result = Option(new FilterList(filter)) + } else if (nonKeyColumn.isDefined) { + val column = nonKeyColumn.get + val filter = new SingleColumnValueFilter(Bytes.toBytes(column.family), + Bytes.toBytes(column.qualifier), + CompareFilter.CompareOp.GREATER_OR_EQUAL, + DataTypeUtils.getComparator(right)) + result = Option(new FilterList(filter)) + } + } + case EqualTo(left: AttributeReference, right: Literal) => { + val keyColumn = keyColumns.find((p: KeyColumn) => p.sqlName.equals(left.name)) + val nonKeyColumn = nonKeyColumns.find((p: NonKeyColumn) => p.sqlName.equals(left.name)) + if (keyColumn.isDefined) { + val filter = new RowFilter(CompareFilter.CompareOp.EQUAL, + new BinaryComparator(Bytes.toBytes(right.value.toString))) + result = Option(new FilterList(filter)) + } else if (nonKeyColumn.isDefined) { + val column = nonKeyColumn.get + val filter = new SingleColumnValueFilter(Bytes.toBytes(column.family), + Bytes.toBytes(column.qualifier), + CompareFilter.CompareOp.EQUAL, + DataTypeUtils.getComparator(right)) + result = Option(new FilterList(filter)) + } + } + case LessThan(left: AttributeReference, right: Literal) => { + val keyColumn = keyColumns.find((p: KeyColumn) => p.sqlName.equals(left.name)) + val nonKeyColumn = nonKeyColumns.find((p: NonKeyColumn) => p.sqlName.equals(left.name)) + if (keyColumn.isDefined) { + val filter = new RowFilter(CompareFilter.CompareOp.LESS, + new BinaryComparator(Bytes.toBytes(right.value.toString))) + result = Option(new FilterList(filter)) + } else if (nonKeyColumn.isDefined) { + val column = nonKeyColumn.get + val filter = new SingleColumnValueFilter(Bytes.toBytes(column.family), + Bytes.toBytes(column.qualifier), + CompareFilter.CompareOp.LESS, + DataTypeUtils.getComparator(right)) + result = Option(new FilterList(filter)) + } + } + case LessThanOrEqual(left: AttributeReference, right: Literal) => { + val keyColumn = keyColumns.find((p: KeyColumn) => p.sqlName.equals(left.name)) + val nonKeyColumn = nonKeyColumns.find((p: NonKeyColumn) => p.sqlName.equals(left.name)) + if (keyColumn.isDefined) { + val filter = new RowFilter(CompareFilter.CompareOp.LESS_OR_EQUAL, + new BinaryComparator(Bytes.toBytes(right.value.toString))) + result = Option(new FilterList(filter)) + } else if (nonKeyColumn.isDefined) { + val column = nonKeyColumn.get + val filter = new SingleColumnValueFilter(Bytes.toBytes(column.family), + Bytes.toBytes(column.qualifier), + CompareFilter.CompareOp.LESS_OR_EQUAL, + DataTypeUtils.getComparator(right)) + result = Option(new FilterList(filter)) + } + } + } + } + result + } + + def buildPut(row: Row): Put = { + // TODO: revisit this using new KeyComposer + val rowKey: HBaseRawType = null + new Put(rowKey) + } + + def buildScan( + split: Partition, + filters: Option[FilterList], + projList: Seq[NamedExpression]): Scan = { + val hbPartition = split.asInstanceOf[HBasePartition] + val scan = { + (hbPartition.lowerBound, hbPartition.upperBound) match { + case (Some(lb), Some(ub)) => new Scan(lb, ub) + case (Some(lb), None) => new Scan(lb) + case _ => new Scan + } + } + if (filters.isDefined && !filters.get.getFilters.isEmpty) { + scan.setFilter(filters.get) + } + // TODO: add add Family to SCAN from projections + scan + } + + def buildGet(projList: Seq[NamedExpression], rowKey: HBaseRawType) { + new Get(rowKey) + // TODO: add columns to the Get + } + + // /** + // * Trait for RowKeyParser's that convert a raw array of bytes into their constituent + // * logical column values + // * + // */ + // trait AbstractRowKeyParser { + // + //// def createKey(rawBytes: Seq[HBaseRawType], version: Byte): HBaseRawType + //// + //// def parseRowKey(rowKey: HBaseRawType): Seq[HBaseRawType] + //// + //// def parseRowKeyWithMetaData(rkCols: Seq[KeyColumn], rowKey: HBaseRawType) + //// : SortedMap[TableName, (KeyColumn, Any)] // TODO change Any + // } + // + // case class RowKeySpec(offsets: Seq[Int], version: Byte = RowKeyParser.Version1) + // + // // TODO(Bo): replace the implementation with the null-byte terminated string logic + // object RowKeyParser extends AbstractRowKeyParser with Serializable { + // val Version1 = 1.toByte + // val VersionFieldLen = 1 + // // Length in bytes of the RowKey version field + // val DimensionCountLen = 1 + // // One byte for the number of key dimensions + // val MaxDimensions = 255 + // val OffsetFieldLen = 2 + // + // // Two bytes for the value of each dimension offset. + // // Therefore max size of rowkey is 65535. Note: if longer rowkeys desired in future + // // then simply define a new RowKey version to support it. Otherwise would be wasteful + // // to define as 4 bytes now. + // def computeLength(keys: Seq[HBaseRawType]) = { + // VersionFieldLen + keys.map(_.length).sum + + // OffsetFieldLen * keys.size + DimensionCountLen + // } + // + // override def createKey(keys: Seq[HBaseRawType], version: Byte = Version1): HBaseRawType = { + // val barr = new Array[Byte](computeLength(keys)) + // val arrayx = new AtomicInteger(0) + // barr(arrayx.getAndAdd(VersionFieldLen)) = version // VersionByte + // + // // Remember the starting offset of first data value + // val valuesStartIndex = new AtomicInteger(arrayx.get) + // + // // copy each of the dimension values in turn + // keys.foreach { k => copyToArr(barr, k, arrayx.getAndAdd(k.length))} + // + // // Copy the offsets of each dim value + // // The valuesStartIndex is the location of the first data value and thus the first + // // value included in the Offsets sequence + // keys.foreach { k => + // copyToArr(barr, + // short2b(valuesStartIndex.getAndAdd(k.length).toShort), + // arrayx.getAndAdd(OffsetFieldLen)) + // } + // barr(arrayx.get) = keys.length.toByte // DimensionCountByte + // barr + // } + // + // def copyToArr[T](a: Array[T], b: Array[T], aoffset: Int) = { + // b.copyToArray(a, aoffset) + // } + // + // def short2b(sh: Short): Array[Byte] = { + // val barr = Array.ofDim[Byte](2) + // barr(0) = ((sh >> 8) & 0xff).toByte + // barr(1) = (sh & 0xff).toByte + // barr + // } + // + // def b2Short(barr: Array[Byte]) = { + // val out = (barr(0).toShort << 8) | barr(1).toShort + // out + // } + // + // def createKeyFromCatalystRow(schema: StructType, keyCols: Seq[KeyColumn], row: Row) = { + // // val rawKeyCols = DataTypeUtils.catalystRowToHBaseRawVals(schema, row, keyCols) + // // createKey(rawKeyCols) + // null + // } + // + // def getMinimumRowKeyLength = VersionFieldLen + DimensionCountLen + // + // override def parseRowKey(rowKey: HBaseRawType): Seq[HBaseRawType] = { + // assert(rowKey.length >= getMinimumRowKeyLength, + // s"RowKey is invalid format - less than minlen . Actual length=${rowKey.length}") + // assert(rowKey(0) == Version1, s"Only Version1 supported. Actual=${rowKey(0)}") + // val ndims: Int = rowKey(rowKey.length - 1).toInt + // val offsetsStart = rowKey.length - DimensionCountLen - ndims * OffsetFieldLen + // val rowKeySpec = RowKeySpec( + // for (dx <- 0 to ndims - 1) + // yield b2Short(rowKey.slice(offsetsStart + dx * OffsetFieldLen, + // offsetsStart + (dx + 1) * OffsetFieldLen)) + // ) + // + // val endOffsets = rowKeySpec.offsets.tail :+ (rowKey.length - DimensionCountLen - 1) + // val colsList = rowKeySpec.offsets.zipWithIndex.map { case (off, ix) => + // rowKey.slice(off, endOffsets(ix)) + // } + // colsList + // } + // + //// //TODO + //// override def parseRowKeyWithMetaData(rkCols: Seq[KeyColumn], rowKey: HBaseRawType): + //// SortedMap[TableName, (KeyColumn, Any)] = { + //// import scala.collection.mutable.HashMap + //// + //// val rowKeyVals = parseRowKey(rowKey) + //// val rmap = rowKeyVals.zipWithIndex.foldLeft(new HashMap[ColumnName, (Column, Any)]()) { + //// case (m, (cval, ix)) => + //// m.update(rkCols(ix).toColumnName, (rkCols(ix), + //// hbaseFieldToRowField(cval, rkCols(ix).dataType))) + //// m + //// } + //// TreeMap(rmap.toArray: _*)(Ordering.by { cn: ColumnName => rmap(cn)._1.ordinal}) + //// .asInstanceOf[SortedMap[ColumnName, (Column, Any)]] + //// } + // + // def show(bytes: Array[Byte]) = { + // val len = bytes.length + // // val out = s"Version=${bytes(0).toInt} NumDims=${bytes(len - 1)} " + // } + // + // } + + def buildRow(projections: Seq[(Attribute, Int)], + result: Result, + row: MutableRow, + bytesUtils: BytesUtils): Row = { + assert(projections.size == row.length, "Projection size and row size mismatched") + // TODO: replaced with the new Key method + val buffer = ListBuffer[HBaseRawType]() + val rowKeys = HBaseKVHelper.decodingRawKeyColumns(buffer, result.getRow, keyColumns) + projections.foreach { p => + columnMap.get(p._1.name).get match { + case column: NonKeyColumn => { + val colValue = result.getValue(column.familyRaw, column.qualifierRaw) + DataTypeUtils.setRowColumnFromHBaseRawType(row, p._2, colValue, + column.dataType, bytesUtils) + } + case ki => { + val keyIndex = ki.asInstanceOf[Int] + val rowKey = rowKeys(keyIndex) + DataTypeUtils.setRowColumnFromHBaseRawType(row, p._2, rowKey, + keyColumns(keyIndex).dataType, bytesUtils) + } + } + } + row + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/hbase/HBaseSQLReaderRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/hbase/HBaseSQLReaderRDD.scala new file mode 100644 index 0000000000000..23e53e2faaf08 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/hbase/HBaseSQLReaderRDD.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hbase + +import org.apache.hadoop.hbase.client.Result +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, GenericMutableRow} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.{InterruptibleIterator, Logging, Partition, TaskContext} + +class HBaseSQLReaderRDD( + relation: HBaseRelation, + output: Seq[Attribute], + rowKeyPred: Option[Expression], + valuePred: Option[Expression], + partitionPred: Option[Expression], + coprocSubPlan: Option[SparkPlan])(@transient sqlContext: SQLContext) + extends RDD[Row](sqlContext.sparkContext, Nil) with Logging { + + private final val cachingSize: Int = 100 // Todo: be made configurable + + override def getPartitions: Array[Partition] = { + relation.getPrunedPartitions(partitionPred).get.toArray + } + + override def getPreferredLocations(split: Partition): Seq[String] = { + split.asInstanceOf[HBasePartition].server.map { + identity + }.toSeq + } + + override def compute(split: Partition, context: TaskContext): Iterator[Row] = { + val filters = relation.buildFilter(output, rowKeyPred, valuePred) + val scan = relation.buildScan(split, filters, output) + scan.setCaching(cachingSize) + logDebug(s"relation.htable scanner conf=" + + s"${relation.htable.getConfiguration.get("hbase.zookeeper.property.clientPort")}") + val scanner = relation.htable.getScanner(scan) + + val row = new GenericMutableRow(output.size) + val projections = output.zipWithIndex + val bytesUtils = new BytesUtils + + var finished: Boolean = false + var gotNext: Boolean = false + var result: Result = null + + val iter = new Iterator[Row] { + override def hasNext: Boolean = { + if (!finished) { + if (!gotNext) { + result = scanner.next + finished = result == null + gotNext = true + } + } + if (finished) { + close + } + !finished + } + + override def next(): Row = { + if (hasNext) { + gotNext = false + relation.buildRow(projections, result, row, bytesUtils) + } else { + null + } + } + + def close() = { + try { + scanner.close() + } catch { + case e: Exception => logWarning("Exception in scanner.close", e) + } + } + } + new InterruptibleIterator(context, iter) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/hbase/IndexMappable.scala b/sql/core/src/main/scala/org/apache/spark/sql/hbase/IndexMappable.scala new file mode 100755 index 0000000000000..e2d5daac2f505 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/hbase/IndexMappable.scala @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.hbase + +private[hbase] trait IndexMappable { + val mappedIndex: Int +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/hbase/NotPusher.scala b/sql/core/src/main/scala/org/apache/spark/sql/hbase/NotPusher.scala new file mode 100755 index 0000000000000..3b6a0c40641d0 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/hbase/NotPusher.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hbase + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.rules._ + +/** + * Pushes NOT through And/Or + */ +object NOTPusher extends Rule[Expression] { + def apply(pred: Expression): Expression = pred transformDown { + case Not(And(left, right)) => Or(Not(left), Not(right)) + case Not(Or(left, right)) => And(Not(left), Not(right)) + case not @ Not(exp) => { + // This pattern has been caught by optimizer but after NOT pushdown + // more opportunities may present + exp match { + case GreaterThan(l, r) => LessThanOrEqual(l, r) + case GreaterThanOrEqual(l, r) => LessThan(l, r) + case LessThan(l, r) => GreaterThanOrEqual(l, r) + case LessThanOrEqual(l, r) => GreaterThan(l, r) + case Not(e) => e + case _ => not + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/hbase/PartialPredEval.scala b/sql/core/src/main/scala/org/apache/spark/sql/hbase/PartialPredEval.scala new file mode 100755 index 0000000000000..1f29f1cc3682e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/hbase/PartialPredEval.scala @@ -0,0 +1,480 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hbase + +import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.types.{DataType, NativeType} + + +object PartialPredicateOperations { + + // Partial evaluation is nullness-based, i.e., uninterested columns are assigned nulls, + // which necessitates changes of the null handling from the normal evaluations + // of predicate expressions + implicit class partialPredicateEvaluator(e: Expression) { + def partialEval(input: Row): Any = { + e match { + case And(left, right) => { + val l = left.partialEval(input) + if (l == false) { + false + } else { + val r = right.partialEval(input) + if (r == false) { + false + } else { + if (l != null && r != null) { + true + } else { + null + } + } + } + } + case Or(left, right) => { + val l = left.partialEval(input) + if (l == true) { + true + } else { + val r = right.partialEval(input) + if (r == true) { + true + } else { + if (l != null && r != null) { + false + } else { + null + } + } + } + } + case Not(child) => { + child.partialEval(input) match { + case null => null + case b: Boolean => !b + } + } + case In(value, list) => { + val evaluatedValue = value.partialEval(input) + if (evaluatedValue == null) { + null + } else { + val evaluatedList = list.map(_.partialEval(input)) + if (evaluatedList.exists(e=> e == evaluatedValue)) { + true + } else if (evaluatedList.exists(e=> e == null)) { + null + } else { + false + } + } + } + case InSet(value, hset) => { + val evaluatedValue = value.partialEval(input) + if (evaluatedValue == null) { + null + } else { + hset.contains(evaluatedValue) + } + } + case l: LeafExpression => l.eval(input) + case b: BoundReference => b.eval(input) //Really a LeafExpression but not declared as such + case n: NamedExpression => n.eval(input) //Really a LeafExpression but not declared as such + case IsNull(child) => { + if (child.partialEval(input) == null) { + // In partial evaluation, null indicates MAYBE + null + } else { + // Now we only support non-nullable primary key components + false + } + } + // TODO: CAST/Arithithmetic can be treated more nicely + case Cast(_, _) => null + // case BinaryArithmetic => null + case UnaryMinus(_) => null + case EqualTo(left, right) => { + val cmp = pc2(input, left, right) + if (cmp.isDefined) { + cmp.get == 0 + } else { + null + } + } + case LessThan(left, right) => { + val cmp = pc2(input, left, right) + if (cmp.isDefined) { + cmp.get < 0 + } else { + null + } + } + case LessThanOrEqual(left, right) => { + val cmp = pc2(input, left, right) + if (cmp.isDefined) { + cmp.get <= 0 + } else { + null + } + } + case GreaterThan(left, right) => { + val cmp = pc2(input, left, right) + if (cmp.isDefined) { + cmp.get > 0 + } else { + null + } + } + case GreaterThanOrEqual(left, right) => { + val cmp = pc2(input, left, right) + if (cmp.isDefined) { + cmp.get >= 0 + } else { + null + } + } + case If(predicate, trueE, falseE) => { + val v = predicate.partialEval(input) + if (v == null) { + null + } else if (v.asInstanceOf[Boolean]) { + trueE.partialEval(input) + } else { + falseE.partialEval(input) + } + } + case _ => null + } + } + + @inline + protected def pc2( + i: Row, + e1: Expression, + e2: Expression): Option[Int] = { + if (e1.dataType != e2.dataType) { + throw new TreeNodeException(e, s"Types do not match ${e1.dataType} != ${e2.dataType}") + } + + val evalE1 = e1.partialEval(i) + if (evalE1 == null) { + None + } else { + val evalE2 = e2.partialEval(i) + if (evalE2 == null) { + None + } else { + e1.dataType match { + case nativeType: NativeType => { + val pdt = RangeType.primitiveToPODataTypeMap.get(nativeType).getOrElse(null) + if (pdt == null) { + sys.error(s"Type $i does not have corresponding partial ordered type") + } else { + pdt.partialOrdering.tryCompare( + pdt.toPartiallyOrderingDataType(evalE1, nativeType).asInstanceOf[pdt.JvmType], + pdt.toPartiallyOrderingDataType(evalE2, nativeType).asInstanceOf[pdt.JvmType]) + } + } + case other => sys.error(s"Type $other does not support partially ordered operations") + } + } + } + } + } + + // Partial reduction is nullness-based, i.e., uninterested columns are assigned nulls, + // which necessitates changes of the null handling from the normal evaluations + // of predicate expressions + // There are 3 possible results: TRUE, FALSE, and MAYBE represented by a predicate + // which will be used to further filter the results + implicit class partialPredicateReducer(e: Expression) { + def partialReduce(input: Row): Any = { + e match { + case And(left, right) => { + val l = left.partialReduce(input) + if (l == false) { + false + } else { + val r = right.partialReduce(input) + if (r == false) { + false + } else { + (l, r) match { + case (true, true) => true + case (true, _) => r + case (_, true) => l + case (nl: Expression, nr: Expression) => { + if ((nl fastEquals left) && (nr fastEquals right)) { + e + } else { + And(nl, nr) + } + } + case _ => sys.error("unexpected child type(s) in partial reduction") + } + } + } + } + case Or(left, right) => { + val l = left.partialReduce(input) + if (l == true) { + true + } else { + val r = right.partialReduce(input) + if (r == true) { + true + } else { + (l, r) match { + case (false, false) => false + case (false, _) => r + case (_, false) => l + case (nl: Expression, nr: Expression) => { + if ((nl fastEquals left) && (nr fastEquals right)) { + e + } else { + Or(nl, nr) + } + } + case _ => sys.error("unexpected child type(s) in partial reduction") + } + } + } + } + case Not(child) => { + child.partialReduce(input) match { + case b: Boolean => !b + case ec: Expression => if (ec fastEquals child) { e } else { Not(ec) } + } + } + case In(value, list) => { + val evaluatedValue = value.partialReduce(input) + if (evaluatedValue.isInstanceOf[Expression]) { + val evaluatedList = list.map(e=>e.partialReduce(input) match { + case e: Expression => e + case d => Literal(d, e.dataType) + }) + In(evaluatedValue.asInstanceOf[Expression], evaluatedList) + } else { + val evaluatedList = list.map(_.partialReduce(input)) + if (evaluatedList.exists(e=> e == evaluatedValue)) { + true + } else { + val newList = evaluatedList.filter(_.isInstanceOf[Expression]) + .map(_.asInstanceOf[Expression]) + if (newList.isEmpty) { + false + } else { + In(Literal(evaluatedValue, value.dataType), newList) + } + } + } + } + case InSet(value, hset) => { + val evaluatedValue = value.partialReduce(input) + if (evaluatedValue.isInstanceOf[Expression]) { + InSet(evaluatedValue.asInstanceOf[Expression], hset) + } else { + hset.contains(evaluatedValue) + } + } + case l: LeafExpression => { + val res = l.eval(input) + if (res == null) { l } else {res} + } + case b: BoundReference => { + val res = b.eval(input) + // If the result is a MAYBE, returns the original expression + if (res == null) { b } else {res} + } + case n: NamedExpression => { + val res = n.eval(input) + if(res == null) { n } else { res } + } + case IsNull(child) => e + // TODO: CAST/Arithithmetic could be treated more nicely + case Cast(_, _) => e + // case BinaryArithmetic => null + case UnaryMinus(_) => e + case EqualTo(left, right) => { + val evalL = left.partialReduce(input) + val evalR = right.partialReduce(input) + if (evalL.isInstanceOf[Expression] && evalR.isInstanceOf[Expression]) { + EqualTo(evalL.asInstanceOf[Expression], evalR.asInstanceOf[Expression]) + } else if (evalL.isInstanceOf[Expression]) { + EqualTo(evalL.asInstanceOf[Expression], right) + } else if (evalR.isInstanceOf[Expression]) { + EqualTo(left.asInstanceOf[Expression], evalR.asInstanceOf[Expression]) + } else { + val cmp = prc2(input, left.dataType, right.dataType, evalL, evalR) + if (cmp.isDefined) { + cmp.get == 0 + } else { + e + } + } + } + case LessThan(left, right) => { + val evalL = left.partialReduce(input) + val evalR = right.partialReduce(input) + if (evalL.isInstanceOf[Expression] && evalR.isInstanceOf[Expression]) { + EqualTo(evalL.asInstanceOf[Expression], evalR.asInstanceOf[Expression]) + } else if (evalL.isInstanceOf[Expression]) { + EqualTo(evalL.asInstanceOf[Expression], right) + } else if (evalR.isInstanceOf[Expression]) { + EqualTo(left, evalR.asInstanceOf[Expression]) + } else { + val cmp = prc2(input, left.dataType, right.dataType, evalL, evalR) + if (cmp.isDefined) { + cmp.get < 0 + } else { + e + } + } + } + case LessThanOrEqual(left, right) => { + val evalL = left.partialReduce(input) + val evalR = right.partialReduce(input) + if (evalL.isInstanceOf[Expression] && evalR.isInstanceOf[Expression]) { + EqualTo(evalL.asInstanceOf[Expression], evalR.asInstanceOf[Expression]) + } else if (evalL.isInstanceOf[Expression]) { + EqualTo(evalL.asInstanceOf[Expression], right) + } else if (evalR.isInstanceOf[Expression]) { + EqualTo(left, evalR.asInstanceOf[Expression]) + } else { + val cmp = prc2(input, left.dataType, right.dataType, evalL, evalR) + if (cmp.isDefined) { + cmp.get <= 0 + } else { + e + } + } + } + case GreaterThan(left, right) => { + val evalL = left.partialReduce(input) + val evalR = right.partialReduce(input) + if (evalL.isInstanceOf[Expression] && evalR.isInstanceOf[Expression]) { + EqualTo(evalL.asInstanceOf[Expression], evalR.asInstanceOf[Expression]) + } else if (evalL.isInstanceOf[Expression]) { + EqualTo(evalL.asInstanceOf[Expression], right) + } else if (evalR.isInstanceOf[Expression]) { + EqualTo(left, evalR.asInstanceOf[Expression]) + } else { + val cmp = prc2(input, left.dataType, right.dataType, evalL, evalR) + if (cmp.isDefined) { + cmp.get > 0 + } else { + e + } + } + } + case GreaterThanOrEqual(left, right) => { + val evalL = left.partialReduce(input) + val evalR = right.partialReduce(input) + if (evalL.isInstanceOf[Expression] && evalR.isInstanceOf[Expression]) { + EqualTo(evalL.asInstanceOf[Expression], evalR.asInstanceOf[Expression]) + } else if (evalL.isInstanceOf[Expression]) { + EqualTo(evalL.asInstanceOf[Expression], right) + } else if (evalR.isInstanceOf[Expression]) { + EqualTo(left, evalR.asInstanceOf[Expression]) + } else { + val cmp = prc2(input, left.dataType, right.dataType, evalL, evalR) + if (cmp.isDefined) { + cmp.get >= 0 + } else { + e + } + } + } + case If(predicate, trueE, falseE) => { + val v = predicate.partialReduce(input) + if (v.isInstanceOf[Expression]) { + If(v.asInstanceOf[Expression], + trueE.partialReduce(input).asInstanceOf[Expression], + falseE.partialReduce(input).asInstanceOf[Expression]) + } else if (v.asInstanceOf[Boolean]) { + trueE.partialReduce(input) + } else { + falseE.partialReduce(input) + } + } + case _ => e + } + } + + @inline + protected def pc2( + i: Row, + e1: Expression, + e2: Expression): Option[Int] = { + if (e1.dataType != e2.dataType) { + throw new TreeNodeException(e, s"Types do not match ${e1.dataType} != ${e2.dataType}") + } + + val evalE1 = e1.partialEval(i) + if (evalE1 == null) { + None + } else { + val evalE2 = e2.partialEval(i) + if (evalE2 == null) { + None + } else { + e1.dataType match { + case nativeType: NativeType => { + val pdt = RangeType.primitiveToPODataTypeMap.get(nativeType).getOrElse(null) + if (pdt == null) { + sys.error(s"Type $i does not have corresponding partial ordered type") + } else { + pdt.partialOrdering.tryCompare( + pdt.toPartiallyOrderingDataType(evalE1, nativeType).asInstanceOf[pdt.JvmType], + pdt.toPartiallyOrderingDataType(evalE2, nativeType).asInstanceOf[pdt.JvmType]) + } + } + case other => sys.error(s"Type $other does not support partially ordered operations") + } + } + } + } + + @inline + protected def prc2( + i: Row, + dataType1: DataType, + dataType2: DataType, + eval1: Any, + eval2: Any): Option[Int] = { + if (dataType1 != dataType2) { + throw new TreeNodeException(e, s"Types do not match ${dataType1} != ${dataType2}") + } + + dataType1 match { + case nativeType: NativeType => { + val pdt = RangeType.primitiveToPODataTypeMap.get(nativeType).getOrElse(null) + if (pdt == null) { + sys.error(s"Type $i does not have corresponding partial ordered type") + } else { + pdt.partialOrdering.tryCompare( + pdt.toPartiallyOrderingDataType(eval1, nativeType).asInstanceOf[pdt.JvmType], + pdt.toPartiallyOrderingDataType(eval2, nativeType).asInstanceOf[pdt.JvmType]) + } + } + case other => sys.error(s"Type $other does not support partially ordered operations") + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/hbase/PartiallyOrderingDataType.scala b/sql/core/src/main/scala/org/apache/spark/sql/hbase/PartiallyOrderingDataType.scala new file mode 100755 index 0000000000000..2c229d432dbaf --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/hbase/PartiallyOrderingDataType.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.hbase + +import org.apache.spark.sql.catalyst.types._ +import scala.math.PartialOrdering +import scala.reflect.runtime.universe.TypeTag + +abstract class PartiallyOrderingDataType extends DataType { + private[sql] type JvmType + + def toPartiallyOrderingDataType(s: Any, dt: NativeType): Any + + @transient private[sql] val tag: TypeTag[JvmType] + + private[sql] val partialOrdering: PartialOrdering[JvmType] +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/hbase/RangeType.scala b/sql/core/src/main/scala/org/apache/spark/sql/hbase/RangeType.scala new file mode 100755 index 0000000000000..4f8c999493b83 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/hbase/RangeType.scala @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.hbase + +import java.sql.Timestamp + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.types._ + +import scala.collection.immutable.HashMap +import scala.language.implicitConversions +import scala.math.PartialOrdering +import scala.reflect.runtime.universe.typeTag + +class Range[T](val start: Option[T], // None for open ends + val startInclusive: Boolean, + val end: Option[T], // None for open ends + val endInclusive: Boolean, + val dt: NativeType) { + require(dt != null && !(start.isDefined && end.isDefined && + ((dt.ordering.eq(start.get, end.get) && + (!startInclusive || !endInclusive)) || + (dt.ordering.gt(start.get.asInstanceOf[dt.JvmType], end.get.asInstanceOf[dt.JvmType])))), + "Inappropriate range parameters") +} + +// HBase ranges: +// @param +// id: partition id to be used to map to a HBase physical partition +class PartitionRange[T](start: Option[T], startInclusive: Boolean, + end: Option[T], endInclusive: Boolean, + val id: Int, dt: NativeType, var pred: Expression) + extends Range[T](start, startInclusive, end, endInclusive, dt) + +// A PointRange is a range of a single point. It is used for convenience when +// do comparison on two values of the same type. An alternatively would be to +// use multiple (overloaded) comparison methods, which could be more natural +// but also more codes + +//class PointRange[T](value: T, dt:NativeType) +// extends Range[T](Some(value), true, Some(value), true, dt) + + +class RangeType[T] extends PartiallyOrderingDataType { + private[sql] type JvmType = Range[T] + @transient private[sql] val tag = typeTag[JvmType] + + def toPartiallyOrderingDataType(s: Any, dt: NativeType): Any = s match { + case i: Int => new Range[Int](Some(i), true, Some(i), true, IntegerType) + case l: Long => new Range[Long](Some(l), true, Some(l), true, LongType) + case d: Double => new Range[Double](Some(d), true, Some(d), true, DoubleType) + case f: Float => new Range[Float](Some(f), true, Some(f), true, FloatType) + case b: Byte => new Range[Byte](Some(b), true, Some(b), true, ByteType) + case s: Short => new Range[Short](Some(s), true, Some(s), true, ShortType) + case s: String => new Range[String](Some(s), true, Some(s), true, StringType) + case b: Boolean => new Range[Boolean](Some(b), true, Some(b), true, BooleanType) + // todo: fix bigdecimal issue, now this will leads to comile error + case d: BigDecimal => new Range[BigDecimal](Some(d), true, Some(d), true, DecimalType.Unlimited) + case t: Timestamp => new Range[Timestamp](Some(t), true, Some(t), true, TimestampType) + case _ => s + } + + val partialOrdering = new PartialOrdering[JvmType] { + // Right now we just support comparisons between a range and a point + // In the future when more generic range comparisons, these two methods + // must be functional as expected + def tryCompare(a: JvmType, b: JvmType): Option[Int] = { + val aRange = a.asInstanceOf[Range[T]] + val aStartInclusive = aRange.startInclusive + val aStart = aRange.start.getOrElse(null).asInstanceOf[aRange.dt.JvmType] + val aEnd = aRange.end.getOrElse(null).asInstanceOf[aRange.dt.JvmType] + val aEndInclusive = aRange.endInclusive + val bRange = b.asInstanceOf[Range[T]] + val bStart = bRange.start.getOrElse(null).asInstanceOf[aRange.dt.JvmType] + val bEnd = bRange.end.getOrElse(null).asInstanceOf[aRange.dt.JvmType] + val bStartInclusive = bRange.startInclusive + val bEndInclusive = bRange.endInclusive + + // return 1 iff aStart > bEnd + // return 1 iff aStart = bEnd, aStartInclusive & bEndInclusive are not true at same position + if ((aStart != null && bEnd != null) + && (aRange.dt.ordering.gt(aStart, bEnd) + || (aRange.dt.ordering.equiv(aStart, bEnd) && !(aStartInclusive && bEndInclusive)))) { + Some(1) + } //Vice versa + else if ((bStart != null && aEnd != null) + && (aRange.dt.ordering.gt(bStart, aEnd) + || (aRange.dt.ordering.equiv(bStart, aEnd) && !(bStartInclusive && aEndInclusive)))) { + Some(-1) + } else if (aStart != null && aEnd != null && bStart != null && bEnd != null && + aRange.dt.ordering.equiv(bStart, aEnd) + && aRange.dt.ordering.equiv(aStart, aEnd) + && aRange.dt.ordering.equiv(bStart, bEnd) + && (aStartInclusive && aEndInclusive && bStartInclusive && bEndInclusive)) { + Some(0) + } else { + None + } + } + + def lteq(a: JvmType, b: JvmType): Boolean = { + // [(aStart, aEnd)] and [(bStart, bEnd)] + // [( and )] mean the possibilities of the inclusive and exclusive condition + val aRange = a.asInstanceOf[Range[T]] + val aStartInclusive = aRange.startInclusive + val aEnd = aRange.end.getOrElse(null) + val aEndInclusive = aRange.endInclusive + val bRange = b.asInstanceOf[Range[T]] + val bStart = bRange.start.getOrElse(null) + val bStartInclusive = bRange.startInclusive + val bEndInclusive = bRange.endInclusive + + // Compare two ranges, return true iff the upper bound of the lower range is lteq to + // the lower bound of the upper range. Because the exclusive boundary could be null, which + // means the boundary could be infinity, we need to further check this conditions. + val result = + (aStartInclusive, aEndInclusive, bStartInclusive, bEndInclusive) match { + // [(aStart, aEnd] compare to [bStart, bEnd)] + case (_, true, true, _) => { + if (aRange.dt.ordering.lteq(aEnd.asInstanceOf[aRange.dt.JvmType], + bStart.asInstanceOf[aRange.dt.JvmType])) { + true + } else { + false + } + } + // [(aStart, aEnd] compare to (bStart, bEnd)] + case (_, true, false, _) => { + if (bStart != null && aRange.dt.ordering.lteq(aEnd.asInstanceOf[aRange.dt.JvmType], + bStart.asInstanceOf[aRange.dt.JvmType])) { + true + } else { + false + } + } + // [(aStart, aEnd) compare to [bStart, bEnd)] + case (_, false, true, _) => { + if (aEnd != null && aRange.dt.ordering.lteq(aEnd.asInstanceOf[aRange.dt.JvmType], + bStart.asInstanceOf[aRange.dt.JvmType])) { + true + } else { + false + } + } + // [(aStart, aEnd) compare to (bStart, bEnd)] + case (_, false, false, _) => { + if (aEnd != null && bStart != null && + aRange.dt.ordering.lteq(aEnd.asInstanceOf[aRange.dt.JvmType], + bStart.asInstanceOf[aRange.dt.JvmType])) { + true + } else { + false + } + } + } + + result + } + } +} + +object RangeType { + + object StringRangeType extends RangeType[String] + + object IntegerRangeType extends RangeType[Int] + + object LongRangeType extends RangeType[Long] + + object DoubleRangeType extends RangeType[Double] + + object FloatRangeType extends RangeType[Float] + + object ByteRangeType extends RangeType[Byte] + + object ShortRangeType extends RangeType[Short] + + object BooleanRangeType extends RangeType[Boolean] + + object DecimalRangeType extends RangeType[BigDecimal] + + object TimestampRangeType extends RangeType[Timestamp] + + val primitiveToPODataTypeMap: HashMap[NativeType, PartiallyOrderingDataType] = + HashMap( + IntegerType -> IntegerRangeType, + LongType -> LongRangeType, + DoubleType -> DoubleRangeType, + FloatType -> FloatRangeType, + ByteType -> ByteRangeType, + ShortType -> ShortRangeType, + BooleanType -> BooleanRangeType, + DecimalType.Unlimited -> DecimalRangeType, + TimestampType -> TimestampRangeType, + StringType -> StringRangeType + ) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/hbase/ScanPredClassfier.scala b/sql/core/src/main/scala/org/apache/spark/sql/hbase/ScanPredClassfier.scala new file mode 100755 index 0000000000000..d0f312242bd99 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/hbase/ScanPredClassfier.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hbase + +import org.apache.spark.sql.catalyst.expressions._ + +/** + * Classfies a predicate into a pair of (push-downable, non-push-downable) predicates + * for a Scan; the logic relationship between the two components of the pair is AND + */ +class ScanPredClassfier(relation: HBaseRelation, keyIndex: Int) { + def apply(pred: Expression): (Option[Expression], Option[Expression]) = { + // post-order bottom-up traversal + pred match { + case And(left, right) => { + val (ll, lr) = apply(left) + val (rl, rr) = apply(right) + (ll, lr, rl, rr) match { + // All Nones + case (None, None, None, None) => (None, None) + // Three Nones + case (None, None, None, _) => (None, rr) + case (None, None, _, None) => (rl, None) + case (None, _, None, None) => (None, lr) + case (_, None, None, None) => (ll, None) + // two Nones + case (None, None, _, _) => (rl, rr) + case (None, _, None, _) => (None, Some(And(rl.get, rr.get))) + case (None, _, _, None) => (rl, lr) + case (_, None, None, _) => (ll, rr) + case (_, None, _, None) => (Some(And(ll.get, rl.get)), None) + case (_, _, None, None) => (ll, lr) + // One None + case (None, _, _, _) => (rl, Some(And(lr.get, rr.get))) + case (_, None, _, _) => (Some(And(ll.get, rl.get)), rr) + case (_, _, None, _) => (ll, Some(And(lr.get, rr.get))) + case (_, _, _, None) => (Some(And(ll.get, rl.get)), lr) + // No nones + case _ => (Some(And(ll.get, rl.get)), Some(And(lr.get, rr.get))) + } + } + case Or(left, right) => { + val (ll, lr) = apply(left) + val (rl, rr) = apply(right) + (ll, lr, rl, rr) match { + // All Nones + case (None, None, None, None) => (None, None) + // Three Nones + case (None, None, None, _) => (None, rr) + case (None, None, _, None) => (rl, None) + case (None, _, None, None) => (None, lr) + case (_, None, None, None) => (ll, None) + // two Nones + case (None, None, _, _) => (rl, rr) + case (None, _, None, _) => (None, Some(Or(rl.get, rr.get))) + case (None, _, _, None) => (None, Some(Or(lr.get, rl.get))) + case (_, None, None, _) => (None, Some(Or(ll.get, rr.get))) + case (_, None, _, None) => (Some(Or(ll.get, rl.get)), None) + case (_, _, None, None) => (ll, lr) + // One None + case (None, _, _, _) => (None, Some(pred)) + // Accept increased evaluation complexity for improved pushed down + case (_, None, _, _) => (Some(Or(ll.get, rl.get)), Some(Or(ll.get, rr.get))) + case (_, _, None, _) => (None, Some(pred)) + // Accept increased evaluation complexity for improved pushed down + case (_, _, _, None) => (Some(Or(ll.get, rl.get)), Some(Or(lr.get, rl.get))) + // No nones + // Accept increased evaluation complexity for improved pushed down + case _ => (Some(Or(ll.get, rl.get)), Some(And(Or(ll.get, rr.get), + And(Or(lr.get, rl.get), Or(lr.get, rr.get))))) + } + } + case EqualTo(left, right) => classifyBinary(left, right, pred) + case LessThan(left, right) => classifyBinary(left, right, pred) + case LessThanOrEqual(left, right) => classifyBinary(left, right, pred) + case GreaterThan(left, right) => classifyBinary(left, right, pred) + case GreaterThanOrEqual(left, right) => classifyBinary(left, right, pred) + // everything else are treated as non pushdownable + case _ => (None, Some(pred)) + } + } + + // returns true if the binary operator of the two args can be pushed down + private def classifyBinary(left: Expression, right: Expression, pred: Expression) + : (Option[Expression], Option[Expression]) = { + (left, right) match { + case (Literal(_, _), AttributeReference(_, _, _, _)) => { + if (relation.isNonKey(right.asInstanceOf[AttributeReference])) { + (Some(pred), None) + } else { + val keyIdx = relation.keyIndex(right.asInstanceOf[AttributeReference]) + if (keyIdx == keyIndex) { + (Some(pred), None) + } else { + (None, Some(pred)) + } + } + } + case (AttributeReference(_, _, _, _), Literal(_, _)) => { + if (relation.isNonKey(left.asInstanceOf[AttributeReference])) { + (Some(pred), None) + } else { + val keyIdx = relation.keyIndex(left.asInstanceOf[AttributeReference]) + if (keyIdx == keyIndex) { + (Some(pred), None) + } else { + (None, Some(pred)) + } + } + } + case _ => (None, Some(pred)) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/hbase/hbase.scala b/sql/core/src/main/scala/org/apache/spark/sql/hbase/hbase.scala new file mode 100644 index 0000000000000..acea2594d1f27 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/hbase/hbase.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hbase + +import org.apache.spark.Logging +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, Expression, Row} +import org.apache.spark.sql.catalyst.types.StructType +import org.apache.spark.sql.sources.{BaseRelation, CatalystScan, RelationProvider} + +/** + * Allows creation of parquet based tables using the syntax + * `CREATE TEMPORARY TABLE table_name(field1 filed1_type, filed2 filed2_type...) + * USING org.apache.spark.sql.hbase + * OPTIONS ( + * hbase_table "hbase_table_name", + * mapping "filed1=cf1.column1, filed2=cf2.column2...", + * primary_key "filed_name1, field_name2" + * )`. + */ +class DefaultSource extends RelationProvider with Logging { + /** Returns a new base relation with the given parameters. */ + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String], + schema: Option[StructType]): BaseRelation = { + + assert(schema.nonEmpty, "schema can not be empty for hbase rouce!") + assert(parameters.get("hbase_table").nonEmpty, "no option for hbase.table") + assert(parameters.get("mapping").nonEmpty, "no option for mapping") + assert(parameters.get("primary_key").nonEmpty, "no option for mapping") + + val hbaseTableName = parameters.getOrElse("hbase_table", "").toLowerCase + val mapping = parameters.getOrElse("mapping", "").toLowerCase + val primaryKey = parameters.getOrElse("primary_key", "").toLowerCase() + val partValue = "([^=]+)=([^=]+)".r + + val fieldByHbaseColumn = mapping.split(",").map { + case partValue(key, value) => (key, value) + } + val keyColumns = primaryKey.split(",").map(_.trim) + + // check the mapping is legal + val fieldSet = schema.get.fields.map(_.name).toSet + fieldByHbaseColumn.iterator.map(_._1).foreach { field => + assert(fieldSet.contains(field), s"no field named $field in table") + } + HBaseScanBuilder("", hbaseTableName, keyColumns, fieldByHbaseColumn, schema.get)(sqlContext) + } +} + +@DeveloperApi +case class HBaseScanBuilder( + tableName: String, + hbaseTableName: String, + keyColumns: Seq[String], + fieldByHbaseColumn: Seq[(String, String)], + schema: StructType)(context: SQLContext) extends CatalystScan with Logging { + + val hbaseMetadata = new HBaseMetadata + + val filedByHbaseFamilyAndColumn = fieldByHbaseColumn.toMap + + def allColumns() = schema.fields.map{ field => + val fieldName = field.name + if(keyColumns.contains(fieldName)) { + KeyColumn(fieldName, field.dataType, keyColumns.indexOf(fieldName)) + } else { + val familyAndQuilifier = filedByHbaseFamilyAndColumn.getOrElse(fieldName, "").split("\\.") + assert(familyAndQuilifier.size == 2, "illegal mapping") + NonKeyColumn(fieldName, field.dataType, familyAndQuilifier(0), familyAndQuilifier(1)) + } + } + + val relation = hbaseMetadata.createTable(tableName, hbaseTableName, allColumns) + + override def sqlContext: SQLContext = context + + // todo: optimization for predict push down + override def buildScan(output: Seq[Attribute], predicates: Seq[Expression]): RDD[Row] = { + new HBaseSQLReaderRDD( + relation, + schema.toAttributes, + None, + None, + predicates.reduceLeftOption(And), + None + )(sqlContext) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/hbase/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/hbase/package.scala new file mode 100755 index 0000000000000..28ecf9560497d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/hbase/package.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +import org.apache.hadoop.hbase.KeyValue +import org.apache.hadoop.hbase.client.Put +import org.apache.hadoop.hbase.io.ImmutableBytesWritable + +import scala.collection.mutable.ArrayBuffer + +package object hbase { + type HBaseRawType = Array[Byte] + + class ImmutableBytesWritableWrapper(rowKey: Array[Byte]) + extends Serializable { + + def compareTo(that: ImmutableBytesWritableWrapper): Int = { + this.toImmutableBytesWritable() compareTo that.toImmutableBytesWritable() + } + + def toImmutableBytesWritable() = new ImmutableBytesWritable(rowKey) + } + + class PutWrapper(rowKey: Array[Byte]) extends Serializable { + val fqv = new ArrayBuffer[(Array[Byte], Array[Byte], Array[Byte])] + + def add(family: Array[Byte], qualifier: Array[Byte], value: Array[Byte]) = + fqv += ((family, qualifier, value)) + + def toPut() = { + val put = new Put(rowKey) + fqv.foreach { fqv => + put.add(fqv._1, fqv._2, fqv._3) + } + put + } + } + + class KeyValueWrapper( + rowKey: Array[Byte], + family: Array[Byte], + qualifier: Array[Byte], + value: Array[Byte]) extends Serializable { + + def toKeyValue() = new KeyValue(rowKey, family, qualifier, value) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index fc70c183437f6..f2796161536fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -18,31 +18,38 @@ package org.apache.spark.sql.json import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.types.StructType import org.apache.spark.sql.sources._ private[sql] class DefaultSource extends RelationProvider { /** Returns a new base relation with the given parameters. */ override def createRelation( sqlContext: SQLContext, - parameters: Map[String, String]): BaseRelation = { + parameters: Map[String, String], + schema: Option[StructType]): BaseRelation = { val fileName = parameters.getOrElse("path", sys.error("Option 'path' not specified")) val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) - JSONRelation(fileName, samplingRatio)(sqlContext) + JSONRelation(fileName, samplingRatio, schema)(sqlContext) } } -private[sql] case class JSONRelation(fileName: String, samplingRatio: Double)( +private[sql] case class JSONRelation( + fileName: String, + samplingRatio: Double, + _schema: Option[StructType])( @transient val sqlContext: SQLContext) extends TableScan { private def baseRDD = sqlContext.sparkContext.textFile(fileName) override val schema = + _schema.getOrElse( JsonRDD.inferSchema( baseRDD, samplingRatio, sqlContext.columnNameOfCorruptRecord) + ) override def buildScan() = JsonRDD.jsonStringToRow(baseRDD, schema, sqlContext.columnNameOfCorruptRecord) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 9b89c3bfb3307..3247621c47297 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -22,21 +22,20 @@ import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce.{JobContext, InputSplit, Job} - import parquet.hadoop.ParquetInputFormat import parquet.hadoop.util.ContextUtil import org.apache.spark.annotation.DeveloperApi import org.apache.spark.{Partition => SparkPartition, Logging} import org.apache.spark.rdd.{NewHadoopPartition, RDD} - -import org.apache.spark.sql.{SQLConf, Row, SQLContext} -import org.apache.spark.sql.catalyst.expressions.{SpecificMutableRow, And, Expression, Attribute} -import org.apache.spark.sql.catalyst.types.{IntegerType, StructField, StructType} +import org.apache.spark.sql.catalyst.expressions.{Row, And, SpecificMutableRow, Expression, Attribute} +import org.apache.spark.sql.catalyst.types.{StructField, IntegerType, StructType} import org.apache.spark.sql.sources._ +import org.apache.spark.sql.{SQLConf, SQLContext} import scala.collection.JavaConversions._ + /** * Allows creation of parquet based tables using the syntax * `CREATE TEMPORARY TABLE ... USING org.apache.spark.sql.parquet`. Currently the only option @@ -47,11 +46,12 @@ class DefaultSource extends RelationProvider { /** Returns a new base relation with the given parameters. */ override def createRelation( sqlContext: SQLContext, - parameters: Map[String, String]): BaseRelation = { + parameters: Map[String, String], + schema: Option[StructType]): BaseRelation = { val path = parameters.getOrElse("path", sys.error("'path' must be specified for parquet tables.")) - ParquetRelation2(path)(sqlContext) + ParquetRelation2(path, schema)(sqlContext) } } @@ -81,7 +81,9 @@ private[parquet] case class Partition(partitionValues: Map[String, Any], files: * discovery. */ @DeveloperApi -case class ParquetRelation2(path: String)(@transient val sqlContext: SQLContext) +case class ParquetRelation2( + path: String, + _schema: Option[StructType])(@transient val sqlContext: SQLContext) extends CatalystScan with Logging { def sparkContext = sqlContext.sparkContext @@ -132,12 +134,13 @@ case class ParquetRelation2(path: String)(@transient val sqlContext: SQLContext) override val sizeInBytes = partitions.flatMap(_.files).map(_.getLen).sum - val dataSchema = StructType.fromAttributes( // TODO: Parquet code should not deal with attributes. - ParquetTypesConverter.readSchemaFromFile( - partitions.head.files.head.getPath, - Some(sparkContext.hadoopConfiguration), - sqlContext.isParquetBinaryAsString)) - + val dataSchema = _schema.getOrElse( + StructType.fromAttributes( // TODO: Parquet code should not deal with attributes. + ParquetTypesConverter.readSchemaFromFile( + partitions.head.files.head.getPath, + Some(sparkContext.hadoopConfiguration), + sqlContext.isParquetBinaryAsString)) + ) val dataIncludesKey = partitionKeys.headOption.map(dataSchema.fieldNames.contains(_)).getOrElse(true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index ca510cb0b07e3..6b3dc79451a7e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -17,16 +17,15 @@ package org.apache.spark.sql.sources -import org.apache.spark.Logging -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.execution.RunnableCommand -import org.apache.spark.util.Utils - import scala.language.implicitConversions -import scala.util.parsing.combinator.lexical.StdLexical import scala.util.parsing.combinator.syntactical.StandardTokenParsers import scala.util.parsing.combinator.PackratParsers +import org.apache.spark.Logging +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.execution.RunnableCommand +import org.apache.spark.util.Utils import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.SqlLexical @@ -49,6 +48,15 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi protected implicit def asParser(k: Keyword): Parser[String] = lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _) + protected val STRING = Keyword("STRING") + protected val SHORT = Keyword("SHORT") + protected val DOUBLE = Keyword("DOUBLE") + protected val BOOLEAN = Keyword("BOOLEAN") + protected val BYTE = Keyword("BYTE") + protected val FLOAT = Keyword("FLOAT") + protected val INT = Keyword("INT") + protected val INTEGER = Keyword("INTEGER") + protected val LONG = Keyword("LONG") protected val CREATE = Keyword("CREATE") protected val TEMPORARY = Keyword("TEMPORARY") protected val TABLE = Keyword("TABLE") @@ -67,16 +75,35 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi protected lazy val ddl: Parser[LogicalPlan] = createTable /** - * CREATE TEMPORARY TABLE avroTable + * `CREATE TEMPORARY TABLE avroTable * USING org.apache.spark.sql.avro - * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro") + * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` + * or + * `CREATE TEMPORARY TABLE avroTable(intField int, stringField string...) + * USING org.apache.spark.sql.avro + * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")` */ protected lazy val createTable: Parser[LogicalPlan] = - CREATE ~ TEMPORARY ~ TABLE ~> ident ~ (USING ~> className) ~ (OPTIONS ~> options) ^^ { + ( CREATE ~ TEMPORARY ~ TABLE ~> ident ~ (USING ~> className) ~ (OPTIONS ~> options) ^^ { case tableName ~ provider ~ opts => - CreateTableUsing(tableName, provider, opts) + CreateTableUsing(tableName, Seq.empty, provider, opts) + } + | + CREATE ~ TEMPORARY ~ TABLE ~> ident + ~ tableCols ~ (USING ~> className) ~ (OPTIONS ~> options) ^^ { + case tableName ~ tableColumns ~ provider ~ opts => + CreateTableUsing(tableName, tableColumns, provider, opts) + } + ) + protected lazy val tableCol: Parser[(String, String)] = + ident ~ (STRING | BYTE | SHORT | INT | INTEGER | LONG | FLOAT | DOUBLE | BOOLEAN) ^^ { + case e1 ~ e2 => (e1, e2) } + protected lazy val tableCols: Parser[Seq[(String, String)]] = + "(" ~> repsep(tableCol, ",") <~ ")" + + protected lazy val options: Parser[Map[String, String]] = "(" ~> repsep(pair, ",") <~ ")" ^^ { case s: Seq[(String, String)] => s.toMap } @@ -87,6 +114,7 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi private[sql] case class CreateTableUsing( tableName: String, + tableCols: Seq[(String, String)], provider: String, options: Map[String, String]) extends RunnableCommand { @@ -100,9 +128,32 @@ private[sql] case class CreateTableUsing( } } val dataSource = clazz.newInstance().asInstanceOf[org.apache.spark.sql.sources.RelationProvider] - val relation = dataSource.createRelation(sqlContext, options) + val relation = dataSource.createRelation(sqlContext, options, toSchema(tableCols)) sqlContext.baseRelationToSchemaRDD(relation).registerTempTable(tableName) Seq.empty } + + def toSchema(tableColumns: Seq[(String, String)]): Option[StructType] = { + val fields: Seq[StructField] = tableColumns.map { tableColumn => + val columnName = tableColumn._1 + val columnType = tableColumn._2 + // todo: support more complex data type + columnType.toLowerCase match { + case "string" => StructField(columnName, StringType) + case "byte" => StructField(columnName, ByteType) + case "short" => StructField(columnName, ShortType) + case "int" => StructField(columnName, IntegerType) + case "integer" => StructField(columnName, IntegerType) + case "long" => StructField(columnName, LongType) + case "double" => StructField(columnName, DoubleType) + case "float" => StructField(columnName, FloatType) + case "boolean" => StructField(columnName, BooleanType) + } + } + if (fields.isEmpty) { + return None + } + Some(StructType(fields)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 939b4e15163a6..7dade36cbef75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -37,7 +37,10 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, Attribute} @DeveloperApi trait RelationProvider { /** Returns a new base relation with the given parameters. */ - def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation + def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String], + schema: Option[StructType]): BaseRelation } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index 939b3c0c66de7..8aa55e2113f36 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -24,18 +24,26 @@ import org.apache.spark.sql._ class FilteredScanSource extends RelationProvider { override def createRelation( sqlContext: SQLContext, - parameters: Map[String, String]): BaseRelation = { - SimpleFilteredScan(parameters("from").toInt, parameters("to").toInt)(sqlContext) + parameters: Map[String, String], + schema: Option[StructType]): BaseRelation = { + SimpleFilteredScan( + parameters("from").toInt, + parameters("to").toInt, + schema: Option[StructType])(sqlContext) } } -case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQLContext) +case class SimpleFilteredScan( + from: Int, + to: Int, + _schema: Option[StructType])(@transient val sqlContext: SQLContext) extends PrunedFilteredScan { - override def schema = + override def schema = _schema.getOrElse( StructType( StructField("a", IntegerType, nullable = false) :: StructField("b", IntegerType, nullable = false) :: Nil) + ) override def buildScan(requiredColumns: Array[String], filters: Array[Filter]) = { val rowBuilders = requiredColumns.map { @@ -80,85 +88,97 @@ class FilteredScanSuite extends DataSourceTest { | to '10' |) """.stripMargin) + + sql( + """ + |CREATE TEMPORARY TABLE oneToTenFiltered_with_schema(a int, b int) + |USING org.apache.spark.sql.sources.FilteredScanSource + |OPTIONS ( + | from '1', + | to '10' + |) + """.stripMargin) } + Seq("oneToTenFiltered", "oneToTenFiltered_with_schema").foreach { table => - sqlTest( - "SELECT * FROM oneToTenFiltered", - (1 to 10).map(i => Row(i, i * 2)).toSeq) + sqlTest( + s"SELECT * FROM $table", + (1 to 10).map(i => Row(i, i * 2)).toSeq) - sqlTest( - "SELECT a, b FROM oneToTenFiltered", - (1 to 10).map(i => Row(i, i * 2)).toSeq) + sqlTest( + s"SELECT a, b FROM $table", + (1 to 10).map(i => Row(i, i * 2)).toSeq) - sqlTest( - "SELECT b, a FROM oneToTenFiltered", - (1 to 10).map(i => Row(i * 2, i)).toSeq) + sqlTest( + s"SELECT b, a FROM $table", + (1 to 10).map(i => Row(i * 2, i)).toSeq) - sqlTest( - "SELECT a FROM oneToTenFiltered", - (1 to 10).map(i => Row(i)).toSeq) + sqlTest( + s"SELECT a FROM $table", + (1 to 10).map(i => Row(i)).toSeq) - sqlTest( - "SELECT b FROM oneToTenFiltered", - (1 to 10).map(i => Row(i * 2)).toSeq) + sqlTest( + s"SELECT b FROM $table", + (1 to 10).map(i => Row(i * 2)).toSeq) - sqlTest( - "SELECT a * 2 FROM oneToTenFiltered", - (1 to 10).map(i => Row(i * 2)).toSeq) + sqlTest( + s"SELECT a * 2 FROM $table", + (1 to 10).map(i => Row(i * 2)).toSeq) - sqlTest( - "SELECT A AS b FROM oneToTenFiltered", - (1 to 10).map(i => Row(i)).toSeq) + sqlTest( + s"SELECT A AS b FROM $table", + (1 to 10).map(i => Row(i)).toSeq) - sqlTest( - "SELECT x.b, y.a FROM oneToTenFiltered x JOIN oneToTenFiltered y ON x.a = y.b", - (1 to 5).map(i => Row(i * 4, i)).toSeq) + sqlTest( + s"SELECT x.b, y.a FROM $table x JOIN $table y ON x.a = y.b", + (1 to 5).map(i => Row(i * 4, i)).toSeq) - sqlTest( - "SELECT x.a, y.b FROM oneToTenFiltered x JOIN oneToTenFiltered y ON x.a = y.b", - (2 to 10 by 2).map(i => Row(i, i)).toSeq) + sqlTest( + s"SELECT x.a, y.b FROM $table x JOIN $table y ON x.a = y.b", + (2 to 10 by 2).map(i => Row(i, i)).toSeq) - sqlTest( - "SELECT * FROM oneToTenFiltered WHERE a = 1", - Seq(1).map(i => Row(i, i * 2)).toSeq) + sqlTest( + s"SELECT * FROM $table WHERE a = 1", + Seq(1).map(i => Row(i, i * 2)).toSeq) - sqlTest( - "SELECT * FROM oneToTenFiltered WHERE a IN (1,3,5)", - Seq(1,3,5).map(i => Row(i, i * 2)).toSeq) + sqlTest( + s"SELECT * FROM $table WHERE a IN (1,3,5)", + Seq(1,3,5).map(i => Row(i, i * 2)).toSeq) - sqlTest( - "SELECT * FROM oneToTenFiltered WHERE A = 1", - Seq(1).map(i => Row(i, i * 2)).toSeq) + sqlTest( + s"SELECT * FROM $table WHERE A = 1", + Seq(1).map(i => Row(i, i * 2)).toSeq) - sqlTest( - "SELECT * FROM oneToTenFiltered WHERE b = 2", - Seq(1).map(i => Row(i, i * 2)).toSeq) + sqlTest( + s"SELECT * FROM $table WHERE b = 2", + Seq(1).map(i => Row(i, i * 2)).toSeq) - testPushDown("SELECT * FROM oneToTenFiltered WHERE A = 1", 1) - testPushDown("SELECT a FROM oneToTenFiltered WHERE A = 1", 1) - testPushDown("SELECT b FROM oneToTenFiltered WHERE A = 1", 1) - testPushDown("SELECT a, b FROM oneToTenFiltered WHERE A = 1", 1) - testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 1", 1) - testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 = a", 1) + testPushDown(s"SELECT * FROM $table WHERE A = 1", 1) + testPushDown(s"SELECT a FROM $table WHERE A = 1", 1) + testPushDown(s"SELECT b FROM $table WHERE A = 1", 1) + testPushDown(s"SELECT a, b FROM $table WHERE A = 1", 1) + testPushDown(s"SELECT * FROM $table WHERE a = 1", 1) + testPushDown(s"SELECT * FROM $table WHERE 1 = a", 1) - testPushDown("SELECT * FROM oneToTenFiltered WHERE a > 1", 9) - testPushDown("SELECT * FROM oneToTenFiltered WHERE a >= 2", 9) + testPushDown(s"SELECT * FROM $table WHERE a > 1", 9) + testPushDown(s"SELECT * FROM $table WHERE a >= 2", 9) - testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 < a", 9) - testPushDown("SELECT * FROM oneToTenFiltered WHERE 2 <= a", 9) + testPushDown(s"SELECT * FROM $table WHERE 1 < a", 9) + testPushDown(s"SELECT * FROM $table WHERE 2 <= a", 9) - testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 > a", 0) - testPushDown("SELECT * FROM oneToTenFiltered WHERE 2 >= a", 2) + testPushDown(s"SELECT * FROM $table WHERE 1 > a", 0) + testPushDown(s"SELECT * FROM $table WHERE 2 >= a", 2) - testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 1", 0) - testPushDown("SELECT * FROM oneToTenFiltered WHERE a <= 2", 2) + testPushDown(s"SELECT * FROM $table WHERE a < 1", 0) + testPushDown(s"SELECT * FROM $table WHERE a <= 2", 2) - testPushDown("SELECT * FROM oneToTenFiltered WHERE a > 1 AND a < 10", 8) + testPushDown(s"SELECT * FROM $table WHERE a > 1 AND a < 10", 8) - testPushDown("SELECT * FROM oneToTenFiltered WHERE a IN (1,3,5)", 3) + testPushDown(s"SELECT * FROM $table WHERE a IN (1,3,5)", 3) - testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 20", 0) - testPushDown("SELECT * FROM oneToTenFiltered WHERE b = 1", 10) + testPushDown(s"SELECT * FROM $table WHERE a = 20", 0) + testPushDown(s"SELECT * FROM $table WHERE b = 1", 10) + } def testPushDown(sqlString: String, expectedCount: Int): Unit = { test(s"PushDown Returns $expectedCount: $sqlString") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala index fee2e22611cdc..a2b9199ea90cf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala @@ -22,18 +22,23 @@ import org.apache.spark.sql._ class PrunedScanSource extends RelationProvider { override def createRelation( sqlContext: SQLContext, - parameters: Map[String, String]): BaseRelation = { - SimplePrunedScan(parameters("from").toInt, parameters("to").toInt)(sqlContext) + parameters: Map[String, String], + schema: Option[StructType]): BaseRelation = { + SimplePrunedScan(parameters("from").toInt, parameters("to").toInt, schema)(sqlContext) } } -case class SimplePrunedScan(from: Int, to: Int)(@transient val sqlContext: SQLContext) +case class SimplePrunedScan( + from: Int, + to: Int, + _schema: Option[StructType])(@transient val sqlContext: SQLContext) extends PrunedScan { - override def schema = + override def schema = _schema.getOrElse( StructType( StructField("a", IntegerType, nullable = false) :: StructField("b", IntegerType, nullable = false) :: Nil) + ) override def buildScan(requiredColumns: Array[String]) = { val rowBuilders = requiredColumns.map { @@ -59,54 +64,66 @@ class PrunedScanSuite extends DataSourceTest { | to '10' |) """.stripMargin) - } - - sqlTest( - "SELECT * FROM oneToTenPruned", - (1 to 10).map(i => Row(i, i * 2)).toSeq) - - sqlTest( - "SELECT a, b FROM oneToTenPruned", - (1 to 10).map(i => Row(i, i * 2)).toSeq) - - sqlTest( - "SELECT b, a FROM oneToTenPruned", - (1 to 10).map(i => Row(i * 2, i)).toSeq) - - sqlTest( - "SELECT a FROM oneToTenPruned", - (1 to 10).map(i => Row(i)).toSeq) - - sqlTest( - "SELECT a, a FROM oneToTenPruned", - (1 to 10).map(i => Row(i, i)).toSeq) - - sqlTest( - "SELECT b FROM oneToTenPruned", - (1 to 10).map(i => Row(i * 2)).toSeq) - - sqlTest( - "SELECT a * 2 FROM oneToTenPruned", - (1 to 10).map(i => Row(i * 2)).toSeq) - - sqlTest( - "SELECT A AS b FROM oneToTenPruned", - (1 to 10).map(i => Row(i)).toSeq) - - sqlTest( - "SELECT x.b, y.a FROM oneToTenPruned x JOIN oneToTenPruned y ON x.a = y.b", - (1 to 5).map(i => Row(i * 4, i)).toSeq) - sqlTest( - "SELECT x.a, y.b FROM oneToTenPruned x JOIN oneToTenPruned y ON x.a = y.b", - (2 to 10 by 2).map(i => Row(i, i)).toSeq) + sql( + """ + |CREATE TEMPORARY TABLE oneToTenPruned_with_schema(a int, b int) + |USING org.apache.spark.sql.sources.PrunedScanSource + |OPTIONS ( + | from '1', + | to '10' + |) + """.stripMargin) + } - testPruning("SELECT * FROM oneToTenPruned", "a", "b") - testPruning("SELECT a, b FROM oneToTenPruned", "a", "b") - testPruning("SELECT b, a FROM oneToTenPruned", "b", "a") - testPruning("SELECT b, b FROM oneToTenPruned", "b") - testPruning("SELECT a FROM oneToTenPruned", "a") - testPruning("SELECT b FROM oneToTenPruned", "b") + Seq("oneToTenPruned", "oneToTenPruned_with_schema").foreach { table => + sqlTest( + s"SELECT * FROM $table", + (1 to 10).map(i => Row(i, i * 2)).toSeq) + + sqlTest( + s"SELECT a, b FROM $table", + (1 to 10).map(i => Row(i, i * 2)).toSeq) + + sqlTest( + s"SELECT b, a FROM $table", + (1 to 10).map(i => Row(i * 2, i)).toSeq) + + sqlTest( + s"SELECT a FROM $table", + (1 to 10).map(i => Row(i)).toSeq) + + sqlTest( + s"SELECT a, a FROM $table", + (1 to 10).map(i => Row(i, i)).toSeq) + + sqlTest( + s"SELECT b FROM $table", + (1 to 10).map(i => Row(i * 2)).toSeq) + + sqlTest( + s"SELECT a * 2 FROM $table", + (1 to 10).map(i => Row(i * 2)).toSeq) + + sqlTest( + s"SELECT A AS b FROM $table", + (1 to 10).map(i => Row(i)).toSeq) + + sqlTest( + s"SELECT x.b, y.a FROM $table x JOIN $table y ON x.a = y.b", + (1 to 5).map(i => Row(i * 4, i)).toSeq) + + sqlTest( + s"SELECT x.a, y.b FROM $table x JOIN $table y ON x.a = y.b", + (2 to 10 by 2).map(i => Row(i, i)).toSeq) + + testPruning(s"SELECT * FROM $table", "a", "b") + testPruning(s"SELECT a, b FROM $table", "a", "b") + testPruning(s"SELECT b, a FROM $table", "b", "a") + testPruning(s"SELECT b, b FROM $table", "b") + testPruning(s"SELECT a FROM $table", "a") + testPruning(s"SELECT b FROM $table", "b") + } def testPruning(sqlString: String, expectedColumns: String*): Unit = { test(s"Columns output ${expectedColumns.mkString(",")}: $sqlString") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index b254b0620c779..8a5ef44fa4be3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -24,17 +24,21 @@ class DefaultSource extends SimpleScanSource class SimpleScanSource extends RelationProvider { override def createRelation( sqlContext: SQLContext, - parameters: Map[String, String]): BaseRelation = { - SimpleScan(parameters("from").toInt, parameters("to").toInt)(sqlContext) + parameters: Map[String, String], + schema: Option[StructType]): BaseRelation = { + SimpleScan(parameters("from").toInt, parameters("to").toInt, schema)(sqlContext) } } -case class SimpleScan(from: Int, to: Int)(@transient val sqlContext: SQLContext) +case class SimpleScan( + from: Int, + to: Int, + _schema: Option[StructType])(@transient val sqlContext: SQLContext) extends TableScan { - override def schema = + override def schema = _schema.getOrElse( StructType(StructField("i", IntegerType, nullable = false) :: Nil) - + ) override def buildScan() = sqlContext.sparkContext.parallelize(from to to).map(Row(_)) } @@ -51,60 +55,75 @@ class TableScanSuite extends DataSourceTest { | to '10' |) """.stripMargin) - } - - sqlTest( - "SELECT * FROM oneToTen", - (1 to 10).map(Row(_)).toSeq) - - sqlTest( - "SELECT i FROM oneToTen", - (1 to 10).map(Row(_)).toSeq) - - sqlTest( - "SELECT i FROM oneToTen WHERE i < 5", - (1 to 4).map(Row(_)).toSeq) - - sqlTest( - "SELECT i * 2 FROM oneToTen", - (1 to 10).map(i => Row(i * 2)).toSeq) - - sqlTest( - "SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1", - (2 to 10).map(i => Row(i, i - 1)).toSeq) + sql( + """ + |CREATE TEMPORARY TABLE oneToTen_with_schema(i int) + |USING org.apache.spark.sql.sources.SimpleScanSource + |OPTIONS ( + | from '1', + | to '10' + |) + """.stripMargin) + } - test("Caching") { - // Cached Query Execution - cacheTable("oneToTen") - assertCached(sql("SELECT * FROM oneToTen")) - checkAnswer( - sql("SELECT * FROM oneToTen"), + Seq("oneToTen", "oneToTen_with_schema").foreach { table => + sqlTest( + s"SELECT * FROM $table", (1 to 10).map(Row(_)).toSeq) - assertCached(sql("SELECT i FROM oneToTen")) - checkAnswer( - sql("SELECT i FROM oneToTen"), + sqlTest( + s"SELECT i FROM $table", (1 to 10).map(Row(_)).toSeq) - assertCached(sql("SELECT i FROM oneToTen WHERE i < 5")) - checkAnswer( - sql("SELECT i FROM oneToTen WHERE i < 5"), + sqlTest( + s"SELECT i FROM $table WHERE i < 5", (1 to 4).map(Row(_)).toSeq) - assertCached(sql("SELECT i * 2 FROM oneToTen")) - checkAnswer( - sql("SELECT i * 2 FROM oneToTen"), + sqlTest( + s"SELECT i * 2 FROM $table", (1 to 10).map(i => Row(i * 2)).toSeq) - assertCached(sql("SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), 2) - checkAnswer( - sql("SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), + sqlTest( + s"SELECT a.i, b.i FROM $table a JOIN $table b ON a.i = b.i + 1", (2 to 10).map(i => Row(i, i - 1)).toSeq) + } + - // Verify uncaching - uncacheTable("oneToTen") - assertCached(sql("SELECT * FROM oneToTen"), 0) + Seq("oneToTen", "oneToTen_with_schema").foreach { table => + + test(s"Caching $table") { + // Cached Query Execution + cacheTable(s"$table") + assertCached(sql(s"SELECT * FROM $table")) + checkAnswer( + sql(s"SELECT * FROM $table"), + (1 to 10).map(Row(_)).toSeq) + + assertCached(sql(s"SELECT i FROM $table")) + checkAnswer( + sql(s"SELECT i FROM $table"), + (1 to 10).map(Row(_)).toSeq) + + assertCached(sql(s"SELECT i FROM $table WHERE i < 5")) + checkAnswer( + sql(s"SELECT i FROM $table WHERE i < 5"), + (1 to 4).map(Row(_)).toSeq) + + assertCached(sql(s"SELECT i * 2 FROM $table")) + checkAnswer( + sql(s"SELECT i * 2 FROM $table"), + (1 to 10).map(i => Row(i * 2)).toSeq) + + assertCached(sql(s"SELECT a.i, b.i FROM $table a JOIN $table b ON a.i = b.i + 1"), 2) + checkAnswer( + sql(s"SELECT a.i, b.i FROM $table a JOIN $table b ON a.i = b.i + 1"), + (2 to 10).map(i => Row(i, i - 1)).toSeq) + + // Verify uncaching + uncacheTable(s"$table") + assertCached(sql(s"SELECT * FROM $table"), 0) + } } test("defaultSource") {