diff --git a/assembly/pom.xml b/assembly/pom.xml
index 4e2b773e7d2f3..8269c6985df34 100644
--- a/assembly/pom.xml
+++ b/assembly/pom.xml
@@ -209,6 +209,16 @@
+
+ hbase
+
+
+ org.apache.spark
+ spark-hbase_${scala.binary.version}
+ ${project.version}
+
+
+
spark-ganglia-lgpl
diff --git a/bin/compute-classpath.cmd b/bin/compute-classpath.cmd
index a4c099fb45b14..23f5b52479452 100644
--- a/bin/compute-classpath.cmd
+++ b/bin/compute-classpath.cmd
@@ -81,6 +81,7 @@ set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%tools\target\scala-%SCALA_VERSION%\clas
set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\catalyst\target\scala-%SCALA_VERSION%\classes
set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\core\target\scala-%SCALA_VERSION%\classes
set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\hive\target\scala-%SCALA_VERSION%\classes
+set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\hbase\target\scala-%SCALA_VERSION%\classes
set SPARK_TEST_CLASSES=%FWDIR%core\target\scala-%SCALA_VERSION%\test-classes
set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%repl\target\scala-%SCALA_VERSION%\test-classes
@@ -91,6 +92,7 @@ set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%streaming\target\scala-%SCALA
set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\catalyst\target\scala-%SCALA_VERSION%\test-classes
set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\core\target\scala-%SCALA_VERSION%\test-classes
set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\hive\target\scala-%SCALA_VERSION%\test-classes
+set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\hbase\target\scala-%SCALA_VERSION%\test-classes
if "x%SPARK_TESTING%"=="x1" (
rem Add test clases to path - note, add SPARK_CLASSES and SPARK_TEST_CLASSES before CLASSPATH
diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh
index 298641f2684de..e2144722881c6 100755
--- a/bin/compute-classpath.sh
+++ b/bin/compute-classpath.sh
@@ -59,6 +59,7 @@ if [ -n "$SPARK_PREPEND_CLASSES" ]; then
CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SPARK_SCALA_VERSION/classes"
CLASSPATH="$CLASSPATH:$FWDIR/sql/hive-thriftserver/target/scala-$SPARK_SCALA_VERSION/classes"
CLASSPATH="$CLASSPATH:$FWDIR/yarn/stable/target/scala-$SPARK_SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/sql/hbase/target/scala-$SCALA_VERSION/classes"
fi
# Use spark-assembly jar from either RELEASE or assembly directory
@@ -130,6 +131,7 @@ if [[ $SPARK_TESTING == 1 ]]; then
CLASSPATH="$CLASSPATH:$FWDIR/sql/catalyst/target/scala-$SPARK_SCALA_VERSION/test-classes"
CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SPARK_SCALA_VERSION/test-classes"
CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SPARK_SCALA_VERSION/test-classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/sql/hbase/target/scala-$SCALA_VERSION/classes"
fi
# Add hadoop conf dir if given -- otherwise FileSystem.*, etc fail !
diff --git a/bin/hbase-sql b/bin/hbase-sql
new file mode 100755
index 0000000000000..4ea11a4faaf12
--- /dev/null
+++ b/bin/hbase-sql
@@ -0,0 +1,55 @@
+#!/usr/bin/env bash
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+#
+# Shell script for starting the Spark SQL for HBase CLI
+
+# Enter posix mode for bash
+set -o posix
+
+CLASS="org.apache.spark.sql.hbase.HBaseSQLCLIDriver"
+
+# Figure out where Spark is installed
+FWDIR="$(cd "`dirname "$0"`"/..; pwd)"
+
+function usage {
+ echo "Usage: ./bin/hbase-sql [options] [cli option]"
+ pattern="usage"
+ pattern+="\|Spark assembly has been built with hbase"
+ pattern+="\|NOTE: SPARK_PREPEND_CLASSES is set"
+ pattern+="\|Spark Command: "
+ pattern+="\|--help"
+ pattern+="\|======="
+
+ "$FWDIR"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2
+ echo
+ echo "CLI options:"
+ "$FWDIR"/bin/spark-class $CLASS --help 2>&1 | grep -v "$pattern" 1>&2
+}
+
+if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then
+ usage
+ exit 0
+fi
+
+source "$FWDIR"/bin/utils.sh
+SUBMIT_USAGE_FUNCTION=usage
+gatherSparkSubmitOpts "$@"
+
+exec "$FWDIR"/bin/spark-submit --class $CLASS "${SUBMISSION_OPTS[@]}" spark-internal "${APPLICATION_OPTS[@]}"
diff --git a/examples/pom.xml b/examples/pom.xml
index 8713230e1e8ed..ea7c34dae1dde 100644
--- a/examples/pom.xml
+++ b/examples/pom.xml
@@ -101,140 +101,140 @@
org.eclipse.jetty
jetty-server
-
- org.apache.hbase
- hbase-testing-util
- ${hbase.version}
-
-
-
- org.apache.hbase
- hbase-annotations
-
-
- org.jruby
- jruby-complete
-
-
-
-
- org.apache.hbase
- hbase-protocol
- ${hbase.version}
-
-
- org.apache.hbase
- hbase-common
- ${hbase.version}
-
-
-
- org.apache.hbase
- hbase-annotations
-
-
-
-
- org.apache.hbase
- hbase-client
- ${hbase.version}
-
-
-
- org.apache.hbase
- hbase-annotations
-
-
+
+ org.apache.hbase
+ hbase-testing-util
+ ${hbase.version}
+
+
+
+ org.apache.hbase
+ hbase-annotations
+
+
+ org.jruby
+ jruby-complete
+
+
+
+
+ org.apache.hbase
+ hbase-protocol
+ ${hbase.version}
+
+
+ org.apache.hbase
+ hbase-common
+ ${hbase.version}
+
+
+
+ org.apache.hbase
+ hbase-annotations
+
+
+
+
+ org.apache.hbase
+ hbase-client
+ ${hbase.version}
+
+
+
+ org.apache.hbase
+ hbase-annotations
+
+
io.netty
netty
-
-
-
-
- org.apache.hbase
- hbase-server
- ${hbase.version}
-
-
- org.apache.hadoop
- hadoop-core
-
-
- org.apache.hadoop
- hadoop-client
-
-
- org.apache.hadoop
- hadoop-mapreduce-client-jobclient
-
-
- org.apache.hadoop
- hadoop-mapreduce-client-core
-
-
- org.apache.hadoop
- hadoop-auth
-
-
-
- org.apache.hbase
- hbase-annotations
-
-
- org.apache.hadoop
- hadoop-annotations
-
-
- org.apache.hadoop
- hadoop-hdfs
-
-
- org.apache.hbase
- hbase-hadoop1-compat
-
-
- org.apache.commons
- commons-math
-
-
- com.sun.jersey
- jersey-core
-
-
- org.slf4j
- slf4j-api
-
-
- com.sun.jersey
- jersey-server
-
-
- com.sun.jersey
- jersey-core
-
-
- com.sun.jersey
- jersey-json
-
-
-
- commons-io
- commons-io
-
-
-
-
- org.apache.hbase
- hbase-hadoop-compat
- ${hbase.version}
-
-
- org.apache.hbase
- hbase-hadoop-compat
- ${hbase.version}
- test-jar
- test
-
+
+
+
+
+ org.apache.hbase
+ hbase-server
+ ${hbase.version}
+
+
+ org.apache.hadoop
+ hadoop-core
+
+
+ org.apache.hadoop
+ hadoop-client
+
+
+ org.apache.hadoop
+ hadoop-mapreduce-client-jobclient
+
+
+ org.apache.hadoop
+ hadoop-mapreduce-client-core
+
+
+ org.apache.hadoop
+ hadoop-auth
+
+
+
+ org.apache.hbase
+ hbase-annotations
+
+
+ org.apache.hadoop
+ hadoop-annotations
+
+
+ org.apache.hadoop
+ hadoop-hdfs
+
+
+ org.apache.hbase
+ hbase-hadoop1-compat
+
+
+ org.apache.commons
+ commons-math
+
+
+ com.sun.jersey
+ jersey-core
+
+
+ org.slf4j
+ slf4j-api
+
+
+ com.sun.jersey
+ jersey-server
+
+
+ com.sun.jersey
+ jersey-core
+
+
+ com.sun.jersey
+ jersey-json
+
+
+
+ commons-io
+ commons-io
+
+
+
+
+ org.apache.hbase
+ hbase-hadoop-compat
+ ${hbase.version}
+
+
+ org.apache.hbase
+ hbase-hadoop-compat
+ ${hbase.version}
+ test-jar
+ test
+
org.apache.commons
commons-math3
@@ -416,7 +416,7 @@
-
scala-2.10
diff --git a/pom.xml b/pom.xml
index b7df53d3e5eb1..901616f257995 100644
--- a/pom.xml
+++ b/pom.xml
@@ -156,7 +156,7 @@
central
Maven Repository
- https://repo1.maven.org/maven2
+ http://repo1.maven.org/maven2
true
@@ -167,7 +167,7 @@
apache-repo
Apache Repository
- https://repository.apache.org/content/repositories/releases
+ http://repository.apache.org/content/repositories/releases
true
@@ -178,7 +178,7 @@
jboss-repo
JBoss Repository
- https://repository.jboss.org/nexus/content/repositories/releases
+ http://repository.jboss.org/nexus/content/repositories/releases
true
@@ -189,7 +189,7 @@
mqtt-repo
MQTT Repository
- https://repo.eclipse.org/content/repositories/paho-releases
+ http://repo.eclipse.org/content/repositories/paho-releases
true
@@ -200,7 +200,7 @@
cloudera-repo
Cloudera Repository
- https://repository.cloudera.com/artifactory/cloudera-repos
+ http://repository.cloudera.com/artifactory/cloudera-repos
true
@@ -222,7 +222,7 @@
spring-releases
Spring Release Repository
- https://repo.spring.io/libs-release
+ http://repo.spring.io/libs-release
true
@@ -234,7 +234,7 @@
spark-staging-1038
Spark 1.2.0 Staging (1038)
- https://repository.apache.org/content/repositories/orgapachespark-1038/
+ http://repository.apache.org/content/repositories/orgapachespark-1038/
true
@@ -246,7 +246,7 @@
central
- https://repo1.maven.org/maven2
+ http://repo1.maven.org/maven2
true
@@ -936,6 +936,7 @@
${java.version}
${java.version}
+ true
UTF-8
1024m
true
@@ -1415,7 +1416,6 @@
10.10.1.1
-
scala-2.10
@@ -1431,7 +1431,6 @@
external/kafka
-
scala-2.11
diff --git a/python/pyspark/hbase.py b/python/pyspark/hbase.py
new file mode 100644
index 0000000000000..e68a5f4e74f2a
--- /dev/null
+++ b/python/pyspark/hbase.py
@@ -0,0 +1,84 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from py4j.protocol import Py4JError
+import traceback
+
+from pyspark.sql import *
+
+__all__ = [
+ "StringType", "BinaryType", "BooleanType", "TimestampType", "DecimalType",
+ "DoubleType", "FloatType", "ByteType", "IntegerType", "LongType",
+ "ShortType", "ArrayType", "MapType", "StructField", "StructType",
+ "SQLContext", "HBaseSQLContext", "SchemaRDD", "Row", "_ssql_ctx", "_get_object_id"]
+
+
+class HBaseSQLContext(SQLContext):
+ """A variant of Spark SQL that integrates with data stored in Hive.
+
+ Configuration for Hive is read from hive-site.xml on the classpath.
+ It supports running both SQL and HiveQL commands.
+ """
+
+ def __init__(self, sparkContext, hbaseContext=None):
+ """Create a new HiveContext.
+
+ @param sparkContext: The SparkContext to wrap.
+ @param hbaseContext: An optional JVM Scala HBaseSQLContext. If set, we do not instatiate a new
+ HBaseSQLContext in the JVM, instead we make all calls to this object.
+ """
+ SQLContext.__init__(self, sparkContext)
+
+ if hbaseContext:
+ self._scala_HBaseSQLContext = hbaseContext
+ else:
+ self._scala_HBaseSQLContext = None
+ print("HbaseContext is %s" % self._scala_HBaseSQLContext)
+
+ @property
+ def _ssql_ctx(self):
+ # try:
+ if self._scala_HBaseSQLContext is None:
+ # if not hasattr(self, '_scala_HBaseSQLContext'):
+ print ("loading hbase context ..")
+ self._scala_HBaseSQLContext = self._get_hbase_ctx()
+ self._scala_SQLContext = self._scala_HBaseSQLContext
+ else:
+ print("We already have hbase context")
+
+ print vars(self)
+ return self._scala_HBaseSQLContext
+ # except Py4JError as e:
+ # import sys
+ # traceback.print_stack(file=sys.stdout)
+ # print ("Nice error .. %s " %e)
+ # print(e)
+ # raise Exception(""
+ # "HbaseSQLContext not found: You must build Spark with Hbase.", e)
+
+ def _get_hbase_ctx(self):
+ print("sc=%s conf=%s" %(self._jsc.sc(), self._jsc.sc().configuration))
+ return self._jvm.HBaseSQLContext(self._jsc.sc())
+
+
+ class HBaseSchemaRDD(SchemaRDD):
+ def createTable(self, tableName, overwrite=False):
+ """Inserts the contents of this SchemaRDD into the specified table.
+
+ Optionally overwriting any existing data.
+ """
+ self._jschema_rdd.createTable(tableName, overwrite)
+
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index a2bcd73b6074f..559ef3e1596bc 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -107,13 +107,16 @@ class SqlParser extends AbstractSparkSQLParser {
protected val WHERE = Keyword("WHERE")
// Use reflection to find the reserved words defined in this class.
+ /* TODO: It will cause the null exception for the subClass of SqlParser.
+ * Temporary solution: Add one more filter to restrain the class must be SqlParser
+ */
protected val reservedWords =
this
.getClass
.getMethods
.filter(_.getReturnType == classOf[Keyword])
- .map(_.invoke(this).asInstanceOf[Keyword].str)
-
+ .filter(_.toString.contains("org.apache.spark.sql.catalyst.SqlParser.".toCharArray))
+ .map{_.invoke(this).asInstanceOf[Keyword].str}
override val lexical = new SqlLexical(reservedWords)
protected def assignAliases(exprs: Seq[Expression]): Seq[NamedExpression] = {
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index 3bd283fd20156..2d4e8f2493590 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -83,6 +83,140 @@
scalacheck_${scala.binary.version}
test
+
+ org.apache.hbase
+ hbase-testing-util
+ 0.98.5-hadoop2
+
+
+
+ org.apache.hbase
+ hbase-annotations
+
+
+ org.jruby
+ jruby-complete
+
+
+
+
+ org.apache.hbase
+ hbase-protocol
+ 0.98.5-hadoop2
+
+
+ org.apache.hbase
+ hbase-common
+ 0.98.5-hadoop2
+
+
+
+ org.apache.hbase
+ hbase-annotations
+
+
+
+
+ org.apache.hbase
+ hbase-client
+ 0.98.5-hadoop2
+
+
+
+ org.apache.hbase
+ hbase-annotations
+
+
+ io.netty
+ netty
+
+
+
+
+ org.apache.hbase
+ hbase-server
+ 0.98.5-hadoop2
+
+
+ org.apache.hadoop
+ hadoop-core
+
+
+ org.apache.hadoop
+ hadoop-client
+
+
+ org.apache.hadoop
+ hadoop-mapreduce-client-jobclient
+
+
+ org.apache.hadoop
+ hadoop-mapreduce-client-core
+
+
+ org.apache.hadoop
+ hadoop-auth
+
+
+
+ org.apache.hbase
+ hbase-annotations
+
+
+ org.apache.hadoop
+ hadoop-annotations
+
+
+ org.apache.hadoop
+ hadoop-hdfs
+
+
+ org.apache.hbase
+ hbase-hadoop1-compat
+
+
+ org.apache.commons
+ commons-math
+
+
+ com.sun.jersey
+ jersey-core
+
+
+ org.slf4j
+ slf4j-api
+
+
+ com.sun.jersey
+ jersey-server
+
+
+ com.sun.jersey
+ jersey-core
+
+
+ com.sun.jersey
+ jersey-json
+
+
+
+ commons-io
+ commons-io
+
+
+
+
+ org.apache.hbase
+ hbase-hadoop-compat
+ 0.98.5-hadoop2
+
+
+ org.apache.hbase
+ hbase-hadoop-compat
+ 0.98.5-hadoop2
+ test-jar
+ test
+
target/scala-${scala.binary.version}/classes
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/hbase/BytesUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/hbase/BytesUtils.scala
new file mode 100644
index 0000000000000..190c772ae5174
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/hbase/BytesUtils.scala
@@ -0,0 +1,141 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+package org.apache.spark.sql.hbase
+
+import org.apache.hadoop.hbase.util.Bytes
+
+class BytesUtils {
+ lazy val booleanArray: HBaseRawType = new HBaseRawType(Bytes.SIZEOF_BOOLEAN)
+ lazy val byteArray: HBaseRawType = new HBaseRawType(Bytes.SIZEOF_BYTE)
+ lazy val charArray: HBaseRawType = new HBaseRawType(Bytes.SIZEOF_CHAR)
+ lazy val doubleArray: HBaseRawType = new HBaseRawType(Bytes.SIZEOF_DOUBLE)
+ lazy val floatArray: HBaseRawType = new HBaseRawType(Bytes.SIZEOF_FLOAT)
+ lazy val intArray: HBaseRawType = new HBaseRawType(Bytes.SIZEOF_INT)
+ lazy val longArray: HBaseRawType = new HBaseRawType(Bytes.SIZEOF_LONG)
+ lazy val shortArray: HBaseRawType = new HBaseRawType(Bytes.SIZEOF_SHORT)
+
+ def toBytes(input: String): HBaseRawType = {
+ Bytes.toBytes(input)
+ }
+
+ def toString(input: HBaseRawType): String = {
+ Bytes.toString(input)
+ }
+
+ def toBytes(input: Byte): HBaseRawType = {
+ // Flip sign bit so that Byte is binary comparable
+ byteArray(0) = (input ^ 0x80).asInstanceOf[Byte]
+ byteArray
+ }
+
+ def toByte(input: HBaseRawType): Byte = {
+ // Flip sign bit back
+ val v: Int = input(0) ^ 0x80
+ v.asInstanceOf[Byte]
+ }
+
+ def toBytes(input: Boolean): HBaseRawType = {
+ booleanArray(0) = 0.asInstanceOf[Byte]
+ if (input) {
+ booleanArray(0) = (-1).asInstanceOf[Byte]
+ }
+ booleanArray
+ }
+
+ def toBoolean(input: HBaseRawType): Boolean = {
+ input(0) != 0
+ }
+
+ def toBytes(input: Double): HBaseRawType = {
+ var l: Long = java.lang.Double.doubleToLongBits(input)
+ l = (l ^ ((l >> java.lang.Long.SIZE - 1) | java.lang.Long.MIN_VALUE)) + 1
+ Bytes.putLong(longArray, 0, l)
+ longArray
+ }
+
+ def toDouble(input: HBaseRawType): Double = {
+ var l: Long = Bytes.toLong(input)
+ l = l - 1
+ l ^= (~l >> java.lang.Long.SIZE - 1) | java.lang.Long.MIN_VALUE
+ java.lang.Double.longBitsToDouble(l)
+ }
+
+ def toBytes(input: Short): HBaseRawType = {
+ shortArray(0) = ((input >> 8) ^ 0x80).asInstanceOf[Byte]
+ shortArray(1) = input.asInstanceOf[Byte]
+ shortArray
+ }
+
+ def toShort(input: HBaseRawType): Short = {
+ // flip sign bit back
+ var v: Int = input(0) ^ 0x80
+ v = (v << 8) + (input(1) & 0xff)
+ v.asInstanceOf[Short]
+ }
+
+ def toBytes(input: Float): HBaseRawType = {
+ var i: Int = java.lang.Float.floatToIntBits(input)
+ i = (i ^ ((i >> Integer.SIZE - 1) | Integer.MIN_VALUE)) + 1
+ toBytes(i)
+ }
+
+ def toFloat(input: HBaseRawType): Float = {
+ var i = toInt(input)
+ i = i - 1
+ i ^= (~i >> Integer.SIZE - 1) | Integer.MIN_VALUE
+ java.lang.Float.intBitsToFloat(i)
+ }
+
+ def toBytes(input: Int): HBaseRawType = {
+ // Flip sign bit so that INTEGER is binary comparable
+ intArray(0) = ((input >> 24) ^ 0x80).asInstanceOf[Byte]
+ intArray(1) = (input >> 16).asInstanceOf[Byte]
+ intArray(2) = (input >> 8).asInstanceOf[Byte]
+ intArray(3) = input.asInstanceOf[Byte]
+ intArray
+ }
+
+ def toInt(input: HBaseRawType): Int = {
+ // Flip sign bit back
+ var v: Int = input(0) ^ 0x80
+ for (i <- 1 to Bytes.SIZEOF_INT - 1) {
+ v = (v << 8) + (input(i) & 0xff)
+ }
+ v
+ }
+
+ def toBytes(input: Long): HBaseRawType = {
+ longArray(0) = ((input >> 56) ^ 0x80).asInstanceOf[Byte]
+ longArray(1) = (input >> 48).asInstanceOf[Byte]
+ longArray(2) = (input >> 40).asInstanceOf[Byte]
+ longArray(3) = (input >> 32).asInstanceOf[Byte]
+ longArray(4) = (input >> 24).asInstanceOf[Byte]
+ longArray(5) = (input >> 16).asInstanceOf[Byte]
+ longArray(6) = (input >> 8).asInstanceOf[Byte]
+ longArray(7) = input.asInstanceOf[Byte]
+ longArray
+ }
+
+ def toLong(input: HBaseRawType): Long = {
+ // Flip sign bit back
+ var v: Long = input(0) ^ 0x80
+ for (i <- 1 to Bytes.SIZEOF_LONG - 1) {
+ v = (v << 8) + (input(i) & 0xff)
+ }
+ v
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/hbase/DataTypeUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/hbase/DataTypeUtils.scala
new file mode 100755
index 0000000000000..8976a4fd48f38
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/hbase/DataTypeUtils.scala
@@ -0,0 +1,109 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+package org.apache.spark.sql.hbase
+
+import org.apache.hadoop.hbase.filter.BinaryComparator
+import org.apache.hadoop.hbase.util.Bytes
+import org.apache.spark.sql.catalyst.expressions.{Literal, MutableRow, Row}
+import org.apache.spark.sql.catalyst.types._
+
+/**
+ * Data Type conversion utilities
+ *
+ */
+object DataTypeUtils {
+ // TODO: more data types support?
+ def bytesToData (src: HBaseRawType,
+ dt: DataType,
+ bu: BytesUtils): Any = {
+ dt match {
+ case StringType => bu.toString(src)
+ case IntegerType => bu.toInt(src)
+ case BooleanType => bu.toBoolean(src)
+ case ByteType => src(0)
+ case DoubleType => bu.toDouble(src)
+ case FloatType => bu.toFloat(src)
+ case LongType => bu.toLong(src)
+ case ShortType => bu.toShort(src)
+ case _ => throw new Exception("Unsupported HBase SQL Data Type")
+ }
+ }
+
+ def setRowColumnFromHBaseRawType(row: MutableRow,
+ index: Int,
+ src: HBaseRawType,
+ dt: DataType,
+ bu: BytesUtils): Unit = {
+ dt match {
+ case StringType => row.setString(index, bu.toString(src))
+ case IntegerType => row.setInt(index, bu.toInt(src))
+ case BooleanType => row.setBoolean(index, bu.toBoolean(src))
+ case ByteType => row.setByte(index, bu.toByte(src))
+ case DoubleType => row.setDouble(index, bu.toDouble(src))
+ case FloatType => row.setFloat(index, bu.toFloat(src))
+ case LongType => row.setLong(index, bu.toLong(src))
+ case ShortType => row.setShort(index, bu.toShort(src))
+ case _ => throw new Exception("Unsupported HBase SQL Data Type")
+ }
+ }
+
+ def getRowColumnFromHBaseRawType(row: Row,
+ index: Int,
+ dt: DataType,
+ bu: BytesUtils): HBaseRawType = {
+ dt match {
+ case StringType => bu.toBytes(row.getString(index))
+ case IntegerType => bu.toBytes(row.getInt(index))
+ case BooleanType => bu.toBytes(row.getBoolean(index))
+ case ByteType => bu.toBytes(row.getByte(index))
+ case DoubleType => bu.toBytes(row.getDouble(index))
+ case FloatType => bu.toBytes(row.getFloat(index))
+ case LongType => bu.toBytes(row.getLong(index))
+ case ShortType => bu.toBytes(row.getShort(index))
+ case _ => throw new Exception("Unsupported HBase SQL Data Type")
+ }
+ }
+
+ def getComparator(expression: Literal): BinaryComparator = {
+ expression.dataType match {
+ case DoubleType => {
+ new BinaryComparator(Bytes.toBytes(expression.value.asInstanceOf[Double]))
+ }
+ case FloatType => {
+ new BinaryComparator(Bytes.toBytes(expression.value.asInstanceOf[Float]))
+ }
+ case IntegerType => {
+ new BinaryComparator(Bytes.toBytes(expression.value.asInstanceOf[Int]))
+ }
+ case LongType => {
+ new BinaryComparator(Bytes.toBytes(expression.value.asInstanceOf[Long]))
+ }
+ case ShortType => {
+ new BinaryComparator(Bytes.toBytes(expression.value.asInstanceOf[Short]))
+ }
+ case StringType => {
+ new BinaryComparator(Bytes.toBytes(expression.value.asInstanceOf[String]))
+ }
+ case BooleanType => {
+ new BinaryComparator(Bytes.toBytes(expression.value.asInstanceOf[Boolean]))
+ }
+ case _ => {
+ throw new Exception("Cannot convert the data type using BinaryComparator")
+ }
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/hbase/HBaseKVHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/hbase/HBaseKVHelper.scala
new file mode 100644
index 0000000000000..8af45a813a8a9
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/hbase/HBaseKVHelper.scala
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hbase
+
+import org.apache.hadoop.hbase.util.Bytes
+import org.apache.spark.sql.catalyst.types._
+
+import scala.collection.mutable
+import scala.collection.mutable.{ArrayBuffer, ListBuffer}
+
+object HBaseKVHelper {
+ private val delimiter: Byte = 0
+
+ /**
+ * create row key based on key columns information
+ * @param buffer an input buffer
+ * @param rawKeyColumns sequence of byte array and data type representing the key columns
+ * @return array of bytes
+ */
+ def encodingRawKeyColumns(buffer: ListBuffer[Byte],
+ rawKeyColumns: Seq[(HBaseRawType, DataType)]): HBaseRawType = {
+ var listBuffer = buffer
+ listBuffer.clear()
+ for (rawKeyColumn <- rawKeyColumns) {
+ listBuffer = listBuffer ++ rawKeyColumn._1
+ if (rawKeyColumn._2 == StringType) {
+ listBuffer += delimiter
+ }
+ }
+ listBuffer.toArray
+ }
+
+ /**
+ * get the sequence of key columns from the byte array
+ * @param buffer an input buffer
+ * @param rowKey array of bytes
+ * @param keyColumns the sequence of key columns
+ * @return sequence of byte array
+ */
+ def decodingRawKeyColumns(buffer: ListBuffer[HBaseRawType],
+ rowKey: HBaseRawType, keyColumns: Seq[KeyColumn]): Seq[HBaseRawType] = {
+ var listBuffer = buffer
+ listBuffer.clear()
+ var arrayBuffer = ArrayBuffer[Byte]()
+ var index = 0
+ for (keyColumn <- keyColumns) {
+ arrayBuffer.clear()
+ val dataType = keyColumn.dataType
+ if (dataType == StringType) {
+ while (index < rowKey.length && rowKey(index) != delimiter) {
+ arrayBuffer += rowKey(index)
+ index = index + 1
+ }
+ index = index + 1
+ }
+ else {
+ val length = NativeType.defaultSizeOf(dataType.asInstanceOf[NativeType])
+ for (i <- 0 to (length - 1)) {
+ arrayBuffer += rowKey(index)
+ index = index + 1
+ }
+ }
+ listBuffer += arrayBuffer.toArray
+ }
+ listBuffer.toSeq
+ }
+
+ /**
+ * Takes a record, translate it into HBase row key column and value by matching with metadata
+ * @param values record that as a sequence of string
+ * @param columns metadata that contains KeyColumn and NonKeyColumn
+ * @param keyBytes output paramater, array of (key column and its type);
+ * @param valueBytes array of (column family, column qualifier, value)
+ */
+ def string2KV(values: Seq[String],
+ columns: Seq[AbstractColumn],
+ keyBytes: ListBuffer[(Array[Byte], DataType)],
+ valueBytes: ListBuffer[(Array[Byte], Array[Byte], Array[Byte])]) = {
+ assert(values.length == columns.length,
+ s"values length ${values.length} not equals lolumns length ${columns.length}")
+ keyBytes.clear()
+ valueBytes.clear()
+ val map = mutable.HashMap[Int, (Array[Byte], DataType)]()
+ var index = 0
+ for (i <- 0 until values.length) {
+ val value = values(i)
+ val column = columns(i)
+ val bytes = string2Bytes(value, column.dataType, new BytesUtils)
+ if (column.isKeyColum()) {
+ map(column.asInstanceOf[KeyColumn].order) = ((bytes, column.dataType))
+ index = index + 1
+ } else {
+ val realCol = column.asInstanceOf[NonKeyColumn]
+ valueBytes += ((Bytes.toBytes(realCol.family), Bytes.toBytes(realCol.qualifier), bytes))
+ }
+ }
+
+ (0 until index).foreach(k => keyBytes += map.get(k).get)
+ }
+
+ private def string2Bytes(v: String, dataType: DataType, bu: BytesUtils): Array[Byte] = {
+ dataType match {
+ // todo: handle some complex types
+ case BooleanType => bu.toBytes(v.toBoolean)
+ case ByteType => bu.toBytes(v)
+ case DoubleType => bu.toBytes(v.toDouble)
+ case FloatType => bu.toBytes((v.toFloat))
+ case IntegerType => bu.toBytes(v.toInt)
+ case LongType => bu.toBytes(v.toLong)
+ case ShortType => bu.toBytes(v.toShort)
+ case StringType => bu.toBytes(v)
+ }
+ }
+}
+
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/hbase/HBaseMetadata.scala b/sql/core/src/main/scala/org/apache/spark/sql/hbase/HBaseMetadata.scala
new file mode 100644
index 0000000000000..40d9c56c1436e
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/hbase/HBaseMetadata.scala
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hbase
+
+import java.io._
+import org.apache.spark.sql.catalyst.types.DataType
+
+import scala.Some
+
+import org.apache.hadoop.hbase.{HColumnDescriptor, TableName, HTableDescriptor, HBaseConfiguration}
+import org.apache.hadoop.hbase.client._
+import org.apache.hadoop.hbase.util.Bytes
+
+import org.apache.spark.sql.hbase._
+import org.apache.spark.Logging
+import org.apache.spark.sql.hbase.HBaseRelation
+import org.apache.spark.sql.hbase.NonKeyColumn
+
+/**
+ * Column represent the sql column
+ * sqlName the name of the column
+ * dataType the data type of the column
+ */
+sealed abstract class AbstractColumn {
+ val sqlName: String
+ val dataType: DataType
+
+ def isKeyColum(): Boolean = false
+
+ override def toString: String = {
+ s"$sqlName , $dataType.typeName"
+ }
+}
+
+case class KeyColumn(val sqlName: String, val dataType: DataType, val order: Int)
+ extends AbstractColumn {
+ override def isKeyColum() = true
+}
+
+case class NonKeyColumn(
+ val sqlName: String,
+ val dataType: DataType,
+ val family: String,
+ val qualifier: String) extends AbstractColumn {
+ @transient lazy val familyRaw = Bytes.toBytes(family)
+ @transient lazy val qualifierRaw = Bytes.toBytes(qualifier)
+
+ override def toString = {
+ s"$sqlName , $dataType.typeName , $family:$qualifier"
+ }
+}
+
+private[hbase] class HBaseMetadata extends Logging with Serializable {
+
+ lazy val configuration = HBaseConfiguration.create()
+
+ lazy val admin = new HBaseAdmin(configuration)
+
+ logDebug(s"HBaseAdmin.configuration zkPort="
+ + s"${admin.getConfiguration.get("hbase.zookeeper.property.clientPort")}")
+
+ private def createHBaseUserTable(tableName: String, allColumns: Seq[AbstractColumn]) {
+ val tableDescriptor = new HTableDescriptor(TableName.valueOf(tableName))
+ allColumns.map(x =>
+ if (x.isInstanceOf[NonKeyColumn]) {
+ val nonKeyColumn = x.asInstanceOf[NonKeyColumn]
+ tableDescriptor.addFamily(new HColumnDescriptor(nonKeyColumn.family))
+ })
+
+ admin.createTable(tableDescriptor, null);
+ }
+
+ def createTable(
+ tableName: String,
+ hbaseTableName: String,
+ allColumns: Seq[AbstractColumn]) = {
+ // create a new hbase table for the user if not exist
+ if (!checkHBaseTableExists(hbaseTableName)) {
+ createHBaseUserTable(hbaseTableName, allColumns)
+ }
+ // check hbase table contain all the families
+ val nonKeyColumns = allColumns.filter(_.isInstanceOf[NonKeyColumn])
+ nonKeyColumns.foreach {
+ case NonKeyColumn(_, _, family, _) =>
+ if (!checkFamilyExists(hbaseTableName, family)) {
+ throw new Exception(s"The HBase table doesn't contain the Column Family: $family")
+ }
+ }
+
+ HBaseRelation(tableName, "", hbaseTableName, allColumns, Some(configuration))
+ }
+
+ private[hbase] def checkHBaseTableExists(hbaseTableName: String): Boolean = {
+ admin.tableExists(hbaseTableName)
+ }
+
+ private[hbase] def checkFamilyExists(hbaseTableName: String, family: String): Boolean = {
+ val tableDescriptor = admin.getTableDescriptor(TableName.valueOf(hbaseTableName))
+ tableDescriptor.hasFamily(Bytes.toBytes(family))
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/hbase/HBasePartition.scala b/sql/core/src/main/scala/org/apache/spark/sql/hbase/HBasePartition.scala
new file mode 100755
index 0000000000000..770dd99f63d4b
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/hbase/HBasePartition.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.hbase
+
+import org.apache.spark.Partition
+import org.apache.spark.sql.catalyst.expressions.Expression
+
+private[hbase] class HBasePartition(
+ val idx: Int, val mappedIndex: Int,
+ val lowerBound: Option[HBaseRawType] = None,
+ val upperBound: Option[HBaseRawType] = None,
+ val server: Option[String] = None,
+ val filterPred: Option[Expression] = None) extends Partition with IndexMappable {
+
+ override def index: Int = idx
+
+ override def hashCode(): Int = idx
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/hbase/HBaseRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/hbase/HBaseRelation.scala
new file mode 100755
index 0000000000000..df493612d5069
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/hbase/HBaseRelation.scala
@@ -0,0 +1,666 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.hbase
+
+import java.util.ArrayList
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.hbase.HBaseConfiguration
+import org.apache.hadoop.hbase.client.{Get, HTable, Put, Result, Scan}
+import org.apache.hadoop.hbase.filter._
+import org.apache.hadoop.hbase.util.Bytes
+import org.apache.spark.Partition
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical.LeafNode
+import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.hbase.PartialPredicateOperations._
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable.{ArrayBuffer, ListBuffer}
+
+
+private[hbase] case class HBaseRelation(
+ tableName: String,
+ hbaseNamespace: String,
+ hbaseTableName: String,
+ allColumns: Seq[AbstractColumn],
+ @transient optConfiguration: Option[Configuration] = None)
+ extends LeafNode {
+
+ @transient lazy val keyColumns = allColumns.filter(_.isInstanceOf[KeyColumn])
+ .asInstanceOf[Seq[KeyColumn]].sortBy(_.order)
+
+ @transient lazy val nonKeyColumns = allColumns.filter(_.isInstanceOf[NonKeyColumn])
+ .asInstanceOf[Seq[NonKeyColumn]]
+
+ @transient lazy val partitionKeys: Seq[AttributeReference] = keyColumns.map(col =>
+ AttributeReference(col.sqlName, col.dataType, nullable = false)())
+
+ @transient lazy val columnMap = allColumns.map {
+ case key: KeyColumn => (key.sqlName, key.order)
+ case nonKey: NonKeyColumn => (nonKey.sqlName, nonKey)
+ }.toMap
+
+ def configuration() = optConfiguration.getOrElse(HBaseConfiguration.create)
+
+ // todo: scwf,remove this later
+ logDebug(s"HBaseRelation config has zkPort="
+ + s"${configuration.get("hbase.zookeeper.property.clientPort")}")
+
+ @transient lazy val htable: HTable = new HTable(configuration, hbaseTableName)
+
+ // todo: scwf, why non key columns
+ lazy val attributes = nonKeyColumns.map(col =>
+ AttributeReference(col.sqlName, col.dataType, nullable = true)())
+
+ // lazy val colFamilies = nonKeyColumns.map(_.family).distinct
+ // lazy val applyFilters = false
+
+ def isNonKey(attr: AttributeReference): Boolean = {
+ attributes.exists(_.exprId == attr.exprId)
+ }
+
+ def keyIndex(attr: AttributeReference): Int = {
+ // -1 if nonexistent
+ partitionKeys.indexWhere(_.exprId == attr.exprId)
+ }
+ def closeHTable() = htable.close
+
+ val output: Seq[Attribute] = {
+ allColumns.map {
+ case column =>
+ (partitionKeys union attributes).find(_.name == column.sqlName).get
+ }
+ }
+
+ lazy val partitions: Seq[HBasePartition] = {
+ val regionLocations = htable.getRegionLocations.asScala.toSeq
+ log.info(s"Number of HBase regions for " +
+ s"table ${htable.getName.getNameAsString}: ${regionLocations.size}")
+ regionLocations.zipWithIndex.map {
+ case p =>
+ new HBasePartition(
+ p._2, p._2,
+ Some(p._1._1.getStartKey),
+ Some(p._1._1.getEndKey),
+ Some(p._1._2.getHostname))
+ }
+ }
+
+ private def generateRange(partition: HBasePartition, pred: Expression,
+ index: Int):
+ (PartitionRange[_]) = {
+ def getData(dt: NativeType,
+ buffer: ListBuffer[HBaseRawType],
+ bound: Option[HBaseRawType]): Option[Any] = {
+ if (Bytes.toStringBinary(bound.get) == "") None
+ else {
+ val bytesUtils = new BytesUtils
+ Some(DataTypeUtils.bytesToData(
+ HBaseKVHelper.decodingRawKeyColumns(buffer, bound.get, keyColumns)(index),
+ dt, bytesUtils).asInstanceOf[dt.JvmType])
+ }
+ }
+
+ val dt = keyColumns(index).dataType.asInstanceOf[NativeType]
+ val isLastKeyIndex = index == (keyColumns.size - 1)
+ val buffer = ListBuffer[HBaseRawType]()
+ val start = getData(dt, buffer, partition.lowerBound)
+ val end = getData(dt, buffer, partition.upperBound)
+ val startInclusive = !start.isEmpty
+ val endInclusive = !end.isEmpty && !isLastKeyIndex
+ new PartitionRange(start, startInclusive, end, endInclusive, partition.index, dt, pred)
+ }
+
+ private def prePruneRanges(ranges: Seq[PartitionRange[_]], keyIndex: Int)
+ : (Seq[PartitionRange[_]], Seq[PartitionRange[_]]) = {
+ require(keyIndex < keyColumns.size, "key index out of range")
+ if (ranges.isEmpty) {
+ (ranges, Nil)
+ } else if (keyIndex == 0) {
+ (Nil, ranges)
+ } else {
+ // the first portion is of those ranges of equal start and end values of the
+ // previous dimensions so they can be subject to further checks on the next dimension
+ val (p1, p2) = ranges.partition(p => p.start == p.end)
+ (p2, p1.map(p => generateRange(partitions(p.id), p.pred, keyIndex)))
+ }
+ }
+
+ def getPrunedPartitions(partitionPred: Option[Expression] = None): Option[Seq[HBasePartition]] = {
+ def getPrunedRanges(pred: Expression): Seq[PartitionRange[_]] = {
+ val predRefs = pred.references.toSeq
+ val boundPruningPred = BindReferences.bindReference(pred, predRefs)
+ val keyIndexToPredIndex = (for {
+ (keyColumn, keyIndex) <- keyColumns.zipWithIndex
+ (predRef, predIndex) <- predRefs.zipWithIndex
+ if (keyColumn.sqlName == predRef.name)
+ } yield (keyIndex, predIndex)).toMap
+
+ val row = new GenericMutableRow(predRefs.size)
+ var notPrunedRanges = partitions.map(generateRange(_, null, 0))
+ var prunedRanges: Seq[PartitionRange[_]] = Nil
+
+ for (keyIndex <- 0 until keyColumns.size; if (!notPrunedRanges.isEmpty)) {
+ val (passedRanges, toBePrunedRanges) = prePruneRanges(notPrunedRanges, keyIndex)
+ prunedRanges = prunedRanges ++ passedRanges
+ notPrunedRanges =
+ if (keyIndexToPredIndex.contains(keyIndex)) {
+ toBePrunedRanges.filter(
+ range => {
+ val predIndex = keyIndexToPredIndex(keyIndex)
+ row.update(predIndex, range)
+ val partialEvalResult = boundPruningPred.partialEval(row)
+ // MAYBE is represented by a null
+ (partialEvalResult == null) || partialEvalResult.asInstanceOf[Boolean]
+ }
+ )
+ } else toBePrunedRanges
+ }
+ prunedRanges ++ notPrunedRanges
+ }
+
+ partitionPred match {
+ case None => Some(partitions)
+ case Some(pred) => if (pred.references.intersect(AttributeSet(partitionKeys)).isEmpty) {
+ Some(partitions)
+ } else {
+ val prunedRanges: Seq[PartitionRange[_]] = getPrunedRanges(pred)
+ println("prunedRanges: " + prunedRanges.length)
+ var idx: Int = -1
+ val result = Some(prunedRanges.map(p => {
+ val par = partitions(p.id)
+ idx = idx + 1
+ if (p.pred == null) {
+ new HBasePartition(idx, par.mappedIndex, par.lowerBound, par.upperBound, par.server)
+ } else {
+ new HBasePartition(idx, par.mappedIndex, par.lowerBound, par.upperBound,
+ par.server, Some(p.pred))
+ }
+ }))
+ result.foreach(println)
+ result
+ }
+ }
+ }
+
+ def getPrunedPartitions2(partitionPred: Option[Expression] = None)
+ : Option[Seq[HBasePartition]] = {
+ def getPrunedRanges(pred: Expression): Seq[PartitionRange[_]] = {
+ val predRefs = pred.references.toSeq
+ val boundPruningPred = BindReferences.bindReference(pred, predRefs)
+ val keyIndexToPredIndex = (for {
+ (keyColumn, keyIndex) <- keyColumns.zipWithIndex
+ (predRef, predIndex) <- predRefs.zipWithIndex
+ if (keyColumn.sqlName == predRef.name)
+ } yield (keyIndex, predIndex)).toMap
+
+ val row = new GenericMutableRow(predRefs.size)
+ var notPrunedRanges = partitions.map(generateRange(_, boundPruningPred, 0))
+ var prunedRanges: Seq[PartitionRange[_]] = Nil
+
+ for (keyIndex <- 0 until keyColumns.size; if (!notPrunedRanges.isEmpty)) {
+ val (passedRanges, toBePrunedRanges) = prePruneRanges(notPrunedRanges, keyIndex)
+ prunedRanges = prunedRanges ++ passedRanges
+ notPrunedRanges =
+ if (keyIndexToPredIndex.contains(keyIndex)) {
+ toBePrunedRanges.filter(
+ range => {
+ val predIndex = keyIndexToPredIndex(keyIndex)
+ row.update(predIndex, range)
+ val partialEvalResult = range.pred.partialReduce(row)
+ range.pred = if (partialEvalResult.isInstanceOf[Expression]) {
+ // progressively fine tune the constraining predicate
+ partialEvalResult.asInstanceOf[Expression]
+ } else {
+ null
+ }
+ // MAYBE is represented by a to-be-qualified-with expression
+ partialEvalResult.isInstanceOf[Expression] ||
+ partialEvalResult.asInstanceOf[Boolean]
+ }
+ )
+ } else toBePrunedRanges
+ }
+ prunedRanges ++ notPrunedRanges
+ }
+
+ partitionPred match {
+ case None => Some(partitions)
+ case Some(pred) => if (pred.references.intersect(AttributeSet(partitionKeys)).isEmpty) {
+ // the predicate does not apply to the partitions at all; just push down the filtering
+ Some(partitions.map(p=>new HBasePartition(p.idx, p.mappedIndex, p.lowerBound,
+ p.upperBound, p.server, Some(pred))))
+ } else {
+ val prunedRanges: Seq[PartitionRange[_]] = getPrunedRanges(pred)
+ println("prunedRanges: " + prunedRanges.length)
+ var idx: Int = -1
+ val result = Some(prunedRanges.map(p => {
+ val par = partitions(p.id)
+ idx = idx + 1
+ // pruned partitions have the same "physical" partition index, but different
+ // "canonical" index
+ if (p.pred == null) {
+ new HBasePartition(idx, par.mappedIndex, par.lowerBound,
+ par.upperBound, par.server, None)
+ } else {
+ new HBasePartition(idx, par.mappedIndex, par.lowerBound,
+ par.upperBound, par.server, Some(p.pred))
+ }
+ }))
+ // TODO: remove/modify the following debug info
+ // result.foreach(println)
+ result
+ }
+ }
+ }
+
+ /**
+ * Return the start keys of all of the regions in this table,
+ * as a list of SparkImmutableBytesWritable.
+ */
+ def getRegionStartKeys() = {
+ val byteKeys: Array[Array[Byte]] = htable.getStartKeys
+ val ret = ArrayBuffer[ImmutableBytesWritableWrapper]()
+ for (byteKey <- byteKeys) {
+ ret += new ImmutableBytesWritableWrapper(byteKey)
+ }
+ ret
+ }
+
+ def buildFilter(
+ projList: Seq[NamedExpression],
+ rowKeyPredicate: Option[Expression],
+ valuePredicate: Option[Expression]) = {
+ val distinctProjList = projList.distinct
+ if (distinctProjList.size == allColumns.size) {
+ Option(new FilterList(new ArrayList[Filter]))
+ } else {
+ val filtersList: List[Filter] = nonKeyColumns.filter {
+ case nkc => distinctProjList.exists(nkc == _.name)
+ }.map {
+ case NonKeyColumn(_, _, family, qualifier) => {
+ val columnFilters = new ArrayList[Filter]
+ columnFilters.add(
+ new FamilyFilter(
+ CompareFilter.CompareOp.EQUAL,
+ new BinaryComparator(Bytes.toBytes(family))
+ ))
+ columnFilters.add(
+ new QualifierFilter(
+ CompareFilter.CompareOp.EQUAL,
+ new BinaryComparator(Bytes.toBytes(qualifier))
+ ))
+ new FilterList(FilterList.Operator.MUST_PASS_ALL, columnFilters)
+ }
+ }.toList
+
+ Option(new FilterList(FilterList.Operator.MUST_PASS_ONE, filtersList.asJava))
+ }
+ }
+
+ def buildFilter2(
+ projList: Seq[NamedExpression],
+ pred: Option[Expression]): (Option[FilterList], Option[Expression]) = {
+ var distinctProjList = projList.distinct
+ if (pred.isDefined) {
+ distinctProjList = distinctProjList.filterNot(_.references.subsetOf(pred.get.references))
+ }
+ val projFilterList = if (distinctProjList.size == allColumns.size) {
+ None
+ } else {
+ val filtersList: List[Filter] = nonKeyColumns.filter {
+ case nkc => distinctProjList.exists(nkc == _.name)
+ }.map {
+ case NonKeyColumn(_, _, family, qualifier) => {
+ val columnFilters = new ArrayList[Filter]
+ columnFilters.add(
+ new FamilyFilter(
+ CompareFilter.CompareOp.EQUAL,
+ new BinaryComparator(Bytes.toBytes(family))
+ ))
+ columnFilters.add(
+ new QualifierFilter(
+ CompareFilter.CompareOp.EQUAL,
+ new BinaryComparator(Bytes.toBytes(qualifier))
+ ))
+ new FilterList(FilterList.Operator.MUST_PASS_ALL, columnFilters)
+ }
+ }.toList
+
+ Option(new FilterList(FilterList.Operator.MUST_PASS_ONE, filtersList.asJava))
+ }
+
+ if (pred.isDefined) {
+ val predExp: Expression = pred.get
+ // build pred pushdown filters:
+ // 1. push any NOT through AND/OR
+ val notPushedPred = NOTPusher(predExp)
+ // 2. classify the transformed predicate into pushdownable and non-pushdownable predicates
+ val classfier = new ScanPredClassfier(this, 0) // Right now only on primary key dimension
+ val (pushdownFilterPred, otherPred) = classfier(notPushedPred)
+ // 3. build a FilterList mirroring the pushdownable predicate
+ val predPushdownFilterList = buildFilterListFromPred(pushdownFilterPred)
+ // 4. merge the above FilterList with the one from the projection
+ (predPushdownFilterList, otherPred)
+ } else {
+ (projFilterList, None)
+ }
+ }
+
+ private def buildFilterListFromPred(pred: Option[Expression]): Option[FilterList] = {
+ var result: Option[FilterList] = None
+ val filters = new ArrayList[Filter]
+ if (pred.isDefined) {
+ val expression = pred.get
+ expression match {
+ case And(left, right) => {
+ if (left != null) {
+ val leftFilterList = buildFilterListFromPred(Some(left))
+ if (leftFilterList.isDefined) {
+ filters.add(leftFilterList.get)
+ }
+ }
+ if (right != null) {
+ val rightFilterList = buildFilterListFromPred(Some(right))
+ if (rightFilterList.isDefined) {
+ filters.add(rightFilterList.get)
+ }
+ }
+ result = Option(new FilterList(FilterList.Operator.MUST_PASS_ALL, filters))
+ }
+ case Or(left, right) => {
+ if (left != null) {
+ val leftFilterList = buildFilterListFromPred(Some(left))
+ if (leftFilterList.isDefined) {
+ filters.add(leftFilterList.get)
+ }
+ }
+ if (right != null) {
+ val rightFilterList = buildFilterListFromPred(Some(right))
+ if (rightFilterList.isDefined) {
+ filters.add(rightFilterList.get)
+ }
+ }
+ result = Option(new FilterList(FilterList.Operator.MUST_PASS_ONE, filters))
+ }
+ case GreaterThan(left: AttributeReference, right: Literal) => {
+ val keyColumn = keyColumns.find((p: KeyColumn) => p.sqlName.equals(left.name))
+ val nonKeyColumn = nonKeyColumns.find((p: NonKeyColumn) => p.sqlName.equals(left.name))
+ if (keyColumn.isDefined) {
+ val filter = new RowFilter(CompareFilter.CompareOp.GREATER,
+ new BinaryComparator(Bytes.toBytes(right.value.toString)))
+ result = Option(new FilterList(filter))
+ } else if (nonKeyColumn.isDefined) {
+ val column = nonKeyColumn.get
+ val filter = new SingleColumnValueFilter(Bytes.toBytes(column.family),
+ Bytes.toBytes(column.qualifier),
+ CompareFilter.CompareOp.GREATER,
+ DataTypeUtils.getComparator(right))
+ result = Option(new FilterList(filter))
+ }
+ }
+ case GreaterThanOrEqual(left: AttributeReference, right: Literal) => {
+ val keyColumn = keyColumns.find((p: KeyColumn) => p.sqlName.equals(left.name))
+ val nonKeyColumn = nonKeyColumns.find((p: NonKeyColumn) => p.sqlName.equals(left.name))
+ if (keyColumn.isDefined) {
+ val filter = new RowFilter(CompareFilter.CompareOp.GREATER_OR_EQUAL,
+ new BinaryComparator(Bytes.toBytes(right.value.toString)))
+ result = Option(new FilterList(filter))
+ } else if (nonKeyColumn.isDefined) {
+ val column = nonKeyColumn.get
+ val filter = new SingleColumnValueFilter(Bytes.toBytes(column.family),
+ Bytes.toBytes(column.qualifier),
+ CompareFilter.CompareOp.GREATER_OR_EQUAL,
+ DataTypeUtils.getComparator(right))
+ result = Option(new FilterList(filter))
+ }
+ }
+ case EqualTo(left: AttributeReference, right: Literal) => {
+ val keyColumn = keyColumns.find((p: KeyColumn) => p.sqlName.equals(left.name))
+ val nonKeyColumn = nonKeyColumns.find((p: NonKeyColumn) => p.sqlName.equals(left.name))
+ if (keyColumn.isDefined) {
+ val filter = new RowFilter(CompareFilter.CompareOp.EQUAL,
+ new BinaryComparator(Bytes.toBytes(right.value.toString)))
+ result = Option(new FilterList(filter))
+ } else if (nonKeyColumn.isDefined) {
+ val column = nonKeyColumn.get
+ val filter = new SingleColumnValueFilter(Bytes.toBytes(column.family),
+ Bytes.toBytes(column.qualifier),
+ CompareFilter.CompareOp.EQUAL,
+ DataTypeUtils.getComparator(right))
+ result = Option(new FilterList(filter))
+ }
+ }
+ case LessThan(left: AttributeReference, right: Literal) => {
+ val keyColumn = keyColumns.find((p: KeyColumn) => p.sqlName.equals(left.name))
+ val nonKeyColumn = nonKeyColumns.find((p: NonKeyColumn) => p.sqlName.equals(left.name))
+ if (keyColumn.isDefined) {
+ val filter = new RowFilter(CompareFilter.CompareOp.LESS,
+ new BinaryComparator(Bytes.toBytes(right.value.toString)))
+ result = Option(new FilterList(filter))
+ } else if (nonKeyColumn.isDefined) {
+ val column = nonKeyColumn.get
+ val filter = new SingleColumnValueFilter(Bytes.toBytes(column.family),
+ Bytes.toBytes(column.qualifier),
+ CompareFilter.CompareOp.LESS,
+ DataTypeUtils.getComparator(right))
+ result = Option(new FilterList(filter))
+ }
+ }
+ case LessThanOrEqual(left: AttributeReference, right: Literal) => {
+ val keyColumn = keyColumns.find((p: KeyColumn) => p.sqlName.equals(left.name))
+ val nonKeyColumn = nonKeyColumns.find((p: NonKeyColumn) => p.sqlName.equals(left.name))
+ if (keyColumn.isDefined) {
+ val filter = new RowFilter(CompareFilter.CompareOp.LESS_OR_EQUAL,
+ new BinaryComparator(Bytes.toBytes(right.value.toString)))
+ result = Option(new FilterList(filter))
+ } else if (nonKeyColumn.isDefined) {
+ val column = nonKeyColumn.get
+ val filter = new SingleColumnValueFilter(Bytes.toBytes(column.family),
+ Bytes.toBytes(column.qualifier),
+ CompareFilter.CompareOp.LESS_OR_EQUAL,
+ DataTypeUtils.getComparator(right))
+ result = Option(new FilterList(filter))
+ }
+ }
+ }
+ }
+ result
+ }
+
+ def buildPut(row: Row): Put = {
+ // TODO: revisit this using new KeyComposer
+ val rowKey: HBaseRawType = null
+ new Put(rowKey)
+ }
+
+ def buildScan(
+ split: Partition,
+ filters: Option[FilterList],
+ projList: Seq[NamedExpression]): Scan = {
+ val hbPartition = split.asInstanceOf[HBasePartition]
+ val scan = {
+ (hbPartition.lowerBound, hbPartition.upperBound) match {
+ case (Some(lb), Some(ub)) => new Scan(lb, ub)
+ case (Some(lb), None) => new Scan(lb)
+ case _ => new Scan
+ }
+ }
+ if (filters.isDefined && !filters.get.getFilters.isEmpty) {
+ scan.setFilter(filters.get)
+ }
+ // TODO: add add Family to SCAN from projections
+ scan
+ }
+
+ def buildGet(projList: Seq[NamedExpression], rowKey: HBaseRawType) {
+ new Get(rowKey)
+ // TODO: add columns to the Get
+ }
+
+ // /**
+ // * Trait for RowKeyParser's that convert a raw array of bytes into their constituent
+ // * logical column values
+ // *
+ // */
+ // trait AbstractRowKeyParser {
+ //
+ //// def createKey(rawBytes: Seq[HBaseRawType], version: Byte): HBaseRawType
+ ////
+ //// def parseRowKey(rowKey: HBaseRawType): Seq[HBaseRawType]
+ ////
+ //// def parseRowKeyWithMetaData(rkCols: Seq[KeyColumn], rowKey: HBaseRawType)
+ //// : SortedMap[TableName, (KeyColumn, Any)] // TODO change Any
+ // }
+ //
+ // case class RowKeySpec(offsets: Seq[Int], version: Byte = RowKeyParser.Version1)
+ //
+ // // TODO(Bo): replace the implementation with the null-byte terminated string logic
+ // object RowKeyParser extends AbstractRowKeyParser with Serializable {
+ // val Version1 = 1.toByte
+ // val VersionFieldLen = 1
+ // // Length in bytes of the RowKey version field
+ // val DimensionCountLen = 1
+ // // One byte for the number of key dimensions
+ // val MaxDimensions = 255
+ // val OffsetFieldLen = 2
+ //
+ // // Two bytes for the value of each dimension offset.
+ // // Therefore max size of rowkey is 65535. Note: if longer rowkeys desired in future
+ // // then simply define a new RowKey version to support it. Otherwise would be wasteful
+ // // to define as 4 bytes now.
+ // def computeLength(keys: Seq[HBaseRawType]) = {
+ // VersionFieldLen + keys.map(_.length).sum +
+ // OffsetFieldLen * keys.size + DimensionCountLen
+ // }
+ //
+ // override def createKey(keys: Seq[HBaseRawType], version: Byte = Version1): HBaseRawType = {
+ // val barr = new Array[Byte](computeLength(keys))
+ // val arrayx = new AtomicInteger(0)
+ // barr(arrayx.getAndAdd(VersionFieldLen)) = version // VersionByte
+ //
+ // // Remember the starting offset of first data value
+ // val valuesStartIndex = new AtomicInteger(arrayx.get)
+ //
+ // // copy each of the dimension values in turn
+ // keys.foreach { k => copyToArr(barr, k, arrayx.getAndAdd(k.length))}
+ //
+ // // Copy the offsets of each dim value
+ // // The valuesStartIndex is the location of the first data value and thus the first
+ // // value included in the Offsets sequence
+ // keys.foreach { k =>
+ // copyToArr(barr,
+ // short2b(valuesStartIndex.getAndAdd(k.length).toShort),
+ // arrayx.getAndAdd(OffsetFieldLen))
+ // }
+ // barr(arrayx.get) = keys.length.toByte // DimensionCountByte
+ // barr
+ // }
+ //
+ // def copyToArr[T](a: Array[T], b: Array[T], aoffset: Int) = {
+ // b.copyToArray(a, aoffset)
+ // }
+ //
+ // def short2b(sh: Short): Array[Byte] = {
+ // val barr = Array.ofDim[Byte](2)
+ // barr(0) = ((sh >> 8) & 0xff).toByte
+ // barr(1) = (sh & 0xff).toByte
+ // barr
+ // }
+ //
+ // def b2Short(barr: Array[Byte]) = {
+ // val out = (barr(0).toShort << 8) | barr(1).toShort
+ // out
+ // }
+ //
+ // def createKeyFromCatalystRow(schema: StructType, keyCols: Seq[KeyColumn], row: Row) = {
+ // // val rawKeyCols = DataTypeUtils.catalystRowToHBaseRawVals(schema, row, keyCols)
+ // // createKey(rawKeyCols)
+ // null
+ // }
+ //
+ // def getMinimumRowKeyLength = VersionFieldLen + DimensionCountLen
+ //
+ // override def parseRowKey(rowKey: HBaseRawType): Seq[HBaseRawType] = {
+ // assert(rowKey.length >= getMinimumRowKeyLength,
+ // s"RowKey is invalid format - less than minlen . Actual length=${rowKey.length}")
+ // assert(rowKey(0) == Version1, s"Only Version1 supported. Actual=${rowKey(0)}")
+ // val ndims: Int = rowKey(rowKey.length - 1).toInt
+ // val offsetsStart = rowKey.length - DimensionCountLen - ndims * OffsetFieldLen
+ // val rowKeySpec = RowKeySpec(
+ // for (dx <- 0 to ndims - 1)
+ // yield b2Short(rowKey.slice(offsetsStart + dx * OffsetFieldLen,
+ // offsetsStart + (dx + 1) * OffsetFieldLen))
+ // )
+ //
+ // val endOffsets = rowKeySpec.offsets.tail :+ (rowKey.length - DimensionCountLen - 1)
+ // val colsList = rowKeySpec.offsets.zipWithIndex.map { case (off, ix) =>
+ // rowKey.slice(off, endOffsets(ix))
+ // }
+ // colsList
+ // }
+ //
+ //// //TODO
+ //// override def parseRowKeyWithMetaData(rkCols: Seq[KeyColumn], rowKey: HBaseRawType):
+ //// SortedMap[TableName, (KeyColumn, Any)] = {
+ //// import scala.collection.mutable.HashMap
+ ////
+ //// val rowKeyVals = parseRowKey(rowKey)
+ //// val rmap = rowKeyVals.zipWithIndex.foldLeft(new HashMap[ColumnName, (Column, Any)]()) {
+ //// case (m, (cval, ix)) =>
+ //// m.update(rkCols(ix).toColumnName, (rkCols(ix),
+ //// hbaseFieldToRowField(cval, rkCols(ix).dataType)))
+ //// m
+ //// }
+ //// TreeMap(rmap.toArray: _*)(Ordering.by { cn: ColumnName => rmap(cn)._1.ordinal})
+ //// .asInstanceOf[SortedMap[ColumnName, (Column, Any)]]
+ //// }
+ //
+ // def show(bytes: Array[Byte]) = {
+ // val len = bytes.length
+ // // val out = s"Version=${bytes(0).toInt} NumDims=${bytes(len - 1)} "
+ // }
+ //
+ // }
+
+ def buildRow(projections: Seq[(Attribute, Int)],
+ result: Result,
+ row: MutableRow,
+ bytesUtils: BytesUtils): Row = {
+ assert(projections.size == row.length, "Projection size and row size mismatched")
+ // TODO: replaced with the new Key method
+ val buffer = ListBuffer[HBaseRawType]()
+ val rowKeys = HBaseKVHelper.decodingRawKeyColumns(buffer, result.getRow, keyColumns)
+ projections.foreach { p =>
+ columnMap.get(p._1.name).get match {
+ case column: NonKeyColumn => {
+ val colValue = result.getValue(column.familyRaw, column.qualifierRaw)
+ DataTypeUtils.setRowColumnFromHBaseRawType(row, p._2, colValue,
+ column.dataType, bytesUtils)
+ }
+ case ki => {
+ val keyIndex = ki.asInstanceOf[Int]
+ val rowKey = rowKeys(keyIndex)
+ DataTypeUtils.setRowColumnFromHBaseRawType(row, p._2, rowKey,
+ keyColumns(keyIndex).dataType, bytesUtils)
+ }
+ }
+ }
+ row
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/hbase/HBaseSQLReaderRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/hbase/HBaseSQLReaderRDD.scala
new file mode 100644
index 0000000000000..23e53e2faaf08
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/hbase/HBaseSQLReaderRDD.scala
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hbase
+
+import org.apache.hadoop.hbase.client.Result
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, GenericMutableRow}
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.{InterruptibleIterator, Logging, Partition, TaskContext}
+
+class HBaseSQLReaderRDD(
+ relation: HBaseRelation,
+ output: Seq[Attribute],
+ rowKeyPred: Option[Expression],
+ valuePred: Option[Expression],
+ partitionPred: Option[Expression],
+ coprocSubPlan: Option[SparkPlan])(@transient sqlContext: SQLContext)
+ extends RDD[Row](sqlContext.sparkContext, Nil) with Logging {
+
+ private final val cachingSize: Int = 100 // Todo: be made configurable
+
+ override def getPartitions: Array[Partition] = {
+ relation.getPrunedPartitions(partitionPred).get.toArray
+ }
+
+ override def getPreferredLocations(split: Partition): Seq[String] = {
+ split.asInstanceOf[HBasePartition].server.map {
+ identity
+ }.toSeq
+ }
+
+ override def compute(split: Partition, context: TaskContext): Iterator[Row] = {
+ val filters = relation.buildFilter(output, rowKeyPred, valuePred)
+ val scan = relation.buildScan(split, filters, output)
+ scan.setCaching(cachingSize)
+ logDebug(s"relation.htable scanner conf="
+ + s"${relation.htable.getConfiguration.get("hbase.zookeeper.property.clientPort")}")
+ val scanner = relation.htable.getScanner(scan)
+
+ val row = new GenericMutableRow(output.size)
+ val projections = output.zipWithIndex
+ val bytesUtils = new BytesUtils
+
+ var finished: Boolean = false
+ var gotNext: Boolean = false
+ var result: Result = null
+
+ val iter = new Iterator[Row] {
+ override def hasNext: Boolean = {
+ if (!finished) {
+ if (!gotNext) {
+ result = scanner.next
+ finished = result == null
+ gotNext = true
+ }
+ }
+ if (finished) {
+ close
+ }
+ !finished
+ }
+
+ override def next(): Row = {
+ if (hasNext) {
+ gotNext = false
+ relation.buildRow(projections, result, row, bytesUtils)
+ } else {
+ null
+ }
+ }
+
+ def close() = {
+ try {
+ scanner.close()
+ } catch {
+ case e: Exception => logWarning("Exception in scanner.close", e)
+ }
+ }
+ }
+ new InterruptibleIterator(context, iter)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/hbase/IndexMappable.scala b/sql/core/src/main/scala/org/apache/spark/sql/hbase/IndexMappable.scala
new file mode 100755
index 0000000000000..e2d5daac2f505
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/hbase/IndexMappable.scala
@@ -0,0 +1,21 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.hbase
+
+private[hbase] trait IndexMappable {
+ val mappedIndex: Int
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/hbase/NotPusher.scala b/sql/core/src/main/scala/org/apache/spark/sql/hbase/NotPusher.scala
new file mode 100755
index 0000000000000..3b6a0c40641d0
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/hbase/NotPusher.scala
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hbase
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.rules._
+
+/**
+ * Pushes NOT through And/Or
+ */
+object NOTPusher extends Rule[Expression] {
+ def apply(pred: Expression): Expression = pred transformDown {
+ case Not(And(left, right)) => Or(Not(left), Not(right))
+ case Not(Or(left, right)) => And(Not(left), Not(right))
+ case not @ Not(exp) => {
+ // This pattern has been caught by optimizer but after NOT pushdown
+ // more opportunities may present
+ exp match {
+ case GreaterThan(l, r) => LessThanOrEqual(l, r)
+ case GreaterThanOrEqual(l, r) => LessThan(l, r)
+ case LessThan(l, r) => GreaterThanOrEqual(l, r)
+ case LessThanOrEqual(l, r) => GreaterThan(l, r)
+ case Not(e) => e
+ case _ => not
+ }
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/hbase/PartialPredEval.scala b/sql/core/src/main/scala/org/apache/spark/sql/hbase/PartialPredEval.scala
new file mode 100755
index 0000000000000..1f29f1cc3682e
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/hbase/PartialPredEval.scala
@@ -0,0 +1,480 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hbase
+
+import org.apache.spark.sql.catalyst.errors.TreeNodeException
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.types.{DataType, NativeType}
+
+
+object PartialPredicateOperations {
+
+ // Partial evaluation is nullness-based, i.e., uninterested columns are assigned nulls,
+ // which necessitates changes of the null handling from the normal evaluations
+ // of predicate expressions
+ implicit class partialPredicateEvaluator(e: Expression) {
+ def partialEval(input: Row): Any = {
+ e match {
+ case And(left, right) => {
+ val l = left.partialEval(input)
+ if (l == false) {
+ false
+ } else {
+ val r = right.partialEval(input)
+ if (r == false) {
+ false
+ } else {
+ if (l != null && r != null) {
+ true
+ } else {
+ null
+ }
+ }
+ }
+ }
+ case Or(left, right) => {
+ val l = left.partialEval(input)
+ if (l == true) {
+ true
+ } else {
+ val r = right.partialEval(input)
+ if (r == true) {
+ true
+ } else {
+ if (l != null && r != null) {
+ false
+ } else {
+ null
+ }
+ }
+ }
+ }
+ case Not(child) => {
+ child.partialEval(input) match {
+ case null => null
+ case b: Boolean => !b
+ }
+ }
+ case In(value, list) => {
+ val evaluatedValue = value.partialEval(input)
+ if (evaluatedValue == null) {
+ null
+ } else {
+ val evaluatedList = list.map(_.partialEval(input))
+ if (evaluatedList.exists(e=> e == evaluatedValue)) {
+ true
+ } else if (evaluatedList.exists(e=> e == null)) {
+ null
+ } else {
+ false
+ }
+ }
+ }
+ case InSet(value, hset) => {
+ val evaluatedValue = value.partialEval(input)
+ if (evaluatedValue == null) {
+ null
+ } else {
+ hset.contains(evaluatedValue)
+ }
+ }
+ case l: LeafExpression => l.eval(input)
+ case b: BoundReference => b.eval(input) //Really a LeafExpression but not declared as such
+ case n: NamedExpression => n.eval(input) //Really a LeafExpression but not declared as such
+ case IsNull(child) => {
+ if (child.partialEval(input) == null) {
+ // In partial evaluation, null indicates MAYBE
+ null
+ } else {
+ // Now we only support non-nullable primary key components
+ false
+ }
+ }
+ // TODO: CAST/Arithithmetic can be treated more nicely
+ case Cast(_, _) => null
+ // case BinaryArithmetic => null
+ case UnaryMinus(_) => null
+ case EqualTo(left, right) => {
+ val cmp = pc2(input, left, right)
+ if (cmp.isDefined) {
+ cmp.get == 0
+ } else {
+ null
+ }
+ }
+ case LessThan(left, right) => {
+ val cmp = pc2(input, left, right)
+ if (cmp.isDefined) {
+ cmp.get < 0
+ } else {
+ null
+ }
+ }
+ case LessThanOrEqual(left, right) => {
+ val cmp = pc2(input, left, right)
+ if (cmp.isDefined) {
+ cmp.get <= 0
+ } else {
+ null
+ }
+ }
+ case GreaterThan(left, right) => {
+ val cmp = pc2(input, left, right)
+ if (cmp.isDefined) {
+ cmp.get > 0
+ } else {
+ null
+ }
+ }
+ case GreaterThanOrEqual(left, right) => {
+ val cmp = pc2(input, left, right)
+ if (cmp.isDefined) {
+ cmp.get >= 0
+ } else {
+ null
+ }
+ }
+ case If(predicate, trueE, falseE) => {
+ val v = predicate.partialEval(input)
+ if (v == null) {
+ null
+ } else if (v.asInstanceOf[Boolean]) {
+ trueE.partialEval(input)
+ } else {
+ falseE.partialEval(input)
+ }
+ }
+ case _ => null
+ }
+ }
+
+ @inline
+ protected def pc2(
+ i: Row,
+ e1: Expression,
+ e2: Expression): Option[Int] = {
+ if (e1.dataType != e2.dataType) {
+ throw new TreeNodeException(e, s"Types do not match ${e1.dataType} != ${e2.dataType}")
+ }
+
+ val evalE1 = e1.partialEval(i)
+ if (evalE1 == null) {
+ None
+ } else {
+ val evalE2 = e2.partialEval(i)
+ if (evalE2 == null) {
+ None
+ } else {
+ e1.dataType match {
+ case nativeType: NativeType => {
+ val pdt = RangeType.primitiveToPODataTypeMap.get(nativeType).getOrElse(null)
+ if (pdt == null) {
+ sys.error(s"Type $i does not have corresponding partial ordered type")
+ } else {
+ pdt.partialOrdering.tryCompare(
+ pdt.toPartiallyOrderingDataType(evalE1, nativeType).asInstanceOf[pdt.JvmType],
+ pdt.toPartiallyOrderingDataType(evalE2, nativeType).asInstanceOf[pdt.JvmType])
+ }
+ }
+ case other => sys.error(s"Type $other does not support partially ordered operations")
+ }
+ }
+ }
+ }
+ }
+
+ // Partial reduction is nullness-based, i.e., uninterested columns are assigned nulls,
+ // which necessitates changes of the null handling from the normal evaluations
+ // of predicate expressions
+ // There are 3 possible results: TRUE, FALSE, and MAYBE represented by a predicate
+ // which will be used to further filter the results
+ implicit class partialPredicateReducer(e: Expression) {
+ def partialReduce(input: Row): Any = {
+ e match {
+ case And(left, right) => {
+ val l = left.partialReduce(input)
+ if (l == false) {
+ false
+ } else {
+ val r = right.partialReduce(input)
+ if (r == false) {
+ false
+ } else {
+ (l, r) match {
+ case (true, true) => true
+ case (true, _) => r
+ case (_, true) => l
+ case (nl: Expression, nr: Expression) => {
+ if ((nl fastEquals left) && (nr fastEquals right)) {
+ e
+ } else {
+ And(nl, nr)
+ }
+ }
+ case _ => sys.error("unexpected child type(s) in partial reduction")
+ }
+ }
+ }
+ }
+ case Or(left, right) => {
+ val l = left.partialReduce(input)
+ if (l == true) {
+ true
+ } else {
+ val r = right.partialReduce(input)
+ if (r == true) {
+ true
+ } else {
+ (l, r) match {
+ case (false, false) => false
+ case (false, _) => r
+ case (_, false) => l
+ case (nl: Expression, nr: Expression) => {
+ if ((nl fastEquals left) && (nr fastEquals right)) {
+ e
+ } else {
+ Or(nl, nr)
+ }
+ }
+ case _ => sys.error("unexpected child type(s) in partial reduction")
+ }
+ }
+ }
+ }
+ case Not(child) => {
+ child.partialReduce(input) match {
+ case b: Boolean => !b
+ case ec: Expression => if (ec fastEquals child) { e } else { Not(ec) }
+ }
+ }
+ case In(value, list) => {
+ val evaluatedValue = value.partialReduce(input)
+ if (evaluatedValue.isInstanceOf[Expression]) {
+ val evaluatedList = list.map(e=>e.partialReduce(input) match {
+ case e: Expression => e
+ case d => Literal(d, e.dataType)
+ })
+ In(evaluatedValue.asInstanceOf[Expression], evaluatedList)
+ } else {
+ val evaluatedList = list.map(_.partialReduce(input))
+ if (evaluatedList.exists(e=> e == evaluatedValue)) {
+ true
+ } else {
+ val newList = evaluatedList.filter(_.isInstanceOf[Expression])
+ .map(_.asInstanceOf[Expression])
+ if (newList.isEmpty) {
+ false
+ } else {
+ In(Literal(evaluatedValue, value.dataType), newList)
+ }
+ }
+ }
+ }
+ case InSet(value, hset) => {
+ val evaluatedValue = value.partialReduce(input)
+ if (evaluatedValue.isInstanceOf[Expression]) {
+ InSet(evaluatedValue.asInstanceOf[Expression], hset)
+ } else {
+ hset.contains(evaluatedValue)
+ }
+ }
+ case l: LeafExpression => {
+ val res = l.eval(input)
+ if (res == null) { l } else {res}
+ }
+ case b: BoundReference => {
+ val res = b.eval(input)
+ // If the result is a MAYBE, returns the original expression
+ if (res == null) { b } else {res}
+ }
+ case n: NamedExpression => {
+ val res = n.eval(input)
+ if(res == null) { n } else { res }
+ }
+ case IsNull(child) => e
+ // TODO: CAST/Arithithmetic could be treated more nicely
+ case Cast(_, _) => e
+ // case BinaryArithmetic => null
+ case UnaryMinus(_) => e
+ case EqualTo(left, right) => {
+ val evalL = left.partialReduce(input)
+ val evalR = right.partialReduce(input)
+ if (evalL.isInstanceOf[Expression] && evalR.isInstanceOf[Expression]) {
+ EqualTo(evalL.asInstanceOf[Expression], evalR.asInstanceOf[Expression])
+ } else if (evalL.isInstanceOf[Expression]) {
+ EqualTo(evalL.asInstanceOf[Expression], right)
+ } else if (evalR.isInstanceOf[Expression]) {
+ EqualTo(left.asInstanceOf[Expression], evalR.asInstanceOf[Expression])
+ } else {
+ val cmp = prc2(input, left.dataType, right.dataType, evalL, evalR)
+ if (cmp.isDefined) {
+ cmp.get == 0
+ } else {
+ e
+ }
+ }
+ }
+ case LessThan(left, right) => {
+ val evalL = left.partialReduce(input)
+ val evalR = right.partialReduce(input)
+ if (evalL.isInstanceOf[Expression] && evalR.isInstanceOf[Expression]) {
+ EqualTo(evalL.asInstanceOf[Expression], evalR.asInstanceOf[Expression])
+ } else if (evalL.isInstanceOf[Expression]) {
+ EqualTo(evalL.asInstanceOf[Expression], right)
+ } else if (evalR.isInstanceOf[Expression]) {
+ EqualTo(left, evalR.asInstanceOf[Expression])
+ } else {
+ val cmp = prc2(input, left.dataType, right.dataType, evalL, evalR)
+ if (cmp.isDefined) {
+ cmp.get < 0
+ } else {
+ e
+ }
+ }
+ }
+ case LessThanOrEqual(left, right) => {
+ val evalL = left.partialReduce(input)
+ val evalR = right.partialReduce(input)
+ if (evalL.isInstanceOf[Expression] && evalR.isInstanceOf[Expression]) {
+ EqualTo(evalL.asInstanceOf[Expression], evalR.asInstanceOf[Expression])
+ } else if (evalL.isInstanceOf[Expression]) {
+ EqualTo(evalL.asInstanceOf[Expression], right)
+ } else if (evalR.isInstanceOf[Expression]) {
+ EqualTo(left, evalR.asInstanceOf[Expression])
+ } else {
+ val cmp = prc2(input, left.dataType, right.dataType, evalL, evalR)
+ if (cmp.isDefined) {
+ cmp.get <= 0
+ } else {
+ e
+ }
+ }
+ }
+ case GreaterThan(left, right) => {
+ val evalL = left.partialReduce(input)
+ val evalR = right.partialReduce(input)
+ if (evalL.isInstanceOf[Expression] && evalR.isInstanceOf[Expression]) {
+ EqualTo(evalL.asInstanceOf[Expression], evalR.asInstanceOf[Expression])
+ } else if (evalL.isInstanceOf[Expression]) {
+ EqualTo(evalL.asInstanceOf[Expression], right)
+ } else if (evalR.isInstanceOf[Expression]) {
+ EqualTo(left, evalR.asInstanceOf[Expression])
+ } else {
+ val cmp = prc2(input, left.dataType, right.dataType, evalL, evalR)
+ if (cmp.isDefined) {
+ cmp.get > 0
+ } else {
+ e
+ }
+ }
+ }
+ case GreaterThanOrEqual(left, right) => {
+ val evalL = left.partialReduce(input)
+ val evalR = right.partialReduce(input)
+ if (evalL.isInstanceOf[Expression] && evalR.isInstanceOf[Expression]) {
+ EqualTo(evalL.asInstanceOf[Expression], evalR.asInstanceOf[Expression])
+ } else if (evalL.isInstanceOf[Expression]) {
+ EqualTo(evalL.asInstanceOf[Expression], right)
+ } else if (evalR.isInstanceOf[Expression]) {
+ EqualTo(left, evalR.asInstanceOf[Expression])
+ } else {
+ val cmp = prc2(input, left.dataType, right.dataType, evalL, evalR)
+ if (cmp.isDefined) {
+ cmp.get >= 0
+ } else {
+ e
+ }
+ }
+ }
+ case If(predicate, trueE, falseE) => {
+ val v = predicate.partialReduce(input)
+ if (v.isInstanceOf[Expression]) {
+ If(v.asInstanceOf[Expression],
+ trueE.partialReduce(input).asInstanceOf[Expression],
+ falseE.partialReduce(input).asInstanceOf[Expression])
+ } else if (v.asInstanceOf[Boolean]) {
+ trueE.partialReduce(input)
+ } else {
+ falseE.partialReduce(input)
+ }
+ }
+ case _ => e
+ }
+ }
+
+ @inline
+ protected def pc2(
+ i: Row,
+ e1: Expression,
+ e2: Expression): Option[Int] = {
+ if (e1.dataType != e2.dataType) {
+ throw new TreeNodeException(e, s"Types do not match ${e1.dataType} != ${e2.dataType}")
+ }
+
+ val evalE1 = e1.partialEval(i)
+ if (evalE1 == null) {
+ None
+ } else {
+ val evalE2 = e2.partialEval(i)
+ if (evalE2 == null) {
+ None
+ } else {
+ e1.dataType match {
+ case nativeType: NativeType => {
+ val pdt = RangeType.primitiveToPODataTypeMap.get(nativeType).getOrElse(null)
+ if (pdt == null) {
+ sys.error(s"Type $i does not have corresponding partial ordered type")
+ } else {
+ pdt.partialOrdering.tryCompare(
+ pdt.toPartiallyOrderingDataType(evalE1, nativeType).asInstanceOf[pdt.JvmType],
+ pdt.toPartiallyOrderingDataType(evalE2, nativeType).asInstanceOf[pdt.JvmType])
+ }
+ }
+ case other => sys.error(s"Type $other does not support partially ordered operations")
+ }
+ }
+ }
+ }
+
+ @inline
+ protected def prc2(
+ i: Row,
+ dataType1: DataType,
+ dataType2: DataType,
+ eval1: Any,
+ eval2: Any): Option[Int] = {
+ if (dataType1 != dataType2) {
+ throw new TreeNodeException(e, s"Types do not match ${dataType1} != ${dataType2}")
+ }
+
+ dataType1 match {
+ case nativeType: NativeType => {
+ val pdt = RangeType.primitiveToPODataTypeMap.get(nativeType).getOrElse(null)
+ if (pdt == null) {
+ sys.error(s"Type $i does not have corresponding partial ordered type")
+ } else {
+ pdt.partialOrdering.tryCompare(
+ pdt.toPartiallyOrderingDataType(eval1, nativeType).asInstanceOf[pdt.JvmType],
+ pdt.toPartiallyOrderingDataType(eval2, nativeType).asInstanceOf[pdt.JvmType])
+ }
+ }
+ case other => sys.error(s"Type $other does not support partially ordered operations")
+ }
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/hbase/PartiallyOrderingDataType.scala b/sql/core/src/main/scala/org/apache/spark/sql/hbase/PartiallyOrderingDataType.scala
new file mode 100755
index 0000000000000..2c229d432dbaf
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/hbase/PartiallyOrderingDataType.scala
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.hbase
+
+import org.apache.spark.sql.catalyst.types._
+import scala.math.PartialOrdering
+import scala.reflect.runtime.universe.TypeTag
+
+abstract class PartiallyOrderingDataType extends DataType {
+ private[sql] type JvmType
+
+ def toPartiallyOrderingDataType(s: Any, dt: NativeType): Any
+
+ @transient private[sql] val tag: TypeTag[JvmType]
+
+ private[sql] val partialOrdering: PartialOrdering[JvmType]
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/hbase/RangeType.scala b/sql/core/src/main/scala/org/apache/spark/sql/hbase/RangeType.scala
new file mode 100755
index 0000000000000..4f8c999493b83
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/hbase/RangeType.scala
@@ -0,0 +1,211 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.hbase
+
+import java.sql.Timestamp
+
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.types._
+
+import scala.collection.immutable.HashMap
+import scala.language.implicitConversions
+import scala.math.PartialOrdering
+import scala.reflect.runtime.universe.typeTag
+
+class Range[T](val start: Option[T], // None for open ends
+ val startInclusive: Boolean,
+ val end: Option[T], // None for open ends
+ val endInclusive: Boolean,
+ val dt: NativeType) {
+ require(dt != null && !(start.isDefined && end.isDefined &&
+ ((dt.ordering.eq(start.get, end.get) &&
+ (!startInclusive || !endInclusive)) ||
+ (dt.ordering.gt(start.get.asInstanceOf[dt.JvmType], end.get.asInstanceOf[dt.JvmType])))),
+ "Inappropriate range parameters")
+}
+
+// HBase ranges:
+// @param
+// id: partition id to be used to map to a HBase physical partition
+class PartitionRange[T](start: Option[T], startInclusive: Boolean,
+ end: Option[T], endInclusive: Boolean,
+ val id: Int, dt: NativeType, var pred: Expression)
+ extends Range[T](start, startInclusive, end, endInclusive, dt)
+
+// A PointRange is a range of a single point. It is used for convenience when
+// do comparison on two values of the same type. An alternatively would be to
+// use multiple (overloaded) comparison methods, which could be more natural
+// but also more codes
+
+//class PointRange[T](value: T, dt:NativeType)
+// extends Range[T](Some(value), true, Some(value), true, dt)
+
+
+class RangeType[T] extends PartiallyOrderingDataType {
+ private[sql] type JvmType = Range[T]
+ @transient private[sql] val tag = typeTag[JvmType]
+
+ def toPartiallyOrderingDataType(s: Any, dt: NativeType): Any = s match {
+ case i: Int => new Range[Int](Some(i), true, Some(i), true, IntegerType)
+ case l: Long => new Range[Long](Some(l), true, Some(l), true, LongType)
+ case d: Double => new Range[Double](Some(d), true, Some(d), true, DoubleType)
+ case f: Float => new Range[Float](Some(f), true, Some(f), true, FloatType)
+ case b: Byte => new Range[Byte](Some(b), true, Some(b), true, ByteType)
+ case s: Short => new Range[Short](Some(s), true, Some(s), true, ShortType)
+ case s: String => new Range[String](Some(s), true, Some(s), true, StringType)
+ case b: Boolean => new Range[Boolean](Some(b), true, Some(b), true, BooleanType)
+ // todo: fix bigdecimal issue, now this will leads to comile error
+ case d: BigDecimal => new Range[BigDecimal](Some(d), true, Some(d), true, DecimalType.Unlimited)
+ case t: Timestamp => new Range[Timestamp](Some(t), true, Some(t), true, TimestampType)
+ case _ => s
+ }
+
+ val partialOrdering = new PartialOrdering[JvmType] {
+ // Right now we just support comparisons between a range and a point
+ // In the future when more generic range comparisons, these two methods
+ // must be functional as expected
+ def tryCompare(a: JvmType, b: JvmType): Option[Int] = {
+ val aRange = a.asInstanceOf[Range[T]]
+ val aStartInclusive = aRange.startInclusive
+ val aStart = aRange.start.getOrElse(null).asInstanceOf[aRange.dt.JvmType]
+ val aEnd = aRange.end.getOrElse(null).asInstanceOf[aRange.dt.JvmType]
+ val aEndInclusive = aRange.endInclusive
+ val bRange = b.asInstanceOf[Range[T]]
+ val bStart = bRange.start.getOrElse(null).asInstanceOf[aRange.dt.JvmType]
+ val bEnd = bRange.end.getOrElse(null).asInstanceOf[aRange.dt.JvmType]
+ val bStartInclusive = bRange.startInclusive
+ val bEndInclusive = bRange.endInclusive
+
+ // return 1 iff aStart > bEnd
+ // return 1 iff aStart = bEnd, aStartInclusive & bEndInclusive are not true at same position
+ if ((aStart != null && bEnd != null)
+ && (aRange.dt.ordering.gt(aStart, bEnd)
+ || (aRange.dt.ordering.equiv(aStart, bEnd) && !(aStartInclusive && bEndInclusive)))) {
+ Some(1)
+ } //Vice versa
+ else if ((bStart != null && aEnd != null)
+ && (aRange.dt.ordering.gt(bStart, aEnd)
+ || (aRange.dt.ordering.equiv(bStart, aEnd) && !(bStartInclusive && aEndInclusive)))) {
+ Some(-1)
+ } else if (aStart != null && aEnd != null && bStart != null && bEnd != null &&
+ aRange.dt.ordering.equiv(bStart, aEnd)
+ && aRange.dt.ordering.equiv(aStart, aEnd)
+ && aRange.dt.ordering.equiv(bStart, bEnd)
+ && (aStartInclusive && aEndInclusive && bStartInclusive && bEndInclusive)) {
+ Some(0)
+ } else {
+ None
+ }
+ }
+
+ def lteq(a: JvmType, b: JvmType): Boolean = {
+ // [(aStart, aEnd)] and [(bStart, bEnd)]
+ // [( and )] mean the possibilities of the inclusive and exclusive condition
+ val aRange = a.asInstanceOf[Range[T]]
+ val aStartInclusive = aRange.startInclusive
+ val aEnd = aRange.end.getOrElse(null)
+ val aEndInclusive = aRange.endInclusive
+ val bRange = b.asInstanceOf[Range[T]]
+ val bStart = bRange.start.getOrElse(null)
+ val bStartInclusive = bRange.startInclusive
+ val bEndInclusive = bRange.endInclusive
+
+ // Compare two ranges, return true iff the upper bound of the lower range is lteq to
+ // the lower bound of the upper range. Because the exclusive boundary could be null, which
+ // means the boundary could be infinity, we need to further check this conditions.
+ val result =
+ (aStartInclusive, aEndInclusive, bStartInclusive, bEndInclusive) match {
+ // [(aStart, aEnd] compare to [bStart, bEnd)]
+ case (_, true, true, _) => {
+ if (aRange.dt.ordering.lteq(aEnd.asInstanceOf[aRange.dt.JvmType],
+ bStart.asInstanceOf[aRange.dt.JvmType])) {
+ true
+ } else {
+ false
+ }
+ }
+ // [(aStart, aEnd] compare to (bStart, bEnd)]
+ case (_, true, false, _) => {
+ if (bStart != null && aRange.dt.ordering.lteq(aEnd.asInstanceOf[aRange.dt.JvmType],
+ bStart.asInstanceOf[aRange.dt.JvmType])) {
+ true
+ } else {
+ false
+ }
+ }
+ // [(aStart, aEnd) compare to [bStart, bEnd)]
+ case (_, false, true, _) => {
+ if (aEnd != null && aRange.dt.ordering.lteq(aEnd.asInstanceOf[aRange.dt.JvmType],
+ bStart.asInstanceOf[aRange.dt.JvmType])) {
+ true
+ } else {
+ false
+ }
+ }
+ // [(aStart, aEnd) compare to (bStart, bEnd)]
+ case (_, false, false, _) => {
+ if (aEnd != null && bStart != null &&
+ aRange.dt.ordering.lteq(aEnd.asInstanceOf[aRange.dt.JvmType],
+ bStart.asInstanceOf[aRange.dt.JvmType])) {
+ true
+ } else {
+ false
+ }
+ }
+ }
+
+ result
+ }
+ }
+}
+
+object RangeType {
+
+ object StringRangeType extends RangeType[String]
+
+ object IntegerRangeType extends RangeType[Int]
+
+ object LongRangeType extends RangeType[Long]
+
+ object DoubleRangeType extends RangeType[Double]
+
+ object FloatRangeType extends RangeType[Float]
+
+ object ByteRangeType extends RangeType[Byte]
+
+ object ShortRangeType extends RangeType[Short]
+
+ object BooleanRangeType extends RangeType[Boolean]
+
+ object DecimalRangeType extends RangeType[BigDecimal]
+
+ object TimestampRangeType extends RangeType[Timestamp]
+
+ val primitiveToPODataTypeMap: HashMap[NativeType, PartiallyOrderingDataType] =
+ HashMap(
+ IntegerType -> IntegerRangeType,
+ LongType -> LongRangeType,
+ DoubleType -> DoubleRangeType,
+ FloatType -> FloatRangeType,
+ ByteType -> ByteRangeType,
+ ShortType -> ShortRangeType,
+ BooleanType -> BooleanRangeType,
+ DecimalType.Unlimited -> DecimalRangeType,
+ TimestampType -> TimestampRangeType,
+ StringType -> StringRangeType
+ )
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/hbase/ScanPredClassfier.scala b/sql/core/src/main/scala/org/apache/spark/sql/hbase/ScanPredClassfier.scala
new file mode 100755
index 0000000000000..d0f312242bd99
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/hbase/ScanPredClassfier.scala
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hbase
+
+import org.apache.spark.sql.catalyst.expressions._
+
+/**
+ * Classfies a predicate into a pair of (push-downable, non-push-downable) predicates
+ * for a Scan; the logic relationship between the two components of the pair is AND
+ */
+class ScanPredClassfier(relation: HBaseRelation, keyIndex: Int) {
+ def apply(pred: Expression): (Option[Expression], Option[Expression]) = {
+ // post-order bottom-up traversal
+ pred match {
+ case And(left, right) => {
+ val (ll, lr) = apply(left)
+ val (rl, rr) = apply(right)
+ (ll, lr, rl, rr) match {
+ // All Nones
+ case (None, None, None, None) => (None, None)
+ // Three Nones
+ case (None, None, None, _) => (None, rr)
+ case (None, None, _, None) => (rl, None)
+ case (None, _, None, None) => (None, lr)
+ case (_, None, None, None) => (ll, None)
+ // two Nones
+ case (None, None, _, _) => (rl, rr)
+ case (None, _, None, _) => (None, Some(And(rl.get, rr.get)))
+ case (None, _, _, None) => (rl, lr)
+ case (_, None, None, _) => (ll, rr)
+ case (_, None, _, None) => (Some(And(ll.get, rl.get)), None)
+ case (_, _, None, None) => (ll, lr)
+ // One None
+ case (None, _, _, _) => (rl, Some(And(lr.get, rr.get)))
+ case (_, None, _, _) => (Some(And(ll.get, rl.get)), rr)
+ case (_, _, None, _) => (ll, Some(And(lr.get, rr.get)))
+ case (_, _, _, None) => (Some(And(ll.get, rl.get)), lr)
+ // No nones
+ case _ => (Some(And(ll.get, rl.get)), Some(And(lr.get, rr.get)))
+ }
+ }
+ case Or(left, right) => {
+ val (ll, lr) = apply(left)
+ val (rl, rr) = apply(right)
+ (ll, lr, rl, rr) match {
+ // All Nones
+ case (None, None, None, None) => (None, None)
+ // Three Nones
+ case (None, None, None, _) => (None, rr)
+ case (None, None, _, None) => (rl, None)
+ case (None, _, None, None) => (None, lr)
+ case (_, None, None, None) => (ll, None)
+ // two Nones
+ case (None, None, _, _) => (rl, rr)
+ case (None, _, None, _) => (None, Some(Or(rl.get, rr.get)))
+ case (None, _, _, None) => (None, Some(Or(lr.get, rl.get)))
+ case (_, None, None, _) => (None, Some(Or(ll.get, rr.get)))
+ case (_, None, _, None) => (Some(Or(ll.get, rl.get)), None)
+ case (_, _, None, None) => (ll, lr)
+ // One None
+ case (None, _, _, _) => (None, Some(pred))
+ // Accept increased evaluation complexity for improved pushed down
+ case (_, None, _, _) => (Some(Or(ll.get, rl.get)), Some(Or(ll.get, rr.get)))
+ case (_, _, None, _) => (None, Some(pred))
+ // Accept increased evaluation complexity for improved pushed down
+ case (_, _, _, None) => (Some(Or(ll.get, rl.get)), Some(Or(lr.get, rl.get)))
+ // No nones
+ // Accept increased evaluation complexity for improved pushed down
+ case _ => (Some(Or(ll.get, rl.get)), Some(And(Or(ll.get, rr.get),
+ And(Or(lr.get, rl.get), Or(lr.get, rr.get)))))
+ }
+ }
+ case EqualTo(left, right) => classifyBinary(left, right, pred)
+ case LessThan(left, right) => classifyBinary(left, right, pred)
+ case LessThanOrEqual(left, right) => classifyBinary(left, right, pred)
+ case GreaterThan(left, right) => classifyBinary(left, right, pred)
+ case GreaterThanOrEqual(left, right) => classifyBinary(left, right, pred)
+ // everything else are treated as non pushdownable
+ case _ => (None, Some(pred))
+ }
+ }
+
+ // returns true if the binary operator of the two args can be pushed down
+ private def classifyBinary(left: Expression, right: Expression, pred: Expression)
+ : (Option[Expression], Option[Expression]) = {
+ (left, right) match {
+ case (Literal(_, _), AttributeReference(_, _, _, _)) => {
+ if (relation.isNonKey(right.asInstanceOf[AttributeReference])) {
+ (Some(pred), None)
+ } else {
+ val keyIdx = relation.keyIndex(right.asInstanceOf[AttributeReference])
+ if (keyIdx == keyIndex) {
+ (Some(pred), None)
+ } else {
+ (None, Some(pred))
+ }
+ }
+ }
+ case (AttributeReference(_, _, _, _), Literal(_, _)) => {
+ if (relation.isNonKey(left.asInstanceOf[AttributeReference])) {
+ (Some(pred), None)
+ } else {
+ val keyIdx = relation.keyIndex(left.asInstanceOf[AttributeReference])
+ if (keyIdx == keyIndex) {
+ (Some(pred), None)
+ } else {
+ (None, Some(pred))
+ }
+ }
+ }
+ case _ => (None, Some(pred))
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/hbase/hbase.scala b/sql/core/src/main/scala/org/apache/spark/sql/hbase/hbase.scala
new file mode 100644
index 0000000000000..acea2594d1f27
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/hbase/hbase.scala
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hbase
+
+import org.apache.spark.Logging
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.catalyst.expressions.{And, Attribute, Expression, Row}
+import org.apache.spark.sql.catalyst.types.StructType
+import org.apache.spark.sql.sources.{BaseRelation, CatalystScan, RelationProvider}
+
+/**
+ * Allows creation of parquet based tables using the syntax
+ * `CREATE TEMPORARY TABLE table_name(field1 filed1_type, filed2 filed2_type...)
+ * USING org.apache.spark.sql.hbase
+ * OPTIONS (
+ * hbase_table "hbase_table_name",
+ * mapping "filed1=cf1.column1, filed2=cf2.column2...",
+ * primary_key "filed_name1, field_name2"
+ * )`.
+ */
+class DefaultSource extends RelationProvider with Logging {
+ /** Returns a new base relation with the given parameters. */
+ override def createRelation(
+ sqlContext: SQLContext,
+ parameters: Map[String, String],
+ schema: Option[StructType]): BaseRelation = {
+
+ assert(schema.nonEmpty, "schema can not be empty for hbase rouce!")
+ assert(parameters.get("hbase_table").nonEmpty, "no option for hbase.table")
+ assert(parameters.get("mapping").nonEmpty, "no option for mapping")
+ assert(parameters.get("primary_key").nonEmpty, "no option for mapping")
+
+ val hbaseTableName = parameters.getOrElse("hbase_table", "").toLowerCase
+ val mapping = parameters.getOrElse("mapping", "").toLowerCase
+ val primaryKey = parameters.getOrElse("primary_key", "").toLowerCase()
+ val partValue = "([^=]+)=([^=]+)".r
+
+ val fieldByHbaseColumn = mapping.split(",").map {
+ case partValue(key, value) => (key, value)
+ }
+ val keyColumns = primaryKey.split(",").map(_.trim)
+
+ // check the mapping is legal
+ val fieldSet = schema.get.fields.map(_.name).toSet
+ fieldByHbaseColumn.iterator.map(_._1).foreach { field =>
+ assert(fieldSet.contains(field), s"no field named $field in table")
+ }
+ HBaseScanBuilder("", hbaseTableName, keyColumns, fieldByHbaseColumn, schema.get)(sqlContext)
+ }
+}
+
+@DeveloperApi
+case class HBaseScanBuilder(
+ tableName: String,
+ hbaseTableName: String,
+ keyColumns: Seq[String],
+ fieldByHbaseColumn: Seq[(String, String)],
+ schema: StructType)(context: SQLContext) extends CatalystScan with Logging {
+
+ val hbaseMetadata = new HBaseMetadata
+
+ val filedByHbaseFamilyAndColumn = fieldByHbaseColumn.toMap
+
+ def allColumns() = schema.fields.map{ field =>
+ val fieldName = field.name
+ if(keyColumns.contains(fieldName)) {
+ KeyColumn(fieldName, field.dataType, keyColumns.indexOf(fieldName))
+ } else {
+ val familyAndQuilifier = filedByHbaseFamilyAndColumn.getOrElse(fieldName, "").split("\\.")
+ assert(familyAndQuilifier.size == 2, "illegal mapping")
+ NonKeyColumn(fieldName, field.dataType, familyAndQuilifier(0), familyAndQuilifier(1))
+ }
+ }
+
+ val relation = hbaseMetadata.createTable(tableName, hbaseTableName, allColumns)
+
+ override def sqlContext: SQLContext = context
+
+ // todo: optimization for predict push down
+ override def buildScan(output: Seq[Attribute], predicates: Seq[Expression]): RDD[Row] = {
+ new HBaseSQLReaderRDD(
+ relation,
+ schema.toAttributes,
+ None,
+ None,
+ predicates.reduceLeftOption(And),
+ None
+ )(sqlContext)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/hbase/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/hbase/package.scala
new file mode 100755
index 0000000000000..28ecf9560497d
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/hbase/package.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql
+
+import org.apache.hadoop.hbase.KeyValue
+import org.apache.hadoop.hbase.client.Put
+import org.apache.hadoop.hbase.io.ImmutableBytesWritable
+
+import scala.collection.mutable.ArrayBuffer
+
+package object hbase {
+ type HBaseRawType = Array[Byte]
+
+ class ImmutableBytesWritableWrapper(rowKey: Array[Byte])
+ extends Serializable {
+
+ def compareTo(that: ImmutableBytesWritableWrapper): Int = {
+ this.toImmutableBytesWritable() compareTo that.toImmutableBytesWritable()
+ }
+
+ def toImmutableBytesWritable() = new ImmutableBytesWritable(rowKey)
+ }
+
+ class PutWrapper(rowKey: Array[Byte]) extends Serializable {
+ val fqv = new ArrayBuffer[(Array[Byte], Array[Byte], Array[Byte])]
+
+ def add(family: Array[Byte], qualifier: Array[Byte], value: Array[Byte]) =
+ fqv += ((family, qualifier, value))
+
+ def toPut() = {
+ val put = new Put(rowKey)
+ fqv.foreach { fqv =>
+ put.add(fqv._1, fqv._2, fqv._3)
+ }
+ put
+ }
+ }
+
+ class KeyValueWrapper(
+ rowKey: Array[Byte],
+ family: Array[Byte],
+ qualifier: Array[Byte],
+ value: Array[Byte]) extends Serializable {
+
+ def toKeyValue() = new KeyValue(rowKey, family, qualifier, value)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
index fc70c183437f6..f2796161536fa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
@@ -18,31 +18,38 @@
package org.apache.spark.sql.json
import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.catalyst.types.StructType
import org.apache.spark.sql.sources._
private[sql] class DefaultSource extends RelationProvider {
/** Returns a new base relation with the given parameters. */
override def createRelation(
sqlContext: SQLContext,
- parameters: Map[String, String]): BaseRelation = {
+ parameters: Map[String, String],
+ schema: Option[StructType]): BaseRelation = {
val fileName = parameters.getOrElse("path", sys.error("Option 'path' not specified"))
val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)
- JSONRelation(fileName, samplingRatio)(sqlContext)
+ JSONRelation(fileName, samplingRatio, schema)(sqlContext)
}
}
-private[sql] case class JSONRelation(fileName: String, samplingRatio: Double)(
+private[sql] case class JSONRelation(
+ fileName: String,
+ samplingRatio: Double,
+ _schema: Option[StructType])(
@transient val sqlContext: SQLContext)
extends TableScan {
private def baseRDD = sqlContext.sparkContext.textFile(fileName)
override val schema =
+ _schema.getOrElse(
JsonRDD.inferSchema(
baseRDD,
samplingRatio,
sqlContext.columnNameOfCorruptRecord)
+ )
override def buildScan() =
JsonRDD.jsonStringToRow(baseRDD, schema, sqlContext.columnNameOfCorruptRecord)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
index 9b89c3bfb3307..3247621c47297 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
@@ -22,21 +22,20 @@ import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
import org.apache.hadoop.conf.{Configurable, Configuration}
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapreduce.{JobContext, InputSplit, Job}
-
import parquet.hadoop.ParquetInputFormat
import parquet.hadoop.util.ContextUtil
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.{Partition => SparkPartition, Logging}
import org.apache.spark.rdd.{NewHadoopPartition, RDD}
-
-import org.apache.spark.sql.{SQLConf, Row, SQLContext}
-import org.apache.spark.sql.catalyst.expressions.{SpecificMutableRow, And, Expression, Attribute}
-import org.apache.spark.sql.catalyst.types.{IntegerType, StructField, StructType}
+import org.apache.spark.sql.catalyst.expressions.{Row, And, SpecificMutableRow, Expression, Attribute}
+import org.apache.spark.sql.catalyst.types.{StructField, IntegerType, StructType}
import org.apache.spark.sql.sources._
+import org.apache.spark.sql.{SQLConf, SQLContext}
import scala.collection.JavaConversions._
+
/**
* Allows creation of parquet based tables using the syntax
* `CREATE TEMPORARY TABLE ... USING org.apache.spark.sql.parquet`. Currently the only option
@@ -47,11 +46,12 @@ class DefaultSource extends RelationProvider {
/** Returns a new base relation with the given parameters. */
override def createRelation(
sqlContext: SQLContext,
- parameters: Map[String, String]): BaseRelation = {
+ parameters: Map[String, String],
+ schema: Option[StructType]): BaseRelation = {
val path =
parameters.getOrElse("path", sys.error("'path' must be specified for parquet tables."))
- ParquetRelation2(path)(sqlContext)
+ ParquetRelation2(path, schema)(sqlContext)
}
}
@@ -81,7 +81,9 @@ private[parquet] case class Partition(partitionValues: Map[String, Any], files:
* discovery.
*/
@DeveloperApi
-case class ParquetRelation2(path: String)(@transient val sqlContext: SQLContext)
+case class ParquetRelation2(
+ path: String,
+ _schema: Option[StructType])(@transient val sqlContext: SQLContext)
extends CatalystScan with Logging {
def sparkContext = sqlContext.sparkContext
@@ -132,12 +134,13 @@ case class ParquetRelation2(path: String)(@transient val sqlContext: SQLContext)
override val sizeInBytes = partitions.flatMap(_.files).map(_.getLen).sum
- val dataSchema = StructType.fromAttributes( // TODO: Parquet code should not deal with attributes.
- ParquetTypesConverter.readSchemaFromFile(
- partitions.head.files.head.getPath,
- Some(sparkContext.hadoopConfiguration),
- sqlContext.isParquetBinaryAsString))
-
+ val dataSchema = _schema.getOrElse(
+ StructType.fromAttributes( // TODO: Parquet code should not deal with attributes.
+ ParquetTypesConverter.readSchemaFromFile(
+ partitions.head.files.head.getPath,
+ Some(sparkContext.hadoopConfiguration),
+ sqlContext.isParquetBinaryAsString))
+ )
val dataIncludesKey =
partitionKeys.headOption.map(dataSchema.fieldNames.contains(_)).getOrElse(true)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
index ca510cb0b07e3..6b3dc79451a7e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
@@ -17,16 +17,15 @@
package org.apache.spark.sql.sources
-import org.apache.spark.Logging
-import org.apache.spark.sql.SQLContext
-import org.apache.spark.sql.execution.RunnableCommand
-import org.apache.spark.util.Utils
-
import scala.language.implicitConversions
-import scala.util.parsing.combinator.lexical.StdLexical
import scala.util.parsing.combinator.syntactical.StandardTokenParsers
import scala.util.parsing.combinator.PackratParsers
+import org.apache.spark.Logging
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.execution.RunnableCommand
+import org.apache.spark.util.Utils
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.SqlLexical
@@ -49,6 +48,15 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi
protected implicit def asParser(k: Keyword): Parser[String] =
lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _)
+ protected val STRING = Keyword("STRING")
+ protected val SHORT = Keyword("SHORT")
+ protected val DOUBLE = Keyword("DOUBLE")
+ protected val BOOLEAN = Keyword("BOOLEAN")
+ protected val BYTE = Keyword("BYTE")
+ protected val FLOAT = Keyword("FLOAT")
+ protected val INT = Keyword("INT")
+ protected val INTEGER = Keyword("INTEGER")
+ protected val LONG = Keyword("LONG")
protected val CREATE = Keyword("CREATE")
protected val TEMPORARY = Keyword("TEMPORARY")
protected val TABLE = Keyword("TABLE")
@@ -67,16 +75,35 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi
protected lazy val ddl: Parser[LogicalPlan] = createTable
/**
- * CREATE TEMPORARY TABLE avroTable
+ * `CREATE TEMPORARY TABLE avroTable
* USING org.apache.spark.sql.avro
- * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")
+ * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")`
+ * or
+ * `CREATE TEMPORARY TABLE avroTable(intField int, stringField string...)
+ * USING org.apache.spark.sql.avro
+ * OPTIONS (path "../hive/src/test/resources/data/files/episodes.avro")`
*/
protected lazy val createTable: Parser[LogicalPlan] =
- CREATE ~ TEMPORARY ~ TABLE ~> ident ~ (USING ~> className) ~ (OPTIONS ~> options) ^^ {
+ ( CREATE ~ TEMPORARY ~ TABLE ~> ident ~ (USING ~> className) ~ (OPTIONS ~> options) ^^ {
case tableName ~ provider ~ opts =>
- CreateTableUsing(tableName, provider, opts)
+ CreateTableUsing(tableName, Seq.empty, provider, opts)
+ }
+ |
+ CREATE ~ TEMPORARY ~ TABLE ~> ident
+ ~ tableCols ~ (USING ~> className) ~ (OPTIONS ~> options) ^^ {
+ case tableName ~ tableColumns ~ provider ~ opts =>
+ CreateTableUsing(tableName, tableColumns, provider, opts)
+ }
+ )
+ protected lazy val tableCol: Parser[(String, String)] =
+ ident ~ (STRING | BYTE | SHORT | INT | INTEGER | LONG | FLOAT | DOUBLE | BOOLEAN) ^^ {
+ case e1 ~ e2 => (e1, e2)
}
+ protected lazy val tableCols: Parser[Seq[(String, String)]] =
+ "(" ~> repsep(tableCol, ",") <~ ")"
+
+
protected lazy val options: Parser[Map[String, String]] =
"(" ~> repsep(pair, ",") <~ ")" ^^ { case s: Seq[(String, String)] => s.toMap }
@@ -87,6 +114,7 @@ private[sql] class DDLParser extends StandardTokenParsers with PackratParsers wi
private[sql] case class CreateTableUsing(
tableName: String,
+ tableCols: Seq[(String, String)],
provider: String,
options: Map[String, String]) extends RunnableCommand {
@@ -100,9 +128,32 @@ private[sql] case class CreateTableUsing(
}
}
val dataSource = clazz.newInstance().asInstanceOf[org.apache.spark.sql.sources.RelationProvider]
- val relation = dataSource.createRelation(sqlContext, options)
+ val relation = dataSource.createRelation(sqlContext, options, toSchema(tableCols))
sqlContext.baseRelationToSchemaRDD(relation).registerTempTable(tableName)
Seq.empty
}
+
+ def toSchema(tableColumns: Seq[(String, String)]): Option[StructType] = {
+ val fields: Seq[StructField] = tableColumns.map { tableColumn =>
+ val columnName = tableColumn._1
+ val columnType = tableColumn._2
+ // todo: support more complex data type
+ columnType.toLowerCase match {
+ case "string" => StructField(columnName, StringType)
+ case "byte" => StructField(columnName, ByteType)
+ case "short" => StructField(columnName, ShortType)
+ case "int" => StructField(columnName, IntegerType)
+ case "integer" => StructField(columnName, IntegerType)
+ case "long" => StructField(columnName, LongType)
+ case "double" => StructField(columnName, DoubleType)
+ case "float" => StructField(columnName, FloatType)
+ case "boolean" => StructField(columnName, BooleanType)
+ }
+ }
+ if (fields.isEmpty) {
+ return None
+ }
+ Some(StructType(fields))
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
index 939b4e15163a6..7dade36cbef75 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
@@ -37,7 +37,10 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, Attribute}
@DeveloperApi
trait RelationProvider {
/** Returns a new base relation with the given parameters. */
- def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation
+ def createRelation(
+ sqlContext: SQLContext,
+ parameters: Map[String, String],
+ schema: Option[StructType]): BaseRelation
}
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
index 939b3c0c66de7..8aa55e2113f36 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
@@ -24,18 +24,26 @@ import org.apache.spark.sql._
class FilteredScanSource extends RelationProvider {
override def createRelation(
sqlContext: SQLContext,
- parameters: Map[String, String]): BaseRelation = {
- SimpleFilteredScan(parameters("from").toInt, parameters("to").toInt)(sqlContext)
+ parameters: Map[String, String],
+ schema: Option[StructType]): BaseRelation = {
+ SimpleFilteredScan(
+ parameters("from").toInt,
+ parameters("to").toInt,
+ schema: Option[StructType])(sqlContext)
}
}
-case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQLContext)
+case class SimpleFilteredScan(
+ from: Int,
+ to: Int,
+ _schema: Option[StructType])(@transient val sqlContext: SQLContext)
extends PrunedFilteredScan {
- override def schema =
+ override def schema = _schema.getOrElse(
StructType(
StructField("a", IntegerType, nullable = false) ::
StructField("b", IntegerType, nullable = false) :: Nil)
+ )
override def buildScan(requiredColumns: Array[String], filters: Array[Filter]) = {
val rowBuilders = requiredColumns.map {
@@ -80,85 +88,97 @@ class FilteredScanSuite extends DataSourceTest {
| to '10'
|)
""".stripMargin)
+
+ sql(
+ """
+ |CREATE TEMPORARY TABLE oneToTenFiltered_with_schema(a int, b int)
+ |USING org.apache.spark.sql.sources.FilteredScanSource
+ |OPTIONS (
+ | from '1',
+ | to '10'
+ |)
+ """.stripMargin)
}
+ Seq("oneToTenFiltered", "oneToTenFiltered_with_schema").foreach { table =>
- sqlTest(
- "SELECT * FROM oneToTenFiltered",
- (1 to 10).map(i => Row(i, i * 2)).toSeq)
+ sqlTest(
+ s"SELECT * FROM $table",
+ (1 to 10).map(i => Row(i, i * 2)).toSeq)
- sqlTest(
- "SELECT a, b FROM oneToTenFiltered",
- (1 to 10).map(i => Row(i, i * 2)).toSeq)
+ sqlTest(
+ s"SELECT a, b FROM $table",
+ (1 to 10).map(i => Row(i, i * 2)).toSeq)
- sqlTest(
- "SELECT b, a FROM oneToTenFiltered",
- (1 to 10).map(i => Row(i * 2, i)).toSeq)
+ sqlTest(
+ s"SELECT b, a FROM $table",
+ (1 to 10).map(i => Row(i * 2, i)).toSeq)
- sqlTest(
- "SELECT a FROM oneToTenFiltered",
- (1 to 10).map(i => Row(i)).toSeq)
+ sqlTest(
+ s"SELECT a FROM $table",
+ (1 to 10).map(i => Row(i)).toSeq)
- sqlTest(
- "SELECT b FROM oneToTenFiltered",
- (1 to 10).map(i => Row(i * 2)).toSeq)
+ sqlTest(
+ s"SELECT b FROM $table",
+ (1 to 10).map(i => Row(i * 2)).toSeq)
- sqlTest(
- "SELECT a * 2 FROM oneToTenFiltered",
- (1 to 10).map(i => Row(i * 2)).toSeq)
+ sqlTest(
+ s"SELECT a * 2 FROM $table",
+ (1 to 10).map(i => Row(i * 2)).toSeq)
- sqlTest(
- "SELECT A AS b FROM oneToTenFiltered",
- (1 to 10).map(i => Row(i)).toSeq)
+ sqlTest(
+ s"SELECT A AS b FROM $table",
+ (1 to 10).map(i => Row(i)).toSeq)
- sqlTest(
- "SELECT x.b, y.a FROM oneToTenFiltered x JOIN oneToTenFiltered y ON x.a = y.b",
- (1 to 5).map(i => Row(i * 4, i)).toSeq)
+ sqlTest(
+ s"SELECT x.b, y.a FROM $table x JOIN $table y ON x.a = y.b",
+ (1 to 5).map(i => Row(i * 4, i)).toSeq)
- sqlTest(
- "SELECT x.a, y.b FROM oneToTenFiltered x JOIN oneToTenFiltered y ON x.a = y.b",
- (2 to 10 by 2).map(i => Row(i, i)).toSeq)
+ sqlTest(
+ s"SELECT x.a, y.b FROM $table x JOIN $table y ON x.a = y.b",
+ (2 to 10 by 2).map(i => Row(i, i)).toSeq)
- sqlTest(
- "SELECT * FROM oneToTenFiltered WHERE a = 1",
- Seq(1).map(i => Row(i, i * 2)).toSeq)
+ sqlTest(
+ s"SELECT * FROM $table WHERE a = 1",
+ Seq(1).map(i => Row(i, i * 2)).toSeq)
- sqlTest(
- "SELECT * FROM oneToTenFiltered WHERE a IN (1,3,5)",
- Seq(1,3,5).map(i => Row(i, i * 2)).toSeq)
+ sqlTest(
+ s"SELECT * FROM $table WHERE a IN (1,3,5)",
+ Seq(1,3,5).map(i => Row(i, i * 2)).toSeq)
- sqlTest(
- "SELECT * FROM oneToTenFiltered WHERE A = 1",
- Seq(1).map(i => Row(i, i * 2)).toSeq)
+ sqlTest(
+ s"SELECT * FROM $table WHERE A = 1",
+ Seq(1).map(i => Row(i, i * 2)).toSeq)
- sqlTest(
- "SELECT * FROM oneToTenFiltered WHERE b = 2",
- Seq(1).map(i => Row(i, i * 2)).toSeq)
+ sqlTest(
+ s"SELECT * FROM $table WHERE b = 2",
+ Seq(1).map(i => Row(i, i * 2)).toSeq)
- testPushDown("SELECT * FROM oneToTenFiltered WHERE A = 1", 1)
- testPushDown("SELECT a FROM oneToTenFiltered WHERE A = 1", 1)
- testPushDown("SELECT b FROM oneToTenFiltered WHERE A = 1", 1)
- testPushDown("SELECT a, b FROM oneToTenFiltered WHERE A = 1", 1)
- testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 1", 1)
- testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 = a", 1)
+ testPushDown(s"SELECT * FROM $table WHERE A = 1", 1)
+ testPushDown(s"SELECT a FROM $table WHERE A = 1", 1)
+ testPushDown(s"SELECT b FROM $table WHERE A = 1", 1)
+ testPushDown(s"SELECT a, b FROM $table WHERE A = 1", 1)
+ testPushDown(s"SELECT * FROM $table WHERE a = 1", 1)
+ testPushDown(s"SELECT * FROM $table WHERE 1 = a", 1)
- testPushDown("SELECT * FROM oneToTenFiltered WHERE a > 1", 9)
- testPushDown("SELECT * FROM oneToTenFiltered WHERE a >= 2", 9)
+ testPushDown(s"SELECT * FROM $table WHERE a > 1", 9)
+ testPushDown(s"SELECT * FROM $table WHERE a >= 2", 9)
- testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 < a", 9)
- testPushDown("SELECT * FROM oneToTenFiltered WHERE 2 <= a", 9)
+ testPushDown(s"SELECT * FROM $table WHERE 1 < a", 9)
+ testPushDown(s"SELECT * FROM $table WHERE 2 <= a", 9)
- testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 > a", 0)
- testPushDown("SELECT * FROM oneToTenFiltered WHERE 2 >= a", 2)
+ testPushDown(s"SELECT * FROM $table WHERE 1 > a", 0)
+ testPushDown(s"SELECT * FROM $table WHERE 2 >= a", 2)
- testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 1", 0)
- testPushDown("SELECT * FROM oneToTenFiltered WHERE a <= 2", 2)
+ testPushDown(s"SELECT * FROM $table WHERE a < 1", 0)
+ testPushDown(s"SELECT * FROM $table WHERE a <= 2", 2)
- testPushDown("SELECT * FROM oneToTenFiltered WHERE a > 1 AND a < 10", 8)
+ testPushDown(s"SELECT * FROM $table WHERE a > 1 AND a < 10", 8)
- testPushDown("SELECT * FROM oneToTenFiltered WHERE a IN (1,3,5)", 3)
+ testPushDown(s"SELECT * FROM $table WHERE a IN (1,3,5)", 3)
- testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 20", 0)
- testPushDown("SELECT * FROM oneToTenFiltered WHERE b = 1", 10)
+ testPushDown(s"SELECT * FROM $table WHERE a = 20", 0)
+ testPushDown(s"SELECT * FROM $table WHERE b = 1", 10)
+ }
def testPushDown(sqlString: String, expectedCount: Int): Unit = {
test(s"PushDown Returns $expectedCount: $sqlString") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala
index fee2e22611cdc..a2b9199ea90cf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala
@@ -22,18 +22,23 @@ import org.apache.spark.sql._
class PrunedScanSource extends RelationProvider {
override def createRelation(
sqlContext: SQLContext,
- parameters: Map[String, String]): BaseRelation = {
- SimplePrunedScan(parameters("from").toInt, parameters("to").toInt)(sqlContext)
+ parameters: Map[String, String],
+ schema: Option[StructType]): BaseRelation = {
+ SimplePrunedScan(parameters("from").toInt, parameters("to").toInt, schema)(sqlContext)
}
}
-case class SimplePrunedScan(from: Int, to: Int)(@transient val sqlContext: SQLContext)
+case class SimplePrunedScan(
+ from: Int,
+ to: Int,
+ _schema: Option[StructType])(@transient val sqlContext: SQLContext)
extends PrunedScan {
- override def schema =
+ override def schema = _schema.getOrElse(
StructType(
StructField("a", IntegerType, nullable = false) ::
StructField("b", IntegerType, nullable = false) :: Nil)
+ )
override def buildScan(requiredColumns: Array[String]) = {
val rowBuilders = requiredColumns.map {
@@ -59,54 +64,66 @@ class PrunedScanSuite extends DataSourceTest {
| to '10'
|)
""".stripMargin)
- }
-
- sqlTest(
- "SELECT * FROM oneToTenPruned",
- (1 to 10).map(i => Row(i, i * 2)).toSeq)
-
- sqlTest(
- "SELECT a, b FROM oneToTenPruned",
- (1 to 10).map(i => Row(i, i * 2)).toSeq)
-
- sqlTest(
- "SELECT b, a FROM oneToTenPruned",
- (1 to 10).map(i => Row(i * 2, i)).toSeq)
-
- sqlTest(
- "SELECT a FROM oneToTenPruned",
- (1 to 10).map(i => Row(i)).toSeq)
-
- sqlTest(
- "SELECT a, a FROM oneToTenPruned",
- (1 to 10).map(i => Row(i, i)).toSeq)
-
- sqlTest(
- "SELECT b FROM oneToTenPruned",
- (1 to 10).map(i => Row(i * 2)).toSeq)
-
- sqlTest(
- "SELECT a * 2 FROM oneToTenPruned",
- (1 to 10).map(i => Row(i * 2)).toSeq)
-
- sqlTest(
- "SELECT A AS b FROM oneToTenPruned",
- (1 to 10).map(i => Row(i)).toSeq)
-
- sqlTest(
- "SELECT x.b, y.a FROM oneToTenPruned x JOIN oneToTenPruned y ON x.a = y.b",
- (1 to 5).map(i => Row(i * 4, i)).toSeq)
- sqlTest(
- "SELECT x.a, y.b FROM oneToTenPruned x JOIN oneToTenPruned y ON x.a = y.b",
- (2 to 10 by 2).map(i => Row(i, i)).toSeq)
+ sql(
+ """
+ |CREATE TEMPORARY TABLE oneToTenPruned_with_schema(a int, b int)
+ |USING org.apache.spark.sql.sources.PrunedScanSource
+ |OPTIONS (
+ | from '1',
+ | to '10'
+ |)
+ """.stripMargin)
+ }
- testPruning("SELECT * FROM oneToTenPruned", "a", "b")
- testPruning("SELECT a, b FROM oneToTenPruned", "a", "b")
- testPruning("SELECT b, a FROM oneToTenPruned", "b", "a")
- testPruning("SELECT b, b FROM oneToTenPruned", "b")
- testPruning("SELECT a FROM oneToTenPruned", "a")
- testPruning("SELECT b FROM oneToTenPruned", "b")
+ Seq("oneToTenPruned", "oneToTenPruned_with_schema").foreach { table =>
+ sqlTest(
+ s"SELECT * FROM $table",
+ (1 to 10).map(i => Row(i, i * 2)).toSeq)
+
+ sqlTest(
+ s"SELECT a, b FROM $table",
+ (1 to 10).map(i => Row(i, i * 2)).toSeq)
+
+ sqlTest(
+ s"SELECT b, a FROM $table",
+ (1 to 10).map(i => Row(i * 2, i)).toSeq)
+
+ sqlTest(
+ s"SELECT a FROM $table",
+ (1 to 10).map(i => Row(i)).toSeq)
+
+ sqlTest(
+ s"SELECT a, a FROM $table",
+ (1 to 10).map(i => Row(i, i)).toSeq)
+
+ sqlTest(
+ s"SELECT b FROM $table",
+ (1 to 10).map(i => Row(i * 2)).toSeq)
+
+ sqlTest(
+ s"SELECT a * 2 FROM $table",
+ (1 to 10).map(i => Row(i * 2)).toSeq)
+
+ sqlTest(
+ s"SELECT A AS b FROM $table",
+ (1 to 10).map(i => Row(i)).toSeq)
+
+ sqlTest(
+ s"SELECT x.b, y.a FROM $table x JOIN $table y ON x.a = y.b",
+ (1 to 5).map(i => Row(i * 4, i)).toSeq)
+
+ sqlTest(
+ s"SELECT x.a, y.b FROM $table x JOIN $table y ON x.a = y.b",
+ (2 to 10 by 2).map(i => Row(i, i)).toSeq)
+
+ testPruning(s"SELECT * FROM $table", "a", "b")
+ testPruning(s"SELECT a, b FROM $table", "a", "b")
+ testPruning(s"SELECT b, a FROM $table", "b", "a")
+ testPruning(s"SELECT b, b FROM $table", "b")
+ testPruning(s"SELECT a FROM $table", "a")
+ testPruning(s"SELECT b FROM $table", "b")
+ }
def testPruning(sqlString: String, expectedColumns: String*): Unit = {
test(s"Columns output ${expectedColumns.mkString(",")}: $sqlString") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
index b254b0620c779..8a5ef44fa4be3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
@@ -24,17 +24,21 @@ class DefaultSource extends SimpleScanSource
class SimpleScanSource extends RelationProvider {
override def createRelation(
sqlContext: SQLContext,
- parameters: Map[String, String]): BaseRelation = {
- SimpleScan(parameters("from").toInt, parameters("to").toInt)(sqlContext)
+ parameters: Map[String, String],
+ schema: Option[StructType]): BaseRelation = {
+ SimpleScan(parameters("from").toInt, parameters("to").toInt, schema)(sqlContext)
}
}
-case class SimpleScan(from: Int, to: Int)(@transient val sqlContext: SQLContext)
+case class SimpleScan(
+ from: Int,
+ to: Int,
+ _schema: Option[StructType])(@transient val sqlContext: SQLContext)
extends TableScan {
- override def schema =
+ override def schema = _schema.getOrElse(
StructType(StructField("i", IntegerType, nullable = false) :: Nil)
-
+ )
override def buildScan() = sqlContext.sparkContext.parallelize(from to to).map(Row(_))
}
@@ -51,60 +55,75 @@ class TableScanSuite extends DataSourceTest {
| to '10'
|)
""".stripMargin)
- }
-
- sqlTest(
- "SELECT * FROM oneToTen",
- (1 to 10).map(Row(_)).toSeq)
-
- sqlTest(
- "SELECT i FROM oneToTen",
- (1 to 10).map(Row(_)).toSeq)
-
- sqlTest(
- "SELECT i FROM oneToTen WHERE i < 5",
- (1 to 4).map(Row(_)).toSeq)
-
- sqlTest(
- "SELECT i * 2 FROM oneToTen",
- (1 to 10).map(i => Row(i * 2)).toSeq)
-
- sqlTest(
- "SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1",
- (2 to 10).map(i => Row(i, i - 1)).toSeq)
+ sql(
+ """
+ |CREATE TEMPORARY TABLE oneToTen_with_schema(i int)
+ |USING org.apache.spark.sql.sources.SimpleScanSource
+ |OPTIONS (
+ | from '1',
+ | to '10'
+ |)
+ """.stripMargin)
+ }
- test("Caching") {
- // Cached Query Execution
- cacheTable("oneToTen")
- assertCached(sql("SELECT * FROM oneToTen"))
- checkAnswer(
- sql("SELECT * FROM oneToTen"),
+ Seq("oneToTen", "oneToTen_with_schema").foreach { table =>
+ sqlTest(
+ s"SELECT * FROM $table",
(1 to 10).map(Row(_)).toSeq)
- assertCached(sql("SELECT i FROM oneToTen"))
- checkAnswer(
- sql("SELECT i FROM oneToTen"),
+ sqlTest(
+ s"SELECT i FROM $table",
(1 to 10).map(Row(_)).toSeq)
- assertCached(sql("SELECT i FROM oneToTen WHERE i < 5"))
- checkAnswer(
- sql("SELECT i FROM oneToTen WHERE i < 5"),
+ sqlTest(
+ s"SELECT i FROM $table WHERE i < 5",
(1 to 4).map(Row(_)).toSeq)
- assertCached(sql("SELECT i * 2 FROM oneToTen"))
- checkAnswer(
- sql("SELECT i * 2 FROM oneToTen"),
+ sqlTest(
+ s"SELECT i * 2 FROM $table",
(1 to 10).map(i => Row(i * 2)).toSeq)
- assertCached(sql("SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), 2)
- checkAnswer(
- sql("SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"),
+ sqlTest(
+ s"SELECT a.i, b.i FROM $table a JOIN $table b ON a.i = b.i + 1",
(2 to 10).map(i => Row(i, i - 1)).toSeq)
+ }
+
- // Verify uncaching
- uncacheTable("oneToTen")
- assertCached(sql("SELECT * FROM oneToTen"), 0)
+ Seq("oneToTen", "oneToTen_with_schema").foreach { table =>
+
+ test(s"Caching $table") {
+ // Cached Query Execution
+ cacheTable(s"$table")
+ assertCached(sql(s"SELECT * FROM $table"))
+ checkAnswer(
+ sql(s"SELECT * FROM $table"),
+ (1 to 10).map(Row(_)).toSeq)
+
+ assertCached(sql(s"SELECT i FROM $table"))
+ checkAnswer(
+ sql(s"SELECT i FROM $table"),
+ (1 to 10).map(Row(_)).toSeq)
+
+ assertCached(sql(s"SELECT i FROM $table WHERE i < 5"))
+ checkAnswer(
+ sql(s"SELECT i FROM $table WHERE i < 5"),
+ (1 to 4).map(Row(_)).toSeq)
+
+ assertCached(sql(s"SELECT i * 2 FROM $table"))
+ checkAnswer(
+ sql(s"SELECT i * 2 FROM $table"),
+ (1 to 10).map(i => Row(i * 2)).toSeq)
+
+ assertCached(sql(s"SELECT a.i, b.i FROM $table a JOIN $table b ON a.i = b.i + 1"), 2)
+ checkAnswer(
+ sql(s"SELECT a.i, b.i FROM $table a JOIN $table b ON a.i = b.i + 1"),
+ (2 to 10).map(i => Row(i, i - 1)).toSeq)
+
+ // Verify uncaching
+ uncacheTable(s"$table")
+ assertCached(sql(s"SELECT * FROM $table"), 0)
+ }
}
test("defaultSource") {