diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index e33d0d8e29d4..97e0c9edeab4 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -2642,6 +2642,7 @@ generateAliasesForIntersectedCols <- function (x, intersectedColNames, suffix) {
#'
#' Return a new SparkDataFrame containing the union of rows in this SparkDataFrame
#' and another SparkDataFrame. This is equivalent to \code{UNION ALL} in SQL.
+#' Input SparkDataFrames can have different schemas (names and data types).
#'
#' Note: This does not remove duplicate rows across the two SparkDataFrames.
#'
@@ -2685,7 +2686,8 @@ setMethod("unionAll",
#' Union two or more SparkDataFrames
#'
-#' Union two or more SparkDataFrames. This is equivalent to \code{UNION ALL} in SQL.
+#' Union two or more SparkDataFrames by row. As in R's \code{rbind}, this method
+#' requires that the input SparkDataFrames have the same column names.
#'
#' Note: This does not remove duplicate rows across the two SparkDataFrames.
#'
@@ -2709,6 +2711,10 @@ setMethod("unionAll",
setMethod("rbind",
signature(... = "SparkDataFrame"),
function(x, ..., deparse.level = 1) {
+ nm <- lapply(list(x, ...), names)
+ if (length(unique(nm)) != 1) {
+ stop("Names of input data frames are different.")
+ }
if (nargs() == 3) {
union(x, ...)
} else {
diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R
index 7c096597fea6..9735fe320155 100644
--- a/R/pkg/inst/tests/testthat/test_sparkSQL.R
+++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R
@@ -1850,6 +1850,13 @@ test_that("union(), rbind(), except(), and intersect() on a DataFrame", {
expect_equal(count(unioned2), 12)
expect_equal(first(unioned2)$name, "Michael")
+ df3 <- df2
+ names(df3)[1] <- "newName"
+ expect_error(rbind(df, df3),
+ "Names of input data frames are different.")
+ expect_error(rbind(df, df2, df3),
+ "Names of input data frames are different.")
+
excepted <- arrange(except(df, df2), desc(df$age))
expect_is(unioned, "SparkDataFrame")
expect_equal(count(excepted), 2)
@@ -2585,8 +2592,8 @@ test_that("coalesce, repartition, numPartitions", {
df2 <- repartition(df1, 10)
expect_equal(getNumPartitions(df2), 10)
- expect_equal(getNumPartitions(coalesce(df2, 13)), 5)
- expect_equal(getNumPartitions(coalesce(df2, 7)), 5)
+ expect_equal(getNumPartitions(coalesce(df2, 13)), 10)
+ expect_equal(getNumPartitions(coalesce(df2, 7)), 7)
expect_equal(getNumPartitions(coalesce(df2, 3)), 3)
})
diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index 10a7cb1d0665..4c28075bd938 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -850,11 +850,23 @@ public UTF8String translate(Map dict) {
return fromString(sb.toString());
}
- private int getDigit(byte b) {
- if (b >= '0' && b <= '9') {
- return b - '0';
- }
- throw new NumberFormatException(toString());
+ /**
+ * Wrapper over `long` to allow result of parsing long from string to be accessed via reference.
+ * This is done solely for better performance and is not expected to be used by end users.
+ */
+ public static class LongWrapper {
+ public long value = 0;
+ }
+
+ /**
+ * Wrapper over `int` to allow result of parsing integer from string to be accessed via reference.
+ * This is done solely for better performance and is not expected to be used by end users.
+ *
+ * {@link LongWrapper} could have been used here but using `int` directly save the extra cost of
+ * conversion from `long` -> `int`
+ */
+ public static class IntWrapper {
+ public int value = 0;
}
/**
@@ -862,14 +874,18 @@ private int getDigit(byte b) {
*
* Note that, in this method we accumulate the result in negative format, and convert it to
* positive format at the end, if this string is not started with '-'. This is because min value
- * is bigger than max value in digits, e.g. Integer.MAX_VALUE is '2147483647' and
- * Integer.MIN_VALUE is '-2147483648'.
+ * is bigger than max value in digits, e.g. Long.MAX_VALUE is '9223372036854775807' and
+ * Long.MIN_VALUE is '-9223372036854775808'.
*
* This code is mostly copied from LazyLong.parseLong in Hive.
+ *
+ * @param toLongResult If a valid `long` was parsed from this UTF8String, then its value would
+ * be set in `toLongResult`
+ * @return true if the parsing was successful else false
*/
- public long toLong() {
+ public boolean toLong(LongWrapper toLongResult) {
if (numBytes == 0) {
- throw new NumberFormatException("Empty string");
+ return false;
}
byte b = getByte(0);
@@ -878,7 +894,7 @@ public long toLong() {
if (negative || b == '+') {
offset++;
if (numBytes == 1) {
- throw new NumberFormatException(toString());
+ return false;
}
}
@@ -897,20 +913,25 @@ public long toLong() {
break;
}
- int digit = getDigit(b);
+ int digit;
+ if (b >= '0' && b <= '9') {
+ digit = b - '0';
+ } else {
+ return false;
+ }
+
// We are going to process the new digit and accumulate the result. However, before doing
// this, if the result is already smaller than the stopValue(Long.MIN_VALUE / radix), then
- // result * 10 will definitely be smaller than minValue, and we can stop and throw exception.
+ // result * 10 will definitely be smaller than minValue, and we can stop.
if (result < stopValue) {
- throw new NumberFormatException(toString());
+ return false;
}
result = result * radix - digit;
// Since the previous result is less than or equal to stopValue(Long.MIN_VALUE / radix), we
- // can just use `result > 0` to check overflow. If result overflows, we should stop and throw
- // exception.
+ // can just use `result > 0` to check overflow. If result overflows, we should stop.
if (result > 0) {
- throw new NumberFormatException(toString());
+ return false;
}
}
@@ -918,8 +939,9 @@ public long toLong() {
// part will not change the number, but we will verify that the fractional part
// is well formed.
while (offset < numBytes) {
- if (getDigit(getByte(offset)) == -1) {
- throw new NumberFormatException(toString());
+ byte currentByte = getByte(offset);
+ if (currentByte < '0' || currentByte > '9') {
+ return false;
}
offset++;
}
@@ -927,11 +949,12 @@ public long toLong() {
if (!negative) {
result = -result;
if (result < 0) {
- throw new NumberFormatException(toString());
+ return false;
}
}
- return result;
+ toLongResult.value = result;
+ return true;
}
/**
@@ -946,10 +969,14 @@ public long toLong() {
*
* Note that, this method is almost same as `toLong`, but we leave it duplicated for performance
* reasons, like Hive does.
+ *
+ * @param intWrapper If a valid `int` was parsed from this UTF8String, then its value would
+ * be set in `intWrapper`
+ * @return true if the parsing was successful else false
*/
- public int toInt() {
+ public boolean toInt(IntWrapper intWrapper) {
if (numBytes == 0) {
- throw new NumberFormatException("Empty string");
+ return false;
}
byte b = getByte(0);
@@ -958,7 +985,7 @@ public int toInt() {
if (negative || b == '+') {
offset++;
if (numBytes == 1) {
- throw new NumberFormatException(toString());
+ return false;
}
}
@@ -977,20 +1004,25 @@ public int toInt() {
break;
}
- int digit = getDigit(b);
+ int digit;
+ if (b >= '0' && b <= '9') {
+ digit = b - '0';
+ } else {
+ return false;
+ }
+
// We are going to process the new digit and accumulate the result. However, before doing
// this, if the result is already smaller than the stopValue(Integer.MIN_VALUE / radix), then
- // result * 10 will definitely be smaller than minValue, and we can stop and throw exception.
+ // result * 10 will definitely be smaller than minValue, and we can stop
if (result < stopValue) {
- throw new NumberFormatException(toString());
+ return false;
}
result = result * radix - digit;
// Since the previous result is less than or equal to stopValue(Integer.MIN_VALUE / radix),
- // we can just use `result > 0` to check overflow. If result overflows, we should stop and
- // throw exception.
+ // we can just use `result > 0` to check overflow. If result overflows, we should stop
if (result > 0) {
- throw new NumberFormatException(toString());
+ return false;
}
}
@@ -998,8 +1030,9 @@ public int toInt() {
// part will not change the number, but we will verify that the fractional part
// is well formed.
while (offset < numBytes) {
- if (getDigit(getByte(offset)) == -1) {
- throw new NumberFormatException(toString());
+ byte currentByte = getByte(offset);
+ if (currentByte < '0' || currentByte > '9') {
+ return false;
}
offset++;
}
@@ -1007,31 +1040,33 @@ public int toInt() {
if (!negative) {
result = -result;
if (result < 0) {
- throw new NumberFormatException(toString());
+ return false;
}
}
-
- return result;
+ intWrapper.value = result;
+ return true;
}
- public short toShort() {
- int intValue = toInt();
- short result = (short) intValue;
- if (result != intValue) {
- throw new NumberFormatException(toString());
+ public boolean toShort(IntWrapper intWrapper) {
+ if (toInt(intWrapper)) {
+ int intValue = intWrapper.value;
+ short result = (short) intValue;
+ if (result == intValue) {
+ return true;
+ }
}
-
- return result;
+ return false;
}
- public byte toByte() {
- int intValue = toInt();
- byte result = (byte) intValue;
- if (result != intValue) {
- throw new NumberFormatException(toString());
+ public boolean toByte(IntWrapper intWrapper) {
+ if (toInt(intWrapper)) {
+ int intValue = intWrapper.value;
+ byte result = (byte) intValue;
+ if (result == intValue) {
+ return true;
+ }
}
-
- return result;
+ return false;
}
@Override
diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
index 6f6e0ef0e485..c376371abdf9 100644
--- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
+++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
@@ -22,9 +22,7 @@
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.StandardCharsets;
-import java.util.Arrays;
-import java.util.HashMap;
-import java.util.HashSet;
+import java.util.*;
import com.google.common.collect.ImmutableMap;
import org.apache.spark.unsafe.Platform;
@@ -608,4 +606,128 @@ public void writeToOutputStreamIntArray() throws IOException {
.writeTo(outputStream);
assertEquals("大千世界", outputStream.toString("UTF-8"));
}
+
+ @Test
+ public void testToShort() throws IOException {
+ Map inputToExpectedOutput = new HashMap<>();
+ inputToExpectedOutput.put("1", (short) 1);
+ inputToExpectedOutput.put("+1", (short) 1);
+ inputToExpectedOutput.put("-1", (short) -1);
+ inputToExpectedOutput.put("0", (short) 0);
+ inputToExpectedOutput.put("1111.12345678901234567890", (short) 1111);
+ inputToExpectedOutput.put(String.valueOf(Short.MAX_VALUE), Short.MAX_VALUE);
+ inputToExpectedOutput.put(String.valueOf(Short.MIN_VALUE), Short.MIN_VALUE);
+
+ Random rand = new Random();
+ for (int i = 0; i < 10; i++) {
+ short value = (short) rand.nextInt();
+ inputToExpectedOutput.put(String.valueOf(value), value);
+ }
+
+ IntWrapper wrapper = new IntWrapper();
+ for (Map.Entry entry : inputToExpectedOutput.entrySet()) {
+ assertTrue(entry.getKey(), UTF8String.fromString(entry.getKey()).toShort(wrapper));
+ assertEquals((short) entry.getValue(), wrapper.value);
+ }
+
+ List negativeInputs =
+ Arrays.asList("", " ", "null", "NULL", "\n", "~1212121", "3276700");
+
+ for (String negativeInput : negativeInputs) {
+ assertFalse(negativeInput, UTF8String.fromString(negativeInput).toShort(wrapper));
+ }
+ }
+
+ @Test
+ public void testToByte() throws IOException {
+ Map inputToExpectedOutput = new HashMap<>();
+ inputToExpectedOutput.put("1", (byte) 1);
+ inputToExpectedOutput.put("+1",(byte) 1);
+ inputToExpectedOutput.put("-1", (byte) -1);
+ inputToExpectedOutput.put("0", (byte) 0);
+ inputToExpectedOutput.put("111.12345678901234567890", (byte) 111);
+ inputToExpectedOutput.put(String.valueOf(Byte.MAX_VALUE), Byte.MAX_VALUE);
+ inputToExpectedOutput.put(String.valueOf(Byte.MIN_VALUE), Byte.MIN_VALUE);
+
+ Random rand = new Random();
+ for (int i = 0; i < 10; i++) {
+ byte value = (byte) rand.nextInt();
+ inputToExpectedOutput.put(String.valueOf(value), value);
+ }
+
+ IntWrapper intWrapper = new IntWrapper();
+ for (Map.Entry entry : inputToExpectedOutput.entrySet()) {
+ assertTrue(entry.getKey(), UTF8String.fromString(entry.getKey()).toByte(intWrapper));
+ assertEquals((byte) entry.getValue(), intWrapper.value);
+ }
+
+ List negativeInputs =
+ Arrays.asList("", " ", "null", "NULL", "\n", "~1212121", "12345678901234567890");
+
+ for (String negativeInput : negativeInputs) {
+ assertFalse(negativeInput, UTF8String.fromString(negativeInput).toByte(intWrapper));
+ }
+ }
+
+ @Test
+ public void testToInt() throws IOException {
+ Map inputToExpectedOutput = new HashMap<>();
+ inputToExpectedOutput.put("1", 1);
+ inputToExpectedOutput.put("+1", 1);
+ inputToExpectedOutput.put("-1", -1);
+ inputToExpectedOutput.put("0", 0);
+ inputToExpectedOutput.put("11111.1234567", 11111);
+ inputToExpectedOutput.put(String.valueOf(Integer.MAX_VALUE), Integer.MAX_VALUE);
+ inputToExpectedOutput.put(String.valueOf(Integer.MIN_VALUE), Integer.MIN_VALUE);
+
+ Random rand = new Random();
+ for (int i = 0; i < 10; i++) {
+ int value = rand.nextInt();
+ inputToExpectedOutput.put(String.valueOf(value), value);
+ }
+
+ IntWrapper intWrapper = new IntWrapper();
+ for (Map.Entry entry : inputToExpectedOutput.entrySet()) {
+ assertTrue(entry.getKey(), UTF8String.fromString(entry.getKey()).toInt(intWrapper));
+ assertEquals((int) entry.getValue(), intWrapper.value);
+ }
+
+ List negativeInputs =
+ Arrays.asList("", " ", "null", "NULL", "\n", "~1212121", "12345678901234567890");
+
+ for (String negativeInput : negativeInputs) {
+ assertFalse(negativeInput, UTF8String.fromString(negativeInput).toInt(intWrapper));
+ }
+ }
+
+ @Test
+ public void testToLong() throws IOException {
+ Map inputToExpectedOutput = new HashMap<>();
+ inputToExpectedOutput.put("1", 1L);
+ inputToExpectedOutput.put("+1", 1L);
+ inputToExpectedOutput.put("-1", -1L);
+ inputToExpectedOutput.put("0", 0L);
+ inputToExpectedOutput.put("1076753423.12345678901234567890", 1076753423L);
+ inputToExpectedOutput.put(String.valueOf(Long.MAX_VALUE), Long.MAX_VALUE);
+ inputToExpectedOutput.put(String.valueOf(Long.MIN_VALUE), Long.MIN_VALUE);
+
+ Random rand = new Random();
+ for (int i = 0; i < 10; i++) {
+ long value = rand.nextLong();
+ inputToExpectedOutput.put(String.valueOf(value), value);
+ }
+
+ LongWrapper wrapper = new LongWrapper();
+ for (Map.Entry entry : inputToExpectedOutput.entrySet()) {
+ assertTrue(entry.getKey(), UTF8String.fromString(entry.getKey()).toLong(wrapper));
+ assertEquals((long) entry.getValue(), wrapper.value);
+ }
+
+ List negativeInputs = Arrays.asList("", " ", "null", "NULL", "\n", "~1212121",
+ "1234567890123456789012345678901234");
+
+ for (String negativeInput : negativeInputs) {
+ assertFalse(negativeInput, UTF8String.fromString(negativeInput).toLong(wrapper));
+ }
+ }
}
diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template
index 5c1e876ef9af..94bd2c477a35 100755
--- a/conf/spark-env.sh.template
+++ b/conf/spark-env.sh.template
@@ -25,12 +25,10 @@
# - HADOOP_CONF_DIR, to point Spark towards Hadoop configuration files
# - SPARK_LOCAL_IP, to set the IP address Spark binds to on this node
# - SPARK_PUBLIC_DNS, to set the public dns name of the driver program
-# - SPARK_CLASSPATH, default classpath entries to append
# Options read by executors and drivers running inside the cluster
# - SPARK_LOCAL_IP, to set the IP address Spark binds to on this node
# - SPARK_PUBLIC_DNS, to set the public DNS name of the driver program
-# - SPARK_CLASSPATH, default classpath entries to append
# - SPARK_LOCAL_DIRS, storage directories to use on this node for shuffle and RDD data
# - MESOS_NATIVE_JAVA_LIBRARY, to point to your libmesos.so if you use Mesos
@@ -48,7 +46,6 @@
# - SPARK_WORKER_CORES, to set the number of cores to use on this machine
# - SPARK_WORKER_MEMORY, to set how much total memory workers have to give executors (e.g. 1000m, 2g)
# - SPARK_WORKER_PORT / SPARK_WORKER_WEBUI_PORT, to use non-default ports for the worker
-# - SPARK_WORKER_INSTANCES, to set the number of worker processes per node
# - SPARK_WORKER_DIR, to set the working directory of worker processes
# - SPARK_WORKER_OPTS, to set config properties only for the worker (e.g. "-Dx=y")
# - SPARK_DAEMON_MEMORY, to allocate to the master, worker and history server themselves (default: 1g).
diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala
index fe912e639bcb..2a2ce0504dbb 100644
--- a/core/src/main/scala/org/apache/spark/SparkConf.scala
+++ b/core/src/main/scala/org/apache/spark/SparkConf.scala
@@ -518,71 +518,6 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria
}
}
- // Check for legacy configs
- sys.env.get("SPARK_JAVA_OPTS").foreach { value =>
- val warning =
- s"""
- |SPARK_JAVA_OPTS was detected (set to '$value').
- |This is deprecated in Spark 1.0+.
- |
- |Please instead use:
- | - ./spark-submit with conf/spark-defaults.conf to set defaults for an application
- | - ./spark-submit with --driver-java-options to set -X options for a driver
- | - spark.executor.extraJavaOptions to set -X options for executors
- | - SPARK_DAEMON_JAVA_OPTS to set java options for standalone daemons (master or worker)
- """.stripMargin
- logWarning(warning)
-
- for (key <- Seq(executorOptsKey, driverOptsKey)) {
- if (getOption(key).isDefined) {
- throw new SparkException(s"Found both $key and SPARK_JAVA_OPTS. Use only the former.")
- } else {
- logWarning(s"Setting '$key' to '$value' as a work-around.")
- set(key, value)
- }
- }
- }
-
- sys.env.get("SPARK_CLASSPATH").foreach { value =>
- val warning =
- s"""
- |SPARK_CLASSPATH was detected (set to '$value').
- |This is deprecated in Spark 1.0+.
- |
- |Please instead use:
- | - ./spark-submit with --driver-class-path to augment the driver classpath
- | - spark.executor.extraClassPath to augment the executor classpath
- """.stripMargin
- logWarning(warning)
-
- for (key <- Seq(executorClasspathKey, driverClassPathKey)) {
- if (getOption(key).isDefined) {
- throw new SparkException(s"Found both $key and SPARK_CLASSPATH. Use only the former.")
- } else {
- logWarning(s"Setting '$key' to '$value' as a work-around.")
- set(key, value)
- }
- }
- }
-
- if (!contains(sparkExecutorInstances)) {
- sys.env.get("SPARK_WORKER_INSTANCES").foreach { value =>
- val warning =
- s"""
- |SPARK_WORKER_INSTANCES was detected (set to '$value').
- |This is deprecated in Spark 1.0+.
- |
- |Please instead use:
- | - ./spark-submit with --num-executors to specify the number of executors
- | - Or set SPARK_EXECUTOR_INSTANCES
- | - spark.executor.instances to configure the number of instances in the spark config.
- """.stripMargin
- logWarning(warning)
-
- set("spark.executor.instances", value)
- }
- }
-
if (contains("spark.master") && get("spark.master").startsWith("yarn-")) {
val warning = s"spark.master ${get("spark.master")} is deprecated in Spark 2.0+, please " +
"instead use \"yarn\" with specified deploy mode."
diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala
index 320af5cf9755..c6307da61c7e 100644
--- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala
@@ -43,8 +43,7 @@ import org.apache.spark.util.{ThreadUtils, Utils}
* Execute using
* ./bin/spark-class org.apache.spark.deploy.FaultToleranceTest
*
- * Make sure that the environment includes the following properties in SPARK_DAEMON_JAVA_OPTS
- * *and* SPARK_JAVA_OPTS:
+ * Make sure that the environment includes the following properties in SPARK_DAEMON_JAVA_OPTS:
* - spark.deploy.recoveryMode=ZOOKEEPER
* - spark.deploy.zookeeper.url=172.17.42.1:2181
* Note that 172.17.42.1 is the default docker ip for the host and 2181 is the default ZK port.
diff --git a/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala b/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala
index 3fd812e9fcfe..4216b2627309 100644
--- a/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala
+++ b/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala
@@ -39,7 +39,6 @@ private[spark] class WorkerCommandBuilder(sparkHome: String, memoryMb: Int, comm
val cmd = buildJavaCommand(command.classPathEntries.mkString(File.pathSeparator))
cmd.add(s"-Xmx${memoryMb}M")
command.javaOpts.foreach(cmd.add)
- addOptionString(cmd, getenv("SPARK_JAVA_OPTS"))
cmd
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala
index 78aa5c40010c..c98b87148e40 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala
@@ -19,6 +19,7 @@ package org.apache.spark.scheduler
import java.io.{DataInputStream, DataOutputStream}
import java.nio.ByteBuffer
+import java.nio.charset.StandardCharsets
import java.util.Properties
import scala.collection.JavaConverters._
@@ -86,7 +87,10 @@ private[spark] object TaskDescription {
dataOut.writeInt(taskDescription.properties.size())
taskDescription.properties.asScala.foreach { case (key, value) =>
dataOut.writeUTF(key)
- dataOut.writeUTF(value)
+ // SPARK-19796 -- writeUTF doesn't work for long strings, which can happen for property values
+ val bytes = value.getBytes(StandardCharsets.UTF_8)
+ dataOut.writeInt(bytes.length)
+ dataOut.write(bytes)
}
// Write the task. The task is already serialized, so write it directly to the byte buffer.
@@ -124,7 +128,11 @@ private[spark] object TaskDescription {
val properties = new Properties()
val numProperties = dataIn.readInt()
for (i <- 0 until numProperties) {
- properties.setProperty(dataIn.readUTF(), dataIn.readUTF())
+ val key = dataIn.readUTF()
+ val valueLength = dataIn.readInt()
+ val valueBytes = new Array[Byte](valueLength)
+ dataIn.readFully(valueBytes)
+ properties.setProperty(key, new String(valueBytes, StandardCharsets.UTF_8))
}
// Create a sub-buffer for the serialized task into its own buffer (to be deserialized later).
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
index 59680139e7af..9843eab4f134 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
@@ -70,11 +70,13 @@ class TaskInfo(
var killed = false
- private[spark] def markGettingResult(time: Long = System.currentTimeMillis) {
+ private[spark] def markGettingResult(time: Long) {
gettingResultTime = time
}
- private[spark] def markFinished(state: TaskState, time: Long = System.currentTimeMillis) {
+ private[spark] def markFinished(state: TaskState, time: Long) {
+ // finishTime should be set larger than 0, otherwise "finished" below will return false.
+ assert(time > 0)
finishTime = time
if (state == TaskState.FAILED) {
failed = true
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index 19ebaf817e24..11633bef3cfc 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -667,7 +667,7 @@ private[spark] class TaskSetManager(
*/
def handleTaskGettingResult(tid: Long): Unit = {
val info = taskInfos(tid)
- info.markGettingResult()
+ info.markGettingResult(clock.getTimeMillis())
sched.dagScheduler.taskGettingResult(info)
}
@@ -695,7 +695,7 @@ private[spark] class TaskSetManager(
def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]): Unit = {
val info = taskInfos(tid)
val index = info.index
- info.markFinished(TaskState.FINISHED)
+ info.markFinished(TaskState.FINISHED, clock.getTimeMillis())
removeRunningTask(tid)
// This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the
// "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not
@@ -739,7 +739,7 @@ private[spark] class TaskSetManager(
return
}
removeRunningTask(tid)
- info.markFinished(state)
+ info.markFinished(state, clock.getTimeMillis())
val index = info.index
copiesRunning(index) -= 1
var accumUpdates: Seq[AccumulatorV2[_, _]] = Seq.empty
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index 94abe30bb12f..7e2cfaccfc7b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -222,12 +222,18 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
// Make fake resource offers on all executors
private def makeOffers() {
- // Filter out executors under killing
- val activeExecutors = executorDataMap.filterKeys(executorIsAlive)
- val workOffers = activeExecutors.map { case (id, executorData) =>
- new WorkerOffer(id, executorData.executorHost, executorData.freeCores)
- }.toIndexedSeq
- launchTasks(scheduler.resourceOffers(workOffers))
+ // Make sure no executor is killed while some task is launching on it
+ val taskDescs = CoarseGrainedSchedulerBackend.this.synchronized {
+ // Filter out executors under killing
+ val activeExecutors = executorDataMap.filterKeys(executorIsAlive)
+ val workOffers = activeExecutors.map { case (id, executorData) =>
+ new WorkerOffer(id, executorData.executorHost, executorData.freeCores)
+ }.toIndexedSeq
+ scheduler.resourceOffers(workOffers)
+ }
+ if (!taskDescs.isEmpty) {
+ launchTasks(taskDescs)
+ }
}
override def onDisconnected(remoteAddress: RpcAddress): Unit = {
@@ -240,12 +246,20 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
// Make fake resource offers on just one executor
private def makeOffers(executorId: String) {
- // Filter out executors under killing
- if (executorIsAlive(executorId)) {
- val executorData = executorDataMap(executorId)
- val workOffers = IndexedSeq(
- new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores))
- launchTasks(scheduler.resourceOffers(workOffers))
+ // Make sure no executor is killed while some task is launching on it
+ val taskDescs = CoarseGrainedSchedulerBackend.this.synchronized {
+ // Filter out executors under killing
+ if (executorIsAlive(executorId)) {
+ val executorData = executorDataMap(executorId)
+ val workOffers = IndexedSeq(
+ new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores))
+ scheduler.resourceOffers(workOffers)
+ } else {
+ Seq.empty
+ }
+ }
+ if (!taskDescs.isEmpty) {
+ launchTasks(taskDescs)
}
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala
index 9f1fe0515732..97487ce1d2ca 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.scheduler
+import java.io.{ByteArrayOutputStream, DataOutputStream, UTFDataFormatException}
import java.nio.ByteBuffer
import java.util.Properties
@@ -36,6 +37,21 @@ class TaskDescriptionSuite extends SparkFunSuite {
val originalProperties = new Properties()
originalProperties.put("property1", "18")
originalProperties.put("property2", "test value")
+ // SPARK-19796 -- large property values (like a large job description for a long sql query)
+ // can cause problems for DataOutputStream, make sure we handle correctly
+ val sb = new StringBuilder()
+ (0 to 10000).foreach(_ => sb.append("1234567890"))
+ val largeString = sb.toString()
+ originalProperties.put("property3", largeString)
+ // make sure we've got a good test case
+ intercept[UTFDataFormatException] {
+ val out = new DataOutputStream(new ByteArrayOutputStream())
+ try {
+ out.writeUTF(largeString)
+ } finally {
+ out.close()
+ }
+ }
// Create a dummy byte buffer for the task.
val taskBuffer = ByteBuffer.wrap(Array[Byte](1, 2, 3, 4))
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index 2c2cda9f318e..f36bcd8504b0 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -192,6 +192,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
val taskOption = manager.resourceOffer("exec1", "host1", NO_PREF)
assert(taskOption.isDefined)
+ clock.advance(1)
// Tell it the task has finished
manager.handleSuccessfulTask(0, createTaskResult(0, accumUpdates))
assert(sched.endedTasks(0) === Success)
@@ -377,6 +378,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
sched = new FakeTaskScheduler(sc, ("exec1", "host1"))
val taskSet = FakeTask.createTaskSet(1)
val clock = new ManualClock
+ clock.advance(1)
val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock)
assert(manager.resourceOffer("exec1", "host1", ANY).get.index === 0)
@@ -394,6 +396,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
sched = new FakeTaskScheduler(sc, ("exec1", "host1"))
val taskSet = FakeTask.createTaskSet(1)
val clock = new ManualClock
+ clock.advance(1)
val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock)
// Fail the task MAX_TASK_FAILURES times, and check that the task set is aborted
@@ -427,6 +430,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
// affinity to exec1 on host1 - which we will fail.
val taskSet = FakeTask.createTaskSet(1, Seq(TaskLocation("host1", "exec1")))
val clock = new ManualClock
+ clock.advance(1)
// We don't directly use the application blacklist, but its presence triggers blacklisting
// within the taskset.
val mockListenerBus = mock(classOf[LiveListenerBus])
@@ -551,7 +555,9 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
Seq(TaskLocation("host1", "execB")),
Seq(TaskLocation("host2", "execC")),
Seq())
- val manager = new TaskSetManager(sched, taskSet, 1, clock = new ManualClock)
+ val clock = new ManualClock()
+ clock.advance(1)
+ val manager = new TaskSetManager(sched, taskSet, 1, clock = clock)
sched.addExecutor("execA", "host1")
manager.executorAdded()
sched.addExecutor("execC", "host2")
@@ -904,6 +910,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
assert(task.executorId === k)
}
assert(sched.startedTasks.toSet === Set(0, 1, 2, 3))
+ clock.advance(1)
// Complete the 3 tasks and leave 1 task in running
for (id <- Set(0, 1, 2)) {
manager.handleSuccessfulTask(id, createTaskResult(id, accumUpdatesByTask(id)))
@@ -961,6 +968,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
tasks += task
}
assert(sched.startedTasks.toSet === (0 until 5).toSet)
+ clock.advance(1)
// Complete 3 tasks and leave 2 tasks in running
for (id <- Set(0, 1, 2)) {
manager.handleSuccessfulTask(id, createTaskResult(id, accumUpdatesByTask(id)))
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
index ccede34b8cb4..75dc04038deb 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
@@ -489,12 +489,12 @@ class BlockManagerProactiveReplicationSuite extends BlockManagerReplicationBehav
Thread.sleep(200)
}
- // giving enough time for replication complete and locks released
- Thread.sleep(500)
-
- val newLocations = master.getLocations(blockId).toSet
+ val newLocations = eventually(timeout(5 seconds), interval(10 millis)) {
+ val _newLocations = master.getLocations(blockId).toSet
+ assert(_newLocations.size === replicationFactor)
+ _newLocations
+ }
logInfo(s"New locations : $newLocations")
- assert(newLocations.size === replicationFactor)
// there should only be one common block manager between initial and new locations
assert(newLocations.intersect(blockLocations.toSet).size === 1)
diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala
index 11482d187aec..38030e066080 100644
--- a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala
@@ -77,7 +77,7 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext {
val taskInfo = new TaskInfo(taskId, taskId, 0, 0, "0", "localhost", TaskLocality.ANY, false)
jobListener.onStageSubmitted(SparkListenerStageSubmitted(stageInfo))
jobListener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo))
- taskInfo.markFinished(TaskState.FINISHED)
+ taskInfo.markFinished(TaskState.FINISHED, System.currentTimeMillis())
val taskMetrics = TaskMetrics.empty
taskMetrics.incPeakExecutionMemory(peakExecutionMemory)
jobListener.onTaskEnd(
diff --git a/docs/ml-features.md b/docs/ml-features.md
index 57605bafbf4c..dad1c6db18f8 100644
--- a/docs/ml-features.md
+++ b/docs/ml-features.md
@@ -503,6 +503,7 @@ for more details on the API.
`StringIndexer` encodes a string column of labels to a column of label indices.
The indices are in `[0, numLabels)`, ordered by label frequencies, so the most frequent label gets index `0`.
+The unseen labels will be put at index numLabels if user chooses to keep them.
If the input column is numeric, we cast it to string and index the string
values. When downstream pipeline components such as `Estimator` or
`Transformer` make use of this string-indexed label, you must set the input
@@ -542,12 +543,13 @@ column, we should get the following:
"a" gets index `0` because it is the most frequent, followed by "c" with index `1` and "b" with
index `2`.
-Additionally, there are two strategies regarding how `StringIndexer` will handle
+Additionally, there are three strategies regarding how `StringIndexer` will handle
unseen labels when you have fit a `StringIndexer` on one dataset and then use it
to transform another:
- throw an exception (which is the default)
- skip the row containing the unseen label entirely
+- put unseen labels in a special additional bucket, at index numLabels
**Examples**
@@ -561,6 +563,7 @@ Let's go back to our previous example but this time reuse our previously defined
1 | b
2 | c
3 | d
+ 4 | e
~~~~
If you've not set how `StringIndexer` handles unseen labels or set it to
@@ -576,7 +579,22 @@ will be generated:
2 | c | 1.0
~~~~
-Notice that the row containing "d" does not appear.
+Notice that the rows containing "d" or "e" do not appear.
+
+If you call `setHandleInvalid("keep")`, the following dataset
+will be generated:
+
+~~~~
+ id | category | categoryIndex
+----|----------|---------------
+ 0 | a | 0.0
+ 1 | b | 2.0
+ 2 | c | 1.0
+ 3 | d | 3.0
+ 4 | e | 3.0
+~~~~
+
+Notice that the rows containing "d" or "e" are mapped to index "3.0"
diff --git a/docs/quick-start.md b/docs/quick-start.md
index aa4319a23325..b88ae5f6bb31 100644
--- a/docs/quick-start.md
+++ b/docs/quick-start.md
@@ -10,12 +10,13 @@ description: Quick start tutorial for Spark SPARK_VERSION_SHORT
This tutorial provides a quick introduction to using Spark. We will first introduce the API through Spark's
interactive shell (in Python or Scala),
then show how to write applications in Java, Scala, and Python.
-See the [programming guide](programming-guide.html) for a more complete reference.
To follow along with this guide, first download a packaged release of Spark from the
[Spark website](http://spark.apache.org/downloads.html). Since we won't be using HDFS,
you can download a package for any version of Hadoop.
+Note that, before Spark 2.0, the main programming interface of Spark was the Resilient Distributed Dataset (RDD). After Spark 2.0, RDDs are replaced by Dataset, which is strongly-typed like an RDD, but with richer optimizations under the hood. The RDD interface is still supported, and you can get a more complete reference at the [RDD programming guide](rdd-programming-guide.html). However, we highly recommend you to switch to use Dataset, which has better performance than RDD. See the [SQL programming guide](sql-programming-guide.html) to get more information about Dataset.
+
# Interactive Analysis with the Spark Shell
## Basics
@@ -29,28 +30,28 @@ or Python. Start it by running the following in the Spark directory:
./bin/spark-shell
-Spark's primary abstraction is a distributed collection of items called a Resilient Distributed Dataset (RDD). RDDs can be created from Hadoop InputFormats (such as HDFS files) or by transforming other RDDs. Let's make a new RDD from the text of the README file in the Spark source directory:
+Spark's primary abstraction is a distributed collection of items called a Dataset. Datasets can be created from Hadoop InputFormats (such as HDFS files) or by transforming other Datasets. Let's make a new Dataset from the text of the README file in the Spark source directory:
{% highlight scala %}
-scala> val textFile = sc.textFile("README.md")
-textFile: org.apache.spark.rdd.RDD[String] = README.md MapPartitionsRDD[1] at textFile at :25
+scala> val textFile = spark.read.textFile("README.md")
+textFile: org.apache.spark.sql.Dataset[String] = [value: string]
{% endhighlight %}
-RDDs have _[actions](programming-guide.html#actions)_, which return values, and _[transformations](programming-guide.html#transformations)_, which return pointers to new RDDs. Let's start with a few actions:
+You can get values from Dataset directly, by calling some actions, or transform the Dataset to get a new one. For more details, please read the _[API doc](api/scala/index.html#org.apache.spark.sql.Dataset)_.
{% highlight scala %}
-scala> textFile.count() // Number of items in this RDD
+scala> textFile.count() // Number of items in this Dataset
res0: Long = 126 // May be different from yours as README.md will change over time, similar to other outputs
-scala> textFile.first() // First item in this RDD
+scala> textFile.first() // First item in this Dataset
res1: String = # Apache Spark
{% endhighlight %}
-Now let's use a transformation. We will use the [`filter`](programming-guide.html#transformations) transformation to return a new RDD with a subset of the items in the file.
+Now let's transform this Dataset to a new one. We call `filter` to return a new Dataset with a subset of the items in the file.
{% highlight scala %}
scala> val linesWithSpark = textFile.filter(line => line.contains("Spark"))
-linesWithSpark: org.apache.spark.rdd.RDD[String] = MapPartitionsRDD[2] at filter at :27
+linesWithSpark: org.apache.spark.sql.Dataset[String] = [value: string]
{% endhighlight %}
We can chain together transformations and actions:
@@ -65,32 +66,32 @@ res3: Long = 15
./bin/pyspark
-Spark's primary abstraction is a distributed collection of items called a Resilient Distributed Dataset (RDD). RDDs can be created from Hadoop InputFormats (such as HDFS files) or by transforming other RDDs. Let's make a new RDD from the text of the README file in the Spark source directory:
+Spark's primary abstraction is a distributed collection of items called a Dataset. Datasets can be created from Hadoop InputFormats (such as HDFS files) or by transforming other Datasets. Due to Python's dynamic nature, we don't need the Dataset to be strongly-typed in Python. As a result, all Datasets in Python are Dataset[Row], and we call it `DataFrame` to be consistent with the data frame concept in Pandas and R. Let's make a new DataFrame from the text of the README file in the Spark source directory:
{% highlight python %}
->>> textFile = sc.textFile("README.md")
+>>> textFile = spark.read.text("README.md")
{% endhighlight %}
-RDDs have _[actions](programming-guide.html#actions)_, which return values, and _[transformations](programming-guide.html#transformations)_, which return pointers to new RDDs. Let's start with a few actions:
+You can get values from DataFrame directly, by calling some actions, or transform the DataFrame to get a new one. For more details, please read the _[API doc](api/python/index.html#pyspark.sql.DataFrame)_.
{% highlight python %}
->>> textFile.count() # Number of items in this RDD
+>>> textFile.count() # Number of rows in this DataFrame
126
->>> textFile.first() # First item in this RDD
-u'# Apache Spark'
+>>> textFile.first() # First row in this DataFrame
+Row(value=u'# Apache Spark')
{% endhighlight %}
-Now let's use a transformation. We will use the [`filter`](programming-guide.html#transformations) transformation to return a new RDD with a subset of the items in the file.
+Now let's transform this DataFrame to a new one. We call `filter` to return a new DataFrame with a subset of the lines in the file.
{% highlight python %}
->>> linesWithSpark = textFile.filter(lambda line: "Spark" in line)
+>>> linesWithSpark = textFile.filter(textFile.value.contains("Spark"))
{% endhighlight %}
We can chain together transformations and actions:
{% highlight python %}
->>> textFile.filter(lambda line: "Spark" in line).count() # How many lines contain "Spark"?
+>>> textFile.filter(textFile.value.contains("Spark")).count() # How many lines contain "Spark"?
15
{% endhighlight %}
@@ -98,8 +99,8 @@ We can chain together transformations and actions:
-## More on RDD Operations
-RDD actions and transformations can be used for more complex computations. Let's say we want to find the line with the most words:
+## More on Dataset Operations
+Dataset actions and transformations can be used for more complex computations. Let's say we want to find the line with the most words:
@@ -109,7 +110,7 @@ scala> textFile.map(line => line.split(" ").size).reduce((a, b) => if (a > b) a
res4: Long = 15
{% endhighlight %}
-This first maps a line to an integer value, creating a new RDD. `reduce` is called on that RDD to find the largest line count. The arguments to `map` and `reduce` are Scala function literals (closures), and can use any language feature or Scala/Java library. For example, we can easily call functions declared elsewhere. We'll use `Math.max()` function to make this code easier to understand:
+This first maps a line to an integer value, creating a new Dataset. `reduce` is called on that Dataset to find the largest word count. The arguments to `map` and `reduce` are Scala function literals (closures), and can use any language feature or Scala/Java library. For example, we can easily call functions declared elsewhere. We'll use `Math.max()` function to make this code easier to understand:
{% highlight scala %}
scala> import java.lang.Math
@@ -122,11 +123,11 @@ res5: Int = 15
One common data flow pattern is MapReduce, as popularized by Hadoop. Spark can implement MapReduce flows easily:
{% highlight scala %}
-scala> val wordCounts = textFile.flatMap(line => line.split(" ")).map(word => (word, 1)).reduceByKey((a, b) => a + b)
-wordCounts: org.apache.spark.rdd.RDD[(String, Int)] = ShuffledRDD[8] at reduceByKey at
:28
+scala> val wordCounts = textFile.flatMap(line => line.split(" ")).groupByKey(identity).count()
+wordCounts: org.apache.spark.sql.Dataset[(String, Long)] = [value: string, count(1): bigint]
{% endhighlight %}
-Here, we combined the [`flatMap`](programming-guide.html#transformations), [`map`](programming-guide.html#transformations), and [`reduceByKey`](programming-guide.html#transformations) transformations to compute the per-word counts in the file as an RDD of (String, Int) pairs. To collect the word counts in our shell, we can use the [`collect`](programming-guide.html#actions) action:
+Here, we call `flatMap` to transform a Dataset of lines to a Dataset of words, and then combine `groupByKey` and `count` to compute the per-word counts in the file as a Dataset of (String, Long) pairs. To collect the word counts in our shell, we can call `collect`:
{% highlight scala %}
scala> wordCounts.collect()
@@ -137,37 +138,24 @@ res6: Array[(String, Int)] = Array((means,1), (under,2), (this,3), (Because,1),
{% highlight python %}
->>> textFile.map(lambda line: len(line.split())).reduce(lambda a, b: a if (a > b) else b)
-15
+>>> from pyspark.sql.functions import *
+>>> textFile.select(size(split(textFile.value, "\s+")).name("numWords")).agg(max(col("numWords"))).collect()
+[Row(max(numWords)=15)]
{% endhighlight %}
-This first maps a line to an integer value, creating a new RDD. `reduce` is called on that RDD to find the largest line count. The arguments to `map` and `reduce` are Python [anonymous functions (lambdas)](https://docs.python.org/2/reference/expressions.html#lambda),
-but we can also pass any top-level Python function we want.
-For example, we'll define a `max` function to make this code easier to understand:
-
-{% highlight python %}
->>> def max(a, b):
-... if a > b:
-... return a
-... else:
-... return b
-...
-
->>> textFile.map(lambda line: len(line.split())).reduce(max)
-15
-{% endhighlight %}
+This first maps a line to an integer value and aliases it as "numWords", creating a new DataFrame. `agg` is called on that DataFrame to find the largest word count. The arguments to `select` and `agg` are both _[Column](api/python/index.html#pyspark.sql.Column)_, we can use `df.colName` to get a column from a DataFrame. We can also import pyspark.sql.functions, which provides a lot of convenient functions to build a new Column from an old one.
One common data flow pattern is MapReduce, as popularized by Hadoop. Spark can implement MapReduce flows easily:
{% highlight python %}
->>> wordCounts = textFile.flatMap(lambda line: line.split()).map(lambda word: (word, 1)).reduceByKey(lambda a, b: a+b)
+>>> wordCounts = textFile.select(explode(split(textFile.value, "\s+")).as("word")).groupBy("word").count()
{% endhighlight %}
-Here, we combined the [`flatMap`](programming-guide.html#transformations), [`map`](programming-guide.html#transformations), and [`reduceByKey`](programming-guide.html#transformations) transformations to compute the per-word counts in the file as an RDD of (string, int) pairs. To collect the word counts in our shell, we can use the [`collect`](programming-guide.html#actions) action:
+Here, we use the `explode` function in `select`, to transfrom a Dataset of lines to a Dataset of words, and then combine `groupBy` and `count` to compute the per-word counts in the file as a DataFrame of 2 columns: "word" and "count". To collect the word counts in our shell, we can call `collect`:
{% highlight python %}
>>> wordCounts.collect()
-[(u'and', 9), (u'A', 1), (u'webpage', 1), (u'README', 1), (u'Note', 1), (u'"local"', 1), (u'variable', 1), ...]
+[Row(word=u'online', count=1), Row(word=u'graphs', count=1), ...]
{% endhighlight %}
@@ -181,7 +169,7 @@ Spark also supports pulling data sets into a cluster-wide in-memory cache. This
{% highlight scala %}
scala> linesWithSpark.cache()
-res7: linesWithSpark.type = MapPartitionsRDD[2] at filter at :27
+res7: linesWithSpark.type = [value: string]
scala> linesWithSpark.count()
res8: Long = 15
@@ -193,7 +181,7 @@ res9: Long = 15
It may seem silly to use Spark to explore and cache a 100-line text file. The interesting part is
that these same functions can be used on very large data sets, even when they are striped across
tens or hundreds of nodes. You can also do this interactively by connecting `bin/spark-shell` to
-a cluster, as described in the [programming guide](programming-guide.html#initializing-spark).
+a cluster, as described in the [RDD programming guide](rdd-programming-guide.html#using-the-shell).
@@ -211,7 +199,7 @@ a cluster, as described in the [programming guide](programming-guide.html#initia
It may seem silly to use Spark to explore and cache a 100-line text file. The interesting part is
that these same functions can be used on very large data sets, even when they are striped across
tens or hundreds of nodes. You can also do this interactively by connecting `bin/pyspark` to
-a cluster, as described in the [programming guide](programming-guide.html#initializing-spark).
+a cluster, as described in the [RDD programming guide](rdd-programming-guide.html#using-the-shell).
@@ -228,20 +216,17 @@ named `SimpleApp.scala`:
{% highlight scala %}
/* SimpleApp.scala */
-import org.apache.spark.SparkContext
-import org.apache.spark.SparkContext._
-import org.apache.spark.SparkConf
+import org.apache.spark.sql.SparkSession
object SimpleApp {
def main(args: Array[String]) {
val logFile = "YOUR_SPARK_HOME/README.md" // Should be some file on your system
- val conf = new SparkConf().setAppName("Simple Application")
- val sc = new SparkContext(conf)
- val logData = sc.textFile(logFile, 2).cache()
+ val spark = SparkSession.builder.appName("Simple Application").getOrCreate()
+ val logData = spark.read.textFile(logFile).cache()
val numAs = logData.filter(line => line.contains("a")).count()
val numBs = logData.filter(line => line.contains("b")).count()
println(s"Lines with a: $numAs, Lines with b: $numBs")
- sc.stop()
+ spark.stop()
}
}
{% endhighlight %}
@@ -251,16 +236,13 @@ Subclasses of `scala.App` may not work correctly.
This program just counts the number of lines containing 'a' and the number containing 'b' in the
Spark README. Note that you'll need to replace YOUR_SPARK_HOME with the location where Spark is
-installed. Unlike the earlier examples with the Spark shell, which initializes its own SparkContext,
-we initialize a SparkContext as part of the program.
+installed. Unlike the earlier examples with the Spark shell, which initializes its own SparkSession,
+we initialize a SparkSession as part of the program.
-We pass the SparkContext constructor a
-[SparkConf](api/scala/index.html#org.apache.spark.SparkConf)
-object which contains information about our
-application.
+We call `SparkSession.builder` to construct a [[SparkSession]], then set the application name, and finally call `getOrCreate` to get the [[SparkSession]] instance.
-Our application depends on the Spark API, so we'll also include an sbt configuration file,
-`build.sbt`, which explains that Spark is a dependency. This file also adds a repository that
+Our application depends on the Spark API, so we'll also include an sbt configuration file,
+`build.sbt`, which explains that Spark is a dependency. This file also adds a repository that
Spark depends on:
{% highlight scala %}
@@ -270,7 +252,7 @@ version := "1.0"
scalaVersion := "{{site.SCALA_VERSION}}"
-libraryDependencies += "org.apache.spark" %% "spark-core" % "{{site.SPARK_VERSION}}"
+libraryDependencies += "org.apache.spark" %% "spark-sql" % "{{site.SPARK_VERSION}}"
{% endhighlight %}
For sbt to work correctly, we'll need to layout `SimpleApp.scala` and `build.sbt`
@@ -309,34 +291,28 @@ We'll create a very simple Spark application, `SimpleApp.java`:
{% highlight java %}
/* SimpleApp.java */
-import org.apache.spark.api.java.*;
-import org.apache.spark.SparkConf;
-import org.apache.spark.api.java.function.Function;
+import org.apache.spark.sql.SparkSession;
public class SimpleApp {
public static void main(String[] args) {
String logFile = "YOUR_SPARK_HOME/README.md"; // Should be some file on your system
- SparkConf conf = new SparkConf().setAppName("Simple Application");
- JavaSparkContext sc = new JavaSparkContext(conf);
- JavaRDD logData = sc.textFile(logFile).cache();
+ SparkSession spark = SparkSession.builder().appName("Simple Application").getOrCreate();
+ Dataset logData = spark.read.textFile(logFile).cache();
long numAs = logData.filter(s -> s.contains("a")).count();
long numBs = logData.filter(s -> s.contains("b")).count();
System.out.println("Lines with a: " + numAs + ", lines with b: " + numBs);
-
- sc.stop();
+
+ spark.stop();
}
}
{% endhighlight %}
-This program just counts the number of lines containing 'a' and the number containing 'b' in a text
-file. Note that you'll need to replace YOUR_SPARK_HOME with the location where Spark is installed.
-As with the Scala example, we initialize a SparkContext, though we use the special
-`JavaSparkContext` class to get a Java-friendly one. We also create RDDs (represented by
-`JavaRDD`) and run transformations on them. Finally, we pass functions to Spark by creating classes
-that extend `spark.api.java.function.Function`. The
-[Spark programming guide](programming-guide.html) describes these differences in more detail.
+This program just counts the number of lines containing 'a' and the number containing 'b' in the
+Spark README. Note that you'll need to replace YOUR_SPARK_HOME with the location where Spark is
+installed. Unlike the earlier examples with the Spark shell, which initializes its own SparkSession,
+we initialize a SparkSession as part of the program.
To build the program, we also write a Maven `pom.xml` file that lists Spark as a dependency.
Note that Spark artifacts are tagged with a Scala version.
@@ -352,7 +328,7 @@ Note that Spark artifacts are tagged with a Scala version.
org.apache.spark
- spark-core_{{site.SCALA_BINARY_VERSION}}
+ spark-sql_{{site.SCALA_BINARY_VERSION}}
{{site.SPARK_VERSION}}
@@ -395,27 +371,25 @@ As an example, we'll create a simple Spark application, `SimpleApp.py`:
{% highlight python %}
"""SimpleApp.py"""
-from pyspark import SparkContext
+from pyspark.sql import SparkSession
logFile = "YOUR_SPARK_HOME/README.md" # Should be some file on your system
-sc = SparkContext("local", "Simple App")
-logData = sc.textFile(logFile).cache()
+spark = SparkSession.builder().appName(appName).master(master).getOrCreate()
+logData = spark.read.text(logFile).cache()
-numAs = logData.filter(lambda s: 'a' in s).count()
-numBs = logData.filter(lambda s: 'b' in s).count()
+numAs = logData.filter(logData.value.contains('a')).count()
+numBs = logData.filter(logData.value.contains('b')).count()
print("Lines with a: %i, lines with b: %i" % (numAs, numBs))
-sc.stop()
+spark.stop()
{% endhighlight %}
This program just counts the number of lines containing 'a' and the number containing 'b' in a
text file.
Note that you'll need to replace YOUR_SPARK_HOME with the location where Spark is installed.
-As with the Scala and Java examples, we use a SparkContext to create RDDs.
-We can pass Python functions to Spark, which are automatically serialized along with any variables
-that they reference.
+As with the Scala and Java examples, we use a SparkSession to create Datasets.
For applications that use custom classes or third-party libraries, we can also add code
dependencies to `spark-submit` through its `--py-files` argument by packaging them into a
.zip file (see `spark-submit --help` for details).
@@ -438,8 +412,7 @@ Lines with a: 46, Lines with b: 23
# Where to Go from Here
Congratulations on running your first Spark application!
-* For an in-depth overview of the API, start with the [Spark programming guide](programming-guide.html),
- or see "Programming Guides" menu for other components.
+* For an in-depth overview of the API, start with the [RDD programming guide](rdd-programming-guide.html) and the [SQL programming guide](sql-programming-guide.html), or see "Programming Guides" menu for other components.
* For running applications on a cluster, head to the [deployment overview](cluster-overview.html).
* Finally, Spark includes several samples in the `examples` directory
([Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples),
diff --git a/docs/programming-guide.md b/docs/rdd-programming-guide.md
similarity index 99%
rename from docs/programming-guide.md
rename to docs/rdd-programming-guide.md
index 6740dbe0014b..e2bf2d7ca77c 100644
--- a/docs/programming-guide.md
+++ b/docs/rdd-programming-guide.md
@@ -24,7 +24,7 @@ along with if you launch Spark's interactive shell -- either `bin/spark-shell` f
-Spark {{site.SPARK_VERSION}} is built and distributed to work with Scala {{site.SCALA_BINARY_VERSION}}
+Spark {{site.SPARK_VERSION}} is built and distributed to work with Scala {{site.SCALA_BINARY_VERSION}}
by default. (Spark can be built to work with other versions of Scala, too.) To write
applications in Scala, you will need to use a compatible Scala version (e.g. {{site.SCALA_BINARY_VERSION}}.X).
@@ -76,10 +76,10 @@ In addition, if you wish to access an HDFS cluster, you need to add a dependency
Finally, you need to import some Spark classes into your program. Add the following lines:
-{% highlight scala %}
-import org.apache.spark.api.java.JavaSparkContext
-import org.apache.spark.api.java.JavaRDD
-import org.apache.spark.SparkConf
+{% highlight java %}
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.SparkConf;
{% endhighlight %}
@@ -244,13 +244,13 @@ use IPython, set the `PYSPARK_DRIVER_PYTHON` variable to `ipython` when running
$ PYSPARK_DRIVER_PYTHON=ipython ./bin/pyspark
{% endhighlight %}
-To use the Jupyter notebook (previously known as the IPython notebook),
+To use the Jupyter notebook (previously known as the IPython notebook),
{% highlight bash %}
$ PYSPARK_DRIVER_PYTHON=jupyter ./bin/pyspark
{% endhighlight %}
-You can customize the `ipython` or `jupyter` commands by setting `PYSPARK_DRIVER_PYTHON_OPTS`.
+You can customize the `ipython` or `jupyter` commands by setting `PYSPARK_DRIVER_PYTHON_OPTS`.
After the Jupyter Notebook server is launched, you can create a new "Python 2" notebook from
the "Files" tab. Inside the notebook, you can input the command `%pylab inline` as part of
@@ -457,7 +457,7 @@ If required, a Hadoop configuration can be passed in as a Python dict. Here is a
Elasticsearch ESInputFormat:
{% highlight python %}
-$ SPARK_CLASSPATH=/path/to/elasticsearch-hadoop.jar ./bin/pyspark
+$ ./bin/pyspark --jars /path/to/elasticsearch-hadoop.jar
>>> conf = {"es.resource" : "index/type"} # assume Elasticsearch is running on localhost defaults
>>> rdd = sc.newAPIHadoopRDD("org.elasticsearch.hadoop.mr.EsInputFormat",
"org.apache.hadoop.io.NullWritable",
@@ -811,7 +811,7 @@ The variables within the closure sent to each executor are now copies and thus,
In local mode, in some circumstances the `foreach` function will actually execute within the same JVM as the driver and will reference the same original **counter**, and may actually update it.
-To ensure well-defined behavior in these sorts of scenarios one should use an [`Accumulator`](#accumulators). Accumulators in Spark are used specifically to provide a mechanism for safely updating a variable when execution is split up across worker nodes in a cluster. The Accumulators section of this guide discusses these in more detail.
+To ensure well-defined behavior in these sorts of scenarios one should use an [`Accumulator`](#accumulators). Accumulators in Spark are used specifically to provide a mechanism for safely updating a variable when execution is split up across worker nodes in a cluster. The Accumulators section of this guide discusses these in more detail.
In general, closures - constructs like loops or locally defined methods, should not be used to mutate some global state. Spark does not define or guarantee the behavior of mutations to objects referenced from outside of closures. Some code that does this may work in local mode, but that's just by accident and such code will not behave as expected in distributed mode. Use an Accumulator instead if some global aggregation is needed.
@@ -1230,8 +1230,8 @@ storage levels is:
-**Note:** *In Python, stored objects will always be serialized with the [Pickle](https://docs.python.org/2/library/pickle.html) library,
-so it does not matter whether you choose a serialized level. The available storage levels in Python include `MEMORY_ONLY`, `MEMORY_ONLY_2`,
+**Note:** *In Python, stored objects will always be serialized with the [Pickle](https://docs.python.org/2/library/pickle.html) library,
+so it does not matter whether you choose a serialized level. The available storage levels in Python include `MEMORY_ONLY`, `MEMORY_ONLY_2`,
`MEMORY_AND_DISK`, `MEMORY_AND_DISK_2`, `DISK_ONLY`, and `DISK_ONLY_2`.*
Spark also automatically persists some intermediate data in shuffle operations (e.g. `reduceByKey`), even without users calling `persist`. This is done to avoid recomputing the entire input if a node fails during the shuffle. We still recommend users call `persist` on the resulting RDD if they plan to reuse it.
@@ -1346,7 +1346,7 @@ As a user, you can create named or unnamed accumulators. As seen in the image be
-Tracking accumulators in the UI can be useful for understanding the progress of
+Tracking accumulators in the UI can be useful for understanding the progress of
running stages (NOTE: this is not yet supported in Python).
@@ -1355,7 +1355,7 @@ running stages (NOTE: this is not yet supported in Python).
A numeric accumulator can be created by calling `SparkContext.longAccumulator()` or `SparkContext.doubleAccumulator()`
to accumulate values of type Long or Double, respectively. Tasks running on a cluster can then add to it using
-the `add` method. However, they cannot read its value. Only the driver program can read the accumulator's value,
+the `add` method. However, they cannot read its value. Only the driver program can read the accumulator's value,
using its `value` method.
The code below shows an accumulator being used to add up the elements of an array:
@@ -1409,7 +1409,7 @@ Note that, when programmers define their own type of AccumulatorV2, the resultin
A numeric accumulator can be created by calling `SparkContext.longAccumulator()` or `SparkContext.doubleAccumulator()`
to accumulate values of type Long or Double, respectively. Tasks running on a cluster can then add to it using
-the `add` method. However, they cannot read its value. Only the driver program can read the accumulator's value,
+the `add` method. However, they cannot read its value. Only the driver program can read the accumulator's value,
using its `value` method.
The code below shows an accumulator being used to add up the elements of an array:
diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md
index 6af47b6efba2..995ac77a4fb3 100644
--- a/docs/structured-streaming-programming-guide.md
+++ b/docs/structured-streaming-programming-guide.md
@@ -1052,10 +1052,18 @@ Here are the details of all the sinks in Spark.
Append |
path: path to the output directory, must be specified.
+
maxFilesPerTrigger: maximum number of new files to be considered in every trigger (default: no max)
- latestFirst: whether to processs the latest new files first, useful when there is a large backlog of files(default: false)
-
+ latestFirst: whether to processs the latest new files first, useful when there is a large backlog of files (default: false)
+
+ fileNameOnly: whether to check new files based on only the filename instead of on the full path (default: false). With this set to `true`, the following files would be considered as the same file, because their filenames, "dataset.txt", are the same:
+
+ · "file:///dataset.txt"
+ · "s3://a/dataset.txt"
+ · "s3n://a/b/dataset.txt"
+ · "s3a://a/b/c/dataset.txt"
+
For file-format-specific options, see the related methods in DataFrameWriter
(Scala/Java/Python).
E.g. for "parquet" format options see DataFrameWriter.parquet()
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala
index 15b28256e825..6d76904fb0e5 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala
@@ -273,19 +273,7 @@ private[kafka010] case class CachedKafkaConsumer private(
message: String,
cause: Throwable = null): Unit = {
val finalMessage = s"$message ${additionalMessage(failOnDataLoss)}"
- if (failOnDataLoss) {
- if (cause != null) {
- throw new IllegalStateException(finalMessage)
- } else {
- throw new IllegalStateException(finalMessage, cause)
- }
- } else {
- if (cause != null) {
- logWarning(finalMessage)
- } else {
- logWarning(finalMessage, cause)
- }
- }
+ reportDataLoss0(failOnDataLoss, finalMessage, cause)
}
private def close(): Unit = consumer.close()
@@ -398,4 +386,23 @@ private[kafka010] object CachedKafkaConsumer extends Logging {
consumer
}
}
+
+ private def reportDataLoss0(
+ failOnDataLoss: Boolean,
+ finalMessage: String,
+ cause: Throwable = null): Unit = {
+ if (failOnDataLoss) {
+ if (cause != null) {
+ throw new IllegalStateException(finalMessage, cause)
+ } else {
+ throw new IllegalStateException(finalMessage)
+ }
+ } else {
+ if (cause != null) {
+ logWarning(finalMessage, cause)
+ } else {
+ logWarning(finalMessage)
+ }
+ }
+ }
}
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSink.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSink.scala
new file mode 100644
index 000000000000..08914d82fffd
--- /dev/null
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSink.scala
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.kafka010
+
+import java.{util => ju}
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.{DataFrame, SQLContext}
+import org.apache.spark.sql.execution.streaming.Sink
+
+private[kafka010] class KafkaSink(
+ sqlContext: SQLContext,
+ executorKafkaParams: ju.Map[String, Object],
+ topic: Option[String]) extends Sink with Logging {
+ @volatile private var latestBatchId = -1L
+
+ override def toString(): String = "KafkaSink"
+
+ override def addBatch(batchId: Long, data: DataFrame): Unit = {
+ if (batchId <= latestBatchId) {
+ logInfo(s"Skipping already committed batch $batchId")
+ } else {
+ KafkaWriter.write(sqlContext.sparkSession,
+ data.queryExecution, executorKafkaParams, topic)
+ latestBatchId = batchId
+ }
+ }
+}
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala
index 6a7456719875..febe3c217122 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala
@@ -23,12 +23,14 @@ import java.util.UUID
import scala.collection.JavaConverters._
import org.apache.kafka.clients.consumer.ConsumerConfig
-import org.apache.kafka.common.serialization.ByteArrayDeserializer
+import org.apache.kafka.clients.producer.ProducerConfig
+import org.apache.kafka.common.serialization.{ByteArrayDeserializer, ByteArraySerializer}
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.SQLContext
-import org.apache.spark.sql.execution.streaming.Source
+import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext}
+import org.apache.spark.sql.execution.streaming.{Sink, Source}
import org.apache.spark.sql.sources._
+import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType
/**
@@ -36,8 +38,12 @@ import org.apache.spark.sql.types.StructType
* IllegalArgumentException when the Kafka Dataset is created, so that it can catch
* missing options even before the query is started.
*/
-private[kafka010] class KafkaSourceProvider extends DataSourceRegister with StreamSourceProvider
- with RelationProvider with Logging {
+private[kafka010] class KafkaSourceProvider extends DataSourceRegister
+ with StreamSourceProvider
+ with StreamSinkProvider
+ with RelationProvider
+ with CreatableRelationProvider
+ with Logging {
import KafkaSourceProvider._
override def shortName(): String = "kafka"
@@ -152,6 +158,72 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister with Stre
endingRelationOffsets)
}
+ override def createSink(
+ sqlContext: SQLContext,
+ parameters: Map[String, String],
+ partitionColumns: Seq[String],
+ outputMode: OutputMode): Sink = {
+ val defaultTopic = parameters.get(TOPIC_OPTION_KEY).map(_.trim)
+ val specifiedKafkaParams = kafkaParamsForProducer(parameters)
+ new KafkaSink(sqlContext,
+ new ju.HashMap[String, Object](specifiedKafkaParams.asJava), defaultTopic)
+ }
+
+ override def createRelation(
+ outerSQLContext: SQLContext,
+ mode: SaveMode,
+ parameters: Map[String, String],
+ data: DataFrame): BaseRelation = {
+ mode match {
+ case SaveMode.Overwrite | SaveMode.Ignore =>
+ throw new AnalysisException(s"Save mode $mode not allowed for Kafka. " +
+ s"Allowed save modes are ${SaveMode.Append} and " +
+ s"${SaveMode.ErrorIfExists} (default).")
+ case _ => // good
+ }
+ val topic = parameters.get(TOPIC_OPTION_KEY).map(_.trim)
+ val specifiedKafkaParams = kafkaParamsForProducer(parameters)
+ KafkaWriter.write(outerSQLContext.sparkSession, data.queryExecution,
+ new ju.HashMap[String, Object](specifiedKafkaParams.asJava), topic)
+
+ /* This method is suppose to return a relation that reads the data that was written.
+ * We cannot support this for Kafka. Therefore, in order to make things consistent,
+ * we return an empty base relation.
+ */
+ new BaseRelation {
+ override def sqlContext: SQLContext = unsupportedException
+ override def schema: StructType = unsupportedException
+ override def needConversion: Boolean = unsupportedException
+ override def sizeInBytes: Long = unsupportedException
+ override def unhandledFilters(filters: Array[Filter]): Array[Filter] = unsupportedException
+ private def unsupportedException =
+ throw new UnsupportedOperationException("BaseRelation from Kafka write " +
+ "operation is not usable.")
+ }
+ }
+
+ private def kafkaParamsForProducer(parameters: Map[String, String]): Map[String, String] = {
+ val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase, v) }
+ if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) {
+ throw new IllegalArgumentException(
+ s"Kafka option '${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}' is not supported as keys "
+ + "are serialized with ByteArraySerializer.")
+ }
+
+ if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}"))
+ {
+ throw new IllegalArgumentException(
+ s"Kafka option '${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}' is not supported as "
+ + "value are serialized with ByteArraySerializer.")
+ }
+ parameters
+ .keySet
+ .filter(_.toLowerCase.startsWith("kafka."))
+ .map { k => k.drop(6).toString -> parameters(k) }
+ .toMap + (ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName,
+ ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName)
+ }
+
private def kafkaParamsForDriver(specifiedKafkaParams: Map[String, String]) =
ConfigUpdater("source", specifiedKafkaParams)
.set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName)
@@ -381,6 +453,7 @@ private[kafka010] object KafkaSourceProvider {
private val STARTING_OFFSETS_OPTION_KEY = "startingoffsets"
private val ENDING_OFFSETS_OPTION_KEY = "endingoffsets"
private val FAIL_ON_DATA_LOSS_OPTION_KEY = "failondataloss"
+ val TOPIC_OPTION_KEY = "topic"
private val deserClassName = classOf[ByteArrayDeserializer].getName
}
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala
new file mode 100644
index 000000000000..6e160cbe2db5
--- /dev/null
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.kafka010
+
+import java.{util => ju}
+
+import org.apache.kafka.clients.producer.{KafkaProducer, _}
+import org.apache.kafka.common.serialization.ByteArraySerializer
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, UnsafeProjection}
+import org.apache.spark.sql.types.{BinaryType, StringType}
+
+/**
+ * A simple trait for writing out data in a single Spark task, without any concerns about how
+ * to commit or abort tasks. Exceptions thrown by the implementation of this class will
+ * automatically trigger task aborts.
+ */
+private[kafka010] class KafkaWriteTask(
+ producerConfiguration: ju.Map[String, Object],
+ inputSchema: Seq[Attribute],
+ topic: Option[String]) {
+ // used to synchronize with Kafka callbacks
+ @volatile private var failedWrite: Exception = null
+ private val projection = createProjection
+ private var producer: KafkaProducer[Array[Byte], Array[Byte]] = _
+
+ /**
+ * Writes key value data out to topics.
+ */
+ def execute(iterator: Iterator[InternalRow]): Unit = {
+ producer = new KafkaProducer[Array[Byte], Array[Byte]](producerConfiguration)
+ while (iterator.hasNext && failedWrite == null) {
+ val currentRow = iterator.next()
+ val projectedRow = projection(currentRow)
+ val topic = projectedRow.getUTF8String(0)
+ val key = projectedRow.getBinary(1)
+ val value = projectedRow.getBinary(2)
+ if (topic == null) {
+ throw new NullPointerException(s"null topic present in the data. Use the " +
+ s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.")
+ }
+ val record = new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, key, value)
+ val callback = new Callback() {
+ override def onCompletion(recordMetadata: RecordMetadata, e: Exception): Unit = {
+ if (failedWrite == null && e != null) {
+ failedWrite = e
+ }
+ }
+ }
+ producer.send(record, callback)
+ }
+ }
+
+ def close(): Unit = {
+ if (producer != null) {
+ checkForErrors
+ producer.close()
+ checkForErrors
+ producer = null
+ }
+ }
+
+ private def createProjection: UnsafeProjection = {
+ val topicExpression = topic.map(Literal(_)).orElse {
+ inputSchema.find(_.name == KafkaWriter.TOPIC_ATTRIBUTE_NAME)
+ }.getOrElse {
+ throw new IllegalStateException(s"topic option required when no " +
+ s"'${KafkaWriter.TOPIC_ATTRIBUTE_NAME}' attribute is present")
+ }
+ topicExpression.dataType match {
+ case StringType => // good
+ case t =>
+ throw new IllegalStateException(s"${KafkaWriter.TOPIC_ATTRIBUTE_NAME} " +
+ s"attribute unsupported type $t. ${KafkaWriter.TOPIC_ATTRIBUTE_NAME} " +
+ s"must be a ${StringType}")
+ }
+ val keyExpression = inputSchema.find(_.name == KafkaWriter.KEY_ATTRIBUTE_NAME)
+ .getOrElse(Literal(null, BinaryType))
+ keyExpression.dataType match {
+ case StringType | BinaryType => // good
+ case t =>
+ throw new IllegalStateException(s"${KafkaWriter.KEY_ATTRIBUTE_NAME} " +
+ s"attribute unsupported type $t")
+ }
+ val valueExpression = inputSchema
+ .find(_.name == KafkaWriter.VALUE_ATTRIBUTE_NAME).getOrElse(
+ throw new IllegalStateException(s"Required attribute " +
+ s"'${KafkaWriter.VALUE_ATTRIBUTE_NAME}' not found")
+ )
+ valueExpression.dataType match {
+ case StringType | BinaryType => // good
+ case t =>
+ throw new IllegalStateException(s"${KafkaWriter.VALUE_ATTRIBUTE_NAME} " +
+ s"attribute unsupported type $t")
+ }
+ UnsafeProjection.create(
+ Seq(topicExpression, Cast(keyExpression, BinaryType),
+ Cast(valueExpression, BinaryType)), inputSchema)
+ }
+
+ private def checkForErrors: Unit = {
+ if (failedWrite != null) {
+ throw failedWrite
+ }
+ }
+}
+
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala
new file mode 100644
index 000000000000..a637d52c933a
--- /dev/null
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.kafka010
+
+import java.{util => ju}
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.{AnalysisException, SparkSession}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.execution.{QueryExecution, SQLExecution}
+import org.apache.spark.sql.types.{BinaryType, StringType}
+import org.apache.spark.util.Utils
+
+/**
+ * The [[KafkaWriter]] class is used to write data from a batch query
+ * or structured streaming query, given by a [[QueryExecution]], to Kafka.
+ * The data is assumed to have a value column, and an optional topic and key
+ * columns. If the topic column is missing, then the topic must come from
+ * the 'topic' configuration option. If the key column is missing, then a
+ * null valued key field will be added to the
+ * [[org.apache.kafka.clients.producer.ProducerRecord]].
+ */
+private[kafka010] object KafkaWriter extends Logging {
+ val TOPIC_ATTRIBUTE_NAME: String = "topic"
+ val KEY_ATTRIBUTE_NAME: String = "key"
+ val VALUE_ATTRIBUTE_NAME: String = "value"
+
+ override def toString: String = "KafkaWriter"
+
+ def validateQuery(
+ queryExecution: QueryExecution,
+ kafkaParameters: ju.Map[String, Object],
+ topic: Option[String] = None): Unit = {
+ val schema = queryExecution.logical.output
+ schema.find(_.name == TOPIC_ATTRIBUTE_NAME).getOrElse(
+ if (topic == None) {
+ throw new AnalysisException(s"topic option required when no " +
+ s"'$TOPIC_ATTRIBUTE_NAME' attribute is present. Use the " +
+ s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a topic.")
+ } else {
+ Literal(topic.get, StringType)
+ }
+ ).dataType match {
+ case StringType => // good
+ case _ =>
+ throw new AnalysisException(s"Topic type must be a String")
+ }
+ schema.find(_.name == KEY_ATTRIBUTE_NAME).getOrElse(
+ Literal(null, StringType)
+ ).dataType match {
+ case StringType | BinaryType => // good
+ case _ =>
+ throw new AnalysisException(s"$KEY_ATTRIBUTE_NAME attribute type " +
+ s"must be a String or BinaryType")
+ }
+ schema.find(_.name == VALUE_ATTRIBUTE_NAME).getOrElse(
+ throw new AnalysisException(s"Required attribute '$VALUE_ATTRIBUTE_NAME' not found")
+ ).dataType match {
+ case StringType | BinaryType => // good
+ case _ =>
+ throw new AnalysisException(s"$VALUE_ATTRIBUTE_NAME attribute type " +
+ s"must be a String or BinaryType")
+ }
+ }
+
+ def write(
+ sparkSession: SparkSession,
+ queryExecution: QueryExecution,
+ kafkaParameters: ju.Map[String, Object],
+ topic: Option[String] = None): Unit = {
+ val schema = queryExecution.logical.output
+ validateQuery(queryExecution, kafkaParameters, topic)
+ SQLExecution.withNewExecutionId(sparkSession, queryExecution) {
+ queryExecution.toRdd.foreachPartition { iter =>
+ val writeTask = new KafkaWriteTask(kafkaParameters, schema, topic)
+ Utils.tryWithSafeFinally(block = writeTask.execute(iter))(
+ finallyBlock = writeTask.close())
+ }
+ }
+ }
+}
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumerSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumerSuite.scala
new file mode 100644
index 000000000000..7aa7dd096c07
--- /dev/null
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumerSuite.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.kafka010
+
+import org.scalatest.PrivateMethodTester
+
+import org.apache.spark.sql.test.SharedSQLContext
+
+class CachedKafkaConsumerSuite extends SharedSQLContext with PrivateMethodTester {
+
+ test("SPARK-19886: Report error cause correctly in reportDataLoss") {
+ val cause = new Exception("D'oh!")
+ val reportDataLoss = PrivateMethod[Unit]('reportDataLoss0)
+ val e = intercept[IllegalStateException] {
+ CachedKafkaConsumer.invokePrivate(reportDataLoss(true, "message", cause))
+ }
+ assert(e.getCause === cause)
+ }
+}
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala
new file mode 100644
index 000000000000..490535623cb3
--- /dev/null
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala
@@ -0,0 +1,412 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.kafka010
+
+import java.util.concurrent.atomic.AtomicInteger
+
+import org.apache.kafka.clients.producer.ProducerConfig
+import org.apache.kafka.common.serialization.ByteArraySerializer
+import org.scalatest.time.SpanSugar._
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, SpecificInternalRow, UnsafeProjection}
+import org.apache.spark.sql.execution.streaming.MemoryStream
+import org.apache.spark.sql.streaming._
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.{BinaryType, DataType}
+
+class KafkaSinkSuite extends StreamTest with SharedSQLContext {
+ import testImplicits._
+
+ protected var testUtils: KafkaTestUtils = _
+
+ override val streamingTimeout = 30.seconds
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ testUtils = new KafkaTestUtils(
+ withBrokerProps = Map("auto.create.topics.enable" -> "false"))
+ testUtils.setup()
+ }
+
+ override def afterAll(): Unit = {
+ if (testUtils != null) {
+ testUtils.teardown()
+ testUtils = null
+ super.afterAll()
+ }
+ }
+
+ test("batch - write to kafka") {
+ val topic = newTopic()
+ testUtils.createTopic(topic)
+ val df = Seq("1", "2", "3", "4", "5").map(v => (topic, v)).toDF("topic", "value")
+ df.write
+ .format("kafka")
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .option("topic", topic)
+ .save()
+ checkAnswer(
+ createKafkaReader(topic).selectExpr("CAST(value as STRING) value"),
+ Row("1") :: Row("2") :: Row("3") :: Row("4") :: Row("5") :: Nil)
+ }
+
+ test("batch - null topic field value, and no topic option") {
+ val df = Seq[(String, String)](null.asInstanceOf[String] -> "1").toDF("topic", "value")
+ val ex = intercept[SparkException] {
+ df.write
+ .format("kafka")
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .save()
+ }
+ assert(ex.getMessage.toLowerCase.contains(
+ "null topic present in the data"))
+ }
+
+ test("batch - unsupported save modes") {
+ val topic = newTopic()
+ testUtils.createTopic(topic)
+ val df = Seq[(String, String)](null.asInstanceOf[String] -> "1").toDF("topic", "value")
+
+ // Test bad save mode Ignore
+ var ex = intercept[AnalysisException] {
+ df.write
+ .format("kafka")
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .mode(SaveMode.Ignore)
+ .save()
+ }
+ assert(ex.getMessage.toLowerCase.contains(
+ s"save mode ignore not allowed for kafka"))
+
+ // Test bad save mode Overwrite
+ ex = intercept[AnalysisException] {
+ df.write
+ .format("kafka")
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .mode(SaveMode.Overwrite)
+ .save()
+ }
+ assert(ex.getMessage.toLowerCase.contains(
+ s"save mode overwrite not allowed for kafka"))
+ }
+
+ test("streaming - write to kafka with topic field") {
+ val input = MemoryStream[String]
+ val topic = newTopic()
+ testUtils.createTopic(topic)
+
+ val writer = createKafkaWriter(
+ input.toDF(),
+ withTopic = None,
+ withOutputMode = Some(OutputMode.Append))(
+ withSelectExpr = s"'$topic' as topic", "value")
+
+ val reader = createKafkaReader(topic)
+ .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value")
+ .selectExpr("CAST(key as INT) key", "CAST(value as INT) value")
+ .as[(Int, Int)]
+ .map(_._2)
+
+ try {
+ input.addData("1", "2", "3", "4", "5")
+ failAfter(streamingTimeout) {
+ writer.processAllAvailable()
+ }
+ checkDatasetUnorderly(reader, 1, 2, 3, 4, 5)
+ input.addData("6", "7", "8", "9", "10")
+ failAfter(streamingTimeout) {
+ writer.processAllAvailable()
+ }
+ checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
+ } finally {
+ writer.stop()
+ }
+ }
+
+ test("streaming - write aggregation w/o topic field, with topic option") {
+ val input = MemoryStream[String]
+ val topic = newTopic()
+ testUtils.createTopic(topic)
+
+ val writer = createKafkaWriter(
+ input.toDF().groupBy("value").count(),
+ withTopic = Some(topic),
+ withOutputMode = Some(OutputMode.Update()))(
+ withSelectExpr = "CAST(value as STRING) key", "CAST(count as STRING) value")
+
+ val reader = createKafkaReader(topic)
+ .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value")
+ .selectExpr("CAST(key as INT) key", "CAST(value as INT) value")
+ .as[(Int, Int)]
+
+ try {
+ input.addData("1", "2", "2", "3", "3", "3")
+ failAfter(streamingTimeout) {
+ writer.processAllAvailable()
+ }
+ checkDatasetUnorderly(reader, (1, 1), (2, 2), (3, 3))
+ input.addData("1", "2", "3")
+ failAfter(streamingTimeout) {
+ writer.processAllAvailable()
+ }
+ checkDatasetUnorderly(reader, (1, 1), (2, 2), (3, 3), (1, 2), (2, 3), (3, 4))
+ } finally {
+ writer.stop()
+ }
+ }
+
+ test("streaming - aggregation with topic field and topic option") {
+ /* The purpose of this test is to ensure that the topic option
+ * overrides the topic field. We begin by writing some data that
+ * includes a topic field and value (e.g., 'foo') along with a topic
+ * option. Then when we read from the topic specified in the option
+ * we should see the data i.e., the data was written to the topic
+ * option, and not to the topic in the data e.g., foo
+ */
+ val input = MemoryStream[String]
+ val topic = newTopic()
+ testUtils.createTopic(topic)
+
+ val writer = createKafkaWriter(
+ input.toDF().groupBy("value").count(),
+ withTopic = Some(topic),
+ withOutputMode = Some(OutputMode.Update()))(
+ withSelectExpr = "'foo' as topic",
+ "CAST(value as STRING) key", "CAST(count as STRING) value")
+
+ val reader = createKafkaReader(topic)
+ .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
+ .selectExpr("CAST(key AS INT)", "CAST(value AS INT)")
+ .as[(Int, Int)]
+
+ try {
+ input.addData("1", "2", "2", "3", "3", "3")
+ failAfter(streamingTimeout) {
+ writer.processAllAvailable()
+ }
+ checkDatasetUnorderly(reader, (1, 1), (2, 2), (3, 3))
+ input.addData("1", "2", "3")
+ failAfter(streamingTimeout) {
+ writer.processAllAvailable()
+ }
+ checkDatasetUnorderly(reader, (1, 1), (2, 2), (3, 3), (1, 2), (2, 3), (3, 4))
+ } finally {
+ writer.stop()
+ }
+ }
+
+
+ test("streaming - write data with bad schema") {
+ val input = MemoryStream[String]
+ val topic = newTopic()
+ testUtils.createTopic(topic)
+
+ /* No topic field or topic option */
+ var writer: StreamingQuery = null
+ var ex: Exception = null
+ try {
+ ex = intercept[StreamingQueryException] {
+ writer = createKafkaWriter(input.toDF())(
+ withSelectExpr = "value as key", "value"
+ )
+ input.addData("1", "2", "3", "4", "5")
+ writer.processAllAvailable()
+ }
+ } finally {
+ writer.stop()
+ }
+ assert(ex.getMessage
+ .toLowerCase
+ .contains("topic option required when no 'topic' attribute is present"))
+
+ try {
+ /* No value field */
+ ex = intercept[StreamingQueryException] {
+ writer = createKafkaWriter(input.toDF())(
+ withSelectExpr = s"'$topic' as topic", "value as key"
+ )
+ input.addData("1", "2", "3", "4", "5")
+ writer.processAllAvailable()
+ }
+ } finally {
+ writer.stop()
+ }
+ assert(ex.getMessage.toLowerCase.contains("required attribute 'value' not found"))
+ }
+
+ test("streaming - write data with valid schema but wrong types") {
+ val input = MemoryStream[String]
+ val topic = newTopic()
+ testUtils.createTopic(topic)
+
+ var writer: StreamingQuery = null
+ var ex: Exception = null
+ try {
+ /* topic field wrong type */
+ ex = intercept[StreamingQueryException] {
+ writer = createKafkaWriter(input.toDF())(
+ withSelectExpr = s"CAST('1' as INT) as topic", "value"
+ )
+ input.addData("1", "2", "3", "4", "5")
+ writer.processAllAvailable()
+ }
+ } finally {
+ writer.stop()
+ }
+ assert(ex.getMessage.toLowerCase.contains("topic type must be a string"))
+
+ try {
+ /* value field wrong type */
+ ex = intercept[StreamingQueryException] {
+ writer = createKafkaWriter(input.toDF())(
+ withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as value"
+ )
+ input.addData("1", "2", "3", "4", "5")
+ writer.processAllAvailable()
+ }
+ } finally {
+ writer.stop()
+ }
+ assert(ex.getMessage.toLowerCase.contains(
+ "value attribute type must be a string or binarytype"))
+
+ try {
+ ex = intercept[StreamingQueryException] {
+ /* key field wrong type */
+ writer = createKafkaWriter(input.toDF())(
+ withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as key", "value"
+ )
+ input.addData("1", "2", "3", "4", "5")
+ writer.processAllAvailable()
+ }
+ } finally {
+ writer.stop()
+ }
+ assert(ex.getMessage.toLowerCase.contains(
+ "key attribute type must be a string or binarytype"))
+ }
+
+ test("streaming - write to non-existing topic") {
+ val input = MemoryStream[String]
+ val topic = newTopic()
+
+ var writer: StreamingQuery = null
+ var ex: Exception = null
+ try {
+ ex = intercept[StreamingQueryException] {
+ writer = createKafkaWriter(input.toDF(), withTopic = Some(topic))()
+ input.addData("1", "2", "3", "4", "5")
+ writer.processAllAvailable()
+ }
+ } finally {
+ writer.stop()
+ }
+ assert(ex.getMessage.toLowerCase.contains("job aborted"))
+ }
+
+ test("streaming - exception on config serializer") {
+ val input = MemoryStream[String]
+ var writer: StreamingQuery = null
+ var ex: Exception = null
+ ex = intercept[IllegalArgumentException] {
+ writer = createKafkaWriter(
+ input.toDF(),
+ withOptions = Map("kafka.key.serializer" -> "foo"))()
+ }
+ assert(ex.getMessage.toLowerCase.contains(
+ "kafka option 'key.serializer' is not supported"))
+
+ ex = intercept[IllegalArgumentException] {
+ writer = createKafkaWriter(
+ input.toDF(),
+ withOptions = Map("kafka.value.serializer" -> "foo"))()
+ }
+ assert(ex.getMessage.toLowerCase.contains(
+ "kafka option 'value.serializer' is not supported"))
+ }
+
+ test("generic - write big data with small producer buffer") {
+ /* This test ensures that we understand the semantics of Kafka when
+ * is comes to blocking on a call to send when the send buffer is full.
+ * This test will configure the smallest possible producer buffer and
+ * indicate that we should block when it is full. Thus, no exception should
+ * be thrown in the case of a full buffer.
+ */
+ val topic = newTopic()
+ testUtils.createTopic(topic, 1)
+ val options = new java.util.HashMap[String, Object]
+ options.put("bootstrap.servers", testUtils.brokerAddress)
+ options.put("buffer.memory", "16384") // min buffer size
+ options.put("block.on.buffer.full", "true")
+ options.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName)
+ options.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName)
+ val inputSchema = Seq(AttributeReference("value", BinaryType)())
+ val data = new Array[Byte](15000) // large value
+ val writeTask = new KafkaWriteTask(options, inputSchema, Some(topic))
+ try {
+ val fieldTypes: Array[DataType] = Array(BinaryType)
+ val converter = UnsafeProjection.create(fieldTypes)
+ val row = new SpecificInternalRow(fieldTypes)
+ row.update(0, data)
+ val iter = Seq.fill(1000)(converter.apply(row)).iterator
+ writeTask.execute(iter)
+ } finally {
+ writeTask.close()
+ }
+ }
+
+ private val topicId = new AtomicInteger(0)
+
+ private def newTopic(): String = s"topic-${topicId.getAndIncrement()}"
+
+ private def createKafkaReader(topic: String): DataFrame = {
+ spark.read
+ .format("kafka")
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .option("startingOffsets", "earliest")
+ .option("endingOffsets", "latest")
+ .option("subscribe", topic)
+ .load()
+ }
+
+ private def createKafkaWriter(
+ input: DataFrame,
+ withTopic: Option[String] = None,
+ withOutputMode: Option[OutputMode] = None,
+ withOptions: Map[String, String] = Map[String, String]())
+ (withSelectExpr: String*): StreamingQuery = {
+ var stream: DataStreamWriter[Row] = null
+ withTempDir { checkpointDir =>
+ var df = input.toDF()
+ if (withSelectExpr.length > 0) {
+ df = df.selectExpr(withSelectExpr: _*)
+ }
+ stream = df.writeStream
+ .format("kafka")
+ .option("checkpointLocation", checkpointDir.getCanonicalPath)
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .queryName("kafkaStream")
+ withTopic.foreach(stream.option("topic", _))
+ withOutputMode.foreach(stream.outputMode(_))
+ withOptions.foreach(opt => stream.option(opt._1, opt._2))
+ }
+ stream.start()
+ }
+}
diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala
index 23c4d99e50f5..0f1790bddcc3 100644
--- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala
+++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala
@@ -36,7 +36,11 @@ import org.apache.spark.util.NextIterator
/** Class representing a range of Kinesis sequence numbers. Both sequence numbers are inclusive. */
private[kinesis]
case class SequenceNumberRange(
- streamName: String, shardId: String, fromSeqNumber: String, toSeqNumber: String)
+ streamName: String,
+ shardId: String,
+ fromSeqNumber: String,
+ toSeqNumber: String,
+ recordCount: Int)
/** Class representing an array of Kinesis sequence number ranges */
private[kinesis]
@@ -136,6 +140,8 @@ class KinesisSequenceRangeIterator(
private val client = new AmazonKinesisClient(credentials)
private val streamName = range.streamName
private val shardId = range.shardId
+ // AWS limits to maximum of 10k records per get call
+ private val maxGetRecordsLimit = 10000
private var toSeqNumberReceived = false
private var lastSeqNumber: String = null
@@ -153,12 +159,14 @@ class KinesisSequenceRangeIterator(
// If the internal iterator has not been initialized,
// then fetch records from starting sequence number
- internalIterator = getRecords(ShardIteratorType.AT_SEQUENCE_NUMBER, range.fromSeqNumber)
+ internalIterator = getRecords(ShardIteratorType.AT_SEQUENCE_NUMBER, range.fromSeqNumber,
+ range.recordCount)
} else if (!internalIterator.hasNext) {
// If the internal iterator does not have any more records,
// then fetch more records after the last consumed sequence number
- internalIterator = getRecords(ShardIteratorType.AFTER_SEQUENCE_NUMBER, lastSeqNumber)
+ internalIterator = getRecords(ShardIteratorType.AFTER_SEQUENCE_NUMBER, lastSeqNumber,
+ range.recordCount)
}
if (!internalIterator.hasNext) {
@@ -191,9 +199,12 @@ class KinesisSequenceRangeIterator(
/**
* Get records starting from or after the given sequence number.
*/
- private def getRecords(iteratorType: ShardIteratorType, seqNum: String): Iterator[Record] = {
+ private def getRecords(
+ iteratorType: ShardIteratorType,
+ seqNum: String,
+ recordCount: Int): Iterator[Record] = {
val shardIterator = getKinesisIterator(iteratorType, seqNum)
- val result = getRecordsAndNextKinesisIterator(shardIterator)
+ val result = getRecordsAndNextKinesisIterator(shardIterator, recordCount)
result._1
}
@@ -202,10 +213,12 @@ class KinesisSequenceRangeIterator(
* to get records from Kinesis), and get the next shard iterator for next consumption.
*/
private def getRecordsAndNextKinesisIterator(
- shardIterator: String): (Iterator[Record], String) = {
+ shardIterator: String,
+ recordCount: Int): (Iterator[Record], String) = {
val getRecordsRequest = new GetRecordsRequest
getRecordsRequest.setRequestCredentials(credentials)
getRecordsRequest.setShardIterator(shardIterator)
+ getRecordsRequest.setLimit(Math.min(recordCount, this.maxGetRecordsLimit))
val getRecordsResult = retryOrTimeout[GetRecordsResult](
s"getting records using shard iterator") {
client.getRecords(getRecordsRequest)
diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala
index 13fc54e531dd..320728f4bb22 100644
--- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala
+++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala
@@ -210,7 +210,8 @@ private[kinesis] class KinesisReceiver[T](
if (records.size > 0) {
val dataIterator = records.iterator().asScala.map(messageHandler)
val metadata = SequenceNumberRange(streamName, shardId,
- records.get(0).getSequenceNumber(), records.get(records.size() - 1).getSequenceNumber())
+ records.get(0).getSequenceNumber(), records.get(records.size() - 1).getSequenceNumber(),
+ records.size())
blockGenerator.addMultipleDataWithCallback(dataIterator, metadata)
}
}
diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala
index 18a5a1509a33..2c7b9c58e6fa 100644
--- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala
+++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala
@@ -51,7 +51,7 @@ abstract class KinesisBackedBlockRDDTests(aggregateTestData: Boolean)
shardIdToSeqNumbers = shardIdToDataAndSeqNumbers.mapValues { _.map { _._2 }}
shardIdToRange = shardIdToSeqNumbers.map { case (shardId, seqNumbers) =>
val seqNumRange = SequenceNumberRange(
- testUtils.streamName, shardId, seqNumbers.head, seqNumbers.last)
+ testUtils.streamName, shardId, seqNumbers.head, seqNumbers.last, seqNumbers.size)
(shardId, seqNumRange)
}
allRanges = shardIdToRange.values.toSeq
@@ -181,7 +181,7 @@ abstract class KinesisBackedBlockRDDTests(aggregateTestData: Boolean)
// Create the necessary ranges to use in the RDD
val fakeRanges = Array.fill(numPartitions - numPartitionsInKinesis)(
- SequenceNumberRanges(SequenceNumberRange("fakeStream", "fakeShardId", "xxx", "yyy")))
+ SequenceNumberRanges(SequenceNumberRange("fakeStream", "fakeShardId", "xxx", "yyy", 1)))
val realRanges = Array.tabulate(numPartitionsInKinesis) { i =>
val range = shardIdToRange(shardIds(i + (numPartitions - numPartitionsInKinesis)))
SequenceNumberRanges(Array(range))
diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala
index 387a96f26b30..afb55c84f81f 100644
--- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala
+++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala
@@ -119,13 +119,13 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun
// Generate block info data for testing
val seqNumRanges1 = SequenceNumberRanges(
- SequenceNumberRange("fakeStream", "fakeShardId", "xxx", "yyy"))
+ SequenceNumberRange("fakeStream", "fakeShardId", "xxx", "yyy", 67))
val blockId1 = StreamBlockId(kinesisStream.id, 123)
val blockInfo1 = ReceivedBlockInfo(
0, None, Some(seqNumRanges1), new BlockManagerBasedStoreResult(blockId1, None))
val seqNumRanges2 = SequenceNumberRanges(
- SequenceNumberRange("fakeStream", "fakeShardId", "aaa", "bbb"))
+ SequenceNumberRange("fakeStream", "fakeShardId", "aaa", "bbb", 89))
val blockId2 = StreamBlockId(kinesisStream.id, 345)
val blockInfo2 = ReceivedBlockInfo(
0, None, Some(seqNumRanges2), new BlockManagerBasedStoreResult(blockId2, None))
diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java
index bc8d6037a367..6c0c3ebcaebf 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java
@@ -135,7 +135,6 @@ List buildClassPath(String appClassPath) throws IOException {
String sparkHome = getSparkHome();
Set cp = new LinkedHashSet<>();
- addToClassPath(cp, getenv("SPARK_CLASSPATH"));
addToClassPath(cp, appClassPath);
addToClassPath(cp, getConfDir());
diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java
index 81786841de22..7cf5b7379503 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java
@@ -66,7 +66,6 @@ public List buildCommand(Map env)
memKey = "SPARK_DAEMON_MEMORY";
break;
case "org.apache.spark.executor.CoarseGrainedExecutorBackend":
- javaOptsKeys.add("SPARK_JAVA_OPTS");
javaOptsKeys.add("SPARK_EXECUTOR_OPTS");
memKey = "SPARK_EXECUTOR_MEMORY";
break;
@@ -84,7 +83,6 @@ public List buildCommand(Map env)
memKey = "SPARK_DAEMON_MEMORY";
break;
default:
- javaOptsKeys.add("SPARK_JAVA_OPTS");
memKey = "SPARK_DRIVER_MEMORY";
break;
}
diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java
index 5e64fa7ed152..5f2da036ff9f 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java
@@ -240,7 +240,6 @@ private List buildSparkSubmitCommand(Map env)
addOptionString(cmd, System.getenv("SPARK_DAEMON_JAVA_OPTS"));
}
addOptionString(cmd, System.getenv("SPARK_SUBMIT_OPTS"));
- addOptionString(cmd, System.getenv("SPARK_JAVA_OPTS"));
// We don't want the client to specify Xmx. These have to be set by their corresponding
// memory flag --driver-memory or configuration entry spark.driver.memory
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index a503411b6361..810b02febbe7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ml.feature
+import scala.language.existentials
+
import org.apache.hadoop.fs.Path
import org.apache.spark.SparkException
@@ -24,7 +26,7 @@ import org.apache.spark.annotation.Since
import org.apache.spark.ml.{Estimator, Model, Transformer}
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.ml.param._
-import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
@@ -34,8 +36,27 @@ import org.apache.spark.util.collection.OpenHashMap
/**
* Base trait for [[StringIndexer]] and [[StringIndexerModel]].
*/
-private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol
- with HasHandleInvalid {
+private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol {
+
+ /**
+ * Param for how to handle unseen labels. Options are 'skip' (filter out rows with
+ * unseen labels), 'error' (throw an error), or 'keep' (put unseen labels in a special additional
+ * bucket, at index numLabels.
+ * Default: "error"
+ * @group param
+ */
+ @Since("1.6.0")
+ val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle " +
+ "unseen labels. Options are 'skip' (filter out rows with unseen labels), " +
+ "error (throw an error), or 'keep' (put unseen labels in a special additional bucket, " +
+ "at index numLabels).",
+ ParamValidators.inArray(StringIndexer.supportedHandleInvalids))
+
+ setDefault(handleInvalid, StringIndexer.ERROR_UNSEEN_LABEL)
+
+ /** @group getParam */
+ @Since("1.6.0")
+ def getHandleInvalid: String = $(handleInvalid)
/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType): StructType = {
@@ -73,7 +94,6 @@ class StringIndexer @Since("1.4.0") (
/** @group setParam */
@Since("1.6.0")
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
- setDefault(handleInvalid, "error")
/** @group setParam */
@Since("1.4.0")
@@ -105,6 +125,11 @@ class StringIndexer @Since("1.4.0") (
@Since("1.6.0")
object StringIndexer extends DefaultParamsReadable[StringIndexer] {
+ private[feature] val SKIP_UNSEEN_LABEL: String = "skip"
+ private[feature] val ERROR_UNSEEN_LABEL: String = "error"
+ private[feature] val KEEP_UNSEEN_LABEL: String = "keep"
+ private[feature] val supportedHandleInvalids: Array[String] =
+ Array(SKIP_UNSEEN_LABEL, ERROR_UNSEEN_LABEL, KEEP_UNSEEN_LABEL)
@Since("1.6.0")
override def load(path: String): StringIndexer = super.load(path)
@@ -144,7 +169,6 @@ class StringIndexerModel (
/** @group setParam */
@Since("1.6.0")
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
- setDefault(handleInvalid, "error")
/** @group setParam */
@Since("1.4.0")
@@ -163,25 +187,34 @@ class StringIndexerModel (
}
transformSchema(dataset.schema, logging = true)
- val indexer = udf { label: String =>
- if (labelToIndex.contains(label)) {
- labelToIndex(label)
- } else {
- throw new SparkException(s"Unseen label: $label.")
- }
+ val filteredLabels = getHandleInvalid match {
+ case StringIndexer.KEEP_UNSEEN_LABEL => labels :+ "__unknown"
+ case _ => labels
}
val metadata = NominalAttribute.defaultAttr
- .withName($(outputCol)).withValues(labels).toMetadata()
+ .withName($(outputCol)).withValues(filteredLabels).toMetadata()
// If we are skipping invalid records, filter them out.
- val filteredDataset = getHandleInvalid match {
- case "skip" =>
+ val (filteredDataset, keepInvalid) = getHandleInvalid match {
+ case StringIndexer.SKIP_UNSEEN_LABEL =>
val filterer = udf { label: String =>
labelToIndex.contains(label)
}
- dataset.where(filterer(dataset($(inputCol))))
- case _ => dataset
+ (dataset.where(filterer(dataset($(inputCol)))), false)
+ case _ => (dataset, getHandleInvalid == StringIndexer.KEEP_UNSEEN_LABEL)
}
+
+ val indexer = udf { label: String =>
+ if (labelToIndex.contains(label)) {
+ labelToIndex(label)
+ } else if (keepInvalid) {
+ labels.length
+ } else {
+ throw new SparkException(s"Unseen label: $label. To handle unseen labels, " +
+ s"set Param handleInvalid to ${StringIndexer.KEEP_UNSEEN_LABEL}.")
+ }
+ }
+
filteredDataset.select(col("*"),
indexer(dataset($(inputCol)).cast(StringType)).as($(outputCol), metadata))
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index 42e8a66a62b6..4ca062c0b5ad 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -227,25 +227,50 @@ class Word2VecModel private[ml] (
/**
* Find "num" number of words closest in similarity to the given word, not
- * including the word itself. Returns a dataframe with the words and the
- * cosine similarities between the synonyms and the given word.
+ * including the word itself.
+ * @return a dataframe with columns "word" and "similarity" of the word and the cosine
+ * similarities between the synonyms and the given word vector.
*/
@Since("1.5.0")
def findSynonyms(word: String, num: Int): DataFrame = {
val spark = SparkSession.builder().getOrCreate()
- spark.createDataFrame(wordVectors.findSynonyms(word, num)).toDF("word", "similarity")
+ spark.createDataFrame(findSynonymsArray(word, num)).toDF("word", "similarity")
}
/**
- * Find "num" number of words whose vector representation most similar to the supplied vector.
+ * Find "num" number of words whose vector representation is most similar to the supplied vector.
* If the supplied vector is the vector representation of a word in the model's vocabulary,
- * that word will be in the results. Returns a dataframe with the words and the cosine
+ * that word will be in the results.
+ * @return a dataframe with columns "word" and "similarity" of the word and the cosine
* similarities between the synonyms and the given word vector.
*/
@Since("2.0.0")
def findSynonyms(vec: Vector, num: Int): DataFrame = {
val spark = SparkSession.builder().getOrCreate()
- spark.createDataFrame(wordVectors.findSynonyms(vec, num)).toDF("word", "similarity")
+ spark.createDataFrame(findSynonymsArray(vec, num)).toDF("word", "similarity")
+ }
+
+ /**
+ * Find "num" number of words whose vector representation is most similar to the supplied vector.
+ * If the supplied vector is the vector representation of a word in the model's vocabulary,
+ * that word will be in the results.
+ * @return an array of the words and the cosine similarities between the synonyms given
+ * word vector.
+ */
+ @Since("2.2.0")
+ def findSynonymsArray(vec: Vector, num: Int): Array[(String, Double)] = {
+ wordVectors.findSynonyms(vec, num)
+ }
+
+ /**
+ * Find "num" number of words closest in similarity to the given word, not
+ * including the word itself.
+ * @return an array of the words and the cosine similarities between the synonyms given
+ * word vector.
+ */
+ @Since("2.2.0")
+ def findSynonymsArray(word: String, num: Int): Array[(String, Double)] = {
+ wordVectors.findSynonyms(word, num)
}
/** @group setParam */
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 799e881fad74..60dd7367053e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -40,7 +40,8 @@ import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.CholeskyDecomposition
import org.apache.spark.mllib.optimization.NNLS
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
+import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
@@ -284,18 +285,20 @@ class ALSModel private[ml] (
@Since("2.2.0")
def setColdStartStrategy(value: String): this.type = set(coldStartStrategy, value)
+ private val predict = udf { (featuresA: Seq[Float], featuresB: Seq[Float]) =>
+ if (featuresA != null && featuresB != null) {
+ // TODO(SPARK-19759): try dot-producting on Seqs or another non-converted type for
+ // potential optimization.
+ blas.sdot(rank, featuresA.toArray, 1, featuresB.toArray, 1)
+ } else {
+ Float.NaN
+ }
+ }
+
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema)
- // Register a UDF for DataFrame, and then
// create a new column named map(predictionCol) by running the predict UDF.
- val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) =>
- if (userFeatures != null && itemFeatures != null) {
- blas.sdot(rank, userFeatures.toArray, 1, itemFeatures.toArray, 1)
- } else {
- Float.NaN
- }
- }
val predictions = dataset
.join(userFactors,
checkedCast(dataset($(userCol))) === userFactors("id"), "left")
@@ -327,6 +330,64 @@ class ALSModel private[ml] (
@Since("1.6.0")
override def write: MLWriter = new ALSModel.ALSModelWriter(this)
+
+ /**
+ * Returns top `numItems` items recommended for each user, for all users.
+ * @param numItems max number of recommendations for each user
+ * @return a DataFrame of (userCol: Int, recommendations), where recommendations are
+ * stored as an array of (itemCol: Int, rating: Float) Rows.
+ */
+ @Since("2.2.0")
+ def recommendForAllUsers(numItems: Int): DataFrame = {
+ recommendForAll(userFactors, itemFactors, $(userCol), $(itemCol), numItems)
+ }
+
+ /**
+ * Returns top `numUsers` users recommended for each item, for all items.
+ * @param numUsers max number of recommendations for each item
+ * @return a DataFrame of (itemCol: Int, recommendations), where recommendations are
+ * stored as an array of (userCol: Int, rating: Float) Rows.
+ */
+ @Since("2.2.0")
+ def recommendForAllItems(numUsers: Int): DataFrame = {
+ recommendForAll(itemFactors, userFactors, $(itemCol), $(userCol), numUsers)
+ }
+
+ /**
+ * Makes recommendations for all users (or items).
+ * @param srcFactors src factors for which to generate recommendations
+ * @param dstFactors dst factors used to make recommendations
+ * @param srcOutputColumn name of the column for the source ID in the output DataFrame
+ * @param dstOutputColumn name of the column for the destination ID in the output DataFrame
+ * @param num max number of recommendations for each record
+ * @return a DataFrame of (srcOutputColumn: Int, recommendations), where recommendations are
+ * stored as an array of (dstOutputColumn: Int, rating: Float) Rows.
+ */
+ private def recommendForAll(
+ srcFactors: DataFrame,
+ dstFactors: DataFrame,
+ srcOutputColumn: String,
+ dstOutputColumn: String,
+ num: Int): DataFrame = {
+ import srcFactors.sparkSession.implicits._
+
+ val ratings = srcFactors.crossJoin(dstFactors)
+ .select(
+ srcFactors("id"),
+ dstFactors("id"),
+ predict(srcFactors("features"), dstFactors("features")))
+ // We'll force the IDs to be Int. Unfortunately this converts IDs to Int in the output.
+ val topKAggregator = new TopByKeyAggregator[Int, Int, Float](num, Ordering.by(_._2))
+ val recs = ratings.as[(Int, Int, Float)].groupByKey(_._1).agg(topKAggregator.toColumn)
+ .toDF("id", "recommendations")
+
+ val arrayType = ArrayType(
+ new StructType()
+ .add(dstOutputColumn, IntegerType)
+ .add("rating", FloatType)
+ )
+ recs.select($"id" as srcOutputColumn, $"recommendations" cast arrayType)
+ }
}
@Since("1.6.0")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/TopByKeyAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/TopByKeyAggregator.scala
new file mode 100644
index 000000000000..517179c0eb9a
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/TopByKeyAggregator.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.recommendation
+
+import scala.language.implicitConversions
+import scala.reflect.runtime.universe.TypeTag
+
+import org.apache.spark.sql.{Encoder, Encoders}
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.expressions.Aggregator
+import org.apache.spark.util.BoundedPriorityQueue
+
+
+/**
+ * Works on rows of the form (K1, K2, V) where K1 & K2 are IDs and V is the score value. Finds
+ * the top `num` K2 items based on the given Ordering.
+ */
+private[recommendation] class TopByKeyAggregator[K1: TypeTag, K2: TypeTag, V: TypeTag]
+ (num: Int, ord: Ordering[(K2, V)])
+ extends Aggregator[(K1, K2, V), BoundedPriorityQueue[(K2, V)], Array[(K2, V)]] {
+
+ override def zero: BoundedPriorityQueue[(K2, V)] = new BoundedPriorityQueue[(K2, V)](num)(ord)
+
+ override def reduce(
+ q: BoundedPriorityQueue[(K2, V)],
+ a: (K1, K2, V)): BoundedPriorityQueue[(K2, V)] = {
+ q += {(a._2, a._3)}
+ }
+
+ override def merge(
+ q1: BoundedPriorityQueue[(K2, V)],
+ q2: BoundedPriorityQueue[(K2, V)]): BoundedPriorityQueue[(K2, V)] = {
+ q1 ++= q2
+ }
+
+ override def finish(r: BoundedPriorityQueue[(K2, V)]): Array[(K2, V)] = {
+ r.toArray.sorted(ord.reverse)
+ }
+
+ override def bufferEncoder: Encoder[BoundedPriorityQueue[(K2, V)]] = {
+ Encoders.kryo[BoundedPriorityQueue[(K2, V)]]
+ }
+
+ override def outputEncoder: Encoder[Array[(K2, V)]] = ExpressionEncoder[Array[(K2, V)]]()
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index 110764dc074f..3be8b533ee3f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -66,7 +66,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
/**
* Param for the power in the variance function of the Tweedie distribution which provides
* the relationship between the variance and mean of the distribution.
- * Only applicable for the Tweedie family.
+ * Only applicable to the Tweedie family.
* (see
* Tweedie Distribution (Wikipedia))
* Supported values: 0 and [1, Inf).
@@ -79,7 +79,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
final val variancePower: DoubleParam = new DoubleParam(this, "variancePower",
"The power in the variance function of the Tweedie distribution which characterizes " +
"the relationship between the variance and mean of the distribution. " +
- "Only applicable for the Tweedie family. Supported values: 0 and [1, Inf).",
+ "Only applicable to the Tweedie family. Supported values: 0 and [1, Inf).",
(x: Double) => x >= 1.0 || x == 0.0)
/** @group getParam */
@@ -106,7 +106,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
def getLink: String = $(link)
/**
- * Param for the index in the power link function. Only applicable for the Tweedie family.
+ * Param for the index in the power link function. Only applicable to the Tweedie family.
* Note that link power 0, 1, -1 or 0.5 corresponds to the Log, Identity, Inverse or Sqrt
* link, respectively.
* When not set, this value defaults to 1 - [[variancePower]], which matches the R "statmod"
@@ -116,7 +116,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
*/
@Since("2.2.0")
final val linkPower: DoubleParam = new DoubleParam(this, "linkPower",
- "The index in the power link function. Only applicable for the Tweedie family.")
+ "The index in the power link function. Only applicable to the Tweedie family.")
/** @group getParam */
@Since("2.2.0")
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index 2364d43aaa0e..531c8b07910f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -30,6 +30,7 @@ import org.json4s.jackson.JsonMethods._
import org.apache.spark.SparkContext
import org.apache.spark.annotation.Since
import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.{Loader, Saveable}
@@ -314,6 +315,20 @@ class Word2Vec extends Serializable with Logging {
val expTable = sc.broadcast(createExpTable())
val bcVocab = sc.broadcast(vocab)
val bcVocabHash = sc.broadcast(vocabHash)
+ try {
+ doFit(dataset, sc, expTable, bcVocab, bcVocabHash)
+ } finally {
+ expTable.destroy(blocking = false)
+ bcVocab.destroy(blocking = false)
+ bcVocabHash.destroy(blocking = false)
+ }
+ }
+
+ private def doFit[S <: Iterable[String]](
+ dataset: RDD[S], sc: SparkContext,
+ expTable: Broadcast[Array[Float]],
+ bcVocab: Broadcast[Array[VocabWord]],
+ bcVocabHash: Broadcast[mutable.HashMap[String, Int]]) = {
// each partition is a collection of sentences,
// will be translated into arrays of Index integer
val sentences: RDD[Array[Int]] = dataset.mapPartitions { sentenceIter =>
@@ -435,9 +450,6 @@ class Word2Vec extends Serializable with Logging {
bcSyn1Global.destroy(false)
}
newSentences.unpersist()
- expTable.destroy(false)
- bcVocab.destroy(false)
- bcVocabHash.destroy(false)
val wordArray = vocab.map(_.word)
new Word2VecModel(wordArray.zipWithIndex.toMap, syn0Global)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index c711e7fa9dc6..10de50306a5c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -372,16 +372,18 @@ class DecisionTreeClassifierSuite
// Categorical splits with tree depth 2
val categoricalData: DataFrame =
TreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 2)
- testEstimatorAndModelReadWrite(dt, categoricalData, allParamSettings, checkModelData)
+ testEstimatorAndModelReadWrite(dt, categoricalData, allParamSettings,
+ allParamSettings, checkModelData)
// Continuous splits with tree depth 2
val continuousData: DataFrame =
TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2)
- testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings, checkModelData)
+ testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings,
+ allParamSettings, checkModelData)
// Continuous splits with tree depth 0
testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings ++ Map("maxDepth" -> 0),
- checkModelData)
+ allParamSettings ++ Map("maxDepth" -> 0), checkModelData)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index 0598943c3d4b..0cddb37281b3 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -374,7 +374,8 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
val continuousData: DataFrame =
TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2)
- testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, checkModelData)
+ testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings,
+ allParamSettings, checkModelData)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
index a165d8a9345c..4c63a2a88c6c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
@@ -24,12 +24,13 @@ import breeze.linalg.{DenseVector => BDV}
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.classification.LinearSVCSuite._
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
-import org.apache.spark.ml.linalg.{Vector, Vectors}
+import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{Dataset, Row}
+import org.apache.spark.sql.functions.udf
class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -41,6 +42,9 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
@transient var smallValidationDataset: Dataset[_] = _
@transient var binaryDataset: Dataset[_] = _
+ @transient var smallSparseBinaryDataset: Dataset[_] = _
+ @transient var smallSparseValidationDataset: Dataset[_] = _
+
override def beforeAll(): Unit = {
super.beforeAll()
@@ -51,6 +55,13 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
smallBinaryDataset = generateSVMInput(A, Array[Double](B, C), nPoints, 42).toDF()
smallValidationDataset = generateSVMInput(A, Array[Double](B, C), nPoints, 17).toDF()
binaryDataset = generateSVMInput(1.0, Array[Double](1.0, 2.0, 3.0, 4.0), 10000, 42).toDF()
+
+ // Dataset for testing SparseVector
+ val toSparse: Vector => SparseVector = _.asInstanceOf[DenseVector].toSparse
+ val sparse = udf(toSparse)
+ smallSparseBinaryDataset = smallBinaryDataset.withColumn("features", sparse('features))
+ smallSparseValidationDataset = smallValidationDataset.withColumn("features", sparse('features))
+
}
/**
@@ -68,6 +79,8 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
val model = svm.fit(smallBinaryDataset)
assert(model.transform(smallValidationDataset)
.where("prediction=label").count() > nPoints * 0.8)
+ val sparseModel = svm.fit(smallSparseBinaryDataset)
+ checkModels(model, sparseModel)
}
test("Linear SVC binary classification with regularization") {
@@ -75,6 +88,8 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
val model = svm.setRegParam(0.1).fit(smallBinaryDataset)
assert(model.transform(smallValidationDataset)
.where("prediction=label").count() > nPoints * 0.8)
+ val sparseModel = svm.fit(smallSparseBinaryDataset)
+ checkModels(model, sparseModel)
}
test("params") {
@@ -217,7 +232,7 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
}
val svm = new LinearSVC()
testEstimatorAndModelReadWrite(svm, smallBinaryDataset, LinearSVCSuite.allParamSettings,
- checkModelData)
+ LinearSVCSuite.allParamSettings, checkModelData)
}
}
@@ -235,7 +250,7 @@ object LinearSVCSuite {
"aggregationDepth" -> 3
)
- // Generate noisy input of the form Y = signum(x.dot(weights) + intercept + noise)
+ // Generate noisy input of the form Y = signum(x.dot(weights) + intercept + noise)
def generateSVMInput(
intercept: Double,
weights: Array[Double],
@@ -252,5 +267,10 @@ object LinearSVCSuite {
y.zip(x).map(p => LabeledPoint(p._1, Vectors.dense(p._2)))
}
+ def checkModels(model1: LinearSVCModel, model2: LinearSVCModel): Unit = {
+ assert(model1.intercept == model2.intercept)
+ assert(model1.coefficients.equals(model2.coefficients))
+ }
+
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index d89a958eed45..affaa573749e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -2089,7 +2089,7 @@ class LogisticRegressionSuite
}
val lr = new LogisticRegression()
testEstimatorAndModelReadWrite(lr, smallBinaryDataset, LogisticRegressionSuite.allParamSettings,
- checkModelData)
+ LogisticRegressionSuite.allParamSettings, checkModelData)
}
test("should support all NumericType labels and weights, and not support other types") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
index 37d7991fe8dd..4d5d299d1408 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
@@ -280,7 +280,8 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
assert(model.theta === model2.theta)
}
val nb = new NaiveBayes()
- testEstimatorAndModelReadWrite(nb, dataset, NaiveBayesSuite.allParamSettings, checkModelData)
+ testEstimatorAndModelReadWrite(nb, dataset, NaiveBayesSuite.allParamSettings,
+ NaiveBayesSuite.allParamSettings, checkModelData)
}
test("should support all NumericType labels and weights, and not support other types") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index 44e1585ee514..c3003cec73b4 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -218,7 +218,8 @@ class RandomForestClassifierSuite
val continuousData: DataFrame =
TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2)
- testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, checkModelData)
+ testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings,
+ allParamSettings, checkModelData)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
index 30513c1e276a..200a892f6c69 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
@@ -138,8 +138,8 @@ class BisectingKMeansSuite
assert(model.clusterCenters === model2.clusterCenters)
}
val bisectingKMeans = new BisectingKMeans()
- testEstimatorAndModelReadWrite(
- bisectingKMeans, dataset, BisectingKMeansSuite.allParamSettings, checkModelData)
+ testEstimatorAndModelReadWrite(bisectingKMeans, dataset, BisectingKMeansSuite.allParamSettings,
+ BisectingKMeansSuite.allParamSettings, checkModelData)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
index c500c5b3e365..61da897b666f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
@@ -163,7 +163,7 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
assert(model.gaussians.map(_.cov) === model2.gaussians.map(_.cov))
}
val gm = new GaussianMixture()
- testEstimatorAndModelReadWrite(gm, dataset,
+ testEstimatorAndModelReadWrite(gm, dataset, GaussianMixtureSuite.allParamSettings,
GaussianMixtureSuite.allParamSettings, checkModelData)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
index e10127f7d108..ca05b9c389f6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
@@ -150,7 +150,8 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
assert(model.clusterCenters === model2.clusterCenters)
}
val kmeans = new KMeans()
- testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings, checkModelData)
+ testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings,
+ KMeansSuite.allParamSettings, checkModelData)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
index 9aa11fbdbe86..75aa0be61a3e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
@@ -250,7 +250,8 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
Vectors.dense(model2.getDocConcentration) absTol 1e-6)
}
val lda = new LDA()
- testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings, checkModelData)
+ testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings,
+ LDASuite.allParamSettings, checkModelData)
}
test("read/write DistributedLDAModel") {
@@ -271,6 +272,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
}
val lda = new LDA()
testEstimatorAndModelReadWrite(lda, dataset,
+ LDASuite.allParamSettings ++ Map("optimizer" -> "em"),
LDASuite.allParamSettings ++ Map("optimizer" -> "em"), checkModelData)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala
index ab937685a555..91eac9e73331 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala
@@ -63,7 +63,7 @@ class BucketedRandomProjectionLSHSuite
}
val mh = new BucketedRandomProjectionLSH()
val settings = Map("inputCol" -> "keys", "outputCol" -> "values", "bucketLength" -> 1.0)
- testEstimatorAndModelReadWrite(mh, dataset, settings, checkModelData)
+ testEstimatorAndModelReadWrite(mh, dataset, settings, settings, checkModelData)
}
test("hashFunction") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
index 482e5d54260d..d6925da97d57 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
@@ -151,7 +151,8 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
assert(model.selectedFeatures === model2.selectedFeatures)
}
val nb = new ChiSqSelector
- testEstimatorAndModelReadWrite(nb, dataset, ChiSqSelectorSuite.allParamSettings, checkModelData)
+ testEstimatorAndModelReadWrite(nb, dataset, ChiSqSelectorSuite.allParamSettings,
+ ChiSqSelectorSuite.allParamSettings, checkModelData)
}
test("should support all NumericType labels and not support other types") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala
index 3461cdf82460..a2f009310fd7 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala
@@ -54,7 +54,7 @@ class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
}
val mh = new MinHashLSH()
val settings = Map("inputCol" -> "keys", "outputCol" -> "values")
- testEstimatorAndModelReadWrite(mh, dataset, settings, checkModelData)
+ testEstimatorAndModelReadWrite(mh, dataset, settings, settings, checkModelData)
}
test("hashFunction") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index 2d0e63c9d669..188dffb3dd55 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -64,7 +64,7 @@ class StringIndexerSuite
test("StringIndexerUnseen") {
val data = Seq((0, "a"), (1, "b"), (4, "b"))
- val data2 = Seq((0, "a"), (1, "b"), (2, "c"))
+ val data2 = Seq((0, "a"), (1, "b"), (2, "c"), (3, "d"))
val df = data.toDF("id", "label")
val df2 = data2.toDF("id", "label")
val indexer = new StringIndexer()
@@ -75,22 +75,32 @@ class StringIndexerSuite
intercept[SparkException] {
indexer.transform(df2).collect()
}
- val indexerSkipInvalid = new StringIndexer()
- .setInputCol("label")
- .setOutputCol("labelIndex")
- .setHandleInvalid("skip")
- .fit(df)
+
+ indexer.setHandleInvalid("skip")
// Verify that we skip the c record
- val transformed = indexerSkipInvalid.transform(df2)
- val attr = Attribute.fromStructField(transformed.schema("labelIndex"))
+ val transformedSkip = indexer.transform(df2)
+ val attrSkip = Attribute.fromStructField(transformedSkip.schema("labelIndex"))
.asInstanceOf[NominalAttribute]
- assert(attr.values.get === Array("b", "a"))
- val output = transformed.select("id", "labelIndex").rdd.map { r =>
+ assert(attrSkip.values.get === Array("b", "a"))
+ val outputSkip = transformedSkip.select("id", "labelIndex").rdd.map { r =>
(r.getInt(0), r.getDouble(1))
}.collect().toSet
// a -> 1, b -> 0
- val expected = Set((0, 1.0), (1, 0.0))
- assert(output === expected)
+ val expectedSkip = Set((0, 1.0), (1, 0.0))
+ assert(outputSkip === expectedSkip)
+
+ indexer.setHandleInvalid("keep")
+ // Verify that we keep the unseen records
+ val transformedKeep = indexer.transform(df2)
+ val attrKeep = Attribute.fromStructField(transformedKeep.schema("labelIndex"))
+ .asInstanceOf[NominalAttribute]
+ assert(attrKeep.values.get === Array("b", "a", "__unknown"))
+ val outputKeep = transformedKeep.select("id", "labelIndex").rdd.map { r =>
+ (r.getInt(0), r.getDouble(1))
+ }.collect().toSet
+ // a -> 1, b -> 0, c -> 2, d -> 3
+ val expectedKeep = Set((0, 1.0), (1, 0.0), (2, 2.0), (3, 2.0))
+ assert(outputKeep === expectedKeep)
}
test("StringIndexer with a numeric input column") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
index 613cc3d60b22..2043a16c15f1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
@@ -133,14 +133,22 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
.setSeed(42L)
.fit(docDF)
- val expectedSimilarity = Array(0.2608488929093532, -0.8271274846926078)
- val (synonyms, similarity) = model.findSynonyms("a", 2).rdd.map {
+ val expected = Map(("b", 0.2608488929093532), ("c", -0.8271274846926078))
+ val findSynonymsResult = model.findSynonyms("a", 2).rdd.map {
case Row(w: String, sim: Double) => (w, sim)
- }.collect().unzip
+ }.collectAsMap()
+
+ expected.foreach {
+ case (expectedSynonym, expectedSimilarity) =>
+ assert(findSynonymsResult.contains(expectedSynonym))
+ assert(expectedSimilarity ~== findSynonymsResult.get(expectedSynonym).get absTol 1E-5)
+ }
- assert(synonyms === Array("b", "c"))
- expectedSimilarity.zip(similarity).foreach {
- case (expected, actual) => assert(math.abs((expected - actual) / expected) < 1E-5)
+ val findSynonymsArrayResult = model.findSynonymsArray("a", 2).toMap
+ findSynonymsResult.foreach {
+ case (expectedSynonym, expectedSimilarity) =>
+ assert(findSynonymsArrayResult.contains(expectedSynonym))
+ assert(expectedSimilarity ~== findSynonymsArrayResult.get(expectedSynonym).get absTol 1E-5)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
index 74c746140190..076d55c18054 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
@@ -99,8 +99,8 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
model2.freqItemsets.sort("items").collect())
}
val fPGrowth = new FPGrowth()
- testEstimatorAndModelReadWrite(
- fPGrowth, dataset, FPGrowthSuite.allParamSettings, checkModelData)
+ testEstimatorAndModelReadWrite(fPGrowth, dataset, FPGrowthSuite.allParamSettings,
+ FPGrowthSuite.allParamSettings, checkModelData)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index c8228dd00437..a177ed13bf8e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -22,6 +22,7 @@ import java.util.Random
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.WrappedArray
import scala.collection.JavaConverters._
import scala.language.existentials
@@ -517,37 +518,26 @@ class ALSSuite
}
test("read/write") {
- import ALSSuite._
- val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1)
- val als = new ALS()
- allEstimatorParamSettings.foreach { case (p, v) =>
- als.set(als.getParam(p), v)
- }
val spark = this.spark
import spark.implicits._
- val model = als.fit(ratings.toDF())
-
- // Test Estimator save/load
- val als2 = testDefaultReadWrite(als)
- allEstimatorParamSettings.foreach { case (p, v) =>
- val param = als.getParam(p)
- assert(als.get(param).get === als2.get(param).get)
- }
+ import ALSSuite._
+ val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1)
- // Test Model save/load
- val model2 = testDefaultReadWrite(model)
- allModelParamSettings.foreach { case (p, v) =>
- val param = model.getParam(p)
- assert(model.get(param).get === model2.get(param).get)
- }
- assert(model.rank === model2.rank)
def getFactors(df: DataFrame): Set[(Int, Array[Float])] = {
df.select("id", "features").collect().map { case r =>
(r.getInt(0), r.getAs[Array[Float]](1))
}.toSet
}
- assert(getFactors(model.userFactors) === getFactors(model2.userFactors))
- assert(getFactors(model.itemFactors) === getFactors(model2.itemFactors))
+
+ def checkModelData(model: ALSModel, model2: ALSModel): Unit = {
+ assert(model.rank === model2.rank)
+ assert(getFactors(model.userFactors) === getFactors(model2.userFactors))
+ assert(getFactors(model.itemFactors) === getFactors(model2.itemFactors))
+ }
+
+ val als = new ALS()
+ testEstimatorAndModelReadWrite(als, ratings.toDF(), allEstimatorParamSettings,
+ allModelParamSettings, checkModelData)
}
test("input type validation") {
@@ -660,6 +650,99 @@ class ALSSuite
model.setColdStartStrategy(s).transform(data)
}
}
+
+ private def getALSModel = {
+ val spark = this.spark
+ import spark.implicits._
+
+ val userFactors = Seq(
+ (0, Array(6.0f, 4.0f)),
+ (1, Array(3.0f, 4.0f)),
+ (2, Array(3.0f, 6.0f))
+ ).toDF("id", "features")
+ val itemFactors = Seq(
+ (3, Array(5.0f, 6.0f)),
+ (4, Array(6.0f, 2.0f)),
+ (5, Array(3.0f, 6.0f)),
+ (6, Array(4.0f, 1.0f))
+ ).toDF("id", "features")
+ val als = new ALS().setRank(2)
+ new ALSModel(als.uid, als.getRank, userFactors, itemFactors)
+ .setUserCol("user")
+ .setItemCol("item")
+ }
+
+ test("recommendForAllUsers with k < num_items") {
+ val topItems = getALSModel.recommendForAllUsers(2)
+ assert(topItems.count() == 3)
+ assert(topItems.columns.contains("user"))
+
+ val expected = Map(
+ 0 -> Array((3, 54f), (4, 44f)),
+ 1 -> Array((3, 39f), (5, 33f)),
+ 2 -> Array((3, 51f), (5, 45f))
+ )
+ checkRecommendations(topItems, expected, "item")
+ }
+
+ test("recommendForAllUsers with k = num_items") {
+ val topItems = getALSModel.recommendForAllUsers(4)
+ assert(topItems.count() == 3)
+ assert(topItems.columns.contains("user"))
+
+ val expected = Map(
+ 0 -> Array((3, 54f), (4, 44f), (5, 42f), (6, 28f)),
+ 1 -> Array((3, 39f), (5, 33f), (4, 26f), (6, 16f)),
+ 2 -> Array((3, 51f), (5, 45f), (4, 30f), (6, 18f))
+ )
+ checkRecommendations(topItems, expected, "item")
+ }
+
+ test("recommendForAllItems with k < num_users") {
+ val topUsers = getALSModel.recommendForAllItems(2)
+ assert(topUsers.count() == 4)
+ assert(topUsers.columns.contains("item"))
+
+ val expected = Map(
+ 3 -> Array((0, 54f), (2, 51f)),
+ 4 -> Array((0, 44f), (2, 30f)),
+ 5 -> Array((2, 45f), (0, 42f)),
+ 6 -> Array((0, 28f), (2, 18f))
+ )
+ checkRecommendations(topUsers, expected, "user")
+ }
+
+ test("recommendForAllItems with k = num_users") {
+ val topUsers = getALSModel.recommendForAllItems(3)
+ assert(topUsers.count() == 4)
+ assert(topUsers.columns.contains("item"))
+
+ val expected = Map(
+ 3 -> Array((0, 54f), (2, 51f), (1, 39f)),
+ 4 -> Array((0, 44f), (2, 30f), (1, 26f)),
+ 5 -> Array((2, 45f), (0, 42f), (1, 33f)),
+ 6 -> Array((0, 28f), (2, 18f), (1, 16f))
+ )
+ checkRecommendations(topUsers, expected, "user")
+ }
+
+ private def checkRecommendations(
+ topK: DataFrame,
+ expected: Map[Int, Array[(Int, Float)]],
+ dstColName: String): Unit = {
+ val spark = this.spark
+ import spark.implicits._
+
+ assert(topK.columns.contains("recommendations"))
+ topK.as[(Int, Seq[(Int, Float)])].collect().foreach { case (id: Int, recs: Seq[(Int, Float)]) =>
+ assert(recs === expected(id))
+ }
+ topK.collect().foreach { row =>
+ val recs = row.getAs[WrappedArray[Row]]("recommendations")
+ assert(recs(0).fieldIndex(dstColName) == 0)
+ assert(recs(0).fieldIndex("rating") == 1)
+ }
+ }
}
class ALSCleanerSuite extends SparkFunSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/TopByKeyAggregatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/TopByKeyAggregatorSuite.scala
new file mode 100644
index 000000000000..5e763a8e908b
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/TopByKeyAggregatorSuite.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.recommendation
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.Dataset
+
+
+class TopByKeyAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+ private def getTopK(k: Int): Dataset[(Int, Array[(Int, Float)])] = {
+ val sqlContext = spark.sqlContext
+ import sqlContext.implicits._
+
+ val topKAggregator = new TopByKeyAggregator[Int, Int, Float](k, Ordering.by(_._2))
+ Seq(
+ (0, 3, 54f),
+ (0, 4, 44f),
+ (0, 5, 42f),
+ (0, 6, 28f),
+ (1, 3, 39f),
+ (2, 3, 51f),
+ (2, 5, 45f),
+ (2, 6, 18f)
+ ).toDS().groupByKey(_._1).agg(topKAggregator.toColumn)
+ }
+
+ test("topByKey with k < #items") {
+ val topK = getTopK(2)
+ assert(topK.count() === 3)
+
+ val expected = Map(
+ 0 -> Array((3, 54f), (4, 44f)),
+ 1 -> Array((3, 39f)),
+ 2 -> Array((3, 51f), (5, 45f))
+ )
+ checkTopK(topK, expected)
+ }
+
+ test("topByKey with k > #items") {
+ val topK = getTopK(5)
+ assert(topK.count() === 3)
+
+ val expected = Map(
+ 0 -> Array((3, 54f), (4, 44f), (5, 42f), (6, 28f)),
+ 1 -> Array((3, 39f)),
+ 2 -> Array((3, 51f), (5, 45f), (6, 18f))
+ )
+ checkTopK(topK, expected)
+ }
+
+ private def checkTopK(
+ topK: Dataset[(Int, Array[(Int, Float)])],
+ expected: Map[Int, Array[(Int, Float)]]): Unit = {
+ topK.collect().foreach { case (id, recs) => assert(recs === expected(id)) }
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
index 3cd4b0ac308e..708185a0943d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
@@ -419,7 +419,8 @@ class AFTSurvivalRegressionSuite
}
val aft = new AFTSurvivalRegression()
testEstimatorAndModelReadWrite(aft, datasetMultivariate,
- AFTSurvivalRegressionSuite.allParamSettings, checkModelData)
+ AFTSurvivalRegressionSuite.allParamSettings, AFTSurvivalRegressionSuite.allParamSettings,
+ checkModelData)
}
test("SPARK-15892: Incorrectly merged AFTAggregator with zero total count") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index 15fa26e8b527..0e91284d03d9 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -165,16 +165,17 @@ class DecisionTreeRegressorSuite
val categoricalData: DataFrame =
TreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 0)
testEstimatorAndModelReadWrite(dt, categoricalData,
- TreeTests.allParamSettings, checkModelData)
+ TreeTests.allParamSettings, TreeTests.allParamSettings, checkModelData)
// Continuous splits with tree depth 2
val continuousData: DataFrame =
TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0)
testEstimatorAndModelReadWrite(dt, continuousData,
- TreeTests.allParamSettings, checkModelData)
+ TreeTests.allParamSettings, TreeTests.allParamSettings, checkModelData)
// Continuous splits with tree depth 0
testEstimatorAndModelReadWrite(dt, continuousData,
+ TreeTests.allParamSettings ++ Map("maxDepth" -> 0),
TreeTests.allParamSettings ++ Map("maxDepth" -> 0), checkModelData)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index dcf3f9a1ea9b..03c2f97797bc 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -184,7 +184,8 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext
val allParamSettings = TreeTests.allParamSettings ++ Map("lossType" -> "squared")
val continuousData: DataFrame =
TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0)
- testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, checkModelData)
+ testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings,
+ allParamSettings, checkModelData)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
index add28a72b680..401911763fa3 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
@@ -1418,6 +1418,7 @@ class GeneralizedLinearRegressionSuite
val glr = new GeneralizedLinearRegression()
testEstimatorAndModelReadWrite(glr, datasetPoissonLog,
+ GeneralizedLinearRegressionSuite.allParamSettings,
GeneralizedLinearRegressionSuite.allParamSettings, checkModelData)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
index 8cbb2acad243..f41a3601b1fa 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
@@ -178,7 +178,7 @@ class IsotonicRegressionSuite
val ir = new IsotonicRegression()
testEstimatorAndModelReadWrite(ir, dataset, IsotonicRegressionSuite.allParamSettings,
- checkModelData)
+ IsotonicRegressionSuite.allParamSettings, checkModelData)
}
test("should support all NumericType labels and weights, and not support other types") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index 584a1b272f6c..6a51e75e12a3 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -985,7 +985,7 @@ class LinearRegressionSuite
}
val lr = new LinearRegression()
testEstimatorAndModelReadWrite(lr, datasetWithWeight, LinearRegressionSuite.allParamSettings,
- checkModelData)
+ LinearRegressionSuite.allParamSettings, checkModelData)
}
test("should support all NumericType labels and weights, and not support other types") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
index c08335f9f84a..3bf0445ebd3d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
@@ -124,7 +124,8 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex
val continuousData: DataFrame =
TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0)
- testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, checkModelData)
+ testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings,
+ allParamSettings, checkModelData)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
index 553b8725b30a..bfe8f12258bb 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
@@ -85,11 +85,12 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
* - Check Params on Estimator and Model
* - Compare model data
*
- * This requires that the [[Estimator]] and [[Model]] share the same set of [[Param]]s.
+ * This requires that [[Model]]'s [[Param]]s should be a subset of [[Estimator]]'s [[Param]]s.
*
* @param estimator Estimator to test
* @param dataset Dataset to pass to [[Estimator.fit()]]
- * @param testParams Set of [[Param]] values to set in estimator
+ * @param testEstimatorParams Set of [[Param]] values to set in estimator
+ * @param testModelParams Set of [[Param]] values to set in model
* @param checkModelData Method which takes the original and loaded [[Model]] and compares their
* data. This method does not need to check [[Param]] values.
* @tparam E Type of [[Estimator]]
@@ -99,24 +100,25 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
E <: Estimator[M] with MLWritable, M <: Model[M] with MLWritable](
estimator: E,
dataset: Dataset[_],
- testParams: Map[String, Any],
+ testEstimatorParams: Map[String, Any],
+ testModelParams: Map[String, Any],
checkModelData: (M, M) => Unit): Unit = {
// Set some Params to make sure set Params are serialized.
- testParams.foreach { case (p, v) =>
+ testEstimatorParams.foreach { case (p, v) =>
estimator.set(estimator.getParam(p), v)
}
val model = estimator.fit(dataset)
// Test Estimator save/load
val estimator2 = testDefaultReadWrite(estimator)
- testParams.foreach { case (p, v) =>
+ testEstimatorParams.foreach { case (p, v) =>
val param = estimator.getParam(p)
assert(estimator.get(param).get === estimator2.get(param).get)
}
// Test Model save/load
val model2 = testDefaultReadWrite(model)
- testParams.foreach { case (p, v) =>
+ testModelParams.foreach { case (p, v) =>
val param = model.getParam(p)
assert(model.get(param).get === model2.get(param).get)
}
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 56b8c0b95e8a..bd4528bd2126 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -914,6 +914,10 @@ object MimaExcludes {
) ++ Seq(
// [SPARK-17163] Unify logistic regression interface. Private constructor has new signature.
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.this")
+ ) ++ Seq(
+ // [SPARK-17498] StringIndexer enhancement for handling unseen labels
+ ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.feature.StringIndexer"),
+ ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.feature.StringIndexerModel")
) ++ Seq(
// [SPARK-17365][Core] Remove/Kill multiple executors together to reduce RPC call time
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.SparkContext")
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 93a31897c9fc..e52baf51aed1 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -655,6 +655,7 @@ object Unidoc {
.map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/util/collection")))
.map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/catalyst")))
.map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/execution")))
+ .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/internal")))
.map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/hive/test")))
}
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index b199bf282e4f..3c3fcc8d9b8d 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -1294,8 +1294,8 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
Fit a Generalized Linear Model specified by giving a symbolic description of the linear
predictor (link function) and a description of the error distribution (family). It supports
- "gaussian", "binomial", "poisson" and "gamma" as family. Valid link functions for each family
- is listed below. The first link function of each family is the default one.
+ "gaussian", "binomial", "poisson", "gamma" and "tweedie" as family. Valid link functions for
+ each family is listed below. The first link function of each family is the default one.
* "gaussian" -> "identity", "log", "inverse"
@@ -1305,6 +1305,9 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
* "gamma" -> "inverse", "identity", "log"
+ * "tweedie" -> power link function specified through "linkPower". \
+ The default link power in the tweedie family is 1 - variancePower.
+
.. seealso:: `GLM `_
>>> from pyspark.ml.linalg import Vectors
@@ -1344,7 +1347,7 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
family = Param(Params._dummy(), "family", "The name of family which is a description of " +
"the error distribution to be used in the model. Supported options: " +
- "gaussian (default), binomial, poisson and gamma.",
+ "gaussian (default), binomial, poisson, gamma and tweedie.",
typeConverter=TypeConverters.toString)
link = Param(Params._dummy(), "link", "The name of link function which provides the " +
"relationship between the linear predictor and the mean of the distribution " +
@@ -1352,32 +1355,46 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
"and sqrt.", typeConverter=TypeConverters.toString)
linkPredictionCol = Param(Params._dummy(), "linkPredictionCol", "link prediction (linear " +
"predictor) column name", typeConverter=TypeConverters.toString)
+ variancePower = Param(Params._dummy(), "variancePower", "The power in the variance function " +
+ "of the Tweedie distribution which characterizes the relationship " +
+ "between the variance and mean of the distribution. Only applicable " +
+ "for the Tweedie family. Supported values: 0 and [1, Inf).",
+ typeConverter=TypeConverters.toFloat)
+ linkPower = Param(Params._dummy(), "linkPower", "The index in the power link function. " +
+ "Only applicable to the Tweedie family.",
+ typeConverter=TypeConverters.toFloat)
@keyword_only
def __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction",
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6,
- regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None):
+ regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None,
+ variancePower=0.0, linkPower=None):
"""
__init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", \
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \
- regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None)
+ regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, \
+ variancePower=0.0, linkPower=None)
"""
super(GeneralizedLinearRegression, self).__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.regression.GeneralizedLinearRegression", self.uid)
- self._setDefault(family="gaussian", maxIter=25, tol=1e-6, regParam=0.0, solver="irls")
+ self._setDefault(family="gaussian", maxIter=25, tol=1e-6, regParam=0.0, solver="irls",
+ variancePower=0.0)
kwargs = self._input_kwargs
+
self.setParams(**kwargs)
@keyword_only
@since("2.0.0")
def setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction",
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6,
- regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None):
+ regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None,
+ variancePower=0.0, linkPower=None):
"""
setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction", \
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \
- regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None)
+ regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, \
+ variancePower=0.0, linkPower=None)
Sets params for generalized linear regression.
"""
kwargs = self._input_kwargs
@@ -1428,6 +1445,34 @@ def getLink(self):
"""
return self.getOrDefault(self.link)
+ @since("2.2.0")
+ def setVariancePower(self, value):
+ """
+ Sets the value of :py:attr:`variancePower`.
+ """
+ return self._set(variancePower=value)
+
+ @since("2.2.0")
+ def getVariancePower(self):
+ """
+ Gets the value of variancePower or its default value.
+ """
+ return self.getOrDefault(self.variancePower)
+
+ @since("2.2.0")
+ def setLinkPower(self, value):
+ """
+ Sets the value of :py:attr:`linkPower`.
+ """
+ return self._set(linkPower=value)
+
+ @since("2.2.0")
+ def getLinkPower(self):
+ """
+ Gets the value of linkPower or its default value.
+ """
+ return self.getOrDefault(self.linkPower)
+
class GeneralizedLinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWritable,
JavaMLReadable):
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 352416055791..f052f5bb770c 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -1223,6 +1223,26 @@ def test_apply_binary_term_freqs(self):
": expected " + str(expected[i]) + ", got " + str(features[i]))
+class GeneralizedLinearRegressionTest(SparkSessionTestCase):
+
+ def test_tweedie_distribution(self):
+
+ df = self.spark.createDataFrame(
+ [(1.0, Vectors.dense(0.0, 0.0)),
+ (1.0, Vectors.dense(1.0, 2.0)),
+ (2.0, Vectors.dense(0.0, 0.0)),
+ (2.0, Vectors.dense(1.0, 1.0)), ], ["label", "features"])
+
+ glr = GeneralizedLinearRegression(family="tweedie", variancePower=1.6)
+ model = glr.fit(df)
+ self.assertTrue(np.allclose(model.coefficients.toArray(), [-0.4645, 0.3402], atol=1E-4))
+ self.assertTrue(np.isclose(model.intercept, 0.7841, atol=1E-4))
+
+ model2 = glr.setLinkPower(-1.0).fit(df)
+ self.assertTrue(np.allclose(model2.coefficients.toArray(), [-0.6667, 0.5], atol=1E-4))
+ self.assertTrue(np.isclose(model2.intercept, 0.6667, atol=1E-4))
+
+
class ALSTest(SparkSessionTestCase):
def test_storage_levels(self):
diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py
index c10ab9638a21..ec05c18d4f06 100644
--- a/python/pyspark/sql/column.py
+++ b/python/pyspark/sql/column.py
@@ -180,7 +180,9 @@ def __init__(self, jc):
__ror__ = _bin_op("or")
# container operators
- __contains__ = _bin_op("contains")
+ def __contains__(self, item):
+ raise ValueError("Cannot apply 'in' operator against a column: please use 'contains' "
+ "in a string column or 'array_contains' function for an array column.")
# bitwise operators
bitwiseOR = _bin_op("bitwiseOR")
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 426a4a8c93a6..376b86ea69bd 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -1773,11 +1773,11 @@ def json_tuple(col, *fields):
@since(2.1)
def from_json(col, schema, options={}):
"""
- Parses a column containing a JSON string into a [[StructType]] with the
- specified schema. Returns `null`, in the case of an unparseable string.
+ Parses a column containing a JSON string into a [[StructType]] or [[ArrayType]]
+ with the specified schema. Returns `null`, in the case of an unparseable string.
:param col: string column in json format
- :param schema: a StructType to use when parsing the json column
+ :param schema: a StructType or ArrayType to use when parsing the json column
:param options: options to control parsing. accepts the same options as the json datasource
>>> from pyspark.sql.types import *
@@ -1786,6 +1786,11 @@ def from_json(col, schema, options={}):
>>> df = spark.createDataFrame(data, ("key", "value"))
>>> df.select(from_json(df.value, schema).alias("json")).collect()
[Row(json=Row(a=1))]
+ >>> data = [(1, '''[{"a": 1}]''')]
+ >>> schema = ArrayType(StructType([StructField("a", IntegerType())]))
+ >>> df = spark.createDataFrame(data, ("key", "value"))
+ >>> df.select(from_json(df.value, schema).alias("json")).collect()
+ [Row(json=[Row(a=1)])]
"""
sc = SparkContext._active_spark_context
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index 45fb9b759152..4354345ebc55 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -161,7 +161,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None,
timeZone=None, wholeFile=None):
"""
- Loads a JSON file and returns the results as a :class:`DataFrame`.
+ Loads JSON files and returns the results as a :class:`DataFrame`.
`JSON Lines `_(newline-delimited JSON) is supported by default.
For JSON (one record per file), set the `wholeFile` parameter to ``true``.
@@ -169,7 +169,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
If the ``schema`` parameter is not specified, this function goes
through the input once to determine the input schema.
- :param path: string represents path to the JSON dataset,
+ :param path: string represents path to the JSON dataset, or a list of paths,
or RDD of Strings storing JSON objects.
:param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema.
:param primitivesAsString: infers all primitive values as a string type. If None is set,
@@ -252,7 +252,7 @@ def func(iterator):
jrdd = keyed._jrdd.map(self._spark._jvm.BytesToString())
return self._df(self._jreader.json(jrdd))
else:
- raise TypeError("path can be only string or RDD")
+ raise TypeError("path can be only string, list or RDD")
@since(1.4)
def table(self, tableName):
@@ -269,7 +269,7 @@ def table(self, tableName):
@since(1.4)
def parquet(self, *paths):
- """Loads a Parquet file, returning the result as a :class:`DataFrame`.
+ """Loads Parquet files, returning the result as a :class:`DataFrame`.
You can set the following Parquet-specific option(s) for reading Parquet files:
* ``mergeSchema``: sets whether we should merge schemas collected from all \
@@ -407,7 +407,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
@since(1.5)
def orc(self, path):
- """Loads an ORC file, returning the result as a :class:`DataFrame`.
+ """Loads ORC files, returning the result as a :class:`DataFrame`.
.. note:: Currently ORC support is only available together with Hive support.
@@ -415,7 +415,9 @@ def orc(self, path):
>>> df.dtypes
[('a', 'bigint'), ('b', 'int'), ('c', 'int')]
"""
- return self._df(self._jreader.orc(path))
+ if isinstance(path, basestring):
+ path = [path]
+ return self._df(self._jreader.orc(_to_seq(self._spark._sc, path)))
@since(1.4)
def jdbc(self, url, table, column=None, lowerBound=None, upperBound=None, numPartitions=None,
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index e943f8da3db1..f0a9a0400e39 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -450,6 +450,11 @@ def test_wholefile_csv(self):
Row(_c0=u'Hyukjin', _c1=u'25', _c2=u'I am Hyukjin\n\nI love Spark!')]
self.assertEqual(ages_newlines.collect(), expected)
+ def test_read_multiple_orc_file(self):
+ df = self.spark.read.orc(["python/test_support/sql/orc_partitioned/b=0/c=0",
+ "python/test_support/sql/orc_partitioned/b=1/c=1"])
+ self.assertEqual(2, df.count())
+
def test_udf_with_input_file_name(self):
from pyspark.sql.functions import udf, input_file_name
from pyspark.sql.types import StringType
@@ -967,6 +972,9 @@ def test_column_operators(self):
cs.startswith('a'), cs.endswith('a')
self.assertTrue(all(isinstance(c, Column) for c in css))
self.assertTrue(isinstance(ci.cast(LongType()), Column))
+ self.assertRaisesRegexp(ValueError,
+ "Cannot apply 'in' operator against a column",
+ lambda: 1 in cs)
def test_column_getitem(self):
from pyspark.sql.functions import col
@@ -1552,6 +1560,14 @@ def test_time_with_timezone(self):
self.assertEqual(now, now1)
self.assertEqual(now, utcnow1)
+ # regression test for SPARK-19561
+ def test_datetime_at_epoch(self):
+ epoch = datetime.datetime.fromtimestamp(0)
+ df = self.spark.createDataFrame([Row(date=epoch)])
+ first = df.select('date', lit(epoch).alias('lit_date')).first()
+ self.assertEqual(first['date'], epoch)
+ self.assertEqual(first['lit_date'], epoch)
+
def test_decimal(self):
from decimal import Decimal
schema = StructType([StructField("decimal", DecimalType(10, 5))])
diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
index 55c91675ed3b..121a02a9be0a 100644
--- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
+++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
@@ -473,4 +473,15 @@ class ReplSuite extends SparkFunSuite {
assertDoesNotContain("AssertionError", output)
assertDoesNotContain("Exception", output)
}
+
+ test("newProductSeqEncoder with REPL defined class") {
+ val output = runInterpreterInPasteMode("local-cluster[1,4,4096]",
+ """
+ |case class Click(id: Int)
+ |spark.implicits.newProductSeqEncoder[Click]
+ """.stripMargin)
+
+ assertDoesNotContain("error:", output)
+ assertDoesNotContain("Exception", output)
+ }
}
diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala
index 2760f31b12fa..1bc6f71860c3 100644
--- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala
+++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala
@@ -152,6 +152,7 @@ private[spark] class MesosClusterScheduler(
// is registered with Mesos master.
@volatile protected var ready = false
private var masterInfo: Option[MasterInfo] = None
+ private var schedulerDriver: SchedulerDriver = _
def submitDriver(desc: MesosDriverDescription): CreateSubmissionResponse = {
val c = new CreateSubmissionResponse
@@ -168,9 +169,8 @@ private[spark] class MesosClusterScheduler(
return c
}
c.submissionId = desc.submissionId
- queuedDriversState.persist(desc.submissionId, desc)
- queuedDrivers += desc
c.success = true
+ addDriverToQueue(desc)
}
c
}
@@ -191,7 +191,7 @@ private[spark] class MesosClusterScheduler(
// 4. Check if it has already completed.
if (launchedDrivers.contains(submissionId)) {
val task = launchedDrivers(submissionId)
- mesosDriver.killTask(task.taskId)
+ schedulerDriver.killTask(task.taskId)
k.success = true
k.message = "Killing running driver"
} else if (removeFromQueuedDrivers(submissionId)) {
@@ -324,7 +324,7 @@ private[spark] class MesosClusterScheduler(
ready = false
metricsSystem.report()
metricsSystem.stop()
- mesosDriver.stop(true)
+ schedulerDriver.stop(true)
}
override def registered(
@@ -340,6 +340,8 @@ private[spark] class MesosClusterScheduler(
stateLock.synchronized {
this.masterInfo = Some(masterInfo)
+ this.schedulerDriver = driver
+
if (!pendingRecover.isEmpty) {
// Start task reconciliation if we need to recover.
val statuses = pendingRecover.collect {
@@ -506,11 +508,10 @@ private[spark] class MesosClusterScheduler(
}
private class ResourceOffer(
- val offerId: OfferID,
- val slaveId: SlaveID,
- var resources: JList[Resource]) {
+ val offer: Offer,
+ var remainingResources: JList[Resource]) {
override def toString(): String = {
- s"Offer id: ${offerId}, resources: ${resources}"
+ s"Offer id: ${offer.getId}, resources: ${remainingResources}"
}
}
@@ -518,16 +519,16 @@ private[spark] class MesosClusterScheduler(
val taskId = TaskID.newBuilder().setValue(desc.submissionId).build()
val (remainingResources, cpuResourcesToUse) =
- partitionResources(offer.resources, "cpus", desc.cores)
+ partitionResources(offer.remainingResources, "cpus", desc.cores)
val (finalResources, memResourcesToUse) =
partitionResources(remainingResources.asJava, "mem", desc.mem)
- offer.resources = finalResources.asJava
+ offer.remainingResources = finalResources.asJava
val appName = desc.conf.get("spark.app.name")
val taskInfo = TaskInfo.newBuilder()
.setTaskId(taskId)
.setName(s"Driver for ${appName}")
- .setSlaveId(offer.slaveId)
+ .setSlaveId(offer.offer.getSlaveId)
.setCommand(buildDriverCommand(desc))
.addAllResources(cpuResourcesToUse.asJava)
.addAllResources(memResourcesToUse.asJava)
@@ -549,23 +550,29 @@ private[spark] class MesosClusterScheduler(
val driverCpu = submission.cores
val driverMem = submission.mem
logTrace(s"Finding offer to launch driver with cpu: $driverCpu, mem: $driverMem")
- val offerOption = currentOffers.find { o =>
- getResource(o.resources, "cpus") >= driverCpu &&
- getResource(o.resources, "mem") >= driverMem
+ val offerOption = currentOffers.find { offer =>
+ getResource(offer.remainingResources, "cpus") >= driverCpu &&
+ getResource(offer.remainingResources, "mem") >= driverMem
}
if (offerOption.isEmpty) {
logDebug(s"Unable to find offer to launch driver id: ${submission.submissionId}, " +
s"cpu: $driverCpu, mem: $driverMem")
} else {
val offer = offerOption.get
- val queuedTasks = tasks.getOrElseUpdate(offer.offerId, new ArrayBuffer[TaskInfo])
+ val queuedTasks = tasks.getOrElseUpdate(offer.offer.getId, new ArrayBuffer[TaskInfo])
try {
val task = createTaskInfo(submission, offer)
queuedTasks += task
- logTrace(s"Using offer ${offer.offerId.getValue} to launch driver " +
+ logTrace(s"Using offer ${offer.offer.getId.getValue} to launch driver " +
submission.submissionId)
- val newState = new MesosClusterSubmissionState(submission, task.getTaskId, offer.slaveId,
- None, new Date(), None, getDriverFrameworkID(submission))
+ val newState = new MesosClusterSubmissionState(
+ submission,
+ task.getTaskId,
+ offer.offer.getSlaveId,
+ None,
+ new Date(),
+ None,
+ getDriverFrameworkID(submission))
launchedDrivers(submission.submissionId) = newState
launchedDriversState.persist(submission.submissionId, newState)
afterLaunchCallback(submission.submissionId)
@@ -588,7 +595,7 @@ private[spark] class MesosClusterScheduler(
val currentTime = new Date()
val currentOffers = offers.asScala.map {
- o => new ResourceOffer(o.getId, o.getSlaveId, o.getResourcesList)
+ offer => new ResourceOffer(offer, offer.getResourcesList)
}.toList
stateLock.synchronized {
@@ -615,8 +622,8 @@ private[spark] class MesosClusterScheduler(
driver.launchTasks(Collections.singleton(offerId), taskInfos.asJava)
}
- for (o <- currentOffers if !tasks.contains(o.offerId)) {
- driver.declineOffer(o.offerId)
+ for (offer <- currentOffers if !tasks.contains(offer.offer.getId)) {
+ declineOffer(driver, offer.offer, None, Some(getRejectOfferDuration(conf)))
}
}
@@ -662,6 +669,12 @@ private[spark] class MesosClusterScheduler(
override def statusUpdate(driver: SchedulerDriver, status: TaskStatus): Unit = {
val taskId = status.getTaskId.getValue
+
+ logInfo(s"Received status update: taskId=${taskId}" +
+ s" state=${status.getState}" +
+ s" message=${status.getMessage}" +
+ s" reason=${status.getReason}");
+
stateLock.synchronized {
if (launchedDrivers.contains(taskId)) {
if (status.getReason == Reason.REASON_RECONCILIATION &&
@@ -682,8 +695,7 @@ private[spark] class MesosClusterScheduler(
val newDriverDescription = state.driverDescription.copy(
retryState = Some(new MesosClusterRetryState(status, retries, nextRetry, waitTimeSec)))
- pendingRetryDrivers += newDriverDescription
- pendingRetryDriversState.persist(taskId, newDriverDescription)
+ addDriverToPending(newDriverDescription, taskId);
} else if (TaskState.isFinished(mesosToTaskState(status.getState))) {
removeFromLaunchedDrivers(taskId)
state.finishDate = Some(new Date())
@@ -746,4 +758,21 @@ private[spark] class MesosClusterScheduler(
def getQueuedDriversSize: Int = queuedDrivers.size
def getLaunchedDriversSize: Int = launchedDrivers.size
def getPendingRetryDriversSize: Int = pendingRetryDrivers.size
+
+ private def addDriverToQueue(desc: MesosDriverDescription): Unit = {
+ queuedDriversState.persist(desc.submissionId, desc)
+ queuedDrivers += desc
+ revive()
+ }
+
+ private def addDriverToPending(desc: MesosDriverDescription, taskId: String) = {
+ pendingRetryDriversState.persist(taskId, desc)
+ pendingRetryDrivers += desc
+ revive()
+ }
+
+ private def revive(): Unit = {
+ logInfo("Reviving Offers.")
+ schedulerDriver.reviveOffers()
+ }
}
diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala
index f69c223ab9b6..c049a32eabf9 100644
--- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala
+++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala
@@ -26,6 +26,7 @@ import scala.collection.mutable
import scala.concurrent.Future
import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _}
+import org.apache.mesos.SchedulerDriver
import org.apache.spark.{SecurityManager, SparkContext, SparkException, TaskState}
import org.apache.spark.network.netty.SparkTransportConf
@@ -119,11 +120,11 @@ private[spark] class MesosCoarseGrainedSchedulerBackend(
// Reject offers with mismatched constraints in seconds
private val rejectOfferDurationForUnmetConstraints =
- getRejectOfferDurationForUnmetConstraints(sc)
+ getRejectOfferDurationForUnmetConstraints(sc.conf)
// Reject offers when we reached the maximum number of cores for this framework
private val rejectOfferDurationForReachedMaxCores =
- getRejectOfferDurationForReachedMaxCores(sc)
+ getRejectOfferDurationForReachedMaxCores(sc.conf)
// A client for talking to the external shuffle service
private val mesosExternalShuffleClient: Option[MesosExternalShuffleClient] = {
@@ -146,6 +147,8 @@ private[spark] class MesosCoarseGrainedSchedulerBackend(
@volatile var appId: String = _
+ private var schedulerDriver: SchedulerDriver = _
+
def newMesosTaskId(): String = {
val id = nextMesosTaskId
nextMesosTaskId += 1
@@ -172,11 +175,6 @@ private[spark] class MesosCoarseGrainedSchedulerBackend(
def createCommand(offer: Offer, numCores: Int, taskId: String): CommandInfo = {
val environment = Environment.newBuilder()
- val extraClassPath = conf.getOption("spark.executor.extraClassPath")
- extraClassPath.foreach { cp =>
- environment.addVariables(
- Environment.Variable.newBuilder().setName("SPARK_CLASSPATH").setValue(cp).build())
- }
val extraJavaOpts = conf.get("spark.executor.extraJavaOptions", "")
// Set the environment variable through a command prefix
@@ -252,9 +250,12 @@ private[spark] class MesosCoarseGrainedSchedulerBackend(
override def offerRescinded(d: org.apache.mesos.SchedulerDriver, o: OfferID) {}
override def registered(
- d: org.apache.mesos.SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) {
- appId = frameworkId.getValue
- mesosExternalShuffleClient.foreach(_.init(appId))
+ driver: org.apache.mesos.SchedulerDriver,
+ frameworkId: FrameworkID,
+ masterInfo: MasterInfo) {
+ this.appId = frameworkId.getValue
+ this.mesosExternalShuffleClient.foreach(_.init(appId))
+ this.schedulerDriver = driver
markRegistered()
}
@@ -293,46 +294,25 @@ private[spark] class MesosCoarseGrainedSchedulerBackend(
}
private def declineUnmatchedOffers(
- d: org.apache.mesos.SchedulerDriver, offers: mutable.Buffer[Offer]): Unit = {
+ driver: org.apache.mesos.SchedulerDriver, offers: mutable.Buffer[Offer]): Unit = {
offers.foreach { offer =>
- declineOffer(d, offer, Some("unmet constraints"),
+ declineOffer(
+ driver,
+ offer,
+ Some("unmet constraints"),
Some(rejectOfferDurationForUnmetConstraints))
}
}
- private def declineOffer(
- d: org.apache.mesos.SchedulerDriver,
- offer: Offer,
- reason: Option[String] = None,
- refuseSeconds: Option[Long] = None): Unit = {
-
- val id = offer.getId.getValue
- val offerAttributes = toAttributeMap(offer.getAttributesList)
- val mem = getResource(offer.getResourcesList, "mem")
- val cpus = getResource(offer.getResourcesList, "cpus")
- val ports = getRangeResource(offer.getResourcesList, "ports")
-
- logDebug(s"Declining offer: $id with attributes: $offerAttributes mem: $mem" +
- s" cpu: $cpus port: $ports for $refuseSeconds seconds" +
- reason.map(r => s" (reason: $r)").getOrElse(""))
-
- refuseSeconds match {
- case Some(seconds) =>
- val filters = Filters.newBuilder().setRefuseSeconds(seconds).build()
- d.declineOffer(offer.getId, filters)
- case _ => d.declineOffer(offer.getId)
- }
- }
-
/**
* Launches executors on accepted offers, and declines unused offers. Executors are launched
* round-robin on offers.
*
- * @param d SchedulerDriver
+ * @param driver SchedulerDriver
* @param offers Mesos offers that match attribute constraints
*/
private def handleMatchedOffers(
- d: org.apache.mesos.SchedulerDriver, offers: mutable.Buffer[Offer]): Unit = {
+ driver: org.apache.mesos.SchedulerDriver, offers: mutable.Buffer[Offer]): Unit = {
val tasks = buildMesosTasks(offers)
for (offer <- offers) {
val offerAttributes = toAttributeMap(offer.getAttributesList)
@@ -358,15 +338,19 @@ private[spark] class MesosCoarseGrainedSchedulerBackend(
s" ports: $ports")
}
- d.launchTasks(
+ driver.launchTasks(
Collections.singleton(offer.getId),
offerTasks.asJava)
} else if (totalCoresAcquired >= maxCores) {
// Reject an offer for a configurable amount of time to avoid starving other frameworks
- declineOffer(d, offer, Some("reached spark.cores.max"),
+ declineOffer(driver,
+ offer,
+ Some("reached spark.cores.max"),
Some(rejectOfferDurationForReachedMaxCores))
} else {
- declineOffer(d, offer)
+ declineOffer(
+ driver,
+ offer)
}
}
}
@@ -582,8 +566,8 @@ private[spark] class MesosCoarseGrainedSchedulerBackend(
// Close the mesos external shuffle client if used
mesosExternalShuffleClient.foreach(_.close())
- if (mesosDriver != null) {
- mesosDriver.stop()
+ if (schedulerDriver != null) {
+ schedulerDriver.stop()
}
}
@@ -634,13 +618,13 @@ private[spark] class MesosCoarseGrainedSchedulerBackend(
}
override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = Future.successful {
- if (mesosDriver == null) {
+ if (schedulerDriver == null) {
logWarning("Asked to kill executors before the Mesos driver was started.")
false
} else {
for (executorId <- executorIds) {
val taskId = TaskID.newBuilder().setValue(executorId).build()
- mesosDriver.killTask(taskId)
+ schedulerDriver.killTask(taskId)
}
// no need to adjust `executorLimitOption` since the AllocationManager already communicated
// the desired limit through a call to `doRequestTotalExecutors`.
diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala
index 7e561916a71e..f198f8893b3d 100644
--- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala
+++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala
@@ -24,6 +24,7 @@ import scala.collection.JavaConverters._
import scala.collection.mutable.{HashMap, HashSet}
import org.apache.mesos.Protos.{ExecutorInfo => MesosExecutorInfo, TaskInfo => MesosTaskInfo, _}
+import org.apache.mesos.SchedulerDriver
import org.apache.mesos.protobuf.ByteString
import org.apache.spark.{SparkContext, SparkException, TaskState}
@@ -65,7 +66,9 @@ private[spark] class MesosFineGrainedSchedulerBackend(
// reject offers with mismatched constraints in seconds
private val rejectOfferDurationForUnmetConstraints =
- getRejectOfferDurationForUnmetConstraints(sc)
+ getRejectOfferDurationForUnmetConstraints(sc.conf)
+
+ private var schedulerDriver: SchedulerDriver = _
@volatile var appId: String = _
@@ -89,6 +92,7 @@ private[spark] class MesosFineGrainedSchedulerBackend(
/**
* Creates a MesosExecutorInfo that is used to launch a Mesos executor.
+ *
* @param availableResources Available resources that is offered by Mesos
* @param execId The executor id to assign to this new executor.
* @return A tuple of the new mesos executor info and the remaining available resources.
@@ -102,10 +106,6 @@ private[spark] class MesosFineGrainedSchedulerBackend(
throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!")
}
val environment = Environment.newBuilder()
- sc.conf.getOption("spark.executor.extraClassPath").foreach { cp =>
- environment.addVariables(
- Environment.Variable.newBuilder().setName("SPARK_CLASSPATH").setValue(cp).build())
- }
val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions").getOrElse("")
val prefixEnv = sc.conf.getOption("spark.executor.extraLibraryPath").map { p =>
@@ -178,10 +178,13 @@ private[spark] class MesosFineGrainedSchedulerBackend(
override def offerRescinded(d: org.apache.mesos.SchedulerDriver, o: OfferID) {}
override def registered(
- d: org.apache.mesos.SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) {
+ driver: org.apache.mesos.SchedulerDriver,
+ frameworkId: FrameworkID,
+ masterInfo: MasterInfo) {
inClassLoader() {
appId = frameworkId.getValue
logInfo("Registered as framework ID " + appId)
+ this.schedulerDriver = driver
markRegistered()
}
}
@@ -383,13 +386,13 @@ private[spark] class MesosFineGrainedSchedulerBackend(
}
override def stop() {
- if (mesosDriver != null) {
- mesosDriver.stop()
+ if (schedulerDriver != null) {
+ schedulerDriver.stop()
}
}
override def reviveOffers() {
- mesosDriver.reviveOffers()
+ schedulerDriver.reviveOffers()
}
override def frameworkMessage(
@@ -426,7 +429,7 @@ private[spark] class MesosFineGrainedSchedulerBackend(
}
override def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = {
- mesosDriver.killTask(
+ schedulerDriver.killTask(
TaskID.newBuilder()
.setValue(taskId.toString).build()
)
diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala
index 1d742fefbbac..3f25535cb5ec 100644
--- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala
+++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala
@@ -46,9 +46,6 @@ trait MesosSchedulerUtils extends Logging {
// Lock used to wait for scheduler to be registered
private final val registerLatch = new CountDownLatch(1)
- // Driver for talking to Mesos
- protected var mesosDriver: SchedulerDriver = null
-
/**
* Creates a new MesosSchedulerDriver that communicates to the Mesos master.
*
@@ -115,10 +112,6 @@ trait MesosSchedulerUtils extends Logging {
*/
def startScheduler(newDriver: SchedulerDriver): Unit = {
synchronized {
- if (mesosDriver != null) {
- registerLatch.await()
- return
- }
@volatile
var error: Option[Exception] = None
@@ -128,8 +121,7 @@ trait MesosSchedulerUtils extends Logging {
setDaemon(true)
override def run() {
try {
- mesosDriver = newDriver
- val ret = mesosDriver.run()
+ val ret = newDriver.run()
logInfo("driver.run() returned with code " + ret)
if (ret != null && ret.equals(Status.DRIVER_ABORTED)) {
error = Some(new SparkException("Error starting driver, DRIVER_ABORTED"))
@@ -379,12 +371,24 @@ trait MesosSchedulerUtils extends Logging {
}
}
- protected def getRejectOfferDurationForUnmetConstraints(sc: SparkContext): Long = {
- sc.conf.getTimeAsSeconds("spark.mesos.rejectOfferDurationForUnmetConstraints", "120s")
+ private def getRejectOfferDurationStr(conf: SparkConf): String = {
+ conf.get("spark.mesos.rejectOfferDuration", "120s")
+ }
+
+ protected def getRejectOfferDuration(conf: SparkConf): Long = {
+ Utils.timeStringAsSeconds(getRejectOfferDurationStr(conf))
+ }
+
+ protected def getRejectOfferDurationForUnmetConstraints(conf: SparkConf): Long = {
+ conf.getTimeAsSeconds(
+ "spark.mesos.rejectOfferDurationForUnmetConstraints",
+ getRejectOfferDurationStr(conf))
}
- protected def getRejectOfferDurationForReachedMaxCores(sc: SparkContext): Long = {
- sc.conf.getTimeAsSeconds("spark.mesos.rejectOfferDurationForReachedMaxCores", "120s")
+ protected def getRejectOfferDurationForReachedMaxCores(conf: SparkConf): Long = {
+ conf.getTimeAsSeconds(
+ "spark.mesos.rejectOfferDurationForReachedMaxCores",
+ getRejectOfferDurationStr(conf))
}
/**
@@ -438,6 +442,7 @@ trait MesosSchedulerUtils extends Logging {
/**
* The values of the non-zero ports to be used by the executor process.
+ *
* @param conf the spark config to use
* @return the ono-zero values of the ports
*/
@@ -521,4 +526,33 @@ trait MesosSchedulerUtils extends Logging {
case TaskState.KILLED => MesosTaskState.TASK_KILLED
case TaskState.LOST => MesosTaskState.TASK_LOST
}
+
+ protected def declineOffer(
+ driver: org.apache.mesos.SchedulerDriver,
+ offer: Offer,
+ reason: Option[String] = None,
+ refuseSeconds: Option[Long] = None): Unit = {
+
+ val id = offer.getId.getValue
+ val offerAttributes = toAttributeMap(offer.getAttributesList)
+ val mem = getResource(offer.getResourcesList, "mem")
+ val cpus = getResource(offer.getResourcesList, "cpus")
+ val ports = getRangeResource(offer.getResourcesList, "ports")
+
+ logDebug(s"Declining offer: $id with " +
+ s"attributes: $offerAttributes " +
+ s"mem: $mem " +
+ s"cpu: $cpus " +
+ s"port: $ports " +
+ refuseSeconds.map(s => s"for ${s} seconds ").getOrElse("") +
+ reason.map(r => s" (reason: $r)").getOrElse(""))
+
+ refuseSeconds match {
+ case Some(seconds) =>
+ val filters = Filters.newBuilder().setRefuseSeconds(seconds).build()
+ driver.declineOffer(offer.getId, filters)
+ case _ =>
+ driver.declineOffer(offer.getId)
+ }
+ }
}
diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala
index b9d098486b67..32967b04cd34 100644
--- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala
+++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala
@@ -53,19 +53,32 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi
override def start(): Unit = { ready = true }
}
scheduler.start()
+ scheduler.registered(driver, Utils.TEST_FRAMEWORK_ID, Utils.TEST_MASTER_INFO)
+ }
+
+ private def testDriverDescription(submissionId: String): MesosDriverDescription = {
+ new MesosDriverDescription(
+ "d1",
+ "jar",
+ 1000,
+ 1,
+ true,
+ command,
+ Map[String, String](),
+ submissionId,
+ new Date())
}
test("can queue drivers") {
setScheduler()
- val response = scheduler.submitDriver(
- new MesosDriverDescription("d1", "jar", 1000, 1, true,
- command, Map[String, String](), "s1", new Date()))
+ val response = scheduler.submitDriver(testDriverDescription("s1"))
assert(response.success)
- val response2 =
- scheduler.submitDriver(new MesosDriverDescription(
- "d1", "jar", 1000, 1, true, command, Map[String, String](), "s2", new Date()))
+ verify(driver, times(1)).reviveOffers()
+
+ val response2 = scheduler.submitDriver(testDriverDescription("s2"))
assert(response2.success)
+
val state = scheduler.getSchedulerState()
val queuedDrivers = state.queuedDrivers.toList
assert(queuedDrivers(0).submissionId == response.submissionId)
@@ -75,9 +88,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi
test("can kill queued drivers") {
setScheduler()
- val response = scheduler.submitDriver(
- new MesosDriverDescription("d1", "jar", 1000, 1, true,
- command, Map[String, String](), "s1", new Date()))
+ val response = scheduler.submitDriver(testDriverDescription("s1"))
assert(response.success)
val killResponse = scheduler.killDriver(response.submissionId)
assert(killResponse.success)
@@ -238,18 +249,10 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi
}
test("can kill supervised drivers") {
- val driver = mock[SchedulerDriver]
val conf = new SparkConf()
conf.setMaster("mesos://localhost:5050")
conf.setAppName("spark mesos")
- scheduler = new MesosClusterScheduler(
- new BlackHoleMesosClusterPersistenceEngineFactory, conf) {
- override def start(): Unit = {
- ready = true
- mesosDriver = driver
- }
- }
- scheduler.start()
+ setScheduler(conf.getAll.toMap)
val response = scheduler.submitDriver(
new MesosDriverDescription("d1", "jar", 100, 1, true, command,
@@ -291,4 +294,16 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi
assert(state.launchedDrivers.isEmpty)
assert(state.finishedDrivers.size == 1)
}
+
+ test("Declines offer with refuse seconds = 120.") {
+ setScheduler()
+
+ val filter = Filters.newBuilder().setRefuseSeconds(120).build()
+ val offerId = OfferID.newBuilder().setValue("o1").build()
+ val offer = Utils.createOffer(offerId.getValue, "s1", 1000, 1)
+
+ scheduler.resourceOffers(driver, Collections.singletonList(offer))
+
+ verify(driver, times(1)).declineOffer(offerId, filter)
+ }
}
diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala
index 78346e974495..98033bec6dd6 100644
--- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala
+++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala
@@ -552,17 +552,14 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite
override protected def getShuffleClient(): MesosExternalShuffleClient = shuffleClient
// override to avoid race condition with the driver thread on `mesosDriver`
- override def startScheduler(newDriver: SchedulerDriver): Unit = {
- mesosDriver = newDriver
- }
+ override def startScheduler(newDriver: SchedulerDriver): Unit = {}
override def stopExecutors(): Unit = {
stopCalled = true
}
-
- markRegistered()
}
backend.start()
+ backend.registered(driver, Utils.TEST_FRAMEWORK_ID, Utils.TEST_MASTER_INFO)
backend
}
diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala
index 7ebb294aa908..2a67cbc913ff 100644
--- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala
+++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala
@@ -28,6 +28,17 @@ import org.mockito.{ArgumentCaptor, Matchers}
import org.mockito.Mockito._
object Utils {
+
+ val TEST_FRAMEWORK_ID = FrameworkID.newBuilder()
+ .setValue("test-framework-id")
+ .build()
+
+ val TEST_MASTER_INFO = MasterInfo.newBuilder()
+ .setId("test-master")
+ .setIp(0)
+ .setPort(0)
+ .build()
+
def createOffer(
offerId: String,
slaveId: String,
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index e86bd5459311..ccb0f8fdbbc2 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -748,14 +748,6 @@ private[spark] class Client(
.map { case (k, v) => (k.substring(amEnvPrefix.length), v) }
.foreach { case (k, v) => YarnSparkHadoopUtil.addPathToEnvironment(env, k, v) }
- // Keep this for backwards compatibility but users should move to the config
- sys.env.get("SPARK_YARN_USER_ENV").foreach { userEnvs =>
- // Allow users to specify some environment variables.
- YarnSparkHadoopUtil.setEnvFromInputString(env, userEnvs)
- // Pass SPARK_YARN_USER_ENV itself to the AM so it can use it to set up executor environments.
- env("SPARK_YARN_USER_ENV") = userEnvs
- }
-
// If pyFiles contains any .py files, we need to add LOCALIZED_PYTHON_DIR to the PYTHONPATH
// of the container processes too. Add all non-.py files directly to PYTHONPATH.
//
@@ -782,35 +774,7 @@ private[spark] class Client(
sparkConf.setExecutorEnv("PYTHONPATH", pythonPathStr)
}
- // In cluster mode, if the deprecated SPARK_JAVA_OPTS is set, we need to propagate it to
- // executors. But we can't just set spark.executor.extraJavaOptions, because the driver's
- // SparkContext will not let that set spark* system properties, which is expected behavior for
- // Yarn clients. So propagate it through the environment.
- //
- // Note that to warn the user about the deprecation in cluster mode, some code from
- // SparkConf#validateSettings() is duplicated here (to avoid triggering the condition
- // described above).
if (isClusterMode) {
- sys.env.get("SPARK_JAVA_OPTS").foreach { value =>
- val warning =
- s"""
- |SPARK_JAVA_OPTS was detected (set to '$value').
- |This is deprecated in Spark 1.0+.
- |
- |Please instead use:
- | - ./spark-submit with conf/spark-defaults.conf to set defaults for an application
- | - ./spark-submit with --driver-java-options to set -X options for a driver
- | - spark.executor.extraJavaOptions to set -X options for executors
- """.stripMargin
- logWarning(warning)
- for (proc <- Seq("driver", "executor")) {
- val key = s"spark.$proc.extraJavaOptions"
- if (sparkConf.contains(key)) {
- throw new SparkException(s"Found both $key and SPARK_JAVA_OPTS. Use only the former.")
- }
- }
- env("SPARK_JAVA_OPTS") = value
- }
// propagate PYSPARK_DRIVER_PYTHON and PYSPARK_PYTHON to driver in cluster mode
Seq("PYSPARK_DRIVER_PYTHON", "PYSPARK_PYTHON").foreach { envname =>
if (!env.contains(envname)) {
@@ -883,8 +847,7 @@ private[spark] class Client(
// Include driver-specific java options if we are launching a driver
if (isClusterMode) {
- val driverOpts = sparkConf.get(DRIVER_JAVA_OPTIONS).orElse(sys.env.get("SPARK_JAVA_OPTS"))
- driverOpts.foreach { opts =>
+ sparkConf.get(DRIVER_JAVA_OPTIONS).foreach { opts =>
javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell)
}
val libraryPaths = Seq(sparkConf.get(DRIVER_LIBRARY_PATH),
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala
index ee85c043b8bc..3f4d236571ff 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala
@@ -143,9 +143,6 @@ private[yarn] class ExecutorRunnable(
sparkConf.get(EXECUTOR_JAVA_OPTIONS).foreach { opts =>
javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell)
}
- sys.env.get("SPARK_JAVA_OPTS").foreach { opts =>
- javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell)
- }
sparkConf.get(EXECUTOR_LIBRARY_PATH).foreach { p =>
prefixEnv = Some(Client.getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(p))))
}
@@ -229,11 +226,6 @@ private[yarn] class ExecutorRunnable(
YarnSparkHadoopUtil.addPathToEnvironment(env, key, value)
}
- // Keep this for backwards compatibility but users should move to the config
- sys.env.get("SPARK_YARN_USER_ENV").foreach { userEnvs =>
- YarnSparkHadoopUtil.setEnvFromInputString(env, userEnvs)
- }
-
// lookup appropriate http scheme for container log urls
val yarnHttpPolicy = conf.get(
YarnConfiguration.YARN_HTTP_POLICY_KEY,
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala
index 2fdb70a73c75..41b7b5d60b03 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala
@@ -60,7 +60,7 @@ private[spark] class CredentialUpdater(
if (remainingTime <= 0) {
credentialUpdater.schedule(credentialUpdaterRunnable, 1, TimeUnit.MINUTES)
} else {
- logInfo(s"Scheduling credentials refresh from HDFS in $remainingTime millis.")
+ logInfo(s"Scheduling credentials refresh from HDFS in $remainingTime ms.")
credentialUpdater.schedule(credentialUpdaterRunnable, remainingTime, TimeUnit.MILLISECONDS)
}
}
@@ -81,8 +81,8 @@ private[spark] class CredentialUpdater(
UserGroupInformation.getCurrentUser.addCredentials(newCredentials)
logInfo("Credentials updated from credentials file.")
- val remainingTime = getTimeOfNextUpdateFromFileName(credentialsStatus.getPath)
- - System.currentTimeMillis()
+ val remainingTime = (getTimeOfNextUpdateFromFileName(credentialsStatus.getPath)
+ - System.currentTimeMillis())
if (remainingTime <= 0) TimeUnit.MINUTES.toMillis(1) else remainingTime
} else {
// If current credential file is older than expected, sleep 1 hour and check again.
@@ -100,6 +100,7 @@ private[spark] class CredentialUpdater(
TimeUnit.HOURS.toMillis(1)
}
+ logInfo(s"Scheduling credentials refresh from HDFS in $timeToNextUpdate ms.")
credentialUpdater.schedule(
credentialUpdaterRunnable, timeToNextUpdate, TimeUnit.MILLISECONDS)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
index 5f50ce1ba68f..cff0efa97993 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
@@ -60,6 +60,14 @@ trait CatalystConf {
* Enables CBO for estimation of plan statistics when set true.
*/
def cboEnabled: Boolean
+
+ /** Enables join reorder in CBO. */
+ def joinReorderEnabled: Boolean
+
+ /** The maximum number of joined nodes allowed in the dynamic programming algorithm. */
+ def joinReorderDPThreshold: Int
+
+ override def clone(): CatalystConf = throw new CloneNotSupportedException()
}
@@ -75,6 +83,11 @@ case class SimpleCatalystConf(
runSQLonFile: Boolean = true,
crossJoinEnabled: Boolean = false,
cboEnabled: Boolean = false,
+ joinReorderEnabled: Boolean = false,
+ joinReorderDPThreshold: Int = 12,
warehousePath: String = "/user/hive/warehouse",
sessionLocalTimeZone: String = TimeZone.getDefault().getID)
- extends CatalystConf
+ extends CatalystConf {
+
+ override def clone(): SimpleCatalystConf = this.copy()
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
index 5b9161551a7a..d4ebdb139fe0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
@@ -310,11 +310,7 @@ object CatalystTypeConverters {
case d: JavaBigInteger => Decimal(d)
case d: Decimal => d
}
- if (decimal.changePrecision(dataType.precision, dataType.scale)) {
- decimal
- } else {
- null
- }
+ decimal.toPrecision(dataType.precision, dataType.scale).orNull
}
override def toScala(catalystValue: Decimal): JavaBigDecimal = {
if (catalystValue == null) null
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 6d569b612de7..93666f14958e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -117,6 +117,8 @@ class Analyzer(
Batch("Hints", fixedPoint,
new ResolveHints.ResolveBroadcastHints(conf),
ResolveHints.RemoveAllHints),
+ Batch("Simple Sanity Check", Once,
+ LookupFunctions),
Batch("Substitution", fixedPoint,
CTESubstitution,
WindowsSubstitution,
@@ -596,7 +598,7 @@ class Analyzer(
execute(child)
}
view.copy(child = newChild)
- case p @ SubqueryAlias(_, view: View, _) =>
+ case p @ SubqueryAlias(_, view: View) =>
val newChild = resolveRelation(view)
p.copy(child = newChild)
case _ => plan
@@ -604,7 +606,11 @@ class Analyzer(
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved =>
- i.copy(table = EliminateSubqueryAliases(lookupTableFromCatalog(u)))
+ lookupTableFromCatalog(u).canonicalized match {
+ case v: View =>
+ u.failAnalysis(s"Inserting into a view is not allowed. View: ${v.desc.identifier}.")
+ case other => i.copy(table = other)
+ }
case u: UnresolvedRelation => resolveRelation(u)
}
@@ -1038,6 +1044,25 @@ class Analyzer(
}
}
+ /**
+ * Checks whether a function identifier referenced by an [[UnresolvedFunction]] is defined in the
+ * function registry. Note that this rule doesn't try to resolve the [[UnresolvedFunction]]. It
+ * only performs simple existence check according to the function identifier to quickly identify
+ * undefined functions without triggering relation resolution, which may incur potentially
+ * expensive partition/schema discovery process in some cases.
+ *
+ * @see [[ResolveFunctions]]
+ * @see https://issues.apache.org/jira/browse/SPARK-19737
+ */
+ object LookupFunctions extends Rule[LogicalPlan] {
+ override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions {
+ case f: UnresolvedFunction if !catalog.functionExists(f.name) =>
+ withPosition(f) {
+ throw new NoSuchFunctionException(f.name.database.getOrElse("default"), f.name.funcName)
+ }
+ }
+ }
+
/**
* Replaces [[UnresolvedFunction]]s with concrete [[Expression]]s.
*/
@@ -2338,7 +2363,7 @@ class Analyzer(
*/
object EliminateSubqueryAliases extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
- case SubqueryAlias(_, child, _) => child
+ case SubqueryAlias(_, child) => child
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 7529f9028498..d32fbeb4e91e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -44,6 +44,18 @@ trait CheckAnalysis extends PredicateHelper {
}).length > 1
}
+ protected def hasMapType(dt: DataType): Boolean = {
+ dt.existsRecursively(_.isInstanceOf[MapType])
+ }
+
+ protected def mapColumnInSetOperation(plan: LogicalPlan): Option[Attribute] = plan match {
+ case _: Intersect | _: Except | _: Distinct =>
+ plan.output.find(a => hasMapType(a.dataType))
+ case d: Deduplicate =>
+ d.keys.find(a => hasMapType(a.dataType))
+ case _ => None
+ }
+
private def checkLimitClause(limitExpr: Expression): Unit = {
limitExpr match {
case e if !e.foldable => failAnalysis(
@@ -121,8 +133,7 @@ trait CheckAnalysis extends PredicateHelper {
if (conditions.isEmpty && query.output.size != 1) {
failAnalysis(
s"Scalar subquery must return only one column, but got ${query.output.size}")
- }
- else if (conditions.nonEmpty) {
+ } else if (conditions.nonEmpty) {
// Collect the columns from the subquery for further checking.
var subqueryColumns = conditions.flatMap(_.references).filter(query.output.contains)
@@ -200,7 +211,7 @@ trait CheckAnalysis extends PredicateHelper {
s"filter expression '${f.condition.sql}' " +
s"of type ${f.condition.dataType.simpleString} is not a boolean.")
- case f @ Filter(condition, child) =>
+ case Filter(condition, _) =>
splitConjunctivePredicates(condition).foreach {
case _: PredicateSubquery | Not(_: PredicateSubquery) =>
case e if PredicateSubquery.hasNullAwarePredicateWithinNot(e) =>
@@ -374,6 +385,14 @@ trait CheckAnalysis extends PredicateHelper {
|Conflicting attributes: ${conflictingAttributes.mkString(",")}
""".stripMargin)
+ // TODO: although map type is not orderable, technically map type should be able to be
+ // used in equality comparison, remove this type check once we support it.
+ case o if mapColumnInSetOperation(o).isDefined =>
+ val mapCol = mapColumnInSetOperation(o).get
+ failAnalysis("Cannot have map type columns in DataFrame which calls " +
+ s"set operations(intersect, except, etc.), but the type of column ${mapCol.name} " +
+ "is " + mapCol.dataType.simpleString)
+
case o if o.expressions.exists(!_.deterministic) &&
!o.isInstanceOf[Project] && !o.isInstanceOf[Filter] &&
!o.isInstanceOf[Aggregate] && !o.isInstanceOf[Window] =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 9c9465f6b8de..0dcb44081f60 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -64,6 +64,8 @@ trait FunctionRegistry {
/** Clear all registered functions. */
def clear(): Unit
+ /** Create a copy of this registry with identical functions as this registry. */
+ override def clone(): FunctionRegistry = throw new CloneNotSupportedException()
}
class SimpleFunctionRegistry extends FunctionRegistry {
@@ -107,7 +109,7 @@ class SimpleFunctionRegistry extends FunctionRegistry {
functionBuilders.clear()
}
- def copy(): SimpleFunctionRegistry = synchronized {
+ override def clone(): SimpleFunctionRegistry = synchronized {
val registry = new SimpleFunctionRegistry
functionBuilders.iterator.foreach { case (name, (info, builder)) =>
registry.registerFunction(name, info, builder)
@@ -150,6 +152,7 @@ object EmptyFunctionRegistry extends FunctionRegistry {
throw new UnsupportedOperationException
}
+ override def clone(): FunctionRegistry = this
}
@@ -421,6 +424,9 @@ object FunctionRegistry {
expression[BitwiseOr]("|"),
expression[BitwiseXor]("^"),
+ // json
+ expression[StructToJson]("to_json"),
+
// Cast aliases (SPARK-16730)
castAlias("boolean", BooleanType),
castAlias("tinyint", ByteType),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
index 397f5cfe2a54..a9ff61e0e880 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
@@ -51,6 +51,37 @@ object UnsupportedOperationChecker {
subplan.collect { case a: Aggregate if a.isStreaming => a }
}
+ val mapGroupsWithStates = plan.collect {
+ case f: FlatMapGroupsWithState if f.isStreaming && f.isMapGroupsWithState => f
+ }
+
+ // Disallow multiple `mapGroupsWithState`s.
+ if (mapGroupsWithStates.size >= 2) {
+ throwError(
+ "Multiple mapGroupsWithStates are not supported on a streaming DataFrames/Datasets")(plan)
+ }
+
+ val flatMapGroupsWithStates = plan.collect {
+ case f: FlatMapGroupsWithState if f.isStreaming && !f.isMapGroupsWithState => f
+ }
+
+ // Disallow mixing `mapGroupsWithState`s and `flatMapGroupsWithState`s
+ if (mapGroupsWithStates.nonEmpty && flatMapGroupsWithStates.nonEmpty) {
+ throwError(
+ "Mixing mapGroupsWithStates and flatMapGroupsWithStates are not supported on a " +
+ "streaming DataFrames/Datasets")(plan)
+ }
+
+ // Only allow multiple `FlatMapGroupsWithState(Append)`s in append mode.
+ if (flatMapGroupsWithStates.size >= 2 && (
+ outputMode != InternalOutputModes.Append ||
+ flatMapGroupsWithStates.exists(_.outputMode != InternalOutputModes.Append)
+ )) {
+ throwError(
+ "Multiple flatMapGroupsWithStates are not supported when they are not all in append mode" +
+ " or the output mode is not append on a streaming DataFrames/Datasets")(plan)
+ }
+
// Disallow multiple streaming aggregations
val aggregates = collectStreamingAggregates(plan)
@@ -116,9 +147,49 @@ object UnsupportedOperationChecker {
throwError("Commands like CreateTable*, AlterTable*, Show* are not supported with " +
"streaming DataFrames/Datasets")
- case m: MapGroupsWithState if collectStreamingAggregates(m).nonEmpty =>
- throwError("(map/flatMap)GroupsWithState is not supported after aggregation on a " +
- "streaming DataFrame/Dataset")
+ // mapGroupsWithState: Allowed only when no aggregation + Update output mode
+ case m: FlatMapGroupsWithState if m.isStreaming && m.isMapGroupsWithState =>
+ if (collectStreamingAggregates(plan).isEmpty) {
+ if (outputMode != InternalOutputModes.Update) {
+ throwError("mapGroupsWithState is not supported with " +
+ s"$outputMode output mode on a streaming DataFrame/Dataset")
+ } else {
+ // Allowed when no aggregation + Update output mode
+ }
+ } else {
+ throwError("mapGroupsWithState is not supported with aggregation " +
+ "on a streaming DataFrame/Dataset")
+ }
+
+ // flatMapGroupsWithState without aggregation
+ case m: FlatMapGroupsWithState
+ if m.isStreaming && collectStreamingAggregates(plan).isEmpty =>
+ m.outputMode match {
+ case InternalOutputModes.Update =>
+ if (outputMode != InternalOutputModes.Update) {
+ throwError("flatMapGroupsWithState in update mode is not supported with " +
+ s"$outputMode output mode on a streaming DataFrame/Dataset")
+ }
+ case InternalOutputModes.Append =>
+ if (outputMode != InternalOutputModes.Append) {
+ throwError("flatMapGroupsWithState in append mode is not supported with " +
+ s"$outputMode output mode on a streaming DataFrame/Dataset")
+ }
+ }
+
+ // flatMapGroupsWithState(Update) with aggregation
+ case m: FlatMapGroupsWithState
+ if m.isStreaming && m.outputMode == InternalOutputModes.Update
+ && collectStreamingAggregates(plan).nonEmpty =>
+ throwError("flatMapGroupsWithState in update mode is not supported with " +
+ "aggregation on a streaming DataFrame/Dataset")
+
+ // flatMapGroupsWithState(Append) with aggregation
+ case m: FlatMapGroupsWithState
+ if m.isStreaming && m.outputMode == InternalOutputModes.Append
+ && collectStreamingAggregates(m).nonEmpty =>
+ throwError("flatMapGroupsWithState in append mode is not supported after " +
+ s"aggregation on a streaming DataFrame/Dataset")
case d: Deduplicate if collectStreamingAggregates(d).nonEmpty =>
throwError("dropDuplicates is not supported after aggregation on a " +
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala
index 31eded4deba7..08a01e860189 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.catalog
import org.apache.spark.sql.catalyst.analysis.{FunctionAlreadyExistsException, NoSuchDatabaseException, NoSuchFunctionException, NoSuchTableException}
import org.apache.spark.sql.catalyst.expressions.Expression
-
+import org.apache.spark.sql.types.StructType
/**
* Interface for the system catalog (of functions, partitions, tables, and databases).
@@ -104,6 +104,19 @@ abstract class ExternalCatalog {
*/
def alterTable(tableDefinition: CatalogTable): Unit
+ /**
+ * Alter the schema of a table identified by the provided database and table name. The new schema
+ * should still contain the existing bucket columns and partition columns used by the table. This
+ * method will also update any Spark SQL-related parameters stored as Hive table properties (such
+ * as the schema itself).
+ *
+ * @param db Database that table to alter schema for exists in
+ * @param table Name of table to alter schema for
+ * @param schema Updated schema to be used for the table (must contain existing partition and
+ * bucket columns)
+ */
+ def alterTableSchema(db: String, table: String, schema: StructType): Unit
+
def getTable(db: String, table: String): CatalogTable
def getTableOption(db: String, table: String): Option[CatalogTable]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala
index 58ced549bafe..a418edc302d9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.catalog
+import java.net.URI
+
import org.apache.hadoop.fs.Path
import org.apache.hadoop.util.Shell
@@ -162,6 +164,30 @@ object CatalogUtils {
BucketSpec(numBuckets, normalizedBucketCols, normalizedSortCols)
}
+ /**
+ * Convert URI to String.
+ * Since URI.toString does not decode the uri, e.g. change '%25' to '%'.
+ * Here we create a hadoop Path with the given URI, and rely on Path.toString
+ * to decode the uri
+ * @param uri the URI of the path
+ * @return the String of the path
+ */
+ def URIToString(uri: URI): String = {
+ new Path(uri).toString
+ }
+
+ /**
+ * Convert String to URI.
+ * Since new URI(string) does not encode string, e.g. change '%' to '%25'.
+ * Here we create a hadoop Path with the given String, and rely on Path.toUri
+ * to encode the string
+ * @param str the String of the path
+ * @return the URI of the path
+ */
+ def stringToURI(str: String): URI = {
+ new Path(str).toUri
+ }
+
private def normalizeColumnName(
tableName: String,
tableCols: Seq[String],
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala
index 340e8451f14e..5cc6b0abc6fd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala
@@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.util.StringUtils
+import org.apache.spark.sql.types.StructType
/**
* An in-memory (ephemeral) implementation of the system catalog.
@@ -202,7 +203,7 @@ class InMemoryCatalog(
tableDefinition.storage.locationUri.isEmpty
val tableWithLocation = if (needDefaultTableLocation) {
- val defaultTableLocation = new Path(catalog(db).db.locationUri, table)
+ val defaultTableLocation = new Path(new Path(catalog(db).db.locationUri), table)
try {
val fs = defaultTableLocation.getFileSystem(hadoopConfig)
fs.mkdirs(defaultTableLocation)
@@ -211,7 +212,7 @@ class InMemoryCatalog(
throw new SparkException(s"Unable to create table $table as failed " +
s"to create its directory $defaultTableLocation", e)
}
- tableDefinition.withNewStorage(locationUri = Some(defaultTableLocation.toUri.toString))
+ tableDefinition.withNewStorage(locationUri = Some(defaultTableLocation.toUri))
} else {
tableDefinition
}
@@ -274,7 +275,7 @@ class InMemoryCatalog(
"Managed table should always have table location, as we will assign a default location " +
"to it if it doesn't have one.")
val oldDir = new Path(oldDesc.table.location)
- val newDir = new Path(catalog(db).db.locationUri, newName)
+ val newDir = new Path(new Path(catalog(db).db.locationUri), newName)
try {
val fs = oldDir.getFileSystem(hadoopConfig)
fs.rename(oldDir, newDir)
@@ -283,7 +284,7 @@ class InMemoryCatalog(
throw new SparkException(s"Unable to rename table $oldName to $newName as failed " +
s"to rename its directory $oldDir", e)
}
- oldDesc.table = oldDesc.table.withNewStorage(locationUri = Some(newDir.toUri.toString))
+ oldDesc.table = oldDesc.table.withNewStorage(locationUri = Some(newDir.toUri))
}
catalog(db).tables.put(newName, oldDesc)
@@ -297,6 +298,15 @@ class InMemoryCatalog(
catalog(db).tables(tableDefinition.identifier.table).table = tableDefinition
}
+ override def alterTableSchema(
+ db: String,
+ table: String,
+ schema: StructType): Unit = synchronized {
+ requireTableExists(db, table)
+ val origTable = catalog(db).tables(table).table
+ catalog(db).tables(table).table = origTable.copy(schema = schema)
+ }
+
override def getTable(db: String, table: String): CatalogTable = synchronized {
requireTableExists(db, table)
catalog(db).tables(table).table
@@ -389,7 +399,7 @@ class InMemoryCatalog(
existingParts.put(
p.spec,
- p.copy(storage = p.storage.copy(locationUri = Some(partitionPath.toString))))
+ p.copy(storage = p.storage.copy(locationUri = Some(partitionPath.toUri))))
}
}
@@ -462,7 +472,7 @@ class InMemoryCatalog(
}
oldPartition.copy(
spec = newSpec,
- storage = oldPartition.storage.copy(locationUri = Some(newPartPath.toString)))
+ storage = oldPartition.storage.copy(locationUri = Some(newPartPath.toUri)))
} else {
oldPartition.copy(spec = newSpec)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
index f6412e42c13d..bfcdb70fe47c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.catalog
+import java.net.URI
import javax.annotation.concurrent.GuardedBy
import scala.collection.mutable
@@ -49,7 +50,6 @@ object SessionCatalog {
class SessionCatalog(
externalCatalog: ExternalCatalog,
globalTempViewManager: GlobalTempViewManager,
- functionResourceLoader: FunctionResourceLoader,
functionRegistry: FunctionRegistry,
conf: CatalystConf,
hadoopConf: Configuration,
@@ -65,16 +65,19 @@ class SessionCatalog(
this(
externalCatalog,
new GlobalTempViewManager("global_temp"),
- DummyFunctionResourceLoader,
functionRegistry,
conf,
new Configuration(),
CatalystSqlParser)
+ functionResourceLoader = DummyFunctionResourceLoader
}
// For testing only.
def this(externalCatalog: ExternalCatalog) {
- this(externalCatalog, new SimpleFunctionRegistry, new SimpleCatalystConf(true))
+ this(
+ externalCatalog,
+ new SimpleFunctionRegistry,
+ SimpleCatalystConf(caseSensitiveAnalysis = true))
}
/** List of temporary tables, mapping from table name to their logical plan. */
@@ -88,6 +91,8 @@ class SessionCatalog(
@GuardedBy("this")
protected var currentDb = formatDatabaseName(DEFAULT_DATABASE)
+ @volatile var functionResourceLoader: FunctionResourceLoader = _
+
/**
* Checks if the given name conforms the Hive standard ("[a-zA-z_0-9]+"),
* i.e. if this name only contains characters, numbers, and _.
@@ -131,10 +136,10 @@ class SessionCatalog(
* does not contain a scheme, this path will not be changed after the default
* FileSystem is changed.
*/
- private def makeQualifiedPath(path: String): Path = {
+ private def makeQualifiedPath(path: URI): URI = {
val hadoopPath = new Path(path)
val fs = hadoopPath.getFileSystem(hadoopConf)
- fs.makeQualified(hadoopPath)
+ fs.makeQualified(hadoopPath).toUri
}
private def requireDbExists(db: String): Unit = {
@@ -170,7 +175,7 @@ class SessionCatalog(
"you cannot create a database with this name.")
}
validateName(dbName)
- val qualifiedPath = makeQualifiedPath(dbDefinition.locationUri).toString
+ val qualifiedPath = makeQualifiedPath(dbDefinition.locationUri)
externalCatalog.createDatabase(
dbDefinition.copy(name = dbName, locationUri = qualifiedPath),
ignoreIfExists)
@@ -228,9 +233,9 @@ class SessionCatalog(
* Get the path for creating a non-default database when database location is not provided
* by users.
*/
- def getDefaultDBPath(db: String): String = {
+ def getDefaultDBPath(db: String): URI = {
val database = formatDatabaseName(db)
- new Path(new Path(conf.warehousePath), database + ".db").toString
+ new Path(new Path(conf.warehousePath), database + ".db").toUri
}
// ----------------------------------------------------------------------------
@@ -254,7 +259,19 @@ class SessionCatalog(
val db = formatDatabaseName(tableDefinition.identifier.database.getOrElse(getCurrentDatabase))
val table = formatTableName(tableDefinition.identifier.table)
validateName(table)
- val newTableDefinition = tableDefinition.copy(identifier = TableIdentifier(table, Some(db)))
+
+ val newTableDefinition = if (tableDefinition.storage.locationUri.isDefined
+ && !tableDefinition.storage.locationUri.get.isAbsolute) {
+ // make the location of the table qualified.
+ val qualifiedTableLocation =
+ makeQualifiedPath(tableDefinition.storage.locationUri.get)
+ tableDefinition.copy(
+ storage = tableDefinition.storage.copy(locationUri = Some(qualifiedTableLocation)),
+ identifier = TableIdentifier(table, Some(db)))
+ } else {
+ tableDefinition.copy(identifier = TableIdentifier(table, Some(db)))
+ }
+
requireDbExists(db)
externalCatalog.createTable(newTableDefinition, ignoreIfExists)
}
@@ -351,11 +368,11 @@ class SessionCatalog(
db, table, loadPath, spec, isOverwrite, inheritTableSpecs, isSrcLocal)
}
- def defaultTablePath(tableIdent: TableIdentifier): String = {
+ def defaultTablePath(tableIdent: TableIdentifier): URI = {
val dbName = formatDatabaseName(tableIdent.database.getOrElse(getCurrentDatabase))
val dbLocation = getDatabaseMetadata(dbName).locationUri
- new Path(new Path(dbLocation), formatTableName(tableIdent.table)).toString
+ new Path(new Path(dbLocation), formatTableName(tableIdent.table)).toUri
}
// ----------------------------------------------
@@ -577,7 +594,7 @@ class SessionCatalog(
val table = formatTableName(name.table)
if (db == globalTempViewManager.database) {
globalTempViewManager.get(table).map { viewDef =>
- SubqueryAlias(table, viewDef, None)
+ SubqueryAlias(table, viewDef)
}.getOrElse(throw new NoSuchTableException(db, table))
} else if (name.database.isDefined || !tempTables.contains(table)) {
val metadata = externalCatalog.getTable(db, table)
@@ -590,17 +607,17 @@ class SessionCatalog(
desc = metadata,
output = metadata.schema.toAttributes,
child = parser.parsePlan(viewText))
- SubqueryAlias(table, child, Some(name.copy(table = table, database = Some(db))))
+ SubqueryAlias(table, child)
} else {
val tableRelation = CatalogRelation(
metadata,
// we assume all the columns are nullable.
metadata.dataSchema.asNullable.toAttributes,
metadata.partitionSchema.asNullable.toAttributes)
- SubqueryAlias(table, tableRelation, None)
+ SubqueryAlias(table, tableRelation)
}
} else {
- SubqueryAlias(table, tempTables(table), None)
+ SubqueryAlias(table, tempTables(table))
}
}
}
@@ -986,6 +1003,9 @@ class SessionCatalog(
* by a tuple (resource type, resource uri).
*/
def loadFunctionResources(resources: Seq[FunctionResource]): Unit = {
+ if (functionResourceLoader == null) {
+ throw new IllegalStateException("functionResourceLoader has not yet been initialized")
+ }
resources.foreach(functionResourceLoader.loadResource)
}
@@ -1181,4 +1201,29 @@ class SessionCatalog(
}
}
+ /**
+ * Create a new [[SessionCatalog]] with the provided parameters. `externalCatalog` and
+ * `globalTempViewManager` are `inherited`, while `currentDb` and `tempTables` are copied.
+ */
+ def newSessionCatalogWith(
+ conf: CatalystConf,
+ hadoopConf: Configuration,
+ functionRegistry: FunctionRegistry,
+ parser: ParserInterface): SessionCatalog = {
+ val catalog = new SessionCatalog(
+ externalCatalog,
+ globalTempViewManager,
+ functionRegistry,
+ conf,
+ hadoopConf,
+ parser)
+
+ synchronized {
+ catalog.currentDb = currentDb
+ // copy over temporary tables
+ tempTables.foreach(kv => catalog.tempTables.put(kv._1, kv._2))
+ }
+
+ catalog
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
index 887caf07d148..e3631b0c0773 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.catalog
+import java.net.URI
import java.util.Date
import com.google.common.base.Objects
@@ -48,10 +49,7 @@ case class CatalogFunction(
* Storage format, used to describe how a partition or a table is stored.
*/
case class CatalogStorageFormat(
- // TODO(ekl) consider storing this field as java.net.URI for type safety. Note that this must
- // be converted to/from a hadoop Path object using new Path(new URI(locationUri)) and
- // path.toUri respectively before use as a filesystem path due to URI char escaping.
- locationUri: Option[String],
+ locationUri: Option[URI],
inputFormat: Option[String],
outputFormat: Option[String],
serde: Option[String],
@@ -105,7 +103,7 @@ case class CatalogTablePartition(
}
/** Return the partition location, assuming it is specified. */
- def location: String = storage.locationUri.getOrElse {
+ def location: URI = storage.locationUri.getOrElse {
val specString = spec.map { case (k, v) => s"$k=$v" }.mkString(", ")
throw new AnalysisException(s"Partition [$specString] did not specify locationUri")
}
@@ -165,6 +163,11 @@ case class BucketSpec(
* @param tracksPartitionsInCatalog whether this table's partition metadata is stored in the
* catalog. If false, it is inferred automatically based on file
* structure.
+ * @param schemaPresevesCase Whether or not the schema resolved for this table is case-sensitive.
+ * When using a Hive Metastore, this flag is set to false if a case-
+ * sensitive schema was unable to be read from the table properties.
+ * Used to trigger case-sensitive schema inference at query time, when
+ * configured.
*/
case class CatalogTable(
identifier: TableIdentifier,
@@ -182,7 +185,8 @@ case class CatalogTable(
viewText: Option[String] = None,
comment: Option[String] = None,
unsupportedFeatures: Seq[String] = Seq.empty,
- tracksPartitionsInCatalog: Boolean = false) {
+ tracksPartitionsInCatalog: Boolean = false,
+ schemaPreservesCase: Boolean = true) {
import CatalogTable._
@@ -210,7 +214,7 @@ case class CatalogTable(
}
/** Return the table location, assuming it is specified. */
- def location: String = storage.locationUri.getOrElse {
+ def location: URI = storage.locationUri.getOrElse {
throw new AnalysisException(s"table $identifier did not specify locationUri")
}
@@ -241,7 +245,7 @@ case class CatalogTable(
/** Syntactic sugar to update a field in `storage`. */
def withNewStorage(
- locationUri: Option[String] = storage.locationUri,
+ locationUri: Option[URI] = storage.locationUri,
inputFormat: Option[String] = storage.inputFormat,
outputFormat: Option[String] = storage.outputFormat,
compressed: Boolean = false,
@@ -337,7 +341,7 @@ object CatalogTableType {
case class CatalogDatabase(
name: String,
description: String,
- locationUri: String,
+ locationUri: URI,
properties: Map[String, String])
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index c062e4e84bcd..35ca2a0aa53a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -346,7 +346,7 @@ package object dsl {
orderSpec: Seq[SortOrder]): LogicalPlan =
Window(windowExpressions, partitionSpec, orderSpec, logicalPlan)
- def subquery(alias: Symbol): LogicalPlan = SubqueryAlias(alias.name, logicalPlan, None)
+ def subquery(alias: Symbol): LogicalPlan = SubqueryAlias(alias.name, logicalPlan)
def except(otherPlan: LogicalPlan): LogicalPlan = Except(logicalPlan, otherPlan)
@@ -368,7 +368,10 @@ package object dsl {
analysis.UnresolvedRelation(TableIdentifier(tableName)),
Map.empty, logicalPlan, overwrite, false)
- def as(alias: String): LogicalPlan = SubqueryAlias(alias, logicalPlan, None)
+ def as(alias: String): LogicalPlan = SubqueryAlias(alias, logicalPlan)
+
+ def coalesce(num: Integer): LogicalPlan =
+ Repartition(num, shuffle = false, logicalPlan)
def repartition(num: Integer): LogicalPlan =
Repartition(num, shuffle = true, logicalPlan)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 0782143d465b..93fc565a5341 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -45,8 +45,8 @@ import org.apache.spark.util.Utils
object ExpressionEncoder {
def apply[T : TypeTag](): ExpressionEncoder[T] = {
// We convert the not-serializable TypeTag into StructType and ClassTag.
- val mirror = typeTag[T].mirror
- val tpe = typeTag[T].tpe
+ val mirror = ScalaReflection.mirror
+ val tpe = typeTag[T].in(mirror).tpe
if (ScalaReflection.optionOfProductType(tpe)) {
throw new UnsupportedOperationException(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index a36d3507d92e..1049915986d9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
-
+import org.apache.spark.unsafe.types.UTF8String.{IntWrapper, LongWrapper}
object Cast {
@@ -277,9 +277,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
// LongConverter
private[this] def castToLong(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[UTF8String](_, s => try s.toLong catch {
- case _: NumberFormatException => null
- })
+ val result = new LongWrapper()
+ buildCast[UTF8String](_, s => if (s.toLong(result)) result.value else null)
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1L else 0L)
case DateType =>
@@ -293,9 +292,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
// IntConverter
private[this] def castToInt(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[UTF8String](_, s => try s.toInt catch {
- case _: NumberFormatException => null
- })
+ val result = new IntWrapper()
+ buildCast[UTF8String](_, s => if (s.toInt(result)) result.value else null)
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1 else 0)
case DateType =>
@@ -309,8 +307,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
// ShortConverter
private[this] def castToShort(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[UTF8String](_, s => try s.toShort catch {
- case _: NumberFormatException => null
+ val result = new IntWrapper()
+ buildCast[UTF8String](_, s => if (s.toShort(result)) {
+ result.value.toShort
+ } else {
+ null
})
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1.toShort else 0.toShort)
@@ -325,8 +326,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
// ByteConverter
private[this] def castToByte(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[UTF8String](_, s => try s.toByte catch {
- case _: NumberFormatException => null
+ val result = new IntWrapper()
+ buildCast[UTF8String](_, s => if (s.toByte(result)) {
+ result.value.toByte
+ } else {
+ null
})
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1.toByte else 0.toByte)
@@ -348,6 +352,15 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
if (value.changePrecision(decimalType.precision, decimalType.scale)) value else null
}
+ /**
+ * Create new `Decimal` with precision and scale given in `decimalType` (if any),
+ * returning null if it overflows or creating a new `value` and returning it if successful.
+ *
+ */
+ private[this] def toPrecision(value: Decimal, decimalType: DecimalType): Decimal =
+ value.toPrecision(decimalType.precision, decimalType.scale).orNull
+
+
private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match {
case StringType =>
buildCast[UTF8String](_, s => try {
@@ -356,14 +369,14 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
case _: NumberFormatException => null
})
case BooleanType =>
- buildCast[Boolean](_, b => changePrecision(if (b) Decimal.ONE else Decimal.ZERO, target))
+ buildCast[Boolean](_, b => toPrecision(if (b) Decimal.ONE else Decimal.ZERO, target))
case DateType =>
buildCast[Int](_, d => null) // date can't cast to decimal in Hive
case TimestampType =>
// Note that we lose precision here.
buildCast[Long](_, t => changePrecision(Decimal(timestampToDouble(t)), target))
case dt: DecimalType =>
- b => changePrecision(b.asInstanceOf[Decimal].clone(), target)
+ b => toPrecision(b.asInstanceOf[Decimal], target)
case t: IntegralType =>
b => changePrecision(Decimal(t.integral.asInstanceOf[Integral[Any]].toLong(b)), target)
case x: FractionalType =>
@@ -503,11 +516,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
case TimestampType => castToTimestampCode(from, ctx)
case CalendarIntervalType => castToIntervalCode(from)
case BooleanType => castToBooleanCode(from)
- case ByteType => castToByteCode(from)
- case ShortType => castToShortCode(from)
- case IntegerType => castToIntCode(from)
+ case ByteType => castToByteCode(from, ctx)
+ case ShortType => castToShortCode(from, ctx)
+ case IntegerType => castToIntCode(from, ctx)
case FloatType => castToFloatCode(from)
- case LongType => castToLongCode(from)
+ case LongType => castToLongCode(from, ctx)
case DoubleType => castToDoubleCode(from)
case array: ArrayType =>
@@ -734,13 +747,16 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
(c, evPrim, evNull) => s"$evPrim = $c != 0;"
}
- private[this] def castToByteCode(from: DataType): CastFunction = from match {
+ private[this] def castToByteCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
case StringType =>
+ val wrapper = ctx.freshName("wrapper")
+ ctx.addMutableState("UTF8String.IntWrapper", wrapper,
+ s"$wrapper = new UTF8String.IntWrapper();")
(c, evPrim, evNull) =>
s"""
- try {
- $evPrim = $c.toByte();
- } catch (java.lang.NumberFormatException e) {
+ if ($c.toByte($wrapper)) {
+ $evPrim = (byte) $wrapper.value;
+ } else {
$evNull = true;
}
"""
@@ -756,13 +772,18 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
(c, evPrim, evNull) => s"$evPrim = (byte) $c;"
}
- private[this] def castToShortCode(from: DataType): CastFunction = from match {
+ private[this] def castToShortCode(
+ from: DataType,
+ ctx: CodegenContext): CastFunction = from match {
case StringType =>
+ val wrapper = ctx.freshName("wrapper")
+ ctx.addMutableState("UTF8String.IntWrapper", wrapper,
+ s"$wrapper = new UTF8String.IntWrapper();")
(c, evPrim, evNull) =>
s"""
- try {
- $evPrim = $c.toShort();
- } catch (java.lang.NumberFormatException e) {
+ if ($c.toShort($wrapper)) {
+ $evPrim = (short) $wrapper.value;
+ } else {
$evNull = true;
}
"""
@@ -778,13 +799,16 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
(c, evPrim, evNull) => s"$evPrim = (short) $c;"
}
- private[this] def castToIntCode(from: DataType): CastFunction = from match {
+ private[this] def castToIntCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
case StringType =>
+ val wrapper = ctx.freshName("wrapper")
+ ctx.addMutableState("UTF8String.IntWrapper", wrapper,
+ s"$wrapper = new UTF8String.IntWrapper();")
(c, evPrim, evNull) =>
s"""
- try {
- $evPrim = $c.toInt();
- } catch (java.lang.NumberFormatException e) {
+ if ($c.toInt($wrapper)) {
+ $evPrim = $wrapper.value;
+ } else {
$evNull = true;
}
"""
@@ -800,13 +824,17 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
(c, evPrim, evNull) => s"$evPrim = (int) $c;"
}
- private[this] def castToLongCode(from: DataType): CastFunction = from match {
+ private[this] def castToLongCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
case StringType =>
+ val wrapper = ctx.freshName("wrapper")
+ ctx.addMutableState("UTF8String.LongWrapper", wrapper,
+ s"$wrapper = new UTF8String.LongWrapper();")
+
(c, evPrim, evNull) =>
s"""
- try {
- $evPrim = $c.toLong();
- } catch (java.lang.NumberFormatException e) {
+ if ($c.toLong($wrapper)) {
+ $evPrim = $wrapper.value;
+ } else {
$evNull = true;
}
"""
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
index fa5dea684114..c2211ae5d594 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
@@ -84,14 +84,8 @@ case class CheckOverflow(child: Expression, dataType: DecimalType) extends Unary
override def nullable: Boolean = true
- override def nullSafeEval(input: Any): Any = {
- val d = input.asInstanceOf[Decimal].clone()
- if (d.changePrecision(dataType.precision, dataType.scale)) {
- d
- } else {
- null
- }
- }
+ override def nullSafeEval(input: Any): Any =
+ input.asInstanceOf[Decimal].toPrecision(dataType.precision, dataType.scale).orNull
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, eval => {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
index 2d9c2e42064b..03101b4bfc5f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
+import java.math.{BigDecimal, RoundingMode}
import java.security.{MessageDigest, NoSuchAlgorithmException}
import java.util.zip.CRC32
@@ -580,7 +581,7 @@ object XxHash64Function extends InterpretedHashFunction {
* We should use this hash function for both shuffle and bucket of Hive tables, so that
* we can guarantee shuffle and bucketing have same data distribution
*
- * TODO: Support Decimal and date related types
+ * TODO: Support date related types
*/
@ExpressionDescription(
usage = "_FUNC_(expr1, expr2, ...) - Returns a hash value of the arguments.")
@@ -635,6 +636,16 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
override protected def genHashBytes(b: String, result: String): String =
s"$result = $hasherClassName.hashUnsafeBytes($b, Platform.BYTE_ARRAY_OFFSET, $b.length);"
+ override protected def genHashDecimal(
+ ctx: CodegenContext,
+ d: DecimalType,
+ input: String,
+ result: String): String = {
+ s"""
+ $result = ${HiveHashFunction.getClass.getName.stripSuffix("$")}.normalizeDecimal(
+ $input.toJavaBigDecimal()).hashCode();"""
+ }
+
override protected def genHashCalendarInterval(input: String, result: String): String = {
s"""
$result = (31 * $hasherClassName.hashInt($input.months)) +
@@ -732,6 +743,44 @@ object HiveHashFunction extends InterpretedHashFunction {
HiveHasher.hashUnsafeBytes(base, offset, len)
}
+ private val HIVE_DECIMAL_MAX_PRECISION = 38
+ private val HIVE_DECIMAL_MAX_SCALE = 38
+
+ // Mimics normalization done for decimals in Hive at HiveDecimalV1.normalize()
+ def normalizeDecimal(input: BigDecimal): BigDecimal = {
+ if (input == null) return null
+
+ def trimDecimal(input: BigDecimal) = {
+ var result = input
+ if (result.compareTo(BigDecimal.ZERO) == 0) {
+ // Special case for 0, because java doesn't strip zeros correctly on that number.
+ result = BigDecimal.ZERO
+ } else {
+ result = result.stripTrailingZeros
+ if (result.scale < 0) {
+ // no negative scale decimals
+ result = result.setScale(0)
+ }
+ }
+ result
+ }
+
+ var result = trimDecimal(input)
+ val intDigits = result.precision - result.scale
+ if (intDigits > HIVE_DECIMAL_MAX_PRECISION) {
+ return null
+ }
+
+ val maxScale = Math.min(HIVE_DECIMAL_MAX_SCALE,
+ Math.min(HIVE_DECIMAL_MAX_PRECISION - intDigits, result.scale))
+ if (result.scale > maxScale) {
+ result = result.setScale(maxScale, RoundingMode.HALF_UP)
+ // Trimming is again necessary, because rounding may introduce new trailing 0's.
+ result = trimDecimal(result)
+ }
+ result
+ }
+
override def hash(value: Any, dataType: DataType, seed: Long): Long = {
value match {
case null => 0
@@ -785,7 +834,10 @@ object HiveHashFunction extends InterpretedHashFunction {
}
result
- case _ => super.hash(value, dataType, 0)
+ case d: Decimal =>
+ normalizeDecimal(d.toJavaBigDecimal).hashCode()
+
+ case _ => super.hash(value, dataType, seed)
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
index 1e690a446951..18b5f2f7ed2e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
@@ -23,11 +23,12 @@ import scala.util.parsing.combinator.RegexParsers
import com.fasterxml.jackson.core._
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.json._
-import org.apache.spark.sql.catalyst.util.ParseModes
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, ParseModes}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
@@ -330,7 +331,7 @@ case class GetJsonObject(json: Expression, path: Expression)
// scalastyle:off line.size.limit
@ExpressionDescription(
- usage = "_FUNC_(jsonStr, p1, p2, ..., pn) - Return a tuple like the function get_json_object, but it takes multiple names. All the input parameters and output column types are string.",
+ usage = "_FUNC_(jsonStr, p1, p2, ..., pn) - Returns a tuple like the function get_json_object, but it takes multiple names. All the input parameters and output column types are string.",
extended = """
Examples:
> SELECT _FUNC_('{"a":1, "b":2}', 'a', 'b');
@@ -480,23 +481,45 @@ case class JsonTuple(children: Seq[Expression])
}
/**
- * Converts an json input string to a [[StructType]] with the specified schema.
+ * Converts an json input string to a [[StructType]] or [[ArrayType]] with the specified schema.
*/
case class JsonToStruct(
- schema: StructType,
+ schema: DataType,
options: Map[String, String],
child: Expression,
timeZoneId: Option[String] = None)
extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes {
override def nullable: Boolean = true
- def this(schema: StructType, options: Map[String, String], child: Expression) =
+ def this(schema: DataType, options: Map[String, String], child: Expression) =
this(schema, options, child, None)
+ override def checkInputDataTypes(): TypeCheckResult = schema match {
+ case _: StructType | ArrayType(_: StructType, _) =>
+ super.checkInputDataTypes()
+ case _ => TypeCheckResult.TypeCheckFailure(
+ s"Input schema ${schema.simpleString} must be a struct or an array of structs.")
+ }
+
+ @transient
+ lazy val rowSchema = schema match {
+ case st: StructType => st
+ case ArrayType(st: StructType, _) => st
+ }
+
+ // This converts parsed rows to the desired output by the given schema.
+ @transient
+ lazy val converter = schema match {
+ case _: StructType =>
+ (rows: Seq[InternalRow]) => if (rows.length == 1) rows.head else null
+ case ArrayType(_: StructType, _) =>
+ (rows: Seq[InternalRow]) => new GenericArrayData(rows)
+ }
+
@transient
lazy val parser =
new JacksonParser(
- schema,
+ rowSchema,
new JSONOptions(options + ("mode" -> ParseModes.FAIL_FAST_MODE), timeZoneId.get))
override def dataType: DataType = schema
@@ -505,11 +528,32 @@ case class JsonToStruct(
copy(timeZoneId = Option(timeZoneId))
override def nullSafeEval(json: Any): Any = {
+ // When input is,
+ // - `null`: `null`.
+ // - invalid json: `null`.
+ // - empty string: `null`.
+ //
+ // When the schema is array,
+ // - json array: `Array(Row(...), ...)`
+ // - json object: `Array(Row(...))`
+ // - empty json array: `Array()`.
+ // - empty json object: `Array(Row(null))`.
+ //
+ // When the schema is a struct,
+ // - json object/array with single element: `Row(...)`
+ // - json array with multiple elements: `null`
+ // - empty json array: `null`.
+ // - empty json object: `Row(null)`.
+
+ // We need `null` if the input string is an empty string. `JacksonParser` can
+ // deal with this but produces `Nil`.
+ if (json.toString.trim.isEmpty) return null
+
try {
- parser.parse(
+ converter(parser.parse(
json.asInstanceOf[UTF8String],
CreateJacksonParser.utf8String,
- identity[UTF8String]).headOption.orNull
+ identity[UTF8String]))
} catch {
case _: SparkSQLJsonProcessingException => null
}
@@ -521,6 +565,17 @@ case class JsonToStruct(
/**
* Converts a [[StructType]] to a json output string.
*/
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(expr[, options]) - Returns a json string with a given struct value",
+ extended = """
+ Examples:
+ > SELECT _FUNC_(named_struct('a', 1, 'b', 2));
+ {"a":1,"b":2}
+ > SELECT _FUNC_(named_struct('time', to_timestamp('2015-08-26', 'yyyy-MM-dd')), map('timestampFormat', 'dd/MM/yyyy'));
+ {"time":"26/08/2015"}
+ """)
+// scalastyle:on line.size.limit
case class StructToJson(
options: Map[String, String],
child: Expression,
@@ -530,6 +585,14 @@ case class StructToJson(
def this(options: Map[String, String], child: Expression) = this(options, child, None)
+ // Used in `FunctionRegistry`
+ def this(child: Expression) = this(Map.empty, child, None)
+ def this(child: Expression, options: Expression) =
+ this(
+ options = StructToJson.convertToMapData(options),
+ child = child,
+ timeZoneId = None)
+
@transient
lazy val writer = new CharArrayWriter()
@@ -570,3 +633,20 @@ case class StructToJson(
override def inputTypes: Seq[AbstractDataType] = StructType :: Nil
}
+
+object StructToJson {
+
+ def convertToMapData(exp: Expression): Map[String, String] = exp match {
+ case m: CreateMap
+ if m.dataType.acceptsType(MapType(StringType, StringType, valueContainsNull = false)) =>
+ val arrayMap = m.eval().asInstanceOf[ArrayBasedMapData]
+ ArrayBasedMapData.toScalaMap(arrayMap).map { case (key, value) =>
+ key.toString -> value.toString
+ }
+ case m: CreateMap =>
+ throw new AnalysisException(
+ s"A type of keys and values in map() must be string, but got ${m.dataType}")
+ case _ =>
+ throw new AnalysisException("Must use a map() function for options")
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
index 65273a77b105..dea5f85cb08c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
@@ -1024,7 +1024,7 @@ abstract class RoundBase(child: Expression, scale: Expression,
child.dataType match {
case _: DecimalType =>
val decimal = input1.asInstanceOf[Decimal]
- if (decimal.changePrecision(decimal.precision, _scale, mode)) decimal else null
+ decimal.toPrecision(decimal.precision, _scale, mode).orNull
case ByteType =>
BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, mode).toByte
case ShortType =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala
new file mode 100644
index 000000000000..b694561e5372
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala
@@ -0,0 +1,297 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import scala.collection.mutable
+
+import org.apache.spark.sql.catalyst.CatalystConf
+import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeSet, Expression, PredicateHelper}
+import org.apache.spark.sql.catalyst.plans.{Inner, InnerLike}
+import org.apache.spark.sql.catalyst.plans.logical.{BinaryNode, Join, LogicalPlan, Project}
+import org.apache.spark.sql.catalyst.rules.Rule
+
+
+/**
+ * Cost-based join reorder.
+ * We may have several join reorder algorithms in the future. This class is the entry of these
+ * algorithms, and chooses which one to use.
+ */
+case class CostBasedJoinReorder(conf: CatalystConf) extends Rule[LogicalPlan] with PredicateHelper {
+ def apply(plan: LogicalPlan): LogicalPlan = {
+ if (!conf.cboEnabled || !conf.joinReorderEnabled) {
+ plan
+ } else {
+ val result = plan transform {
+ case p @ Project(projectList, j @ Join(_, _, _: InnerLike, _)) =>
+ reorder(p, p.outputSet)
+ case j @ Join(_, _, _: InnerLike, _) =>
+ reorder(j, j.outputSet)
+ }
+ // After reordering is finished, convert OrderedJoin back to Join
+ result transform {
+ case oj: OrderedJoin => oj.join
+ }
+ }
+ }
+
+ def reorder(plan: LogicalPlan, output: AttributeSet): LogicalPlan = {
+ val (items, conditions) = extractInnerJoins(plan)
+ val result =
+ // Do reordering if the number of items is appropriate and join conditions exist.
+ // We also need to check if costs of all items can be evaluated.
+ if (items.size > 2 && items.size <= conf.joinReorderDPThreshold && conditions.nonEmpty &&
+ items.forall(_.stats(conf).rowCount.isDefined)) {
+ JoinReorderDP.search(conf, items, conditions, output).getOrElse(plan)
+ } else {
+ plan
+ }
+ // Set consecutive join nodes ordered.
+ replaceWithOrderedJoin(result)
+ }
+
+ /**
+ * Extract consecutive inner joinable items and join conditions.
+ * This method works for bushy trees and left/right deep trees.
+ */
+ private def extractInnerJoins(plan: LogicalPlan): (Seq[LogicalPlan], Set[Expression]) = {
+ plan match {
+ case Join(left, right, _: InnerLike, cond) =>
+ val (leftPlans, leftConditions) = extractInnerJoins(left)
+ val (rightPlans, rightConditions) = extractInnerJoins(right)
+ (leftPlans ++ rightPlans, cond.toSet.flatMap(splitConjunctivePredicates) ++
+ leftConditions ++ rightConditions)
+ case Project(projectList, join) if projectList.forall(_.isInstanceOf[Attribute]) =>
+ extractInnerJoins(join)
+ case _ =>
+ (Seq(plan), Set())
+ }
+ }
+
+ private def replaceWithOrderedJoin(plan: LogicalPlan): LogicalPlan = plan match {
+ case j @ Join(left, right, _: InnerLike, cond) =>
+ val replacedLeft = replaceWithOrderedJoin(left)
+ val replacedRight = replaceWithOrderedJoin(right)
+ OrderedJoin(j.copy(left = replacedLeft, right = replacedRight))
+ case p @ Project(_, join) =>
+ p.copy(child = replaceWithOrderedJoin(join))
+ case _ =>
+ plan
+ }
+
+ /** This is a wrapper class for a join node that has been ordered. */
+ private case class OrderedJoin(join: Join) extends BinaryNode {
+ override def left: LogicalPlan = join.left
+ override def right: LogicalPlan = join.right
+ override def output: Seq[Attribute] = join.output
+ }
+}
+
+/**
+ * Reorder the joins using a dynamic programming algorithm. This implementation is based on the
+ * paper: Access Path Selection in a Relational Database Management System.
+ * http://www.inf.ed.ac.uk/teaching/courses/adbs/AccessPath.pdf
+ *
+ * First we put all items (basic joined nodes) into level 0, then we build all two-way joins
+ * at level 1 from plans at level 0 (single items), then build all 3-way joins from plans
+ * at previous levels (two-way joins and single items), then 4-way joins ... etc, until we
+ * build all n-way joins and pick the best plan among them.
+ *
+ * When building m-way joins, we only keep the best plan (with the lowest cost) for the same set
+ * of m items. E.g., for 3-way joins, we keep only the best plan for items {A, B, C} among
+ * plans (A J B) J C, (A J C) J B and (B J C) J A.
+ *
+ * Thus the plans maintained for each level when reordering four items A, B, C, D are as follows:
+ * level 0: p({A}), p({B}), p({C}), p({D})
+ * level 1: p({A, B}), p({A, C}), p({A, D}), p({B, C}), p({B, D}), p({C, D})
+ * level 2: p({A, B, C}), p({A, B, D}), p({A, C, D}), p({B, C, D})
+ * level 3: p({A, B, C, D})
+ * where p({A, B, C, D}) is the final output plan.
+ *
+ * For cost evaluation, since physical costs for operators are not available currently, we use
+ * cardinalities and sizes to compute costs.
+ */
+object JoinReorderDP extends PredicateHelper {
+
+ def search(
+ conf: CatalystConf,
+ items: Seq[LogicalPlan],
+ conditions: Set[Expression],
+ topOutput: AttributeSet): Option[LogicalPlan] = {
+
+ // Level i maintains all found plans for i + 1 items.
+ // Create the initial plans: each plan is a single item with zero cost.
+ val itemIndex = items.zipWithIndex
+ val foundPlans = mutable.Buffer[JoinPlanMap](itemIndex.map {
+ case (item, id) => Set(id) -> JoinPlan(Set(id), item, Set(), Cost(0, 0))
+ }.toMap)
+
+ for (lev <- 1 until items.length) {
+ // Build plans for the next level.
+ foundPlans += searchLevel(foundPlans, conf, conditions, topOutput)
+ }
+
+ val plansLastLevel = foundPlans(items.length - 1)
+ if (plansLastLevel.isEmpty) {
+ // Failed to find a plan, fall back to the original plan
+ None
+ } else {
+ // There must be only one plan at the last level, which contains all items.
+ assert(plansLastLevel.size == 1 && plansLastLevel.head._1.size == items.length)
+ Some(plansLastLevel.head._2.plan)
+ }
+ }
+
+ /** Find all possible plans at the next level, based on existing levels. */
+ private def searchLevel(
+ existingLevels: Seq[JoinPlanMap],
+ conf: CatalystConf,
+ conditions: Set[Expression],
+ topOutput: AttributeSet): JoinPlanMap = {
+
+ val nextLevel = mutable.Map.empty[Set[Int], JoinPlan]
+ var k = 0
+ val lev = existingLevels.length - 1
+ // Build plans for the next level from plans at level k (one side of the join) and level
+ // lev - k (the other side of the join).
+ // For the lower level k, we only need to search from 0 to lev - k, because when building
+ // a join from A and B, both A J B and B J A are handled.
+ while (k <= lev - k) {
+ val oneSideCandidates = existingLevels(k).values.toSeq
+ for (i <- oneSideCandidates.indices) {
+ val oneSidePlan = oneSideCandidates(i)
+ val otherSideCandidates = if (k == lev - k) {
+ // Both sides of a join are at the same level, no need to repeat for previous ones.
+ oneSideCandidates.drop(i)
+ } else {
+ existingLevels(lev - k).values.toSeq
+ }
+
+ otherSideCandidates.foreach { otherSidePlan =>
+ // Should not join two overlapping item sets.
+ if (oneSidePlan.itemIds.intersect(otherSidePlan.itemIds).isEmpty) {
+ val joinPlan = buildJoin(oneSidePlan, otherSidePlan, conf, conditions, topOutput)
+ // Check if it's the first plan for the item set, or it's a better plan than
+ // the existing one due to lower cost.
+ val existingPlan = nextLevel.get(joinPlan.itemIds)
+ if (existingPlan.isEmpty || joinPlan.cost.lessThan(existingPlan.get.cost)) {
+ nextLevel.update(joinPlan.itemIds, joinPlan)
+ }
+ }
+ }
+ }
+ k += 1
+ }
+ nextLevel.toMap
+ }
+
+ /** Build a new join node. */
+ private def buildJoin(
+ oneJoinPlan: JoinPlan,
+ otherJoinPlan: JoinPlan,
+ conf: CatalystConf,
+ conditions: Set[Expression],
+ topOutput: AttributeSet): JoinPlan = {
+
+ val onePlan = oneJoinPlan.plan
+ val otherPlan = otherJoinPlan.plan
+ // Now both onePlan and otherPlan become intermediate joins, so the cost of the
+ // new join should also include their own cardinalities and sizes.
+ val newCost = if (isCartesianProduct(onePlan) || isCartesianProduct(otherPlan)) {
+ // We consider cartesian product very expensive, thus set a very large cost for it.
+ // This enables to plan all the cartesian products at the end, because having a cartesian
+ // product as an intermediate join will significantly increase a plan's cost, making it
+ // impossible to be selected as the best plan for the items, unless there's no other choice.
+ Cost(
+ rows = BigInt(Long.MaxValue) * BigInt(Long.MaxValue),
+ size = BigInt(Long.MaxValue) * BigInt(Long.MaxValue))
+ } else {
+ val onePlanStats = onePlan.stats(conf)
+ val otherPlanStats = otherPlan.stats(conf)
+ Cost(
+ rows = oneJoinPlan.cost.rows + onePlanStats.rowCount.get +
+ otherJoinPlan.cost.rows + otherPlanStats.rowCount.get,
+ size = oneJoinPlan.cost.size + onePlanStats.sizeInBytes +
+ otherJoinPlan.cost.size + otherPlanStats.sizeInBytes)
+ }
+
+ // Put the deeper side on the left, tend to build a left-deep tree.
+ val (left, right) = if (oneJoinPlan.itemIds.size >= otherJoinPlan.itemIds.size) {
+ (onePlan, otherPlan)
+ } else {
+ (otherPlan, onePlan)
+ }
+ val joinConds = conditions
+ .filterNot(l => canEvaluate(l, onePlan))
+ .filterNot(r => canEvaluate(r, otherPlan))
+ .filter(e => e.references.subsetOf(onePlan.outputSet ++ otherPlan.outputSet))
+ // We use inner join whether join condition is empty or not. Since cross join is
+ // equivalent to inner join without condition.
+ val newJoin = Join(left, right, Inner, joinConds.reduceOption(And))
+ val collectedJoinConds = joinConds ++ oneJoinPlan.joinConds ++ otherJoinPlan.joinConds
+ val remainingConds = conditions -- collectedJoinConds
+ val neededAttr = AttributeSet(remainingConds.flatMap(_.references)) ++ topOutput
+ val neededFromNewJoin = newJoin.outputSet.filter(neededAttr.contains)
+ val newPlan =
+ if ((newJoin.outputSet -- neededFromNewJoin).nonEmpty) {
+ Project(neededFromNewJoin.toSeq, newJoin)
+ } else {
+ newJoin
+ }
+
+ val itemIds = oneJoinPlan.itemIds.union(otherJoinPlan.itemIds)
+ JoinPlan(itemIds, newPlan, collectedJoinConds, newCost)
+ }
+
+ private def isCartesianProduct(plan: LogicalPlan): Boolean = plan match {
+ case Join(_, _, _, None) => true
+ case Project(_, Join(_, _, _, None)) => true
+ case _ => false
+ }
+
+ /** Map[set of item ids, join plan for these items] */
+ type JoinPlanMap = Map[Set[Int], JoinPlan]
+
+ /**
+ * Partial join order in a specific level.
+ *
+ * @param itemIds Set of item ids participating in this partial plan.
+ * @param plan The plan tree with the lowest cost for these items found so far.
+ * @param joinConds Join conditions included in the plan.
+ * @param cost The cost of this plan is the sum of costs of all intermediate joins.
+ */
+ case class JoinPlan(itemIds: Set[Int], plan: LogicalPlan, joinConds: Set[Expression], cost: Cost)
+}
+
+/** This class defines the cost model. */
+case class Cost(rows: BigInt, size: BigInt) {
+ /**
+ * An empirical value for the weights of cardinality (number of rows) in the cost formula:
+ * cost = rows * weight + size * (1 - weight), usually cardinality is more important than size.
+ */
+ val weight = 0.7
+
+ def lessThan(other: Cost): Boolean = {
+ if (other.rows == 0 || other.size == 0) {
+ false
+ } else {
+ val relativeRows = BigDecimal(rows) / BigDecimal(other.rows)
+ val relativeSize = BigDecimal(size) / BigDecimal(other.size)
+ relativeRows * weight + relativeSize * (1 - weight) < 1
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 036da3ad2062..caafa1c134cd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -118,6 +118,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
SimplifyCreateMapOps) ::
Batch("Check Cartesian Products", Once,
CheckCartesianProducts(conf)) ::
+ Batch("Join Reorder", Once,
+ CostBasedJoinReorder(conf)) ::
Batch("Decimal Optimizations", fixedPoint,
DecimalAggregates(conf)) ::
Batch("Typed Filter Optimization", fixedPoint,
@@ -562,27 +564,23 @@ object CollapseProject extends Rule[LogicalPlan] {
}
/**
- * Combines adjacent [[Repartition]] and [[RepartitionByExpression]] operator combinations
- * by keeping only the one.
- * 1. For adjacent [[Repartition]]s, collapse into the last [[Repartition]].
- * 2. For adjacent [[RepartitionByExpression]]s, collapse into the last [[RepartitionByExpression]].
- * 3. For a combination of [[Repartition]] and [[RepartitionByExpression]], collapse as a single
- * [[RepartitionByExpression]] with the expression and last number of partition.
+ * Combines adjacent [[RepartitionOperation]] operators
*/
object CollapseRepartition extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
- // Case 1
- case Repartition(numPartitions, shuffle, Repartition(_, _, child)) =>
- Repartition(numPartitions, shuffle, child)
- // Case 2
- case RepartitionByExpression(exprs, RepartitionByExpression(_, child, _), numPartitions) =>
- RepartitionByExpression(exprs, child, numPartitions)
- // Case 3
- case Repartition(numPartitions, _, r: RepartitionByExpression) =>
- r.copy(numPartitions = numPartitions)
- // Case 3
- case RepartitionByExpression(exprs, Repartition(_, _, child), numPartitions) =>
- RepartitionByExpression(exprs, child, numPartitions)
+ // Case 1: When a Repartition has a child of Repartition or RepartitionByExpression,
+ // 1) When the top node does not enable the shuffle (i.e., coalesce API), but the child
+ // enables the shuffle. Returns the child node if the last numPartitions is bigger;
+ // otherwise, keep unchanged.
+ // 2) In the other cases, returns the top node with the child's child
+ case r @ Repartition(_, _, child: RepartitionOperation) => (r.shuffle, child.shuffle) match {
+ case (false, true) => if (r.numPartitions >= child.numPartitions) child else r
+ case _ => r.copy(child = child.child)
+ }
+ // Case 2: When a RepartitionByExpression has a child of Repartition or RepartitionByExpression
+ // we can remove the child.
+ case r @ RepartitionByExpression(_, child: RepartitionOperation, _) =>
+ r.copy(child = child.child)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
index 4d62cce9da0a..fb7ce6aecea5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
@@ -169,7 +169,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
// and Project operators, followed by an optional Filter, followed by an
// Aggregate. Traverse the operators recursively.
def evalPlan(lp : LogicalPlan) : Map[ExprId, Option[Any]] = lp match {
- case SubqueryAlias(_, child, _) => evalPlan(child)
+ case SubqueryAlias(_, child) => evalPlan(child)
case Filter(condition, child) =>
val bindings = evalPlan(child)
if (bindings.isEmpty) bindings
@@ -227,7 +227,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
topPart += p
bottomPart = child
- case s @ SubqueryAlias(_, child, _) =>
+ case s @ SubqueryAlias(_, child) =>
topPart += s
bottomPart = child
@@ -298,8 +298,8 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
topPart.reverse.foreach {
case Project(projList, _) =>
subqueryRoot = Project(projList ++ havingInputs, subqueryRoot)
- case s @ SubqueryAlias(alias, _, None) =>
- subqueryRoot = SubqueryAlias(alias, subqueryRoot, None)
+ case s @ SubqueryAlias(alias, _) =>
+ subqueryRoot = SubqueryAlias(alias, subqueryRoot)
case op => sys.error(s"Unexpected operator $op in corelated subquery")
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index d2e091f4dda6..3cf11adc1953 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -108,7 +108,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
* This is only used for Common Table Expressions.
*/
override def visitNamedQuery(ctx: NamedQueryContext): SubqueryAlias = withOrigin(ctx) {
- SubqueryAlias(ctx.name.getText, plan(ctx.query), None)
+ SubqueryAlias(ctx.name.getText, plan(ctx.query))
}
/**
@@ -666,7 +666,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
val tableWithAlias = Option(ctx.strictIdentifier).map(_.getText) match {
case Some(strictIdentifier) =>
- SubqueryAlias(strictIdentifier, table, None)
+ SubqueryAlias(strictIdentifier, table)
case _ => table
}
tableWithAlias.optionalMap(ctx.sample)(withSample)
@@ -731,7 +731,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
* Create an alias (SubqueryAlias) for a LogicalPlan.
*/
private def aliasPlan(alias: ParserRuleContext, plan: LogicalPlan): LogicalPlan = {
- SubqueryAlias(alias.getText, plan, None)
+ SubqueryAlias(alias.getText, plan)
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala
index 77309ce391a1..06196b5afb03 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala
@@ -24,6 +24,12 @@ import org.apache.spark.unsafe.types.CalendarInterval
object EventTimeWatermark {
/** The [[org.apache.spark.sql.types.Metadata]] key used to hold the eventTime watermark delay. */
val delayKey = "spark.watermarkDelayMs"
+
+ def getDelayMs(delay: CalendarInterval): Long = {
+ // We define month as `31 days` to simplify calculation.
+ val millisPerMonth = CalendarInterval.MICROS_PER_DAY / 1000 * 31
+ delay.milliseconds + delay.months * millisPerMonth
+ }
}
/**
@@ -37,9 +43,17 @@ case class EventTimeWatermark(
// Update the metadata on the eventTime column to include the desired delay.
override val output: Seq[Attribute] = child.output.map { a =>
if (a semanticEquals eventTime) {
+ val delayMs = EventTimeWatermark.getDelayMs(delay)
+ val updatedMetadata = new MetadataBuilder()
+ .withMetadata(a.metadata)
+ .putLong(EventTimeWatermark.delayKey, delayMs)
+ .build()
+ a.withMetadata(updatedMetadata)
+ } else if (a.metadata.contains(EventTimeWatermark.delayKey)) {
+ // Remove existing watermark
val updatedMetadata = new MetadataBuilder()
.withMetadata(a.metadata)
- .putLong(EventTimeWatermark.delayKey, delay.milliseconds)
+ .remove(EventTimeWatermark.delayKey)
.build()
a.withMetadata(updatedMetadata)
} else {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index ccebae3cc270..31b6ed48a223 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -752,14 +752,13 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN
}
override def computeStats(conf: CatalystConf): Statistics = {
val limit = limitExpr.eval().asInstanceOf[Int]
- val sizeInBytes = if (limit == 0) {
- // sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero
- // (product of children).
- 1
- } else {
- (limit: Long) * output.map(a => a.dataType.defaultSize).sum
- }
- child.stats(conf).copy(sizeInBytes = sizeInBytes)
+ val childStats = child.stats(conf)
+ val rowCount: BigInt = childStats.rowCount.map(_.min(limit)).getOrElse(limit)
+ // Don't propagate column stats, because we don't know the distribution after a limit operation
+ Statistics(
+ sizeInBytes = EstimationUtils.getOutputSize(output, rowCount, childStats.attributeStats),
+ rowCount = Some(rowCount),
+ isBroadcastable = childStats.isBroadcastable)
}
}
@@ -773,21 +772,27 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo
}
override def computeStats(conf: CatalystConf): Statistics = {
val limit = limitExpr.eval().asInstanceOf[Int]
- val sizeInBytes = if (limit == 0) {
+ val childStats = child.stats(conf)
+ if (limit == 0) {
// sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero
// (product of children).
- 1
+ Statistics(
+ sizeInBytes = 1,
+ rowCount = Some(0),
+ isBroadcastable = childStats.isBroadcastable)
} else {
- (limit: Long) * output.map(a => a.dataType.defaultSize).sum
+ // The output row count of LocalLimit should be the sum of row counts from each partition.
+ // However, since the number of partitions is not available here, we just use statistics of
+ // the child. Because the distribution after a limit operation is unknown, we do not propagate
+ // the column stats.
+ childStats.copy(attributeStats = AttributeMap(Nil))
}
- child.stats(conf).copy(sizeInBytes = sizeInBytes)
}
}
case class SubqueryAlias(
alias: String,
- child: LogicalPlan,
- view: Option[TableIdentifier])
+ child: LogicalPlan)
extends UnaryNode {
override def output: Seq[Attribute] = child.output.map(_.withQualifier(Some(alias)))
@@ -816,12 +821,14 @@ case class Sample(
override def computeStats(conf: CatalystConf): Statistics = {
val ratio = upperBound - lowerBound
- // BigInt can't multiply with Double
- var sizeInBytes = child.stats(conf).sizeInBytes * (ratio * 100).toInt / 100
+ val childStats = child.stats(conf)
+ var sizeInBytes = EstimationUtils.ceil(BigDecimal(childStats.sizeInBytes) * ratio)
if (sizeInBytes == 0) {
sizeInBytes = 1
}
- child.stats(conf).copy(sizeInBytes = sizeInBytes)
+ val sampledRowCount = childStats.rowCount.map(c => EstimationUtils.ceil(BigDecimal(c) * ratio))
+ // Don't propagate column stats, because we don't know the distribution after a sample operation
+ Statistics(sizeInBytes, sampledRowCount, isBroadcastable = childStats.isBroadcastable)
}
override protected def otherCopyArgs: Seq[AnyRef] = isTableSample :: Nil
@@ -835,6 +842,15 @@ case class Distinct(child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output
}
+/**
+ * A base interface for [[RepartitionByExpression]] and [[Repartition]]
+ */
+abstract class RepartitionOperation extends UnaryNode {
+ def shuffle: Boolean
+ def numPartitions: Int
+ override def output: Seq[Attribute] = child.output
+}
+
/**
* Returns a new RDD that has exactly `numPartitions` partitions. Differs from
* [[RepartitionByExpression]] as this method is called directly by DataFrame's, because the user
@@ -842,9 +858,8 @@ case class Distinct(child: LogicalPlan) extends UnaryNode {
* of the output requires some specific ordering or distribution of the data.
*/
case class Repartition(numPartitions: Int, shuffle: Boolean, child: LogicalPlan)
- extends UnaryNode {
+ extends RepartitionOperation {
require(numPartitions > 0, s"Number of partitions ($numPartitions) must be positive.")
- override def output: Seq[Attribute] = child.output
}
/**
@@ -856,12 +871,12 @@ case class Repartition(numPartitions: Int, shuffle: Boolean, child: LogicalPlan)
case class RepartitionByExpression(
partitionExpressions: Seq[Expression],
child: LogicalPlan,
- numPartitions: Int) extends UnaryNode {
+ numPartitions: Int) extends RepartitionOperation {
require(numPartitions > 0, s"Number of partitions ($numPartitions) must be positive.")
override def maxRows: Option[Long] = child.maxRows
- override def output: Seq[Attribute] = child.output
+ override def shuffle: Boolean = true
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index 0be4823bbc89..7f4462e58360 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -26,7 +26,9 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
+import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types._
+import org.apache.spark.util.Utils
object CatalystSerde {
def deserialize[T : Encoder](child: LogicalPlan): DeserializeToObject = {
@@ -210,13 +212,48 @@ case class TypedFilter(
def typedCondition(input: Expression): Expression = {
val (funcClass, methodName) = func match {
case m: FilterFunction[_] => classOf[FilterFunction[_]] -> "call"
- case _ => classOf[Any => Boolean] -> "apply"
+ case _ => FunctionUtils.getFunctionOneName(BooleanType, input.dataType)
}
val funcObj = Literal.create(func, ObjectType(funcClass))
Invoke(funcObj, methodName, BooleanType, input :: Nil)
}
}
+object FunctionUtils {
+ private def getMethodType(dt: DataType, isOutput: Boolean): Option[String] = {
+ dt match {
+ case BooleanType if isOutput => Some("Z")
+ case IntegerType => Some("I")
+ case LongType => Some("J")
+ case FloatType => Some("F")
+ case DoubleType => Some("D")
+ case _ => None
+ }
+ }
+
+ def getFunctionOneName(outputDT: DataType, inputDT: DataType): (Class[_], String) = {
+ // load "scala.Function1" using Java API to avoid requirements of type parameters
+ Utils.classForName("scala.Function1") -> {
+ // if a pair of an argument and return types is one of specific types
+ // whose specialized method (apply$mc..$sp) is generated by scalac,
+ // Catalyst generated a direct method call to the specialized method.
+ // The followings are references for this specialization:
+ // http://www.scala-lang.org/api/2.12.0/scala/Function1.html
+ // https://github.com/scala/scala/blob/2.11.x/src/compiler/scala/tools/nsc/transform/
+ // SpecializeTypes.scala
+ // http://www.cakesolutions.net/teamblogs/scala-dissection-functions
+ // http://axel22.github.io/2013/11/03/specialization-quirks.html
+ val inputType = getMethodType(inputDT, false)
+ val outputType = getMethodType(outputDT, true)
+ if (inputType.isDefined && outputType.isDefined) {
+ s"apply$$mc${outputType.get}${inputType.get}$$sp"
+ } else {
+ "apply"
+ }
+ }
+ }
+}
+
/** Factory for constructing new `AppendColumn` nodes. */
object AppendColumns {
def apply[T : Encoder, U : Encoder](
@@ -317,13 +354,15 @@ case class MapGroups(
trait LogicalKeyedState[S]
/** Factory for constructing new `MapGroupsWithState` nodes. */
-object MapGroupsWithState {
+object FlatMapGroupsWithState {
def apply[K: Encoder, V: Encoder, S: Encoder, U: Encoder](
func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any],
groupingAttributes: Seq[Attribute],
dataAttributes: Seq[Attribute],
+ outputMode: OutputMode,
+ isMapGroupsWithState: Boolean,
child: LogicalPlan): LogicalPlan = {
- val mapped = new MapGroupsWithState(
+ val mapped = new FlatMapGroupsWithState(
func,
UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes),
UnresolvedDeserializer(encoderFor[V].deserializer, dataAttributes),
@@ -332,7 +371,9 @@ object MapGroupsWithState {
CatalystSerde.generateObjAttr[U],
encoderFor[S].resolveAndBind().deserializer,
encoderFor[S].namedExpressions,
- child)
+ outputMode,
+ child,
+ isMapGroupsWithState)
CatalystSerde.serialize[U](mapped)
}
}
@@ -350,8 +391,10 @@ object MapGroupsWithState {
* @param outputObjAttr used to define the output object
* @param stateDeserializer used to deserialize state before calling `func`
* @param stateSerializer used to serialize updated state after calling `func`
+ * @param outputMode the output mode of `func`
+ * @param isMapGroupsWithState whether it is created by the `mapGroupsWithState` method
*/
-case class MapGroupsWithState(
+case class FlatMapGroupsWithState(
func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any],
keyDeserializer: Expression,
valueDeserializer: Expression,
@@ -360,7 +403,14 @@ case class MapGroupsWithState(
outputObjAttr: Attribute,
stateDeserializer: Expression,
stateSerializer: Seq[NamedExpression],
- child: LogicalPlan) extends UnaryNode with ObjectProducer
+ outputMode: OutputMode,
+ child: LogicalPlan,
+ isMapGroupsWithState: Boolean = false) extends UnaryNode with ObjectProducer {
+
+ if (isMapGroupsWithState) {
+ assert(outputMode == OutputMode.Update)
+ }
+}
/** Factory for constructing new `FlatMapGroupsInR` nodes. */
object FlatMapGroupsInR {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
index 0c928832d7d2..b10785b05d6c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
@@ -19,11 +19,12 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation
import scala.collection.immutable.HashSet
import scala.collection.mutable
+import scala.math.BigDecimal.RoundingMode
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.CatalystConf
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, Statistics}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
@@ -52,17 +53,19 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
def estimate: Option[Statistics] = {
if (childStats.rowCount.isEmpty) return None
- // save a mutable copy of colStats so that we can later change it recursively
+ // Save a mutable copy of colStats so that we can later change it recursively.
colStatsMap.setInitValues(childStats.attributeStats)
- // estimate selectivity of this filter predicate
- val filterSelectivity: Double = calculateFilterSelectivity(plan.condition) match {
- case Some(percent) => percent
- // for not-supported condition, set filter selectivity to a conservative estimate 100%
- case None => 1.0
- }
+ // Estimate selectivity of this filter predicate, and update column stats if needed.
+ // For not-supported condition, set filter selectivity to a conservative estimate 100%
+ val filterSelectivity: Double = calculateFilterSelectivity(plan.condition).getOrElse(1.0)
- val newColStats = colStatsMap.toColumnStats
+ val newColStats = if (filterSelectivity == 0) {
+ // The output is empty, we don't need to keep column stats.
+ AttributeMap[ColumnStat](Nil)
+ } else {
+ colStatsMap.toColumnStats
+ }
val filteredRowCount: BigInt =
EstimationUtils.ceil(BigDecimal(childStats.rowCount.get) * filterSelectivity)
@@ -74,12 +77,14 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
}
/**
- * Returns a percentage of rows meeting a compound condition in Filter node.
- * A compound condition is decomposed into multiple single conditions linked with AND, OR, NOT.
+ * Returns a percentage of rows meeting a condition in Filter node.
+ * If it's a single condition, we calculate the percentage directly.
+ * If it's a compound condition, it is decomposed into multiple single conditions linked with
+ * AND, OR, NOT.
* For logical AND conditions, we need to update stats after a condition estimation
* so that the stats will be more accurate for subsequent estimation. This is needed for
* range condition such as (c > 40 AND c <= 50)
- * For logical OR conditions, we do not update stats after a condition estimation.
+ * For logical OR and NOT conditions, we do not update stats after a condition estimation.
*
* @param condition the compound logical expression
* @param update a boolean flag to specify if we need to update ColumnStat of a column
@@ -90,34 +95,29 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
def calculateFilterSelectivity(condition: Expression, update: Boolean = true): Option[Double] = {
condition match {
case And(cond1, cond2) =>
- // For ease of debugging, we compute percent1 and percent2 in 2 statements.
- val percent1 = calculateFilterSelectivity(cond1, update)
- val percent2 = calculateFilterSelectivity(cond2, update)
- (percent1, percent2) match {
- case (Some(p1), Some(p2)) => Some(p1 * p2)
- case (Some(p1), None) => Some(p1)
- case (None, Some(p2)) => Some(p2)
- case (None, None) => None
- }
+ val percent1 = calculateFilterSelectivity(cond1, update).getOrElse(1.0)
+ val percent2 = calculateFilterSelectivity(cond2, update).getOrElse(1.0)
+ Some(percent1 * percent2)
case Or(cond1, cond2) =>
- // For ease of debugging, we compute percent1 and percent2 in 2 statements.
- val percent1 = calculateFilterSelectivity(cond1, update = false)
- val percent2 = calculateFilterSelectivity(cond2, update = false)
- (percent1, percent2) match {
- case (Some(p1), Some(p2)) => Some(math.min(1.0, p1 + p2 - (p1 * p2)))
- case (Some(p1), None) => Some(1.0)
- case (None, Some(p2)) => Some(1.0)
- case (None, None) => None
- }
+ val percent1 = calculateFilterSelectivity(cond1, update = false).getOrElse(1.0)
+ val percent2 = calculateFilterSelectivity(cond2, update = false).getOrElse(1.0)
+ Some(percent1 + percent2 - (percent1 * percent2))
- case Not(cond) => calculateFilterSelectivity(cond, update = false) match {
- case Some(percent) => Some(1.0 - percent)
- // for not-supported condition, set filter selectivity to a conservative estimate 100%
- case None => None
- }
+ case Not(And(cond1, cond2)) =>
+ calculateFilterSelectivity(Or(Not(cond1), Not(cond2)), update = false)
+
+ case Not(Or(cond1, cond2)) =>
+ calculateFilterSelectivity(And(Not(cond1), Not(cond2)), update = false)
- case _ => calculateSingleCondition(condition, update)
+ case Not(cond) =>
+ calculateFilterSelectivity(cond, update = false) match {
+ case Some(percent) => Some(1.0 - percent)
+ case None => None
+ }
+
+ case _ =>
+ calculateSingleCondition(condition, update)
}
}
@@ -225,12 +225,12 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
}
val percent = if (isNull) {
- nullPercent.toDouble
+ nullPercent
} else {
- 1.0 - nullPercent.toDouble
+ 1.0 - nullPercent
}
- Some(percent)
+ Some(percent.toDouble)
}
/**
@@ -249,17 +249,19 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
attr: Attribute,
literal: Literal,
update: Boolean): Option[Double] = {
+ if (!colStatsMap.contains(attr)) {
+ logDebug("[CBO] No statistics for " + attr)
+ return None
+ }
+
attr.dataType match {
- case _: NumericType | DateType | TimestampType =>
+ case _: NumericType | DateType | TimestampType | BooleanType =>
evaluateBinaryForNumeric(op, attr, literal, update)
case StringType | BinaryType =>
// TODO: It is difficult to support other binary comparisons for String/Binary
// type without min/max and advanced statistics like histogram.
logDebug("[CBO] No range comparison statistics for String/Binary type " + attr)
None
- case _ =>
- // TODO: support boolean type.
- None
}
}
@@ -291,6 +293,10 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
* Returns a percentage of rows meeting an equality (=) expression.
* This method evaluates the equality predicate for all data types.
*
+ * For EqualNullSafe (<=>), if the literal is not null, result will be the same as EqualTo;
+ * if the literal is null, the condition will be changed to IsNull after optimization.
+ * So we don't need specific logic for EqualNullSafe here.
+ *
* @param attr an Attribute (or a column)
* @param literal a literal value (or constant)
* @param update a boolean flag to specify if we need to update ColumnStat of a given column
@@ -323,7 +329,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
colStatsMap(attr) = newStats
}
- Some(1.0 / ndv.toDouble)
+ Some((1.0 / BigDecimal(ndv)).toDouble)
} else {
Some(0.0)
}
@@ -394,12 +400,12 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
// return the filter selectivity. Without advanced statistics such as histograms,
// we have to assume uniform distribution.
- Some(math.min(1.0, newNdv.toDouble / ndv.toDouble))
+ Some(math.min(1.0, (BigDecimal(newNdv) / BigDecimal(ndv)).toDouble))
}
/**
* Returns a percentage of rows meeting a binary comparison expression.
- * This method evaluate expression for Numeric columns only.
+ * This method evaluate expression for Numeric/Date/Timestamp/Boolean columns.
*
* @param op a binary comparison operator uch as =, <, <=, >, >=
* @param attr an Attribute (or a column)
@@ -414,53 +420,66 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
literal: Literal,
update: Boolean): Option[Double] = {
- var percent = 1.0
val colStat = colStatsMap(attr)
- val statsRange =
- Range(colStat.min, colStat.max, attr.dataType).asInstanceOf[NumericRange]
+ val statsRange = Range(colStat.min, colStat.max, attr.dataType).asInstanceOf[NumericRange]
+ val max = BigDecimal(statsRange.max)
+ val min = BigDecimal(statsRange.min)
+ val ndv = BigDecimal(colStat.distinctCount)
// determine the overlapping degree between predicate range and column's range
- val literalValueBD = BigDecimal(literal.value.toString)
+ val numericLiteral = if (literal.dataType == BooleanType) {
+ if (literal.value.asInstanceOf[Boolean]) BigDecimal(1) else BigDecimal(0)
+ } else {
+ BigDecimal(literal.value.toString)
+ }
val (noOverlap: Boolean, completeOverlap: Boolean) = op match {
case _: LessThan =>
- (literalValueBD <= statsRange.min, literalValueBD > statsRange.max)
+ (numericLiteral <= min, numericLiteral > max)
case _: LessThanOrEqual =>
- (literalValueBD < statsRange.min, literalValueBD >= statsRange.max)
+ (numericLiteral < min, numericLiteral >= max)
case _: GreaterThan =>
- (literalValueBD >= statsRange.max, literalValueBD < statsRange.min)
+ (numericLiteral >= max, numericLiteral < min)
case _: GreaterThanOrEqual =>
- (literalValueBD > statsRange.max, literalValueBD <= statsRange.min)
+ (numericLiteral > max, numericLiteral <= min)
}
+ var percent = BigDecimal(1.0)
if (noOverlap) {
percent = 0.0
} else if (completeOverlap) {
percent = 1.0
} else {
- // this is partial overlap case
- val literalDouble = literalValueBD.toDouble
- val maxDouble = BigDecimal(statsRange.max).toDouble
- val minDouble = BigDecimal(statsRange.min).toDouble
-
+ // This is the partial overlap case:
// Without advanced statistics like histogram, we assume uniform data distribution.
// We just prorate the adjusted range over the initial range to compute filter selectivity.
- // For ease of computation, we convert all relevant numeric values to Double.
+ assert(max > min)
percent = op match {
case _: LessThan =>
- (literalDouble - minDouble) / (maxDouble - minDouble)
+ if (numericLiteral == max) {
+ // If the literal value is right on the boundary, we can minus the part of the
+ // boundary value (1/ndv).
+ 1.0 - 1.0 / ndv
+ } else {
+ (numericLiteral - min) / (max - min)
+ }
case _: LessThanOrEqual =>
- if (literalValueBD == BigDecimal(statsRange.min)) {
- 1.0 / colStat.distinctCount.toDouble
+ if (numericLiteral == min) {
+ // The boundary value is the only satisfying value.
+ 1.0 / ndv
} else {
- (literalDouble - minDouble) / (maxDouble - minDouble)
+ (numericLiteral - min) / (max - min)
}
case _: GreaterThan =>
- (maxDouble - literalDouble) / (maxDouble - minDouble)
+ if (numericLiteral == min) {
+ 1.0 - 1.0 / ndv
+ } else {
+ (max - numericLiteral) / (max - min)
+ }
case _: GreaterThanOrEqual =>
- if (literalValueBD == BigDecimal(statsRange.max)) {
- 1.0 / colStat.distinctCount.toDouble
+ if (numericLiteral == max) {
+ 1.0 / ndv
} else {
- (maxDouble - literalDouble) / (maxDouble - minDouble)
+ (max - numericLiteral) / (max - min)
}
}
@@ -469,22 +488,25 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo
val newValue = convertBoundValue(attr.dataType, literal.value)
var newMax = colStat.max
var newMin = colStat.min
+ var newNdv = (ndv * percent).setScale(0, RoundingMode.HALF_UP).toBigInt()
+ if (newNdv < 1) newNdv = 1
+
op match {
- case _: GreaterThan => newMin = newValue
- case _: GreaterThanOrEqual => newMin = newValue
- case _: LessThan => newMax = newValue
- case _: LessThanOrEqual => newMax = newValue
+ case _: GreaterThan | _: GreaterThanOrEqual =>
+ // If new ndv is 1, then new max must be equal to new min.
+ newMin = if (newNdv == 1) newMax else newValue
+ case _: LessThan | _: LessThanOrEqual =>
+ newMax = if (newNdv == 1) newMin else newValue
}
- val newNdv = math.max(math.round(colStat.distinctCount.toDouble * percent), 1)
- val newStats = colStat.copy(distinctCount = newNdv, min = newMin,
- max = newMax, nullCount = 0)
+ val newStats =
+ colStat.copy(distinctCount = newNdv, min = newMin, max = newMax, nullCount = 0)
colStatsMap(attr) = newStats
}
}
- Some(percent)
+ Some(percent.toDouble)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModes.scala
index 351bd6fff4ad..bdf2baf7361d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModes.scala
@@ -44,4 +44,19 @@ private[sql] object InternalOutputModes {
* aggregations, it will be equivalent to `Append` mode.
*/
case object Update extends OutputMode
+
+
+ def apply(outputMode: String): OutputMode = {
+ outputMode.toLowerCase match {
+ case "append" =>
+ OutputMode.Append
+ case "complete" =>
+ OutputMode.Complete
+ case "update" =>
+ OutputMode.Update
+ case _ =>
+ throw new IllegalArgumentException(s"Unknown output mode $outputMode. " +
+ "Accepted output modes are 'append', 'complete', 'update'")
+ }
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
index 089c84d5f773..e8f6884c025c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
@@ -21,6 +21,7 @@ import java.lang.{Long => JLong}
import java.math.{BigInteger, MathContext, RoundingMode}
import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.sql.AnalysisException
/**
* A mutable implementation of BigDecimal that can hold a Long if values are small enough.
@@ -222,6 +223,19 @@ final class Decimal extends Ordered[Decimal] with Serializable {
case java.math.BigDecimal.ROUND_HALF_EVEN => changePrecision(precision, scale, ROUND_HALF_EVEN)
}
+ /**
+ * Create new `Decimal` with given precision and scale.
+ *
+ * @return `Some(decimal)` if successful or `None` if overflow would occur
+ */
+ private[sql] def toPrecision(
+ precision: Int,
+ scale: Int,
+ roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP): Option[Decimal] = {
+ val copy = clone()
+ if (copy.changePrecision(precision, scale, roundMode)) Some(copy) else None
+ }
+
/**
* Update precision and scale while keeping our value the same, and return true if successful.
*
@@ -362,17 +376,15 @@ final class Decimal extends Ordered[Decimal] with Serializable {
def abs: Decimal = if (this.compare(Decimal.ZERO) < 0) this.unary_- else this
def floor: Decimal = if (scale == 0) this else {
- val value = this.clone()
- value.changePrecision(
- DecimalType.bounded(precision - scale + 1, 0).precision, 0, ROUND_FLOOR)
- value
+ val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision
+ toPrecision(newPrecision, 0, ROUND_FLOOR).getOrElse(
+ throw new AnalysisException(s"Overflow when setting precision to $newPrecision"))
}
def ceil: Decimal = if (scale == 0) this else {
- val value = this.clone()
- value.changePrecision(
- DecimalType.bounded(precision - scale + 1, 0).precision, 0, ROUND_CEILING)
- value
+ val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision
+ toPrecision(newPrecision, 0, ROUND_CEILING).getOrElse(
+ throw new AnalysisException(s"Overflow when setting precision to $newPrecision"))
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 01737e0a1734..893bb1b74cea 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -62,23 +62,23 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers {
checkAnalysis(
Project(Seq(UnresolvedAttribute("TbL.a")),
- SubqueryAlias("TbL", UnresolvedRelation(TableIdentifier("TaBlE")), None)),
+ SubqueryAlias("TbL", UnresolvedRelation(TableIdentifier("TaBlE")))),
Project(testRelation.output, testRelation))
assertAnalysisError(
Project(Seq(UnresolvedAttribute("tBl.a")),
- SubqueryAlias("TbL", UnresolvedRelation(TableIdentifier("TaBlE")), None)),
+ SubqueryAlias("TbL", UnresolvedRelation(TableIdentifier("TaBlE")))),
Seq("cannot resolve"))
checkAnalysis(
Project(Seq(UnresolvedAttribute("TbL.a")),
- SubqueryAlias("TbL", UnresolvedRelation(TableIdentifier("TaBlE")), None)),
+ SubqueryAlias("TbL", UnresolvedRelation(TableIdentifier("TaBlE")))),
Project(testRelation.output, testRelation),
caseSensitive = false)
checkAnalysis(
Project(Seq(UnresolvedAttribute("tBl.a")),
- SubqueryAlias("TbL", UnresolvedRelation(TableIdentifier("TaBlE")), None)),
+ SubqueryAlias("TbL", UnresolvedRelation(TableIdentifier("TaBlE")))),
Project(testRelation.output, testRelation),
caseSensitive = false)
}
@@ -374,8 +374,8 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers {
val query =
Project(Seq($"x.key", $"y.key"),
Join(
- Project(Seq($"x.key"), SubqueryAlias("x", input, None)),
- Project(Seq($"y.key"), SubqueryAlias("y", input, None)),
+ Project(Seq($"x.key"), SubqueryAlias("x", input)),
+ Project(Seq($"y.key"), SubqueryAlias("y", input)),
Cross, None))
assertAnalysisSuccess(query)
@@ -435,10 +435,10 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers {
test("resolve as with an already existed alias") {
checkAnalysis(
Project(Seq(UnresolvedAttribute("tbl2.a")),
- SubqueryAlias("tbl", testRelation, None).as("tbl2")),
+ SubqueryAlias("tbl", testRelation).as("tbl2")),
Project(testRelation.output, testRelation),
caseSensitive = false)
- checkAnalysis(SubqueryAlias("tbl", testRelation, None).as("tbl2"), testRelation)
+ checkAnalysis(SubqueryAlias("tbl", testRelation).as("tbl2"), testRelation)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
index 82be69a0f7d7..200c39f43a6b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Literal, NamedExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate.Count
import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical.{MapGroupsWithState, _}
+import org.apache.spark.sql.catalyst.plans.logical.{FlatMapGroupsWithState, _}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.{IntegerType, LongType, MetadataBuilder}
@@ -138,29 +138,202 @@ class UnsupportedOperationsSuite extends SparkFunSuite {
outputMode = Complete,
expectedMsgs = Seq("distinct aggregation"))
- // MapGroupsWithState: Not supported after a streaming aggregation
val att = new AttributeReference(name = "a", dataType = LongType)()
- assertSupportedInBatchPlan(
- "mapGroupsWithState - mapGroupsWithState on batch relation",
- MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), batchRelation))
+ // FlatMapGroupsWithState: Both function modes equivalent and supported in batch.
+ for (funcMode <- Seq(Append, Update)) {
+ assertSupportedInBatchPlan(
+ s"flatMapGroupsWithState - flatMapGroupsWithState($funcMode) on batch relation",
+ FlatMapGroupsWithState(
+ null, att, att, Seq(att), Seq(att), att, att, Seq(att), funcMode, batchRelation))
+
+ assertSupportedInBatchPlan(
+ s"flatMapGroupsWithState - multiple flatMapGroupsWithState($funcMode)s on batch relation",
+ FlatMapGroupsWithState(
+ null, att, att, Seq(att), Seq(att), att, att, Seq(att), funcMode,
+ FlatMapGroupsWithState(
+ null, att, att, Seq(att), Seq(att), att, att, Seq(att), funcMode, batchRelation)))
+ }
+
+ // FlatMapGroupsWithState(Update) in streaming without aggregation
+ assertSupportedInStreamingPlan(
+ "flatMapGroupsWithState - flatMapGroupsWithState(Update) " +
+ "on streaming relation without aggregation in update mode",
+ FlatMapGroupsWithState(
+ null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, streamRelation),
+ outputMode = Update)
+
+ assertNotSupportedInStreamingPlan(
+ "flatMapGroupsWithState - flatMapGroupsWithState(Update) " +
+ "on streaming relation without aggregation in append mode",
+ FlatMapGroupsWithState(
+ null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, streamRelation),
+ outputMode = Append,
+ expectedMsgs = Seq("flatMapGroupsWithState in update mode", "Append"))
+
+ assertNotSupportedInStreamingPlan(
+ "flatMapGroupsWithState - flatMapGroupsWithState(Update) " +
+ "on streaming relation without aggregation in complete mode",
+ FlatMapGroupsWithState(
+ null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, streamRelation),
+ outputMode = Complete,
+ // Disallowed by the aggregation check but let's still keep this test in case it's broken in
+ // future.
+ expectedMsgs = Seq("Complete"))
+
+ // FlatMapGroupsWithState(Update) in streaming with aggregation
+ for (outputMode <- Seq(Append, Update, Complete)) {
+ assertNotSupportedInStreamingPlan(
+ "flatMapGroupsWithState - flatMapGroupsWithState(Update) on streaming relation " +
+ s"with aggregation in $outputMode mode",
+ FlatMapGroupsWithState(
+ null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update,
+ Aggregate(Seq(attributeWithWatermark), aggExprs("c"), streamRelation)),
+ outputMode = outputMode,
+ expectedMsgs = Seq("flatMapGroupsWithState in update mode", "with aggregation"))
+ }
+ // FlatMapGroupsWithState(Append) in streaming without aggregation
assertSupportedInStreamingPlan(
- "mapGroupsWithState - mapGroupsWithState on streaming relation before aggregation",
- MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), streamRelation),
+ "flatMapGroupsWithState - flatMapGroupsWithState(Append) " +
+ "on streaming relation without aggregation in append mode",
+ FlatMapGroupsWithState(
+ null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, streamRelation),
outputMode = Append)
assertNotSupportedInStreamingPlan(
- "mapGroupsWithState - mapGroupsWithState on streaming relation after aggregation",
- MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att),
- Aggregate(Nil, aggExprs("c"), streamRelation)),
+ "flatMapGroupsWithState - flatMapGroupsWithState(Append) " +
+ "on streaming relation without aggregation in update mode",
+ FlatMapGroupsWithState(
+ null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, streamRelation),
+ outputMode = Update,
+ expectedMsgs = Seq("flatMapGroupsWithState in append mode", "update"))
+
+ // FlatMapGroupsWithState(Append) in streaming with aggregation
+ for (outputMode <- Seq(Append, Update, Complete)) {
+ assertSupportedInStreamingPlan(
+ "flatMapGroupsWithState - flatMapGroupsWithState(Append) " +
+ s"on streaming relation before aggregation in $outputMode mode",
+ Aggregate(
+ Seq(attributeWithWatermark),
+ aggExprs("c"),
+ FlatMapGroupsWithState(
+ null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, streamRelation)),
+ outputMode = outputMode)
+ }
+
+ for (outputMode <- Seq(Append, Update)) {
+ assertNotSupportedInStreamingPlan(
+ "flatMapGroupsWithState - flatMapGroupsWithState(Append) " +
+ s"on streaming relation after aggregation in $outputMode mode",
+ FlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append,
+ Aggregate(Seq(attributeWithWatermark), aggExprs("c"), streamRelation)),
+ outputMode = outputMode,
+ expectedMsgs = Seq("flatMapGroupsWithState", "after aggregation"))
+ }
+
+ assertNotSupportedInStreamingPlan(
+ "flatMapGroupsWithState - " +
+ "flatMapGroupsWithState(Update) on streaming relation in complete mode",
+ FlatMapGroupsWithState(
+ null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, streamRelation),
outputMode = Complete,
- expectedMsgs = Seq("(map/flatMap)GroupsWithState"))
+ // Disallowed by the aggregation check but let's still keep this test in case it's broken in
+ // future.
+ expectedMsgs = Seq("Complete"))
+ // FlatMapGroupsWithState inside batch relation should always be allowed
+ for (funcMode <- Seq(Append, Update)) {
+ for (outputMode <- Seq(Append, Update)) { // Complete is not supported without aggregation
+ assertSupportedInStreamingPlan(
+ s"flatMapGroupsWithState - flatMapGroupsWithState($funcMode) on batch relation inside " +
+ s"streaming relation in $outputMode output mode",
+ FlatMapGroupsWithState(
+ null, att, att, Seq(att), Seq(att), att, att, Seq(att), funcMode, batchRelation),
+ outputMode = outputMode
+ )
+ }
+ }
+
+ // multiple FlatMapGroupsWithStates
assertSupportedInStreamingPlan(
- "mapGroupsWithState - mapGroupsWithState on batch relation inside streaming relation",
- MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), batchRelation),
- outputMode = Append
- )
+ "flatMapGroupsWithState - multiple flatMapGroupsWithStates on streaming relation and all are " +
+ "in append mode",
+ FlatMapGroupsWithState(
+ null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append,
+ FlatMapGroupsWithState(
+ null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, streamRelation)),
+ outputMode = Append)
+
+ assertNotSupportedInStreamingPlan(
+ "flatMapGroupsWithState - multiple flatMapGroupsWithStates on s streaming relation but some" +
+ " are not in append mode",
+ FlatMapGroupsWithState(
+ null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update,
+ FlatMapGroupsWithState(
+ null, att, att, Seq(att), Seq(att), att, att, Seq(att), Append, streamRelation)),
+ outputMode = Append,
+ expectedMsgs = Seq("multiple flatMapGroupsWithState", "append"))
+
+ // mapGroupsWithState
+ assertNotSupportedInStreamingPlan(
+ "mapGroupsWithState - mapGroupsWithState " +
+ "on streaming relation without aggregation in append mode",
+ FlatMapGroupsWithState(
+ null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, streamRelation,
+ isMapGroupsWithState = true),
+ outputMode = Append,
+ // Disallowed by the aggregation check but let's still keep this test in case it's broken in
+ // future.
+ expectedMsgs = Seq("mapGroupsWithState", "append"))
+
+ assertNotSupportedInStreamingPlan(
+ "mapGroupsWithState - mapGroupsWithState " +
+ "on streaming relation without aggregation in complete mode",
+ FlatMapGroupsWithState(
+ null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, streamRelation,
+ isMapGroupsWithState = true),
+ outputMode = Complete,
+ // Disallowed by the aggregation check but let's still keep this test in case it's broken in
+ // future.
+ expectedMsgs = Seq("Complete"))
+
+ for (outputMode <- Seq(Append, Update, Complete)) {
+ assertNotSupportedInStreamingPlan(
+ "mapGroupsWithState - mapGroupsWithState on streaming relation " +
+ s"with aggregation in $outputMode mode",
+ FlatMapGroupsWithState(
+ null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update,
+ Aggregate(Seq(attributeWithWatermark), aggExprs("c"), streamRelation),
+ isMapGroupsWithState = true),
+ outputMode = outputMode,
+ expectedMsgs = Seq("mapGroupsWithState", "with aggregation"))
+ }
+
+ // multiple mapGroupsWithStates
+ assertNotSupportedInStreamingPlan(
+ "mapGroupsWithState - multiple mapGroupsWithStates on streaming relation and all are " +
+ "in append mode",
+ FlatMapGroupsWithState(
+ null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update,
+ FlatMapGroupsWithState(
+ null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, streamRelation,
+ isMapGroupsWithState = true),
+ isMapGroupsWithState = true),
+ outputMode = Append,
+ expectedMsgs = Seq("multiple mapGroupsWithStates"))
+
+ // mixing mapGroupsWithStates and flatMapGroupsWithStates
+ assertNotSupportedInStreamingPlan(
+ "mapGroupsWithState - " +
+ "mixing mapGroupsWithStates and flatMapGroupsWithStates on streaming relation",
+ FlatMapGroupsWithState(
+ null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update,
+ FlatMapGroupsWithState(
+ null, att, att, Seq(att), Seq(att), att, att, Seq(att), Update, streamRelation,
+ isMapGroupsWithState = false),
+ isMapGroupsWithState = true),
+ outputMode = Append,
+ expectedMsgs = Seq("Mixing mapGroupsWithStates and flatMapGroupsWithStates"))
// Deduplicate
assertSupportedInStreamingPlan(
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala
index a5d399a06558..7820f39d9642 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.catalog
+import java.net.URI
+
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.scalatest.BeforeAndAfterEach
@@ -26,7 +28,7 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.{FunctionAlreadyExistsException, NoSuchDatabaseException, NoSuchFunctionException}
import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -238,6 +240,19 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac
}
}
+ test("alter table schema") {
+ val catalog = newBasicCatalog()
+ val tbl1 = catalog.getTable("db2", "tbl1")
+ val newSchema = StructType(Seq(
+ StructField("new_field_1", IntegerType),
+ StructField("new_field_2", StringType),
+ StructField("a", IntegerType),
+ StructField("b", StringType)))
+ catalog.alterTableSchema("db2", "tbl1", newSchema)
+ val newTbl1 = catalog.getTable("db2", "tbl1")
+ assert(newTbl1.schema == newSchema)
+ }
+
test("get table") {
assert(newBasicCatalog().getTable("db2", "tbl1").identifier.table == "tbl1")
}
@@ -340,7 +355,7 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac
"db1",
"tbl",
Map("partCol1" -> "1", "partCol2" -> "2")).location
- val tableLocation = catalog.getTable("db1", "tbl").location
+ val tableLocation = new Path(catalog.getTable("db1", "tbl").location)
val defaultPartitionLocation = new Path(new Path(tableLocation, "partCol1=1"), "partCol2=2")
assert(new Path(partitionLocation) == defaultPartitionLocation)
}
@@ -508,7 +523,7 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac
partitionColumnNames = Seq("partCol1", "partCol2"))
catalog.createTable(table, ignoreIfExists = false)
- val tableLocation = catalog.getTable("db1", "tbl").location
+ val tableLocation = new Path(catalog.getTable("db1", "tbl").location)
val mixedCasePart1 = CatalogTablePartition(
Map("partCol1" -> "1", "partCol2" -> "2"), storageFormat)
@@ -699,7 +714,7 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac
// File System operations
// --------------------------------------------------------------------------
- private def exists(uri: String, children: String*): Boolean = {
+ private def exists(uri: URI, children: String*): Boolean = {
val base = new Path(uri)
val finalPath = children.foldLeft(base) {
case (parent, child) => new Path(parent, child)
@@ -742,7 +757,7 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac
identifier = TableIdentifier("external_table", Some("db1")),
tableType = CatalogTableType.EXTERNAL,
storage = CatalogStorageFormat(
- Some(Utils.createTempDir().getAbsolutePath),
+ Some(Utils.createTempDir().toURI),
None, None, None, false, Map.empty),
schema = new StructType().add("a", "int").add("b", "string"),
provider = Some(defaultProvider)
@@ -790,7 +805,7 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac
val partWithExistingDir = CatalogTablePartition(
Map("partCol1" -> "7", "partCol2" -> "8"),
CatalogStorageFormat(
- Some(tempPath.toURI.toString),
+ Some(tempPath.toURI),
None, None, None, false, Map.empty))
catalog.createPartitions("db1", "tbl", Seq(partWithExistingDir), ignoreIfExists = false)
@@ -799,7 +814,7 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac
val partWithNonExistingDir = CatalogTablePartition(
Map("partCol1" -> "9", "partCol2" -> "10"),
CatalogStorageFormat(
- Some(tempPath.toURI.toString),
+ Some(tempPath.toURI),
None, None, None, false, Map.empty))
catalog.createPartitions("db1", "tbl", Seq(partWithNonExistingDir), ignoreIfExists = false)
assert(tempPath.exists())
@@ -883,7 +898,7 @@ abstract class CatalogTestUtils {
def newFunc(): CatalogFunction = newFunc("funcName")
- def newUriForDatabase(): String = Utils.createTempDir().toURI.toString.stripSuffix("/")
+ def newUriForDatabase(): URI = new URI(Utils.createTempDir().toURI.toString.stripSuffix("/"))
def newDb(name: String): CatalogDatabase = {
CatalogDatabase(name, name + " description", newUriForDatabase(), Map.empty)
@@ -895,7 +910,7 @@ abstract class CatalogTestUtils {
CatalogTable(
identifier = TableIdentifier(name, database),
tableType = CatalogTableType.EXTERNAL,
- storage = storageFormat.copy(locationUri = Some(Utils.createTempDir().getAbsolutePath)),
+ storage = storageFormat.copy(locationUri = Some(Utils.createTempDir().toURI)),
schema = new StructType()
.add("col1", "int")
.add("col2", "string")
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
index a755231962be..7e74dcdef0e2 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
@@ -17,8 +17,10 @@
package org.apache.spark.sql.catalyst.catalog
+import org.apache.hadoop.conf.Configuration
+
import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
+import org.apache.spark.sql.catalyst.{FunctionIdentifier, SimpleCatalystConf, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
@@ -437,7 +439,7 @@ class SessionCatalogSuite extends PlanTest {
.asInstanceOf[CatalogRelation].tableMeta == metastoreTable1)
// Otherwise, we'll first look up a temporary table with the same name
assert(sessionCatalog.lookupRelation(TableIdentifier("tbl1"))
- == SubqueryAlias("tbl1", tempTable1, None))
+ == SubqueryAlias("tbl1", tempTable1))
// Then, if that does not exist, look up the relation in the current database
sessionCatalog.dropTable(TableIdentifier("tbl1"), ignoreIfNotExists = false, purge = false)
assert(sessionCatalog.lookupRelation(TableIdentifier("tbl1")).children.head
@@ -454,11 +456,11 @@ class SessionCatalogSuite extends PlanTest {
val view = View(desc = metadata, output = metadata.schema.toAttributes,
child = CatalystSqlParser.parsePlan(metadata.viewText.get))
comparePlans(sessionCatalog.lookupRelation(TableIdentifier("view1", Some("db3"))),
- SubqueryAlias("view1", view, Some(TableIdentifier("view1", Some("db3")))))
+ SubqueryAlias("view1", view))
// Look up a view using current database of the session catalog.
sessionCatalog.setCurrentDatabase("db3")
comparePlans(sessionCatalog.lookupRelation(TableIdentifier("view1")),
- SubqueryAlias("view1", view, Some(TableIdentifier("view1", Some("db3")))))
+ SubqueryAlias("view1", view))
}
test("table exists") {
@@ -1196,4 +1198,78 @@ class SessionCatalogSuite extends PlanTest {
catalog.listFunctions("unknown_db", "func*")
}
}
+
+ test("clone SessionCatalog - temp views") {
+ val externalCatalog = newEmptyCatalog()
+ val original = new SessionCatalog(externalCatalog)
+ val tempTable1 = Range(1, 10, 1, 10)
+ original.createTempView("copytest1", tempTable1, overrideIfExists = false)
+
+ // check if tables copied over
+ val clone = original.newSessionCatalogWith(
+ SimpleCatalystConf(caseSensitiveAnalysis = true),
+ new Configuration(),
+ new SimpleFunctionRegistry,
+ CatalystSqlParser)
+ assert(original ne clone)
+ assert(clone.getTempView("copytest1") == Some(tempTable1))
+
+ // check if clone and original independent
+ clone.dropTable(TableIdentifier("copytest1"), ignoreIfNotExists = false, purge = false)
+ assert(original.getTempView("copytest1") == Some(tempTable1))
+
+ val tempTable2 = Range(1, 20, 2, 10)
+ original.createTempView("copytest2", tempTable2, overrideIfExists = false)
+ assert(clone.getTempView("copytest2").isEmpty)
+ }
+
+ test("clone SessionCatalog - current db") {
+ val externalCatalog = newEmptyCatalog()
+ val db1 = "db1"
+ val db2 = "db2"
+ val db3 = "db3"
+
+ externalCatalog.createDatabase(newDb(db1), ignoreIfExists = true)
+ externalCatalog.createDatabase(newDb(db2), ignoreIfExists = true)
+ externalCatalog.createDatabase(newDb(db3), ignoreIfExists = true)
+
+ val original = new SessionCatalog(externalCatalog)
+ original.setCurrentDatabase(db1)
+
+ // check if current db copied over
+ val clone = original.newSessionCatalogWith(
+ SimpleCatalystConf(caseSensitiveAnalysis = true),
+ new Configuration(),
+ new SimpleFunctionRegistry,
+ CatalystSqlParser)
+ assert(original ne clone)
+ assert(clone.getCurrentDatabase == db1)
+
+ // check if clone and original independent
+ clone.setCurrentDatabase(db2)
+ assert(original.getCurrentDatabase == db1)
+ original.setCurrentDatabase(db3)
+ assert(clone.getCurrentDatabase == db2)
+ }
+
+ test("SPARK-19737: detect undefined functions without triggering relation resolution") {
+ import org.apache.spark.sql.catalyst.dsl.plans._
+
+ Seq(true, false) foreach { caseSensitive =>
+ val conf = SimpleCatalystConf(caseSensitive)
+ val catalog = new SessionCatalog(newBasicCatalog(), new SimpleFunctionRegistry, conf)
+ val analyzer = new Analyzer(catalog, conf)
+
+ // The analyzer should report the undefined function rather than the undefined table first.
+ val cause = intercept[AnalysisException] {
+ analyzer.execute(
+ UnresolvedRelation(TableIdentifier("undefined_table")).select(
+ UnresolvedFunction("undefined_fn", Nil, isDistinct = false)
+ )
+ )
+ }
+
+ assert(cause.getMessage.contains("Undefined function: 'undefined_fn'"))
+ }
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
index 0cb3a79eee67..0c77dc2709da 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
@@ -75,7 +75,6 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkConsistencyBetweenInterpretedAndCodegen(Crc32, BinaryType)
}
-
def checkHiveHash(input: Any, dataType: DataType, expected: Long): Unit = {
// Note : All expected hashes need to be computed using Hive 1.2.1
val actual = HiveHashFunction.hash(input, dataType, seed = 0)
@@ -371,6 +370,51 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
new StructType().add("array", arrayOfString).add("map", mapOfString))
.add("structOfUDT", structOfUDT))
+ test("hive-hash for decimal") {
+ def checkHiveHashForDecimal(
+ input: String,
+ precision: Int,
+ scale: Int,
+ expected: Long): Unit = {
+ val decimalType = DataTypes.createDecimalType(precision, scale)
+ val decimal = {
+ val value = Decimal.apply(new java.math.BigDecimal(input))
+ if (value.changePrecision(precision, scale)) value else null
+ }
+
+ checkHiveHash(decimal, decimalType, expected)
+ }
+
+ checkHiveHashForDecimal("18", 38, 0, 558)
+ checkHiveHashForDecimal("-18", 38, 0, -558)
+ checkHiveHashForDecimal("-18", 38, 12, -558)
+ checkHiveHashForDecimal("18446744073709001000", 38, 19, 0)
+ checkHiveHashForDecimal("-18446744073709001000", 38, 22, 0)
+ checkHiveHashForDecimal("-18446744073709001000", 38, 3, 17070057)
+ checkHiveHashForDecimal("18446744073709001000", 38, 4, -17070057)
+ checkHiveHashForDecimal("9223372036854775807", 38, 4, 2147482656)
+ checkHiveHashForDecimal("-9223372036854775807", 38, 5, -2147482656)
+ checkHiveHashForDecimal("00000.00000000000", 38, 34, 0)
+ checkHiveHashForDecimal("-00000.00000000000", 38, 11, 0)
+ checkHiveHashForDecimal("123456.1234567890", 38, 2, 382713974)
+ checkHiveHashForDecimal("123456.1234567890", 38, 20, 1871500252)
+ checkHiveHashForDecimal("123456.1234567890", 38, 10, 1871500252)
+ checkHiveHashForDecimal("-123456.1234567890", 38, 10, -1871500234)
+ checkHiveHashForDecimal("123456.1234567890", 38, 0, 3827136)
+ checkHiveHashForDecimal("-123456.1234567890", 38, 0, -3827136)
+ checkHiveHashForDecimal("123456.1234567890", 38, 20, 1871500252)
+ checkHiveHashForDecimal("-123456.1234567890", 38, 20, -1871500234)
+ checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 0, 3827136)
+ checkHiveHashForDecimal("-123456.123456789012345678901234567890", 38, 0, -3827136)
+ checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 10, 1871500252)
+ checkHiveHashForDecimal("-123456.123456789012345678901234567890", 38, 10, -1871500234)
+ checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 20, 236317582)
+ checkHiveHashForDecimal("-123456.123456789012345678901234567890", 38, 20, -236317544)
+ checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 30, 1728235666)
+ checkHiveHashForDecimal("-123456.123456789012345678901234567890", 38, 30, -1728235608)
+ checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 31, 1728235666)
+ }
+
test("SPARK-18207: Compute hash for a lot of expressions") {
val N = 1000
val wideRow = new GenericInternalRow(
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala
index 0c46819cdb9c..e3584909ddc4 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala
@@ -22,7 +22,7 @@ import java.util.Calendar
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils, ParseModes}
-import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType, TimestampType}
+import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -372,6 +372,62 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
)
}
+ test("from_json - input=array, schema=array, output=array") {
+ val input = """[{"a": 1}, {"a": 2}]"""
+ val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil))
+ val output = InternalRow(1) :: InternalRow(2) :: Nil
+ checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output)
+ }
+
+ test("from_json - input=object, schema=array, output=array of single row") {
+ val input = """{"a": 1}"""
+ val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil))
+ val output = InternalRow(1) :: Nil
+ checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output)
+ }
+
+ test("from_json - input=empty array, schema=array, output=empty array") {
+ val input = "[ ]"
+ val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil))
+ val output = Nil
+ checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output)
+ }
+
+ test("from_json - input=empty object, schema=array, output=array of single row with null") {
+ val input = "{ }"
+ val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil))
+ val output = InternalRow(null) :: Nil
+ checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output)
+ }
+
+ test("from_json - input=array of single object, schema=struct, output=single row") {
+ val input = """[{"a": 1}]"""
+ val schema = StructType(StructField("a", IntegerType) :: Nil)
+ val output = InternalRow(1)
+ checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output)
+ }
+
+ test("from_json - input=array, schema=struct, output=null") {
+ val input = """[{"a": 1}, {"a": 2}]"""
+ val schema = StructType(StructField("a", IntegerType) :: Nil)
+ val output = null
+ checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output)
+ }
+
+ test("from_json - input=empty array, schema=struct, output=null") {
+ val input = """[]"""
+ val schema = StructType(StructField("a", IntegerType) :: Nil)
+ val output = null
+ checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output)
+ }
+
+ test("from_json - input=empty object, schema=struct, output=single row with null") {
+ val input = """{ }"""
+ val schema = StructType(StructField("a", IntegerType) :: Nil)
+ val output = InternalRow(null)
+ checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output)
+ }
+
test("from_json null input column") {
val schema = StructType(StructField("a", IntegerType) :: Nil)
checkEvaluation(
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseRepartitionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseRepartitionSuite.scala
index 8952c72fe42f..59d2dc46f00c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseRepartitionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseRepartitionSuite.scala
@@ -32,47 +32,168 @@ class CollapseRepartitionSuite extends PlanTest {
val testRelation = LocalRelation('a.int, 'b.int)
+
+ test("collapse two adjacent coalesces into one") {
+ // Always respects the top coalesces amd removes useless coalesce below coalesce
+ val query1 = testRelation
+ .coalesce(10)
+ .coalesce(20)
+ val query2 = testRelation
+ .coalesce(30)
+ .coalesce(20)
+
+ val optimized1 = Optimize.execute(query1.analyze)
+ val optimized2 = Optimize.execute(query2.analyze)
+ val correctAnswer = testRelation.coalesce(20).analyze
+
+ comparePlans(optimized1, correctAnswer)
+ comparePlans(optimized2, correctAnswer)
+ }
+
test("collapse two adjacent repartitions into one") {
- val query = testRelation
+ // Always respects the top repartition amd removes useless repartition below repartition
+ val query1 = testRelation
+ .repartition(10)
+ .repartition(20)
+ val query2 = testRelation
+ .repartition(30)
+ .repartition(20)
+
+ val optimized1 = Optimize.execute(query1.analyze)
+ val optimized2 = Optimize.execute(query2.analyze)
+ val correctAnswer = testRelation.repartition(20).analyze
+
+ comparePlans(optimized1, correctAnswer)
+ comparePlans(optimized2, correctAnswer)
+ }
+
+ test("coalesce above repartition") {
+ // Remove useless coalesce above repartition
+ val query1 = testRelation
.repartition(10)
+ .coalesce(20)
+
+ val optimized1 = Optimize.execute(query1.analyze)
+ val correctAnswer1 = testRelation.repartition(10).analyze
+
+ comparePlans(optimized1, correctAnswer1)
+
+ // No change in this case
+ val query2 = testRelation
+ .repartition(30)
+ .coalesce(20)
+
+ val optimized2 = Optimize.execute(query2.analyze)
+ val correctAnswer2 = query2.analyze
+
+ comparePlans(optimized2, correctAnswer2)
+ }
+
+ test("repartition above coalesce") {
+ // Always respects the top repartition amd removes useless coalesce below repartition
+ val query1 = testRelation
+ .coalesce(10)
+ .repartition(20)
+ val query2 = testRelation
+ .coalesce(30)
.repartition(20)
- val optimized = Optimize.execute(query.analyze)
+ val optimized1 = Optimize.execute(query1.analyze)
+ val optimized2 = Optimize.execute(query2.analyze)
val correctAnswer = testRelation.repartition(20).analyze
- comparePlans(optimized, correctAnswer)
+ comparePlans(optimized1, correctAnswer)
+ comparePlans(optimized2, correctAnswer)
}
- test("collapse repartition and repartitionBy into one") {
- val query = testRelation
+ test("repartitionBy above repartition") {
+ // Always respects the top repartitionBy amd removes useless repartition
+ val query1 = testRelation
.repartition(10)
.distribute('a)(20)
+ val query2 = testRelation
+ .repartition(30)
+ .distribute('a)(20)
- val optimized = Optimize.execute(query.analyze)
+ val optimized1 = Optimize.execute(query1.analyze)
+ val optimized2 = Optimize.execute(query2.analyze)
val correctAnswer = testRelation.distribute('a)(20).analyze
- comparePlans(optimized, correctAnswer)
+ comparePlans(optimized1, correctAnswer)
+ comparePlans(optimized2, correctAnswer)
}
- test("collapse repartitionBy and repartition into one") {
- val query = testRelation
+ test("repartitionBy above coalesce") {
+ // Always respects the top repartitionBy amd removes useless coalesce below repartition
+ val query1 = testRelation
+ .coalesce(10)
+ .distribute('a)(20)
+ val query2 = testRelation
+ .coalesce(30)
.distribute('a)(20)
- .repartition(10)
- val optimized = Optimize.execute(query.analyze)
- val correctAnswer = testRelation.distribute('a)(10).analyze
+ val optimized1 = Optimize.execute(query1.analyze)
+ val optimized2 = Optimize.execute(query2.analyze)
+ val correctAnswer = testRelation.distribute('a)(20).analyze
- comparePlans(optimized, correctAnswer)
+ comparePlans(optimized1, correctAnswer)
+ comparePlans(optimized2, correctAnswer)
+ }
+
+ test("repartition above repartitionBy") {
+ // Always respects the top repartition amd removes useless distribute below repartition
+ val query1 = testRelation
+ .distribute('a)(10)
+ .repartition(20)
+ val query2 = testRelation
+ .distribute('a)(30)
+ .repartition(20)
+
+ val optimized1 = Optimize.execute(query1.analyze)
+ val optimized2 = Optimize.execute(query2.analyze)
+ val correctAnswer = testRelation.repartition(20).analyze
+
+ comparePlans(optimized1, correctAnswer)
+ comparePlans(optimized2, correctAnswer)
+
+ }
+
+ test("coalesce above repartitionBy") {
+ // Remove useless coalesce above repartition
+ val query1 = testRelation
+ .distribute('a)(10)
+ .coalesce(20)
+
+ val optimized1 = Optimize.execute(query1.analyze)
+ val correctAnswer1 = testRelation.distribute('a)(10).analyze
+
+ comparePlans(optimized1, correctAnswer1)
+
+ // No change in this case
+ val query2 = testRelation
+ .distribute('a)(30)
+ .coalesce(20)
+
+ val optimized2 = Optimize.execute(query2.analyze)
+ val correctAnswer2 = query2.analyze
+
+ comparePlans(optimized2, correctAnswer2)
}
test("collapse two adjacent repartitionBys into one") {
- val query = testRelation
+ // Always respects the top repartitionBy
+ val query1 = testRelation
.distribute('b)(10)
.distribute('a)(20)
+ val query2 = testRelation
+ .distribute('b)(30)
+ .distribute('a)(20)
- val optimized = Optimize.execute(query.analyze)
+ val optimized1 = Optimize.execute(query1.analyze)
+ val optimized2 = Optimize.execute(query2.analyze)
val correctAnswer = testRelation.distribute('a)(20).analyze
- comparePlans(optimized, correctAnswer)
+ comparePlans(optimized1, correctAnswer)
+ comparePlans(optimized2, correctAnswer)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
index 5bd1bc80c3b8..589607e3ad5c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
@@ -320,16 +320,16 @@ class ColumnPruningSuite extends PlanTest {
val query =
Project(Seq($"x.key", $"y.key"),
Join(
- SubqueryAlias("x", input, None),
- BroadcastHint(SubqueryAlias("y", input, None)), Inner, None)).analyze
+ SubqueryAlias("x", input),
+ BroadcastHint(SubqueryAlias("y", input)), Inner, None)).analyze
val optimized = Optimize.execute(query)
val expected =
Join(
- Project(Seq($"x.key"), SubqueryAlias("x", input, None)),
+ Project(Seq($"x.key"), SubqueryAlias("x", input)),
BroadcastHint(
- Project(Seq($"y.key"), SubqueryAlias("y", input, None))),
+ Project(Seq($"y.key"), SubqueryAlias("y", input))),
Inner, None).analyze
comparePlans(optimized, expected)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSubqueryAliasesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSubqueryAliasesSuite.scala
index a8aeedbd6275..9b6d68aee803 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSubqueryAliasesSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSubqueryAliasesSuite.scala
@@ -46,13 +46,13 @@ class EliminateSubqueryAliasesSuite extends PlanTest with PredicateHelper {
test("eliminate top level subquery") {
val input = LocalRelation('a.int, 'b.int)
- val query = SubqueryAlias("a", input, None)
+ val query = SubqueryAlias("a", input)
comparePlans(afterOptimization(query), input)
}
test("eliminate mid-tree subquery") {
val input = LocalRelation('a.int, 'b.int)
- val query = Filter(TrueLiteral, SubqueryAlias("a", input, None))
+ val query = Filter(TrueLiteral, SubqueryAlias("a", input))
comparePlans(
afterOptimization(query),
Filter(TrueLiteral, LocalRelation('a.int, 'b.int)))
@@ -61,7 +61,7 @@ class EliminateSubqueryAliasesSuite extends PlanTest with PredicateHelper {
test("eliminate multiple subqueries") {
val input = LocalRelation('a.int, 'b.int)
val query = Filter(TrueLiteral,
- SubqueryAlias("c", SubqueryAlias("b", SubqueryAlias("a", input, None), None), None))
+ SubqueryAlias("c", SubqueryAlias("b", SubqueryAlias("a", input))))
comparePlans(
afterOptimization(query),
Filter(TrueLiteral, LocalRelation('a.int, 'b.int)))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
index 65dd6225cea0..985e49069da9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
@@ -129,15 +129,15 @@ class JoinOptimizationSuite extends PlanTest {
val query =
Project(Seq($"x.key", $"y.key"),
Join(
- SubqueryAlias("x", input, None),
- BroadcastHint(SubqueryAlias("y", input, None)), Cross, None)).analyze
+ SubqueryAlias("x", input),
+ BroadcastHint(SubqueryAlias("y", input)), Cross, None)).analyze
val optimized = Optimize.execute(query)
val expected =
Join(
- Project(Seq($"x.key"), SubqueryAlias("x", input, None)),
- BroadcastHint(Project(Seq($"y.key"), SubqueryAlias("y", input, None))),
+ Project(Seq($"x.key"), SubqueryAlias("x", input)),
+ BroadcastHint(Project(Seq($"y.key"), SubqueryAlias("y", input))),
Cross, None).analyze
comparePlans(optimized, expected)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala
new file mode 100644
index 000000000000..1b2f7a66b6a0
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala
@@ -0,0 +1,194 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.SimpleCatalystConf
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
+import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
+import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, StatsTestPlan}
+import org.apache.spark.sql.catalyst.util._
+
+
+class JoinReorderSuite extends PlanTest with StatsEstimationTestBase {
+
+ override val conf = SimpleCatalystConf(
+ caseSensitiveAnalysis = true, cboEnabled = true, joinReorderEnabled = true)
+
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches =
+ Batch("Operator Optimizations", FixedPoint(100),
+ CombineFilters,
+ PushDownPredicate,
+ PushPredicateThroughJoin,
+ ColumnPruning,
+ CollapseProject) ::
+ Batch("Join Reorder", Once,
+ CostBasedJoinReorder(conf)) :: Nil
+ }
+
+ /** Set up tables and columns for testing */
+ private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq(
+ attr("t1.k-1-2") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ attr("t1.v-1-10") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ attr("t2.k-1-5") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ attr("t3.v-1-100") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ attr("t4.k-1-2") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2),
+ nullCount = 0, avgLen = 4, maxLen = 4),
+ attr("t4.v-1-10") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
+ nullCount = 0, avgLen = 4, maxLen = 4)
+ ))
+
+ private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1)
+ private val nameToColInfo: Map[String, (Attribute, ColumnStat)] =
+ columnInfo.map(kv => kv._1.name -> kv)
+
+ // Table t1/t4: big table with two columns
+ private val t1 = StatsTestPlan(
+ outputList = Seq("t1.k-1-2", "t1.v-1-10").map(nameToAttr),
+ rowCount = 1000,
+ // size = rows * (overhead + column length)
+ size = Some(1000 * (8 + 4 + 4)),
+ attributeStats = AttributeMap(Seq("t1.k-1-2", "t1.v-1-10").map(nameToColInfo)))
+
+ private val t4 = StatsTestPlan(
+ outputList = Seq("t4.k-1-2", "t4.v-1-10").map(nameToAttr),
+ rowCount = 2000,
+ size = Some(2000 * (8 + 4 + 4)),
+ attributeStats = AttributeMap(Seq("t4.k-1-2", "t4.v-1-10").map(nameToColInfo)))
+
+ // Table t2/t3: small table with only one column
+ private val t2 = StatsTestPlan(
+ outputList = Seq("t2.k-1-5").map(nameToAttr),
+ rowCount = 20,
+ size = Some(20 * (8 + 4)),
+ attributeStats = AttributeMap(Seq("t2.k-1-5").map(nameToColInfo)))
+
+ private val t3 = StatsTestPlan(
+ outputList = Seq("t3.v-1-100").map(nameToAttr),
+ rowCount = 100,
+ size = Some(100 * (8 + 4)),
+ attributeStats = AttributeMap(Seq("t3.v-1-100").map(nameToColInfo)))
+
+ test("reorder 3 tables") {
+ val originalPlan =
+ t1.join(t2).join(t3).where((nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) &&
+ (nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100")))
+
+ // The cost of original plan (use only cardinality to simplify explanation):
+ // cost = cost(t1 J t2) = 1000 * 20 / 5 = 4000
+ // In contrast, the cost of the best plan:
+ // cost = cost(t1 J t3) = 1000 * 100 / 100 = 1000 < 4000
+ // so (t1 J t3) J t2 is better (has lower cost, i.e. intermediate result size) than
+ // the original order (t1 J t2) J t3.
+ val bestPlan =
+ t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100")))
+ .join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
+
+ assertEqualPlans(originalPlan, bestPlan)
+ }
+
+ test("reorder 3 tables - put cross join at the end") {
+ val originalPlan =
+ t1.join(t2).join(t3).where(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))
+
+ val bestPlan =
+ t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100")))
+ .join(t2, Inner, None)
+
+ assertEqualPlans(originalPlan, bestPlan)
+ }
+
+ test("reorder 3 tables with pure-attribute project") {
+ val originalPlan =
+ t1.join(t2).join(t3).where((nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) &&
+ (nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100")))
+ .select(nameToAttr("t1.v-1-10"))
+
+ val bestPlan =
+ t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100")))
+ .select(nameToAttr("t1.k-1-2"), nameToAttr("t1.v-1-10"))
+ .join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
+ .select(nameToAttr("t1.v-1-10"))
+
+ assertEqualPlans(originalPlan, bestPlan)
+ }
+
+ test("don't reorder if project contains non-attribute") {
+ val originalPlan =
+ t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
+ .select((nameToAttr("t1.k-1-2") + nameToAttr("t2.k-1-5")) as "key", nameToAttr("t1.v-1-10"))
+ .join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100")))
+ .select("key".attr)
+
+ assertEqualPlans(originalPlan, originalPlan)
+ }
+
+ test("reorder 4 tables (bushy tree)") {
+ val originalPlan =
+ t1.join(t4).join(t2).join(t3).where((nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")) &&
+ (nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) &&
+ (nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100")))
+
+ // The cost of original plan (use only cardinality to simplify explanation):
+ // cost(t1 J t4) = 1000 * 2000 / 2 = 1000000, cost(t1t4 J t2) = 1000000 * 20 / 5 = 4000000,
+ // cost = cost(t1 J t4) + cost(t1t4 J t2) = 5000000
+ // In contrast, the cost of the best plan (a bushy tree):
+ // cost(t1 J t2) = 1000 * 20 / 5 = 4000, cost(t4 J t3) = 2000 * 100 / 100 = 2000,
+ // cost = cost(t1 J t2) + cost(t4 J t3) = 6000 << 5000000.
+ val bestPlan =
+ t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
+ .join(t4.join(t3, Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))),
+ Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")))
+
+ assertEqualPlans(originalPlan, bestPlan)
+ }
+
+ private def assertEqualPlans(
+ originalPlan: LogicalPlan,
+ groundTruthBestPlan: LogicalPlan): Unit = {
+ val optimized = Optimize.execute(originalPlan.analyze)
+ val normalized1 = normalizePlan(normalizeExprIds(optimized))
+ val normalized2 = normalizePlan(normalizeExprIds(groundTruthBestPlan.analyze))
+ if (!sameJoinPlan(normalized1, normalized2)) {
+ fail(
+ s"""
+ |== FAIL: Plans do not match ===
+ |${sideBySide(normalized1.treeString, normalized2.treeString).mkString("\n")}
+ """.stripMargin)
+ }
+ }
+
+ /** Consider symmetry for joins when comparing plans. */
+ private def sameJoinPlan(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = {
+ (plan1, plan2) match {
+ case (j1: Join, j2: Join) =>
+ (sameJoinPlan(j1.left, j2.left) && sameJoinPlan(j1.right, j2.right)) ||
+ (sameJoinPlan(j1.left, j2.right) && sameJoinPlan(j1.right, j2.left))
+ case _ =>
+ plan1 == plan2
+ }
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
index 67d5d2202b68..411777d6e85a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
@@ -79,7 +79,7 @@ class PlanParserSuite extends PlanTest {
def cte(plan: LogicalPlan, namedPlans: (String, LogicalPlan)*): With = {
val ctes = namedPlans.map {
case (name, cte) =>
- name -> SubqueryAlias(name, cte, None)
+ name -> SubqueryAlias(name, cte)
}
With(plan, ctes)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
index 3b7e5e938a8e..e9b7a0c6ad67 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
@@ -62,7 +62,7 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper {
* - Sample the seed will replaced by 0L.
* - Join conditions will be resorted by hashCode.
*/
- private def normalizePlan(plan: LogicalPlan): LogicalPlan = {
+ protected def normalizePlan(plan: LogicalPlan): LogicalPlan = {
plan transform {
case filter @ Filter(condition: Expression, child: LogicalPlan) =>
Filter(splitConjunctivePredicates(condition).map(rewriteEqual(_)).sortBy(_.hashCode())
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala
new file mode 100644
index 000000000000..e5dc811c8b7d
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.statsEstimation
+
+import org.apache.spark.sql.catalyst.CatalystConf
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Literal}
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.types.IntegerType
+
+
+class BasicStatsEstimationSuite extends StatsEstimationTestBase {
+ val attribute = attr("key")
+ val colStat = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
+ nullCount = 0, avgLen = 4, maxLen = 4)
+
+ val plan = StatsTestPlan(
+ outputList = Seq(attribute),
+ attributeStats = AttributeMap(Seq(attribute -> colStat)),
+ rowCount = 10,
+ // row count * (overhead + column size)
+ size = Some(10 * (8 + 4)))
+
+ test("limit estimation: limit < child's rowCount") {
+ val localLimit = LocalLimit(Literal(2), plan)
+ val globalLimit = GlobalLimit(Literal(2), plan)
+ // LocalLimit's stats is just its child's stats except column stats
+ checkStats(localLimit, plan.stats(conf).copy(attributeStats = AttributeMap(Nil)))
+ checkStats(globalLimit, Statistics(sizeInBytes = 24, rowCount = Some(2)))
+ }
+
+ test("limit estimation: limit > child's rowCount") {
+ val localLimit = LocalLimit(Literal(20), plan)
+ val globalLimit = GlobalLimit(Literal(20), plan)
+ checkStats(localLimit, plan.stats(conf).copy(attributeStats = AttributeMap(Nil)))
+ // Limit is larger than child's rowCount, so GlobalLimit's stats is equal to its child's stats.
+ checkStats(globalLimit, plan.stats(conf).copy(attributeStats = AttributeMap(Nil)))
+ }
+
+ test("limit estimation: limit = 0") {
+ val localLimit = LocalLimit(Literal(0), plan)
+ val globalLimit = GlobalLimit(Literal(0), plan)
+ val stats = Statistics(sizeInBytes = 1, rowCount = Some(0))
+ checkStats(localLimit, stats)
+ checkStats(globalLimit, stats)
+ }
+
+ test("sample estimation") {
+ val sample = Sample(0.0, 0.5, withReplacement = false, (math.random * 1000).toLong, plan)()
+ checkStats(sample, Statistics(sizeInBytes = 60, rowCount = Some(5)))
+
+ // Child doesn't have rowCount in stats
+ val childStats = Statistics(sizeInBytes = 120)
+ val childPlan = DummyLogicalPlan(childStats, childStats)
+ val sample2 =
+ Sample(0.0, 0.11, withReplacement = false, (math.random * 1000).toLong, childPlan)()
+ checkStats(sample2, Statistics(sizeInBytes = 14))
+ }
+
+ test("estimate statistics when the conf changes") {
+ val expectedDefaultStats =
+ Statistics(
+ sizeInBytes = 40,
+ rowCount = Some(10),
+ attributeStats = AttributeMap(Seq(
+ AttributeReference("c1", IntegerType)() -> ColumnStat(10, Some(1), Some(10), 0, 4, 4))),
+ isBroadcastable = false)
+ val expectedCboStats =
+ Statistics(
+ sizeInBytes = 4,
+ rowCount = Some(1),
+ attributeStats = AttributeMap(Seq(
+ AttributeReference("c1", IntegerType)() -> ColumnStat(1, Some(5), Some(5), 0, 4, 4))),
+ isBroadcastable = false)
+
+ val plan = DummyLogicalPlan(defaultStats = expectedDefaultStats, cboStats = expectedCboStats)
+ checkStats(
+ plan, expectedStatsCboOn = expectedCboStats, expectedStatsCboOff = expectedDefaultStats)
+ }
+
+ /** Check estimated stats when cbo is turned on/off. */
+ private def checkStats(
+ plan: LogicalPlan,
+ expectedStatsCboOn: Statistics,
+ expectedStatsCboOff: Statistics): Unit = {
+ assert(plan.stats(conf.copy(cboEnabled = true)) == expectedStatsCboOn)
+ // Invalidate statistics
+ plan.invalidateStatsCache()
+ assert(plan.stats(conf.copy(cboEnabled = false)) == expectedStatsCboOff)
+ }
+
+ /** Check estimated stats when it's the same whether cbo is turned on or off. */
+ private def checkStats(plan: LogicalPlan, expectedStats: Statistics): Unit =
+ checkStats(plan, expectedStats, expectedStats)
+}
+
+/**
+ * This class is used for unit-testing the cbo switch, it mimics a logical plan which computes
+ * a simple statistics or a cbo estimated statistics based on the conf.
+ */
+private case class DummyLogicalPlan(
+ defaultStats: Statistics,
+ cboStats: Statistics) extends LogicalPlan {
+ override def output: Seq[Attribute] = Nil
+ override def children: Seq[LogicalPlan] = Nil
+ override def computeStats(conf: CatalystConf): Statistics =
+ if (conf.cboEnabled) cboStats else defaultStats
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
index 8be74ced7bb7..4691913c8c98 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.statsEstimation
import java.sql.Date
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, Statistics}
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
import org.apache.spark.sql.types._
@@ -33,219 +33,235 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
// Suppose our test table has 10 rows and 6 columns.
// First column cint has values: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10
// Hence, distinctCount:10, min:1, max:10, nullCount:0, avgLen:4, maxLen:4
- val arInt = AttributeReference("cint", IntegerType)()
- val childColStatInt = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
+ val attrInt = AttributeReference("cint", IntegerType)()
+ val colStatInt = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
nullCount = 0, avgLen = 4, maxLen = 4)
// only 2 values
- val arBool = AttributeReference("cbool", BooleanType)()
- val childColStatBool = ColumnStat(distinctCount = 2, min = Some(false), max = Some(true),
+ val attrBool = AttributeReference("cbool", BooleanType)()
+ val colStatBool = ColumnStat(distinctCount = 2, min = Some(false), max = Some(true),
nullCount = 0, avgLen = 1, maxLen = 1)
// Second column cdate has 10 values from 2017-01-01 through 2017-01-10.
val dMin = Date.valueOf("2017-01-01")
val dMax = Date.valueOf("2017-01-10")
- val arDate = AttributeReference("cdate", DateType)()
- val childColStatDate = ColumnStat(distinctCount = 10, min = Some(dMin), max = Some(dMax),
+ val attrDate = AttributeReference("cdate", DateType)()
+ val colStatDate = ColumnStat(distinctCount = 10, min = Some(dMin), max = Some(dMax),
nullCount = 0, avgLen = 4, maxLen = 4)
// Fourth column cdecimal has 4 values from 0.20 through 0.80 at increment of 0.20.
val decMin = new java.math.BigDecimal("0.200000000000000000")
val decMax = new java.math.BigDecimal("0.800000000000000000")
- val arDecimal = AttributeReference("cdecimal", DecimalType(18, 18))()
- val childColStatDecimal = ColumnStat(distinctCount = 4, min = Some(decMin), max = Some(decMax),
+ val attrDecimal = AttributeReference("cdecimal", DecimalType(18, 18))()
+ val colStatDecimal = ColumnStat(distinctCount = 4, min = Some(decMin), max = Some(decMax),
nullCount = 0, avgLen = 8, maxLen = 8)
// Fifth column cdouble has 10 double values: 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0
- val arDouble = AttributeReference("cdouble", DoubleType)()
- val childColStatDouble = ColumnStat(distinctCount = 10, min = Some(1.0), max = Some(10.0),
+ val attrDouble = AttributeReference("cdouble", DoubleType)()
+ val colStatDouble = ColumnStat(distinctCount = 10, min = Some(1.0), max = Some(10.0),
nullCount = 0, avgLen = 8, maxLen = 8)
// Sixth column cstring has 10 String values:
// "A0", "A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8", "A9"
- val arString = AttributeReference("cstring", StringType)()
- val childColStatString = ColumnStat(distinctCount = 10, min = None, max = None,
+ val attrString = AttributeReference("cstring", StringType)()
+ val colStatString = ColumnStat(distinctCount = 10, min = None, max = None,
nullCount = 0, avgLen = 2, maxLen = 2)
+ val attributeMap = AttributeMap(Seq(
+ attrInt -> colStatInt,
+ attrBool -> colStatBool,
+ attrDate -> colStatDate,
+ attrDecimal -> colStatDecimal,
+ attrDouble -> colStatDouble,
+ attrString -> colStatString))
+
test("cint = 2") {
validateEstimatedStats(
- arInt,
- Filter(EqualTo(arInt, Literal(2)), childStatsTestPlan(Seq(arInt), 10L)),
- ColumnStat(distinctCount = 1, min = Some(2), max = Some(2),
- nullCount = 0, avgLen = 4, maxLen = 4),
- 1)
+ Filter(EqualTo(attrInt, Literal(2)), childStatsTestPlan(Seq(attrInt), 10L)),
+ Seq(attrInt -> ColumnStat(distinctCount = 1, min = Some(2), max = Some(2),
+ nullCount = 0, avgLen = 4, maxLen = 4)),
+ expectedRowCount = 1)
}
test("cint <=> 2") {
validateEstimatedStats(
- arInt,
- Filter(EqualNullSafe(arInt, Literal(2)), childStatsTestPlan(Seq(arInt), 10L)),
- ColumnStat(distinctCount = 1, min = Some(2), max = Some(2),
- nullCount = 0, avgLen = 4, maxLen = 4),
- 1)
+ Filter(EqualNullSafe(attrInt, Literal(2)), childStatsTestPlan(Seq(attrInt), 10L)),
+ Seq(attrInt -> ColumnStat(distinctCount = 1, min = Some(2), max = Some(2),
+ nullCount = 0, avgLen = 4, maxLen = 4)),
+ expectedRowCount = 1)
}
test("cint = 0") {
// This is an out-of-range case since 0 is outside the range [min, max]
validateEstimatedStats(
- arInt,
- Filter(EqualTo(arInt, Literal(0)), childStatsTestPlan(Seq(arInt), 10L)),
- ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
- nullCount = 0, avgLen = 4, maxLen = 4),
- 0)
+ Filter(EqualTo(attrInt, Literal(0)), childStatsTestPlan(Seq(attrInt), 10L)),
+ Nil,
+ expectedRowCount = 0)
}
test("cint < 3") {
validateEstimatedStats(
- arInt,
- Filter(LessThan(arInt, Literal(3)), childStatsTestPlan(Seq(arInt), 10L)),
- ColumnStat(distinctCount = 2, min = Some(1), max = Some(3),
- nullCount = 0, avgLen = 4, maxLen = 4),
- 3)
+ Filter(LessThan(attrInt, Literal(3)), childStatsTestPlan(Seq(attrInt), 10L)),
+ Seq(attrInt -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(3),
+ nullCount = 0, avgLen = 4, maxLen = 4)),
+ expectedRowCount = 3)
}
test("cint < 0") {
// This is a corner case since literal 0 is smaller than min.
validateEstimatedStats(
- arInt,
- Filter(LessThan(arInt, Literal(0)), childStatsTestPlan(Seq(arInt), 10L)),
- ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
- nullCount = 0, avgLen = 4, maxLen = 4),
- 0)
+ Filter(LessThan(attrInt, Literal(0)), childStatsTestPlan(Seq(attrInt), 10L)),
+ Nil,
+ expectedRowCount = 0)
}
test("cint <= 3") {
validateEstimatedStats(
- arInt,
- Filter(LessThanOrEqual(arInt, Literal(3)), childStatsTestPlan(Seq(arInt), 10L)),
- ColumnStat(distinctCount = 2, min = Some(1), max = Some(3),
- nullCount = 0, avgLen = 4, maxLen = 4),
- 3)
+ Filter(LessThanOrEqual(attrInt, Literal(3)), childStatsTestPlan(Seq(attrInt), 10L)),
+ Seq(attrInt -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(3),
+ nullCount = 0, avgLen = 4, maxLen = 4)),
+ expectedRowCount = 3)
}
test("cint > 6") {
validateEstimatedStats(
- arInt,
- Filter(GreaterThan(arInt, Literal(6)), childStatsTestPlan(Seq(arInt), 10L)),
- ColumnStat(distinctCount = 4, min = Some(6), max = Some(10),
- nullCount = 0, avgLen = 4, maxLen = 4),
- 5)
+ Filter(GreaterThan(attrInt, Literal(6)), childStatsTestPlan(Seq(attrInt), 10L)),
+ Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(6), max = Some(10),
+ nullCount = 0, avgLen = 4, maxLen = 4)),
+ expectedRowCount = 5)
}
test("cint > 10") {
// This is a corner case since max value is 10.
validateEstimatedStats(
- arInt,
- Filter(GreaterThan(arInt, Literal(10)), childStatsTestPlan(Seq(arInt), 10L)),
- ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
- nullCount = 0, avgLen = 4, maxLen = 4),
- 0)
+ Filter(GreaterThan(attrInt, Literal(10)), childStatsTestPlan(Seq(attrInt), 10L)),
+ Nil,
+ expectedRowCount = 0)
}
test("cint >= 6") {
validateEstimatedStats(
- arInt,
- Filter(GreaterThanOrEqual(arInt, Literal(6)), childStatsTestPlan(Seq(arInt), 10L)),
- ColumnStat(distinctCount = 4, min = Some(6), max = Some(10),
- nullCount = 0, avgLen = 4, maxLen = 4),
- 5)
+ Filter(GreaterThanOrEqual(attrInt, Literal(6)), childStatsTestPlan(Seq(attrInt), 10L)),
+ Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(6), max = Some(10),
+ nullCount = 0, avgLen = 4, maxLen = 4)),
+ expectedRowCount = 5)
}
test("cint IS NULL") {
validateEstimatedStats(
- arInt,
- Filter(IsNull(arInt), childStatsTestPlan(Seq(arInt), 10L)),
- ColumnStat(distinctCount = 0, min = None, max = None,
- nullCount = 0, avgLen = 4, maxLen = 4),
- 0)
+ Filter(IsNull(attrInt), childStatsTestPlan(Seq(attrInt), 10L)),
+ Nil,
+ expectedRowCount = 0)
}
test("cint IS NOT NULL") {
validateEstimatedStats(
- arInt,
- Filter(IsNotNull(arInt), childStatsTestPlan(Seq(arInt), 10L)),
- ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
- nullCount = 0, avgLen = 4, maxLen = 4),
- 10)
+ Filter(IsNotNull(attrInt), childStatsTestPlan(Seq(attrInt), 10L)),
+ Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
+ nullCount = 0, avgLen = 4, maxLen = 4)),
+ expectedRowCount = 10)
}
test("cint > 3 AND cint <= 6") {
- val condition = And(GreaterThan(arInt, Literal(3)), LessThanOrEqual(arInt, Literal(6)))
+ val condition = And(GreaterThan(attrInt, Literal(3)), LessThanOrEqual(attrInt, Literal(6)))
validateEstimatedStats(
- arInt,
- Filter(condition, childStatsTestPlan(Seq(arInt), 10L)),
- ColumnStat(distinctCount = 3, min = Some(3), max = Some(6),
- nullCount = 0, avgLen = 4, maxLen = 4),
- 4)
+ Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)),
+ Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(3), max = Some(6),
+ nullCount = 0, avgLen = 4, maxLen = 4)),
+ expectedRowCount = 4)
}
test("cint = 3 OR cint = 6") {
- val condition = Or(EqualTo(arInt, Literal(3)), EqualTo(arInt, Literal(6)))
+ val condition = Or(EqualTo(attrInt, Literal(3)), EqualTo(attrInt, Literal(6)))
+ validateEstimatedStats(
+ Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)),
+ Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
+ nullCount = 0, avgLen = 4, maxLen = 4)),
+ expectedRowCount = 2)
+ }
+
+ test("Not(cint > 3 AND cint <= 6)") {
+ val condition = Not(And(GreaterThan(attrInt, Literal(3)), LessThanOrEqual(attrInt, Literal(6))))
+ validateEstimatedStats(
+ Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)),
+ Seq(attrInt -> colStatInt),
+ expectedRowCount = 6)
+ }
+
+ test("Not(cint <= 3 OR cint > 6)") {
+ val condition = Not(Or(LessThanOrEqual(attrInt, Literal(3)), GreaterThan(attrInt, Literal(6))))
+ validateEstimatedStats(
+ Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)),
+ Seq(attrInt -> colStatInt),
+ expectedRowCount = 5)
+ }
+
+ test("Not(cint = 3 AND cstring < 'A8')") {
+ val condition = Not(And(EqualTo(attrInt, Literal(3)), LessThan(attrString, Literal("A8"))))
+ validateEstimatedStats(
+ Filter(condition, childStatsTestPlan(Seq(attrInt, attrString), 10L)),
+ Seq(attrInt -> colStatInt, attrString -> colStatString),
+ expectedRowCount = 10)
+ }
+
+ test("Not(cint = 3 OR cstring < 'A8')") {
+ val condition = Not(Or(EqualTo(attrInt, Literal(3)), LessThan(attrString, Literal("A8"))))
validateEstimatedStats(
- arInt,
- Filter(condition, childStatsTestPlan(Seq(arInt), 10L)),
- ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
- nullCount = 0, avgLen = 4, maxLen = 4),
- 2)
+ Filter(condition, childStatsTestPlan(Seq(attrInt, attrString), 10L)),
+ Seq(attrInt -> colStatInt, attrString -> colStatString),
+ expectedRowCount = 9)
}
test("cint IN (3, 4, 5)") {
validateEstimatedStats(
- arInt,
- Filter(InSet(arInt, Set(3, 4, 5)), childStatsTestPlan(Seq(arInt), 10L)),
- ColumnStat(distinctCount = 3, min = Some(3), max = Some(5),
- nullCount = 0, avgLen = 4, maxLen = 4),
- 3)
+ Filter(InSet(attrInt, Set(3, 4, 5)), childStatsTestPlan(Seq(attrInt), 10L)),
+ Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(3), max = Some(5),
+ nullCount = 0, avgLen = 4, maxLen = 4)),
+ expectedRowCount = 3)
}
test("cint NOT IN (3, 4, 5)") {
validateEstimatedStats(
- arInt,
- Filter(Not(InSet(arInt, Set(3, 4, 5))), childStatsTestPlan(Seq(arInt), 10L)),
- ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
- nullCount = 0, avgLen = 4, maxLen = 4),
- 7)
+ Filter(Not(InSet(attrInt, Set(3, 4, 5))), childStatsTestPlan(Seq(attrInt), 10L)),
+ Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
+ nullCount = 0, avgLen = 4, maxLen = 4)),
+ expectedRowCount = 7)
}
test("cbool = true") {
validateEstimatedStats(
- arBool,
- Filter(EqualTo(arBool, Literal(true)), childStatsTestPlan(Seq(arBool), 10L)),
- ColumnStat(distinctCount = 1, min = Some(true), max = Some(true),
- nullCount = 0, avgLen = 1, maxLen = 1),
- 5)
+ Filter(EqualTo(attrBool, Literal(true)), childStatsTestPlan(Seq(attrBool), 10L)),
+ Seq(attrBool -> ColumnStat(distinctCount = 1, min = Some(true), max = Some(true),
+ nullCount = 0, avgLen = 1, maxLen = 1)),
+ expectedRowCount = 5)
}
test("cbool > false") {
- // bool comparison is not supported yet, so stats remain same.
validateEstimatedStats(
- arBool,
- Filter(GreaterThan(arBool, Literal(false)), childStatsTestPlan(Seq(arBool), 10L)),
- ColumnStat(distinctCount = 2, min = Some(false), max = Some(true),
- nullCount = 0, avgLen = 1, maxLen = 1),
- 10)
+ Filter(GreaterThan(attrBool, Literal(false)), childStatsTestPlan(Seq(attrBool), 10L)),
+ Seq(attrBool -> ColumnStat(distinctCount = 1, min = Some(true), max = Some(true),
+ nullCount = 0, avgLen = 1, maxLen = 1)),
+ expectedRowCount = 5)
}
test("cdate = cast('2017-01-02' AS DATE)") {
val d20170102 = Date.valueOf("2017-01-02")
validateEstimatedStats(
- arDate,
- Filter(EqualTo(arDate, Literal(d20170102)),
- childStatsTestPlan(Seq(arDate), 10L)),
- ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102),
- nullCount = 0, avgLen = 4, maxLen = 4),
- 1)
+ Filter(EqualTo(attrDate, Literal(d20170102)),
+ childStatsTestPlan(Seq(attrDate), 10L)),
+ Seq(attrDate -> ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102),
+ nullCount = 0, avgLen = 4, maxLen = 4)),
+ expectedRowCount = 1)
}
test("cdate < cast('2017-01-03' AS DATE)") {
val d20170103 = Date.valueOf("2017-01-03")
validateEstimatedStats(
- arDate,
- Filter(LessThan(arDate, Literal(d20170103)),
- childStatsTestPlan(Seq(arDate), 10L)),
- ColumnStat(distinctCount = 2, min = Some(dMin), max = Some(d20170103),
- nullCount = 0, avgLen = 4, maxLen = 4),
- 3)
+ Filter(LessThan(attrDate, Literal(d20170103)),
+ childStatsTestPlan(Seq(attrDate), 10L)),
+ Seq(attrDate -> ColumnStat(distinctCount = 2, min = Some(dMin), max = Some(d20170103),
+ nullCount = 0, avgLen = 4, maxLen = 4)),
+ expectedRowCount = 3)
}
test("""cdate IN ( cast('2017-01-03' AS DATE),
@@ -254,133 +270,118 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
val d20170104 = Date.valueOf("2017-01-04")
val d20170105 = Date.valueOf("2017-01-05")
validateEstimatedStats(
- arDate,
- Filter(In(arDate, Seq(Literal(d20170103), Literal(d20170104), Literal(d20170105))),
- childStatsTestPlan(Seq(arDate), 10L)),
- ColumnStat(distinctCount = 3, min = Some(d20170103), max = Some(d20170105),
- nullCount = 0, avgLen = 4, maxLen = 4),
- 3)
+ Filter(In(attrDate, Seq(Literal(d20170103), Literal(d20170104), Literal(d20170105))),
+ childStatsTestPlan(Seq(attrDate), 10L)),
+ Seq(attrDate -> ColumnStat(distinctCount = 3, min = Some(d20170103), max = Some(d20170105),
+ nullCount = 0, avgLen = 4, maxLen = 4)),
+ expectedRowCount = 3)
}
test("cdecimal = 0.400000000000000000") {
val dec_0_40 = new java.math.BigDecimal("0.400000000000000000")
validateEstimatedStats(
- arDecimal,
- Filter(EqualTo(arDecimal, Literal(dec_0_40)),
- childStatsTestPlan(Seq(arDecimal), 4L)),
- ColumnStat(distinctCount = 1, min = Some(dec_0_40), max = Some(dec_0_40),
- nullCount = 0, avgLen = 8, maxLen = 8),
- 1)
+ Filter(EqualTo(attrDecimal, Literal(dec_0_40)),
+ childStatsTestPlan(Seq(attrDecimal), 4L)),
+ Seq(attrDecimal -> ColumnStat(distinctCount = 1, min = Some(dec_0_40), max = Some(dec_0_40),
+ nullCount = 0, avgLen = 8, maxLen = 8)),
+ expectedRowCount = 1)
}
test("cdecimal < 0.60 ") {
val dec_0_60 = new java.math.BigDecimal("0.600000000000000000")
validateEstimatedStats(
- arDecimal,
- Filter(LessThan(arDecimal, Literal(dec_0_60)),
- childStatsTestPlan(Seq(arDecimal), 4L)),
- ColumnStat(distinctCount = 3, min = Some(decMin), max = Some(dec_0_60),
- nullCount = 0, avgLen = 8, maxLen = 8),
- 3)
+ Filter(LessThan(attrDecimal, Literal(dec_0_60)),
+ childStatsTestPlan(Seq(attrDecimal), 4L)),
+ Seq(attrDecimal -> ColumnStat(distinctCount = 3, min = Some(decMin), max = Some(dec_0_60),
+ nullCount = 0, avgLen = 8, maxLen = 8)),
+ expectedRowCount = 3)
}
test("cdouble < 3.0") {
validateEstimatedStats(
- arDouble,
- Filter(LessThan(arDouble, Literal(3.0)), childStatsTestPlan(Seq(arDouble), 10L)),
- ColumnStat(distinctCount = 2, min = Some(1.0), max = Some(3.0),
- nullCount = 0, avgLen = 8, maxLen = 8),
- 3)
+ Filter(LessThan(attrDouble, Literal(3.0)), childStatsTestPlan(Seq(attrDouble), 10L)),
+ Seq(attrDouble -> ColumnStat(distinctCount = 2, min = Some(1.0), max = Some(3.0),
+ nullCount = 0, avgLen = 8, maxLen = 8)),
+ expectedRowCount = 3)
}
test("cstring = 'A2'") {
validateEstimatedStats(
- arString,
- Filter(EqualTo(arString, Literal("A2")), childStatsTestPlan(Seq(arString), 10L)),
- ColumnStat(distinctCount = 1, min = None, max = None,
- nullCount = 0, avgLen = 2, maxLen = 2),
- 1)
+ Filter(EqualTo(attrString, Literal("A2")), childStatsTestPlan(Seq(attrString), 10L)),
+ Seq(attrString -> ColumnStat(distinctCount = 1, min = None, max = None,
+ nullCount = 0, avgLen = 2, maxLen = 2)),
+ expectedRowCount = 1)
}
- // There is no min/max statistics for String type. We estimate 10 rows returned.
- test("cstring < 'A2'") {
+ test("cstring < 'A2' - unsupported condition") {
validateEstimatedStats(
- arString,
- Filter(LessThan(arString, Literal("A2")), childStatsTestPlan(Seq(arString), 10L)),
- ColumnStat(distinctCount = 10, min = None, max = None,
- nullCount = 0, avgLen = 2, maxLen = 2),
- 10)
+ Filter(LessThan(attrString, Literal("A2")), childStatsTestPlan(Seq(attrString), 10L)),
+ Seq(attrString -> ColumnStat(distinctCount = 10, min = None, max = None,
+ nullCount = 0, avgLen = 2, maxLen = 2)),
+ expectedRowCount = 10)
}
- // This is a corner test case. We want to test if we can handle the case when the number of
- // valid values in IN clause is greater than the number of distinct values for a given column.
- // For example, column has only 2 distinct values 1 and 6.
- // The predicate is: column IN (1, 2, 3, 4, 5).
test("cint IN (1, 2, 3, 4, 5)") {
+ // This is a corner test case. We want to test if we can handle the case when the number of
+ // valid values in IN clause is greater than the number of distinct values for a given column.
+ // For example, column has only 2 distinct values 1 and 6.
+ // The predicate is: column IN (1, 2, 3, 4, 5).
val cornerChildColStatInt = ColumnStat(distinctCount = 2, min = Some(1), max = Some(6),
nullCount = 0, avgLen = 4, maxLen = 4)
val cornerChildStatsTestplan = StatsTestPlan(
- outputList = Seq(arInt),
+ outputList = Seq(attrInt),
rowCount = 2L,
- attributeStats = AttributeMap(Seq(arInt -> cornerChildColStatInt))
+ attributeStats = AttributeMap(Seq(attrInt -> cornerChildColStatInt))
)
validateEstimatedStats(
- arInt,
- Filter(InSet(arInt, Set(1, 2, 3, 4, 5)), cornerChildStatsTestplan),
- ColumnStat(distinctCount = 2, min = Some(1), max = Some(5),
- nullCount = 0, avgLen = 4, maxLen = 4),
- 2)
+ Filter(InSet(attrInt, Set(1, 2, 3, 4, 5)), cornerChildStatsTestplan),
+ Seq(attrInt -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(5),
+ nullCount = 0, avgLen = 4, maxLen = 4)),
+ expectedRowCount = 2)
}
private def childStatsTestPlan(outList: Seq[Attribute], tableRowCount: BigInt): StatsTestPlan = {
StatsTestPlan(
outputList = outList,
rowCount = tableRowCount,
- attributeStats = AttributeMap(Seq(
- arInt -> childColStatInt,
- arBool -> childColStatBool,
- arDate -> childColStatDate,
- arDecimal -> childColStatDecimal,
- arDouble -> childColStatDouble,
- arString -> childColStatString
- ))
- )
+ attributeStats = AttributeMap(outList.map(a => a -> attributeMap(a))))
}
private def validateEstimatedStats(
- ar: AttributeReference,
filterNode: Filter,
- expectedColStats: ColumnStat,
- rowCount: Int): Unit = {
-
- val expectedAttrStats = toAttributeMap(Seq(ar.name -> expectedColStats), filterNode)
- val expectedSizeInBytes = getOutputSize(filterNode.output, rowCount, expectedAttrStats)
-
- val filteredStats = filterNode.stats(conf)
- assert(filteredStats.sizeInBytes == expectedSizeInBytes)
- assert(filteredStats.rowCount.get == rowCount)
- assert(filteredStats.attributeStats(ar) == expectedColStats)
-
- // If the filter has a binary operator (including those nested inside
- // AND/OR/NOT), swap the sides of the attribte and the literal, reverse the
- // operator, and then check again.
- val rewrittenFilter = filterNode transformExpressionsDown {
- case EqualTo(ar: AttributeReference, l: Literal) =>
- EqualTo(l, ar)
-
- case LessThan(ar: AttributeReference, l: Literal) =>
- GreaterThan(l, ar)
- case LessThanOrEqual(ar: AttributeReference, l: Literal) =>
- GreaterThanOrEqual(l, ar)
-
- case GreaterThan(ar: AttributeReference, l: Literal) =>
- LessThan(l, ar)
- case GreaterThanOrEqual(ar: AttributeReference, l: Literal) =>
- LessThanOrEqual(l, ar)
+ expectedColStats: Seq[(Attribute, ColumnStat)],
+ expectedRowCount: Int): Unit = {
+
+ // If the filter has a binary operator (including those nested inside AND/OR/NOT), swap the
+ // sides of the attribute and the literal, reverse the operator, and then check again.
+ val swappedFilter = filterNode transformExpressionsDown {
+ case EqualTo(attr: Attribute, l: Literal) =>
+ EqualTo(l, attr)
+
+ case LessThan(attr: Attribute, l: Literal) =>
+ GreaterThan(l, attr)
+ case LessThanOrEqual(attr: Attribute, l: Literal) =>
+ GreaterThanOrEqual(l, attr)
+
+ case GreaterThan(attr: Attribute, l: Literal) =>
+ LessThan(l, attr)
+ case GreaterThanOrEqual(attr: Attribute, l: Literal) =>
+ LessThanOrEqual(l, attr)
+ }
+
+ val testFilters = if (swappedFilter != filterNode) {
+ Seq(swappedFilter, filterNode)
+ } else {
+ Seq(filterNode)
}
- if (rewrittenFilter != filterNode) {
- validateEstimatedStats(ar, rewrittenFilter, expectedColStats, rowCount)
+ testFilters.foreach { filter =>
+ val expectedAttributeMap = AttributeMap(expectedColStats)
+ val expectedStats = Statistics(
+ sizeInBytes = getOutputSize(filter.output, expectedRowCount, expectedAttributeMap),
+ rowCount = Some(expectedRowCount),
+ attributeStats = expectedAttributeMap)
+ assert(filter.stats(conf) == expectedStats)
}
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsConfSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsConfSuite.scala
deleted file mode 100644
index 212d57a9bcf9..000000000000
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsConfSuite.scala
+++ /dev/null
@@ -1,64 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.statsEstimation
-
-import org.apache.spark.sql.catalyst.CatalystConf
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference}
-import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Statistics}
-import org.apache.spark.sql.types.IntegerType
-
-
-class StatsConfSuite extends StatsEstimationTestBase {
- test("estimate statistics when the conf changes") {
- val expectedDefaultStats =
- Statistics(
- sizeInBytes = 40,
- rowCount = Some(10),
- attributeStats = AttributeMap(Seq(
- AttributeReference("c1", IntegerType)() -> ColumnStat(10, Some(1), Some(10), 0, 4, 4))),
- isBroadcastable = false)
- val expectedCboStats =
- Statistics(
- sizeInBytes = 4,
- rowCount = Some(1),
- attributeStats = AttributeMap(Seq(
- AttributeReference("c1", IntegerType)() -> ColumnStat(1, Some(5), Some(5), 0, 4, 4))),
- isBroadcastable = false)
-
- val plan = DummyLogicalPlan(defaultStats = expectedDefaultStats, cboStats = expectedCboStats)
- // Return the statistics estimated by cbo
- assert(plan.stats(conf.copy(cboEnabled = true)) == expectedCboStats)
- // Invalidate statistics
- plan.invalidateStatsCache()
- // Return the simple statistics
- assert(plan.stats(conf.copy(cboEnabled = false)) == expectedDefaultStats)
- }
-}
-
-/**
- * This class is used for unit-testing the cbo switch, it mimics a logical plan which computes
- * a simple statistics or a cbo estimated statistics based on the conf.
- */
-private case class DummyLogicalPlan(
- defaultStats: Statistics,
- cboStats: Statistics) extends LogicalPlan {
- override def output: Seq[Attribute] = Nil
- override def children: Seq[LogicalPlan] = Nil
- override def computeStats(conf: CatalystConf): Statistics =
- if (conf.cboEnabled) cboStats else defaultStats
-}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala
index c56b41ce3763..9b2b8dbe1bf4 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, Logica
import org.apache.spark.sql.types.{IntegerType, StringType}
-class StatsEstimationTestBase extends SparkFunSuite {
+trait StatsEstimationTestBase extends SparkFunSuite {
/** Enable stats estimation based on CBO. */
protected val conf = SimpleCatalystConf(caseSensitiveAnalysis = true, cboEnabled = true)
@@ -48,7 +48,7 @@ class StatsEstimationTestBase extends SparkFunSuite {
/**
* This class is used for unit-testing. It's a logical plan whose output and stats are passed in.
*/
-protected case class StatsTestPlan(
+case class StatsTestPlan(
outputList: Seq[Attribute],
rowCount: BigInt,
attributeStats: AttributeMap[ColumnStat],
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModesSuite.scala
new file mode 100644
index 000000000000..201dac35ed2d
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModesSuite.scala
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.streaming
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.streaming.OutputMode
+
+class InternalOutputModesSuite extends SparkFunSuite {
+
+ test("supported strings") {
+ def testMode(outputMode: String, expected: OutputMode): Unit = {
+ assert(InternalOutputModes(outputMode) === expected)
+ }
+
+ testMode("append", OutputMode.Append)
+ testMode("Append", OutputMode.Append)
+ testMode("complete", OutputMode.Complete)
+ testMode("Complete", OutputMode.Complete)
+ testMode("update", OutputMode.Update)
+ testMode("Update", OutputMode.Update)
+ }
+
+ test("unsupported strings") {
+ def testMode(outputMode: String): Unit = {
+ val acceptedModes = Seq("append", "update", "complete")
+ val e = intercept[IllegalArgumentException](InternalOutputModes(outputMode))
+ (Seq("output mode", "unknown", outputMode) ++ acceptedModes).foreach { s =>
+ assert(e.getMessage.toLowerCase.contains(s.toLowerCase))
+ }
+ }
+ testMode("Xyz")
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
index af1eaa1f2374..37e3dfabd0b2 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
@@ -491,7 +491,8 @@ class TreeNodeSuite extends SparkFunSuite {
"lastAccessTime" -> -1,
"tracksPartitionsInCatalog" -> false,
"properties" -> JNull,
- "unsupportedFeatures" -> List.empty[String]))
+ "unsupportedFeatures" -> List.empty[String],
+ "schemaPreservesCase" -> JBool(true)))
// For unknown case class, returns JNull.
val bigValue = new Array[Int](10000)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
index 52d0692524d0..714883a4099c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
@@ -193,7 +193,7 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester {
assert(Decimal(Long.MaxValue, 100, 0).toUnscaledLong === Long.MaxValue)
}
- test("changePrecision() on compact decimal should respect rounding mode") {
+ test("changePrecision/toPrecision on compact decimal should respect rounding mode") {
Seq(ROUND_FLOOR, ROUND_CEILING, ROUND_HALF_UP, ROUND_HALF_EVEN).foreach { mode =>
Seq("0.4", "0.5", "0.6", "1.0", "1.1", "1.6", "2.5", "5.5").foreach { n =>
Seq("", "-").foreach { sign =>
@@ -202,6 +202,12 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester {
val d = Decimal(unscaled, 8, 1)
assert(d.changePrecision(10, 0, mode))
assert(d.toString === bd.setScale(0, mode).toString(), s"num: $sign$n, mode: $mode")
+
+ val copy = d.toPrecision(10, 0, mode).orNull
+ assert(copy !== null)
+ assert(d.ne(copy))
+ assert(d === copy)
+ assert(copy.toString === bd.setScale(0, mode).toString(), s"num: $sign$n, mode: $mode")
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 41470ae6aae1..4f4cc9311749 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -29,6 +29,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions}
import org.apache.spark.sql.execution.LogicalRDD
import org.apache.spark.sql.execution.command.DDLUtils
+import org.apache.spark.sql.execution.datasources.csv._
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.datasources.jdbc._
import org.apache.spark.sql.execution.datasources.json.JsonInferSchema
@@ -261,7 +262,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
}
/**
- * Loads a JSON file and returns the results as a `DataFrame`.
+ * Loads JSON files and returns the results as a `DataFrame`.
*
* JSON Lines (newline-delimited JSON) is supported by
* default. For JSON (one record per file), set the `wholeFile` option to true.
@@ -368,14 +369,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
createParser)
}
- // Check a field requirement for corrupt records here to throw an exception in a driver side
- schema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord).foreach { corruptFieldIndex =>
- val f = schema(corruptFieldIndex)
- if (f.dataType != StringType || !f.nullable) {
- throw new AnalysisException(
- "The field for corrupt records must be string type and nullable")
- }
- }
+ verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord)
val parsed = jsonDataset.rdd.mapPartitions { iter =>
val parser = new JacksonParser(schema, parsedOptions)
@@ -399,7 +393,52 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
}
/**
- * Loads a CSV file and returns the result as a `DataFrame`.
+ * Loads an `Dataset[String]` storing CSV rows and returns the result as a `DataFrame`.
+ *
+ * If the schema is not specified using `schema` function and `inferSchema` option is enabled,
+ * this function goes through the input once to determine the input schema.
+ *
+ * If the schema is not specified using `schema` function and `inferSchema` option is disabled,
+ * it determines the columns as string types and it reads only the first line to determine the
+ * names and the number of fields.
+ *
+ * @param csvDataset input Dataset with one CSV row per record
+ * @since 2.2.0
+ */
+ def csv(csvDataset: Dataset[String]): DataFrame = {
+ val parsedOptions: CSVOptions = new CSVOptions(
+ extraOptions.toMap,
+ sparkSession.sessionState.conf.sessionLocalTimeZone)
+ val filteredLines: Dataset[String] =
+ CSVUtils.filterCommentAndEmpty(csvDataset, parsedOptions)
+ val maybeFirstLine: Option[String] = filteredLines.take(1).headOption
+
+ val schema = userSpecifiedSchema.getOrElse {
+ TextInputCSVDataSource.inferFromDataset(
+ sparkSession,
+ csvDataset,
+ maybeFirstLine,
+ parsedOptions)
+ }
+
+ verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord)
+
+ val linesWithoutHeader: RDD[String] = maybeFirstLine.map { firstLine =>
+ filteredLines.rdd.mapPartitions(CSVUtils.filterHeaderLine(_, firstLine, parsedOptions))
+ }.getOrElse(filteredLines.rdd)
+
+ val parsed = linesWithoutHeader.mapPartitions { iter =>
+ val parser = new UnivocityParser(schema, parsedOptions)
+ iter.flatMap(line => parser.parse(line))
+ }
+
+ Dataset.ofRows(
+ sparkSession,
+ LogicalRDD(schema.toAttributes, parsed)(sparkSession))
+ }
+
+ /**
+ * Loads CSV files and returns the result as a `DataFrame`.
*
* This function will go through the input once to determine the input schema if `inferSchema`
* is enabled. To avoid going through the entire data once, disable `inferSchema` option or
@@ -510,7 +549,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
}
/**
- * Loads an ORC file and returns the result as a `DataFrame`.
+ * Loads ORC files and returns the result as a `DataFrame`.
*
* @param paths input paths
* @since 2.0.0
@@ -604,6 +643,22 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
}
}
+ /**
+ * A convenient function for schema validation in datasources supporting
+ * `columnNameOfCorruptRecord` as an option.
+ */
+ private def verifyColumnNameOfCorruptRecord(
+ schema: StructType,
+ columnNameOfCorruptRecord: String): Unit = {
+ schema.getFieldIndex(columnNameOfCorruptRecord).foreach { corruptFieldIndex =>
+ val f = schema(corruptFieldIndex)
+ if (f.dataType != StringType || !f.nullable) {
+ throw new AnalysisException(
+ "The field for corrupt records must be string type and nullable")
+ }
+ }
+ }
+
///////////////////////////////////////////////////////////////////////////////////////
// Builder pattern config options
///////////////////////////////////////////////////////////////////////////////////////
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 1b0462359607..520663f62440 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -36,6 +36,7 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.analysis._
+import org.apache.spark.sql.catalyst.catalog.CatalogRelation
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
@@ -563,7 +564,7 @@ class Dataset[T] private[sql](
* @param eventTime the name of the column that contains the event time of the row.
* @param delayThreshold the minimum delay to wait to data to arrive late, relative to the latest
* record that has been processed in the form of an interval
- * (e.g. "1 minute" or "5 hours").
+ * (e.g. "1 minute" or "5 hours"). NOTE: This should not be negative.
*
* @group streaming
* @since 2.1.0
@@ -576,6 +577,8 @@ class Dataset[T] private[sql](
val parsedDelay =
Option(CalendarInterval.fromString("interval " + delayThreshold))
.getOrElse(throw new AnalysisException(s"Unable to parse time delay '$delayThreshold'"))
+ require(parsedDelay.milliseconds >= 0 && parsedDelay.months >= 0,
+ s"delay threshold ($delayThreshold) should not be negative.")
EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, logicalPlan)
}
@@ -1093,7 +1096,7 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
def as(alias: String): Dataset[T] = withTypedPlan {
- SubqueryAlias(alias, logicalPlan, None)
+ SubqueryAlias(alias, logicalPlan)
}
/**
@@ -2441,11 +2444,11 @@ class Dataset[T] private[sql](
}
/**
- * Returns a new Dataset that has exactly `numPartitions` partitions.
- * Similar to coalesce defined on an `RDD`, this operation results in a narrow dependency, e.g.
- * if you go from 1000 partitions to 100 partitions, there will not be a shuffle, instead each of
- * the 100 new partitions will claim 10 of the current partitions. If a larger number of
- * partitions is requested, it will stay at the current number of partitions.
+ * Returns a new Dataset that has exactly `numPartitions` partitions, when the fewer partitions
+ * are requested. If a larger number of partitions is requested, it will stay at the current
+ * number of partitions. Similar to coalesce defined on an `RDD`, this operation results in
+ * a narrow dependency, e.g. if you go from 1000 partitions to 100 partitions, there will not
+ * be a shuffle, instead each of the 100 new partitions will claim 10 of the current partitions.
*
* However, if you're doing a drastic coalesce, e.g. to numPartitions = 1,
* this may result in your computation taking place on fewer nodes than
@@ -2732,6 +2735,8 @@ class Dataset[T] private[sql](
fsBasedRelation.inputFiles
case fr: FileRelation =>
fr.inputFiles
+ case r: CatalogRelation if DDLUtils.isHiveTable(r.tableMeta) =>
+ r.tableMeta.storage.locationUri.map(_.toString).toArray
}.flatten
files.toSet.toArray
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala
index 1e8ba51e59e3..bd8dd6ea3fe0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala
@@ -46,4 +46,10 @@ class ExperimentalMethods private[sql]() {
@volatile var extraOptimizations: Seq[Rule[LogicalPlan]] = Nil
+ override def clone(): ExperimentalMethods = {
+ val result = new ExperimentalMethods
+ result.extraStrategies = extraStrategies
+ result.extraOptimizations = extraOptimizations
+ result
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index 3a548c251f5b..ab956ffd642e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -24,8 +24,10 @@ import org.apache.spark.api.java.function._
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CreateStruct}
import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.expressions.ReduceAggregator
+import org.apache.spark.sql.streaming.OutputMode
/**
* :: Experimental ::
@@ -238,8 +240,16 @@ class KeyValueGroupedDataset[K, V] private[sql](
@InterfaceStability.Evolving
def mapGroupsWithState[S: Encoder, U: Encoder](
func: (K, Iterator[V], KeyedState[S]) => U): Dataset[U] = {
- flatMapGroupsWithState[S, U](
- (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func(key, it, s)))
+ val flatMapFunc = (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func(key, it, s))
+ Dataset[U](
+ sparkSession,
+ FlatMapGroupsWithState[K, V, S, U](
+ flatMapFunc.asInstanceOf[(Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any]],
+ groupingAttributes,
+ dataAttributes,
+ OutputMode.Update,
+ isMapGroupsWithState = true,
+ child = logicalPlan))
}
/**
@@ -267,8 +277,8 @@ class KeyValueGroupedDataset[K, V] private[sql](
func: MapGroupsWithStateFunction[K, V, S, U],
stateEncoder: Encoder[S],
outputEncoder: Encoder[U]): Dataset[U] = {
- flatMapGroupsWithState[S, U](
- (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func.call(key, it.asJava, s))
+ mapGroupsWithState[S, U](
+ (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s)
)(stateEncoder, outputEncoder)
}
@@ -284,6 +294,8 @@ class KeyValueGroupedDataset[K, V] private[sql](
*
* @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
* @tparam U The type of the output objects. Must be encodable to Spark SQL types.
+ * @param func Function to be called on every group.
+ * @param outputMode The output mode of the function.
*
* See [[Encoder]] for more details on what types are encodable to Spark SQL.
* @since 2.1.1
@@ -291,14 +303,44 @@ class KeyValueGroupedDataset[K, V] private[sql](
@Experimental
@InterfaceStability.Evolving
def flatMapGroupsWithState[S: Encoder, U: Encoder](
- func: (K, Iterator[V], KeyedState[S]) => Iterator[U]): Dataset[U] = {
+ func: (K, Iterator[V], KeyedState[S]) => Iterator[U], outputMode: OutputMode): Dataset[U] = {
+ if (outputMode != OutputMode.Append && outputMode != OutputMode.Update) {
+ throw new IllegalArgumentException("The output mode of function should be append or update")
+ }
Dataset[U](
sparkSession,
- MapGroupsWithState[K, V, S, U](
+ FlatMapGroupsWithState[K, V, S, U](
func.asInstanceOf[(Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any]],
groupingAttributes,
dataAttributes,
- logicalPlan))
+ outputMode,
+ isMapGroupsWithState = false,
+ child = logicalPlan))
+ }
+
+ /**
+ * ::Experimental::
+ * (Scala-specific)
+ * Applies the given function to each group of data, while maintaining a user-defined per-group
+ * state. The result Dataset will represent the objects returned by the function.
+ * For a static batch Dataset, the function will be invoked once per group. For a streaming
+ * Dataset, the function will be invoked for each group repeatedly in every trigger, and
+ * updates to each group's state will be saved across invocations.
+ * See [[KeyedState]] for more details.
+ *
+ * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
+ * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
+ * @param func Function to be called on every group.
+ * @param outputMode The output mode of the function.
+ *
+ * See [[Encoder]] for more details on what types are encodable to Spark SQL.
+ * @since 2.1.1
+ */
+ @Experimental
+ @InterfaceStability.Evolving
+ def flatMapGroupsWithState[S: Encoder, U: Encoder](
+ func: (K, Iterator[V], KeyedState[S]) => Iterator[U], outputMode: String): Dataset[U] = {
+ flatMapGroupsWithState(func, InternalOutputModes(outputMode))
}
/**
@@ -314,6 +356,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
* @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
* @tparam U The type of the output objects. Must be encodable to Spark SQL types.
* @param func Function to be called on every group.
+ * @param outputMode The output mode of the function.
* @param stateEncoder Encoder for the state type.
* @param outputEncoder Encoder for the output type.
*
@@ -324,13 +367,45 @@ class KeyValueGroupedDataset[K, V] private[sql](
@InterfaceStability.Evolving
def flatMapGroupsWithState[S, U](
func: FlatMapGroupsWithStateFunction[K, V, S, U],
+ outputMode: OutputMode,
stateEncoder: Encoder[S],
outputEncoder: Encoder[U]): Dataset[U] = {
flatMapGroupsWithState[S, U](
- (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s).asScala
+ (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s).asScala,
+ outputMode
)(stateEncoder, outputEncoder)
}
+ /**
+ * ::Experimental::
+ * (Java-specific)
+ * Applies the given function to each group of data, while maintaining a user-defined per-group
+ * state. The result Dataset will represent the objects returned by the function.
+ * For a static batch Dataset, the function will be invoked once per group. For a streaming
+ * Dataset, the function will be invoked for each group repeatedly in every trigger, and
+ * updates to each group's state will be saved across invocations.
+ * See [[KeyedState]] for more details.
+ *
+ * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
+ * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
+ * @param func Function to be called on every group.
+ * @param outputMode The output mode of the function.
+ * @param stateEncoder Encoder for the state type.
+ * @param outputEncoder Encoder for the output type.
+ *
+ * See [[Encoder]] for more details on what types are encodable to Spark SQL.
+ * @since 2.1.1
+ */
+ @Experimental
+ @InterfaceStability.Evolving
+ def flatMapGroupsWithState[S, U](
+ func: FlatMapGroupsWithStateFunction[K, V, S, U],
+ outputMode: String,
+ stateEncoder: Encoder[S],
+ outputEncoder: Encoder[U]): Dataset[U] = {
+ flatMapGroupsWithState(func, InternalOutputModes(outputMode), stateEncoder, outputEncoder)
+ }
+
/**
* (Scala-specific)
* Reduces the elements of each group of data using the specified binary function.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index afc1827e7eec..49562578b23c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -21,7 +21,6 @@ import java.io.Closeable
import java.util.concurrent.atomic.AtomicReference
import scala.collection.JavaConverters._
-import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal
@@ -43,7 +42,7 @@ import org.apache.spark.sql.internal.{CatalogImpl, SessionState, SharedState}
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.streaming._
-import org.apache.spark.sql.types.{DataType, LongType, StructType}
+import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.util.ExecutionListenerManager
import org.apache.spark.util.Utils
@@ -67,15 +66,22 @@ import org.apache.spark.util.Utils
* .config("spark.some.config.option", "some-value")
* .getOrCreate()
* }}}
+ *
+ * @param sparkContext The Spark context associated with this Spark session.
+ * @param existingSharedState If supplied, use the existing shared state
+ * instead of creating a new one.
+ * @param parentSessionState If supplied, inherit all session state (i.e. temporary
+ * views, SQL config, UDFs etc) from parent.
*/
@InterfaceStability.Stable
class SparkSession private(
@transient val sparkContext: SparkContext,
- @transient private val existingSharedState: Option[SharedState])
+ @transient private val existingSharedState: Option[SharedState],
+ @transient private val parentSessionState: Option[SessionState])
extends Serializable with Closeable with Logging { self =>
private[sql] def this(sc: SparkContext) {
- this(sc, None)
+ this(sc, None, None)
}
sparkContext.assertNotStopped()
@@ -108,6 +114,7 @@ class SparkSession private(
/**
* State isolated across sessions, including SQL configurations, temporary tables, registered
* functions, and everything else that accepts a [[org.apache.spark.sql.internal.SQLConf]].
+ * If `parentSessionState` is not null, the `SessionState` will be a copy of the parent.
*
* This is internal to Spark and there is no guarantee on interface stability.
*
@@ -116,9 +123,13 @@ class SparkSession private(
@InterfaceStability.Unstable
@transient
lazy val sessionState: SessionState = {
- SparkSession.reflect[SessionState, SparkSession](
- SparkSession.sessionStateClassName(sparkContext.conf),
- self)
+ parentSessionState
+ .map(_.clone(this))
+ .getOrElse {
+ SparkSession.instantiateSessionState(
+ SparkSession.sessionStateClassName(sparkContext.conf),
+ self)
+ }
}
/**
@@ -208,7 +219,25 @@ class SparkSession private(
* @since 2.0.0
*/
def newSession(): SparkSession = {
- new SparkSession(sparkContext, Some(sharedState))
+ new SparkSession(sparkContext, Some(sharedState), parentSessionState = None)
+ }
+
+ /**
+ * Create an identical copy of this `SparkSession`, sharing the underlying `SparkContext`
+ * and shared state. All the state of this session (i.e. SQL configurations, temporary tables,
+ * registered functions) is copied over, and the cloned session is set up with the same shared
+ * state as this session. The cloned session is independent of this session, that is, any
+ * non-global change in either session is not reflected in the other.
+ *
+ * @note Other than the `SparkContext`, all shared state is initialized lazily.
+ * This method will force the initialization of the shared state to ensure that parent
+ * and child sessions are set up with the same shared state. If the underlying catalog
+ * implementation is Hive, this will initialize the metastore, which may take some time.
+ */
+ private[sql] def cloneSession(): SparkSession = {
+ val result = new SparkSession(sparkContext, Some(sharedState), Some(sessionState))
+ result.sessionState // force copy of SessionState
+ result
}
@@ -971,16 +1000,18 @@ object SparkSession {
}
/**
- * Helper method to create an instance of [[T]] using a single-arg constructor that
- * accepts an [[Arg]].
+ * Helper method to create an instance of `SessionState` based on `className` from conf.
+ * The result is either `SessionState` or `HiveSessionState`.
*/
- private def reflect[T, Arg <: AnyRef](
+ private def instantiateSessionState(
className: String,
- ctorArg: Arg)(implicit ctorArgTag: ClassTag[Arg]): T = {
+ sparkSession: SparkSession): SessionState = {
+
try {
+ // get `SessionState.apply(SparkSession)`
val clazz = Utils.classForName(className)
- val ctor = clazz.getDeclaredConstructor(ctorArgTag.runtimeClass)
- ctor.newInstance(ctorArg).asInstanceOf[T]
+ val method = clazz.getMethod("apply", sparkSession.getClass)
+ method.invoke(null, sparkSession).asInstanceOf[SessionState]
} catch {
case NonFatal(e) =>
throw new IllegalArgumentException(s"Error while instantiating '$className':", e)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
index 80138510dc9e..0ea806d6cb50 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.execution
import java.util.concurrent.locks.ReentrantReadWriteLock
+import scala.collection.JavaConverters._
+
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.internal.Logging
@@ -45,7 +47,7 @@ case class CachedData(plan: LogicalPlan, cachedRepresentation: InMemoryRelation)
class CacheManager extends Logging {
@transient
- private val cachedData = new scala.collection.mutable.ArrayBuffer[CachedData]
+ private val cachedData = new java.util.LinkedList[CachedData]
@transient
private val cacheLock = new ReentrantReadWriteLock
@@ -70,7 +72,7 @@ class CacheManager extends Logging {
/** Clears all cached tables. */
def clearCache(): Unit = writeLock {
- cachedData.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist())
+ cachedData.asScala.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist())
cachedData.clear()
}
@@ -88,46 +90,81 @@ class CacheManager extends Logging {
query: Dataset[_],
tableName: Option[String] = None,
storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock {
- val planToCache = query.queryExecution.analyzed
+ val planToCache = query.logicalPlan
if (lookupCachedData(planToCache).nonEmpty) {
logWarning("Asked to cache already cached data.")
} else {
val sparkSession = query.sparkSession
- cachedData +=
- CachedData(
- planToCache,
- InMemoryRelation(
- sparkSession.sessionState.conf.useCompression,
- sparkSession.sessionState.conf.columnBatchSize,
- storageLevel,
- sparkSession.sessionState.executePlan(planToCache).executedPlan,
- tableName))
+ cachedData.add(CachedData(
+ planToCache,
+ InMemoryRelation(
+ sparkSession.sessionState.conf.useCompression,
+ sparkSession.sessionState.conf.columnBatchSize,
+ storageLevel,
+ sparkSession.sessionState.executePlan(planToCache).executedPlan,
+ tableName)))
}
}
/**
- * Tries to remove the data for the given [[Dataset]] from the cache.
- * No operation, if it's already uncached.
+ * Un-cache all the cache entries that refer to the given plan.
+ */
+ def uncacheQuery(query: Dataset[_], blocking: Boolean = true): Unit = writeLock {
+ uncacheQuery(query.sparkSession, query.logicalPlan, blocking)
+ }
+
+ /**
+ * Un-cache all the cache entries that refer to the given plan.
*/
- def uncacheQuery(query: Dataset[_], blocking: Boolean = true): Boolean = writeLock {
- val planToCache = query.queryExecution.analyzed
- val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan))
- val found = dataIndex >= 0
- if (found) {
- cachedData(dataIndex).cachedRepresentation.cachedColumnBuffers.unpersist(blocking)
- cachedData.remove(dataIndex)
+ def uncacheQuery(spark: SparkSession, plan: LogicalPlan, blocking: Boolean): Unit = writeLock {
+ val it = cachedData.iterator()
+ while (it.hasNext) {
+ val cd = it.next()
+ if (cd.plan.find(_.sameResult(plan)).isDefined) {
+ cd.cachedRepresentation.cachedColumnBuffers.unpersist(blocking)
+ it.remove()
+ }
}
- found
+ }
+
+ /**
+ * Tries to re-cache all the cache entries that refer to the given plan.
+ */
+ def recacheByPlan(spark: SparkSession, plan: LogicalPlan): Unit = writeLock {
+ recacheByCondition(spark, _.find(_.sameResult(plan)).isDefined)
+ }
+
+ private def recacheByCondition(spark: SparkSession, condition: LogicalPlan => Boolean): Unit = {
+ val it = cachedData.iterator()
+ val needToRecache = scala.collection.mutable.ArrayBuffer.empty[CachedData]
+ while (it.hasNext) {
+ val cd = it.next()
+ if (condition(cd.plan)) {
+ cd.cachedRepresentation.cachedColumnBuffers.unpersist()
+ // Remove the cache entry before we create a new one, so that we can have a different
+ // physical plan.
+ it.remove()
+ val newCache = InMemoryRelation(
+ useCompression = cd.cachedRepresentation.useCompression,
+ batchSize = cd.cachedRepresentation.batchSize,
+ storageLevel = cd.cachedRepresentation.storageLevel,
+ child = spark.sessionState.executePlan(cd.plan).executedPlan,
+ tableName = cd.cachedRepresentation.tableName)
+ needToRecache += cd.copy(cachedRepresentation = newCache)
+ }
+ }
+
+ needToRecache.foreach(cachedData.add)
}
/** Optionally returns cached data for the given [[Dataset]] */
def lookupCachedData(query: Dataset[_]): Option[CachedData] = readLock {
- lookupCachedData(query.queryExecution.analyzed)
+ lookupCachedData(query.logicalPlan)
}
/** Optionally returns cached data for the given [[LogicalPlan]]. */
def lookupCachedData(plan: LogicalPlan): Option[CachedData] = readLock {
- cachedData.find(cd => plan.sameResult(cd.plan))
+ cachedData.asScala.find(cd => plan.sameResult(cd.plan))
}
/** Replaces segments of the given logical plan with cached versions where possible. */
@@ -145,40 +182,17 @@ class CacheManager extends Logging {
}
/**
- * Invalidates the cache of any data that contains `plan`. Note that it is possible that this
- * function will over invalidate.
- */
- def invalidateCache(plan: LogicalPlan): Unit = writeLock {
- cachedData.foreach {
- case data if data.plan.collect { case p if p.sameResult(plan) => p }.nonEmpty =>
- data.cachedRepresentation.recache()
- case _ =>
- }
- }
-
- /**
- * Invalidates the cache of any data that contains `resourcePath` in one or more
+ * Tries to re-cache all the cache entries that contain `resourcePath` in one or more
* `HadoopFsRelation` node(s) as part of its logical plan.
*/
- def invalidateCachedPath(
- sparkSession: SparkSession, resourcePath: String): Unit = writeLock {
+ def recacheByPath(spark: SparkSession, resourcePath: String): Unit = writeLock {
val (fs, qualifiedPath) = {
val path = new Path(resourcePath)
- val fs = path.getFileSystem(sparkSession.sessionState.newHadoopConf())
- (fs, path.makeQualified(fs.getUri, fs.getWorkingDirectory))
+ val fs = path.getFileSystem(spark.sessionState.newHadoopConf())
+ (fs, fs.makeQualified(path))
}
- cachedData.filter {
- case data if data.plan.find(lookupAndRefresh(_, fs, qualifiedPath)).isDefined => true
- case _ => false
- }.foreach { data =>
- val dataIndex = cachedData.indexWhere(cd => data.plan.sameResult(cd.plan))
- if (dataIndex >= 0) {
- data.cachedRepresentation.cachedColumnBuffers.unpersist(blocking = true)
- cachedData.remove(dataIndex)
- }
- sparkSession.sharedState.cacheManager.cacheQuery(Dataset.ofRows(sparkSession, data.plan))
- }
+ recacheByCondition(spark, _.find(lookupAndRefresh(_, fs, qualifiedPath)).isDefined)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index 6ec2f4d84086..9a3656ddc79f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -46,9 +46,14 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
protected def planner = sparkSession.sessionState.planner
def assertAnalyzed(): Unit = {
- try sparkSession.sessionState.analyzer.checkAnalysis(analyzed) catch {
+ // Analyzer is invoked outside the try block to avoid calling it again from within the
+ // catch block below.
+ analyzed
+ try {
+ sparkSession.sessionState.analyzer.checkAnalysis(analyzed)
+ } catch {
case e: AnalysisException =>
- val ae = new AnalysisException(e.message, e.line, e.startPosition, Some(analyzed))
+ val ae = new AnalysisException(e.message, e.line, e.startPosition, Option(analyzed))
ae.setStackTrace(e.getStackTrace)
throw ae
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
index cc576bbc4c80..f98ae82574d2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
@@ -177,6 +177,8 @@ case class SortExec(
""".stripMargin.trim
}
+ protected override val shouldStopRequired = false
+
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
s"""
|${row.code}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
index 65df68868939..00d1d6d2701f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
@@ -386,7 +386,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder {
"LOCATION and 'path' in OPTIONS are both used to indicate the custom table path, " +
"you can only specify one of them.", ctx)
}
- val customLocation = storage.locationUri.orElse(location)
+ val customLocation = storage.locationUri.orElse(location.map(CatalogUtils.stringToURI(_)))
val tableType = if (customLocation.isDefined) {
CatalogTableType.EXTERNAL
@@ -1080,8 +1080,10 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder {
if (external && location.isEmpty) {
operationNotAllowed("CREATE EXTERNAL TABLE must be accompanied by LOCATION", ctx)
}
+
+ val locUri = location.map(CatalogUtils.stringToURI(_))
val storage = CatalogStorageFormat(
- locationUri = location,
+ locationUri = locUri,
inputFormat = fileStorage.inputFormat.orElse(defaultStorage.inputFormat),
outputFormat = fileStorage.outputFormat.orElse(defaultStorage.outputFormat),
serde = rowStorage.serde.orElse(fileStorage.serde).orElse(defaultStorage.serde),
@@ -1132,7 +1134,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder {
// At here, both rowStorage.serdeProperties and fileStorage.serdeProperties
// are empty Maps.
val newTableDesc = tableDesc.copy(
- storage = CatalogStorageFormat.empty.copy(locationUri = location),
+ storage = CatalogStorageFormat.empty.copy(locationUri = locUri),
provider = Some(conf.defaultDataSourceName))
CreateTable(newTableDesc, mode, Some(q))
} else {
@@ -1329,6 +1331,15 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder {
if (ctx.identifierList != null) {
operationNotAllowed("CREATE VIEW ... PARTITIONED ON", ctx)
} else {
+ // CREATE VIEW ... AS INSERT INTO is not allowed.
+ ctx.query.queryNoWith match {
+ case s: SingleInsertQueryContext if s.insertInto != null =>
+ operationNotAllowed("CREATE VIEW ... AS INSERT INTO", ctx)
+ case _: MultiInsertQueryContext =>
+ operationNotAllowed("CREATE VIEW ... AS FROM ... [INSERT INTO ...]+", ctx)
+ case _ => // OK
+ }
+
val userSpecifiedColumns = Option(ctx.identifierCommentList).toSeq.flatMap { icl =>
icl.identifierComment.asScala.map { ic =>
ic.identifier.getText -> Option(ic.STRING).map(string)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 20bf4925dbec..0f7aa3709c1c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -326,14 +326,24 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
/**
- * Strategy to convert MapGroupsWithState logical operator to physical operator
+ * Strategy to convert [[FlatMapGroupsWithState]] logical operator to physical operator
* in streaming plans. Conversion for batch plans is handled by [[BasicOperators]].
*/
object MapGroupsWithStateStrategy extends Strategy {
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case MapGroupsWithState(
- f, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, stateDeser, stateSer, child) =>
- val execPlan = MapGroupsWithStateExec(
+ case FlatMapGroupsWithState(
+ f,
+ keyDeser,
+ valueDeser,
+ groupAttr,
+ dataAttr,
+ outputAttr,
+ stateDeser,
+ stateSer,
+ outputMode,
+ child,
+ _) =>
+ val execPlan = FlatMapGroupsWithStateExec(
f, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, stateDeser, stateSer,
planLater(child))
execPlan :: Nil
@@ -381,7 +391,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.AppendColumnsWithObjectExec(f, childSer, newSer, planLater(child)) :: Nil
case logical.MapGroups(f, key, value, grouping, data, objAttr, child) =>
execution.MapGroupsExec(f, key, value, grouping, data, objAttr, planLater(child)) :: Nil
- case logical.MapGroupsWithState(f, key, value, grouping, data, output, _, _, child) =>
+ case logical.FlatMapGroupsWithState(
+ f, key, value, grouping, data, output, _, _, _, child, _) =>
execution.MapGroupsExec(f, key, value, grouping, data, output, planLater(child)) :: Nil
case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) =>
execution.CoGroupExec(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
index c58474eba05d..c31fd92447c0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
@@ -206,6 +206,21 @@ trait CodegenSupport extends SparkPlan {
def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
throw new UnsupportedOperationException
}
+
+ /**
+ * For optimization to suppress shouldStop() in a loop of WholeStageCodegen.
+ * Returning true means we need to insert shouldStop() into the loop producing rows, if any.
+ */
+ def isShouldStopRequired: Boolean = {
+ return shouldStopRequired && (this.parent == null || this.parent.isShouldStopRequired)
+ }
+
+ /**
+ * Set to false if this plan consumes all rows produced by children but doesn't output row
+ * to buffer by calling append(), so the children don't require shouldStop()
+ * in the loop of producing rows.
+ */
+ protected def shouldStopRequired: Boolean = true
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
index 4529ed067e56..68c8e6ce62cb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
@@ -238,6 +238,8 @@ case class HashAggregateExec(
""".stripMargin
}
+ protected override val shouldStopRequired = false
+
private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = {
// only have DeclarativeAggregate
val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index 87e90ed685cc..d876688a8aab 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -387,8 +387,8 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
// How many values should be generated in the next batch.
val nextBatchTodo = ctx.freshName("nextBatchTodo")
- // The default size of a batch.
- val batchSize = 1000L
+ // The default size of a batch, which must be positive integer
+ val batchSize = 1000
ctx.addNewFunction("initRange",
s"""
@@ -434,6 +434,15 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
val input = ctx.freshName("input")
// Right now, Range is only used when there is one upstream.
ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];")
+
+ val localIdx = ctx.freshName("localIdx")
+ val localEnd = ctx.freshName("localEnd")
+ val range = ctx.freshName("range")
+ val shouldStop = if (isShouldStopRequired) {
+ s"if (shouldStop()) { $number = $value + ${step}L; return; }"
+ } else {
+ "// shouldStop check is eliminated"
+ }
s"""
| // initialize Range
| if (!$initTerm) {
@@ -442,11 +451,15 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
| }
|
| while (true) {
- | while ($number != $batchEnd) {
- | long $value = $number;
- | $number += ${step}L;
- | ${consume(ctx, Seq(ev))}
- | if (shouldStop()) return;
+ | long $range = $batchEnd - $number;
+ | if ($range != 0L) {
+ | int $localEnd = (int)($range / ${step}L);
+ | for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) {
+ | long $value = ((long)$localIdx * ${step}L) + $number;
+ | ${consume(ctx, Seq(ev))}
+ | $shouldStop
+ | }
+ | $number = $batchEnd;
| }
|
| if ($taskContext.isInterrupted()) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
index 37bd95e73778..36037ac00372 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
@@ -85,12 +85,6 @@ case class InMemoryRelation(
buildBuffers()
}
- def recache(): Unit = {
- _cachedColumnBuffers.unpersist()
- _cachedColumnBuffers = null
- buildBuffers()
- }
-
private def buildBuffers(): Unit = {
val output = child.output
val cached = child.execute().mapPartitionsInternal { rowIterator =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala
index 7afa4e78a378..5f12830ee621 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala
@@ -60,6 +60,23 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm
}
(keyValueOutput, runFunc)
+ case Some((SQLConf.Replaced.MAPREDUCE_JOB_REDUCES, Some(value))) =>
+ val runFunc = (sparkSession: SparkSession) => {
+ logWarning(
+ s"Property ${SQLConf.Replaced.MAPREDUCE_JOB_REDUCES} is Hadoop's property, " +
+ s"automatically converted to ${SQLConf.SHUFFLE_PARTITIONS.key} instead.")
+ if (value.toInt < 1) {
+ val msg =
+ s"Setting negative ${SQLConf.Replaced.MAPREDUCE_JOB_REDUCES} for automatically " +
+ "determining the number of reducers is not supported."
+ throw new IllegalArgumentException(msg)
+ } else {
+ sparkSession.conf.set(SQLConf.SHUFFLE_PARTITIONS.key, value)
+ Seq(Row(SQLConf.SHUFFLE_PARTITIONS.key, value))
+ }
+ }
+ (keyValueOutput, runFunc)
+
case Some((key @ SetCommand.VariableName(name), Some(value))) =>
val runFunc = (sparkSession: SparkSession) => {
sparkSession.conf.set(name, value)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala
index d835b521166a..2d890118ae0a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala
@@ -17,6 +17,10 @@
package org.apache.spark.sql.execution.command
+import java.net.URI
+
+import org.apache.hadoop.fs.Path
+
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
@@ -54,7 +58,7 @@ case class CreateDataSourceTableCommand(table: CatalogTable, ignoreIfExists: Boo
// Create the relation to validate the arguments before writing the metadata to the metastore,
// and infer the table schema and partition if users didn't specify schema in CREATE TABLE.
- val pathOption = table.storage.locationUri.map("path" -> _)
+ val pathOption = table.storage.locationUri.map("path" -> CatalogUtils.URIToString(_))
// Fill in some default table options from the session conf
val tableWithDefaultOptions = table.copy(
identifier = table.identifier.copy(
@@ -69,7 +73,8 @@ case class CreateDataSourceTableCommand(table: CatalogTable, ignoreIfExists: Boo
className = table.provider.get,
bucketSpec = table.bucketSpec,
options = table.storage.properties ++ pathOption,
- catalogTable = Some(tableWithDefaultOptions)).resolveRelation()
+ // As discussed in SPARK-19583, we don't check if the location is existed
+ catalogTable = Some(tableWithDefaultOptions)).resolveRelation(checkFilesExist = false)
val partitionColumnNames = if (table.schema.nonEmpty) {
table.partitionColumnNames
@@ -175,12 +180,12 @@ case class CreateDataSourceTableAsSelectCommand(
private def saveDataIntoTable(
session: SparkSession,
table: CatalogTable,
- tableLocation: Option[String],
+ tableLocation: Option[URI],
data: LogicalPlan,
mode: SaveMode,
tableExists: Boolean): BaseRelation = {
// Create the relation based on the input logical plan: `data`.
- val pathOption = tableLocation.map("path" -> _)
+ val pathOption = tableLocation.map("path" -> CatalogUtils.URIToString(_))
val dataSource = DataSource(
session,
className = table.provider.get,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
index 82cbb4aa4744..9d3c55060dfb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
@@ -66,7 +66,7 @@ case class CreateDatabaseCommand(
CatalogDatabase(
databaseName,
comment.getOrElse(""),
- path.getOrElse(catalog.getDefaultDBPath(databaseName)),
+ path.map(CatalogUtils.stringToURI(_)).getOrElse(catalog.getDefaultDBPath(databaseName)),
props),
ifNotExists)
Seq.empty[Row]
@@ -146,7 +146,7 @@ case class DescribeDatabaseCommand(
val result =
Row("Database Name", dbMetadata.name) ::
Row("Description", dbMetadata.description) ::
- Row("Location", dbMetadata.locationUri) :: Nil
+ Row("Location", CatalogUtils.URIToString(dbMetadata.locationUri)) :: Nil
if (extended) {
val properties =
@@ -199,8 +199,7 @@ case class DropTableCommand(
}
}
try {
- sparkSession.sharedState.cacheManager.uncacheQuery(
- sparkSession.table(tableName.quotedString))
+ sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName))
} catch {
case _: NoSuchTableException if ifExists =>
case NonFatal(e) => log.warn(e.toString, e)
@@ -426,7 +425,8 @@ case class AlterTableAddPartitionCommand(
table.identifier.quotedString,
sparkSession.sessionState.conf.resolver)
// inherit table storage format (possibly except for location)
- CatalogTablePartition(normalizedSpec, table.storage.copy(locationUri = location))
+ CatalogTablePartition(normalizedSpec, table.storage.copy(
+ locationUri = location.map(CatalogUtils.stringToURI(_))))
}
catalog.createPartitions(table.identifier, parts, ignoreIfExists = ifNotExists)
Seq.empty[Row]
@@ -710,7 +710,7 @@ case class AlterTableRecoverPartitionsCommand(
// inherit table storage format (possibly except for location)
CatalogTablePartition(
spec,
- table.storage.copy(locationUri = Some(location.toUri.toString)),
+ table.storage.copy(locationUri = Some(location.toUri)),
params)
}
spark.sessionState.catalog.createPartitions(tableName, parts, ignoreIfExists = true)
@@ -741,6 +741,7 @@ case class AlterTableSetLocationCommand(
override def run(sparkSession: SparkSession): Seq[Row] = {
val catalog = sparkSession.sessionState.catalog
val table = catalog.getTableMetadata(tableName)
+ val locUri = CatalogUtils.stringToURI(location)
DDLUtils.verifyAlterTableType(catalog, table, isView = false)
partitionSpec match {
case Some(spec) =>
@@ -748,11 +749,11 @@ case class AlterTableSetLocationCommand(
sparkSession, table, "ALTER TABLE ... SET LOCATION")
// Partition spec is specified, so we set the location only for this partition
val part = catalog.getPartition(table.identifier, spec)
- val newPart = part.copy(storage = part.storage.copy(locationUri = Some(location)))
+ val newPart = part.copy(storage = part.storage.copy(locationUri = Some(locUri)))
catalog.alterPartitions(table.identifier, Seq(newPart))
case None =>
// No partition spec is specified, so we set the location for the table itself
- catalog.alterTable(table.withNewStorage(locationUri = Some(location)))
+ catalog.alterTable(table.withNewStorage(locationUri = Some(locUri)))
}
Seq.empty[Row]
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
index 3e80916104bd..86394ff23e37 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
@@ -79,7 +79,8 @@ case class CreateTableLikeCommand(
CatalogTable(
identifier = targetTable,
tableType = tblType,
- storage = sourceTableDesc.storage.copy(locationUri = location),
+ storage = sourceTableDesc.storage.copy(
+ locationUri = location.map(CatalogUtils.stringToURI(_))),
schema = sourceTableDesc.schema,
provider = newProvider,
partitionColumnNames = sourceTableDesc.partitionColumnNames,
@@ -495,7 +496,8 @@ case class DescribeTableCommand(
append(buffer, "Owner:", table.owner, "")
append(buffer, "Create Time:", new Date(table.createTime).toString, "")
append(buffer, "Last Access Time:", new Date(table.lastAccessTime).toString, "")
- append(buffer, "Location:", table.storage.locationUri.getOrElse(""), "")
+ append(buffer, "Location:", table.storage.locationUri.map(CatalogUtils.URIToString(_))
+ .getOrElse(""), "")
append(buffer, "Table Type:", table.tableType.name, "")
table.stats.foreach(s => append(buffer, "Statistics:", s.simpleString, ""))
@@ -587,7 +589,8 @@ case class DescribeTableCommand(
append(buffer, "Partition Value:", s"[${partition.spec.values.mkString(", ")}]", "")
append(buffer, "Database:", table.database, "")
append(buffer, "Table:", tableIdentifier.table, "")
- append(buffer, "Location:", partition.storage.locationUri.getOrElse(""), "")
+ append(buffer, "Location:", partition.storage.locationUri.map(CatalogUtils.URIToString(_))
+ .getOrElse(""), "")
append(buffer, "Partition Parameters:", "", "")
partition.parameters.foreach { case (key, value) =>
append(buffer, s" $key", value, "")
@@ -953,7 +956,7 @@ case class ShowCreateTableCommand(table: TableIdentifier) extends RunnableComman
// when the table creation DDL contains the PATH option.
None
} else {
- Some(s"path '${escapeSingleQuotedString(location)}'")
+ Some(s"path '${escapeSingleQuotedString(CatalogUtils.URIToString(location))}'")
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala
index 921c84895598..00f0acab21aa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala
@@ -23,9 +23,9 @@ import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedRelation}
import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType}
-import org.apache.spark.sql.catalyst.expressions.Alias
+import org.apache.spark.sql.catalyst.expressions.{Alias, SubqueryExpression}
import org.apache.spark.sql.catalyst.plans.QueryPlan
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, View}
import org.apache.spark.sql.types.MetadataBuilder
@@ -154,6 +154,10 @@ case class CreateViewCommand(
} else if (tableMetadata.tableType != CatalogTableType.VIEW) {
throw new AnalysisException(s"$name is not a view")
} else if (replace) {
+ // Detect cyclic view reference on CREATE OR REPLACE VIEW.
+ val viewIdent = tableMetadata.identifier
+ checkCyclicViewReference(analyzedPlan, Seq(viewIdent), viewIdent)
+
// Handles `CREATE OR REPLACE VIEW v0 AS SELECT ...`
catalog.alterTable(prepareTable(sparkSession, analyzedPlan))
} else {
@@ -283,6 +287,10 @@ case class AlterViewAsCommand(
throw new AnalysisException(s"${viewMeta.identifier} is not a view.")
}
+ // Detect cyclic view reference on ALTER VIEW.
+ val viewIdent = viewMeta.identifier
+ checkCyclicViewReference(analyzedPlan, Seq(viewIdent), viewIdent)
+
val newProperties = generateViewProperties(viewMeta.properties, session, analyzedPlan)
val updatedViewMeta = viewMeta.copy(
@@ -358,4 +366,53 @@ object ViewHelper {
generateViewDefaultDatabase(viewDefaultDatabase) ++
generateQueryColumnNames(queryOutput)
}
+
+ /**
+ * Recursively search the logical plan to detect cyclic view references, throw an
+ * AnalysisException if cycle detected.
+ *
+ * A cyclic view reference is a cycle of reference dependencies, for example, if the following
+ * statements are executed:
+ * CREATE VIEW testView AS SELECT id FROM tbl
+ * CREATE VIEW testView2 AS SELECT id FROM testView
+ * ALTER VIEW testView AS SELECT * FROM testView2
+ * The view `testView` references `testView2`, and `testView2` also references `testView`,
+ * therefore a reference cycle (testView -> testView2 -> testView) exists.
+ *
+ * @param plan the logical plan we detect cyclic view references from.
+ * @param path the path between the altered view and current node.
+ * @param viewIdent the table identifier of the altered view, we compare two views by the
+ * `desc.identifier`.
+ */
+ def checkCyclicViewReference(
+ plan: LogicalPlan,
+ path: Seq[TableIdentifier],
+ viewIdent: TableIdentifier): Unit = {
+ plan match {
+ case v: View =>
+ val ident = v.desc.identifier
+ val newPath = path :+ ident
+ // If the table identifier equals to the `viewIdent`, current view node is the same with
+ // the altered view. We detect a view reference cycle, should throw an AnalysisException.
+ if (ident == viewIdent) {
+ throw new AnalysisException(s"Recursive view $viewIdent detected " +
+ s"(cycle: ${newPath.mkString(" -> ")})")
+ } else {
+ v.children.foreach { child =>
+ checkCyclicViewReference(child, newPath, viewIdent)
+ }
+ }
+ case _ =>
+ plan.children.foreach(child => checkCyclicViewReference(child, path, viewIdent))
+ }
+
+ // Detect cyclic references from subqueries.
+ plan.expressions.foreach { expr =>
+ expr match {
+ case s: SubqueryExpression =>
+ checkCyclicViewReference(s.plan, path, viewIdent)
+ case _ => // Do nothing.
+ }
+ }
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala
index 2068811661fe..d6c4b97ebd08 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution.datasources
+import java.net.URI
+
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
@@ -46,7 +48,7 @@ class CatalogFileIndex(
assert(table.identifier.database.isDefined,
"The table identifier must be qualified in CatalogFileIndex")
- private val baseLocation: Option[String] = table.storage.locationUri
+ private val baseLocation: Option[URI] = table.storage.locationUri
override def partitionSchema: StructType = table.partitionSchema
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index 4947dfda6fc7..c9384e44255b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -29,7 +29,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable}
+import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogUtils}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
@@ -597,6 +597,7 @@ object DataSource {
def buildStorageFormatFromOptions(options: Map[String, String]): CatalogStorageFormat = {
val path = CaseInsensitiveMap(options).get("path")
val optionsWithoutPath = options.filterKeys(_.toLowerCase != "path")
- CatalogStorageFormat.empty.copy(locationUri = path, properties = optionsWithoutPath)
+ CatalogStorageFormat.empty.copy(
+ locationUri = path.map(CatalogUtils.stringToURI(_)), properties = optionsWithoutPath)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index f694a0d6d724..bddf5af23e06 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -21,13 +21,15 @@ import java.util.concurrent.Callable
import scala.collection.mutable.ArrayBuffer
+import org.apache.hadoop.fs.Path
+
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.{CatalystConf, CatalystTypeConverters, InternalRow, QualifiedTableName, TableIdentifier}
import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala
import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.catalog.CatalogRelation
+import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogUtils}
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
@@ -220,7 +222,7 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan]
val plan = cache.get(qualifiedTableName, new Callable[LogicalPlan]() {
override def call(): LogicalPlan = {
- val pathOption = table.storage.locationUri.map("path" -> _)
+ val pathOption = table.storage.locationUri.map("path" -> CatalogUtils.URIToString(_))
val dataSource =
DataSource(
sparkSession,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
index 950e5ca0d621..30a09a9ad337 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
@@ -341,7 +341,7 @@ object FileFormatWriter extends Logging {
Seq(Cast(c, StringType, Option(desc.timeZoneId))),
Seq(StringType))
val str = If(IsNull(c), Literal(ExternalCatalogUtils.DEFAULT_PARTITION_NAME), escaped)
- val partitionName = Literal(c.name + "=") :: str :: Nil
+ val partitionName = Literal(ExternalCatalogUtils.escapePathName(c.name) + "=") :: str :: Nil
if (i == 0) partitionName else Literal(Path.SEPARATOR) :: partitionName
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala
index b2ff68a833fe..a813829d50cb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala
@@ -42,8 +42,9 @@ case class InsertIntoDataSourceCommand(
val df = sparkSession.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema)
relation.insert(df, overwrite)
- // Invalidate the cache.
- sparkSession.sharedState.cacheManager.invalidateCache(logicalRelation)
+ // Re-cache all cached plans(including this relation itself, if it's cached) that refer to this
+ // data source relation.
+ sparkSession.sharedState.cacheManager.recacheByPlan(sparkSession, logicalRelation)
Seq.empty[Row]
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
index 73e6abc6dad3..35ff924f27ce 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
@@ -17,12 +17,11 @@
package org.apache.spark.sql.execution.datasources.csv
-import java.io.InputStream
import java.nio.charset.{Charset, StandardCharsets}
-import com.univocity.parsers.csv.{CsvParser, CsvParserSettings}
+import com.univocity.parsers.csv.CsvParser
import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.{FileStatus, Path}
+import org.apache.hadoop.fs.FileStatus
import org.apache.hadoop.io.{LongWritable, Text}
import org.apache.hadoop.mapred.TextInputFormat
import org.apache.hadoop.mapreduce.Job
@@ -133,20 +132,34 @@ object TextInputCSVDataSource extends CSVDataSource {
sparkSession: SparkSession,
inputPaths: Seq[FileStatus],
parsedOptions: CSVOptions): Option[StructType] = {
- val csv: Dataset[String] = createBaseDataset(sparkSession, inputPaths, parsedOptions)
- val firstLine: String = CSVUtils.filterCommentAndEmpty(csv, parsedOptions).first()
- val firstRow = new CsvParser(parsedOptions.asParserSettings).parseLine(firstLine)
- val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
- val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions)
- val tokenRDD = csv.rdd.mapPartitions { iter =>
- val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions)
- val linesWithoutHeader =
- CSVUtils.filterHeaderLine(filteredLines, firstLine, parsedOptions)
- val parser = new CsvParser(parsedOptions.asParserSettings)
- linesWithoutHeader.map(parser.parseLine)
- }
+ val csv = createBaseDataset(sparkSession, inputPaths, parsedOptions)
+ val maybeFirstLine = CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption
+ Some(inferFromDataset(sparkSession, csv, maybeFirstLine, parsedOptions))
+ }
- Some(CSVInferSchema.infer(tokenRDD, header, parsedOptions))
+ /**
+ * Infers the schema from `Dataset` that stores CSV string records.
+ */
+ def inferFromDataset(
+ sparkSession: SparkSession,
+ csv: Dataset[String],
+ maybeFirstLine: Option[String],
+ parsedOptions: CSVOptions): StructType = maybeFirstLine match {
+ case Some(firstLine) =>
+ val firstRow = new CsvParser(parsedOptions.asParserSettings).parseLine(firstLine)
+ val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
+ val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions)
+ val tokenRDD = csv.rdd.mapPartitions { iter =>
+ val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions)
+ val linesWithoutHeader =
+ CSVUtils.filterHeaderLine(filteredLines, firstLine, parsedOptions)
+ val parser = new CsvParser(parsedOptions.asParserSettings)
+ linesWithoutHeader.map(parser.parseLine)
+ }
+ CSVInferSchema.infer(tokenRDD, header, parsedOptions)
+ case None =>
+ // If the first line could not be read, just return the empty schema.
+ StructType(Nil)
}
private def createBaseDataset(
@@ -190,28 +203,28 @@ object WholeFileCSVDataSource extends CSVDataSource {
sparkSession: SparkSession,
inputPaths: Seq[FileStatus],
parsedOptions: CSVOptions): Option[StructType] = {
- val csv: RDD[PortableDataStream] = createBaseRdd(sparkSession, inputPaths, parsedOptions)
- val maybeFirstRow: Option[Array[String]] = csv.flatMap { lines =>
+ val csv = createBaseRdd(sparkSession, inputPaths, parsedOptions)
+ csv.flatMap { lines =>
UnivocityParser.tokenizeStream(
CodecStreams.createInputStreamWithCloseResource(lines.getConfiguration, lines.getPath()),
- false,
+ shouldDropHeader = false,
new CsvParser(parsedOptions.asParserSettings))
- }.take(1).headOption
-
- if (maybeFirstRow.isDefined) {
- val firstRow = maybeFirstRow.get
- val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
- val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions)
- val tokenRDD = csv.flatMap { lines =>
- UnivocityParser.tokenizeStream(
- CodecStreams.createInputStreamWithCloseResource(lines.getConfiguration, lines.getPath()),
- parsedOptions.headerFlag,
- new CsvParser(parsedOptions.asParserSettings))
- }
- Some(CSVInferSchema.infer(tokenRDD, header, parsedOptions))
- } else {
- // If the first row could not be read, just return the empty schema.
- Some(StructType(Nil))
+ }.take(1).headOption match {
+ case Some(firstRow) =>
+ val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
+ val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions)
+ val tokenRDD = csv.flatMap { lines =>
+ UnivocityParser.tokenizeStream(
+ CodecStreams.createInputStreamWithCloseResource(
+ lines.getConfiguration,
+ lines.getPath()),
+ parsedOptions.headerFlag,
+ new CsvParser(parsedOptions.asParserSettings))
+ }
+ Some(CSVInferSchema.infer(tokenRDD, header, parsedOptions))
+ case None =>
+ // If the first row could not be read, just return the empty schema.
+ Some(StructType(Nil))
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
index 50503385ad6d..0b1e5dac2da6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
@@ -26,7 +26,7 @@ import org.apache.commons.lang3.time.FastDateFormat
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs, ParseModes}
-private[csv] class CSVOptions(
+class CSVOptions(
@transient private val parameters: CaseInsensitiveMap[String],
defaultTimeZoneId: String,
defaultColumnNameOfCorruptRecord: String)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
index 3b3b87e4354d..e42ea3fa391f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
@@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
-private[csv] class UnivocityParser(
+class UnivocityParser(
schema: StructType,
requiredSchema: StructType,
private val options: CSVOptions) extends Logging {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
index 828949eddc8e..5313c2f3746a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
@@ -475,71 +475,6 @@ object ParquetFileFormat extends Logging {
}
}
- /**
- * Reconciles Hive Metastore case insensitivity issue and data type conflicts between Metastore
- * schema and Parquet schema.
- *
- * Hive doesn't retain case information, while Parquet is case sensitive. On the other hand, the
- * schema read from Parquet files may be incomplete (e.g. older versions of Parquet doesn't
- * distinguish binary and string). This method generates a correct schema by merging Metastore
- * schema data types and Parquet schema field names.
- */
- def mergeMetastoreParquetSchema(
- metastoreSchema: StructType,
- parquetSchema: StructType): StructType = {
- def schemaConflictMessage: String =
- s"""Converting Hive Metastore Parquet, but detected conflicting schemas. Metastore schema:
- |${metastoreSchema.prettyJson}
- |
- |Parquet schema:
- |${parquetSchema.prettyJson}
- """.stripMargin
-
- val mergedParquetSchema = mergeMissingNullableFields(metastoreSchema, parquetSchema)
-
- assert(metastoreSchema.size <= mergedParquetSchema.size, schemaConflictMessage)
-
- val ordinalMap = metastoreSchema.zipWithIndex.map {
- case (field, index) => field.name.toLowerCase -> index
- }.toMap
-
- val reorderedParquetSchema = mergedParquetSchema.sortBy(f =>
- ordinalMap.getOrElse(f.name.toLowerCase, metastoreSchema.size + 1))
-
- StructType(metastoreSchema.zip(reorderedParquetSchema).map {
- // Uses Parquet field names but retains Metastore data types.
- case (mSchema, pSchema) if mSchema.name.toLowerCase == pSchema.name.toLowerCase =>
- mSchema.copy(name = pSchema.name)
- case _ =>
- throw new SparkException(schemaConflictMessage)
- })
- }
-
- /**
- * Returns the original schema from the Parquet file with any missing nullable fields from the
- * Hive Metastore schema merged in.
- *
- * When constructing a DataFrame from a collection of structured data, the resulting object has
- * a schema corresponding to the union of the fields present in each element of the collection.
- * Spark SQL simply assigns a null value to any field that isn't present for a particular row.
- * In some cases, it is possible that a given table partition stored as a Parquet file doesn't
- * contain a particular nullable field in its schema despite that field being present in the
- * table schema obtained from the Hive Metastore. This method returns a schema representing the
- * Parquet file schema along with any additional nullable fields from the Metastore schema
- * merged in.
- */
- private[parquet] def mergeMissingNullableFields(
- metastoreSchema: StructType,
- parquetSchema: StructType): StructType = {
- val fieldMap = metastoreSchema.map(f => f.name.toLowerCase -> f).toMap
- val missingFields = metastoreSchema
- .map(_.name.toLowerCase)
- .diff(parquetSchema.map(_.name.toLowerCase))
- .map(fieldMap(_))
- .filter(_.nullable)
- StructType(parquetSchema ++ missingFields)
- }
-
/**
* Reads Parquet footers in multi-threaded manner.
* If the config "spark.sql.files.ignoreCorruptFiles" is set to true, we will ignore the corrupted
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
index 4d781b96abac..8b598cc60e77 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
@@ -66,7 +66,8 @@ class ResolveSQLOnFile(sparkSession: SparkSession) extends Rule[LogicalPlan] {
* Preprocess [[CreateTable]], to do some normalization and checking.
*/
case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[LogicalPlan] {
- private val catalog = sparkSession.sessionState.catalog
+ // catalog is a def and not a val/lazy val as the latter would introduce a circular reference
+ private def catalog = sparkSession.sessionState.catalog
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// When we CREATE TABLE without specifying the table schema, we should fail the query if
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala
index 125a4930c652..f06544ea8ed0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala
@@ -46,7 +46,7 @@ case class ShuffleExchange(
override def nodeName: String = {
val extraInfo = coordinator match {
case Some(exchangeCoordinator) =>
- s"(coordinator id: ${System.identityHashCode(coordinator)})"
+ s"(coordinator id: ${System.identityHashCode(exchangeCoordinator)})"
case None => ""
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
index 199ba5ce6969..fdd1bcc94be2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
@@ -28,11 +28,13 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
+import org.apache.spark.sql.catalyst.plans.logical.FunctionUtils
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState
import org.apache.spark.sql.execution.streaming.KeyedStateImpl
-import org.apache.spark.sql.types.{DataType, ObjectType, StructType}
+import org.apache.spark.sql.types._
+import org.apache.spark.util.Utils
/**
@@ -219,7 +221,7 @@ case class MapElementsExec(
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
val (funcClass, methodName) = func match {
case m: MapFunction[_, _] => classOf[MapFunction[_, _]] -> "call"
- case _ => classOf[Any => Any] -> "apply"
+ case _ => FunctionUtils.getFunctionOneName(outputObjAttr.dataType, child.output(0).dataType)
}
val funcObj = Literal.create(func, ObjectType(funcClass))
val callFunc = Invoke(funcObj, methodName, outputObjAttr.dataType, child.output)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
index 46fd54e5c742..fcd84705f7e8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
@@ -112,6 +112,8 @@ object EvaluatePython {
case (c: Int, DateType) => c
case (c: Long, TimestampType) => c
+ // Py4J serializes values between MIN_INT and MAX_INT as Ints, not Longs
+ case (c: Int, TimestampType) => c.toLong
case (c, StringType) => UTF8String.fromString(c.toString)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala
index 5a9a99e11188..25cf609fc336 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala
@@ -84,10 +84,7 @@ case class EventTimeWatermarkExec(
child: SparkPlan) extends SparkPlan {
val eventTimeStats = new EventTimeStatsAccum()
- val delayMs = {
- val millisPerMonth = CalendarInterval.MICROS_PER_DAY / 1000 * 31
- delay.milliseconds + delay.months * millisPerMonth
- }
+ val delayMs = EventTimeWatermark.getDelayMs(delay)
sparkContext.register(eventTimeStats)
@@ -105,10 +102,16 @@ case class EventTimeWatermarkExec(
override val output: Seq[Attribute] = child.output.map { a =>
if (a semanticEquals eventTime) {
val updatedMetadata = new MetadataBuilder()
- .withMetadata(a.metadata)
- .putLong(EventTimeWatermark.delayKey, delayMs)
- .build()
-
+ .withMetadata(a.metadata)
+ .putLong(EventTimeWatermark.delayKey, delayMs)
+ .build()
+ a.withMetadata(updatedMetadata)
+ } else if (a.metadata.contains(EventTimeWatermark.delayKey)) {
+ // Remove existing watermark
+ val updatedMetadata = new MetadataBuilder()
+ .withMetadata(a.metadata)
+ .remove(EventTimeWatermark.delayKey)
+ .build()
a.withMetadata(updatedMetadata)
} else {
a
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala
index 2f802d782f5a..d54ed44b43bf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala
@@ -38,7 +38,10 @@ class FileStreamOptions(parameters: CaseInsensitiveMap[String]) extends Logging
}
/**
- * Maximum age of a file that can be found in this directory, before it is deleted.
+ * Maximum age of a file that can be found in this directory, before it is ignored. For the
+ * first batch all files will be considered valid. If `latestFirst` is set to `true` and
+ * `maxFilesPerTrigger` is set, then this parameter will be ignored, because old files that are
+ * valid, and should be processed, may be ignored. Please refer to SPARK-19813 for details.
*
* The max age is specified with respect to the timestamp of the latest file, and not the
* timestamp of the current system. That this means if the last file has timestamp 1000, and the
@@ -58,13 +61,29 @@ class FileStreamOptions(parameters: CaseInsensitiveMap[String]) extends Logging
* Whether to scan latest files first. If it's true, when the source finds unprocessed files in a
* trigger, it will first process the latest files.
*/
- val latestFirst: Boolean = parameters.get("latestFirst").map { str =>
- try {
- str.toBoolean
- } catch {
- case _: IllegalArgumentException =>
- throw new IllegalArgumentException(
- s"Invalid value '$str' for option 'latestFirst', must be 'true' or 'false'")
- }
- }.getOrElse(false)
+ val latestFirst: Boolean = withBooleanParameter("latestFirst", false)
+
+ /**
+ * Whether to check new files based on only the filename instead of on the full path.
+ *
+ * With this set to `true`, the following files would be considered as the same file, because
+ * their filenames, "dataset.txt", are the same:
+ * - "file:///dataset.txt"
+ * - "s3://a/dataset.txt"
+ * - "s3n://a/b/dataset.txt"
+ * - "s3a://a/b/c/dataset.txt"
+ */
+ val fileNameOnly: Boolean = withBooleanParameter("fileNameOnly", false)
+
+ private def withBooleanParameter(name: String, default: Boolean) = {
+ parameters.get(name).map { str =>
+ try {
+ str.toBoolean
+ } catch {
+ case _: IllegalArgumentException =>
+ throw new IllegalArgumentException(
+ s"Invalid value '$str' for option '$name', must be 'true' or 'false'")
+ }
+ }.getOrElse(default)
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala
index 6a7263ca45d8..411a15ffceb6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution.streaming
+import java.net.URI
+
import scala.collection.JavaConverters._
import org.apache.hadoop.fs.{FileStatus, Path}
@@ -66,23 +68,36 @@ class FileStreamSource(
private val fileSortOrder = if (sourceOptions.latestFirst) {
logWarning(
- """'latestFirst' is true. New files will be processed first.
- |It may affect the watermark value""".stripMargin)
+ """'latestFirst' is true. New files will be processed first, which may affect the watermark
+ |value. In addition, 'maxFileAge' will be ignored.""".stripMargin)
implicitly[Ordering[Long]].reverse
} else {
implicitly[Ordering[Long]]
}
+ private val maxFileAgeMs: Long = if (sourceOptions.latestFirst && maxFilesPerBatch.isDefined) {
+ Long.MaxValue
+ } else {
+ sourceOptions.maxFileAgeMs
+ }
+
+ private val fileNameOnly = sourceOptions.fileNameOnly
+ if (fileNameOnly) {
+ logWarning("'fileNameOnly' is enabled. Make sure your file names are unique (e.g. using " +
+ "UUID), otherwise, files with the same name but under different paths will be considered " +
+ "the same and causes data lost.")
+ }
+
/** A mapping from a file that we have processed to some timestamp it was last modified. */
// Visible for testing and debugging in production.
- val seenFiles = new SeenFilesMap(sourceOptions.maxFileAgeMs)
+ val seenFiles = new SeenFilesMap(maxFileAgeMs, fileNameOnly)
metadataLog.allFiles().foreach { entry =>
seenFiles.add(entry.path, entry.timestamp)
}
seenFiles.purge()
- logInfo(s"maxFilesPerBatch = $maxFilesPerBatch, maxFileAge = ${sourceOptions.maxFileAgeMs}")
+ logInfo(s"maxFilesPerBatch = $maxFilesPerBatch, maxFileAge = $maxFileAgeMs")
/**
* Returns the maximum offset that can be retrieved from the source.
@@ -262,7 +277,7 @@ object FileStreamSource {
* To prevent the hash map from growing indefinitely, a purge function is available to
* remove files "maxAgeMs" older than the latest file.
*/
- class SeenFilesMap(maxAgeMs: Long) {
+ class SeenFilesMap(maxAgeMs: Long, fileNameOnly: Boolean) {
require(maxAgeMs >= 0)
/** Mapping from file to its timestamp. */
@@ -274,9 +289,13 @@ object FileStreamSource {
/** Timestamp for the last purge operation. */
private var lastPurgeTimestamp: Timestamp = 0L
+ @inline private def stripPathIfNecessary(path: String) = {
+ if (fileNameOnly) new Path(new URI(path)).getName else path
+ }
+
/** Add a new file to the map. */
def add(path: String, timestamp: Timestamp): Unit = {
- map.put(path, timestamp)
+ map.put(stripPathIfNecessary(path), timestamp)
if (timestamp > latestTimestamp) {
latestTimestamp = timestamp
}
@@ -289,7 +308,7 @@ object FileStreamSource {
def isNewFile(path: String, timestamp: Timestamp): Boolean = {
// Note that we are testing against lastPurgeTimestamp here so we'd never miss a file that
// is older than (latestTimestamp - maxAgeMs) but has not been purged yet.
- timestamp >= lastPurgeTimestamp && !map.containsKey(path)
+ timestamp >= lastPurgeTimestamp && !map.containsKey(stripPathIfNecessary(path))
}
/** Removes aged entries and returns the number of files removed. */
@@ -308,9 +327,5 @@ object FileStreamSource {
}
def size: Int = map.size()
-
- def allEntries: Seq[(String, Timestamp)] = {
- map.asScala.toSeq
- }
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
index ffdcd9b19d05..610ce5e1ebf5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
@@ -103,11 +103,11 @@ class IncrementalExecution(
child,
Some(stateId),
Some(currentEventTimeWatermark))
- case MapGroupsWithStateExec(
+ case FlatMapGroupsWithStateExec(
f, kDeser, vDeser, group, data, output, None, stateDeser, stateSer, child) =>
val stateId =
OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId)
- MapGroupsWithStateExec(
+ FlatMapGroupsWithStateExec(
f, kDeser, vDeser, group, data, output, Some(stateId), stateDeser, stateSer, child)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index 70912d13ae45..529263805c0a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -361,6 +361,13 @@ class StreamExecution(
}
}
} finally {
+ awaitBatchLock.lock()
+ try {
+ // Wake up any threads that are waiting for the stream to progress.
+ awaitBatchLockCondition.signalAll()
+ } finally {
+ awaitBatchLock.unlock()
+ }
terminationLatch.countDown()
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
index d92529748b6a..c3075a3eacaa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
@@ -68,7 +68,7 @@ trait StateStoreWriter extends StatefulOperator {
}
/** An operator that supports watermark. */
-trait WatermarkSupport extends SparkPlan {
+trait WatermarkSupport extends UnaryExecNode {
/** The keys that may have a watermark attribute. */
def keyExpressions: Seq[Attribute]
@@ -76,8 +76,8 @@ trait WatermarkSupport extends SparkPlan {
/** The watermark value. */
def eventTimeWatermark: Option[Long]
- /** Generate a predicate that matches data older than the watermark */
- lazy val watermarkPredicate: Option[Predicate] = {
+ /** Generate an expression that matches data older than the watermark */
+ lazy val watermarkExpression: Option[Expression] = {
val optionalWatermarkAttribute =
keyExpressions.find(_.metadata.contains(EventTimeWatermark.delayKey))
@@ -96,9 +96,19 @@ trait WatermarkSupport extends SparkPlan {
}
logInfo(s"Filtering state store on: $evictionExpression")
- newPredicate(evictionExpression, keyExpressions)
+ evictionExpression
}
}
+
+ /** Generate a predicate based on keys that matches data older than the watermark */
+ lazy val watermarkPredicateForKeys: Option[Predicate] =
+ watermarkExpression.map(newPredicate(_, keyExpressions))
+
+ /**
+ * Generate a predicate based on the child output that matches data older than the watermark.
+ */
+ lazy val watermarkPredicate: Option[Predicate] =
+ watermarkExpression.map(newPredicate(_, child.output))
}
/**
@@ -192,7 +202,7 @@ case class StateStoreSaveExec(
}
// Assumption: Append mode can be done only when watermark has been specified
- store.remove(watermarkPredicate.get.eval _)
+ store.remove(watermarkPredicateForKeys.get.eval _)
store.commit()
numTotalStateRows += store.numKeys()
@@ -215,7 +225,9 @@ case class StateStoreSaveExec(
override def hasNext: Boolean = {
if (!baseIterator.hasNext) {
// Remove old aggregates if watermark specified
- if (watermarkPredicate.nonEmpty) store.remove(watermarkPredicate.get.eval _)
+ if (watermarkPredicateForKeys.nonEmpty) {
+ store.remove(watermarkPredicateForKeys.get.eval _)
+ }
store.commit()
numTotalStateRows += store.numKeys()
false
@@ -245,8 +257,8 @@ case class StateStoreSaveExec(
}
-/** Physical operator for executing streaming mapGroupsWithState. */
-case class MapGroupsWithStateExec(
+/** Physical operator for executing streaming flatMapGroupsWithState. */
+case class FlatMapGroupsWithStateExec(
func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any],
keyDeserializer: Expression,
valueDeserializer: Expression,
@@ -361,7 +373,7 @@ case class StreamingDeduplicateExec(
val numUpdatedStateRows = longMetric("numUpdatedStateRows")
val baseIterator = watermarkPredicate match {
- case Some(predicate) => iter.filter((row: InternalRow) => !predicate.eval(row))
+ case Some(predicate) => iter.filter(row => !predicate.eval(row))
case None => iter
}
@@ -381,7 +393,7 @@ case class StreamingDeduplicateExec(
}
CompletionIterator[InternalRow, Iterator[InternalRow]](result, {
- watermarkPredicate.foreach(f => store.remove(f.eval _))
+ watermarkPredicateForKeys.foreach(f => store.remove(f.eval _))
store.commit()
numTotalStateRows += store.numKeys()
})
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 2247010ac3f3..201f726db3fa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -2973,7 +2973,22 @@ object functions {
* @group collection_funcs
* @since 2.1.0
*/
- def from_json(e: Column, schema: StructType, options: Map[String, String]): Column = withExpr {
+ def from_json(e: Column, schema: StructType, options: Map[String, String]): Column =
+ from_json(e, schema.asInstanceOf[DataType], options)
+
+ /**
+ * (Scala-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType`
+ * with the specified schema. Returns `null`, in the case of an unparseable string.
+ *
+ * @param e a string column containing JSON data.
+ * @param schema the schema to use when parsing the json string
+ * @param options options to control how the json is parsed. accepts the same options and the
+ * json data source.
+ *
+ * @group collection_funcs
+ * @since 2.2.0
+ */
+ def from_json(e: Column, schema: DataType, options: Map[String, String]): Column = withExpr {
JsonToStruct(schema, options, e.expr)
}
@@ -2992,6 +3007,21 @@ object functions {
def from_json(e: Column, schema: StructType, options: java.util.Map[String, String]): Column =
from_json(e, schema, options.asScala.toMap)
+ /**
+ * (Java-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType`
+ * with the specified schema. Returns `null`, in the case of an unparseable string.
+ *
+ * @param e a string column containing JSON data.
+ * @param schema the schema to use when parsing the json string
+ * @param options options to control how the json is parsed. accepts the same options and the
+ * json data source.
+ *
+ * @group collection_funcs
+ * @since 2.2.0
+ */
+ def from_json(e: Column, schema: DataType, options: java.util.Map[String, String]): Column =
+ from_json(e, schema, options.asScala.toMap)
+
/**
* Parses a column containing a JSON string into a `StructType` with the specified schema.
* Returns `null`, in the case of an unparseable string.
@@ -3006,8 +3036,21 @@ object functions {
from_json(e, schema, Map.empty[String, String])
/**
- * Parses a column containing a JSON string into a `StructType` with the specified schema.
- * Returns `null`, in the case of an unparseable string.
+ * Parses a column containing a JSON string into a `StructType` or `ArrayType`
+ * with the specified schema. Returns `null`, in the case of an unparseable string.
+ *
+ * @param e a string column containing JSON data.
+ * @param schema the schema to use when parsing the json string
+ *
+ * @group collection_funcs
+ * @since 2.2.0
+ */
+ def from_json(e: Column, schema: DataType): Column =
+ from_json(e, schema, Map.empty[String, String])
+
+ /**
+ * Parses a column containing a JSON string into a `StructType` or `ArrayType`
+ * with the specified schema. Returns `null`, in the case of an unparseable string.
*
* @param e a string column containing JSON data.
* @param schema the schema to use when parsing the json string as a json string
@@ -3016,8 +3059,7 @@ object functions {
* @since 2.1.0
*/
def from_json(e: Column, schema: String, options: java.util.Map[String, String]): Column =
- from_json(e, DataType.fromJson(schema).asInstanceOf[StructType], options)
-
+ from_json(e, DataType.fromJson(schema), options)
/**
* (Scala-specific) Converts a column containing a `StructType` into a JSON string with the
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala
index 3d9f41832bc7..53374859f13f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.internal
import scala.reflect.runtime.universe.TypeTag
+import org.apache.hadoop.fs.Path
+
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql._
import org.apache.spark.sql.catalog.{Catalog, Column, Database, Function, Table}
@@ -77,7 +79,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
new Database(
name = metadata.name,
description = metadata.description,
- locationUri = metadata.locationUri)
+ locationUri = CatalogUtils.URIToString(metadata.locationUri))
}
/**
@@ -341,8 +343,8 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
* @since 2.0.0
*/
override def dropTempView(viewName: String): Boolean = {
- sparkSession.sessionState.catalog.getTempView(viewName).exists { tempView =>
- sparkSession.sharedState.cacheManager.uncacheQuery(Dataset.ofRows(sparkSession, tempView))
+ sparkSession.sessionState.catalog.getTempView(viewName).exists { viewDef =>
+ sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, viewDef, blocking = true)
sessionCatalog.dropTempView(viewName)
}
}
@@ -357,7 +359,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
*/
override def dropGlobalTempView(viewName: String): Boolean = {
sparkSession.sessionState.catalog.getGlobalTempView(viewName).exists { viewDef =>
- sparkSession.sharedState.cacheManager.uncacheQuery(Dataset.ofRows(sparkSession, viewDef))
+ sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, viewDef, blocking = true)
sessionCatalog.dropGlobalTempView(viewName)
}
}
@@ -402,7 +404,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
* @since 2.0.0
*/
override def uncacheTable(tableName: String): Unit = {
- sparkSession.sharedState.cacheManager.uncacheQuery(query = sparkSession.table(tableName))
+ sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName))
}
/**
@@ -440,17 +442,12 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
// If this table is cached as an InMemoryRelation, drop the original
// cached version and make the new version cached lazily.
- val logicalPlan = sparkSession.table(tableIdent).queryExecution.analyzed
- // Use lookupCachedData directly since RefreshTable also takes databaseName.
- val isCached = sparkSession.sharedState.cacheManager.lookupCachedData(logicalPlan).nonEmpty
- if (isCached) {
- // Create a data frame to represent the table.
- // TODO: Use uncacheTable once it supports database name.
- val df = Dataset.ofRows(sparkSession, logicalPlan)
+ val table = sparkSession.table(tableIdent)
+ if (isCached(table)) {
// Uncache the logicalPlan.
- sparkSession.sharedState.cacheManager.uncacheQuery(df, blocking = true)
+ sparkSession.sharedState.cacheManager.uncacheQuery(table, blocking = true)
// Cache it again.
- sparkSession.sharedState.cacheManager.cacheQuery(df, Some(tableIdent.table))
+ sparkSession.sharedState.cacheManager.cacheQuery(table, Some(tableIdent.table))
}
}
@@ -462,7 +459,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
* @since 2.0.0
*/
override def refreshByPath(resourcePath: String): Unit = {
- sparkSession.sharedState.cacheManager.invalidateCachedPath(sparkSession, resourcePath)
+ sparkSession.sharedState.cacheManager.recacheByPath(sparkSession, resourcePath)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 461dfe3a66e1..8e3f567b7dd9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -296,6 +296,25 @@ object SQLConf {
.longConf
.createWithDefault(250 * 1024 * 1024)
+ object HiveCaseSensitiveInferenceMode extends Enumeration {
+ val INFER_AND_SAVE, INFER_ONLY, NEVER_INFER = Value
+ }
+
+ val HIVE_CASE_SENSITIVE_INFERENCE = buildConf("spark.sql.hive.caseSensitiveInferenceMode")
+ .doc("Sets the action to take when a case-sensitive schema cannot be read from a Hive " +
+ "table's properties. Although Spark SQL itself is not case-sensitive, Hive compatible file " +
+ "formats such as Parquet are. Spark SQL must use a case-preserving schema when querying " +
+ "any table backed by files containing case-sensitive field names or queries may not return " +
+ "accurate results. Valid options include INFER_AND_SAVE (the default mode-- infer the " +
+ "case-sensitive schema from the underlying data files and write it back to the table " +
+ "properties), INFER_ONLY (infer the schema but don't attempt to write it to the table " +
+ "properties) and NEVER_INFER (fallback to using the case-insensitive metastore schema " +
+ "instead of inferring).")
+ .stringConf
+ .transform(_.toUpperCase())
+ .checkValues(HiveCaseSensitiveInferenceMode.values.map(_.toString))
+ .createWithDefault(HiveCaseSensitiveInferenceMode.INFER_AND_SAVE.toString)
+
val OPTIMIZER_METADATA_ONLY = buildConf("spark.sql.optimizer.metadataOnly")
.doc("When true, enable the metadata-only query optimization that use the table's metadata " +
"to produce the partition columns instead of table scans. It applies when all the columns " +
@@ -668,6 +687,18 @@ object SQLConf {
.booleanConf
.createWithDefault(false)
+ val JOIN_REORDER_ENABLED =
+ buildConf("spark.sql.cbo.joinReorder.enabled")
+ .doc("Enables join reorder in CBO.")
+ .booleanConf
+ .createWithDefault(false)
+
+ val JOIN_REORDER_DP_THRESHOLD =
+ buildConf("spark.sql.cbo.joinReorder.dp.threshold")
+ .doc("The maximum number of joined nodes allowed in the dynamic programming algorithm.")
+ .intConf
+ .createWithDefault(12)
+
val SESSION_LOCAL_TIMEZONE =
buildConf("spark.sql.session.timeZone")
.doc("""The ID of session local timezone, e.g. "GMT", "America/Los_Angeles", etc.""")
@@ -677,6 +708,10 @@ object SQLConf {
object Deprecated {
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
}
+
+ object Replaced {
+ val MAPREDUCE_JOB_REDUCES = "mapreduce.job.reduces"
+ }
}
/**
@@ -776,6 +811,9 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging {
def filesourcePartitionFileCacheSize: Long = getConf(HIVE_FILESOURCE_PARTITION_FILE_CACHE_SIZE)
+ def caseSensitiveInferenceMode: HiveCaseSensitiveInferenceMode.Value =
+ HiveCaseSensitiveInferenceMode.withName(getConf(HIVE_CASE_SENSITIVE_INFERENCE))
+
def gatherFastStats: Boolean = getConf(GATHER_FASTSTAT)
def optimizerMetadataOnly: Boolean = getConf(OPTIMIZER_METADATA_ONLY)
@@ -881,6 +919,10 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging {
override def cboEnabled: Boolean = getConf(SQLConf.CBO_ENABLED)
+ override def joinReorderEnabled: Boolean = getConf(SQLConf.JOIN_REORDER_ENABLED)
+
+ override def joinReorderDPThreshold: Int = getConf(SQLConf.JOIN_REORDER_DP_THRESHOLD)
+
/** ********************** SQLConf functionality methods ************ */
/** Set Spark SQL configuration properties. */
@@ -999,6 +1041,14 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging {
def clear(): Unit = {
settings.clear()
}
+
+ override def clone(): SQLConf = {
+ val result = new SQLConf
+ getAllConfs.foreach {
+ case(k, v) => if (v ne null) result.setConfString(k, v)
+ }
+ result
+ }
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
index 69085605113e..ce80604bd365 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
@@ -22,38 +22,49 @@ import java.io.File
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
+import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry}
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.command.AnalyzeTableCommand
import org.apache.spark.sql.execution.datasources._
-import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryManager}
+import org.apache.spark.sql.streaming.StreamingQueryManager
import org.apache.spark.sql.util.ExecutionListenerManager
/**
* A class that holds all session-specific state in a given [[SparkSession]].
+ * @param sparkContext The [[SparkContext]].
+ * @param sharedState The shared state.
+ * @param conf SQL-specific key-value configurations.
+ * @param experimentalMethods The experimental methods.
+ * @param functionRegistry Internal catalog for managing functions registered by the user.
+ * @param catalog Internal catalog for managing table and database states.
+ * @param sqlParser Parser that extracts expressions, plans, table identifiers etc. from SQL texts.
+ * @param analyzer Logical query plan analyzer for resolving unresolved attributes and relations.
+ * @param streamingQueryManager Interface to start and stop
+ * [[org.apache.spark.sql.streaming.StreamingQuery]]s.
+ * @param queryExecutionCreator Lambda to create a [[QueryExecution]] from a [[LogicalPlan]]
*/
-private[sql] class SessionState(sparkSession: SparkSession) {
+private[sql] class SessionState(
+ sparkContext: SparkContext,
+ sharedState: SharedState,
+ val conf: SQLConf,
+ val experimentalMethods: ExperimentalMethods,
+ val functionRegistry: FunctionRegistry,
+ val catalog: SessionCatalog,
+ val sqlParser: ParserInterface,
+ val analyzer: Analyzer,
+ val streamingQueryManager: StreamingQueryManager,
+ val queryExecutionCreator: LogicalPlan => QueryExecution) {
- // Note: These are all lazy vals because they depend on each other (e.g. conf) and we
- // want subclasses to override some of the fields. Otherwise, we would get a lot of NPEs.
-
- /**
- * SQL-specific key-value configurations.
- */
- lazy val conf: SQLConf = new SQLConf
-
- def newHadoopConf(): Configuration = {
- val hadoopConf = new Configuration(sparkSession.sparkContext.hadoopConfiguration)
- conf.getAllConfs.foreach { case (k, v) => if (v ne null) hadoopConf.set(k, v) }
- hadoopConf
- }
+ def newHadoopConf(): Configuration = SessionState.newHadoopConf(
+ sparkContext.hadoopConfiguration,
+ conf)
def newHadoopConfWithOptions(options: Map[String, String]): Configuration = {
val hadoopConf = newHadoopConf()
@@ -65,22 +76,15 @@ private[sql] class SessionState(sparkSession: SparkSession) {
hadoopConf
}
- lazy val experimentalMethods = new ExperimentalMethods
-
- /**
- * Internal catalog for managing functions registered by the user.
- */
- lazy val functionRegistry: FunctionRegistry = FunctionRegistry.builtin.copy()
-
/**
* A class for loading resources specified by a function.
*/
- lazy val functionResourceLoader: FunctionResourceLoader = {
+ val functionResourceLoader: FunctionResourceLoader = {
new FunctionResourceLoader {
override def loadResource(resource: FunctionResource): Unit = {
resource.resourceType match {
case JarResource => addJar(resource.uri)
- case FileResource => sparkSession.sparkContext.addFile(resource.uri)
+ case FileResource => sparkContext.addFile(resource.uri)
case ArchiveResource =>
throw new AnalysisException(
"Archive is not allowed to be loaded. If YARN mode is used, " +
@@ -90,93 +94,78 @@ private[sql] class SessionState(sparkSession: SparkSession) {
}
}
- /**
- * Internal catalog for managing table and database states.
- */
- lazy val catalog = new SessionCatalog(
- sparkSession.sharedState.externalCatalog,
- sparkSession.sharedState.globalTempViewManager,
- functionResourceLoader,
- functionRegistry,
- conf,
- newHadoopConf(),
- sqlParser)
-
/**
* Interface exposed to the user for registering user-defined functions.
* Note that the user-defined functions must be deterministic.
*/
- lazy val udf: UDFRegistration = new UDFRegistration(functionRegistry)
-
- /**
- * Logical query plan analyzer for resolving unresolved attributes and relations.
- */
- lazy val analyzer: Analyzer = {
- new Analyzer(catalog, conf) {
- override val extendedResolutionRules =
- new FindDataSourceTable(sparkSession) ::
- new ResolveSQLOnFile(sparkSession) :: Nil
-
- override val postHocResolutionRules =
- PreprocessTableCreation(sparkSession) ::
- PreprocessTableInsertion(conf) ::
- DataSourceAnalysis(conf) :: Nil
-
- override val extendedCheckRules = Seq(PreWriteCheck, HiveOnlyCheck)
- }
- }
+ val udf: UDFRegistration = new UDFRegistration(functionRegistry)
/**
* Logical query plan optimizer.
*/
- lazy val optimizer: Optimizer = new SparkOptimizer(catalog, conf, experimentalMethods)
-
- /**
- * Parser that extracts expressions, plans, table identifiers etc. from SQL texts.
- */
- lazy val sqlParser: ParserInterface = new SparkSqlParser(conf)
+ val optimizer: Optimizer = new SparkOptimizer(catalog, conf, experimentalMethods)
/**
* Planner that converts optimized logical plans to physical plans.
*/
def planner: SparkPlanner =
- new SparkPlanner(sparkSession.sparkContext, conf, experimentalMethods.extraStrategies)
+ new SparkPlanner(sparkContext, conf, experimentalMethods.extraStrategies)
/**
* An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s
* that listen for execution metrics.
*/
- lazy val listenerManager: ExecutionListenerManager = new ExecutionListenerManager
+ val listenerManager: ExecutionListenerManager = new ExecutionListenerManager
/**
- * Interface to start and stop [[StreamingQuery]]s.
+ * Get an identical copy of the `SessionState` and associate it with the given `SparkSession`
*/
- lazy val streamingQueryManager: StreamingQueryManager = {
- new StreamingQueryManager(sparkSession)
- }
+ def clone(newSparkSession: SparkSession): SessionState = {
+ val sparkContext = newSparkSession.sparkContext
+ val confCopy = conf.clone()
+ val functionRegistryCopy = functionRegistry.clone()
+ val sqlParser: ParserInterface = new SparkSqlParser(confCopy)
+ val catalogCopy = catalog.newSessionCatalogWith(
+ confCopy,
+ SessionState.newHadoopConf(sparkContext.hadoopConfiguration, confCopy),
+ functionRegistryCopy,
+ sqlParser)
+ val queryExecutionCreator = (plan: LogicalPlan) => new QueryExecution(newSparkSession, plan)
- private val jarClassLoader: NonClosableMutableURLClassLoader =
- sparkSession.sharedState.jarClassLoader
+ SessionState.mergeSparkConf(confCopy, sparkContext.getConf)
- // Automatically extract all entries and put it in our SQLConf
- // We need to call it after all of vals have been initialized.
- sparkSession.sparkContext.getConf.getAll.foreach { case (k, v) =>
- conf.setConfString(k, v)
+ new SessionState(
+ sparkContext,
+ newSparkSession.sharedState,
+ confCopy,
+ experimentalMethods.clone(),
+ functionRegistryCopy,
+ catalogCopy,
+ sqlParser,
+ SessionState.createAnalyzer(newSparkSession, catalogCopy, confCopy),
+ new StreamingQueryManager(newSparkSession),
+ queryExecutionCreator)
}
// ------------------------------------------------------
// Helper methods, partially leftover from pre-2.0 days
// ------------------------------------------------------
- def executePlan(plan: LogicalPlan): QueryExecution = new QueryExecution(sparkSession, plan)
+ def executePlan(plan: LogicalPlan): QueryExecution = queryExecutionCreator(plan)
def refreshTable(tableName: String): Unit = {
catalog.refreshTable(sqlParser.parseTableIdentifier(tableName))
}
+ /**
+ * Add a jar path to [[SparkContext]] and the classloader.
+ *
+ * Note: this method seems not access any session state, but the subclass `HiveSessionState` needs
+ * to add the jar to its hive client for the current session. Hence, it still needs to be in
+ * [[SessionState]].
+ */
def addJar(path: String): Unit = {
- sparkSession.sparkContext.addJar(path)
-
+ sparkContext.addJar(path)
val uri = new Path(path).toUri
val jarURL = if (uri.getScheme == null) {
// `path` is a local file path without a URL scheme
@@ -185,15 +174,93 @@ private[sql] class SessionState(sparkSession: SparkSession) {
// `path` is a URL with a scheme
uri.toURL
}
- jarClassLoader.addURL(jarURL)
- Thread.currentThread().setContextClassLoader(jarClassLoader)
+ sharedState.jarClassLoader.addURL(jarURL)
+ Thread.currentThread().setContextClassLoader(sharedState.jarClassLoader)
+ }
+}
+
+
+private[sql] object SessionState {
+
+ def apply(sparkSession: SparkSession): SessionState = {
+ apply(sparkSession, new SQLConf)
+ }
+
+ def apply(sparkSession: SparkSession, sqlConf: SQLConf): SessionState = {
+ val sparkContext = sparkSession.sparkContext
+
+ // Automatically extract all entries and put them in our SQLConf
+ mergeSparkConf(sqlConf, sparkContext.getConf)
+
+ val functionRegistry = FunctionRegistry.builtin.clone()
+
+ val sqlParser: ParserInterface = new SparkSqlParser(sqlConf)
+
+ val catalog = new SessionCatalog(
+ sparkSession.sharedState.externalCatalog,
+ sparkSession.sharedState.globalTempViewManager,
+ functionRegistry,
+ sqlConf,
+ newHadoopConf(sparkContext.hadoopConfiguration, sqlConf),
+ sqlParser)
+
+ val analyzer: Analyzer = createAnalyzer(sparkSession, catalog, sqlConf)
+
+ val streamingQueryManager: StreamingQueryManager = new StreamingQueryManager(sparkSession)
+
+ val queryExecutionCreator = (plan: LogicalPlan) => new QueryExecution(sparkSession, plan)
+
+ val sessionState = new SessionState(
+ sparkContext,
+ sparkSession.sharedState,
+ sqlConf,
+ new ExperimentalMethods,
+ functionRegistry,
+ catalog,
+ sqlParser,
+ analyzer,
+ streamingQueryManager,
+ queryExecutionCreator)
+ // functionResourceLoader needs to access SessionState.addJar, so it cannot be created before
+ // creating SessionState. Setting `catalog.functionResourceLoader` here is safe since the caller
+ // cannot use SessionCatalog before we return SessionState.
+ catalog.functionResourceLoader = sessionState.functionResourceLoader
+ sessionState
+ }
+
+ def newHadoopConf(hadoopConf: Configuration, sqlConf: SQLConf): Configuration = {
+ val newHadoopConf = new Configuration(hadoopConf)
+ sqlConf.getAllConfs.foreach { case (k, v) => if (v ne null) newHadoopConf.set(k, v) }
+ newHadoopConf
+ }
+
+ /**
+ * Create an logical query plan `Analyzer` with rules specific to a non-Hive `SessionState`.
+ */
+ private def createAnalyzer(
+ sparkSession: SparkSession,
+ catalog: SessionCatalog,
+ sqlConf: SQLConf): Analyzer = {
+ new Analyzer(catalog, sqlConf) {
+ override val extendedResolutionRules: Seq[Rule[LogicalPlan]] =
+ new FindDataSourceTable(sparkSession) ::
+ new ResolveSQLOnFile(sparkSession) :: Nil
+
+ override val postHocResolutionRules: Seq[Rule[LogicalPlan]] =
+ PreprocessTableCreation(sparkSession) ::
+ PreprocessTableInsertion(sqlConf) ::
+ DataSourceAnalysis(sqlConf) :: Nil
+
+ override val extendedCheckRules = Seq(PreWriteCheck, HiveOnlyCheck)
+ }
}
/**
- * Analyzes the given table in the current database to generate statistics, which will be
- * used in query optimizations.
+ * Extract entries from `SparkConf` and put them in the `SQLConf`
*/
- def analyze(tableIdent: TableIdentifier, noscan: Boolean = true): Unit = {
- AnalyzeTableCommand(tableIdent, noscan).run(sparkSession)
+ def mergeSparkConf(sqlConf: SQLConf, sparkConf: SparkConf): Unit = {
+ sparkConf.getAll.foreach { case (k, v) =>
+ sqlConf.setConfString(k, v)
+ }
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala
index bce84de45c3d..86129fa87fea 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala
@@ -21,6 +21,7 @@ import scala.reflect.ClassTag
import scala.util.control.NonFatal
import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
import org.apache.spark.{SparkConf, SparkContext, SparkException}
import org.apache.spark.internal.Logging
@@ -95,7 +96,10 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging {
// Create the default database if it doesn't exist.
{
val defaultDbDefinition = CatalogDatabase(
- SessionCatalog.DEFAULT_DATABASE, "default database", warehousePath, Map())
+ SessionCatalog.DEFAULT_DATABASE,
+ "default database",
+ CatalogUtils.stringToURI(warehousePath),
+ Map())
// Initialize default database if it doesn't exist
if (!externalCatalog.databaseExists(SessionCatalog.DEFAULT_DATABASE)) {
// There may be another Spark application creating default database at the same time, here we
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
index 0f7a33723ccc..c8fda8cd8359 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
@@ -20,8 +20,8 @@ package org.apache.spark.sql.streaming
import scala.collection.JavaConverters._
import org.apache.spark.annotation.{Experimental, InterfaceStability}
-import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, ForeachWriter}
-import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
+import org.apache.spark.sql.{AnalysisException, Dataset, ForeachWriter}
+import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.streaming.{ForeachSink, MemoryPlan, MemorySink}
@@ -69,17 +69,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
* @since 2.0.0
*/
def outputMode(outputMode: String): DataStreamWriter[T] = {
- this.outputMode = outputMode.toLowerCase match {
- case "append" =>
- OutputMode.Append
- case "complete" =>
- OutputMode.Complete
- case "update" =>
- OutputMode.Update
- case _ =>
- throw new IllegalArgumentException(s"Unknown output mode $outputMode. " +
- "Accepted output modes are 'append', 'complete', 'update'")
- }
+ this.outputMode = InternalOutputModes(outputMode)
this
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index e3b0e37ccab0..d06e35bb44d0 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -23,6 +23,7 @@
import java.sql.Timestamp;
import java.util.*;
+import org.apache.spark.sql.streaming.OutputMode;
import scala.Tuple2;
import scala.Tuple3;
import scala.Tuple4;
@@ -205,6 +206,7 @@ public void testGroupBy() {
}
return Collections.singletonList(sb.toString()).iterator();
},
+ OutputMode.Append(),
Encoders.LONG(),
Encoders.STRING());
diff --git a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql
new file mode 100644
index 000000000000..9308560451bf
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql
@@ -0,0 +1,8 @@
+-- to_json
+describe function to_json;
+describe function extended to_json;
+select to_json(named_struct('a', 1, 'b', 2));
+select to_json(named_struct('time', to_timestamp('2015-08-26', 'yyyy-MM-dd')), map('timestampFormat', 'dd/MM/yyyy'));
+-- Check if errors handled
+select to_json(named_struct('a', 1, 'b', 2), named_struct('mode', 'PERMISSIVE'));
+select to_json();
diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out
new file mode 100644
index 000000000000..d8aa4fb9fa78
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out
@@ -0,0 +1,63 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 6
+
+
+-- !query 0
+describe function to_json
+-- !query 0 schema
+struct
+-- !query 0 output
+Class: org.apache.spark.sql.catalyst.expressions.StructToJson
+Function: to_json
+Usage: to_json(expr[, options]) - Returns a json string with a given struct value
+
+
+-- !query 1
+describe function extended to_json
+-- !query 1 schema
+struct
+-- !query 1 output
+Class: org.apache.spark.sql.catalyst.expressions.StructToJson
+Extended Usage:
+ Examples:
+ > SELECT to_json(named_struct('a', 1, 'b', 2));
+ {"a":1,"b":2}
+ > SELECT to_json(named_struct('time', to_timestamp('2015-08-26', 'yyyy-MM-dd')), map('timestampFormat', 'dd/MM/yyyy'));
+ {"time":"26/08/2015"}
+
+Function: to_json
+Usage: to_json(expr[, options]) - Returns a json string with a given struct value
+
+
+-- !query 2
+select to_json(named_struct('a', 1, 'b', 2))
+-- !query 2 schema
+struct
+-- !query 2 output
+{"a":1,"b":2}
+
+
+-- !query 3
+select to_json(named_struct('time', to_timestamp('2015-08-26', 'yyyy-MM-dd')), map('timestampFormat', 'dd/MM/yyyy'))
+-- !query 3 schema
+struct
+-- !query 3 output
+{"time":"26/08/2015"}
+
+
+-- !query 4
+select to_json(named_struct('a', 1, 'b', 2), named_struct('mode', 'PERMISSIVE'))
+-- !query 4 schema
+struct<>
+-- !query 4 output
+org.apache.spark.sql.AnalysisException
+Must use a map() function for options;; line 1 pos 7
+
+
+-- !query 5
+select to_json()
+-- !query 5 schema
+struct<>
+-- !query 5 output
+org.apache.spark.sql.AnalysisException
+Invalid number of arguments for function to_json; line 1 pos 7
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index 2a0e088437fd..7a7d52b21427 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -24,15 +24,15 @@ import scala.language.postfixOps
import org.scalatest.concurrent.Eventually._
import org.apache.spark.CleanerListener
-import org.apache.spark.sql.catalyst.expressions.{Expression, SubqueryExpression}
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
import org.apache.spark.sql.execution.RDDScanExec
import org.apache.spark.sql.execution.columnar._
import org.apache.spark.sql.execution.exchange.ShuffleExchange
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
import org.apache.spark.storage.{RDDBlockId, StorageLevel}
-import org.apache.spark.util.AccumulatorContext
+import org.apache.spark.util.{AccumulatorContext, Utils}
private case class BigData(s: String)
@@ -65,7 +65,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
maybeBlock.nonEmpty
}
- private def getNumInMemoryRelations(plan: LogicalPlan): Int = {
+ private def getNumInMemoryRelations(ds: Dataset[_]): Int = {
+ val plan = ds.queryExecution.withCachedData
var sum = plan.collect { case _: InMemoryRelation => 1 }.sum
plan.transformAllExpressions {
case e: SubqueryExpression =>
@@ -187,7 +188,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
assertCached(spark.table("testData"))
assertResult(1, "InMemoryRelation not found, testData should have been cached") {
- getNumInMemoryRelations(spark.table("testData").queryExecution.withCachedData)
+ getNumInMemoryRelations(spark.table("testData"))
}
spark.catalog.cacheTable("testData")
@@ -580,21 +581,21 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
localRelation.createOrReplaceTempView("localRelation")
spark.catalog.cacheTable("localRelation")
- assert(getNumInMemoryRelations(localRelation.queryExecution.withCachedData) == 1)
+ assert(getNumInMemoryRelations(localRelation) == 1)
}
test("SPARK-19093 Caching in side subquery") {
withTempView("t1") {
Seq(1).toDF("c1").createOrReplaceTempView("t1")
spark.catalog.cacheTable("t1")
- val cachedPlan =
+ val ds =
sql(
"""
|SELECT * FROM t1
|WHERE
|NOT EXISTS (SELECT * FROM t1)
- """.stripMargin).queryExecution.optimizedPlan
- assert(getNumInMemoryRelations(cachedPlan) == 2)
+ """.stripMargin)
+ assert(getNumInMemoryRelations(ds) == 2)
}
}
@@ -610,17 +611,17 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
spark.catalog.cacheTable("t4")
// Nested predicate subquery
- val cachedPlan =
+ val ds =
sql(
"""
|SELECT * FROM t1
|WHERE
|c1 IN (SELECT c1 FROM t2 WHERE c1 IN (SELECT c1 FROM t3 WHERE c1 = 1))
- """.stripMargin).queryExecution.optimizedPlan
- assert(getNumInMemoryRelations(cachedPlan) == 3)
+ """.stripMargin)
+ assert(getNumInMemoryRelations(ds) == 3)
// Scalar subquery and predicate subquery
- val cachedPlan2 =
+ val ds2 =
sql(
"""
|SELECT * FROM (SELECT max(c1) FROM t1 GROUP BY c1)
@@ -630,8 +631,27 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
|EXISTS (SELECT c1 FROM t3)
|OR
|c1 IN (SELECT c1 FROM t4)
- """.stripMargin).queryExecution.optimizedPlan
- assert(getNumInMemoryRelations(cachedPlan2) == 4)
+ """.stripMargin)
+ assert(getNumInMemoryRelations(ds2) == 4)
+ }
+ }
+
+ test("SPARK-19765: UNCACHE TABLE should un-cache all cached plans that refer to this table") {
+ withTable("t") {
+ withTempPath { path =>
+ Seq(1 -> "a").toDF("i", "j").write.parquet(path.getCanonicalPath)
+ sql(s"CREATE TABLE t USING parquet LOCATION '$path'")
+ spark.catalog.cacheTable("t")
+ spark.table("t").select($"i").cache()
+ checkAnswer(spark.table("t").select($"i"), Row(1))
+ assertCached(spark.table("t").select($"i"))
+
+ Utils.deleteRecursively(path)
+ spark.sessionState.catalog.refreshTable(TableIdentifier("t"))
+ spark.catalog.uncacheTable("t")
+ assert(spark.table("t").select($"i").count() == 0)
+ assert(getNumInMemoryRelations(spark.table("t").select($"i")) == 0)
+ }
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala
index acf393a9b0fa..5e323c02b253 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala
@@ -89,6 +89,22 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall
val n = 9L * 1000 * 1000 * 1000 * 1000 * 1000 * 1000
val res13 = spark.range(-n, n, n / 9).select("id")
assert(res13.count == 18)
+
+ // range with non aggregation operation
+ val res14 = spark.range(0, 100, 2).toDF.filter("50 <= id")
+ val len14 = res14.collect.length
+ assert(len14 == 25)
+
+ val res15 = spark.range(100, -100, -2).toDF.filter("id <= 0")
+ val len15 = res15.collect.length
+ assert(len15 == 50)
+
+ val res16 = spark.range(-1500, 1500, 3).toDF.filter("0 <= id")
+ val len16 = res16.collect.length
+ assert(len16 == 500)
+
+ val res17 = spark.range(10, 0, -1, 1).toDF.sortWithinPartitions("id")
+ assert(res17.collect === (1 to 10).map(i => Row(i)).toArray)
}
test("Range with randomized parameters") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 19c2d5532d08..52bd4e19f895 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -1703,4 +1703,23 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val df = spark.range(1).selectExpr("CAST(id as DECIMAL) as x").selectExpr("percentile(x, 0.5)")
checkAnswer(df, Row(BigDecimal(0.0)) :: Nil)
}
+
+ test("SPARK-19893: cannot run set operations with map type") {
+ val df = spark.range(1).select(map(lit("key"), $"id").as("m"))
+ val e = intercept[AnalysisException](df.intersect(df))
+ assert(e.message.contains(
+ "Cannot have map type columns in DataFrame which calls set operations"))
+ val e2 = intercept[AnalysisException](df.except(df))
+ assert(e2.message.contains(
+ "Cannot have map type columns in DataFrame which calls set operations"))
+ val e3 = intercept[AnalysisException](df.distinct())
+ assert(e3.message.contains(
+ "Cannot have map type columns in DataFrame which calls set operations"))
+ withTempView("v") {
+ df.createOrReplaceTempView("v")
+ val e4 = intercept[AnalysisException](sql("SELECT DISTINCT m FROM v"))
+ assert(e4.message.contains(
+ "Cannot have map type columns in DataFrame which calls set operations"))
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala
index 66d94d601605..1a0672b8876d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala
@@ -31,6 +31,49 @@ object DatasetBenchmark {
case class Data(l: Long, s: String)
+ def backToBackMapLong(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = {
+ import spark.implicits._
+
+ val rdd = spark.sparkContext.range(0, numRows)
+ val ds = spark.range(0, numRows)
+ val df = ds.toDF("l")
+ val func = (l: Long) => l + 1
+
+ val benchmark = new Benchmark("back-to-back map long", numRows)
+
+ benchmark.addCase("RDD") { iter =>
+ var res = rdd
+ var i = 0
+ while (i < numChains) {
+ res = res.map(func)
+ i += 1
+ }
+ res.foreach(_ => Unit)
+ }
+
+ benchmark.addCase("DataFrame") { iter =>
+ var res = df
+ var i = 0
+ while (i < numChains) {
+ res = res.select($"l" + 1 as "l")
+ i += 1
+ }
+ res.queryExecution.toRdd.foreach(_ => Unit)
+ }
+
+ benchmark.addCase("Dataset") { iter =>
+ var res = ds.as[Long]
+ var i = 0
+ while (i < numChains) {
+ res = res.map(func)
+ i += 1
+ }
+ res.queryExecution.toRdd.foreach(_ => Unit)
+ }
+
+ benchmark
+ }
+
def backToBackMap(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = {
import spark.implicits._
@@ -72,6 +115,49 @@ object DatasetBenchmark {
benchmark
}
+ def backToBackFilterLong(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = {
+ import spark.implicits._
+
+ val rdd = spark.sparkContext.range(1, numRows)
+ val ds = spark.range(1, numRows)
+ val df = ds.toDF("l")
+ val func = (l: Long) => l % 2L == 0L
+
+ val benchmark = new Benchmark("back-to-back filter Long", numRows)
+
+ benchmark.addCase("RDD") { iter =>
+ var res = rdd
+ var i = 0
+ while (i < numChains) {
+ res = res.filter(func)
+ i += 1
+ }
+ res.foreach(_ => Unit)
+ }
+
+ benchmark.addCase("DataFrame") { iter =>
+ var res = df
+ var i = 0
+ while (i < numChains) {
+ res = res.filter($"l" % 2L === 0L)
+ i += 1
+ }
+ res.queryExecution.toRdd.foreach(_ => Unit)
+ }
+
+ benchmark.addCase("Dataset") { iter =>
+ var res = ds.as[Long]
+ var i = 0
+ while (i < numChains) {
+ res = res.filter(func)
+ i += 1
+ }
+ res.queryExecution.toRdd.foreach(_ => Unit)
+ }
+
+ benchmark
+ }
+
def backToBackFilter(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = {
import spark.implicits._
@@ -165,9 +251,22 @@ object DatasetBenchmark {
val numRows = 100000000
val numChains = 10
- val benchmark = backToBackMap(spark, numRows, numChains)
- val benchmark2 = backToBackFilter(spark, numRows, numChains)
- val benchmark3 = aggregate(spark, numRows)
+ val benchmark0 = backToBackMapLong(spark, numRows, numChains)
+ val benchmark1 = backToBackMap(spark, numRows, numChains)
+ val benchmark2 = backToBackFilterLong(spark, numRows, numChains)
+ val benchmark3 = backToBackFilter(spark, numRows, numChains)
+ val benchmark4 = aggregate(spark, numRows)
+
+ /*
+ OpenJDK 64-Bit Server VM 1.8.0_111-8u111-b14-2ubuntu0.16.04.2-b14 on Linux 4.4.0-47-generic
+ Intel(R) Xeon(R) CPU E5-2667 v3 @ 3.20GHz
+ back-to-back map long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ ------------------------------------------------------------------------------------------------
+ RDD 1883 / 1892 53.1 18.8 1.0X
+ DataFrame 502 / 642 199.1 5.0 3.7X
+ Dataset 657 / 784 152.2 6.6 2.9X
+ */
+ benchmark0.run()
/*
OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 3.10.0-327.18.2.el7.x86_64
@@ -178,7 +277,18 @@ object DatasetBenchmark {
DataFrame 2647 / 3116 37.8 26.5 1.3X
Dataset 4781 / 5155 20.9 47.8 0.7X
*/
- benchmark.run()
+ benchmark1.run()
+
+ /*
+ OpenJDK 64-Bit Server VM 1.8.0_121-8u121-b13-0ubuntu1.16.04.2-b13 on Linux 4.4.0-47-generic
+ Intel(R) Xeon(R) CPU E5-2667 v3 @ 3.20GHz
+ back-to-back filter Long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ ------------------------------------------------------------------------------------------------
+ RDD 846 / 1120 118.1 8.5 1.0X
+ DataFrame 270 / 329 370.9 2.7 3.1X
+ Dataset 545 / 789 183.5 5.4 1.6X
+ */
+ benchmark2.run()
/*
OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 3.10.0-327.18.2.el7.x86_64
@@ -189,7 +299,7 @@ object DatasetBenchmark {
DataFrame 59 / 72 1695.4 0.6 22.8X
Dataset 2777 / 2805 36.0 27.8 0.5X
*/
- benchmark2.run()
+ benchmark3.run()
/*
Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.12.1
@@ -201,6 +311,6 @@ object DatasetBenchmark {
Dataset sum using Aggregator 4656 / 4758 21.5 46.6 0.4X
Dataset complex Aggregator 6636 / 7039 15.1 66.4 0.3X
*/
- benchmark3.run()
+ benchmark4.run()
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
index 6b50cb3e48c7..82b707537e45 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
@@ -62,6 +62,40 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
2, 3, 4)
}
+ test("mapPrimitive") {
+ val dsInt = Seq(1, 2, 3).toDS()
+ checkDataset(dsInt.map(_ > 1), false, true, true)
+ checkDataset(dsInt.map(_ + 1), 2, 3, 4)
+ checkDataset(dsInt.map(_ + 8589934592L), 8589934593L, 8589934594L, 8589934595L)
+ checkDataset(dsInt.map(_ + 1.1F), 2.1F, 3.1F, 4.1F)
+ checkDataset(dsInt.map(_ + 1.23D), 2.23D, 3.23D, 4.23D)
+
+ val dsLong = Seq(1L, 2L, 3L).toDS()
+ checkDataset(dsLong.map(_ > 1), false, true, true)
+ checkDataset(dsLong.map(e => (e + 1).toInt), 2, 3, 4)
+ checkDataset(dsLong.map(_ + 8589934592L), 8589934593L, 8589934594L, 8589934595L)
+ checkDataset(dsLong.map(_ + 1.1F), 2.1F, 3.1F, 4.1F)
+ checkDataset(dsLong.map(_ + 1.23D), 2.23D, 3.23D, 4.23D)
+
+ val dsFloat = Seq(1F, 2F, 3F).toDS()
+ checkDataset(dsFloat.map(_ > 1), false, true, true)
+ checkDataset(dsFloat.map(e => (e + 1).toInt), 2, 3, 4)
+ checkDataset(dsFloat.map(e => (e + 123456L).toLong), 123457L, 123458L, 123459L)
+ checkDataset(dsFloat.map(_ + 1.1F), 2.1F, 3.1F, 4.1F)
+ checkDataset(dsFloat.map(_ + 1.23D), 2.23D, 3.23D, 4.23D)
+
+ val dsDouble = Seq(1D, 2D, 3D).toDS()
+ checkDataset(dsDouble.map(_ > 1), false, true, true)
+ checkDataset(dsDouble.map(e => (e + 1).toInt), 2, 3, 4)
+ checkDataset(dsDouble.map(e => (e + 8589934592L).toLong),
+ 8589934593L, 8589934594L, 8589934595L)
+ checkDataset(dsDouble.map(e => (e + 1.1F).toFloat), 2.1F, 3.1F, 4.1F)
+ checkDataset(dsDouble.map(_ + 1.23D), 2.23D, 3.23D, 4.23D)
+
+ val dsBoolean = Seq(true, false).toDS()
+ checkDataset(dsBoolean.map(e => !e), false, true)
+ }
+
test("filter") {
val ds = Seq(1, 2, 3, 4).toDS()
checkDataset(
@@ -69,6 +103,23 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
2, 4)
}
+ test("filterPrimitive") {
+ val dsInt = Seq(1, 2, 3).toDS()
+ checkDataset(dsInt.filter(_ > 1), 2, 3)
+
+ val dsLong = Seq(1L, 2L, 3L).toDS()
+ checkDataset(dsLong.filter(_ > 1), 2L, 3L)
+
+ val dsFloat = Seq(1F, 2F, 3F).toDS()
+ checkDataset(dsFloat.filter(_ > 1), 2F, 3F)
+
+ val dsDouble = Seq(1D, 2D, 3D).toDS()
+ checkDataset(dsDouble.filter(_ > 1), 2D, 3D)
+
+ val dsBoolean = Seq(true, false).toDS()
+ checkDataset(dsBoolean.filter(e => !e), false)
+ }
+
test("foreach") {
val ds = Seq(1, 2, 3).toDS()
val acc = sparkContext.longAccumulator
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
index 9c39b3c7f09b..cdea3b9a0f79 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql
import org.apache.spark.sql.functions.{from_json, struct, to_json}
import org.apache.spark.sql.test.SharedSQLContext
-import org.apache.spark.sql.types.{CalendarIntervalType, IntegerType, StructType, TimestampType}
+import org.apache.spark.sql.types._
class JsonFunctionsSuite extends QueryTest with SharedSQLContext {
import testImplicits._
@@ -133,6 +133,29 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext {
Row(null) :: Nil)
}
+ test("from_json invalid schema") {
+ val df = Seq("""{"a" 1}""").toDS()
+ val schema = ArrayType(StringType)
+ val message = intercept[AnalysisException] {
+ df.select(from_json($"value", schema))
+ }.getMessage
+
+ assert(message.contains(
+ "Input schema array must be a struct or an array of structs."))
+ }
+
+ test("from_json array support") {
+ val df = Seq("""[{"a": 1, "b": "a"}, {"a": 2}, { }]""").toDS()
+ val schema = ArrayType(
+ StructType(
+ StructField("a", IntegerType) ::
+ StructField("b", StringType) :: Nil))
+
+ checkAnswer(
+ df.select(from_json($"value", schema)),
+ Row(Seq(Row(1, "a"), Row(2, null), Row(null, null))))
+ }
+
test("to_json") {
val df = Seq(Tuple1(Tuple1(1))).toDF("a")
@@ -174,4 +197,27 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext {
.select(to_json($"struct").as("json"))
checkAnswer(dfTwo, readBackTwo)
}
+
+ test("SPARK-19637 Support to_json in SQL") {
+ val df1 = Seq(Tuple1(Tuple1(1))).toDF("a")
+ checkAnswer(
+ df1.selectExpr("to_json(a)"),
+ Row("""{"_1":1}""") :: Nil)
+
+ val df2 = Seq(Tuple1(Tuple1(java.sql.Timestamp.valueOf("2015-08-26 18:00:00.0")))).toDF("a")
+ checkAnswer(
+ df2.selectExpr("to_json(a, map('timestampFormat', 'dd/MM/yyyy HH:mm'))"),
+ Row("""{"_1":"26/08/2015 18:00"}""") :: Nil)
+
+ val errMsg1 = intercept[AnalysisException] {
+ df2.selectExpr("to_json(a, named_struct('a', 1))")
+ }
+ assert(errMsg1.getMessage.startsWith("Must use a map() function for options"))
+
+ val errMsg2 = intercept[AnalysisException] {
+ df2.selectExpr("to_json(a, map('a', 1))")
+ }
+ assert(errMsg2.getMessage.startsWith(
+ "A type of keys and values in map() must be string, but got"))
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala
index 37443d034298..328c5395ec91 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala
@@ -233,6 +233,18 @@ class MathFunctionsSuite extends QueryTest with SharedSQLContext {
)
}
+ test("round/bround with data frame from a local Seq of Product") {
+ val df = spark.createDataFrame(Seq(Tuple1(BigDecimal("5.9")))).toDF("value")
+ checkAnswer(
+ df.withColumn("value_rounded", round('value)),
+ Seq(Row(BigDecimal("5.9"), BigDecimal("6")))
+ )
+ checkAnswer(
+ df.withColumn("value_brounded", bround('value)),
+ Seq(Row(BigDecimal("5.9"), BigDecimal("6")))
+ )
+ }
+
test("exp") {
testOneToOneMathFunction(exp, math.exp)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 468ea0551298..d9e0196c5795 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -1019,6 +1019,18 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
spark.sessionState.conf.clear()
}
+ test("SET mapreduce.job.reduces automatically converted to spark.sql.shuffle.partitions") {
+ spark.sessionState.conf.clear()
+ val before = spark.conf.get(SQLConf.SHUFFLE_PARTITIONS.key).toInt
+ val newConf = before + 1
+ sql(s"SET mapreduce.job.reduces=${newConf.toString}")
+ val after = spark.conf.get(SQLConf.SHUFFLE_PARTITIONS.key).toInt
+ assert(before != after)
+ assert(newConf === after)
+ intercept[IllegalArgumentException](sql(s"SET mapreduce.job.reduces=-1"))
+ spark.sessionState.conf.clear()
+ }
+
test("apply schema") {
val schema1 = StructType(
StructField("f1", IntegerType, false) ::
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala
new file mode 100644
index 000000000000..2d5e37242a58
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.BeforeAndAfterEach
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+
+class SessionStateSuite extends SparkFunSuite
+ with BeforeAndAfterEach with BeforeAndAfterAll {
+
+ /**
+ * A shared SparkSession for all tests in this suite. Make sure you reset any changes to this
+ * session as this is a singleton HiveSparkSession in HiveSessionStateSuite and it's shared
+ * with all Hive test suites.
+ */
+ protected var activeSession: SparkSession = _
+
+ override def beforeAll(): Unit = {
+ activeSession = SparkSession.builder().master("local").getOrCreate()
+ }
+
+ override def afterAll(): Unit = {
+ if (activeSession != null) {
+ activeSession.stop()
+ activeSession = null
+ }
+ super.afterAll()
+ }
+
+ test("fork new session and inherit RuntimeConfig options") {
+ val key = "spark-config-clone"
+ try {
+ activeSession.conf.set(key, "active")
+
+ // inheritance
+ val forkedSession = activeSession.cloneSession()
+ assert(forkedSession ne activeSession)
+ assert(forkedSession.conf ne activeSession.conf)
+ assert(forkedSession.conf.get(key) == "active")
+
+ // independence
+ forkedSession.conf.set(key, "forked")
+ assert(activeSession.conf.get(key) == "active")
+ activeSession.conf.set(key, "dontcopyme")
+ assert(forkedSession.conf.get(key) == "forked")
+ } finally {
+ activeSession.conf.unset(key)
+ }
+ }
+
+ test("fork new session and inherit function registry and udf") {
+ val testFuncName1 = "strlenScala"
+ val testFuncName2 = "addone"
+ try {
+ activeSession.udf.register(testFuncName1, (_: String).length + (_: Int))
+ val forkedSession = activeSession.cloneSession()
+
+ // inheritance
+ assert(forkedSession ne activeSession)
+ assert(forkedSession.sessionState.functionRegistry ne
+ activeSession.sessionState.functionRegistry)
+ assert(forkedSession.sessionState.functionRegistry.lookupFunction(testFuncName1).nonEmpty)
+
+ // independence
+ forkedSession.sessionState.functionRegistry.dropFunction(testFuncName1)
+ assert(activeSession.sessionState.functionRegistry.lookupFunction(testFuncName1).nonEmpty)
+ activeSession.udf.register(testFuncName2, (_: Int) + 1)
+ assert(forkedSession.sessionState.functionRegistry.lookupFunction(testFuncName2).isEmpty)
+ } finally {
+ activeSession.sessionState.functionRegistry.dropFunction(testFuncName1)
+ activeSession.sessionState.functionRegistry.dropFunction(testFuncName2)
+ }
+ }
+
+ test("fork new session and inherit experimental methods") {
+ val originalExtraOptimizations = activeSession.experimental.extraOptimizations
+ val originalExtraStrategies = activeSession.experimental.extraStrategies
+ try {
+ object DummyRule1 extends Rule[LogicalPlan] {
+ def apply(p: LogicalPlan): LogicalPlan = p
+ }
+ object DummyRule2 extends Rule[LogicalPlan] {
+ def apply(p: LogicalPlan): LogicalPlan = p
+ }
+ val optimizations = List(DummyRule1, DummyRule2)
+ activeSession.experimental.extraOptimizations = optimizations
+ val forkedSession = activeSession.cloneSession()
+
+ // inheritance
+ assert(forkedSession ne activeSession)
+ assert(forkedSession.experimental ne activeSession.experimental)
+ assert(forkedSession.experimental.extraOptimizations.toSet ==
+ activeSession.experimental.extraOptimizations.toSet)
+
+ // independence
+ forkedSession.experimental.extraOptimizations = List(DummyRule2)
+ assert(activeSession.experimental.extraOptimizations == optimizations)
+ activeSession.experimental.extraOptimizations = List(DummyRule1)
+ assert(forkedSession.experimental.extraOptimizations == List(DummyRule2))
+ } finally {
+ activeSession.experimental.extraOptimizations = originalExtraOptimizations
+ activeSession.experimental.extraStrategies = originalExtraStrategies
+ }
+ }
+
+ test("fork new sessions and run query on inherited table") {
+ def checkTableExists(sparkSession: SparkSession): Unit = {
+ QueryTest.checkAnswer(sparkSession.sql(
+ """
+ |SELECT x.str, COUNT(*)
+ |FROM df x JOIN df y ON x.str = y.str
+ |GROUP BY x.str
+ """.stripMargin),
+ Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil)
+ }
+
+ val spark = activeSession
+ // Cannot use `import activeSession.implicits._` due to the compiler limitation.
+ import spark.implicits._
+
+ try {
+ activeSession
+ .createDataset[(Int, String)](Seq(1, 2, 3).map(i => (i, i.toString)))
+ .toDF("int", "str")
+ .createOrReplaceTempView("df")
+ checkTableExists(activeSession)
+
+ val forkedSession = activeSession.cloneSession()
+ assert(forkedSession ne activeSession)
+ assert(forkedSession.sessionState ne activeSession.sessionState)
+ checkTableExists(forkedSession)
+ checkTableExists(activeSession.cloneSession()) // ability to clone multiple times
+ checkTableExists(forkedSession.cloneSession()) // clone of clone
+ } finally {
+ activeSession.sql("drop table df")
+ }
+ }
+
+ test("fork new session and inherit reference to SharedState") {
+ val forkedSession = activeSession.cloneSession()
+ assert(activeSession.sharedState eq forkedSession.sharedState)
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala
index bbb31dbc8f3d..1f547c5a2a8f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala
@@ -112,30 +112,6 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared
spark.sessionState.conf.autoBroadcastJoinThreshold)
}
- test("estimates the size of limit") {
- withTempView("test") {
- Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v")
- .createOrReplaceTempView("test")
- Seq((0, 1), (1, 24), (2, 48)).foreach { case (limit, expected) =>
- val df = sql(s"""SELECT * FROM test limit $limit""")
-
- val sizesGlobalLimit = df.queryExecution.analyzed.collect { case g: GlobalLimit =>
- g.stats(conf).sizeInBytes
- }
- assert(sizesGlobalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}")
- assert(sizesGlobalLimit.head === BigInt(expected),
- s"expected exact size $expected for table 'test', got: ${sizesGlobalLimit.head}")
-
- val sizesLocalLimit = df.queryExecution.analyzed.collect { case l: LocalLimit =>
- l.stats(conf).sizeInBytes
- }
- assert(sizesLocalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}")
- assert(sizesLocalLimit.head === BigInt(expected),
- s"expected exact size $expected for table 'test', got: ${sizesLocalLimit.head}")
- }
- }
- }
-
test("column stats round trip serialization") {
// Make sure we serialize and then deserialize and we will get the result data
val df = data.toDF(stats.keys.toSeq :+ "carray" : _*)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 0bfc92fdb621..02ccebd22bdf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -242,11 +242,12 @@ class PlannerSuite extends SharedSQLContext {
val doubleRepartitioned = testData.repartition(10).repartition(20).coalesce(5)
def countRepartitions(plan: LogicalPlan): Int = plan.collect { case r: Repartition => r }.length
assert(countRepartitions(doubleRepartitioned.queryExecution.logical) === 3)
- assert(countRepartitions(doubleRepartitioned.queryExecution.optimizedPlan) === 1)
+ assert(countRepartitions(doubleRepartitioned.queryExecution.optimizedPlan) === 2)
doubleRepartitioned.queryExecution.optimizedPlan match {
- case r: Repartition =>
- assert(r.numPartitions === 5)
- assert(r.shuffle === false)
+ case Repartition (numPartitions, shuffle, Repartition(_, shuffleChild, _)) =>
+ assert(numPartitions === 5)
+ assert(shuffle === false)
+ assert(shuffleChild === true)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala
index 2d95cb6d64a8..2ca2206bb9d4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala
@@ -172,7 +172,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils {
var e = intercept[AnalysisException] {
sql(s"INSERT INTO TABLE $viewName SELECT 1")
}.getMessage
- assert(e.contains("Inserting into an RDD-based table is not allowed"))
+ assert(e.contains("Inserting into a view is not allowed. View: `default`.`testview`"))
val dataFilePath =
Thread.currentThread().getContextClassLoader.getResource("data/files/employee.dat")
@@ -609,12 +609,39 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils {
}
}
- // TODO: Check for cyclic view references on ALTER VIEW.
- ignore("correctly handle a cyclic view reference") {
- withView("view1", "view2") {
+ test("correctly handle a cyclic view reference") {
+ withView("view1", "view2", "view3") {
sql("CREATE VIEW view1 AS SELECT * FROM jt")
sql("CREATE VIEW view2 AS SELECT * FROM view1")
- intercept[AnalysisException](sql("ALTER VIEW view1 AS SELECT * FROM view2"))
+ sql("CREATE VIEW view3 AS SELECT * FROM view2")
+
+ // Detect cyclic view reference on ALTER VIEW.
+ val e1 = intercept[AnalysisException] {
+ sql("ALTER VIEW view1 AS SELECT * FROM view2")
+ }.getMessage
+ assert(e1.contains("Recursive view `default`.`view1` detected (cycle: `default`.`view1` " +
+ "-> `default`.`view2` -> `default`.`view1`)"))
+
+ // Detect the most left cycle when there exists multiple cyclic view references.
+ val e2 = intercept[AnalysisException] {
+ sql("ALTER VIEW view1 AS SELECT * FROM view3 JOIN view2")
+ }.getMessage
+ assert(e2.contains("Recursive view `default`.`view1` detected (cycle: `default`.`view1` " +
+ "-> `default`.`view3` -> `default`.`view2` -> `default`.`view1`)"))
+
+ // Detect cyclic view reference on CREATE OR REPLACE VIEW.
+ val e3 = intercept[AnalysisException] {
+ sql("CREATE OR REPLACE VIEW view1 AS SELECT * FROM view2")
+ }.getMessage
+ assert(e3.contains("Recursive view `default`.`view1` detected (cycle: `default`.`view1` " +
+ "-> `default`.`view2` -> `default`.`view1`)"))
+
+ // Detect cyclic view reference from subqueries.
+ val e4 = intercept[AnalysisException] {
+ sql("ALTER VIEW view1 AS SELECT * FROM jt WHERE EXISTS (SELECT 1 FROM view2)")
+ }.getMessage
+ assert(e4.contains("Recursive view `default`.`view1` detected (cycle: `default`.`view1` " +
+ "-> `default`.`view2` -> `default`.`view1`)"))
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala
index bb6c486e880a..a4d012cd7611 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala
@@ -45,7 +45,7 @@ class SparkSqlParserSuite extends PlanTest {
* Normalizes plans:
* - CreateTable the createTime in tableDesc will replaced by -1L.
*/
- private def normalizePlan(plan: LogicalPlan): LogicalPlan = {
+ override def normalizePlan(plan: LogicalPlan): LogicalPlan = {
plan match {
case CreateTable(tableDesc, mode, query) =>
val newTableDesc = tableDesc.copy(createTime = -1L)
@@ -210,6 +210,17 @@ class SparkSqlParserSuite extends PlanTest {
"no viable alternative at input")
}
+ test("create view as insert into table") {
+ // Single insert query
+ intercept("CREATE VIEW testView AS INSERT INTO jt VALUES(1, 1)",
+ "Operation not allowed: CREATE VIEW ... AS INSERT INTO")
+
+ // Multi insert query
+ intercept("CREATE VIEW testView AS FROM jt INSERT INTO tbl1 SELECT * WHERE jt.id < 5 " +
+ "INSERT INTO tbl2 SELECT * WHERE jt.id > 4",
+ "Operation not allowed: CREATE VIEW ... AS FROM ... [INSERT INTO ...]+")
+ }
+
test("SPARK-17328 Fix NPE with EXPLAIN DESCRIBE TABLE") {
assertEqual("describe table t",
DescribeTableCommand(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
index f355a5200ce2..0250a53fe232 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
@@ -234,8 +234,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
Seq(StringType, BinaryType, NullType, BooleanType,
ByteType, ShortType, IntegerType, LongType,
FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5),
- DateType, TimestampType,
- ArrayType(IntegerType), MapType(StringType, LongType), struct)
+ DateType, TimestampType, ArrayType(IntegerType), struct)
val fields = dataTypes.zipWithIndex.map { case (dataType, index) =>
StructField(s"col$index", dataType, true)
}
@@ -244,10 +243,10 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
// Create an RDD for the schema
val rdd =
- sparkContext.parallelize((1 to 10000), 10).map { i =>
+ sparkContext.parallelize(1 to 10000, 10).map { i =>
Row(
- s"str${i}: test cache.",
- s"binary${i}: test cache.".getBytes(StandardCharsets.UTF_8),
+ s"str$i: test cache.",
+ s"binary$i: test cache.".getBytes(StandardCharsets.UTF_8),
null,
i % 2 == 0,
i.toByte,
@@ -255,13 +254,12 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
i,
Long.MaxValue - i.toLong,
(i + 0.25).toFloat,
- (i + 0.75),
+ i + 0.75,
BigDecimal(Long.MaxValue.toString + ".12345"),
new java.math.BigDecimal(s"${i % 9 + 1}" + ".23456"),
new Date(i),
new Timestamp(i * 1000000L),
- (i to i + 10).toSeq,
- (i to i + 10).map(j => s"map_key_$j" -> (Long.MaxValue - j)).toMap,
+ i to i + 10,
Row((i - 0.25).toFloat, Seq(true, false, null)))
}
spark.createDataFrame(rdd, schema).createOrReplaceTempView("InMemoryCache_different_data_types")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala
index 76bb9e5929a7..4b73b078da38 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution.command
+import java.net.URI
+
import scala.reflect.{classTag, ClassTag}
import org.apache.spark.sql.catalyst.TableIdentifier
@@ -317,7 +319,7 @@ class DDLCommandSuite extends PlanTest {
val query = "CREATE EXTERNAL TABLE my_tab LOCATION '/something/anything'"
val ct = parseAs[CreateTable](query)
assert(ct.tableDesc.tableType == CatalogTableType.EXTERNAL)
- assert(ct.tableDesc.storage.locationUri == Some("/something/anything"))
+ assert(ct.tableDesc.storage.locationUri == Some(new URI("/something/anything")))
}
test("create hive table - property values must be set") {
@@ -334,7 +336,7 @@ class DDLCommandSuite extends PlanTest {
val query = "CREATE TABLE my_tab LOCATION '/something/anything'"
val ct = parseAs[CreateTable](query)
assert(ct.tableDesc.tableType == CatalogTableType.EXTERNAL)
- assert(ct.tableDesc.storage.locationUri == Some("/something/anything"))
+ assert(ct.tableDesc.storage.locationUri == Some(new URI("/something/anything")))
}
test("create table - with partitioned by") {
@@ -409,7 +411,7 @@ class DDLCommandSuite extends PlanTest {
val expectedTableDesc = CatalogTable(
identifier = TableIdentifier("my_tab"),
tableType = CatalogTableType.EXTERNAL,
- storage = CatalogStorageFormat.empty.copy(locationUri = Some("/tmp/file")),
+ storage = CatalogStorageFormat.empty.copy(locationUri = Some(new URI("/tmp/file"))),
schema = new StructType().add("a", IntegerType).add("b", StringType),
provider = Some("parquet"))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala
index 8b8cd0fdf4db..0666f446f3b5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala
@@ -26,29 +26,168 @@ import org.scalatest.BeforeAndAfterEach
import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SaveMode}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{DatabaseAlreadyExistsException, FunctionRegistry, NoSuchPartitionException, NoSuchTableException, TempTableAlreadyExistsException}
-import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogDatabase, CatalogStorageFormat}
-import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType}
-import org.apache.spark.sql.catalyst.catalog.{CatalogTablePartition, SessionCatalog}
+import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
-import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
-class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
- private val escapedIdentifier = "`(.+)`".r
+class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSQLContext with BeforeAndAfterEach {
override def afterEach(): Unit = {
try {
// drop all databases, tables and functions after each test
spark.sessionState.catalog.reset()
} finally {
- Utils.deleteRecursively(new File("spark-warehouse"))
+ Utils.deleteRecursively(new File(spark.sessionState.conf.warehousePath))
super.afterEach()
}
}
+ protected override def generateTable(
+ catalog: SessionCatalog,
+ name: TableIdentifier): CatalogTable = {
+ val storage =
+ CatalogStorageFormat.empty.copy(locationUri = Some(catalog.defaultTablePath(name)))
+ val metadata = new MetadataBuilder()
+ .putString("key", "value")
+ .build()
+ CatalogTable(
+ identifier = name,
+ tableType = CatalogTableType.EXTERNAL,
+ storage = storage,
+ schema = new StructType()
+ .add("col1", "int", nullable = true, metadata = metadata)
+ .add("col2", "string")
+ .add("a", "int")
+ .add("b", "int"),
+ provider = Some("parquet"),
+ partitionColumnNames = Seq("a", "b"),
+ createTime = 0L,
+ tracksPartitionsInCatalog = true)
+ }
+
+ test("desc table for parquet data source table using in-memory catalog") {
+ val tabName = "tab1"
+ withTable(tabName) {
+ sql(s"CREATE TABLE $tabName(a int comment 'test') USING parquet ")
+
+ checkAnswer(
+ sql(s"DESC $tabName").select("col_name", "data_type", "comment"),
+ Row("a", "int", "test")
+ )
+ }
+ }
+
+ test("alter table: set location (datasource table)") {
+ testSetLocation(isDatasourceTable = true)
+ }
+
+ test("alter table: set properties (datasource table)") {
+ testSetProperties(isDatasourceTable = true)
+ }
+
+ test("alter table: unset properties (datasource table)") {
+ testUnsetProperties(isDatasourceTable = true)
+ }
+
+ test("alter table: set serde (datasource table)") {
+ testSetSerde(isDatasourceTable = true)
+ }
+
+ test("alter table: set serde partition (datasource table)") {
+ testSetSerdePartition(isDatasourceTable = true)
+ }
+
+ test("alter table: change column (datasource table)") {
+ testChangeColumn(isDatasourceTable = true)
+ }
+
+ test("alter table: add partition (datasource table)") {
+ testAddPartitions(isDatasourceTable = true)
+ }
+
+ test("alter table: drop partition (datasource table)") {
+ testDropPartitions(isDatasourceTable = true)
+ }
+
+ test("alter table: rename partition (datasource table)") {
+ testRenamePartitions(isDatasourceTable = true)
+ }
+
+ test("drop table - data source table") {
+ testDropTable(isDatasourceTable = true)
+ }
+
+ test("create a managed Hive source table") {
+ assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "in-memory")
+ val tabName = "tbl"
+ withTable(tabName) {
+ val e = intercept[AnalysisException] {
+ sql(s"CREATE TABLE $tabName (i INT, j STRING)")
+ }.getMessage
+ assert(e.contains("Hive support is required to CREATE Hive TABLE"))
+ }
+ }
+
+ test("create an external Hive source table") {
+ assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "in-memory")
+ withTempDir { tempDir =>
+ val tabName = "tbl"
+ withTable(tabName) {
+ val e = intercept[AnalysisException] {
+ sql(
+ s"""
+ |CREATE EXTERNAL TABLE $tabName (i INT, j STRING)
+ |ROW FORMAT DELIMITED FIELDS TERMINATED BY ','
+ |LOCATION '${tempDir.toURI}'
+ """.stripMargin)
+ }.getMessage
+ assert(e.contains("Hive support is required to CREATE Hive TABLE"))
+ }
+ }
+ }
+
+ test("Create Hive Table As Select") {
+ import testImplicits._
+ withTable("t", "t1") {
+ var e = intercept[AnalysisException] {
+ sql("CREATE TABLE t SELECT 1 as a, 1 as b")
+ }.getMessage
+ assert(e.contains("Hive support is required to CREATE Hive TABLE (AS SELECT)"))
+
+ spark.range(1).select('id as 'a, 'id as 'b).write.saveAsTable("t1")
+ e = intercept[AnalysisException] {
+ sql("CREATE TABLE t SELECT a, b from t1")
+ }.getMessage
+ assert(e.contains("Hive support is required to CREATE Hive TABLE (AS SELECT)"))
+ }
+ }
+
+}
+
+abstract class DDLSuite extends QueryTest with SQLTestUtils {
+
+ protected def isUsingHiveMetastore: Boolean = {
+ spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "hive"
+ }
+
+ protected def generateTable(catalog: SessionCatalog, name: TableIdentifier): CatalogTable
+
+ private val escapedIdentifier = "`(.+)`".r
+
+ protected def normalizeCatalogTable(table: CatalogTable): CatalogTable = table
+
+ private def normalizeSerdeProp(props: Map[String, String]): Map[String, String] = {
+ props.filterNot(p => Seq("serialization.format", "path").contains(p._1))
+ }
+
+ private def checkCatalogTables(expected: CatalogTable, actual: CatalogTable): Unit = {
+ assert(normalizeCatalogTable(actual) == normalizeCatalogTable(expected))
+ }
+
/**
* Strip backticks, if any, from the string.
*/
@@ -72,37 +211,11 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
private def createDatabase(catalog: SessionCatalog, name: String): Unit = {
catalog.createDatabase(
- CatalogDatabase(name, "", spark.sessionState.conf.warehousePath, Map()),
+ CatalogDatabase(
+ name, "", CatalogUtils.stringToURI(spark.sessionState.conf.warehousePath), Map()),
ignoreIfExists = false)
}
- private def generateTable(catalog: SessionCatalog, name: TableIdentifier): CatalogTable = {
- val storage =
- CatalogStorageFormat(
- locationUri = Some(catalog.defaultTablePath(name)),
- inputFormat = None,
- outputFormat = None,
- serde = None,
- compressed = false,
- properties = Map())
- val metadata = new MetadataBuilder()
- .putString("key", "value")
- .build()
- CatalogTable(
- identifier = name,
- tableType = CatalogTableType.EXTERNAL,
- storage = storage,
- schema = new StructType()
- .add("col1", "int", nullable = true, metadata = metadata)
- .add("col2", "string")
- .add("a", "int")
- .add("b", "int"),
- provider = Some("parquet"),
- partitionColumnNames = Seq("a", "b"),
- createTime = 0L,
- tracksPartitionsInCatalog = true)
- }
-
private def createTable(catalog: SessionCatalog, name: TableIdentifier): Unit = {
catalog.createTable(generateTable(catalog, name), ignoreIfExists = false)
}
@@ -116,6 +229,11 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
catalog.createPartitions(tableName, Seq(part), ignoreIfExists = false)
}
+ private def getDBPath(dbName: String): URI = {
+ val warehousePath = makeQualifiedPath(spark.sessionState.conf.warehousePath)
+ new Path(CatalogUtils.URIToString(warehousePath), s"$dbName.db").toUri
+ }
+
test("the qualified path of a database is stored in the catalog") {
val catalog = spark.sessionState.catalog
@@ -133,24 +251,16 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
}
}
- private def makeQualifiedPath(path: String): String = {
- // copy-paste from SessionCatalog
- val hadoopPath = new Path(path)
- val fs = hadoopPath.getFileSystem(sparkContext.hadoopConfiguration)
- fs.makeQualified(hadoopPath).toString
- }
-
test("Create Database using Default Warehouse Path") {
val catalog = spark.sessionState.catalog
val dbName = "db1"
try {
sql(s"CREATE DATABASE $dbName")
val db1 = catalog.getDatabaseMetadata(dbName)
- val expectedLocation = makeQualifiedPath(s"spark-warehouse/$dbName.db")
assert(db1 == CatalogDatabase(
dbName,
"",
- expectedLocation,
+ getDBPath(dbName),
Map.empty))
sql(s"DROP DATABASE $dbName CASCADE")
assert(!catalog.databaseExists(dbName))
@@ -193,16 +303,17 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
val dbNameWithoutBackTicks = cleanIdentifier(dbName)
sql(s"CREATE DATABASE $dbName")
val db1 = catalog.getDatabaseMetadata(dbNameWithoutBackTicks)
- val expectedLocation = makeQualifiedPath(s"spark-warehouse/$dbNameWithoutBackTicks.db")
assert(db1 == CatalogDatabase(
dbNameWithoutBackTicks,
"",
- expectedLocation,
+ getDBPath(dbNameWithoutBackTicks),
Map.empty))
- intercept[DatabaseAlreadyExistsException] {
+ // TODO: HiveExternalCatalog should throw DatabaseAlreadyExistsException
+ val e = intercept[AnalysisException] {
sql(s"CREATE DATABASE $dbName")
- }
+ }.getMessage
+ assert(e.contains(s"already exists"))
} finally {
catalog.reset()
}
@@ -421,19 +532,6 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
}
}
- test("desc table for parquet data source table using in-memory catalog") {
- assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "in-memory")
- val tabName = "tab1"
- withTable(tabName) {
- sql(s"CREATE TABLE $tabName(a int comment 'test') USING parquet ")
-
- checkAnswer(
- sql(s"DESC $tabName").select("col_name", "data_type", "comment"),
- Row("a", "int", "test")
- )
- }
- }
-
test("Alter/Describe Database") {
val catalog = spark.sessionState.catalog
val databaseNames = Seq("db1", "`database`")
@@ -441,7 +539,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
databaseNames.foreach { dbName =>
try {
val dbNameWithoutBackTicks = cleanIdentifier(dbName)
- val location = makeQualifiedPath(s"spark-warehouse/$dbNameWithoutBackTicks.db")
+ val location = getDBPath(dbNameWithoutBackTicks)
sql(s"CREATE DATABASE $dbName")
@@ -449,7 +547,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
sql(s"DESCRIBE DATABASE EXTENDED $dbName"),
Row("Database Name", dbNameWithoutBackTicks) ::
Row("Description", "") ::
- Row("Location", location) ::
+ Row("Location", CatalogUtils.URIToString(location)) ::
Row("Properties", "") :: Nil)
sql(s"ALTER DATABASE $dbName SET DBPROPERTIES ('a'='a', 'b'='b', 'c'='c')")
@@ -458,7 +556,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
sql(s"DESCRIBE DATABASE EXTENDED $dbName"),
Row("Database Name", dbNameWithoutBackTicks) ::
Row("Description", "") ::
- Row("Location", location) ::
+ Row("Location", CatalogUtils.URIToString(location)) ::
Row("Properties", "((a,a), (b,b), (c,c))") :: Nil)
sql(s"ALTER DATABASE $dbName SET DBPROPERTIES ('d'='d')")
@@ -467,7 +565,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
sql(s"DESCRIBE DATABASE EXTENDED $dbName"),
Row("Database Name", dbNameWithoutBackTicks) ::
Row("Description", "") ::
- Row("Location", location) ::
+ Row("Location", CatalogUtils.URIToString(location)) ::
Row("Properties", "((a,a), (b,b), (c,c), (d,d))") :: Nil)
} finally {
catalog.reset()
@@ -485,7 +583,12 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
var message = intercept[AnalysisException] {
sql(s"DROP DATABASE $dbName")
}.getMessage
- assert(message.contains(s"Database '$dbNameWithoutBackTicks' not found"))
+ // TODO: Unify the exception.
+ if (isUsingHiveMetastore) {
+ assert(message.contains(s"NoSuchObjectException: $dbNameWithoutBackTicks"))
+ } else {
+ assert(message.contains(s"Database '$dbNameWithoutBackTicks' not found"))
+ }
message = intercept[AnalysisException] {
sql(s"ALTER DATABASE $dbName SET DBPROPERTIES ('d'='d')")
@@ -514,7 +617,12 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
val message = intercept[AnalysisException] {
sql(s"DROP DATABASE $dbName RESTRICT")
}.getMessage
- assert(message.contains(s"Database '$dbName' is not empty. One or more tables exist"))
+ // TODO: Unify the exception.
+ if (isUsingHiveMetastore) {
+ assert(message.contains(s"Database $dbName is not empty. One or more tables exist"))
+ } else {
+ assert(message.contains(s"Database '$dbName' is not empty. One or more tables exist"))
+ }
catalog.dropTable(tableIdent1, ignoreIfNotExists = false, purge = false)
@@ -545,7 +653,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
createTable(catalog, tableIdent1)
val expectedTableIdent = tableIdent1.copy(database = Some("default"))
val expectedTable = generateTable(catalog, expectedTableIdent)
- assert(catalog.getTableMetadata(tableIdent1) === expectedTable)
+ checkCatalogTables(expectedTable, catalog.getTableMetadata(tableIdent1))
}
test("create table in a specific db") {
@@ -554,7 +662,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
val tableIdent1 = TableIdentifier("tab1", Some("dbx"))
createTable(catalog, tableIdent1)
val expectedTable = generateTable(catalog, tableIdent1)
- assert(catalog.getTableMetadata(tableIdent1) === expectedTable)
+ checkCatalogTables(expectedTable, catalog.getTableMetadata(tableIdent1))
}
test("create table using") {
@@ -739,52 +847,28 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
testSetLocation(isDatasourceTable = false)
}
- test("alter table: set location (datasource table)") {
- testSetLocation(isDatasourceTable = true)
- }
-
test("alter table: set properties") {
testSetProperties(isDatasourceTable = false)
}
- test("alter table: set properties (datasource table)") {
- testSetProperties(isDatasourceTable = true)
- }
-
test("alter table: unset properties") {
testUnsetProperties(isDatasourceTable = false)
}
- test("alter table: unset properties (datasource table)") {
- testUnsetProperties(isDatasourceTable = true)
- }
-
// TODO: move this test to HiveDDLSuite.scala
ignore("alter table: set serde") {
testSetSerde(isDatasourceTable = false)
}
- test("alter table: set serde (datasource table)") {
- testSetSerde(isDatasourceTable = true)
- }
-
// TODO: move this test to HiveDDLSuite.scala
ignore("alter table: set serde partition") {
testSetSerdePartition(isDatasourceTable = false)
}
- test("alter table: set serde partition (datasource table)") {
- testSetSerdePartition(isDatasourceTable = true)
- }
-
test("alter table: change column") {
testChangeColumn(isDatasourceTable = false)
}
- test("alter table: change column (datasource table)") {
- testChangeColumn(isDatasourceTable = true)
- }
-
test("alter table: bucketing is not supported") {
val catalog = spark.sessionState.catalog
val tableIdent = TableIdentifier("tab1", Some("dbx"))
@@ -813,10 +897,6 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
testAddPartitions(isDatasourceTable = false)
}
- test("alter table: add partition (datasource table)") {
- testAddPartitions(isDatasourceTable = true)
- }
-
test("alter table: recover partitions (sequential)") {
withSQLConf("spark.rdd.parallelListingThreshold" -> "10") {
testRecoverPartitions()
@@ -829,7 +909,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
}
}
- private def testRecoverPartitions() {
+ protected def testRecoverPartitions() {
val catalog = spark.sessionState.catalog
// table to alter does not exist
intercept[AnalysisException] {
@@ -868,8 +948,14 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
sql("ALTER TABLE tab1 RECOVER PARTITIONS")
assert(catalog.listPartitions(tableIdent).map(_.spec).toSet ==
Set(part1, part2))
- assert(catalog.getPartition(tableIdent, part1).parameters("numFiles") == "1")
- assert(catalog.getPartition(tableIdent, part2).parameters("numFiles") == "2")
+ if (!isUsingHiveMetastore) {
+ assert(catalog.getPartition(tableIdent, part1).parameters("numFiles") == "1")
+ assert(catalog.getPartition(tableIdent, part2).parameters("numFiles") == "2")
+ } else {
+ // After ALTER TABLE, the statistics of the first partition is removed by Hive megastore
+ assert(catalog.getPartition(tableIdent, part1).parameters.get("numFiles").isEmpty)
+ assert(catalog.getPartition(tableIdent, part2).parameters("numFiles") == "2")
+ }
} finally {
fs.delete(root, true)
}
@@ -883,10 +969,6 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
testDropPartitions(isDatasourceTable = false)
}
- test("alter table: drop partition (datasource table)") {
- testDropPartitions(isDatasourceTable = true)
- }
-
test("alter table: drop partition is not supported for views") {
assertUnsupported("ALTER VIEW dbx.tab1 DROP IF EXISTS PARTITION (b='2')")
}
@@ -895,10 +977,6 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
testRenamePartitions(isDatasourceTable = false)
}
- test("alter table: rename partition (datasource table)") {
- testRenamePartitions(isDatasourceTable = true)
- }
-
test("show table extended") {
withTempView("show1a", "show2b") {
sql(
@@ -979,11 +1057,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
testDropTable(isDatasourceTable = false)
}
- test("drop table - data source table") {
- testDropTable(isDatasourceTable = true)
- }
-
- private def testDropTable(isDatasourceTable: Boolean): Unit = {
+ protected def testDropTable(isDatasourceTable: Boolean): Unit = {
val catalog = spark.sessionState.catalog
val tableIdent = TableIdentifier("tab1", Some("dbx"))
createDatabase(catalog, "dbx")
@@ -1019,9 +1093,10 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
tableIdent: TableIdentifier): Unit = {
catalog.alterTable(catalog.getTableMetadata(tableIdent).copy(
provider = Some("csv")))
+ assert(catalog.getTableMetadata(tableIdent).provider == Some("csv"))
}
- private def testSetProperties(isDatasourceTable: Boolean): Unit = {
+ protected def testSetProperties(isDatasourceTable: Boolean): Unit = {
val catalog = spark.sessionState.catalog
val tableIdent = TableIdentifier("tab1", Some("dbx"))
createDatabase(catalog, "dbx")
@@ -1030,7 +1105,11 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
convertToDatasourceTable(catalog, tableIdent)
}
def getProps: Map[String, String] = {
- catalog.getTableMetadata(tableIdent).properties
+ if (isUsingHiveMetastore) {
+ normalizeCatalogTable(catalog.getTableMetadata(tableIdent)).properties
+ } else {
+ catalog.getTableMetadata(tableIdent).properties
+ }
}
assert(getProps.isEmpty)
// set table properties
@@ -1046,7 +1125,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
}
}
- private def testUnsetProperties(isDatasourceTable: Boolean): Unit = {
+ protected def testUnsetProperties(isDatasourceTable: Boolean): Unit = {
val catalog = spark.sessionState.catalog
val tableIdent = TableIdentifier("tab1", Some("dbx"))
createDatabase(catalog, "dbx")
@@ -1055,7 +1134,11 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
convertToDatasourceTable(catalog, tableIdent)
}
def getProps: Map[String, String] = {
- catalog.getTableMetadata(tableIdent).properties
+ if (isUsingHiveMetastore) {
+ normalizeCatalogTable(catalog.getTableMetadata(tableIdent)).properties
+ } else {
+ catalog.getTableMetadata(tableIdent).properties
+ }
}
// unset table properties
sql("ALTER TABLE dbx.tab1 SET TBLPROPERTIES ('j' = 'am', 'p' = 'an', 'c' = 'lan', 'x' = 'y')")
@@ -1079,7 +1162,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
assert(getProps == Map("x" -> "y"))
}
- private def testSetLocation(isDatasourceTable: Boolean): Unit = {
+ protected def testSetLocation(isDatasourceTable: Boolean): Unit = {
val catalog = spark.sessionState.catalog
val tableIdent = TableIdentifier("tab1", Some("dbx"))
val partSpec = Map("a" -> "1", "b" -> "2")
@@ -1090,38 +1173,35 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
convertToDatasourceTable(catalog, tableIdent)
}
assert(catalog.getTableMetadata(tableIdent).storage.locationUri.isDefined)
- assert(catalog.getTableMetadata(tableIdent).storage.properties.isEmpty)
+ assert(normalizeSerdeProp(catalog.getTableMetadata(tableIdent).storage.properties).isEmpty)
assert(catalog.getPartition(tableIdent, partSpec).storage.locationUri.isDefined)
- assert(catalog.getPartition(tableIdent, partSpec).storage.properties.isEmpty)
+ assert(
+ normalizeSerdeProp(catalog.getPartition(tableIdent, partSpec).storage.properties).isEmpty)
+
// Verify that the location is set to the expected string
- def verifyLocation(expected: String, spec: Option[TablePartitionSpec] = None): Unit = {
+ def verifyLocation(expected: URI, spec: Option[TablePartitionSpec] = None): Unit = {
val storageFormat = spec
.map { s => catalog.getPartition(tableIdent, s).storage }
.getOrElse { catalog.getTableMetadata(tableIdent).storage }
- if (isDatasourceTable) {
- if (spec.isDefined) {
- assert(storageFormat.properties.isEmpty)
- assert(storageFormat.locationUri === Some(expected))
- } else {
- assert(storageFormat.locationUri === Some(expected))
- }
- } else {
- assert(storageFormat.locationUri === Some(expected))
- }
+ // TODO(gatorsmile): fix the bug in alter table set location.
+ // if (isUsingHiveMetastore) {
+ // assert(storageFormat.properties.get("path") === expected)
+ // }
+ assert(storageFormat.locationUri === Some(expected))
}
// set table location
sql("ALTER TABLE dbx.tab1 SET LOCATION '/path/to/your/lovely/heart'")
- verifyLocation("/path/to/your/lovely/heart")
+ verifyLocation(new URI("/path/to/your/lovely/heart"))
// set table partition location
sql("ALTER TABLE dbx.tab1 PARTITION (a='1', b='2') SET LOCATION '/path/to/part/ways'")
- verifyLocation("/path/to/part/ways", Some(partSpec))
+ verifyLocation(new URI("/path/to/part/ways"), Some(partSpec))
// set table location without explicitly specifying database
catalog.setCurrentDatabase("dbx")
sql("ALTER TABLE tab1 SET LOCATION '/swanky/steak/place'")
- verifyLocation("/swanky/steak/place")
+ verifyLocation(new URI("/swanky/steak/place"))
// set table partition location without explicitly specifying database
sql("ALTER TABLE tab1 PARTITION (a='1', b='2') SET LOCATION 'vienna'")
- verifyLocation("vienna", Some(partSpec))
+ verifyLocation(new URI("vienna"), Some(partSpec))
// table to alter does not exist
intercept[AnalysisException] {
sql("ALTER TABLE dbx.does_not_exist SET LOCATION '/mister/spark'")
@@ -1132,7 +1212,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
}
}
- private def testSetSerde(isDatasourceTable: Boolean): Unit = {
+ protected def testSetSerde(isDatasourceTable: Boolean): Unit = {
val catalog = spark.sessionState.catalog
val tableIdent = TableIdentifier("tab1", Some("dbx"))
createDatabase(catalog, "dbx")
@@ -1140,8 +1220,21 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
if (isDatasourceTable) {
convertToDatasourceTable(catalog, tableIdent)
}
- assert(catalog.getTableMetadata(tableIdent).storage.serde.isEmpty)
- assert(catalog.getTableMetadata(tableIdent).storage.properties.isEmpty)
+ def checkSerdeProps(expectedSerdeProps: Map[String, String]): Unit = {
+ val serdeProp = catalog.getTableMetadata(tableIdent).storage.properties
+ if (isUsingHiveMetastore) {
+ assert(normalizeSerdeProp(serdeProp) == expectedSerdeProps)
+ } else {
+ assert(serdeProp == expectedSerdeProps)
+ }
+ }
+ if (isUsingHiveMetastore) {
+ assert(catalog.getTableMetadata(tableIdent).storage.serde ==
+ Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"))
+ } else {
+ assert(catalog.getTableMetadata(tableIdent).storage.serde.isEmpty)
+ }
+ checkSerdeProps(Map.empty[String, String])
// set table serde and/or properties (should fail on datasource tables)
if (isDatasourceTable) {
val e1 = intercept[AnalysisException] {
@@ -1154,31 +1247,30 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
assert(e1.getMessage.contains("datasource"))
assert(e2.getMessage.contains("datasource"))
} else {
- sql("ALTER TABLE dbx.tab1 SET SERDE 'org.apache.jadoop'")
- assert(catalog.getTableMetadata(tableIdent).storage.serde == Some("org.apache.jadoop"))
- assert(catalog.getTableMetadata(tableIdent).storage.properties.isEmpty)
- sql("ALTER TABLE dbx.tab1 SET SERDE 'org.apache.madoop' " +
+ val newSerde = "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe"
+ sql(s"ALTER TABLE dbx.tab1 SET SERDE '$newSerde'")
+ assert(catalog.getTableMetadata(tableIdent).storage.serde == Some(newSerde))
+ checkSerdeProps(Map.empty[String, String])
+ val serde2 = "org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe"
+ sql(s"ALTER TABLE dbx.tab1 SET SERDE '$serde2' " +
"WITH SERDEPROPERTIES ('k' = 'v', 'kay' = 'vee')")
- assert(catalog.getTableMetadata(tableIdent).storage.serde == Some("org.apache.madoop"))
- assert(catalog.getTableMetadata(tableIdent).storage.properties ==
- Map("k" -> "v", "kay" -> "vee"))
+ assert(catalog.getTableMetadata(tableIdent).storage.serde == Some(serde2))
+ checkSerdeProps(Map("k" -> "v", "kay" -> "vee"))
}
// set serde properties only
sql("ALTER TABLE dbx.tab1 SET SERDEPROPERTIES ('k' = 'vvv', 'kay' = 'vee')")
- assert(catalog.getTableMetadata(tableIdent).storage.properties ==
- Map("k" -> "vvv", "kay" -> "vee"))
+ checkSerdeProps(Map("k" -> "vvv", "kay" -> "vee"))
// set things without explicitly specifying database
catalog.setCurrentDatabase("dbx")
sql("ALTER TABLE tab1 SET SERDEPROPERTIES ('kay' = 'veee')")
- assert(catalog.getTableMetadata(tableIdent).storage.properties ==
- Map("k" -> "vvv", "kay" -> "veee"))
+ checkSerdeProps(Map("k" -> "vvv", "kay" -> "veee"))
// table to alter does not exist
intercept[AnalysisException] {
sql("ALTER TABLE does_not_exist SET SERDEPROPERTIES ('x' = 'y')")
}
}
- private def testSetSerdePartition(isDatasourceTable: Boolean): Unit = {
+ protected def testSetSerdePartition(isDatasourceTable: Boolean): Unit = {
val catalog = spark.sessionState.catalog
val tableIdent = TableIdentifier("tab1", Some("dbx"))
val spec = Map("a" -> "1", "b" -> "2")
@@ -1191,8 +1283,21 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
if (isDatasourceTable) {
convertToDatasourceTable(catalog, tableIdent)
}
- assert(catalog.getPartition(tableIdent, spec).storage.serde.isEmpty)
- assert(catalog.getPartition(tableIdent, spec).storage.properties.isEmpty)
+ def checkPartitionSerdeProps(expectedSerdeProps: Map[String, String]): Unit = {
+ val serdeProp = catalog.getPartition(tableIdent, spec).storage.properties
+ if (isUsingHiveMetastore) {
+ assert(normalizeSerdeProp(serdeProp) == expectedSerdeProps)
+ } else {
+ assert(serdeProp == expectedSerdeProps)
+ }
+ }
+ if (isUsingHiveMetastore) {
+ assert(catalog.getPartition(tableIdent, spec).storage.serde ==
+ Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"))
+ } else {
+ assert(catalog.getPartition(tableIdent, spec).storage.serde.isEmpty)
+ }
+ checkPartitionSerdeProps(Map.empty[String, String])
// set table serde and/or properties (should fail on datasource tables)
if (isDatasourceTable) {
val e1 = intercept[AnalysisException] {
@@ -1207,26 +1312,23 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
} else {
sql("ALTER TABLE dbx.tab1 PARTITION (a=1, b=2) SET SERDE 'org.apache.jadoop'")
assert(catalog.getPartition(tableIdent, spec).storage.serde == Some("org.apache.jadoop"))
- assert(catalog.getPartition(tableIdent, spec).storage.properties.isEmpty)
+ checkPartitionSerdeProps(Map.empty[String, String])
sql("ALTER TABLE dbx.tab1 PARTITION (a=1, b=2) SET SERDE 'org.apache.madoop' " +
"WITH SERDEPROPERTIES ('k' = 'v', 'kay' = 'vee')")
assert(catalog.getPartition(tableIdent, spec).storage.serde == Some("org.apache.madoop"))
- assert(catalog.getPartition(tableIdent, spec).storage.properties ==
- Map("k" -> "v", "kay" -> "vee"))
+ checkPartitionSerdeProps(Map("k" -> "v", "kay" -> "vee"))
}
// set serde properties only
maybeWrapException(isDatasourceTable) {
sql("ALTER TABLE dbx.tab1 PARTITION (a=1, b=2) " +
"SET SERDEPROPERTIES ('k' = 'vvv', 'kay' = 'vee')")
- assert(catalog.getPartition(tableIdent, spec).storage.properties ==
- Map("k" -> "vvv", "kay" -> "vee"))
+ checkPartitionSerdeProps(Map("k" -> "vvv", "kay" -> "vee"))
}
// set things without explicitly specifying database
catalog.setCurrentDatabase("dbx")
maybeWrapException(isDatasourceTable) {
sql("ALTER TABLE tab1 PARTITION (a=1, b=2) SET SERDEPROPERTIES ('kay' = 'veee')")
- assert(catalog.getPartition(tableIdent, spec).storage.properties ==
- Map("k" -> "vvv", "kay" -> "veee"))
+ checkPartitionSerdeProps(Map("k" -> "vvv", "kay" -> "veee"))
}
// table to alter does not exist
intercept[AnalysisException] {
@@ -1234,7 +1336,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
}
}
- private def testAddPartitions(isDatasourceTable: Boolean): Unit = {
+ protected def testAddPartitions(isDatasourceTable: Boolean): Unit = {
val catalog = spark.sessionState.catalog
val tableIdent = TableIdentifier("tab1", Some("dbx"))
val part1 = Map("a" -> "1", "b" -> "5")
@@ -1255,7 +1357,15 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
"PARTITION (a='2', b='6') LOCATION 'paris' PARTITION (a='3', b='7')")
assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2, part3))
assert(catalog.getPartition(tableIdent, part1).storage.locationUri.isDefined)
- assert(catalog.getPartition(tableIdent, part2).storage.locationUri == Option("paris"))
+ val partitionLocation = if (isUsingHiveMetastore) {
+ val tableLocation = catalog.getTableMetadata(tableIdent).storage.locationUri
+ assert(tableLocation.isDefined)
+ makeQualifiedPath(new Path(tableLocation.get.toString, "paris").toString)
+ } else {
+ new URI("paris")
+ }
+
+ assert(catalog.getPartition(tableIdent, part2).storage.locationUri == Option(partitionLocation))
assert(catalog.getPartition(tableIdent, part3).storage.locationUri.isDefined)
// add partitions without explicitly specifying database
@@ -1285,7 +1395,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
Set(part1, part2, part3, part4, part5))
}
- private def testDropPartitions(isDatasourceTable: Boolean): Unit = {
+ protected def testDropPartitions(isDatasourceTable: Boolean): Unit = {
val catalog = spark.sessionState.catalog
val tableIdent = TableIdentifier("tab1", Some("dbx"))
val part1 = Map("a" -> "1", "b" -> "5")
@@ -1338,7 +1448,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
assert(catalog.listPartitions(tableIdent).isEmpty)
}
- private def testRenamePartitions(isDatasourceTable: Boolean): Unit = {
+ protected def testRenamePartitions(isDatasourceTable: Boolean): Unit = {
val catalog = spark.sessionState.catalog
val tableIdent = TableIdentifier("tab1", Some("dbx"))
val part1 = Map("a" -> "1", "b" -> "q")
@@ -1382,7 +1492,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
Set(Map("a" -> "1", "b" -> "p"), Map("a" -> "20", "b" -> "c"), Map("a" -> "3", "b" -> "p")))
}
- private def testChangeColumn(isDatasourceTable: Boolean): Unit = {
+ protected def testChangeColumn(isDatasourceTable: Boolean): Unit = {
val catalog = spark.sessionState.catalog
val resolver = spark.sessionState.conf.resolver
val tableIdent = TableIdentifier("tab1", Some("dbx"))
@@ -1482,35 +1592,6 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
)
}
- test("create a managed Hive source table") {
- assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "in-memory")
- val tabName = "tbl"
- withTable(tabName) {
- val e = intercept[AnalysisException] {
- sql(s"CREATE TABLE $tabName (i INT, j STRING)")
- }.getMessage
- assert(e.contains("Hive support is required to CREATE Hive TABLE"))
- }
- }
-
- test("create an external Hive source table") {
- assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "in-memory")
- withTempDir { tempDir =>
- val tabName = "tbl"
- withTable(tabName) {
- val e = intercept[AnalysisException] {
- sql(
- s"""
- |CREATE EXTERNAL TABLE $tabName (i INT, j STRING)
- |ROW FORMAT DELIMITED FIELDS TERMINATED BY ','
- |LOCATION '${tempDir.toURI}'
- """.stripMargin)
- }.getMessage
- assert(e.contains("Hive support is required to CREATE Hive TABLE"))
- }
- }
- }
-
test("create a data source table without schema") {
import testImplicits._
withTempPath { tempDir =>
@@ -1549,22 +1630,6 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
}
}
- test("Create Hive Table As Select") {
- import testImplicits._
- withTable("t", "t1") {
- var e = intercept[AnalysisException] {
- sql("CREATE TABLE t SELECT 1 as a, 1 as b")
- }.getMessage
- assert(e.contains("Hive support is required to CREATE Hive TABLE (AS SELECT)"))
-
- spark.range(1).select('id as 'a, 'id as 'b).write.saveAsTable("t1")
- e = intercept[AnalysisException] {
- sql("CREATE TABLE t SELECT a, b from t1")
- }.getMessage
- assert(e.contains("Hive support is required to CREATE Hive TABLE (AS SELECT)"))
- }
- }
-
test("Create Data Source Table As Select") {
import testImplicits._
withTable("t", "t1", "t2") {
@@ -1588,7 +1653,8 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
}
test("drop default database") {
- Seq("true", "false").foreach { caseSensitive =>
+ val caseSensitiveOptions = if (isUsingHiveMetastore) Seq("false") else Seq("true", "false")
+ caseSensitiveOptions.foreach { caseSensitive =>
withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive) {
var message = intercept[AnalysisException] {
sql("DROP DATABASE default")
@@ -1819,7 +1885,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
// SET LOCATION won't move data from previous table path to new table path.
assert(spark.table("tbl").count() == 0)
// the previous table path should be still there.
- assert(new File(new URI(defaultTablePath)).exists())
+ assert(new File(defaultTablePath).exists())
sql("INSERT INTO tbl SELECT 2")
checkAnswer(spark.table("tbl"), Row(2))
@@ -1833,7 +1899,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
}
}
- test("insert data to a data source table which has a not existed location should succeed") {
+ test("insert data to a data source table which has a non-existing location should succeed") {
withTable("t") {
withTempDir { dir =>
spark.sql(
@@ -1843,28 +1909,27 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
|OPTIONS(path "$dir")
""".stripMargin)
val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t"))
- assert(table.location == dir.getAbsolutePath)
+ assert(table.location == makeQualifiedPath(dir.getAbsolutePath))
dir.delete
- val tableLocFile = new File(table.location)
- assert(!tableLocFile.exists)
+ assert(!dir.exists)
spark.sql("INSERT INTO TABLE t SELECT 'c', 1")
- assert(tableLocFile.exists)
+ assert(dir.exists)
checkAnswer(spark.table("t"), Row("c", 1) :: Nil)
Utils.deleteRecursively(dir)
- assert(!tableLocFile.exists)
+ assert(!dir.exists)
spark.sql("INSERT OVERWRITE TABLE t SELECT 'c', 1")
- assert(tableLocFile.exists)
+ assert(dir.exists)
checkAnswer(spark.table("t"), Row("c", 1) :: Nil)
val newDirFile = new File(dir, "x")
- val newDir = newDirFile.toURI.toString
+ val newDir = newDirFile.getAbsolutePath
spark.sql(s"ALTER TABLE t SET LOCATION '$newDir'")
spark.sessionState.catalog.refreshTable(TableIdentifier("t"))
val table1 = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t"))
- assert(table1.location == newDir)
+ assert(table1.location == new URI(newDir))
assert(!newDirFile.exists)
spark.sql("INSERT INTO TABLE t SELECT 'c', 1")
@@ -1874,7 +1939,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
}
}
- test("insert into a data source table with no existed partition location should succeed") {
+ test("insert into a data source table with a non-existing partition location should succeed") {
withTable("t") {
withTempDir { dir =>
spark.sql(
@@ -1885,7 +1950,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
|LOCATION "$dir"
""".stripMargin)
val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t"))
- assert(table.location == dir.getAbsolutePath)
+ assert(table.location == makeQualifiedPath(dir.getAbsolutePath))
spark.sql("INSERT INTO TABLE t PARTITION(a=1, b=2) SELECT 3, 4")
checkAnswer(spark.table("t"), Row(3, 4, 1, 2) :: Nil)
@@ -1901,7 +1966,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
}
}
- test("read data from a data source table which has a not existed location should succeed") {
+ test("read data from a data source table which has a non-existing location should succeed") {
withTable("t") {
withTempDir { dir =>
spark.sql(
@@ -1911,13 +1976,14 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
|OPTIONS(path "$dir")
""".stripMargin)
val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t"))
- assert(table.location == dir.getAbsolutePath)
+
+ assert(table.location == makeQualifiedPath(dir.getAbsolutePath))
dir.delete()
checkAnswer(spark.table("t"), Nil)
val newDirFile = new File(dir, "x")
- val newDir = newDirFile.toURI.toString
+ val newDir = newDirFile.toURI
spark.sql(s"ALTER TABLE t SET LOCATION '$newDir'")
val table1 = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t"))
@@ -1928,7 +1994,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
}
}
- test("read data from a data source table with no existed partition location should succeed") {
+ test("read data from a data source table with non-existing partition location should succeed") {
withTable("t") {
withTempDir { dir =>
spark.sql(
@@ -1950,50 +2016,204 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
}
}
+ test("create datasource table with a non-existing location") {
+ withTable("t", "t1") {
+ withTempPath { dir =>
+ spark.sql(s"CREATE TABLE t(a int, b int) USING parquet LOCATION '$dir'")
+
+ val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t"))
+ assert(table.location == makeQualifiedPath(dir.getAbsolutePath))
+
+ spark.sql("INSERT INTO TABLE t SELECT 1, 2")
+ assert(dir.exists())
+
+ checkAnswer(spark.table("t"), Row(1, 2))
+ }
+ // partition table
+ withTempPath { dir =>
+ spark.sql(s"CREATE TABLE t1(a int, b int) USING parquet PARTITIONED BY(a) LOCATION '$dir'")
+
+ val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1"))
+ assert(table.location == makeQualifiedPath(dir.getAbsolutePath))
+
+ spark.sql("INSERT INTO TABLE t1 PARTITION(a=1) SELECT 2")
+
+ val partDir = new File(dir, "a=1")
+ assert(partDir.exists())
+
+ checkAnswer(spark.table("t1"), Row(2, 1))
+ }
+ }
+ }
+
Seq(true, false).foreach { shouldDelete =>
- val tcName = if (shouldDelete) "non-existent" else "existed"
+ val tcName = if (shouldDelete) "non-existing" else "existed"
test(s"CTAS for external data source table with a $tcName location") {
withTable("t", "t1") {
- withTempDir {
- dir =>
- if (shouldDelete) {
- dir.delete()
- }
- spark.sql(
- s"""
- |CREATE TABLE t
- |USING parquet
- |LOCATION '$dir'
- |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d
- """.stripMargin)
- val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t"))
- assert(table.location == dir.getAbsolutePath)
-
- checkAnswer(spark.table("t"), Row(3, 4, 1, 2))
+ withTempDir { dir =>
+ if (shouldDelete) dir.delete()
+ spark.sql(
+ s"""
+ |CREATE TABLE t
+ |USING parquet
+ |LOCATION '$dir'
+ |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d
+ """.stripMargin)
+ val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t"))
+ assert(table.location == makeQualifiedPath(dir.getAbsolutePath))
+
+ checkAnswer(spark.table("t"), Row(3, 4, 1, 2))
}
// partition table
- withTempDir {
- dir =>
- if (shouldDelete) {
- dir.delete()
- }
- spark.sql(
- s"""
- |CREATE TABLE t1
- |USING parquet
- |PARTITIONED BY(a, b)
- |LOCATION '$dir'
- |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d
- """.stripMargin)
- val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1"))
- assert(table.location == dir.getAbsolutePath)
-
- val partDir = new File(dir, "a=3")
- assert(partDir.exists())
-
- checkAnswer(spark.table("t1"), Row(1, 2, 3, 4))
+ withTempDir { dir =>
+ if (shouldDelete) dir.delete()
+ spark.sql(
+ s"""
+ |CREATE TABLE t1
+ |USING parquet
+ |PARTITIONED BY(a, b)
+ |LOCATION '$dir'
+ |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d
+ """.stripMargin)
+ val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1"))
+ assert(table.location == makeQualifiedPath(dir.getAbsolutePath))
+
+ val partDir = new File(dir, "a=3")
+ assert(partDir.exists())
+
+ checkAnswer(spark.table("t1"), Row(1, 2, 3, 4))
+ }
+ }
+ }
+ }
+
+ Seq("a b", "a:b", "a%b", "a,b").foreach { specialChars =>
+ test(s"data source table:partition column name containing $specialChars") {
+ withTable("t") {
+ withTempDir { dir =>
+ spark.sql(
+ s"""
+ |CREATE TABLE t(a string, `$specialChars` string)
+ |USING parquet
+ |PARTITIONED BY(`$specialChars`)
+ |LOCATION '$dir'
+ """.stripMargin)
+
+ assert(dir.listFiles().isEmpty)
+ spark.sql(s"INSERT INTO TABLE t PARTITION(`$specialChars`=2) SELECT 1")
+ val partEscaped = s"${ExternalCatalogUtils.escapePathName(specialChars)}=2"
+ val partFile = new File(dir, partEscaped)
+ assert(partFile.listFiles().length >= 1)
+ checkAnswer(spark.table("t"), Row("1", "2") :: Nil)
+ }
+ }
+ }
+ }
+
+ Seq("a b", "a:b", "a%b").foreach { specialChars =>
+ test(s"location uri contains $specialChars for datasource table") {
+ withTable("t", "t1") {
+ withTempDir { dir =>
+ val loc = new File(dir, specialChars)
+ loc.mkdir()
+ spark.sql(
+ s"""
+ |CREATE TABLE t(a string)
+ |USING parquet
+ |LOCATION '$loc'
+ """.stripMargin)
+
+ val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t"))
+ assert(table.location == makeQualifiedPath(loc.getAbsolutePath))
+ assert(new Path(table.location).toString.contains(specialChars))
+
+ assert(loc.listFiles().isEmpty)
+ spark.sql("INSERT INTO TABLE t SELECT 1")
+ assert(loc.listFiles().length >= 1)
+ checkAnswer(spark.table("t"), Row("1") :: Nil)
+ }
+
+ withTempDir { dir =>
+ val loc = new File(dir, specialChars)
+ loc.mkdir()
+ spark.sql(
+ s"""
+ |CREATE TABLE t1(a string, b string)
+ |USING parquet
+ |PARTITIONED BY(b)
+ |LOCATION '$loc'
+ """.stripMargin)
+
+ val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1"))
+ assert(table.location == makeQualifiedPath(loc.getAbsolutePath))
+ assert(new Path(table.location).toString.contains(specialChars))
+
+ assert(loc.listFiles().isEmpty)
+ spark.sql("INSERT INTO TABLE t1 PARTITION(b=2) SELECT 1")
+ val partFile = new File(loc, "b=2")
+ assert(partFile.listFiles().length >= 1)
+ checkAnswer(spark.table("t1"), Row("1", "2") :: Nil)
+
+ spark.sql("INSERT INTO TABLE t1 PARTITION(b='2017-03-03 12:13%3A14') SELECT 1")
+ val partFile1 = new File(loc, "b=2017-03-03 12:13%3A14")
+ assert(!partFile1.exists())
+ val partFile2 = new File(loc, "b=2017-03-03 12%3A13%253A14")
+ assert(partFile2.listFiles().length >= 1)
+ checkAnswer(spark.table("t1"), Row("1", "2") :: Row("1", "2017-03-03 12:13%3A14") :: Nil)
}
}
}
}
+
+ Seq("a b", "a:b", "a%b").foreach { specialChars =>
+ test(s"location uri contains $specialChars for database") {
+ try {
+ withTable("t") {
+ withTempDir { dir =>
+ val loc = new File(dir, specialChars)
+ spark.sql(s"CREATE DATABASE tmpdb LOCATION '$loc'")
+ spark.sql("USE tmpdb")
+
+ import testImplicits._
+ Seq(1).toDF("a").write.saveAsTable("t")
+ val tblloc = new File(loc, "t")
+ val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t"))
+ assert(table.location == makeQualifiedPath(tblloc.getAbsolutePath))
+ assert(tblloc.listFiles().nonEmpty)
+ }
+ }
+ } finally {
+ spark.sql("DROP DATABASE IF EXISTS tmpdb")
+ }
+ }
+ }
+
+ test("the qualified path of a datasource table is stored in the catalog") {
+ withTable("t", "t1") {
+ withTempDir { dir =>
+ assert(!dir.getAbsolutePath.startsWith("file:/"))
+ spark.sql(
+ s"""
+ |CREATE TABLE t(a string)
+ |USING parquet
+ |LOCATION '$dir'
+ """.stripMargin)
+ val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t"))
+ assert(table.location.toString.startsWith("file:/"))
+ }
+
+ withTempDir { dir =>
+ assert(!dir.getAbsolutePath.startsWith("file:/"))
+ spark.sql(
+ s"""
+ |CREATE TABLE t1(a string, b string)
+ |USING parquet
+ |PARTITIONED BY(b)
+ |LOCATION '$dir'
+ """.stripMargin)
+ val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1"))
+ assert(table.location.toString.startsWith("file:/"))
+ }
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
index 56071803f685..4435e4df38ef 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
@@ -129,6 +129,22 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
verifyCars(cars, withHeader = true, checkTypes = true)
}
+ test("simple csv test with string dataset") {
+ val csvDataset = spark.read.text(testFile(carsFile)).as[String]
+ val cars = spark.read
+ .option("header", "true")
+ .option("inferSchema", "true")
+ .csv(csvDataset)
+
+ verifyCars(cars, withHeader = true, checkTypes = true)
+
+ val carsWithoutHeader = spark.read
+ .option("header", "false")
+ .csv(csvDataset)
+
+ verifyCars(carsWithoutHeader, withHeader = false, checkTypes = false)
+ }
+
test("test inferring booleans") {
val result = spark.read
.format("csv")
@@ -1077,17 +1093,26 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
}
}
- test("Empty file produces empty dataframe with empty schema - wholeFile option") {
- withTempPath { path =>
- path.createNewFile()
-
+ test("Empty file produces empty dataframe with empty schema") {
+ Seq(false, true).foreach { wholeFile =>
val df = spark.read.format("csv")
.option("header", true)
- .option("wholeFile", true)
- .load(path.getAbsolutePath)
+ .option("wholeFile", wholeFile)
+ .load(testFile(emptyFile))
assert(df.schema === spark.emptyDataFrame.schema)
checkAnswer(df, spark.emptyDataFrame)
}
}
+
+ test("Empty string dataset produces empty dataframe and keep user-defined schema") {
+ val df1 = spark.read.csv(spark.emptyDataset[String])
+ assert(df1.schema === spark.emptyDataFrame.schema)
+ checkAnswer(df1, spark.emptyDataFrame)
+
+ val schema = StructType(StructField("a", StringType) :: Nil)
+ val df2 = spark.read.schema(schema).csv(spark.emptyDataset[String])
+ assert(df2.schema === schema)
+ }
+
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala
index 8a980a7eb538..6aa940afbb2c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala
@@ -368,88 +368,6 @@ class ParquetSchemaSuite extends ParquetSchemaTest {
}
}
- test("merge with metastore schema") {
- // Field type conflict resolution
- assertResult(
- StructType(Seq(
- StructField("lowerCase", StringType),
- StructField("UPPERCase", DoubleType, nullable = false)))) {
-
- ParquetFileFormat.mergeMetastoreParquetSchema(
- StructType(Seq(
- StructField("lowercase", StringType),
- StructField("uppercase", DoubleType, nullable = false))),
-
- StructType(Seq(
- StructField("lowerCase", BinaryType),
- StructField("UPPERCase", IntegerType, nullable = true))))
- }
-
- // MetaStore schema is subset of parquet schema
- assertResult(
- StructType(Seq(
- StructField("UPPERCase", DoubleType, nullable = false)))) {
-
- ParquetFileFormat.mergeMetastoreParquetSchema(
- StructType(Seq(
- StructField("uppercase", DoubleType, nullable = false))),
-
- StructType(Seq(
- StructField("lowerCase", BinaryType),
- StructField("UPPERCase", IntegerType, nullable = true))))
- }
-
- // Metastore schema contains additional non-nullable fields.
- assert(intercept[Throwable] {
- ParquetFileFormat.mergeMetastoreParquetSchema(
- StructType(Seq(
- StructField("uppercase", DoubleType, nullable = false),
- StructField("lowerCase", BinaryType, nullable = false))),
-
- StructType(Seq(
- StructField("UPPERCase", IntegerType, nullable = true))))
- }.getMessage.contains("detected conflicting schemas"))
-
- // Conflicting non-nullable field names
- intercept[Throwable] {
- ParquetFileFormat.mergeMetastoreParquetSchema(
- StructType(Seq(StructField("lower", StringType, nullable = false))),
- StructType(Seq(StructField("lowerCase", BinaryType))))
- }
- }
-
- test("merge missing nullable fields from Metastore schema") {
- // Standard case: Metastore schema contains additional nullable fields not present
- // in the Parquet file schema.
- assertResult(
- StructType(Seq(
- StructField("firstField", StringType, nullable = true),
- StructField("secondField", StringType, nullable = true),
- StructField("thirdfield", StringType, nullable = true)))) {
- ParquetFileFormat.mergeMetastoreParquetSchema(
- StructType(Seq(
- StructField("firstfield", StringType, nullable = true),
- StructField("secondfield", StringType, nullable = true),
- StructField("thirdfield", StringType, nullable = true))),
- StructType(Seq(
- StructField("firstField", StringType, nullable = true),
- StructField("secondField", StringType, nullable = true))))
- }
-
- // Merge should fail if the Metastore contains any additional fields that are not
- // nullable.
- assert(intercept[Throwable] {
- ParquetFileFormat.mergeMetastoreParquetSchema(
- StructType(Seq(
- StructField("firstfield", StringType, nullable = true),
- StructField("secondfield", StringType, nullable = true),
- StructField("thirdfield", StringType, nullable = false))),
- StructType(Seq(
- StructField("firstField", StringType, nullable = true),
- StructField("secondField", StringType, nullable = true))))
- }.getMessage.contains("detected conflicting schemas"))
- }
-
test("schema merging failure error message") {
import testImplicits._
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
index 9c55357ab9bc..26c45e092dc6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
@@ -22,15 +22,12 @@ import scala.reflect.ClassTag
import org.apache.spark.AccumulatorSuite
import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, BitwiseOr, Cast, Literal, ShiftLeft}
-import org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias
-import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.execution.exchange.EnsureRequirements
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types.{LongType, ShortType}
-import org.apache.spark.util.Utils
/**
* Test various broadcast join operators.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala
index 75723d0abcfc..9742b3b2d5c2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala
@@ -18,7 +18,6 @@
package org.apache.spark.sql.internal
import java.io.File
-import java.net.URI
import org.scalatest.BeforeAndAfterEach
@@ -459,7 +458,7 @@ class CatalogSuite
options = Map("path" -> dir.getAbsolutePath))
val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t"))
assert(table.tableType == CatalogTableType.EXTERNAL)
- assert(table.storage.locationUri.get == dir.getAbsolutePath)
+ assert(table.storage.locationUri.get == makeQualifiedPath(dir.getAbsolutePath))
Seq((1)).toDF("i").write.insertInto("t")
assert(dir.exists() && dir.listFiles().nonEmpty)
@@ -481,7 +480,7 @@ class CatalogSuite
options = Map.empty[String, String])
val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t"))
assert(table.tableType == CatalogTableType.MANAGED)
- val tablePath = new File(new URI(table.storage.locationUri.get))
+ val tablePath = new File(table.storage.locationUri.get)
assert(tablePath.exists() && tablePath.listFiles().isEmpty)
Seq((1)).toDF("i").write.insertInto("t")
@@ -493,6 +492,25 @@ class CatalogSuite
}
}
- // TODO: add tests for the rest of them
+ test("clone Catalog") {
+ // need to test tempTables are cloned
+ assert(spark.catalog.listTables().collect().isEmpty)
+ createTempTable("my_temp_table")
+ assert(spark.catalog.listTables().collect().map(_.name).toSet == Set("my_temp_table"))
+
+ // inheritance
+ val forkedSession = spark.cloneSession()
+ assert(spark ne forkedSession)
+ assert(spark.catalog ne forkedSession.catalog)
+ assert(forkedSession.catalog.listTables().collect().map(_.name).toSet == Set("my_temp_table"))
+
+ // independence
+ dropTable("my_temp_table") // drop table in original session
+ assert(spark.catalog.listTables().collect().map(_.name).toSet == Set())
+ assert(forkedSession.catalog.listTables().collect().map(_.name).toSet == Set("my_temp_table"))
+ forkedSession.sessionState.catalog
+ .createTempView("fork_table", Range(1, 2, 3, 4), overrideIfExists = true)
+ assert(spark.catalog.listTables().collect().map(_.name).toSet == Set())
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala
index 0e3a5ca9d71d..f2456c770406 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala
@@ -187,4 +187,22 @@ class SQLConfEntrySuite extends SparkFunSuite {
}
assert(e2.getMessage === "The maximum size of the cache must not be negative")
}
+
+ test("clone SQLConf") {
+ val original = new SQLConf
+ val key = "spark.sql.SQLConfEntrySuite.clone"
+ assert(original.getConfString(key, "noentry") === "noentry")
+
+ // inheritance
+ original.setConfString(key, "orig")
+ val clone = original.clone()
+ assert(original ne clone)
+ assert(clone.getConfString(key, "noentry") === "orig")
+
+ // independence
+ clone.setConfString(key, "clone")
+ assert(original.getConfString(key, "noentry") === "orig")
+ original.setConfString(key, "dontcopyme")
+ assert(clone.getConfString(key, "noentry") === "clone")
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
index 9082261af7b0..93f3efe2ccc4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
@@ -92,7 +92,7 @@ abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils {
def tableDir: File = {
val identifier = spark.sessionState.sqlParser.parseTableIdentifier("bucketed_table")
- new File(URI.create(spark.sessionState.catalog.defaultTablePath(identifier)))
+ new File(spark.sessionState.catalog.defaultTablePath(identifier))
}
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala
index faf9afc49a2d..60adee4599b0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala
@@ -17,10 +17,13 @@
package org.apache.spark.sql.sources
+import java.net.URI
+
import org.apache.hadoop.fs.Path
import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession, SQLContext}
import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.catalog.CatalogUtils
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, Metadata, MetadataBuilder, StructType}
@@ -72,28 +75,29 @@ class PathOptionSuite extends DataSourceTest with SharedSQLContext {
|USING ${classOf[TestOptionsSource].getCanonicalName}
|OPTIONS (PATH '/tmp/path')
""".stripMargin)
- assert(getPathOption("src") == Some("/tmp/path"))
+ assert(getPathOption("src") == Some("file:/tmp/path"))
}
// should exist even path option is not specified when creating table
withTable("src") {
sql(s"CREATE TABLE src(i int) USING ${classOf[TestOptionsSource].getCanonicalName}")
- assert(getPathOption("src") == Some(defaultTablePath("src")))
+ assert(getPathOption("src") == Some(CatalogUtils.URIToString(defaultTablePath("src"))))
}
}
test("path option also exist for write path") {
withTable("src") {
withTempPath { p =>
- val path = new Path(p.getAbsolutePath).toString
sql(
s"""
|CREATE TABLE src
|USING ${classOf[TestOptionsSource].getCanonicalName}
- |OPTIONS (PATH '$path')
+ |OPTIONS (PATH '$p')
|AS SELECT 1
""".stripMargin)
- assert(spark.table("src").schema.head.metadata.getString("path") == path)
+ assert(CatalogUtils.stringToURI(
+ spark.table("src").schema.head.metadata.getString("path")) ==
+ makeQualifiedPath(p.getAbsolutePath))
}
}
@@ -105,7 +109,8 @@ class PathOptionSuite extends DataSourceTest with SharedSQLContext {
|USING ${classOf[TestOptionsSource].getCanonicalName}
|AS SELECT 1
""".stripMargin)
- assert(spark.table("src").schema.head.metadata.getString("path") == defaultTablePath("src"))
+ assert(spark.table("src").schema.head.metadata.getString("path") ==
+ CatalogUtils.URIToString(defaultTablePath("src")))
}
}
@@ -123,7 +128,7 @@ class PathOptionSuite extends DataSourceTest with SharedSQLContext {
withTable("src", "src2") {
sql(s"CREATE TABLE src(i int) USING ${classOf[TestOptionsSource].getCanonicalName}")
sql("ALTER TABLE src RENAME TO src2")
- assert(getPathOption("src2") == Some(defaultTablePath("src2")))
+ assert(getPathOption("src2") == Some(CatalogUtils.URIToString(defaultTablePath("src2"))))
}
}
@@ -133,7 +138,7 @@ class PathOptionSuite extends DataSourceTest with SharedSQLContext {
}.head
}
- private def defaultTablePath(tableName: String): String = {
+ private def defaultTablePath(tableName: String): URI = {
spark.sessionState.catalog.defaultTablePath(TableIdentifier(tableName))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala
index 7ea716231e5d..a15c2cff930f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala
@@ -249,4 +249,23 @@ class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll {
}
}
}
+
+ test("SPARK-19841: watermarkPredicate should filter based on keys") {
+ val input = MemoryStream[(Int, Int)]
+ val df = input.toDS.toDF("time", "id")
+ .withColumn("time", $"time".cast("timestamp"))
+ .withWatermark("time", "1 second")
+ .dropDuplicates("id", "time") // Change the column positions
+ .select($"id")
+ testStream(df)(
+ AddData(input, 1 -> 1, 1 -> 1, 1 -> 2),
+ CheckLastBatch(1, 2),
+ AddData(input, 1 -> 1, 2 -> 3, 2 -> 4),
+ CheckLastBatch(3, 4),
+ AddData(input, 1 -> 0, 1 -> 1, 3 -> 5, 3 -> 6), // Drop (1 -> 0, 1 -> 1) due to watermark
+ CheckLastBatch(5, 6),
+ AddData(input, 1 -> 0, 4 -> 7), // Drop (1 -> 0) due to watermark
+ CheckLastBatch(7)
+ )
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala
index c34d119734cc..7614ea5eb3c0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala
@@ -25,6 +25,7 @@ import org.scalatest.BeforeAndAfter
import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.functions.{count, window}
import org.apache.spark.sql.streaming.OutputMode._
@@ -305,6 +306,42 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Loggin
)
}
+ test("delay threshold should not be negative.") {
+ val inputData = MemoryStream[Int].toDF()
+ var e = intercept[IllegalArgumentException] {
+ inputData.withWatermark("value", "-1 year")
+ }
+ assert(e.getMessage contains "should not be negative.")
+
+ e = intercept[IllegalArgumentException] {
+ inputData.withWatermark("value", "1 year -13 months")
+ }
+ assert(e.getMessage contains "should not be negative.")
+
+ e = intercept[IllegalArgumentException] {
+ inputData.withWatermark("value", "1 month -40 days")
+ }
+ assert(e.getMessage contains "should not be negative.")
+
+ e = intercept[IllegalArgumentException] {
+ inputData.withWatermark("value", "-10 seconds")
+ }
+ assert(e.getMessage contains "should not be negative.")
+ }
+
+ test("the new watermark should override the old one") {
+ val df = MemoryStream[(Long, Long)].toDF()
+ .withColumn("first", $"_1".cast("timestamp"))
+ .withColumn("second", $"_2".cast("timestamp"))
+ .withWatermark("first", "1 minute")
+ .withWatermark("second", "2 minutes")
+
+ val eventTimeColumns = df.logicalPlan.output
+ .filter(_.metadata.contains(EventTimeWatermark.delayKey))
+ assert(eventTimeColumns.size === 1)
+ assert(eventTimeColumns(0).name === "second")
+ }
+
private def assertNumStateRows(numTotalRows: Long): AssertOnQuery = AssertOnQuery { q =>
val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get
assert(progressWithData.stateOperators(0).numRowsTotal === numTotalRows)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
index 1586850c77fc..f705da3d6a70 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
@@ -1173,6 +1173,41 @@ class FileStreamSourceSuite extends FileStreamSourceTest {
SerializedOffset(str.trim)
}
+ private def runTwoBatchesAndVerifyResults(
+ src: File,
+ latestFirst: Boolean,
+ firstBatch: String,
+ secondBatch: String,
+ maxFileAge: Option[String] = None): Unit = {
+ val srcOptions = Map("latestFirst" -> latestFirst.toString, "maxFilesPerTrigger" -> "1") ++
+ maxFileAge.map("maxFileAge" -> _)
+ val fileStream = createFileStream(
+ "text",
+ src.getCanonicalPath,
+ options = srcOptions)
+ val clock = new StreamManualClock()
+ testStream(fileStream)(
+ StartStream(trigger = ProcessingTime(10), triggerClock = clock),
+ AssertOnQuery { _ =>
+ // Block until the first batch finishes.
+ eventually(timeout(streamingTimeout)) {
+ assert(clock.isStreamWaitingAt(0))
+ }
+ true
+ },
+ CheckLastBatch(firstBatch),
+ AdvanceManualClock(10),
+ AssertOnQuery { _ =>
+ // Block until the second batch finishes.
+ eventually(timeout(streamingTimeout)) {
+ assert(clock.isStreamWaitingAt(10))
+ }
+ true
+ },
+ CheckLastBatch(secondBatch)
+ )
+ }
+
test("FileStreamSource - latestFirst") {
withTempDir { src =>
// Prepare two files: 1.txt, 2.txt, and make sure they have different modified time.
@@ -1180,47 +1215,28 @@ class FileStreamSourceSuite extends FileStreamSourceTest {
val f2 = stringToFile(new File(src, "2.txt"), "2")
f2.setLastModified(f1.lastModified + 1000)
- def runTwoBatchesAndVerifyResults(
- latestFirst: Boolean,
- firstBatch: String,
- secondBatch: String): Unit = {
- val fileStream = createFileStream(
- "text",
- src.getCanonicalPath,
- options = Map("latestFirst" -> latestFirst.toString, "maxFilesPerTrigger" -> "1"))
- val clock = new StreamManualClock()
- testStream(fileStream)(
- StartStream(trigger = ProcessingTime(10), triggerClock = clock),
- AssertOnQuery { _ =>
- // Block until the first batch finishes.
- eventually(timeout(streamingTimeout)) {
- assert(clock.isStreamWaitingAt(0))
- }
- true
- },
- CheckLastBatch(firstBatch),
- AdvanceManualClock(10),
- AssertOnQuery { _ =>
- // Block until the second batch finishes.
- eventually(timeout(streamingTimeout)) {
- assert(clock.isStreamWaitingAt(10))
- }
- true
- },
- CheckLastBatch(secondBatch)
- )
- }
-
// Read oldest files first, so the first batch is "1", and the second batch is "2".
- runTwoBatchesAndVerifyResults(latestFirst = false, firstBatch = "1", secondBatch = "2")
+ runTwoBatchesAndVerifyResults(src, latestFirst = false, firstBatch = "1", secondBatch = "2")
// Read latest files first, so the first batch is "2", and the second batch is "1".
- runTwoBatchesAndVerifyResults(latestFirst = true, firstBatch = "2", secondBatch = "1")
+ runTwoBatchesAndVerifyResults(src, latestFirst = true, firstBatch = "2", secondBatch = "1")
+ }
+ }
+
+ test("SPARK-19813: Ignore maxFileAge when maxFilesPerTrigger and latestFirst is used") {
+ withTempDir { src =>
+ // Prepare two files: 1.txt, 2.txt, and make sure they have different modified time.
+ val f1 = stringToFile(new File(src, "1.txt"), "1")
+ val f2 = stringToFile(new File(src, "2.txt"), "2")
+ f2.setLastModified(f1.lastModified + 3600 * 1000 /* 1 hour later */)
+
+ runTwoBatchesAndVerifyResults(src, latestFirst = true, firstBatch = "2", secondBatch = "1",
+ maxFileAge = Some("1m") /* 1 minute */)
}
}
test("SeenFilesMap") {
- val map = new SeenFilesMap(maxAgeMs = 10)
+ val map = new SeenFilesMap(maxAgeMs = 10, fileNameOnly = false)
map.add("a", 5)
assert(map.size == 1)
@@ -1253,8 +1269,26 @@ class FileStreamSourceSuite extends FileStreamSourceTest {
assert(map.isNewFile("e", 20))
}
+ test("SeenFilesMap with fileNameOnly = true") {
+ val map = new SeenFilesMap(maxAgeMs = 10, fileNameOnly = true)
+
+ map.add("file:///a/b/c/d", 5)
+ map.add("file:///a/b/c/e", 5)
+ assert(map.size === 2)
+
+ assert(!map.isNewFile("d", 5))
+ assert(!map.isNewFile("file:///d", 5))
+ assert(!map.isNewFile("file:///x/d", 5))
+ assert(!map.isNewFile("file:///x/y/d", 5))
+
+ map.add("s3:///bucket/d", 5)
+ map.add("s3n:///bucket/d", 5)
+ map.add("s3a:///bucket/d", 5)
+ assert(map.size === 2)
+ }
+
test("SeenFilesMap should only consider a file old if it is earlier than last purge time") {
- val map = new SeenFilesMap(maxAgeMs = 10)
+ val map = new SeenFilesMap(maxAgeMs = 10, fileNameOnly = false)
map.add("a", 20)
assert(map.size == 1)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
similarity index 88%
rename from sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala
rename to sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
index 6cf4d51f9933..902b842e97aa 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.streaming.state.StateStore
/** Class to check custom state types */
case class RunningCount(count: Long)
-class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterAll {
+class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterAll {
import testImplicits._
@@ -119,9 +119,9 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA
val result =
inputData.toDS()
.groupByKey(x => x)
- .flatMapGroupsWithState(stateFunc) // State: Int, Out: (Str, Str)
+ .flatMapGroupsWithState(stateFunc, Update) // State: Int, Out: (Str, Str)
- testStream(result, Append)(
+ testStream(result, Update)(
AddData(inputData, "a"),
CheckLastBatch(("a", "1")),
assertNumStateRows(total = 1, updated = 1),
@@ -162,9 +162,9 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA
val result =
inputData.toDS()
.groupByKey(x => x)
- .flatMapGroupsWithState(stateFunc) // State: Int, Out: (Str, Str)
+ .flatMapGroupsWithState(stateFunc, Update) // State: Int, Out: (Str, Str)
- testStream(result, Append)(
+ testStream(result, Update)(
AddData(inputData, "a", "a", "b"),
CheckLastBatch(("a", "1"), ("a", "2"), ("b", "1")),
StopStream,
@@ -185,7 +185,7 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA
Iterator((key, values.size))
}
checkAnswer(
- Seq("a", "a", "b").toDS.groupByKey(x => x).flatMapGroupsWithState(stateFunc).toDF,
+ Seq("a", "a", "b").toDS.groupByKey(x => x).flatMapGroupsWithState(stateFunc, Update).toDF,
Seq(("a", 2), ("b", 1)).toDF)
}
@@ -210,7 +210,7 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA
.groupByKey(x => x)
.mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str)
- testStream(result, Append)(
+ testStream(result, Update)(
AddData(inputData, "a"),
CheckLastBatch(("a", "1")),
assertNumStateRows(total = 1, updated = 1),
@@ -230,7 +230,7 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA
)
}
- test("mapGroupsWithState - streaming + aggregation") {
+ test("flatMapGroupsWithState - streaming + aggregation") {
// Function to maintain running count up to 2, and then remove the count
// Returns the data and the count (-1 if count reached beyond 2 and state was just removed)
val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
@@ -238,10 +238,10 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA
val count = state.getOption.map(_.count).getOrElse(0L) + values.size
if (count == 3) {
state.remove()
- (key, "-1")
+ Iterator(key -> "-1")
} else {
state.update(RunningCount(count))
- (key, count.toString)
+ Iterator(key -> count.toString)
}
}
@@ -249,7 +249,7 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA
val result =
inputData.toDS()
.groupByKey(x => x)
- .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str)
+ .flatMapGroupsWithState(stateFunc, Append) // Types = State: MyState, Out: (Str, Str)
.groupByKey(_._1)
.count()
@@ -290,7 +290,7 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA
testQuietly("StateStore.abort on task failure handling") {
val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
- if (MapGroupsWithStateSuite.failInTask) throw new Exception("expected failure")
+ if (FlatMapGroupsWithStateSuite.failInTask) throw new Exception("expected failure")
val count = state.getOption.map(_.count).getOrElse(0L) + values.size
state.update(RunningCount(count))
(key, count)
@@ -303,11 +303,11 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA
.mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str)
def setFailInTask(value: Boolean): AssertOnQuery = AssertOnQuery { q =>
- MapGroupsWithStateSuite.failInTask = value
+ FlatMapGroupsWithStateSuite.failInTask = value
true
}
- testStream(result, Append)(
+ testStream(result, Update)(
setFailInTask(false),
AddData(inputData, "a"),
CheckLastBatch(("a", 1L)),
@@ -321,8 +321,24 @@ class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterA
CheckLastBatch(("a", 3L)) // task should not fail, and should show correct count
)
}
+
+ test("disallow complete mode") {
+ val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
+ Iterator[String]()
+ }
+
+ var e = intercept[IllegalArgumentException] {
+ MemoryStream[String].toDS().groupByKey(x => x).flatMapGroupsWithState(stateFunc, Complete)
+ }
+ assert(e.getMessage === "The output mode of function should be append or update")
+
+ e = intercept[IllegalArgumentException] {
+ MemoryStream[String].toDS().groupByKey(x => x).flatMapGroupsWithState(stateFunc, "complete")
+ }
+ assert(e.getMessage === "The output mode of function should be append or update")
+ }
}
-object MapGroupsWithStateSuite {
+object FlatMapGroupsWithStateSuite {
var failInTask = true
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala
index 0470411a0f10..f61dcdcbcf71 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala
@@ -24,8 +24,7 @@ import scala.concurrent.duration._
import org.apache.hadoop.fs.Path
import org.mockito.Mockito._
-import org.scalatest.{BeforeAndAfter, PrivateMethodTester}
-import org.scalatest.PrivateMethodTester.PrivateMethod
+import org.scalatest.BeforeAndAfter
import org.apache.spark.sql._
import org.apache.spark.sql.execution.streaming._
@@ -107,7 +106,7 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider {
}
}
-class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter with PrivateMethodTester {
+class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter {
private def newMetadataDir =
Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
@@ -390,42 +389,6 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter with Pr
private def newTextInput = Utils.createTempDir(namePrefix = "text").getCanonicalPath
- test("supported strings in outputMode(string)") {
- val outputModeMethod = PrivateMethod[OutputMode]('outputMode)
-
- def testMode(outputMode: String, expected: OutputMode): Unit = {
- val df = spark.readStream
- .format("org.apache.spark.sql.streaming.test")
- .load()
- val w = df.writeStream
- w.outputMode(outputMode)
- val setOutputMode = w invokePrivate outputModeMethod()
- assert(setOutputMode === expected)
- }
-
- testMode("append", OutputMode.Append)
- testMode("Append", OutputMode.Append)
- testMode("complete", OutputMode.Complete)
- testMode("Complete", OutputMode.Complete)
- testMode("update", OutputMode.Update)
- testMode("Update", OutputMode.Update)
- }
-
- test("unsupported strings in outputMode(string)") {
- def testMode(outputMode: String): Unit = {
- val acceptedModes = Seq("append", "update", "complete")
- val df = spark.readStream
- .format("org.apache.spark.sql.streaming.test")
- .load()
- val w = df.writeStream
- val e = intercept[IllegalArgumentException](w.outputMode(outputMode))
- (Seq("output mode", "unknown", outputMode) ++ acceptedModes).foreach { s =>
- assert(e.getMessage.toLowerCase.contains(s.toLowerCase))
- }
- }
- testMode("Xyz")
- }
-
test("check foreach() catches null writers") {
val df = spark.readStream
.format("org.apache.spark.sql.streaming.test")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
index d4afb9d8af6f..9201954b66d1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
@@ -18,13 +18,14 @@
package org.apache.spark.sql.test
import java.io.File
+import java.net.URI
import java.util.UUID
import scala.language.implicitConversions
import scala.util.Try
import scala.util.control.NonFatal
-import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.SparkFunSuite
@@ -294,6 +295,17 @@ private[sql] trait SQLTestUtils
test(name) { runOnThread() }
}
}
+
+ /**
+ * This method is used to make the given path qualified, when a path
+ * does not contain a scheme, this path will not be changed after the default
+ * FileSystem is changed.
+ */
+ def makeQualifiedPath(path: String): URI = {
+ val hadoopPath = new Path(path)
+ val fs = hadoopPath.getFileSystem(spark.sessionState.newHadoopConf())
+ fs.makeQualified(hadoopPath).toUri
+ }
}
private[sql] object SQLTestUtils {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
index 8ab6db175da5..898a2fb4f329 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
@@ -35,18 +35,16 @@ private[sql] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) {
}
@transient
- override lazy val sessionState: SessionState = new SessionState(self) {
- override lazy val conf: SQLConf = {
- new SQLConf {
- clear()
- override def clear(): Unit = {
- super.clear()
- // Make sure we start with the default test configs even after clear
- TestSQLContext.overrideConfs.foreach { case (key, value) => setConfString(key, value) }
- }
+ override lazy val sessionState: SessionState = SessionState(
+ this,
+ new SQLConf {
+ clear()
+ override def clear(): Unit = {
+ super.clear()
+ // Make sure we start with the default test configs even after clear
+ TestSQLContext.overrideConfs.foreach { case (key, value) => setConfString(key, value) }
}
- }
- }
+ })
// Needed for Java tests
def loadTestData(): Unit = {
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala
index 43d9c2bec682..78aa2bd2494f 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala
@@ -210,7 +210,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
tableDefinition.storage.locationUri.isEmpty
val tableLocation = if (needDefaultTableLocation) {
- Some(defaultTablePath(tableDefinition.identifier))
+ Some(CatalogUtils.stringToURI(defaultTablePath(tableDefinition.identifier)))
} else {
tableDefinition.storage.locationUri
}
@@ -260,7 +260,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
// However, in older version of Spark we already store table location in storage properties
// with key "path". Here we keep this behaviour for backward compatibility.
val storagePropsWithLocation = table.storage.properties ++
- table.storage.locationUri.map("path" -> _)
+ table.storage.locationUri.map("path" -> CatalogUtils.URIToString(_))
// converts the table metadata to Spark SQL specific format, i.e. set data schema, names and
// bucket specification to empty. Note that partition columns are retained, so that we can
@@ -285,7 +285,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
// compatible format, which means the data source is file-based and must have a `path`.
require(table.storage.locationUri.isDefined,
"External file-based data source table must have a `path` entry in storage properties.")
- Some(new Path(table.location).toUri.toString)
+ Some(table.location)
} else {
None
}
@@ -432,13 +432,13 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
//
// Please refer to https://issues.apache.org/jira/browse/SPARK-15269 for more details.
val tempPath = {
- val dbLocation = getDatabase(tableDefinition.database).locationUri
+ val dbLocation = new Path(getDatabase(tableDefinition.database).locationUri)
new Path(dbLocation, tableDefinition.identifier.table + "-__PLACEHOLDER__")
}
try {
client.createTable(
- tableDefinition.withNewStorage(locationUri = Some(tempPath.toString)),
+ tableDefinition.withNewStorage(locationUri = Some(tempPath.toUri)),
ignoreIfExists)
} finally {
FileSystem.get(tempPath.toUri, hadoopConf).delete(tempPath, true)
@@ -563,7 +563,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
// want to alter the table location to a file path, we will fail. This should be fixed
// in the future.
- val newLocation = tableDefinition.storage.locationUri
+ val newLocation = tableDefinition.storage.locationUri.map(CatalogUtils.URIToString(_))
val storageWithPathOption = tableDefinition.storage.copy(
properties = tableDefinition.storage.properties ++ newLocation.map("path" -> _))
@@ -597,6 +597,25 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
}
}
+ override def alterTableSchema(db: String, table: String, schema: StructType): Unit = withClient {
+ requireTableExists(db, table)
+ val rawTable = getRawTable(db, table)
+ val withNewSchema = rawTable.copy(schema = schema)
+ // Add table metadata such as table schema, partition columns, etc. to table properties.
+ val updatedTable = withNewSchema.copy(
+ properties = withNewSchema.properties ++ tableMetaToTableProps(withNewSchema))
+ try {
+ client.alterTable(updatedTable)
+ } catch {
+ case NonFatal(e) =>
+ val warningMessage =
+ s"Could not alter schema of table ${rawTable.identifier.quotedString} in a Hive " +
+ "compatible way. Updating Hive metastore in Spark SQL specific format."
+ logWarning(warningMessage, e)
+ client.alterTable(updatedTable.copy(schema = updatedTable.partitionSchema))
+ }
+ }
+
override def getTable(db: String, table: String): CatalogTable = withClient {
restoreTableMetadata(getRawTable(db, table))
}
@@ -690,10 +709,10 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
"different from the schema when this table was created by Spark SQL" +
s"(${schemaFromTableProps.simpleString}). We have to fall back to the table schema " +
"from Hive metastore which is not case preserving.")
- hiveTable
+ hiveTable.copy(schemaPreservesCase = false)
}
} else {
- hiveTable
+ hiveTable.copy(schemaPreservesCase = false)
}
}
@@ -704,7 +723,8 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
val storageWithLocation = {
val tableLocation = getLocationFromStorageProps(table)
// We pass None as `newPath` here, to remove the path option in storage properties.
- updateLocationInStorageProps(table, newPath = None).copy(locationUri = tableLocation)
+ updateLocationInStorageProps(table, newPath = None).copy(
+ locationUri = tableLocation.map(CatalogUtils.stringToURI(_)))
}
val partitionProvider = table.properties.get(TABLE_PARTITION_PROVIDER)
@@ -848,10 +868,10 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
// However, Hive metastore is not case preserving and will generate wrong partition location
// with lower cased partition column names. Here we set the default partition location
// manually to avoid this problem.
- val partitionPath = p.storage.locationUri.map(uri => new Path(new URI(uri))).getOrElse {
+ val partitionPath = p.storage.locationUri.map(uri => new Path(uri)).getOrElse {
ExternalCatalogUtils.generatePartitionPath(p.spec, partitionColumnNames, tablePath)
}
- p.copy(storage = p.storage.copy(locationUri = Some(partitionPath.toUri.toString)))
+ p.copy(storage = p.storage.copy(locationUri = Some(partitionPath.toUri)))
}
val lowerCasedParts = partsWithLocation.map(p => p.copy(spec = lowerCasePartitionSpec(p.spec)))
client.createPartitions(db, table, lowerCasedParts, ignoreIfExists)
@@ -890,7 +910,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
val newParts = newSpecs.map { spec =>
val rightPath = renamePartitionDirectory(fs, tablePath, partitionColumnNames, spec)
val partition = client.getPartition(db, table, lowerCasePartitionSpec(spec))
- partition.copy(storage = partition.storage.copy(locationUri = Some(rightPath.toString)))
+ partition.copy(storage = partition.storage.copy(locationUri = Some(rightPath.toUri)))
}
alterPartitions(db, table, newParts)
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index 151a69aebf1d..9f0d1ceb28fc 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -19,9 +19,12 @@ package org.apache.spark.sql.hive
import java.net.URI
+import scala.util.control.NonFatal
+
import com.google.common.util.concurrent.Striped
import org.apache.hadoop.fs.Path
+import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.{QualifiedTableName, TableIdentifier}
@@ -32,6 +35,7 @@ import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetOptions}
import org.apache.spark.sql.hive.orc.OrcFileFormat
+import org.apache.spark.sql.internal.SQLConf.HiveCaseSensitiveInferenceMode._
import org.apache.spark.sql.types._
/**
@@ -41,8 +45,10 @@ import org.apache.spark.sql.types._
* cleaned up to integrate more nicely with [[HiveExternalCatalog]].
*/
private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Logging {
- private val sessionState = sparkSession.sessionState.asInstanceOf[HiveSessionState]
- private lazy val tableRelationCache = sparkSession.sessionState.catalog.tableRelationCache
+ // these are def_s and not val/lazy val since the latter would introduce circular references
+ private def sessionState = sparkSession.sessionState.asInstanceOf[HiveSessionState]
+ private def tableRelationCache = sparkSession.sessionState.catalog.tableRelationCache
+ import HiveMetastoreCatalog._
private def getCurrentDatabase: String = sessionState.catalog.getCurrentDatabase
@@ -128,7 +134,9 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log
QualifiedTableName(relation.tableMeta.database, relation.tableMeta.identifier.table)
val lazyPruningEnabled = sparkSession.sqlContext.conf.manageFilesourcePartitions
- val tablePath = new Path(new URI(relation.tableMeta.location))
+ val tablePath = new Path(relation.tableMeta.location)
+ val fileFormat = fileFormatClass.newInstance()
+
val result = if (relation.isPartitioned) {
val partitionSchema = relation.tableMeta.partitionSchema
val rootPaths: Seq[Path] = if (lazyPruningEnabled) {
@@ -141,7 +149,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log
// locations,_omitting_ the table's base path.
val paths = sparkSession.sharedState.externalCatalog
.listPartitions(tableIdentifier.database, tableIdentifier.name)
- .map(p => new Path(new URI(p.storage.locationUri.get)))
+ .map(p => new Path(p.storage.locationUri.get))
if (paths.isEmpty) {
Seq(tablePath)
@@ -169,16 +177,18 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log
}
}
+ val (dataSchema, updatedTable) =
+ inferIfNeeded(relation, options, fileFormat, Option(fileIndex))
+
val fsRelation = HadoopFsRelation(
location = fileIndex,
partitionSchema = partitionSchema,
- dataSchema = relation.tableMeta.dataSchema,
+ dataSchema = dataSchema,
// We don't support hive bucketed tables, only ones we write out.
bucketSpec = None,
- fileFormat = fileFormatClass.newInstance(),
+ fileFormat = fileFormat,
options = options)(sparkSession = sparkSession)
-
- val created = LogicalRelation(fsRelation, catalogTable = Some(relation.tableMeta))
+ val created = LogicalRelation(fsRelation, catalogTable = Some(updatedTable))
tableRelationCache.put(tableIdentifier, created)
created
}
@@ -195,17 +205,18 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log
fileFormatClass,
None)
val logicalRelation = cached.getOrElse {
+ val (dataSchema, updatedTable) = inferIfNeeded(relation, options, fileFormat)
val created =
LogicalRelation(
DataSource(
sparkSession = sparkSession,
paths = rootPath.toString :: Nil,
- userSpecifiedSchema = Some(metastoreSchema),
+ userSpecifiedSchema = Option(dataSchema),
// We don't support hive bucketed tables, only ones we write out.
bucketSpec = None,
options = options,
className = fileType).resolveRelation(),
- catalogTable = Some(relation.tableMeta))
+ catalogTable = Some(updatedTable))
tableRelationCache.put(tableIdentifier, created)
created
@@ -217,6 +228,54 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log
result.copy(expectedOutputAttributes = Some(relation.output))
}
+ private def inferIfNeeded(
+ relation: CatalogRelation,
+ options: Map[String, String],
+ fileFormat: FileFormat,
+ fileIndexOpt: Option[FileIndex] = None): (StructType, CatalogTable) = {
+ val inferenceMode = sparkSession.sessionState.conf.caseSensitiveInferenceMode
+ val shouldInfer = (inferenceMode != NEVER_INFER) && !relation.tableMeta.schemaPreservesCase
+ val tableName = relation.tableMeta.identifier.unquotedString
+ if (shouldInfer) {
+ logInfo(s"Inferring case-sensitive schema for table $tableName (inference mode: " +
+ s"$inferenceMode)")
+ val fileIndex = fileIndexOpt.getOrElse {
+ val rootPath = new Path(relation.tableMeta.location)
+ new InMemoryFileIndex(sparkSession, Seq(rootPath), options, None)
+ }
+
+ val inferredSchema = fileFormat
+ .inferSchema(
+ sparkSession,
+ options,
+ fileIndex.listFiles(Nil).flatMap(_.files))
+ .map(mergeWithMetastoreSchema(relation.tableMeta.schema, _))
+
+ inferredSchema match {
+ case Some(schema) =>
+ if (inferenceMode == INFER_AND_SAVE) {
+ updateCatalogSchema(relation.tableMeta.identifier, schema)
+ }
+ (schema, relation.tableMeta.copy(schema = schema))
+ case None =>
+ logWarning(s"Unable to infer schema for table $tableName from file format " +
+ s"$fileFormat (inference mode: $inferenceMode). Using metastore schema.")
+ (relation.tableMeta.schema, relation.tableMeta)
+ }
+ } else {
+ (relation.tableMeta.schema, relation.tableMeta)
+ }
+ }
+
+ private def updateCatalogSchema(identifier: TableIdentifier, schema: StructType): Unit = try {
+ val db = identifier.database.get
+ logInfo(s"Saving case-sensitive schema for table ${identifier.unquotedString}")
+ sparkSession.sharedState.externalCatalog.alterTableSchema(db, identifier.table, schema)
+ } catch {
+ case NonFatal(ex) =>
+ logWarning(s"Unable to save case-sensitive schema for table ${identifier.unquotedString}", ex)
+ }
+
/**
* When scanning or writing to non-partitioned Metastore Parquet tables, convert them to Parquet
* data source relations for better performance.
@@ -286,3 +345,30 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log
}
}
}
+
+private[hive] object HiveMetastoreCatalog {
+ def mergeWithMetastoreSchema(
+ metastoreSchema: StructType,
+ inferredSchema: StructType): StructType = try {
+ // Find any nullable fields in mestastore schema that are missing from the inferred schema.
+ val metastoreFields = metastoreSchema.map(f => f.name.toLowerCase -> f).toMap
+ val missingNullables = metastoreFields
+ .filterKeys(!inferredSchema.map(_.name.toLowerCase).contains(_))
+ .values
+ .filter(_.nullable)
+ // Merge missing nullable fields to inferred schema and build a case-insensitive field map.
+ val inferredFields = StructType(inferredSchema ++ missingNullables)
+ .map(f => f.name.toLowerCase -> f).toMap
+ StructType(metastoreSchema.map(f => f.copy(name = inferredFields(f.name).name)))
+ } catch {
+ case NonFatal(_) =>
+ val msg = s"""Detected conflicting schemas when merging the schema obtained from the Hive
+ | Metastore with the one inferred from the file format. Metastore schema:
+ |${metastoreSchema.prettyJson}
+ |
+ |Inferred schema:
+ |${inferredSchema.prettyJson}
+ """.stripMargin
+ throw new SparkException(msg)
+ }
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
index c9be1b9d100b..6b7599e3d340 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
@@ -26,7 +26,7 @@ import org.apache.hadoop.hive.ql.exec.{FunctionRegistry => HiveFunctionRegistry}
import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, GenericUDF, GenericUDTF}
import org.apache.spark.sql.{AnalysisException, SparkSession}
-import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
+import org.apache.spark.sql.catalyst.{CatalystConf, FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.catalog.{FunctionResourceLoader, GlobalTempViewManager, SessionCatalog}
@@ -43,31 +43,23 @@ import org.apache.spark.util.Utils
private[sql] class HiveSessionCatalog(
externalCatalog: HiveExternalCatalog,
globalTempViewManager: GlobalTempViewManager,
- sparkSession: SparkSession,
- functionResourceLoader: FunctionResourceLoader,
+ private val metastoreCatalog: HiveMetastoreCatalog,
functionRegistry: FunctionRegistry,
conf: SQLConf,
hadoopConf: Configuration,
parser: ParserInterface)
extends SessionCatalog(
- externalCatalog,
- globalTempViewManager,
- functionResourceLoader,
- functionRegistry,
- conf,
- hadoopConf,
- parser) {
+ externalCatalog,
+ globalTempViewManager,
+ functionRegistry,
+ conf,
+ hadoopConf,
+ parser) {
// ----------------------------------------------------------------
// | Methods and fields for interacting with HiveMetastoreCatalog |
// ----------------------------------------------------------------
- // Catalog for handling data source tables. TODO: This really doesn't belong here since it is
- // essentially a cache for metastore tables. However, it relies on a lot of session-specific
- // things so it would be a lot of work to split its functionality between HiveSessionCatalog
- // and HiveCatalog. We should still do it at some point...
- private val metastoreCatalog = new HiveMetastoreCatalog(sparkSession)
-
// These 2 rules must be run before all other DDL post-hoc resolution rules, i.e.
// `PreprocessTableCreation`, `PreprocessTableInsertion`, `DataSourceAnalysis` and `HiveAnalysis`.
val ParquetConversions: Rule[LogicalPlan] = metastoreCatalog.ParquetConversions
@@ -77,10 +69,51 @@ private[sql] class HiveSessionCatalog(
metastoreCatalog.hiveDefaultTableFilePath(name)
}
+ /**
+ * Create a new [[HiveSessionCatalog]] with the provided parameters. `externalCatalog` and
+ * `globalTempViewManager` are `inherited`, while `currentDb` and `tempTables` are copied.
+ */
+ def newSessionCatalogWith(
+ newSparkSession: SparkSession,
+ conf: SQLConf,
+ hadoopConf: Configuration,
+ functionRegistry: FunctionRegistry,
+ parser: ParserInterface): HiveSessionCatalog = {
+ val catalog = HiveSessionCatalog(
+ newSparkSession,
+ functionRegistry,
+ conf,
+ hadoopConf,
+ parser)
+
+ synchronized {
+ catalog.currentDb = currentDb
+ // copy over temporary tables
+ tempTables.foreach(kv => catalog.tempTables.put(kv._1, kv._2))
+ }
+
+ catalog
+ }
+
+ /**
+ * The parent class [[SessionCatalog]] cannot access the [[SparkSession]] class, so we cannot add
+ * a [[SparkSession]] parameter to [[SessionCatalog.newSessionCatalogWith]]. However,
+ * [[HiveSessionCatalog]] requires a [[SparkSession]] parameter, so we can a new version of
+ * `newSessionCatalogWith` and disable this one.
+ *
+ * TODO Refactor HiveSessionCatalog to not use [[SparkSession]] directly.
+ */
+ override def newSessionCatalogWith(
+ conf: CatalystConf,
+ hadoopConf: Configuration,
+ functionRegistry: FunctionRegistry,
+ parser: ParserInterface): HiveSessionCatalog = throw new UnsupportedOperationException(
+ "to clone HiveSessionCatalog, use the other clone method that also accepts a SparkSession")
+
// For testing only
private[hive] def getCachedDataSourceTable(table: TableIdentifier): LogicalPlan = {
val key = metastoreCatalog.getQualifiedTableName(table)
- sparkSession.sessionState.catalog.tableRelationCache.getIfPresent(key)
+ tableRelationCache.getIfPresent(key)
}
override def makeFunctionBuilder(funcName: String, className: String): FunctionBuilder = {
@@ -199,6 +232,11 @@ private[sql] class HiveSessionCatalog(
}
}
+ // TODO Removes this method after implementing Spark native "histogram_numeric".
+ override def functionExists(name: FunctionIdentifier): Boolean = {
+ super.functionExists(name) || hiveFunctions.contains(name.funcName)
+ }
+
/** List of functions we pass over to Hive. Note that over time this list should go to 0. */
// We have a list of Hive built-in functions that we do not support. So, we will check
// Hive's function registry and lazily load needed functions into our own function registry.
@@ -212,3 +250,28 @@ private[sql] class HiveSessionCatalog(
"histogram_numeric"
)
}
+
+private[sql] object HiveSessionCatalog {
+
+ def apply(
+ sparkSession: SparkSession,
+ functionRegistry: FunctionRegistry,
+ conf: SQLConf,
+ hadoopConf: Configuration,
+ parser: ParserInterface): HiveSessionCatalog = {
+ // Catalog for handling data source tables. TODO: This really doesn't belong here since it is
+ // essentially a cache for metastore tables. However, it relies on a lot of session-specific
+ // things so it would be a lot of work to split its functionality between HiveSessionCatalog
+ // and HiveCatalog. We should still do it at some point...
+ val metastoreCatalog = new HiveMetastoreCatalog(sparkSession)
+
+ new HiveSessionCatalog(
+ sparkSession.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog],
+ sparkSession.sharedState.globalTempViewManager,
+ metastoreCatalog,
+ functionRegistry,
+ conf,
+ hadoopConf,
+ parser)
+ }
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
index 5a08a6bc66f6..cb8bcb8591bd 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
@@ -17,89 +17,65 @@
package org.apache.spark.sql.hive
+import org.apache.spark.SparkContext
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.analysis.Analyzer
-import org.apache.spark.sql.execution.SparkPlanner
+import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry}
+import org.apache.spark.sql.catalyst.parser.ParserInterface
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.{QueryExecution, SparkPlanner, SparkSqlParser}
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.hive.client.HiveClient
-import org.apache.spark.sql.internal.SessionState
+import org.apache.spark.sql.internal.{SessionState, SharedState, SQLConf}
+import org.apache.spark.sql.streaming.StreamingQueryManager
/**
* A class that holds all session-specific state in a given [[SparkSession]] backed by Hive.
+ * @param sparkContext The [[SparkContext]].
+ * @param sharedState The shared state.
+ * @param conf SQL-specific key-value configurations.
+ * @param experimentalMethods The experimental methods.
+ * @param functionRegistry Internal catalog for managing functions registered by the user.
+ * @param catalog Internal catalog for managing table and database states that uses Hive client for
+ * interacting with the metastore.
+ * @param sqlParser Parser that extracts expressions, plans, table identifiers etc. from SQL texts.
+ * @param metadataHive The Hive metadata client.
+ * @param analyzer Logical query plan analyzer for resolving unresolved attributes and relations.
+ * @param streamingQueryManager Interface to start and stop
+ * [[org.apache.spark.sql.streaming.StreamingQuery]]s.
+ * @param queryExecutionCreator Lambda to create a [[QueryExecution]] from a [[LogicalPlan]]
+ * @param plannerCreator Lambda to create a planner that takes into account Hive-specific strategies
*/
-private[hive] class HiveSessionState(sparkSession: SparkSession)
- extends SessionState(sparkSession) {
-
- self =>
-
- /**
- * A Hive client used for interacting with the metastore.
- */
- lazy val metadataHive: HiveClient =
- sparkSession.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client.newSession()
-
- /**
- * Internal catalog for managing table and database states.
- */
- override lazy val catalog = {
- new HiveSessionCatalog(
- sparkSession.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog],
- sparkSession.sharedState.globalTempViewManager,
- sparkSession,
- functionResourceLoader,
- functionRegistry,
+private[hive] class HiveSessionState(
+ sparkContext: SparkContext,
+ sharedState: SharedState,
+ conf: SQLConf,
+ experimentalMethods: ExperimentalMethods,
+ functionRegistry: FunctionRegistry,
+ override val catalog: HiveSessionCatalog,
+ sqlParser: ParserInterface,
+ val metadataHive: HiveClient,
+ analyzer: Analyzer,
+ streamingQueryManager: StreamingQueryManager,
+ queryExecutionCreator: LogicalPlan => QueryExecution,
+ val plannerCreator: () => SparkPlanner)
+ extends SessionState(
+ sparkContext,
+ sharedState,
conf,
- newHadoopConf(),
- sqlParser)
- }
-
- /**
- * An analyzer that uses the Hive metastore.
- */
- override lazy val analyzer: Analyzer = {
- new Analyzer(catalog, conf) {
- override val extendedResolutionRules =
- new ResolveHiveSerdeTable(sparkSession) ::
- new FindDataSourceTable(sparkSession) ::
- new ResolveSQLOnFile(sparkSession) :: Nil
-
- override val postHocResolutionRules =
- new DetermineTableStats(sparkSession) ::
- catalog.ParquetConversions ::
- catalog.OrcConversions ::
- PreprocessTableCreation(sparkSession) ::
- PreprocessTableInsertion(conf) ::
- DataSourceAnalysis(conf) ::
- HiveAnalysis :: Nil
-
- override val extendedCheckRules = Seq(PreWriteCheck)
- }
- }
+ experimentalMethods,
+ functionRegistry,
+ catalog,
+ sqlParser,
+ analyzer,
+ streamingQueryManager,
+ queryExecutionCreator) { self =>
/**
* Planner that takes into account Hive-specific strategies.
*/
- override def planner: SparkPlanner = {
- new SparkPlanner(sparkSession.sparkContext, conf, experimentalMethods.extraStrategies)
- with HiveStrategies {
- override val sparkSession: SparkSession = self.sparkSession
-
- override def strategies: Seq[Strategy] = {
- experimentalMethods.extraStrategies ++ Seq(
- FileSourceStrategy,
- DataSourceStrategy,
- SpecialLimits,
- InMemoryScans,
- HiveTableScans,
- Scripts,
- Aggregation,
- JoinSelection,
- BasicOperators
- )
- }
- }
- }
+ override def planner: SparkPlanner = plannerCreator()
// ------------------------------------------------------
@@ -146,4 +122,149 @@ private[hive] class HiveSessionState(sparkSession: SparkSession)
conf.getConf(HiveUtils.HIVE_THRIFT_SERVER_ASYNC)
}
+ /**
+ * Get an identical copy of the `HiveSessionState`.
+ * This should ideally reuse the `SessionState.clone` but cannot do so.
+ * Doing that will throw an exception when trying to clone the catalog.
+ */
+ override def clone(newSparkSession: SparkSession): HiveSessionState = {
+ val sparkContext = newSparkSession.sparkContext
+ val confCopy = conf.clone()
+ val functionRegistryCopy = functionRegistry.clone()
+ val experimentalMethodsCopy = experimentalMethods.clone()
+ val sqlParser: ParserInterface = new SparkSqlParser(confCopy)
+ val catalogCopy = catalog.newSessionCatalogWith(
+ newSparkSession,
+ confCopy,
+ SessionState.newHadoopConf(sparkContext.hadoopConfiguration, confCopy),
+ functionRegistryCopy,
+ sqlParser)
+ val queryExecutionCreator = (plan: LogicalPlan) => new QueryExecution(newSparkSession, plan)
+
+ val hiveClient =
+ newSparkSession.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client
+ .newSession()
+
+ SessionState.mergeSparkConf(confCopy, sparkContext.getConf)
+
+ new HiveSessionState(
+ sparkContext,
+ newSparkSession.sharedState,
+ confCopy,
+ experimentalMethodsCopy,
+ functionRegistryCopy,
+ catalogCopy,
+ sqlParser,
+ hiveClient,
+ HiveSessionState.createAnalyzer(newSparkSession, catalogCopy, confCopy),
+ new StreamingQueryManager(newSparkSession),
+ queryExecutionCreator,
+ HiveSessionState.createPlannerCreator(
+ newSparkSession,
+ confCopy,
+ experimentalMethodsCopy))
+ }
+
+}
+
+private[hive] object HiveSessionState {
+
+ def apply(sparkSession: SparkSession): HiveSessionState = {
+ apply(sparkSession, new SQLConf)
+ }
+
+ def apply(sparkSession: SparkSession, conf: SQLConf): HiveSessionState = {
+ val initHelper = SessionState(sparkSession, conf)
+
+ val sparkContext = sparkSession.sparkContext
+
+ val catalog = HiveSessionCatalog(
+ sparkSession,
+ initHelper.functionRegistry,
+ initHelper.conf,
+ SessionState.newHadoopConf(sparkContext.hadoopConfiguration, initHelper.conf),
+ initHelper.sqlParser)
+
+ val metadataHive: HiveClient =
+ sparkSession.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client
+ .newSession()
+
+ val analyzer: Analyzer = createAnalyzer(sparkSession, catalog, initHelper.conf)
+
+ val plannerCreator = createPlannerCreator(
+ sparkSession,
+ initHelper.conf,
+ initHelper.experimentalMethods)
+
+ val hiveSessionState = new HiveSessionState(
+ sparkContext,
+ sparkSession.sharedState,
+ initHelper.conf,
+ initHelper.experimentalMethods,
+ initHelper.functionRegistry,
+ catalog,
+ initHelper.sqlParser,
+ metadataHive,
+ analyzer,
+ initHelper.streamingQueryManager,
+ initHelper.queryExecutionCreator,
+ plannerCreator)
+ catalog.functionResourceLoader = hiveSessionState.functionResourceLoader
+ hiveSessionState
+ }
+
+ /**
+ * Create an logical query plan `Analyzer` with rules specific to a `HiveSessionState`.
+ */
+ private def createAnalyzer(
+ sparkSession: SparkSession,
+ catalog: HiveSessionCatalog,
+ sqlConf: SQLConf): Analyzer = {
+ new Analyzer(catalog, sqlConf) {
+ override val extendedResolutionRules: Seq[Rule[LogicalPlan]] =
+ new ResolveHiveSerdeTable(sparkSession) ::
+ new FindDataSourceTable(sparkSession) ::
+ new ResolveSQLOnFile(sparkSession) :: Nil
+
+ override val postHocResolutionRules: Seq[Rule[LogicalPlan]] =
+ new DetermineTableStats(sparkSession) ::
+ catalog.ParquetConversions ::
+ catalog.OrcConversions ::
+ PreprocessTableCreation(sparkSession) ::
+ PreprocessTableInsertion(sqlConf) ::
+ DataSourceAnalysis(sqlConf) ::
+ HiveAnalysis :: Nil
+
+ override val extendedCheckRules = Seq(PreWriteCheck)
+ }
+ }
+
+ private def createPlannerCreator(
+ associatedSparkSession: SparkSession,
+ sqlConf: SQLConf,
+ experimentalMethods: ExperimentalMethods): () => SparkPlanner = {
+ () =>
+ new SparkPlanner(
+ associatedSparkSession.sparkContext,
+ sqlConf,
+ experimentalMethods.extraStrategies)
+ with HiveStrategies {
+
+ override val sparkSession: SparkSession = associatedSparkSession
+
+ override def strategies: Seq[Strategy] = {
+ experimentalMethods.extraStrategies ++ Seq(
+ FileSourceStrategy,
+ DataSourceStrategy,
+ SpecialLimits,
+ InMemoryScans,
+ HiveTableScans,
+ Scripts,
+ Aggregation,
+ JoinSelection,
+ BasicOperators
+ )
+ }
+ }
+ }
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
index 624cfa206eeb..b5ce027d51e7 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
@@ -133,7 +133,7 @@ class DetermineTableStats(session: SparkSession) extends Rule[LogicalPlan] {
} else if (session.sessionState.conf.fallBackToHdfsForStatsEnabled) {
try {
val hadoopConf = session.sessionState.newHadoopConf()
- val tablePath = new Path(new URI(table.location))
+ val tablePath = new Path(table.location)
val fs: FileSystem = tablePath.getFileSystem(hadoopConf)
fs.getContentSummary(tablePath).getLength
} catch {
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
index 7acaa9a7ab41..6e1f429286cf 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
@@ -278,6 +278,8 @@ private[hive] class HiveClientImpl(
state.getConf.setClassLoader(clientLoader.classLoader)
// Set the thread local metastore client to the client associated with this HiveClientImpl.
Hive.set(client)
+ // Replace conf in the thread local Hive with current conf
+ Hive.get(conf)
// setCurrentSessionState will use the classLoader associated
// with the HiveConf in `state` to override the context class loader of the current
// thread.
@@ -317,7 +319,7 @@ private[hive] class HiveClientImpl(
new HiveDatabase(
database.name,
database.description,
- database.locationUri,
+ CatalogUtils.URIToString(database.locationUri),
Option(database.properties).map(_.asJava).orNull),
ignoreIfExists)
}
@@ -335,7 +337,7 @@ private[hive] class HiveClientImpl(
new HiveDatabase(
database.name,
database.description,
- database.locationUri,
+ CatalogUtils.URIToString(database.locationUri),
Option(database.properties).map(_.asJava).orNull))
}
@@ -344,7 +346,7 @@ private[hive] class HiveClientImpl(
CatalogDatabase(
name = d.getName,
description = d.getDescription,
- locationUri = d.getLocationUri,
+ locationUri = CatalogUtils.stringToURI(d.getLocationUri),
properties = Option(d.getParameters).map(_.asScala.toMap).orNull)
}.getOrElse(throw new NoSuchDatabaseException(dbName))
}
@@ -410,7 +412,7 @@ private[hive] class HiveClientImpl(
createTime = h.getTTable.getCreateTime.toLong * 1000,
lastAccessTime = h.getLastAccessTime.toLong * 1000,
storage = CatalogStorageFormat(
- locationUri = shim.getDataLocation(h),
+ locationUri = shim.getDataLocation(h).map(CatalogUtils.stringToURI(_)),
// To avoid ClassNotFound exception, we try our best to not get the format class, but get
// the class name directly. However, for non-native tables, there is no interface to get
// the format class name, so we may still throw ClassNotFound in this case.
@@ -851,7 +853,8 @@ private[hive] object HiveClientImpl {
conf.foreach(c => hiveTable.setOwner(c.getUser))
hiveTable.setCreateTime((table.createTime / 1000).toInt)
hiveTable.setLastAccessTime((table.lastAccessTime / 1000).toInt)
- table.storage.locationUri.foreach { loc => hiveTable.getTTable.getSd.setLocation(loc)}
+ table.storage.locationUri.map(CatalogUtils.URIToString(_)).foreach { loc =>
+ hiveTable.getTTable.getSd.setLocation(loc)}
table.storage.inputFormat.map(toInputFormat).foreach(hiveTable.setInputFormatClass)
table.storage.outputFormat.map(toOutputFormat).foreach(hiveTable.setOutputFormatClass)
hiveTable.setSerializationLib(
@@ -885,7 +888,7 @@ private[hive] object HiveClientImpl {
}
val storageDesc = new StorageDescriptor
val serdeInfo = new SerDeInfo
- p.storage.locationUri.foreach(storageDesc.setLocation)
+ p.storage.locationUri.map(CatalogUtils.URIToString(_)).foreach(storageDesc.setLocation)
p.storage.inputFormat.foreach(storageDesc.setInputFormat)
p.storage.outputFormat.foreach(storageDesc.setOutputFormat)
p.storage.serde.foreach(serdeInfo.setSerializationLib)
@@ -906,7 +909,7 @@ private[hive] object HiveClientImpl {
CatalogTablePartition(
spec = Option(hp.getSpec).map(_.asScala.toMap).getOrElse(Map.empty),
storage = CatalogStorageFormat(
- locationUri = Option(apiPartition.getSd.getLocation),
+ locationUri = Option(CatalogUtils.stringToURI(apiPartition.getSd.getLocation)),
inputFormat = Option(apiPartition.getSd.getInputFormat),
outputFormat = Option(apiPartition.getSd.getOutputFormat),
serde = Option(apiPartition.getSd.getSerdeInfo.getSerializationLib),
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
index 7280748361d6..c6188fc683e7 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
@@ -24,10 +24,9 @@ import java.util.{ArrayList => JArrayList, List => JList, Map => JMap, Set => JS
import java.util.concurrent.TimeUnit
import scala.collection.JavaConverters._
-import scala.util.Try
import scala.util.control.NonFatal
-import org.apache.hadoop.fs.{FileSystem, Path}
+import org.apache.hadoop.fs.Path
import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.metastore.api.{Function => HiveFunction, FunctionType, MetaException, PrincipalType, ResourceType, ResourceUri}
import org.apache.hadoop.hive.ql.Driver
@@ -41,7 +40,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.NoSuchPermanentFunctionException
-import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, CatalogTablePartition, FunctionResource, FunctionResourceType}
+import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, CatalogTablePartition, CatalogUtils, FunctionResource, FunctionResourceType}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{IntegralType, StringType}
@@ -268,7 +267,7 @@ private[client] class Shim_v0_12 extends Shim with Logging {
val table = hive.getTable(database, tableName)
parts.foreach { s =>
val location = s.storage.locationUri.map(
- uri => new Path(table.getPath, new Path(new URI(uri)))).orNull
+ uri => new Path(table.getPath, new Path(uri))).orNull
val params = if (s.parameters.nonEmpty) s.parameters.asJava else null
val spec = s.spec.asJava
if (hive.getPartition(table, spec, false) != null && ignoreIfExists) {
@@ -463,7 +462,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
val addPartitionDesc = new AddPartitionDesc(db, table, ignoreIfExists)
parts.zipWithIndex.foreach { case (s, i) =>
addPartitionDesc.addPartition(
- s.spec.asJava, s.storage.locationUri.map(u => new Path(new URI(u)).toString).orNull)
+ s.spec.asJava, s.storage.locationUri.map(CatalogUtils.URIToString(_)).orNull)
if (s.parameters.nonEmpty) {
addPartitionDesc.getPartition(i).setPartParams(s.parameters.asJava)
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
index 3c57ee4c8b8f..b8536d0c1bd5 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
@@ -393,8 +393,8 @@ case class InsertIntoHiveTable(
logWarning(s"Unable to delete staging directory: $stagingDir.\n" + e)
}
- // Invalidate the cache.
- sparkSession.catalog.uncacheTable(table.qualifiedName)
+ // un-cache this table.
+ sparkSession.catalog.uncacheTable(table.identifier.quotedString)
sparkSession.sessionState.catalog.refreshTable(table.identifier)
// It would be nice to just return the childRdd unchanged so insert operations could be chained,
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
index efc2f0098454..076c40d45932 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
@@ -30,16 +30,17 @@ import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{SparkSession, SQLContext}
-import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
-import org.apache.spark.sql.catalyst.expressions.ExpressionInfo
+import org.apache.spark.sql.{ExperimentalMethods, SparkSession, SQLContext}
+import org.apache.spark.sql.catalyst.analysis.{Analyzer, UnresolvedRelation}
+import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.execution.QueryExecution
+import org.apache.spark.sql.execution.{QueryExecution, SparkPlanner}
import org.apache.spark.sql.execution.command.CacheTableCommand
import org.apache.spark.sql.hive._
-import org.apache.spark.sql.internal.{SharedState, SQLConf}
+import org.apache.spark.sql.hive.client.HiveClient
+import org.apache.spark.sql.internal.{SessionState, SharedState, SQLConf}
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
+import org.apache.spark.sql.streaming.StreamingQueryManager
import org.apache.spark.util.{ShutdownHookManager, Utils}
// SPARK-3729: Test key required to check for initialization errors with config.
@@ -84,7 +85,7 @@ class TestHiveContext(
new TestHiveContext(sparkSession.newSession())
}
- override def sessionState: TestHiveSessionState = sparkSession.sessionState
+ override def sessionState: HiveSessionState = sparkSession.sessionState
def setCacheTables(c: Boolean): Unit = {
sparkSession.setCacheTables(c)
@@ -144,11 +145,35 @@ private[hive] class TestHiveSparkSession(
existingSharedState.getOrElse(new SharedState(sc))
}
- // TODO: Let's remove TestHiveSessionState. Otherwise, we are not really testing the reflection
- // logic based on the setting of CATALOG_IMPLEMENTATION.
@transient
- override lazy val sessionState: TestHiveSessionState =
- new TestHiveSessionState(self)
+ override lazy val sessionState: HiveSessionState = {
+ val testConf =
+ new SQLConf {
+ clear()
+ override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false)
+ override def clear(): Unit = {
+ super.clear()
+ TestHiveContext.overrideConfs.foreach { case (k, v) => setConfString(k, v) }
+ }
+ }
+ val queryExecutionCreator = (plan: LogicalPlan) => new TestHiveQueryExecution(this, plan)
+ val initHelper = HiveSessionState(this, testConf)
+ SessionState.mergeSparkConf(testConf, sparkContext.getConf)
+
+ new HiveSessionState(
+ sparkContext,
+ sharedState,
+ testConf,
+ initHelper.experimentalMethods,
+ initHelper.functionRegistry,
+ initHelper.catalog,
+ initHelper.sqlParser,
+ initHelper.metadataHive,
+ initHelper.analyzer,
+ initHelper.streamingQueryManager,
+ queryExecutionCreator,
+ initHelper.plannerCreator)
+ }
override def newSession(): TestHiveSparkSession = {
new TestHiveSparkSession(sc, Some(sharedState), loadTestTables)
@@ -492,26 +517,6 @@ private[hive] class TestHiveQueryExecution(
}
}
-private[hive] class TestHiveSessionState(
- sparkSession: TestHiveSparkSession)
- extends HiveSessionState(sparkSession) { self =>
-
- override lazy val conf: SQLConf = {
- new SQLConf {
- clear()
- override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false)
- override def clear(): Unit = {
- super.clear()
- TestHiveContext.overrideConfs.foreach { case (k, v) => setConfString(k, v) }
- }
- }
- }
-
- override def executePlan(plan: LogicalPlan): TestHiveQueryExecution = {
- new TestHiveQueryExecution(sparkSession, plan)
- }
-}
-
private[hive] object TestHiveContext {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
index 8ccc2b7527f2..2b3f36064c1f 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
@@ -195,10 +195,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto
tempPath.delete()
table("src").write.mode(SaveMode.Overwrite).parquet(tempPath.toString)
sql("DROP TABLE IF EXISTS refreshTable")
- sparkSession.catalog.createExternalTable("refreshTable", tempPath.toString, "parquet")
- checkAnswer(
- table("refreshTable"),
- table("src").collect())
+ sparkSession.catalog.createTable("refreshTable", tempPath.toString, "parquet")
+ checkAnswer(table("refreshTable"), table("src"))
// Cache the table.
sql("CACHE TABLE refreshTable")
assertCached(table("refreshTable"))
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala
index 6d7a1c3937a9..490e02d0bd54 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.hive
+import java.net.URI
+
import org.apache.spark.sql.{AnalysisException, SaveMode}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
@@ -70,7 +72,7 @@ class HiveDDLCommandSuite extends PlanTest with SQLTestUtils with TestHiveSingle
assert(desc.identifier.database == Some("mydb"))
assert(desc.identifier.table == "page_view")
assert(desc.tableType == CatalogTableType.EXTERNAL)
- assert(desc.storage.locationUri == Some("/user/external/page_view"))
+ assert(desc.storage.locationUri == Some(new URI("/user/external/page_view")))
assert(desc.schema.isEmpty) // will be populated later when the table is actually created
assert(desc.comment == Some("This is the staging page view table"))
// TODO will be SQLText
@@ -102,7 +104,7 @@ class HiveDDLCommandSuite extends PlanTest with SQLTestUtils with TestHiveSingle
assert(desc.identifier.database == Some("mydb"))
assert(desc.identifier.table == "page_view")
assert(desc.tableType == CatalogTableType.EXTERNAL)
- assert(desc.storage.locationUri == Some("/user/external/page_view"))
+ assert(desc.storage.locationUri == Some(new URI("/user/external/page_view")))
assert(desc.schema.isEmpty) // will be populated later when the table is actually created
// TODO will be SQLText
assert(desc.comment == Some("This is the staging page view table"))
@@ -338,7 +340,7 @@ class HiveDDLCommandSuite extends PlanTest with SQLTestUtils with TestHiveSingle
val query = "CREATE EXTERNAL TABLE tab1 (id int, name string) LOCATION '/path/to/nowhere'"
val (desc, _) = extractTableDesc(query)
assert(desc.tableType == CatalogTableType.EXTERNAL)
- assert(desc.storage.locationUri == Some("/path/to/nowhere"))
+ assert(desc.storage.locationUri == Some(new URI("/path/to/nowhere")))
}
test("create table - if not exists") {
@@ -469,7 +471,7 @@ class HiveDDLCommandSuite extends PlanTest with SQLTestUtils with TestHiveSingle
assert(desc.viewText.isEmpty)
assert(desc.viewDefaultDatabase.isEmpty)
assert(desc.viewQueryColumnNames.isEmpty)
- assert(desc.storage.locationUri == Some("/path/to/mercury"))
+ assert(desc.storage.locationUri == Some(new URI("/path/to/mercury")))
assert(desc.storage.inputFormat == Some("winput"))
assert(desc.storage.outputFormat == Some("wowput"))
assert(desc.storage.serde == Some("org.apache.poof.serde.Baff"))
@@ -644,7 +646,7 @@ class HiveDDLCommandSuite extends PlanTest with SQLTestUtils with TestHiveSingle
.add("id", "int")
.add("name", "string", nullable = true, comment = "blabla"))
assert(table.provider == Some(DDLUtils.HIVE_PROVIDER))
- assert(table.storage.locationUri == Some("/tmp/file"))
+ assert(table.storage.locationUri == Some(new URI("/tmp/file")))
assert(table.storage.properties == Map("my_prop" -> "1"))
assert(table.comment == Some("BLABLA"))
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogBackwardCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogBackwardCompatibilitySuite.scala
index ee632d24b717..705d43f1f3ab 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogBackwardCompatibilitySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogBackwardCompatibilitySuite.scala
@@ -40,7 +40,8 @@ class HiveExternalCatalogBackwardCompatibilitySuite extends QueryTest
spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client
val tempDir = Utils.createTempDir().getCanonicalFile
- val tempDirUri = tempDir.toURI.toString.stripSuffix("/")
+ val tempDirUri = tempDir.toURI
+ val tempDirStr = tempDir.getAbsolutePath
override def beforeEach(): Unit = {
sql("CREATE DATABASE test_db")
@@ -59,9 +60,7 @@ class HiveExternalCatalogBackwardCompatibilitySuite extends QueryTest
}
private def defaultTableURI(tableName: String): URI = {
- val defaultPath =
- spark.sessionState.catalog.defaultTablePath(TableIdentifier(tableName, Some("test_db")))
- new Path(defaultPath).toUri
+ spark.sessionState.catalog.defaultTablePath(TableIdentifier(tableName, Some("test_db")))
}
// Raw table metadata that are dumped from tables created by Spark 2.0. Note that, all spark
@@ -170,8 +169,8 @@ class HiveExternalCatalogBackwardCompatibilitySuite extends QueryTest
identifier = TableIdentifier("tbl7", Some("test_db")),
tableType = CatalogTableType.EXTERNAL,
storage = CatalogStorageFormat.empty.copy(
- locationUri = Some(defaultTableURI("tbl7").toString + "-__PLACEHOLDER__"),
- properties = Map("path" -> tempDirUri)),
+ locationUri = Some(new URI(defaultTableURI("tbl7") + "-__PLACEHOLDER__")),
+ properties = Map("path" -> tempDirStr)),
schema = new StructType(),
provider = Some("json"),
properties = Map(
@@ -184,7 +183,7 @@ class HiveExternalCatalogBackwardCompatibilitySuite extends QueryTest
tableType = CatalogTableType.EXTERNAL,
storage = CatalogStorageFormat.empty.copy(
locationUri = Some(tempDirUri),
- properties = Map("path" -> tempDirUri)),
+ properties = Map("path" -> tempDirStr)),
schema = simpleSchema,
properties = Map(
"spark.sql.sources.provider" -> "parquet",
@@ -195,8 +194,8 @@ class HiveExternalCatalogBackwardCompatibilitySuite extends QueryTest
identifier = TableIdentifier("tbl9", Some("test_db")),
tableType = CatalogTableType.EXTERNAL,
storage = CatalogStorageFormat.empty.copy(
- locationUri = Some(defaultTableURI("tbl9").toString + "-__PLACEHOLDER__"),
- properties = Map("path" -> tempDirUri)),
+ locationUri = Some(new URI(defaultTableURI("tbl9") + "-__PLACEHOLDER__")),
+ properties = Map("path" -> tempDirStr)),
schema = new StructType(),
provider = Some("json"),
properties = Map("spark.sql.sources.provider" -> "json"))
@@ -220,7 +219,7 @@ class HiveExternalCatalogBackwardCompatibilitySuite extends QueryTest
if (tbl.tableType == CatalogTableType.EXTERNAL) {
// trim the URI prefix
- val tableLocation = new URI(readBack.storage.locationUri.get).getPath
+ val tableLocation = readBack.storage.locationUri.get.getPath
val expectedLocation = tempDir.toURI.getPath.stripSuffix("/")
assert(tableLocation == expectedLocation)
}
@@ -236,7 +235,7 @@ class HiveExternalCatalogBackwardCompatibilitySuite extends QueryTest
val readBack = getTableMetadata(tbl.identifier.table)
// trim the URI prefix
- val actualTableLocation = new URI(readBack.storage.locationUri.get).getPath
+ val actualTableLocation = readBack.storage.locationUri.get.getPath
val expected = dir.toURI.getPath.stripSuffix("/")
assert(actualTableLocation == expected)
}
@@ -252,7 +251,7 @@ class HiveExternalCatalogBackwardCompatibilitySuite extends QueryTest
assert(readBack.schema.sameType(expectedSchema))
// trim the URI prefix
- val actualTableLocation = new URI(readBack.storage.locationUri.get).getPath
+ val actualTableLocation = readBack.storage.locationUri.get.getPath
val expectedLocation = if (tbl.tableType == CatalogTableType.EXTERNAL) {
tempDir.toURI.getPath.stripSuffix("/")
} else {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala
index 16cf4d7ec67f..079358b29a19 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala
@@ -62,7 +62,7 @@ class HiveMetastoreCatalogSuite extends TestHiveSingleton with SQLTestUtils {
spark.sql("create view vw1 as select 1 as id")
val plan = spark.sql("select id from vw1").queryExecution.analyzed
val aliases = plan.collect {
- case x @ SubqueryAlias("vw1", _, Some(TableIdentifier("vw1", Some("default")))) => x
+ case x @ SubqueryAlias("vw1", _) => x
}
assert(aliases.size == 1)
}
@@ -140,7 +140,7 @@ class DataSourceWithHiveMetastoreCatalogSuite
assert(hiveTable.storage.serde === Some(serde))
assert(hiveTable.tableType === CatalogTableType.EXTERNAL)
- assert(hiveTable.storage.locationUri === Some(path.toString))
+ assert(hiveTable.storage.locationUri === Some(makeQualifiedPath(dir.getAbsolutePath)))
val columns = hiveTable.schema
assert(columns.map(_.name) === Seq("d1", "d2"))
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala
new file mode 100644
index 000000000000..e48ce2304d08
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala
@@ -0,0 +1,326 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+import java.io.File
+import java.util.concurrent.{Executors, TimeUnit}
+
+import scala.util.Random
+
+import org.scalatest.BeforeAndAfterEach
+
+import org.apache.spark.metrics.source.HiveCatalogMetrics
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.catalog._
+import org.apache.spark.sql.execution.datasources.FileStatusCache
+import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.hive.client.HiveClient
+import org.apache.spark.sql.hive.test.TestHiveSingleton
+import org.apache.spark.sql.internal.{HiveSerDe, SQLConf}
+import org.apache.spark.sql.internal.SQLConf.HiveCaseSensitiveInferenceMode.{Value => InferenceMode, _}
+import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.sql.types._
+
+class HiveSchemaInferenceSuite
+ extends QueryTest with TestHiveSingleton with SQLTestUtils with BeforeAndAfterEach {
+
+ import HiveSchemaInferenceSuite._
+ import HiveExternalCatalog.DATASOURCE_SCHEMA_PREFIX
+
+ override def beforeEach(): Unit = {
+ super.beforeEach()
+ FileStatusCache.resetForTesting()
+ }
+
+ override def afterEach(): Unit = {
+ super.afterEach()
+ spark.sessionState.catalog.tableRelationCache.invalidateAll()
+ FileStatusCache.resetForTesting()
+ }
+
+ private val externalCatalog = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog]
+ private val client = externalCatalog.client
+
+ // Return a copy of the given schema with all field names converted to lower case.
+ private def lowerCaseSchema(schema: StructType): StructType = {
+ StructType(schema.map(f => f.copy(name = f.name.toLowerCase)))
+ }
+
+ // Create a Hive external test table containing the given field and partition column names.
+ // Returns a case-sensitive schema for the table.
+ private def setupExternalTable(
+ fileType: String,
+ fields: Seq[String],
+ partitionCols: Seq[String],
+ dir: File): StructType = {
+ // Treat all table fields as bigints...
+ val structFields = fields.map { field =>
+ StructField(
+ name = field,
+ dataType = LongType,
+ nullable = true,
+ metadata = new MetadataBuilder().putString(HIVE_TYPE_STRING, "bigint").build())
+ }
+ // and all partition columns as ints
+ val partitionStructFields = partitionCols.map { field =>
+ StructField(
+ // Partition column case isn't preserved
+ name = field.toLowerCase,
+ dataType = IntegerType,
+ nullable = true,
+ metadata = new MetadataBuilder().putString(HIVE_TYPE_STRING, "int").build())
+ }
+ val schema = StructType(structFields ++ partitionStructFields)
+
+ // Write some test data (partitioned if specified)
+ val writer = spark.range(NUM_RECORDS)
+ .selectExpr((fields ++ partitionCols).map("id as " + _): _*)
+ .write
+ .partitionBy(partitionCols: _*)
+ .mode("overwrite")
+ fileType match {
+ case ORC_FILE_TYPE =>
+ writer.orc(dir.getAbsolutePath)
+ case PARQUET_FILE_TYPE =>
+ writer.parquet(dir.getAbsolutePath)
+ }
+
+ // Create Hive external table with lowercased schema
+ val serde = HiveSerDe.serdeMap(fileType)
+ client.createTable(
+ CatalogTable(
+ identifier = TableIdentifier(table = TEST_TABLE_NAME, database = Option(DATABASE)),
+ tableType = CatalogTableType.EXTERNAL,
+ storage = CatalogStorageFormat(
+ locationUri = Option(new java.net.URI(dir.getAbsolutePath)),
+ inputFormat = serde.inputFormat,
+ outputFormat = serde.outputFormat,
+ serde = serde.serde,
+ compressed = false,
+ properties = Map("serialization.format" -> "1")),
+ schema = schema,
+ provider = Option("hive"),
+ partitionColumnNames = partitionCols.map(_.toLowerCase),
+ properties = Map.empty),
+ true)
+
+ // Add partition records (if specified)
+ if (!partitionCols.isEmpty) {
+ spark.catalog.recoverPartitions(TEST_TABLE_NAME)
+ }
+
+ // Check that the table returned by HiveExternalCatalog has schemaPreservesCase set to false
+ // and that the raw table returned by the Hive client doesn't have any Spark SQL properties
+ // set (table needs to be obtained from client since HiveExternalCatalog filters these
+ // properties out).
+ assert(!externalCatalog.getTable(DATABASE, TEST_TABLE_NAME).schemaPreservesCase)
+ val rawTable = client.getTable(DATABASE, TEST_TABLE_NAME)
+ assert(rawTable.properties.filterKeys(_.startsWith(DATASOURCE_SCHEMA_PREFIX)) == Map.empty)
+ schema
+ }
+
+ private def withTestTables(
+ fileType: String)(f: (Seq[String], Seq[String], StructType) => Unit): Unit = {
+ // Test both a partitioned and unpartitioned Hive table
+ val tableFields = Seq(
+ (Seq("fieldOne"), Seq("partCol1", "partCol2")),
+ (Seq("fieldOne", "fieldTwo"), Seq.empty[String]))
+
+ tableFields.foreach { case (fields, partCols) =>
+ withTempDir { dir =>
+ val schema = setupExternalTable(fileType, fields, partCols, dir)
+ withTable(TEST_TABLE_NAME) { f(fields, partCols, schema) }
+ }
+ }
+ }
+
+ private def withFileTypes(f: (String) => Unit): Unit
+ = Seq(ORC_FILE_TYPE, PARQUET_FILE_TYPE).foreach(f)
+
+ private def withInferenceMode(mode: InferenceMode)(f: => Unit): Unit = {
+ withSQLConf(
+ HiveUtils.CONVERT_METASTORE_ORC.key -> "true",
+ SQLConf.HIVE_CASE_SENSITIVE_INFERENCE.key -> mode.toString)(f)
+ }
+
+ private val inferenceKey = SQLConf.HIVE_CASE_SENSITIVE_INFERENCE.key
+
+ private def testFieldQuery(fields: Seq[String]): Unit = {
+ if (!fields.isEmpty) {
+ val query = s"SELECT * FROM ${TEST_TABLE_NAME} WHERE ${Random.shuffle(fields).head} >= 0"
+ assert(spark.sql(query).count == NUM_RECORDS)
+ }
+ }
+
+ private def testTableSchema(expectedSchema: StructType): Unit
+ = assert(spark.table(TEST_TABLE_NAME).schema == expectedSchema)
+
+ withFileTypes { fileType =>
+ test(s"$fileType: schema should be inferred and saved when INFER_AND_SAVE is specified") {
+ withInferenceMode(INFER_AND_SAVE) {
+ withTestTables(fileType) { (fields, partCols, schema) =>
+ testFieldQuery(fields)
+ testFieldQuery(partCols)
+ testTableSchema(schema)
+
+ // Verify the catalog table now contains the updated schema and properties
+ val catalogTable = externalCatalog.getTable(DATABASE, TEST_TABLE_NAME)
+ assert(catalogTable.schemaPreservesCase)
+ assert(catalogTable.schema == schema)
+ assert(catalogTable.partitionColumnNames == partCols.map(_.toLowerCase))
+ }
+ }
+ }
+ }
+
+ withFileTypes { fileType =>
+ test(s"$fileType: schema should be inferred but not stored when INFER_ONLY is specified") {
+ withInferenceMode(INFER_ONLY) {
+ withTestTables(fileType) { (fields, partCols, schema) =>
+ val originalTable = externalCatalog.getTable(DATABASE, TEST_TABLE_NAME)
+ testFieldQuery(fields)
+ testFieldQuery(partCols)
+ testTableSchema(schema)
+ // Catalog table shouldn't be altered
+ assert(externalCatalog.getTable(DATABASE, TEST_TABLE_NAME) == originalTable)
+ }
+ }
+ }
+ }
+
+ withFileTypes { fileType =>
+ test(s"$fileType: schema should not be inferred when NEVER_INFER is specified") {
+ withInferenceMode(NEVER_INFER) {
+ withTestTables(fileType) { (fields, partCols, schema) =>
+ val originalTable = externalCatalog.getTable(DATABASE, TEST_TABLE_NAME)
+ // Only check the table schema as the test queries will break
+ testTableSchema(lowerCaseSchema(schema))
+ assert(externalCatalog.getTable(DATABASE, TEST_TABLE_NAME) == originalTable)
+ }
+ }
+ }
+ }
+
+ test("mergeWithMetastoreSchema() should return expected results") {
+ // Field type conflict resolution
+ assertResult(
+ StructType(Seq(
+ StructField("lowerCase", StringType),
+ StructField("UPPERCase", DoubleType, nullable = false)))) {
+
+ HiveMetastoreCatalog.mergeWithMetastoreSchema(
+ StructType(Seq(
+ StructField("lowercase", StringType),
+ StructField("uppercase", DoubleType, nullable = false))),
+
+ StructType(Seq(
+ StructField("lowerCase", BinaryType),
+ StructField("UPPERCase", IntegerType, nullable = true))))
+ }
+
+ // MetaStore schema is subset of parquet schema
+ assertResult(
+ StructType(Seq(
+ StructField("UPPERCase", DoubleType, nullable = false)))) {
+
+ HiveMetastoreCatalog.mergeWithMetastoreSchema(
+ StructType(Seq(
+ StructField("uppercase", DoubleType, nullable = false))),
+
+ StructType(Seq(
+ StructField("lowerCase", BinaryType),
+ StructField("UPPERCase", IntegerType, nullable = true))))
+ }
+
+ // Metastore schema contains additional non-nullable fields.
+ assert(intercept[Throwable] {
+ HiveMetastoreCatalog.mergeWithMetastoreSchema(
+ StructType(Seq(
+ StructField("uppercase", DoubleType, nullable = false),
+ StructField("lowerCase", BinaryType, nullable = false))),
+
+ StructType(Seq(
+ StructField("UPPERCase", IntegerType, nullable = true))))
+ }.getMessage.contains("Detected conflicting schemas"))
+
+ // Conflicting non-nullable field names
+ intercept[Throwable] {
+ HiveMetastoreCatalog.mergeWithMetastoreSchema(
+ StructType(Seq(StructField("lower", StringType, nullable = false))),
+ StructType(Seq(StructField("lowerCase", BinaryType))))
+ }
+
+ // Check that merging missing nullable fields works as expected.
+ assertResult(
+ StructType(Seq(
+ StructField("firstField", StringType, nullable = true),
+ StructField("secondField", StringType, nullable = true),
+ StructField("thirdfield", StringType, nullable = true)))) {
+ HiveMetastoreCatalog.mergeWithMetastoreSchema(
+ StructType(Seq(
+ StructField("firstfield", StringType, nullable = true),
+ StructField("secondfield", StringType, nullable = true),
+ StructField("thirdfield", StringType, nullable = true))),
+ StructType(Seq(
+ StructField("firstField", StringType, nullable = true),
+ StructField("secondField", StringType, nullable = true))))
+ }
+
+ // Merge should fail if the Metastore contains any additional fields that are not
+ // nullable.
+ assert(intercept[Throwable] {
+ HiveMetastoreCatalog.mergeWithMetastoreSchema(
+ StructType(Seq(
+ StructField("firstfield", StringType, nullable = true),
+ StructField("secondfield", StringType, nullable = true),
+ StructField("thirdfield", StringType, nullable = false))),
+ StructType(Seq(
+ StructField("firstField", StringType, nullable = true),
+ StructField("secondField", StringType, nullable = true))))
+ }.getMessage.contains("Detected conflicting schemas"))
+
+ // Schema merge should maintain metastore order.
+ assertResult(
+ StructType(Seq(
+ StructField("first_field", StringType, nullable = true),
+ StructField("second_field", StringType, nullable = true),
+ StructField("third_field", StringType, nullable = true),
+ StructField("fourth_field", StringType, nullable = true),
+ StructField("fifth_field", StringType, nullable = true)))) {
+ HiveMetastoreCatalog.mergeWithMetastoreSchema(
+ StructType(Seq(
+ StructField("first_field", StringType, nullable = true),
+ StructField("second_field", StringType, nullable = true),
+ StructField("third_field", StringType, nullable = true),
+ StructField("fourth_field", StringType, nullable = true),
+ StructField("fifth_field", StringType, nullable = true))),
+ StructType(Seq(
+ StructField("fifth_field", StringType, nullable = true),
+ StructField("third_field", StringType, nullable = true),
+ StructField("second_field", StringType, nullable = true))))
+ }
+ }
+}
+
+object HiveSchemaInferenceSuite {
+ private val NUM_RECORDS = 10
+ private val DATABASE = "default"
+ private val TEST_TABLE_NAME = "test_table"
+ private val ORC_FILE_TYPE = "orc"
+ private val PARQUET_FILE_TYPE = "parquet"
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionCatalogSuite.scala
new file mode 100644
index 000000000000..3b0f59b15916
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionCatalogSuite.scala
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+import java.net.URI
+
+import org.apache.hadoop.conf.Configuration
+
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.analysis.SimpleFunctionRegistry
+import org.apache.spark.sql.catalyst.catalog.CatalogDatabase
+import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
+import org.apache.spark.sql.catalyst.plans.logical.Range
+import org.apache.spark.sql.hive.test.TestHiveSingleton
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.util.Utils
+
+class HiveSessionCatalogSuite extends TestHiveSingleton {
+
+ test("clone HiveSessionCatalog") {
+ val original = spark.sessionState.catalog.asInstanceOf[HiveSessionCatalog]
+
+ val tempTableName1 = "copytest1"
+ val tempTableName2 = "copytest2"
+ try {
+ val tempTable1 = Range(1, 10, 1, 10)
+ original.createTempView(tempTableName1, tempTable1, overrideIfExists = false)
+
+ // check if tables copied over
+ val clone = original.newSessionCatalogWith(
+ spark,
+ new SQLConf,
+ new Configuration(),
+ new SimpleFunctionRegistry,
+ CatalystSqlParser)
+ assert(original ne clone)
+ assert(clone.getTempView(tempTableName1) == Some(tempTable1))
+
+ // check if clone and original independent
+ clone.dropTable(TableIdentifier(tempTableName1), ignoreIfNotExists = false, purge = false)
+ assert(original.getTempView(tempTableName1) == Some(tempTable1))
+
+ val tempTable2 = Range(1, 20, 2, 10)
+ original.createTempView(tempTableName2, tempTable2, overrideIfExists = false)
+ assert(clone.getTempView(tempTableName2).isEmpty)
+ } finally {
+ // Drop the created temp views from the global singleton HiveSession.
+ original.dropTable(TableIdentifier(tempTableName1), ignoreIfNotExists = true, purge = true)
+ original.dropTable(TableIdentifier(tempTableName2), ignoreIfNotExists = true, purge = true)
+ }
+ }
+
+ test("clone SessionCatalog - current db") {
+ val original = spark.sessionState.catalog.asInstanceOf[HiveSessionCatalog]
+ val originalCurrentDatabase = original.getCurrentDatabase
+ val db1 = "db1"
+ val db2 = "db2"
+ val db3 = "db3"
+ try {
+ original.createDatabase(newDb(db1), ignoreIfExists = true)
+ original.createDatabase(newDb(db2), ignoreIfExists = true)
+ original.createDatabase(newDb(db3), ignoreIfExists = true)
+
+ original.setCurrentDatabase(db1)
+
+ // check if tables copied over
+ val clone = original.newSessionCatalogWith(
+ spark,
+ new SQLConf,
+ new Configuration(),
+ new SimpleFunctionRegistry,
+ CatalystSqlParser)
+
+ // check if current db copied over
+ assert(original ne clone)
+ assert(clone.getCurrentDatabase == db1)
+
+ // check if clone and original independent
+ clone.setCurrentDatabase(db2)
+ assert(original.getCurrentDatabase == db1)
+ original.setCurrentDatabase(db3)
+ assert(clone.getCurrentDatabase == db2)
+ } finally {
+ // Drop the created databases from the global singleton HiveSession.
+ original.dropDatabase(db1, ignoreIfNotExists = true, cascade = true)
+ original.dropDatabase(db2, ignoreIfNotExists = true, cascade = true)
+ original.dropDatabase(db3, ignoreIfNotExists = true, cascade = true)
+ original.setCurrentDatabase(originalCurrentDatabase)
+ }
+ }
+
+ def newUriForDatabase(): URI = new URI(Utils.createTempDir().toURI.toString.stripSuffix("/"))
+
+ def newDb(name: String): CatalogDatabase = {
+ CatalogDatabase(name, name + " description", newUriForDatabase(), Map.empty)
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala
new file mode 100644
index 000000000000..67c77fb62f4e
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+import org.scalatest.BeforeAndAfterEach
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.hive.test.TestHiveSingleton
+
+/**
+ * Run all tests from `SessionStateSuite` with a `HiveSessionState`.
+ */
+class HiveSessionStateSuite extends SessionStateSuite
+ with TestHiveSingleton with BeforeAndAfterEach {
+
+ override def beforeAll(): Unit = {
+ // Reuse the singleton session
+ activeSession = spark
+ }
+
+ override def afterAll(): Unit = {
+ // Set activeSession to null to avoid stopping the singleton session
+ activeSession = null
+ super.afterAll()
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala
index 8f0d5d886c9d..5f15a705a2e9 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala
@@ -485,7 +485,7 @@ object SetWarehouseLocationTest extends Logging {
val tableMetadata =
catalog.getTableMetadata(TableIdentifier("testLocation", Some("default")))
val expectedLocation =
- "file:" + expectedWarehouseLocation.toString + "/testlocation"
+ CatalogUtils.stringToURI(s"file:${expectedWarehouseLocation.toString}/testlocation")
val actualLocation = tableMetadata.location
if (actualLocation != expectedLocation) {
throw new Exception(
@@ -500,8 +500,8 @@ object SetWarehouseLocationTest extends Logging {
sparkSession.sql("create table testLocation (a int)")
val tableMetadata =
catalog.getTableMetadata(TableIdentifier("testLocation", Some("testLocationDB")))
- val expectedLocation =
- "file:" + expectedWarehouseLocation.toString + "/testlocationdb.db/testlocation"
+ val expectedLocation = CatalogUtils.stringToURI(
+ s"file:${expectedWarehouseLocation.toString}/testlocationdb.db/testlocation")
val actualLocation = tableMetadata.location
if (actualLocation != expectedLocation) {
throw new Exception(
@@ -868,14 +868,16 @@ object SPARK_18360 {
val rawTable = hiveClient.getTable("default", "test_tbl")
// Hive will use the value of `hive.metastore.warehouse.dir` to generate default table
// location for tables in default database.
- assert(rawTable.storage.locationUri.get.contains(newWarehousePath))
+ assert(rawTable.storage.locationUri.map(
+ CatalogUtils.URIToString(_)).get.contains(newWarehousePath))
hiveClient.dropTable("default", "test_tbl", ignoreIfNotExists = false, purge = false)
spark.sharedState.externalCatalog.createTable(tableMeta, ignoreIfExists = false)
val readBack = spark.sharedState.externalCatalog.getTable("default", "test_tbl")
// Spark SQL will use the location of default database to generate default table
// location for tables in default database.
- assert(readBack.storage.locationUri.get.contains(defaultDbLocation))
+ assert(readBack.storage.locationUri.map(CatalogUtils.URIToString(_))
+ .get.contains(defaultDbLocation))
} finally {
hiveClient.dropTable("default", "test_tbl", ignoreIfNotExists = true, purge = false)
hiveClient.runSqlHive(s"SET hive.metastore.warehouse.dir=$defaultDbLocation")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
index 03ea0c8c7768..f02b7218d6ee 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
@@ -1011,7 +1011,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv
identifier = TableIdentifier("not_skip_hive_metadata"),
tableType = CatalogTableType.EXTERNAL,
storage = CatalogStorageFormat.empty.copy(
- locationUri = Some(tempPath.getCanonicalPath),
+ locationUri = Some(tempPath.toURI),
properties = Map("skipHiveMetadata" -> "false")
),
schema = schema,
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala
index 47ee4dd4d952..4aea6d14efb0 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala
@@ -17,6 +17,10 @@
package org.apache.spark.sql.hive
+import java.net.URI
+
+import org.apache.hadoop.fs.Path
+
import org.apache.spark.sql.{AnalysisException, QueryTest, SaveMode}
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.test.SQLTestUtils
@@ -26,8 +30,8 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle
private def checkTablePath(dbName: String, tableName: String): Unit = {
val metastoreTable = spark.sharedState.externalCatalog.getTable(dbName, tableName)
- val expectedPath =
- spark.sharedState.externalCatalog.getDatabase(dbName).locationUri + "/" + tableName
+ val expectedPath = new Path(new Path(
+ spark.sharedState.externalCatalog.getDatabase(dbName).locationUri), tableName).toUri
assert(metastoreTable.location === expectedPath)
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala
index d61d10bf869e..6025f8adbce2 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.hive.client
import java.io.{ByteArrayOutputStream, File, PrintStream}
+import java.net.URI
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
@@ -54,7 +55,7 @@ class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton w
test("success sanity check") {
val badClient = buildClient(HiveUtils.hiveExecutionVersion, new Configuration())
- val db = new CatalogDatabase("default", "desc", "loc", Map())
+ val db = new CatalogDatabase("default", "desc", new URI("loc"), Map())
badClient.createDatabase(db, ignoreIfExists = true)
}
@@ -125,10 +126,10 @@ class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton w
// Database related API
///////////////////////////////////////////////////////////////////////////
- val tempDatabasePath = Utils.createTempDir().getCanonicalPath
+ val tempDatabasePath = Utils.createTempDir().toURI
test(s"$version: createDatabase") {
- val defaultDB = CatalogDatabase("default", "desc", "loc", Map())
+ val defaultDB = CatalogDatabase("default", "desc", new URI("loc"), Map())
client.createDatabase(defaultDB, ignoreIfExists = true)
val tempDB = CatalogDatabase(
"temporary", description = "test create", tempDatabasePath, Map())
@@ -346,7 +347,7 @@ class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton w
test(s"$version: alterPartitions") {
val spec = Map("key1" -> "1", "key2" -> "2")
- val newLocation = Utils.createTempDir().getPath()
+ val newLocation = new URI(Utils.createTempDir().toURI.toString.stripSuffix("/"))
val storage = storageFormat.copy(
locationUri = Some(newLocation),
// needed for 0.12 alter partitions
@@ -657,19 +658,17 @@ class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton w
val tPath = new Path(spark.sessionState.conf.warehousePath, "t")
Seq("1").toDF("a").write.saveAsTable("t")
- val expectedPath = s"file:${tPath.toUri.getPath.stripSuffix("/")}"
val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t"))
- assert(table.location.stripSuffix("/") == expectedPath)
+ assert(table.location == makeQualifiedPath(tPath.toString))
assert(tPath.getFileSystem(spark.sessionState.newHadoopConf()).exists(tPath))
checkAnswer(spark.table("t"), Row("1") :: Nil)
val t1Path = new Path(spark.sessionState.conf.warehousePath, "t1")
spark.sql("create table t1 using parquet as select 2 as a")
val table1 = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1"))
- val expectedPath1 = s"file:${t1Path.toUri.getPath.stripSuffix("/")}"
- assert(table1.location.stripSuffix("/") == expectedPath1)
+ assert(table1.location == makeQualifiedPath(t1Path.toString))
assert(t1Path.getFileSystem(spark.sessionState.newHadoopConf()).exists(t1Path))
checkAnswer(spark.table("t1"), Row(2) :: Nil)
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala
index 81ae5b7bdb67..d29242bb47e3 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala
@@ -18,6 +18,8 @@
package org.apache.spark.sql.hive.execution
import java.io.File
+import java.lang.reflect.InvocationTargetException
+import java.net.URI
import org.apache.hadoop.fs.Path
import org.scalatest.BeforeAndAfterEach
@@ -25,16 +27,88 @@ import org.scalatest.BeforeAndAfterEach
import org.apache.spark.SparkException
import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SaveMode}
import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, TableAlreadyExistsException}
-import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable, CatalogTableType}
+import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.execution.command.DDLUtils
+import org.apache.spark.sql.execution.command.{DDLSuite, DDLUtils}
import org.apache.spark.sql.hive.HiveExternalCatalog
import org.apache.spark.sql.hive.orc.OrcFileOperator
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
import org.apache.spark.sql.test.SQLTestUtils
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{MetadataBuilder, StructType}
+
+// TODO(gatorsmile): combine HiveCatalogedDDLSuite and HiveDDLSuite
+class HiveCatalogedDDLSuite extends DDLSuite with TestHiveSingleton with BeforeAndAfterEach {
+ override def afterEach(): Unit = {
+ try {
+ // drop all databases, tables and functions after each test
+ spark.sessionState.catalog.reset()
+ } finally {
+ super.afterEach()
+ }
+ }
+
+ protected override def generateTable(
+ catalog: SessionCatalog,
+ name: TableIdentifier): CatalogTable = {
+ val storage =
+ CatalogStorageFormat(
+ locationUri = Some(catalog.defaultTablePath(name)),
+ inputFormat = Some("org.apache.hadoop.mapred.SequenceFileInputFormat"),
+ outputFormat = Some("org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat"),
+ serde = Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"),
+ compressed = false,
+ properties = Map("serialization.format" -> "1"))
+ val metadata = new MetadataBuilder()
+ .putString("key", "value")
+ .build()
+ CatalogTable(
+ identifier = name,
+ tableType = CatalogTableType.EXTERNAL,
+ storage = storage,
+ schema = new StructType()
+ .add("col1", "int", nullable = true, metadata = metadata)
+ .add("col2", "string")
+ .add("a", "int")
+ .add("b", "int"),
+ provider = Some("hive"),
+ partitionColumnNames = Seq("a", "b"),
+ createTime = 0L,
+ tracksPartitionsInCatalog = true)
+ }
+
+ protected override def normalizeCatalogTable(table: CatalogTable): CatalogTable = {
+ val nondeterministicProps = Set(
+ "CreateTime",
+ "transient_lastDdlTime",
+ "grantTime",
+ "lastUpdateTime",
+ "last_modified_by",
+ "last_modified_time",
+ "Owner:",
+ "COLUMN_STATS_ACCURATE",
+ // The following are hive specific schema parameters which we do not need to match exactly.
+ "numFiles",
+ "numRows",
+ "rawDataSize",
+ "totalSize",
+ "totalNumberFiles",
+ "maxFileSize",
+ "minFileSize"
+ )
+
+ table.copy(
+ createTime = 0L,
+ lastAccessTime = 0L,
+ owner = "",
+ properties = table.properties.filterKeys(!nondeterministicProps.contains(_)),
+ // View texts are checked separately
+ viewText = None
+ )
+ }
+
+}
class HiveDDLSuite
extends QueryTest with SQLTestUtils with TestHiveSingleton with BeforeAndAfterEach {
@@ -710,7 +784,7 @@ class HiveDDLSuite
}
sql(s"CREATE DATABASE $dbName Location '${tmpDir.toURI.getPath.stripSuffix("/")}'")
val db1 = catalog.getDatabaseMetadata(dbName)
- val dbPath = tmpDir.toURI.toString.stripSuffix("/")
+ val dbPath = new URI(tmpDir.toURI.toString.stripSuffix("/"))
assert(db1 == CatalogDatabase(dbName, "", dbPath, Map.empty))
sql("USE db1")
@@ -747,11 +821,12 @@ class HiveDDLSuite
sql(s"CREATE DATABASE $dbName")
val catalog = spark.sessionState.catalog
val expectedDBLocation = s"file:${dbPath.toUri.getPath.stripSuffix("/")}/$dbName.db"
+ val expectedDBUri = CatalogUtils.stringToURI(expectedDBLocation)
val db1 = catalog.getDatabaseMetadata(dbName)
assert(db1 == CatalogDatabase(
dbName,
"",
- expectedDBLocation,
+ expectedDBUri,
Map.empty))
// the database directory was created
assert(fs.exists(dbPath) && fs.isDirectory(dbPath))
@@ -1588,101 +1663,201 @@ class HiveDDLSuite
}
}
+ test("create hive table with a non-existing location") {
+ withTable("t", "t1") {
+ withTempPath { dir =>
+ spark.sql(s"CREATE TABLE t(a int, b int) USING hive LOCATION '$dir'")
+
+ val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t"))
+ assert(table.location == makeQualifiedPath(dir.getAbsolutePath))
+
+ spark.sql("INSERT INTO TABLE t SELECT 1, 2")
+ assert(dir.exists())
+
+ checkAnswer(spark.table("t"), Row(1, 2))
+ }
+ // partition table
+ withTempPath { dir =>
+ spark.sql(
+ s"""
+ |CREATE TABLE t1(a int, b int)
+ |USING hive
+ |PARTITIONED BY(a)
+ |LOCATION '$dir'
+ """.stripMargin)
+
+ val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1"))
+ assert(table.location == makeQualifiedPath(dir.getAbsolutePath))
+
+ spark.sql("INSERT INTO TABLE t1 PARTITION(a=1) SELECT 2")
+
+ val partDir = new File(dir, "a=1")
+ assert(partDir.exists())
+
+ checkAnswer(spark.table("t1"), Row(2, 1))
+ }
+ }
+ }
+
Seq(true, false).foreach { shouldDelete =>
- val tcName = if (shouldDelete) "non-existent" else "existed"
- test(s"CTAS for external data source table with a $tcName location") {
+ val tcName = if (shouldDelete) "non-existing" else "existed"
+
+ test(s"CTAS for external hive table with a $tcName location") {
withTable("t", "t1") {
- withTempDir {
- dir =>
- if (shouldDelete) {
- dir.delete()
- }
+ withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") {
+ withTempDir { dir =>
+ if (shouldDelete) dir.delete()
spark.sql(
s"""
|CREATE TABLE t
- |USING parquet
+ |USING hive
|LOCATION '$dir'
|AS SELECT 3 as a, 4 as b, 1 as c, 2 as d
""".stripMargin)
-
val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t"))
- assert(table.location == dir.getAbsolutePath)
+ assert(table.location == makeQualifiedPath(dir.getAbsolutePath))
checkAnswer(spark.table("t"), Row(3, 4, 1, 2))
- }
- // partition table
- withTempDir {
- dir =>
- if (shouldDelete) {
- dir.delete()
- }
+ }
+ // partition table
+ withTempDir { dir =>
+ if (shouldDelete) dir.delete()
spark.sql(
s"""
|CREATE TABLE t1
- |USING parquet
+ |USING hive
|PARTITIONED BY(a, b)
|LOCATION '$dir'
|AS SELECT 3 as a, 4 as b, 1 as c, 2 as d
""".stripMargin)
-
val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1"))
- assert(table.location == dir.getAbsolutePath)
+ assert(table.location == makeQualifiedPath(dir.getAbsolutePath))
val partDir = new File(dir, "a=3")
assert(partDir.exists())
checkAnswer(spark.table("t1"), Row(1, 2, 3, 4))
+ }
}
}
}
+ }
- test(s"CTAS for external hive table with a $tcName location") {
- withTable("t", "t1") {
- withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") {
- withTempDir {
- dir =>
- if (shouldDelete) {
- dir.delete()
- }
- spark.sql(
- s"""
- |CREATE TABLE t
- |USING hive
- |LOCATION '$dir'
- |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d
- """.stripMargin)
- val dirPath = new Path(dir.getAbsolutePath)
- val fs = dirPath.getFileSystem(spark.sessionState.newHadoopConf())
- val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t"))
- assert(new Path(table.location) == fs.makeQualified(dirPath))
-
- checkAnswer(spark.table("t"), Row(3, 4, 1, 2))
+ Seq("parquet", "hive").foreach { datasource =>
+ Seq("a b", "a:b", "a%b", "a,b").foreach { specialChars =>
+ test(s"partition column name of $datasource table containing $specialChars") {
+ withTable("t") {
+ withTempDir { dir =>
+ spark.sql(
+ s"""
+ |CREATE TABLE t(a string, `$specialChars` string)
+ |USING $datasource
+ |PARTITIONED BY(`$specialChars`)
+ |LOCATION '$dir'
+ """.stripMargin)
+
+ assert(dir.listFiles().isEmpty)
+ spark.sql(s"INSERT INTO TABLE t PARTITION(`$specialChars`=2) SELECT 1")
+ val partEscaped = s"${ExternalCatalogUtils.escapePathName(specialChars)}=2"
+ val partFile = new File(dir, partEscaped)
+ assert(partFile.listFiles().length >= 1)
+ checkAnswer(spark.table("t"), Row("1", "2") :: Nil)
+
+ withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") {
+ spark.sql(s"INSERT INTO TABLE t PARTITION(`$specialChars`) SELECT 3, 4")
+ val partEscaped1 = s"${ExternalCatalogUtils.escapePathName(specialChars)}=4"
+ val partFile1 = new File(dir, partEscaped1)
+ assert(partFile1.listFiles().length >= 1)
+ checkAnswer(spark.table("t"), Row("1", "2") :: Row("3", "4") :: Nil)
+ }
}
- // partition table
- withTempDir {
- dir =>
- if (shouldDelete) {
- dir.delete()
- }
- spark.sql(
- s"""
- |CREATE TABLE t1
- |USING hive
- |PARTITIONED BY(a, b)
- |LOCATION '$dir'
- |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d
- """.stripMargin)
- val dirPath = new Path(dir.getAbsolutePath)
- val fs = dirPath.getFileSystem(spark.sessionState.newHadoopConf())
- val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1"))
- assert(new Path(table.location) == fs.makeQualified(dirPath))
-
- val partDir = new File(dir, "a=3")
- assert(partDir.exists())
-
- checkAnswer(spark.table("t1"), Row(1, 2, 3, 4))
+ }
+ }
+ }
+ }
+
+ Seq("a b", "a:b", "a%b").foreach { specialChars =>
+ test(s"hive table: location uri contains $specialChars") {
+ withTable("t") {
+ withTempDir { dir =>
+ val loc = new File(dir, specialChars)
+ loc.mkdir()
+ spark.sql(
+ s"""
+ |CREATE TABLE t(a string)
+ |USING hive
+ |LOCATION '$loc'
+ """.stripMargin)
+
+ val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t"))
+ assert(table.location == makeQualifiedPath(loc.getAbsolutePath))
+ assert(new Path(table.location).toString.contains(specialChars))
+
+ assert(loc.listFiles().isEmpty)
+ if (specialChars != "a:b") {
+ spark.sql("INSERT INTO TABLE t SELECT 1")
+ assert(loc.listFiles().length >= 1)
+ checkAnswer(spark.table("t"), Row("1") :: Nil)
+ } else {
+ val e = intercept[InvocationTargetException] {
+ spark.sql("INSERT INTO TABLE t SELECT 1")
+ }.getTargetException.getMessage
+ assert(e.contains("java.net.URISyntaxException: Relative path in absolute URI: a:b"))
}
}
+
+ withTempDir { dir =>
+ val loc = new File(dir, specialChars)
+ loc.mkdir()
+ spark.sql(
+ s"""
+ |CREATE TABLE t1(a string, b string)
+ |USING hive
+ |PARTITIONED BY(b)
+ |LOCATION '$loc'
+ """.stripMargin)
+
+ val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1"))
+ assert(table.location == makeQualifiedPath(loc.getAbsolutePath))
+ assert(new Path(table.location).toString.contains(specialChars))
+
+ assert(loc.listFiles().isEmpty)
+ if (specialChars != "a:b") {
+ spark.sql("INSERT INTO TABLE t1 PARTITION(b=2) SELECT 1")
+ val partFile = new File(loc, "b=2")
+ assert(partFile.listFiles().length >= 1)
+ checkAnswer(spark.table("t1"), Row("1", "2") :: Nil)
+
+ spark.sql("INSERT INTO TABLE t1 PARTITION(b='2017-03-03 12:13%3A14') SELECT 1")
+ val partFile1 = new File(loc, "b=2017-03-03 12:13%3A14")
+ assert(!partFile1.exists())
+ val partFile2 = new File(loc, "b=2017-03-03 12%3A13%253A14")
+ assert(partFile2.listFiles().length >= 1)
+ checkAnswer(spark.table("t1"),
+ Row("1", "2") :: Row("1", "2017-03-03 12:13%3A14") :: Nil)
+ } else {
+ val e = intercept[InvocationTargetException] {
+ spark.sql("INSERT INTO TABLE t1 PARTITION(b=2) SELECT 1")
+ }.getTargetException.getMessage
+ assert(e.contains("java.net.URISyntaxException: Relative path in absolute URI: a:b"))
+
+ val e1 = intercept[InvocationTargetException] {
+ spark.sql("INSERT INTO TABLE t1 PARTITION(b='2017-03-03 12:13%3A14') SELECT 1")
+ }.getTargetException.getMessage
+ assert(e1.contains("java.net.URISyntaxException: Relative path in absolute URI: a:b"))
+ }
+ }
+ }
+ }
+ }
+
+ test("SPARK-19905: Hive SerDe table input paths") {
+ withTable("spark_19905") {
+ withTempView("spark_19905_view") {
+ spark.range(10).createOrReplaceTempView("spark_19905_view")
+ sql("CREATE TABLE spark_19905 STORED AS RCFILE AS SELECT * FROM spark_19905_view")
+ assert(spark.table("spark_19905").inputFiles.nonEmpty)
+ assert(sql("SELECT input_file_name() FROM spark_19905").count() > 0)
}
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index ef2d451e6b2d..236135dcff52 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -28,7 +28,7 @@ import org.apache.spark.TestUtils
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, FunctionRegistry, NoSuchPartitionException}
-import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTableType}
+import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTableType, CatalogUtils}
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias}
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
@@ -544,7 +544,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
}
userSpecifiedLocation match {
case Some(location) =>
- assert(r.tableMeta.location === location)
+ assert(r.tableMeta.location === CatalogUtils.stringToURI(location))
case None => // OK.
}
// Also make sure that the format and serde are as desired.
@@ -1030,7 +1030,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
withSQLConf(SQLConf.CONVERT_CTAS.key -> "false") {
sql("CREATE TABLE explodeTest (key bigInt)")
table("explodeTest").queryExecution.analyzed match {
- case SubqueryAlias(_, r: CatalogRelation, _) => // OK
+ case SubqueryAlias(_, r: CatalogRelation) => // OK
case _ =>
fail("To correctly test the fix of SPARK-5875, explodeTest should be a MetastoreRelation")
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala
index 4f771caa1db2..ba0a7605da71 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala
@@ -19,10 +19,10 @@ package org.apache.spark.sql.hive.orc
import java.io.File
-import org.apache.hadoop.fs.{FileSystem, Path}
+import org.apache.hadoop.fs.Path
-import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.catalog.CatalogUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.HadoopFsRelationTest
import org.apache.spark.sql.types._
@@ -42,12 +42,9 @@ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest {
test("save()/load() - partitioned table - simple queries - partition columns in data") {
withTempDir { file =>
- val basePath = new Path(file.getCanonicalPath)
- val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf)
- val qualifiedBasePath = fs.makeQualified(basePath)
-
for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) {
- val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2")
+ val partitionDir = new Path(
+ CatalogUtils.URIToString(makeQualifiedPath(file.getCanonicalPath)), s"p1=$p1/p2=$p2")
sparkContext
.parallelize(for (i <- 1 to 3) yield (i, s"val_$i", p1))
.toDF("a", "b", "p1")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala
index 38a5477796a4..5d8ba9d7c85d 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala
@@ -33,6 +33,7 @@ import org.apache.spark.sql.hive.test.TestHive._
import org.apache.spark.sql.hive.test.TestHive.implicits._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{IntegerType, StructType}
+import org.apache.spark.util.Utils
case class AllDataTypesWithNonPrimitiveType(
stringField: String,
@@ -611,4 +612,12 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest {
}
}
}
+
+ test("read from multiple orc input paths") {
+ val path1 = Utils.createTempDir()
+ val path2 = Utils.createTempDir()
+ makeOrcFile((1 to 10).map(Tuple1.apply), path1)
+ makeOrcFile((1 to 10).map(Tuple1.apply), path2)
+ assertResult(20)(read.orc(path1.getCanonicalPath, path2.getCanonicalPath).count())
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
index 3512c4a89031..81af24979d82 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
@@ -453,7 +453,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest {
// Converted test_parquet should be cached.
sessionState.catalog.getCachedDataSourceTable(tableIdentifier) match {
case null => fail("Converted test_parquet should be cached in the cache.")
- case logical @ LogicalRelation(parquetRelation: HadoopFsRelation, _, _) => // OK
+ case LogicalRelation(_: HadoopFsRelation, _, _) => // OK
case other =>
fail(
"The cached test_parquet should be a Parquet Relation. " +
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala
index d79edee5b1a4..49be30435ad2 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala
@@ -21,8 +21,8 @@ import java.math.BigDecimal
import org.apache.hadoop.fs.Path
-import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.catalog.CatalogUtils
import org.apache.spark.sql.types._
class JsonHadoopFsRelationSuite extends HadoopFsRelationTest {
@@ -38,12 +38,9 @@ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest {
test("save()/load() - partitioned table - simple queries - partition columns in data") {
withTempDir { file =>
- val basePath = new Path(file.getCanonicalPath)
- val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf)
- val qualifiedBasePath = fs.makeQualified(basePath)
-
for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) {
- val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2")
+ val partitionDir = new Path(
+ CatalogUtils.URIToString(makeQualifiedPath(file.getCanonicalPath)), s"p1=$p1/p2=$p2")
sparkContext
.parallelize(for (i <- 1 to 3) yield s"""{"a":$i,"b":"val_$i"}""")
.saveAsTextFile(partitionDir.toString)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala
index 03207ab869d1..dce5bb7ddba6 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala
@@ -23,8 +23,8 @@ import com.google.common.io.Files
import org.apache.hadoop.fs.Path
import org.apache.parquet.hadoop.ParquetOutputFormat
-import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.catalog.CatalogUtils
import org.apache.spark.sql.execution.datasources.SQLHadoopMapReduceCommitProtocol
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -44,12 +44,9 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest {
test("save()/load() - partitioned table - simple queries - partition columns in data") {
withTempDir { file =>
- val basePath = new Path(file.getCanonicalPath)
- val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf)
- val qualifiedBasePath = fs.makeQualified(basePath)
-
for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) {
- val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2")
+ val partitionDir = new Path(
+ CatalogUtils.URIToString(makeQualifiedPath(file.getCanonicalPath)), s"p1=$p1/p2=$p2")
sparkContext
.parallelize(for (i <- 1 to 3) yield (i, s"val_$i", p1))
.toDF("a", "b", "p1")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala
index a47a2246ddc3..2ec593b95c9b 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.sources
import org.apache.hadoop.fs.Path
-import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.sql.catalyst.catalog.CatalogUtils
import org.apache.spark.sql.catalyst.expressions.PredicateHelper
import org.apache.spark.sql.types._
@@ -45,12 +45,9 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest with Predicat
test("save()/load() - partitioned table - simple queries - partition columns in data") {
withTempDir { file =>
- val basePath = new Path(file.getCanonicalPath)
- val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf)
- val qualifiedBasePath = fs.makeQualified(basePath)
-
for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) {
- val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2")
+ val partitionDir = new Path(
+ CatalogUtils.URIToString(makeQualifiedPath(file.getCanonicalPath)), s"p1=$p1/p2=$p2")
sparkContext
.parallelize(for (i <- 1 to 3) yield s"$i,val_$i,$p1")
.saveAsTextFile(partitionDir.toString)
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
index 7fcf45e7dedc..ee2fd45a7e85 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
@@ -152,11 +152,9 @@ trait DStreamCheckpointTester { self: SparkFunSuite =>
stopSparkContext: Boolean
): Seq[Seq[V]] = {
try {
- val batchDuration = ssc.graph.batchDuration
val batchCounter = new BatchCounter(ssc)
ssc.start()
val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
- val currentTime = clock.getTimeMillis()
logInfo("Manual clock before advancing = " + clock.getTimeMillis())
clock.setTime(targetBatchTime.milliseconds)
@@ -171,7 +169,7 @@ trait DStreamCheckpointTester { self: SparkFunSuite =>
eventually(timeout(10 seconds)) {
val checkpointFilesOfLatestTime = Checkpoint.getCheckpointFiles(checkpointDir).filter {
- _.toString.contains(clock.getTimeMillis.toString)
+ _.getName.contains(clock.getTimeMillis.toString)
}
// Checkpoint files are written twice for every batch interval. So assert that both
// are written to make sure that both of them have been written.
|