diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index c575fe255f57a..9696f6987ad78 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -201,10 +201,13 @@ exportMethods("%<=>%", "approxCountDistinct", "approxQuantile", "array_contains", + "array_join", "array_max", "array_min", "array_position", + "array_repeat", "array_sort", + "arrays_overlap", "asc", "ascii", "asin", @@ -302,6 +305,7 @@ exportMethods("%<=>%", "lower", "lpad", "ltrim", + "map_entries", "map_keys", "map_values", "max", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index a1c9495b0795e..70eb7a874b75c 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2297,6 +2297,8 @@ setMethod("rename", setClassUnion("characterOrColumn", c("character", "Column")) +setClassUnion("numericOrColumn", c("numeric", "Column")) + #' Arrange Rows by Variables #' #' Sort a SparkDataFrame by the specified column(s). diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index fcb3521f901ea..3bff633fbc1ff 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -189,6 +189,7 @@ NULL #' the map or array of maps. #' \item \code{from_json}: it is the column containing the JSON string. #' } +#' @param y Column to compute on. #' @param value A value to compute on. #' \itemize{ #' \item \code{array_contains}: a value to be checked if contained in the column. @@ -207,7 +208,7 @@ NULL #' tmp <- mutate(df, v1 = create_array(df$mpg, df$cyl, df$hp)) #' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1))) #' head(select(tmp, array_max(tmp$v1), array_min(tmp$v1))) -#' head(select(tmp, array_position(tmp$v1, 21), array_sort(tmp$v1))) +#' head(select(tmp, array_position(tmp$v1, 21), array_repeat(df$mpg, 3), array_sort(tmp$v1))) #' head(select(tmp, flatten(tmp$v1), reverse(tmp$v1))) #' tmp2 <- mutate(tmp, v2 = explode(tmp$v1)) #' head(tmp2) @@ -216,12 +217,13 @@ NULL #' head(select(tmp, sort_array(tmp$v1))) #' head(select(tmp, sort_array(tmp$v1, asc = FALSE))) #' tmp3 <- mutate(df, v3 = create_map(df$model, df$cyl)) -#' head(select(tmp3, map_keys(tmp3$v3))) -#' head(select(tmp3, map_values(tmp3$v3))) +#' head(select(tmp3, map_entries(tmp3$v3), map_keys(tmp3$v3), map_values(tmp3$v3))) #' head(select(tmp3, element_at(tmp3$v3, "Valiant"))) -#' tmp4 <- mutate(df, v4 = create_array(df$mpg, df$cyl), v5 = create_array(df$hp)) -#' head(select(tmp4, concat(tmp4$v4, tmp4$v5))) -#' head(select(tmp, concat(df$mpg, df$cyl, df$hp)))} +#' tmp4 <- mutate(df, v4 = create_array(df$mpg, df$cyl), v5 = create_array(df$cyl, df$hp)) +#' head(select(tmp4, concat(tmp4$v4, tmp4$v5), arrays_overlap(tmp4$v4, tmp4$v5))) +#' head(select(tmp, concat(df$mpg, df$cyl, df$hp))) +#' tmp5 <- mutate(df, v6 = create_array(df$model, df$model)) +#' head(select(tmp5, array_join(tmp5$v6, "#"), array_join(tmp5$v6, "#", "NULL")))} NULL #' Window functions for Column operations @@ -3006,6 +3008,27 @@ setMethod("array_contains", column(jc) }) +#' @details +#' \code{array_join}: Concatenates the elements of column using the delimiter. +#' Null values are replaced with nullReplacement if set, otherwise they are ignored. +#' +#' @param delimiter a character string that is used to concatenate the elements of column. +#' @param nullReplacement an optional character string that is used to replace the Null values. +#' @rdname column_collection_functions +#' @aliases array_join array_join,Column-method +#' @note array_join since 2.4.0 +setMethod("array_join", + signature(x = "Column", delimiter = "character"), + function(x, delimiter, nullReplacement = NULL) { + jc <- if (is.null(nullReplacement)) { + callJStatic("org.apache.spark.sql.functions", "array_join", x@jc, delimiter) + } else { + callJStatic("org.apache.spark.sql.functions", "array_join", x@jc, delimiter, + as.character(nullReplacement)) + } + column(jc) + }) + #' @details #' \code{array_max}: Returns the maximum value of the array. #' @@ -3048,6 +3071,26 @@ setMethod("array_position", column(jc) }) +#' @details +#' \code{array_repeat}: Creates an array containing \code{x} repeated the number of times +#' given by \code{count}. +#' +#' @param count a Column or constant determining the number of repetitions. +#' @rdname column_collection_functions +#' @aliases array_repeat array_repeat,Column,numericOrColumn-method +#' @note array_repeat since 2.4.0 +setMethod("array_repeat", + signature(x = "Column", count = "numericOrColumn"), + function(x, count) { + if (class(count) == "Column") { + count <- count@jc + } else { + count <- as.integer(count) + } + jc <- callJStatic("org.apache.spark.sql.functions", "array_repeat", x@jc, count) + column(jc) + }) + #' @details #' \code{array_sort}: Sorts the input array in ascending order. The elements of the input array #' must be orderable. NA elements will be placed at the end of the returned array. @@ -3062,6 +3105,21 @@ setMethod("array_sort", column(jc) }) +#' @details +#' \code{arrays_overlap}: Returns true if the input arrays have at least one non-null element in +#' common. If not and both arrays are non-empty and any of them contains a null, it returns null. +#' It returns false otherwise. +#' +#' @rdname column_collection_functions +#' @aliases arrays_overlap arrays_overlap,Column-method +#' @note arrays_overlap since 2.4.0 +setMethod("arrays_overlap", + signature(x = "Column", y = "Column"), + function(x, y) { + jc <- callJStatic("org.apache.spark.sql.functions", "arrays_overlap", x@jc, y@jc) + column(jc) + }) + #' @details #' \code{flatten}: Creates a single array from an array of arrays. #' If a structure of nested arrays is deeper than two levels, only one level of nesting is removed. @@ -3076,6 +3134,19 @@ setMethod("flatten", column(jc) }) +#' @details +#' \code{map_entries}: Returns an unordered array of all entries in the given map. +#' +#' @rdname column_collection_functions +#' @aliases map_entries map_entries,Column-method +#' @note map_entries since 2.4.0 +setMethod("map_entries", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "map_entries", x@jc) + column(jc) + }) + #' @details #' \code{map_keys}: Returns an unordered array containing the keys of the map. #' @@ -3149,8 +3220,8 @@ setMethod("size", #' (or starting from the end if start is negative) with the specified length. #' #' @rdname column_collection_functions -#' @param start an index indicating the first element occuring in the result. -#' @param length a number of consecutive elements choosen to the result. +#' @param start an index indicating the first element occurring in the result. +#' @param length a number of consecutive elements chosen to the result. #' @aliases slice slice,Column-method #' @note slice since 2.4.0 setMethod("slice", diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 3ea181157b644..9321bbaf96ff8 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -757,6 +757,10 @@ setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCoun #' @name NULL setGeneric("array_contains", function(x, value) { standardGeneric("array_contains") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_join", function(x, delimiter, ...) { standardGeneric("array_join") }) + #' @rdname column_collection_functions #' @name NULL setGeneric("array_max", function(x) { standardGeneric("array_max") }) @@ -769,10 +773,18 @@ setGeneric("array_min", function(x) { standardGeneric("array_min") }) #' @name NULL setGeneric("array_position", function(x, value) { standardGeneric("array_position") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_repeat", function(x, count) { standardGeneric("array_repeat") }) + #' @rdname column_collection_functions #' @name NULL setGeneric("array_sort", function(x) { standardGeneric("array_sort") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("arrays_overlap", function(x, y) { standardGeneric("arrays_overlap") }) + #' @rdname column_string_functions #' @name NULL setGeneric("ascii", function(x) { standardGeneric("ascii") }) @@ -1034,6 +1046,10 @@ setGeneric("lpad", function(x, len, pad) { standardGeneric("lpad") }) #' @name NULL setGeneric("ltrim", function(x, trimString) { standardGeneric("ltrim") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("map_entries", function(x) { standardGeneric("map_entries") }) + #' @rdname column_collection_functions #' @name NULL setGeneric("map_keys", function(x) { standardGeneric("map_keys") }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 13b55ac6e6e3c..36e0f78bb0599 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1503,6 +1503,36 @@ test_that("column functions", { result <- collect(select(df2, reverse(df2[[1]])))[[1]] expect_equal(result, "cba") + # Test array_repeat() + df <- createDataFrame(list(list("a", 3L), list("b", 2L))) + result <- collect(select(df, array_repeat(df[[1]], df[[2]])))[[1]] + expect_equal(result, list(list("a", "a", "a"), list("b", "b"))) + + result <- collect(select(df, array_repeat(df[[1]], 2L)))[[1]] + expect_equal(result, list(list("a", "a"), list("b", "b"))) + + # Test arrays_overlap() + df <- createDataFrame(list(list(list(1L, 2L), list(3L, 1L)), + list(list(1L, 2L), list(3L, 4L)), + list(list(1L, NA), list(3L, 4L)))) + result <- collect(select(df, arrays_overlap(df[[1]], df[[2]])))[[1]] + expect_equal(result, c(TRUE, FALSE, NA)) + + # Test array_join() + df <- createDataFrame(list(list(list("Hello", "World!")))) + result <- collect(select(df, array_join(df[[1]], "#")))[[1]] + expect_equal(result, "Hello#World!") + df2 <- createDataFrame(list(list(list("Hello", NA, "World!")))) + result <- collect(select(df2, array_join(df2[[1]], "#", "Beautiful")))[[1]] + expect_equal(result, "Hello#Beautiful#World!") + result <- collect(select(df2, array_join(df2[[1]], "#")))[[1]] + expect_equal(result, "Hello#World!") + df3 <- createDataFrame(list(list(list("Hello", NULL, "World!")))) + result <- collect(select(df3, array_join(df3[[1]], "#", "Beautiful")))[[1]] + expect_equal(result, "Hello#Beautiful#World!") + result <- collect(select(df3, array_join(df3[[1]], "#")))[[1]] + expect_equal(result, "Hello#World!") + # Test array_sort() and sort_array() df <- createDataFrame(list(list(list(2L, 1L, 3L, NA)), list(list(NA, 6L, 5L, NA, 4L)))) @@ -1531,8 +1561,13 @@ test_that("column functions", { result <- collect(select(df, flatten(df[[1]])))[[1]] expect_equal(result, list(list(1L, 2L, 3L, 4L), list(5L, 6L, 7L, 8L))) - # Test map_keys(), map_values() and element_at() + # Test map_entries(), map_keys(), map_values() and element_at() df <- createDataFrame(list(list(map = as.environment(list(x = 1, y = 2))))) + result <- collect(select(df, map_entries(df$map)))[[1]] + expected_entries <- list(listToStruct(list(key = "x", value = 1)), + listToStruct(list(key = "y", value = 2))) + expect_equal(result, list(expected_entries)) + result <- collect(select(df, map_keys(df$map)))[[1]] expect_equal(result, list(list("x", "y"))) diff --git a/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java index ae91bc9cfdd08..5a0f575b0dac1 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java +++ b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java @@ -146,7 +146,8 @@ public TransportChannelHandler initializePipeline( TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler); channel.pipeline() .addLast("encoder", ENCODER) - .addLast(TransportFrameDecoder.HANDLER_NAME, NettyUtils.createFrameDecoder()) + .addLast(TransportFrameDecoder.HANDLER_NAME, + NettyUtils.createFrameDecoder(conf.maxRemoteBlockSizeFetchToMem(), false)) .addLast("decoder", DECODER) .addLast("idleStateHandler", new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000)) // NOTE: Chunks are currently guaranteed to be returned in the order of request, but this diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallback.java b/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallback.java index d322aec28793e..bf3b60648b3ca 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallback.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallback.java @@ -17,6 +17,8 @@ package org.apache.spark.network.client; +import io.netty.buffer.ByteBuf; + import java.io.IOException; import java.nio.ByteBuffer; @@ -28,13 +30,13 @@ * The network library guarantees that a single thread will call these methods at a time, but * different call may be made by different threads. */ -public interface StreamCallback { +public interface StreamCallback { /** Called upon receipt of stream data. */ - void onData(String streamId, ByteBuffer buf) throws IOException; + void onData(T streamId, ByteBuffer buf) throws IOException; /** Called when all data from the stream has been received. */ - void onComplete(String streamId) throws IOException; + void onComplete(T streamId) throws IOException; /** Called if there's an error reading data from the stream. */ - void onFailure(String streamId, Throwable cause) throws IOException; + void onFailure(T streamId, Throwable cause) throws IOException; } diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java b/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java index b0e85bae7c309..19faacf67d00e 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java @@ -28,17 +28,17 @@ * An interceptor that is registered with the frame decoder to feed stream data to a * callback. */ -class StreamInterceptor implements TransportFrameDecoder.Interceptor { +class StreamInterceptor implements TransportFrameDecoder.Interceptor { private final TransportResponseHandler handler; - private final String streamId; + private final T streamId; private final long byteCount; - private final StreamCallback callback; + private final StreamCallback callback; private long bytesRead; StreamInterceptor( TransportResponseHandler handler, - String streamId, + T streamId, long byteCount, StreamCallback callback) { this.handler = handler; diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java index 8f354ad78bbaa..c8a8c83f93b9c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -24,6 +24,7 @@ import java.util.UUID; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; import javax.annotation.Nullable; import com.google.common.annotations.VisibleForTesting; @@ -132,14 +133,15 @@ public void setClientId(String id) { public void fetchChunk( long streamId, int chunkIndex, - ChunkReceivedCallback callback) { + ChunkReceivedCallback callback, + Supplier> streamCallbackFactory) { long startTime = System.currentTimeMillis(); if (logger.isDebugEnabled()) { logger.debug("Sending fetch chunk request {} to {}", chunkIndex, getRemoteAddress(channel)); } StreamChunkId streamChunkId = new StreamChunkId(streamId, chunkIndex); - handler.addFetchRequest(streamChunkId, callback); + handler.addFetchRequest(streamChunkId, callback, streamCallbackFactory); channel.writeAndFlush(new ChunkFetchRequest(streamChunkId)).addListener(future -> { if (future.isSuccess()) { @@ -169,7 +171,7 @@ public void fetchChunk( * @param streamId The stream to fetch. * @param callback Object to call with the stream data. */ - public void stream(String streamId, StreamCallback callback) { + public void stream(String streamId, StreamCallback callback) { long startTime = System.currentTimeMillis(); if (logger.isDebugEnabled()) { logger.debug("Sending stream request for {} to {}", streamId, getRemoteAddress(channel)); diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index 7a3d96ceaef0c..41c2c7cc6b271 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -23,6 +23,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Supplier; import com.google.common.annotations.VisibleForTesting; import io.netty.channel.Channel; @@ -55,10 +56,12 @@ public class TransportResponseHandler extends MessageHandler { private final Channel channel; private final Map outstandingFetches; + private final Map>> + outstandingStreamFetches; private final Map outstandingRpcs; - private final Queue> streamCallbacks; + private final Queue>> streamCallbacks; private volatile boolean streamActive; /** Records the time (in system nanoseconds) that the last fetch or RPC request was sent. */ @@ -67,18 +70,24 @@ public class TransportResponseHandler extends MessageHandler { public TransportResponseHandler(Channel channel) { this.channel = channel; this.outstandingFetches = new ConcurrentHashMap<>(); + this.outstandingStreamFetches = new ConcurrentHashMap<>(); this.outstandingRpcs = new ConcurrentHashMap<>(); this.streamCallbacks = new ConcurrentLinkedQueue<>(); this.timeOfLastRequestNs = new AtomicLong(0); } - public void addFetchRequest(StreamChunkId streamChunkId, ChunkReceivedCallback callback) { + public void addFetchRequest( + StreamChunkId streamChunkId, + ChunkReceivedCallback callback, + Supplier> streamCallbackFactory) { updateTimeOfLastRequest(); outstandingFetches.put(streamChunkId, callback); + outstandingStreamFetches.put(streamChunkId, streamCallbackFactory); } public void removeFetchRequest(StreamChunkId streamChunkId) { outstandingFetches.remove(streamChunkId); + outstandingStreamFetches.remove(streamChunkId); } public void addRpcRequest(long requestId, RpcResponseCallback callback) { @@ -90,7 +99,7 @@ public void removeRpcRequest(long requestId) { outstandingRpcs.remove(requestId); } - public void addStreamCallback(String streamId, StreamCallback callback) { + public void addStreamCallback(String streamId, StreamCallback callback) { timeOfLastRequestNs.set(System.nanoTime()); streamCallbacks.offer(ImmutablePair.of(streamId, callback)); } @@ -119,7 +128,7 @@ private void failOutstandingRequests(Throwable cause) { logger.warn("RpcResponseCallback.onFailure throws exception", e); } } - for (Pair entry : streamCallbacks) { + for (Pair> entry : streamCallbacks) { try { entry.getValue().onFailure(entry.getKey(), cause); } catch (Exception e) { @@ -131,6 +140,7 @@ private void failOutstandingRequests(Throwable cause) { outstandingFetches.clear(); outstandingRpcs.clear(); streamCallbacks.clear(); + outstandingStreamFetches.clear(); } @Override @@ -165,10 +175,37 @@ public void handle(ResponseMessage message) throws Exception { if (listener == null) { logger.warn("Ignoring response for block {} from {} since it is not outstanding", resp.streamChunkId, getRemoteAddress(channel)); - resp.body().release(); } else { - outstandingFetches.remove(resp.streamChunkId); - listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body()); + if (resp.isBodyInFrame()) { + outstandingFetches.remove(resp.streamChunkId); + listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body()); + } else { + StreamCallback streamCallback = + outstandingStreamFetches.get(resp.streamChunkId).get(); + outstandingFetches.remove(resp.streamChunkId); + outstandingStreamFetches.remove(resp.streamChunkId); + if (resp.remainingFrameSize > 0) { + StreamInterceptor interceptor = new StreamInterceptor(this, + resp.streamChunkId, resp.remainingFrameSize, streamCallback); + try { + TransportFrameDecoder frameDecoder = (TransportFrameDecoder) + channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); + frameDecoder.setInterceptor(interceptor); + streamActive = true; + } catch (Exception e) { + logger.error("Error installing stream handler.", e); + deactivateStream(); + } + } else { + try { + streamCallback.onComplete(resp.streamChunkId); + } catch (Exception e) { + logger.warn("Error in stream handler onComplete().", e); + } + } + } + } + if (resp.isBodyInFrame()) { resp.body().release(); } } else if (message instanceof ChunkFetchFailure) { @@ -208,12 +245,12 @@ public void handle(ResponseMessage message) throws Exception { } } else if (message instanceof StreamResponse) { StreamResponse resp = (StreamResponse) message; - Pair entry = streamCallbacks.poll(); + Pair> entry = streamCallbacks.poll(); if (entry != null) { StreamCallback callback = entry.getValue(); if (resp.byteCount > 0) { - StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount, - callback); + StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, + resp.byteCount, callback); try { TransportFrameDecoder frameDecoder = (TransportFrameDecoder) channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); @@ -235,7 +272,7 @@ public void handle(ResponseMessage message) throws Exception { } } else if (message instanceof StreamFailure) { StreamFailure resp = (StreamFailure) message; - Pair entry = streamCallbacks.poll(); + Pair> entry = streamCallbacks.poll(); if (entry != null) { StreamCallback callback = entry.getValue(); try { diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java index 94c2ac9b20e43..7a315f9db2fd2 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java @@ -31,11 +31,23 @@ * Similarly, the client-side decoding will reuse the Netty ByteBuf as the buffer. */ public final class ChunkFetchSuccess extends AbstractResponseMessage { + public static final int ENCODED_LENGTH = StreamChunkId.ENCODED_LENGTH; public final StreamChunkId streamChunkId; + public final long remainingFrameSize; public ChunkFetchSuccess(StreamChunkId streamChunkId, ManagedBuffer buffer) { super(buffer, true); this.streamChunkId = streamChunkId; + this.remainingFrameSize = 0; + } + + public ChunkFetchSuccess(StreamChunkId streamChunkId, + ManagedBuffer buffer, + boolean isBodyInFrame, + long remainingFrameSize) { + super(buffer, isBodyInFrame); + this.streamChunkId = streamChunkId; + this.remainingFrameSize = remainingFrameSize; } @Override @@ -58,11 +70,16 @@ public ResponseMessage createFailureResponse(String error) { } /** Decoding uses the given ByteBuf as our data, and will retain() it. */ - public static ChunkFetchSuccess decode(ByteBuf buf) { + public static ChunkFetchSuccess decode(ByteBuf buf, long remainingFrameSize) { StreamChunkId streamChunkId = StreamChunkId.decode(buf); - buf.retain(); - NettyManagedBuffer managedBuf = new NettyManagedBuffer(buf.duplicate()); - return new ChunkFetchSuccess(streamChunkId, managedBuf); + NettyManagedBuffer managedBuf = null; + final boolean isFullFrameProcessed = + remainingFrameSize == 0; + if (isFullFrameProcessed) { + buf.retain(); + managedBuf = new NettyManagedBuffer(buf.duplicate()); + } + return new ChunkFetchSuccess(streamChunkId, managedBuf, isFullFrameProcessed, remainingFrameSize); } @Override diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java index 434935a8ef2ad..8360d7d0ac21b 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java @@ -39,6 +39,9 @@ enum Type implements Encodable { StreamRequest(6), StreamResponse(7), StreamFailure(8), OneWayMessage(9), User(-1); + /** Encoded length in bytes. */ + public static final int LENGTH = 1; + private final byte id; Type(int id) { @@ -48,7 +51,7 @@ enum Type implements Encodable { public byte id() { return id; } - @Override public int encodedLength() { return 1; } + @Override public int encodedLength() { return LENGTH; } @Override public void encode(ByteBuf buf) { buf.writeByte(id); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java index 39a7495828a8a..fc7151d55a3e2 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java @@ -31,7 +31,7 @@ * This encoder is stateless so it is safe to be shared by multiple threads. */ @ChannelHandler.Sharable -public final class MessageDecoder extends MessageToMessageDecoder { +public final class MessageDecoder extends MessageToMessageDecoder { private static final Logger logger = LoggerFactory.getLogger(MessageDecoder.class); @@ -40,21 +40,20 @@ public final class MessageDecoder extends MessageToMessageDecoder { private MessageDecoder() {} @Override - public void decode(ChannelHandlerContext ctx, ByteBuf in, List out) { - Message.Type msgType = Message.Type.decode(in); - Message decoded = decode(msgType, in); - assert decoded.type() == msgType; - logger.trace("Received message {}: {}", msgType, decoded); + public void decode(ChannelHandlerContext ctx, ParsedFrame in, List out) { + Message decoded = decode(in.messageType, in.byteBuf, in.remainingFrameSize); + assert decoded.type() == in.messageType; + logger.trace("Received message {}: {}", in.messageType, decoded); out.add(decoded); } - private Message decode(Message.Type msgType, ByteBuf in) { + private Message decode(Message.Type msgType, ByteBuf in, long remainingFrameSize) { switch (msgType) { case ChunkFetchRequest: return ChunkFetchRequest.decode(in); case ChunkFetchSuccess: - return ChunkFetchSuccess.decode(in); + return ChunkFetchSuccess.decode(in, remainingFrameSize); case ChunkFetchFailure: return ChunkFetchFailure.decode(in); diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/ParsedFrame.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/ParsedFrame.java new file mode 100644 index 0000000000000..b29fe7626eefa --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/ParsedFrame.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import io.netty.buffer.ByteBuf; + +public class ParsedFrame { + + public final Message.Type messageType; + + public final ByteBuf byteBuf; + + public final long remainingFrameSize; + + + public ParsedFrame(Message.Type messageType, ByteBuf byteBuf, long remainingFrameSize) { + this.messageType = messageType; + this.byteBuf = byteBuf; + this.remainingFrameSize = remainingFrameSize; + } +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java index d46a263884807..0807cb127c9d1 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java @@ -24,6 +24,9 @@ * Encapsulates a request for a particular chunk of a stream. */ public final class StreamChunkId implements Encodable { + + public static final int ENCODED_LENGTH = 8 + 4; + public final long streamId; public final int chunkIndex; @@ -34,7 +37,7 @@ public StreamChunkId(long streamId, int chunkIndex) { @Override public int encodedLength() { - return 8 + 4; + return ENCODED_LENGTH; } public void encode(ByteBuf buffer) { @@ -43,7 +46,7 @@ public void encode(ByteBuf buffer) { } public static StreamChunkId decode(ByteBuf buffer) { - assert buffer.readableBytes() >= 8 + 4; + assert buffer.readableBytes() >= ENCODED_LENGTH; long streamId = buffer.readLong(); int chunkIndex = buffer.readInt(); return new StreamChunkId(streamId, chunkIndex); diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java index 3ac9081d78a75..5128f90160a93 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java @@ -61,7 +61,7 @@ static void addToChannel( channel.pipeline() .addFirst(ENCRYPTION_HANDLER_NAME, new EncryptionHandler(backend, maxOutboundBlockSize)) .addFirst("saslDecryption", new DecryptionHandler(backend)) - .addFirst("saslFrameDecoder", NettyUtils.createFrameDecoder()); + .addFirst("saslFrameDecoder", NettyUtils.createFrameDecoder(Integer.MAX_VALUE, true)); } private static class EncryptionHandler extends ChannelOutboundHandlerAdapter { diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java index afc59efaef810..b5497087634ce 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -17,10 +17,7 @@ package org.apache.spark.network.util; -import java.io.Closeable; -import java.io.EOFException; -import java.io.File; -import java.io.IOException; +import java.io.*; import java.nio.ByteBuffer; import java.nio.channels.ReadableByteChannel; import java.nio.charset.StandardCharsets; @@ -91,11 +88,24 @@ public static String bytesToString(ByteBuffer b) { * @throws IOException if deletion is unsuccessful */ public static void deleteRecursively(File file) throws IOException { + deleteRecursively(file, null); + } + + /** + * Delete a file or directory and its contents recursively. + * Don't follow directories if they are symlinks. + * + * @param file Input file / dir to be deleted + * @param filter A filename filter that make sure only files / dirs with the satisfied filenames + * are deleted. + * @throws IOException if deletion is unsuccessful + */ + public static void deleteRecursively(File file, FilenameFilter filter) throws IOException { if (file == null) { return; } // On Unix systems, use operating system command to run faster // If that does not work out, fallback to the Java IO way - if (SystemUtils.IS_OS_UNIX) { + if (SystemUtils.IS_OS_UNIX && filter == null) { try { deleteRecursivelyUsingUnixNative(file); return; @@ -105,15 +115,17 @@ public static void deleteRecursively(File file) throws IOException { } } - deleteRecursivelyUsingJavaIO(file); + deleteRecursivelyUsingJavaIO(file, filter); } - private static void deleteRecursivelyUsingJavaIO(File file) throws IOException { + private static void deleteRecursivelyUsingJavaIO( + File file, + FilenameFilter filter) throws IOException { if (file.isDirectory() && !isSymlink(file)) { IOException savedIOException = null; - for (File child : listFilesSafely(file)) { + for (File child : listFilesSafely(file, filter)) { try { - deleteRecursively(child); + deleteRecursively(child, filter); } catch (IOException e) { // In case of multiple exceptions, only last one will be thrown savedIOException = e; @@ -124,10 +136,13 @@ private static void deleteRecursivelyUsingJavaIO(File file) throws IOException { } } - boolean deleted = file.delete(); - // Delete can also fail if the file simply did not exist. - if (!deleted && file.exists()) { - throw new IOException("Failed to delete: " + file.getAbsolutePath()); + // Delete file only when it's a normal file or an empty directory. + if (file.isFile() || (file.isDirectory() && listFilesSafely(file, null).length == 0)) { + boolean deleted = file.delete(); + // Delete can also fail if the file simply did not exist. + if (!deleted && file.exists()) { + throw new IOException("Failed to delete: " + file.getAbsolutePath()); + } } } @@ -157,9 +172,9 @@ private static void deleteRecursivelyUsingUnixNative(File file) throws IOExcepti } } - private static File[] listFilesSafely(File file) throws IOException { + private static File[] listFilesSafely(File file, FilenameFilter filter) throws IOException { if (file.exists()) { - File[] files = file.listFiles(); + File[] files = file.listFiles(filter); if (files == null) { throw new IOException("Failed to list files for dir: " + file); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java b/common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java index 5e85180bd6f9f..eaa02e78610e2 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java @@ -84,8 +84,10 @@ public static Class getServerChannelClass(IOMode mode) * Creates a LengthFieldBasedFrameDecoder where the first 8 bytes are the length of the frame. * This is used before all decoders. */ - public static TransportFrameDecoder createFrameDecoder() { - return new TransportFrameDecoder(); + public static TransportFrameDecoder createFrameDecoder( + long maxRemoteBlockSizeFetchToMem, + boolean isSasl) { + return new TransportFrameDecoder(maxRemoteBlockSizeFetchToMem, isSasl); } /** Returns the remote address on the channel or "<unknown remote>" if none exists. */ diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index 91497b9492219..a2a0939a15586 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -281,4 +281,8 @@ public Properties cryptoConf() { public long maxChunksBeingTransferred() { return conf.getLong("spark.shuffle.maxChunksBeingTransferred", Long.MAX_VALUE); } + + public long maxRemoteBlockSizeFetchToMem() { + return conf.getLong("spark.maxRemoteBlockSizeFetchToMem", Long.MAX_VALUE); + } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java index 8e73ab077a5c1..4e6b617e53996 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java @@ -25,6 +25,9 @@ import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; +import org.apache.spark.network.protocol.ChunkFetchSuccess; +import org.apache.spark.network.protocol.Message; +import org.apache.spark.network.protocol.ParsedFrame; /** * A customized frame decoder that allows intercepting raw data. @@ -55,6 +58,19 @@ public class TransportFrameDecoder extends ChannelInboundHandlerAdapter { private long totalSize = 0; private long nextFrameSize = UNKNOWN_FRAME_SIZE; private volatile Interceptor interceptor; + private Message.Type msgType = null; + private final boolean isSasl; + + private final long maxRemoteBlockSizeFetchToMem; + + public TransportFrameDecoder(long maxRemoteBlockSizeFetchToMem) { + this(maxRemoteBlockSizeFetchToMem, false); + } + + public TransportFrameDecoder(long maxRemoteBlockSizeFetchToMem, boolean isSasl) { + this.maxRemoteBlockSizeFetchToMem = maxRemoteBlockSizeFetchToMem; + this.isSasl = isSasl; + } @Override public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception { @@ -77,12 +93,37 @@ public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception } totalSize -= read; } else { - // Interceptor is not active, so try to decode one frame. - ByteBuf frame = decodeNext(); - if (frame == null) { - break; + if (isSasl) { + if (!isFrameSizeAvailable()) { + break; + } + ByteBuf frame = decodeNext(); + if (frame == null) { + break; + } + ctx.fireChannelRead(frame); + } else { + // Interceptor is not active, so try to decode one frame. + decodeNextMsgType(); + if (msgType == null) { + break; + } + long remainingFrameSize = 0; + if (msgType == Message.Type.ChunkFetchSuccess && + nextFrameSize - ChunkFetchSuccess.ENCODED_LENGTH > maxRemoteBlockSizeFetchToMem) { + remainingFrameSize = nextFrameSize - ChunkFetchSuccess.ENCODED_LENGTH; + nextFrameSize = ChunkFetchSuccess.ENCODED_LENGTH; + } + + ByteBuf frame = decodeNext(); + if (frame == null) { + break; + } + ParsedFrame parsedFrame = + new ParsedFrame(msgType, frame, remainingFrameSize); + msgType = null; + ctx.fireChannelRead(parsedFrame); } - ctx.fireChannelRead(frame); } } } @@ -121,18 +162,40 @@ private long decodeFrameSize() { return nextFrameSize; } - private ByteBuf decodeNext() { + private void decodeNextMsgType() { + if (msgType != null || !isFrameSizeAvailable() || totalSize < Message.Type.LENGTH) { + return; + } + + ByteBuf first = buffers.getFirst(); + msgType = Message.Type.decode(first); + totalSize -= Message.Type.LENGTH; + nextFrameSize -= Message.Type.LENGTH; + if (!first.isReadable()) { + buffers.removeFirst().release(); + } + } + + private boolean isFrameSizeAvailable() { long frameSize = decodeFrameSize(); - if (frameSize == UNKNOWN_FRAME_SIZE || totalSize < frameSize) { + if (frameSize == UNKNOWN_FRAME_SIZE) { + return false; + } + + Preconditions.checkArgument(frameSize < MAX_FRAME_SIZE, "Too large frame: %s", frameSize); + Preconditions.checkArgument(frameSize > 0, "Frame length should be positive: %s", frameSize); + return true; + } + + private ByteBuf decodeNext() { + long frameSize = nextFrameSize; + if (totalSize < frameSize) { return null; } // Reset size for next frame. nextFrameSize = UNKNOWN_FRAME_SIZE; - Preconditions.checkArgument(frameSize < MAX_FRAME_SIZE, "Too large frame: %s", frameSize); - Preconditions.checkArgument(frameSize > 0, "Frame length should be positive: %s", frameSize); - // If the first buffer holds the entire frame, return it. int remaining = (int) frameSize; if (buffers.getFirst().readableBytes() >= remaining) { diff --git a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java index 824482af08dd4..1717f74a97a03 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java @@ -18,33 +18,30 @@ package org.apache.spark.network; import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; import java.io.RandomAccessFile; import java.nio.ByteBuffer; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashSet; -import java.util.LinkedList; -import java.util.List; -import java.util.Random; -import java.util.Set; +import java.nio.channels.Channels; +import java.nio.channels.WritableByteChannel; +import java.util.*; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import com.google.common.collect.Sets; import com.google.common.io.Closeables; +import org.apache.spark.network.client.*; +import org.apache.spark.network.protocol.StreamChunkId; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; import static org.junit.Assert.*; +import static org.mockito.Mockito.mock; import org.apache.spark.network.buffer.FileSegmentManagedBuffer; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; -import org.apache.spark.network.client.ChunkReceivedCallback; -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.server.StreamManager; @@ -55,6 +52,14 @@ public class ChunkFetchIntegrationSuite { static final long STREAM_ID = 1; static final int BUFFER_CHUNK_INDEX = 0; static final int FILE_CHUNK_INDEX = 1; + static final int BUFFER_FETCH_TO_DISK_CHUNK_INDEX = 2; + static final int MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM = 100000; + + static final TransportConf transportConf = + new TransportConf("shuffle", + new MapConfigProvider( + Collections.singletonMap( + "spark.maxRemoteBlockSizeFetchToMem", Integer.toString(MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM)))); static TransportServer server; static TransportClientFactory clientFactory; @@ -62,17 +67,68 @@ public class ChunkFetchIntegrationSuite { static File testFile; static ManagedBuffer bufferChunk; + static ManagedBuffer bufferToDiskChunk; + static ManagedBuffer fileChunk; + private class FetchChunkDownloadTestCallback implements StreamCallback { + private WritableByteChannel channel; + private File targetFile; + private FetchResult fetchResult; + private Semaphore semaphore; + + FetchChunkDownloadTestCallback(FetchResult fetchResult, Semaphore semaphore) { + this.fetchResult = fetchResult; + this.semaphore = semaphore; + try { + this.targetFile = File.createTempFile("shuffle-test-file-download-", "txt"); + this.targetFile.deleteOnExit(); + this.channel = Channels.newChannel(new FileOutputStream(targetFile)); + } catch (IOException e) { + throw new IllegalStateException(e); + } + } + + @Override + public void onData(StreamChunkId streamId, ByteBuffer buf) throws IOException { + while (buf.hasRemaining()) { + channel.write(buf); + } + } + + @Override + public void onComplete(StreamChunkId streamId) throws IOException { + channel.close(); + ManagedBuffer buffer = new FileSegmentManagedBuffer(transportConf, targetFile, 0, + targetFile.length()); + fetchResult.successChunks.add(streamId.chunkIndex); + fetchResult.buffers.add(buffer); + semaphore.release(); + } + + @Override + public void onFailure(StreamChunkId streamId, Throwable cause) throws IOException { + channel.close(); + this.fetchResult.failedChunks.add(streamId.chunkIndex); + semaphore.release(); + } + } + @BeforeClass public static void setUp() throws Exception { - int bufSize = 100000; - final ByteBuffer buf = ByteBuffer.allocate(bufSize); - for (int i = 0; i < bufSize; i ++) { - buf.put((byte) i); + int bufSize = MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM + 100; + final ByteBuffer hugeBuf = ByteBuffer.allocate(bufSize); + for (int i = 0; i < MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM; i ++) { + hugeBuf.put((byte) i); + } + ByteBuffer smallBuff = hugeBuf.duplicate(); + smallBuff.flip(); + bufferChunk = new NioManagedBuffer(smallBuff); + for (int i = MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM; i < bufSize; i ++) { + hugeBuf.put((byte) i); } - buf.flip(); - bufferChunk = new NioManagedBuffer(buf); + hugeBuf.flip(); + bufferToDiskChunk = new NioManagedBuffer(hugeBuf); testFile = File.createTempFile("shuffle-test-file", "txt"); testFile.deleteOnExit(); @@ -87,19 +143,21 @@ public static void setUp() throws Exception { Closeables.close(fp, shouldSuppressIOException); } - final TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); - fileChunk = new FileSegmentManagedBuffer(conf, testFile, 10, testFile.length() - 25); + fileChunk = new FileSegmentManagedBuffer(transportConf, testFile, 10, testFile.length() - 25); streamManager = new StreamManager() { @Override public ManagedBuffer getChunk(long streamId, int chunkIndex) { assertEquals(STREAM_ID, streamId); - if (chunkIndex == BUFFER_CHUNK_INDEX) { - return new NioManagedBuffer(buf); - } else if (chunkIndex == FILE_CHUNK_INDEX) { - return new FileSegmentManagedBuffer(conf, testFile, 10, testFile.length() - 25); - } else { - throw new IllegalArgumentException("Invalid chunk index: " + chunkIndex); + switch (chunkIndex) { + case BUFFER_CHUNK_INDEX: + return new NioManagedBuffer(smallBuff); + case FILE_CHUNK_INDEX: + return new FileSegmentManagedBuffer(transportConf, testFile, 10, testFile.length() - 25); + case BUFFER_FETCH_TO_DISK_CHUNK_INDEX: + return new NioManagedBuffer(hugeBuf); + default: + throw new IllegalArgumentException("Invalid chunk index: " + chunkIndex); } } }; @@ -117,7 +175,7 @@ public StreamManager getStreamManager() { return streamManager; } }; - TransportContext context = new TransportContext(conf, handler); + TransportContext context = new TransportContext(transportConf, handler); server = context.createServer(); clientFactory = context.createClientFactory(); } @@ -125,6 +183,7 @@ public StreamManager getStreamManager() { @AfterClass public static void tearDown() { bufferChunk.release(); + bufferToDiskChunk.release(); server.close(); clientFactory.close(); testFile.delete(); @@ -168,7 +227,8 @@ public void onFailure(int chunkIndex, Throwable e) { }; for (int chunkIndex : chunkIndices) { - client.fetchChunk(STREAM_ID, chunkIndex, callback); + client.fetchChunk(STREAM_ID, chunkIndex, callback, + () -> new FetchChunkDownloadTestCallback(res, sem)); } if (!sem.tryAcquire(chunkIndices.size(), 5, TimeUnit.SECONDS)) { fail("Timeout getting response from the server"); @@ -182,6 +242,7 @@ public void fetchBufferChunk() throws Exception { FetchResult res = fetchChunks(Arrays.asList(BUFFER_CHUNK_INDEX)); assertEquals(Sets.newHashSet(BUFFER_CHUNK_INDEX), res.successChunks); assertTrue(res.failedChunks.isEmpty()); + assertNumFileSegments(0, res.buffers); assertBufferListsEqual(Arrays.asList(bufferChunk), res.buffers); res.releaseBuffers(); } @@ -191,6 +252,7 @@ public void fetchFileChunk() throws Exception { FetchResult res = fetchChunks(Arrays.asList(FILE_CHUNK_INDEX)); assertEquals(Sets.newHashSet(FILE_CHUNK_INDEX), res.successChunks); assertTrue(res.failedChunks.isEmpty()); + assertNumFileSegments(0, res.buffers); assertBufferListsEqual(Arrays.asList(fileChunk), res.buffers); res.releaseBuffers(); } @@ -208,10 +270,25 @@ public void fetchBothChunks() throws Exception { FetchResult res = fetchChunks(Arrays.asList(BUFFER_CHUNK_INDEX, FILE_CHUNK_INDEX)); assertEquals(Sets.newHashSet(BUFFER_CHUNK_INDEX, FILE_CHUNK_INDEX), res.successChunks); assertTrue(res.failedChunks.isEmpty()); + assertNumFileSegments(0, res.buffers); assertBufferListsEqual(Arrays.asList(bufferChunk, fileChunk), res.buffers); res.releaseBuffers(); } + + @Test + public void fetchSomeChunksToDisk() throws Exception { + FetchResult res = fetchChunks( + Arrays.asList(BUFFER_CHUNK_INDEX, BUFFER_FETCH_TO_DISK_CHUNK_INDEX, FILE_CHUNK_INDEX)); + assertEquals( + Sets.newHashSet(BUFFER_CHUNK_INDEX, BUFFER_FETCH_TO_DISK_CHUNK_INDEX, FILE_CHUNK_INDEX), + res.successChunks); + assertTrue(res.failedChunks.isEmpty()); + assertNumFileSegments(1, res.buffers); + assertBufferListsEqual(Arrays.asList(bufferChunk, bufferToDiskChunk, fileChunk), res.buffers); + res.releaseBuffers(); + } + @Test public void fetchChunkAndNonExistent() throws Exception { FetchResult res = fetchChunks(Arrays.asList(BUFFER_CHUNK_INDEX, 12345)); @@ -221,6 +298,11 @@ public void fetchChunkAndNonExistent() throws Exception { res.releaseBuffers(); } + private void assertNumFileSegments(int expected, List buffers) { + assertEquals(expected, + buffers.stream().filter(b -> b instanceof FileSegmentManagedBuffer).count()); + } + private static void assertBufferListsEqual(List list0, List list1) throws Exception { assertEquals(list0.size(), list1.size()); diff --git a/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java index bc94f7ca63a96..fd318b632919e 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java @@ -28,6 +28,8 @@ import org.junit.Test; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; import org.apache.spark.network.protocol.ChunkFetchFailure; import org.apache.spark.network.protocol.ChunkFetchRequest; @@ -47,20 +49,26 @@ import org.apache.spark.network.util.NettyUtils; public class ProtocolSuite { - private void testServerToClient(Message msg) { + private Message decodedClientMessageFromChannel(Message msg, long maxRemoteBlockSizeFetchToMem) { EmbeddedChannel serverChannel = new EmbeddedChannel(new FileRegionEncoder(), - MessageEncoder.INSTANCE); + MessageEncoder.INSTANCE); serverChannel.writeOutbound(msg); EmbeddedChannel clientChannel = new EmbeddedChannel( - NettyUtils.createFrameDecoder(), MessageDecoder.INSTANCE); + NettyUtils.createFrameDecoder(maxRemoteBlockSizeFetchToMem, false), + MessageDecoder.INSTANCE); while (!serverChannel.outboundMessages().isEmpty()) { clientChannel.writeOneInbound(serverChannel.readOutbound()); } assertEquals(1, clientChannel.inboundMessages().size()); - assertEquals(msg, clientChannel.readInbound()); + return clientChannel.readInbound(); + } + + private void testServerToClient(Message msg) { + Message clientMessage = decodedClientMessageFromChannel(msg, Integer.MAX_VALUE); + assertEquals(msg, clientMessage); } private void testClientToServer(Message msg) { @@ -69,7 +77,8 @@ private void testClientToServer(Message msg) { clientChannel.writeOutbound(msg); EmbeddedChannel serverChannel = new EmbeddedChannel( - NettyUtils.createFrameDecoder(), MessageDecoder.INSTANCE); + NettyUtils.createFrameDecoder(Integer.MAX_VALUE, false), + MessageDecoder.INSTANCE); while (!clientChannel.outboundMessages().isEmpty()) { serverChannel.writeOneInbound(clientChannel.readOutbound()); @@ -79,6 +88,31 @@ private void testClientToServer(Message msg) { assertEquals(msg, serverChannel.readInbound()); } + private ChunkFetchSuccess chunkFetchSuccessWith100Bytes() { + return new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(100)); + } + + private void testChunkFetchSuccess() { + // test without fetch to disk, maxRemoteBlockSizeFetchToMem is Integer.MAX_VALUE + testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(10))); + testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(0))); + + // test with fetch to disk + // under (and at) the fetch to mem limit + Message chunkFetchSuccessUnderFetchToMemBlockSize = + decodedClientMessageFromChannel(chunkFetchSuccessWith100Bytes(), 101); + assertEquals(chunkFetchSuccessWith100Bytes(), chunkFetchSuccessUnderFetchToMemBlockSize); + chunkFetchSuccessUnderFetchToMemBlockSize = + decodedClientMessageFromChannel(chunkFetchSuccessWith100Bytes(), 100); + assertEquals(chunkFetchSuccessWith100Bytes(), chunkFetchSuccessUnderFetchToMemBlockSize); + + // above the fetch to mem limit + Message chunkFetchSuccessAboveFetchToMemBlockSize = + decodedClientMessageFromChannel(chunkFetchSuccessWith100Bytes(), 99); + assertNull("message body must be not included", chunkFetchSuccessAboveFetchToMemBlockSize.body()); + assertFalse("message body must be not included", chunkFetchSuccessAboveFetchToMemBlockSize.isBodyInFrame()); + } + @Test public void requests() { testClientToServer(new ChunkFetchRequest(new StreamChunkId(1, 2))); @@ -88,10 +122,10 @@ public void requests() { testClientToServer(new OneWayMessage(new TestManagedBuffer(10))); } + @Test public void responses() { - testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(10))); - testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(0))); + testChunkFetchSuccess(); testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), "this is an error")); testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), "")); testServerToClient(new RpcResponse(12345, new TestManagedBuffer(0))); diff --git a/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java index c0724e018263f..1d6dcf4a11280 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java @@ -20,10 +20,8 @@ import com.google.common.util.concurrent.Uninterruptibles; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; -import org.apache.spark.network.client.ChunkReceivedCallback; -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.client.*; +import org.apache.spark.network.protocol.StreamChunkId; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.server.TransportServer; @@ -31,6 +29,7 @@ import org.apache.spark.network.util.TransportConf; import org.junit.*; import static org.junit.Assert.*; +import static org.mockito.Mockito.mock; import java.io.IOException; import java.nio.ByteBuffer; @@ -38,6 +37,7 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; /** * Suite which ensures that requests that go without a response for the network timeout period are @@ -211,12 +211,13 @@ public StreamManager getStreamManager() { // Send one request, which will eventually fail. TestCallback callback0 = new TestCallback(); - client.fetchChunk(0, 0, callback0); + Supplier> streamCallbackFactory = mock(Supplier.class); + client.fetchChunk(0, 0, callback0, streamCallbackFactory); Uninterruptibles.sleepUninterruptibly(1200, TimeUnit.MILLISECONDS); // Send a second request before the first has failed. TestCallback callback1 = new TestCallback(); - client.fetchChunk(0, 1, callback1); + client.fetchChunk(0, 1, callback1, streamCallbackFactory); Uninterruptibles.sleepUninterruptibly(1200, TimeUnit.MILLISECONDS); // not complete yet, but should complete soon diff --git a/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java b/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java index f253a07e64be1..04ae42efc0abb 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java @@ -308,7 +308,7 @@ private void waitForCompletion(TestCallback callback) throws Exception { } - private static class TestCallback implements StreamCallback { + private static class TestCallback implements StreamCallback { private final OutputStream out; public volatile boolean completed; diff --git a/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java index b4032c4c3f031..f5935e4c404c1 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java @@ -19,19 +19,17 @@ import java.io.IOException; import java.nio.ByteBuffer; +import java.util.function.Supplier; import io.netty.channel.Channel; import io.netty.channel.local.LocalChannel; +import org.apache.spark.network.client.*; import org.junit.Test; import static org.junit.Assert.assertEquals; import static org.mockito.Mockito.*; import org.apache.spark.network.buffer.NioManagedBuffer; -import org.apache.spark.network.client.ChunkReceivedCallback; -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.StreamCallback; -import org.apache.spark.network.client.TransportResponseHandler; import org.apache.spark.network.protocol.ChunkFetchFailure; import org.apache.spark.network.protocol.ChunkFetchSuccess; import org.apache.spark.network.protocol.RpcFailure; @@ -48,7 +46,8 @@ public void handleSuccessfulFetch() throws Exception { TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); - handler.addFetchRequest(streamChunkId, callback); + Supplier> streamCallbackFactory = mock(Supplier.class); + handler.addFetchRequest(streamChunkId, callback, streamCallbackFactory); assertEquals(1, handler.numOutstandingRequests()); handler.handle(new ChunkFetchSuccess(streamChunkId, new TestManagedBuffer(123))); @@ -61,7 +60,8 @@ public void handleFailedFetch() throws Exception { StreamChunkId streamChunkId = new StreamChunkId(1, 0); TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); - handler.addFetchRequest(streamChunkId, callback); + Supplier> streamCallbackFactory = mock(Supplier.class); + handler.addFetchRequest(streamChunkId, callback, streamCallbackFactory); assertEquals(1, handler.numOutstandingRequests()); handler.handle(new ChunkFetchFailure(streamChunkId, "some error msg")); @@ -73,9 +73,11 @@ public void handleFailedFetch() throws Exception { public void clearAllOutstandingRequests() throws Exception { TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); - handler.addFetchRequest(new StreamChunkId(1, 0), callback); - handler.addFetchRequest(new StreamChunkId(1, 1), callback); - handler.addFetchRequest(new StreamChunkId(1, 2), callback); + Supplier> streamCallback = mock(Supplier.class); + + handler.addFetchRequest(new StreamChunkId(1, 0), callback, streamCallback); + handler.addFetchRequest(new StreamChunkId(1, 1), callback, streamCallback); + handler.addFetchRequest(new StreamChunkId(1, 2), callback, streamCallback); assertEquals(3, handler.numOutstandingRequests()); handler.handle(new ChunkFetchSuccess(new StreamChunkId(1, 0), new TestManagedBuffer(12))); @@ -123,7 +125,8 @@ public void handleFailedRPC() throws Exception { @Test public void testActiveStreams() throws Exception { Channel c = new LocalChannel(); - c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder()); + c.pipeline() + .addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder(Integer.MAX_VALUE)); TransportResponseHandler handler = new TransportResponseHandler(c); StreamResponse response = new StreamResponse("stream", 1234L, null); @@ -145,7 +148,8 @@ public void testActiveStreams() throws Exception { @Test public void failOutstandingStreamCallbackOnClose() throws Exception { Channel c = new LocalChannel(); - c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder()); + c.pipeline() + .addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder(Integer.MAX_VALUE)); TransportResponseHandler handler = new TransportResponseHandler(c); StreamCallback cb = mock(StreamCallback.class); @@ -158,7 +162,8 @@ public void failOutstandingStreamCallbackOnClose() throws Exception { @Test public void failOutstandingStreamCallbackOnException() throws Exception { Channel c = new LocalChannel(); - c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder()); + c.pipeline() + .addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder(Integer.MAX_VALUE)); TransportResponseHandler handler = new TransportResponseHandler(c); StreamCallback cb = mock(StreamCallback.class); diff --git a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index 6f15718bd8705..397e4f63f608f 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -33,6 +33,7 @@ import java.util.concurrent.TimeoutException; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; import javax.security.sasl.SaslException; import com.google.common.collect.ImmutableMap; @@ -44,16 +45,14 @@ import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelOutboundHandlerAdapter; import io.netty.channel.ChannelPromise; +import org.apache.spark.network.client.*; +import org.apache.spark.network.protocol.StreamChunkId; import org.junit.Test; import org.apache.spark.network.TestUtils; import org.apache.spark.network.TransportContext; import org.apache.spark.network.buffer.FileSegmentManagedBuffer; import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.client.ChunkReceivedCallback; -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.client.TransportClientBootstrap; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.server.TransportServer; @@ -281,7 +280,8 @@ public void testFileRegionEncryption() throws Exception { return null; }).when(callback).onSuccess(anyInt(), any(ManagedBuffer.class)); - ctx.client.fetchChunk(0, 0, callback); + Supplier> streamCallbackFactory = mock(Supplier.class); + ctx.client.fetchChunk(0, 0, callback, streamCallbackFactory); lock.await(10, TimeUnit.SECONDS); verify(callback, times(1)).onSuccess(anyInt(), any(ManagedBuffer.class)); diff --git a/common/network-common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java b/common/network-common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java index b53e41303751c..388bb41a57a5e 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java @@ -17,8 +17,8 @@ package org.apache.spark.network.util; -import java.nio.ByteBuffer; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Random; import java.util.concurrent.atomic.AtomicInteger; @@ -26,6 +26,8 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandlerContext; +import org.apache.spark.network.protocol.Message; +import org.apache.spark.network.protocol.ParsedFrame; import org.junit.AfterClass; import org.junit.Test; import static org.junit.Assert.*; @@ -42,7 +44,7 @@ public static void cleanup() { @Test public void testFrameDecoding() throws Exception { - TransportFrameDecoder decoder = new TransportFrameDecoder(); + TransportFrameDecoder decoder = new TransportFrameDecoder(Integer.MAX_VALUE); ChannelHandlerContext ctx = mockChannelHandlerContext(); ByteBuf data = createAndFeedFrames(100, decoder, ctx); verifyAndCloseDecoder(decoder, ctx, data); @@ -51,7 +53,7 @@ public void testFrameDecoding() throws Exception { @Test public void testInterception() throws Exception { int interceptedReads = 3; - TransportFrameDecoder decoder = new TransportFrameDecoder(); + TransportFrameDecoder decoder = new TransportFrameDecoder(Integer.MAX_VALUE); TransportFrameDecoder.Interceptor interceptor = spy(new MockInterceptor(interceptedReads)); ChannelHandlerContext ctx = mockChannelHandlerContext(); @@ -69,7 +71,7 @@ public void testInterception() throws Exception { decoder.channelRead(ctx, len); decoder.channelRead(ctx, dataBuf); verify(interceptor, times(interceptedReads)).handle(any(ByteBuf.class)); - verify(ctx).fireChannelRead(any(ByteBuffer.class)); + verify(ctx).fireChannelRead(any(ParsedFrame.class)); assertEquals(0, len.refCnt()); assertEquals(0, dataBuf.refCnt()); } finally { @@ -80,19 +82,19 @@ public void testInterception() throws Exception { @Test public void testRetainedFrames() throws Exception { - TransportFrameDecoder decoder = new TransportFrameDecoder(); + TransportFrameDecoder decoder = new TransportFrameDecoder(Integer.MAX_VALUE); AtomicInteger count = new AtomicInteger(); - List retained = new ArrayList<>(); + List retained = new ArrayList<>(); ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); when(ctx.fireChannelRead(any())).thenAnswer(in -> { // Retain a few frames but not others. - ByteBuf buf = (ByteBuf) in.getArguments()[0]; + ParsedFrame parsedFrame = (ParsedFrame) in.getArguments()[0]; if (count.incrementAndGet() % 2 == 0) { - retained.add(buf); + retained.add(parsedFrame); } else { - buf.release(); + parsedFrame.byteBuf.release(); } return null; }); @@ -100,15 +102,15 @@ public void testRetainedFrames() throws Exception { ByteBuf data = createAndFeedFrames(100, decoder, ctx); try { // Verify all retained buffers are readable. - for (ByteBuf b : retained) { - byte[] tmp = new byte[b.readableBytes()]; - b.readBytes(tmp); - b.release(); + for (ParsedFrame b : retained) { + byte[] tmp = new byte[b.byteBuf.readableBytes()]; + b.byteBuf.readBytes(tmp); + b.byteBuf.release(); } verifyAndCloseDecoder(decoder, ctx, data); } finally { - for (ByteBuf b : retained) { - release(b); + for (ParsedFrame b : retained) { + release(b.byteBuf); } } } @@ -120,13 +122,13 @@ public void testSplitLengthField() throws Exception { buf.writeLong(frame.length + 8); buf.writeBytes(frame); - TransportFrameDecoder decoder = new TransportFrameDecoder(); + TransportFrameDecoder decoder = new TransportFrameDecoder(Integer.MAX_VALUE); ChannelHandlerContext ctx = mockChannelHandlerContext(); try { decoder.channelRead(ctx, buf.readSlice(RND.nextInt(7)).retain()); - verify(ctx, never()).fireChannelRead(any(ByteBuf.class)); + verify(ctx, never()).fireChannelRead(any(ParsedFrame.class)); decoder.channelRead(ctx, buf); - verify(ctx).fireChannelRead(any(ByteBuf.class)); + verify(ctx).fireChannelRead(any(ParsedFrame.class)); assertEquals(0, buf.refCnt()); } finally { decoder.channelInactive(ctx); @@ -154,8 +156,14 @@ private ByteBuf createAndFeedFrames( TransportFrameDecoder decoder, ChannelHandlerContext ctx) throws Exception { ByteBuf data = Unpooled.buffer(); + Message.Type msgTypes[] = Arrays.stream(Message.Type.values()) + .filter(t -> t != Message.Type.User) + .toArray(Message.Type[]::new); + for (int i = 0; i < frameCount; i++) { byte[] frame = new byte[1024 * (RND.nextInt(31) + 1)]; + Message.Type randomMsgType = msgTypes[RND.nextInt(msgTypes.length)]; + frame[0] = randomMsgType.id(); data.writeLong(frame.length + 8); data.writeBytes(frame); } @@ -166,7 +174,7 @@ private ByteBuf createAndFeedFrames( decoder.channelRead(ctx, data.readSlice(Math.min(data.readableBytes(), size)).retain()); } - verify(ctx, times(frameCount)).fireChannelRead(any(ByteBuf.class)); + verify(ctx, times(frameCount)).fireChannelRead(any(ParsedFrame.class)); } catch (Exception e) { release(data); throw e; @@ -187,7 +195,7 @@ private void verifyAndCloseDecoder( } private void testInvalidFrame(long size) throws Exception { - TransportFrameDecoder decoder = new TransportFrameDecoder(); + TransportFrameDecoder decoder = new TransportFrameDecoder(Integer.MAX_VALUE); ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); ByteBuf frame = Unpooled.copyLong(size); try { @@ -200,8 +208,8 @@ private void testInvalidFrame(long size) throws Exception { private ChannelHandlerContext mockChannelHandlerContext() { ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); when(ctx.fireChannelRead(any())).thenAnswer(in -> { - ByteBuf buf = (ByteBuf) in.getArguments()[0]; - buf.release(); + ParsedFrame parsedFrame = (ParsedFrame) in.getArguments()[0]; + parsedFrame.byteBuf.release(); return null; }); return ctx; diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index fc7bba41185f0..098fa7974b87b 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -138,6 +138,13 @@ public void applicationRemoved(String appId, boolean cleanupLocalDirs) { blockManager.applicationRemoved(appId, cleanupLocalDirs); } + /** + * Clean up any non-shuffle files in any local directories associated with an finished executor. + */ + public void executorRemoved(String executorId, String appId) { + blockManager.executorRemoved(executorId, appId); + } + /** * Register an (application, executor) with the given shuffle info. * diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index e6399897be9c2..0b7a27402369d 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -24,6 +24,8 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; import java.util.concurrent.Executors; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; @@ -59,6 +61,7 @@ public class ExternalShuffleBlockResolver { private static final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockResolver.class); private static final ObjectMapper mapper = new ObjectMapper(); + /** * This a common prefix to the key for each app registration we stick in leveldb, so they * are easy to find, since leveldb lets you search based on prefix. @@ -66,6 +69,8 @@ public class ExternalShuffleBlockResolver { private static final String APP_KEY_PREFIX = "AppExecShuffleInfo"; private static final StoreVersion CURRENT_VERSION = new StoreVersion(1, 0); + private static final Pattern MULTIPLE_SEPARATORS = Pattern.compile(File.separator + "{2,}"); + // Map containing all registered executors' metadata. @VisibleForTesting final ConcurrentMap executors; @@ -211,6 +216,26 @@ public void applicationRemoved(String appId, boolean cleanupLocalDirs) { } } + /** + * Removes all the non-shuffle files in any local directories associated with the finished + * executor. + */ + public void executorRemoved(String executorId, String appId) { + logger.info("Clean up non-shuffle files associated with the finished executor {}", executorId); + AppExecId fullId = new AppExecId(appId, executorId); + final ExecutorShuffleInfo executor = executors.get(fullId); + if (executor == null) { + // Executor not registered, skip clean up of the local directories. + logger.info("Executor is not registered (appId={}, execId={})", appId, executorId); + } else { + logger.info("Cleaning up non-shuffle files in executor {}'s {} local dirs", fullId, + executor.localDirs.length); + + // Execute the actual deletion in a different thread, as it may take some time. + directoryCleaner.execute(() -> deleteNonShuffleFiles(executor.localDirs)); + } + } + /** * Synchronously deletes each directory one at a time. * Should be executed in its own thread, as this may take a long time. @@ -226,6 +251,29 @@ private void deleteExecutorDirs(String[] dirs) { } } + /** + * Synchronously deletes non-shuffle files in each directory recursively. + * Should be executed in its own thread, as this may take a long time. + */ + private void deleteNonShuffleFiles(String[] dirs) { + FilenameFilter filter = new FilenameFilter() { + @Override + public boolean accept(File dir, String name) { + // Don't delete shuffle data or shuffle index files. + return !name.endsWith(".index") && !name.endsWith(".data"); + } + }; + + for (String localDir : dirs) { + try { + JavaUtils.deleteRecursively(new File(localDir), filter); + logger.debug("Successfully cleaned up non-shuffle files in directory: {}", localDir); + } catch (Exception e) { + logger.error("Failed to delete non-shuffle files in directory: " + localDir, e); + } + } + } + /** * Sort-based shuffle data uses an index called "shuffle_ShuffleId_MapId_0.index" into a data file * called "shuffle_ShuffleId_MapId_0.data". This logic is from IndexShuffleBlockResolver, @@ -259,7 +307,8 @@ static File getFile(String[] localDirs, int subDirsPerLocalDir, String filename) int hash = JavaUtils.nonNegativeHash(filename); String localDir = localDirs[hash % localDirs.length]; int subDirId = (hash / localDirs.length) % subDirsPerLocalDir; - return new File(new File(localDir, String.format("%02x", subDirId)), filename); + return new File(createNormalizedInternedPathname( + localDir, String.format("%02x", subDirId), filename)); } void close() { @@ -272,6 +321,28 @@ void close() { } } + /** + * This method is needed to avoid the situation when multiple File instances for the + * same pathname "foo/bar" are created, each with a separate copy of the "foo/bar" String. + * According to measurements, in some scenarios such duplicate strings may waste a lot + * of memory (~ 10% of the heap). To avoid that, we intern the pathname, and before that + * we make sure that it's in a normalized form (contains no "//", "///" etc.) Otherwise, + * the internal code in java.io.File would normalize it later, creating a new "foo/bar" + * String copy. Unfortunately, we cannot just reuse the normalization code that java.io.File + * uses, since it is in the package-private class java.io.FileSystem. + */ + @VisibleForTesting + static String createNormalizedInternedPathname(String dir1, String dir2, String fname) { + String pathname = dir1 + File.separator + dir2 + File.separator + fname; + Matcher m = MULTIPLE_SEPARATORS.matcher(pathname); + pathname = m.replaceAll("/"); + // A single trailing slash needs to be taken care of separately + if (pathname.length() > 1 && pathname.endsWith("/")) { + pathname = pathname.substring(0, pathname.length() - 1); + } + return pathname.intern(); + } + /** Simply encodes an executor's full ID, which is appId + execId. */ public static class AppExecId { public final String appId; diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index 7ed0b6e93a7a8..f2f445da748fb 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -86,12 +86,13 @@ public void init(String appId) { @Override public void fetchBlocks( - String host, - int port, - String execId, - String[] blockIds, - BlockFetchingListener listener, - TempFileManager tempFileManager) { + String host, + int port, + String execId, + String[] blockIds, + BlockFetchingListener listener, + TempFileManager tempFileManager, + boolean useStreamRequestMessage) { checkInit(); logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId); try { @@ -99,7 +100,7 @@ public void fetchBlocks( (blockIds1, listener1) -> { TransportClient client = clientFactory.createClient(host, port); new OneForOneBlockFetcher(client, appId, execId, - blockIds1, listener1, conf, tempFileManager).start(); + blockIds1, listener1, conf, tempFileManager, useStreamRequestMessage).start(); }; int maxRetries = conf.maxIORetries(); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index 0bc571874f07c..5d1ec48ddaf0c 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -24,16 +24,18 @@ import java.nio.channels.Channels; import java.nio.channels.WritableByteChannel; import java.util.Arrays; +import java.util.function.Supplier; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.CompositeByteBuf; +import org.apache.spark.network.buffer.NettyManagedBuffer; +import org.apache.spark.network.client.*; +import org.apache.spark.network.protocol.StreamChunkId; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.buffer.FileSegmentManagedBuffer; import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.client.ChunkReceivedCallback; -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.StreamCallback; -import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.server.OneForOneStreamManager; import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; import org.apache.spark.network.shuffle.protocol.OpenBlocks; @@ -57,8 +59,10 @@ public class OneForOneBlockFetcher { private final String[] blockIds; private final BlockFetchingListener listener; private final ChunkReceivedCallback chunkCallback; + private final Supplier> fetchChunkDownloadCallbackFactory; private final TransportConf transportConf; private final TempFileManager tempFileManager; + private final boolean useStreamRequestMessage; private StreamHandle streamHandle = null; @@ -69,7 +73,7 @@ public OneForOneBlockFetcher( String[] blockIds, BlockFetchingListener listener, TransportConf transportConf) { - this(client, appId, execId, blockIds, listener, transportConf, null); + this(client, appId, execId, blockIds, listener, transportConf, null, false); } public OneForOneBlockFetcher( @@ -79,18 +83,24 @@ public OneForOneBlockFetcher( String[] blockIds, BlockFetchingListener listener, TransportConf transportConf, - TempFileManager tempFileManager) { + TempFileManager tempFileManager, + boolean useStreamRequestMessage) { this.client = client; this.openMessage = new OpenBlocks(appId, execId, blockIds); this.blockIds = blockIds; this.listener = listener; this.chunkCallback = new ChunkCallback(); this.transportConf = transportConf; + // TODO extend tests to pass a valid tempFileManager and use: + // this.tempFileManager = Preconditions.checkNotNull(tempFileManager); + fetchChunkDownloadCallbackFactory = () -> new FetchChunkDownloadCallback(); this.tempFileManager = tempFileManager; + this.useStreamRequestMessage = useStreamRequestMessage; } /** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */ private class ChunkCallback implements ChunkReceivedCallback { + @Override public void onSuccess(int chunkIndex, ManagedBuffer buffer) { // On receipt of a chunk, pass it upwards as a block. @@ -125,11 +135,12 @@ public void onSuccess(ByteBuffer response) { // Immediately request all chunks -- we expect that the total size of the request is // reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]]. for (int i = 0; i < streamHandle.numChunks; i++) { - if (tempFileManager != null) { + if (useStreamRequestMessage) { client.stream(OneForOneStreamManager.genStreamChunkId(streamHandle.streamId, i), new DownloadCallback(i)); } else { - client.fetchChunk(streamHandle.streamId, i, chunkCallback); + client.fetchChunk(streamHandle.streamId, i, + chunkCallback, fetchChunkDownloadCallbackFactory); } } } catch (Exception e) { @@ -157,7 +168,7 @@ private void failRemainingBlocks(String[] failedBlockIds, Throwable e) { } } - private class DownloadCallback implements StreamCallback { + private class DownloadCallback implements StreamCallback { private WritableByteChannel channel = null; private File targetFile = null; @@ -196,4 +207,45 @@ public void onFailure(String streamId, Throwable cause) throws IOException { targetFile.delete(); } } + + private class FetchChunkDownloadCallback implements StreamCallback { + private WritableByteChannel channel = null; + private File targetFile = null; + + FetchChunkDownloadCallback() { + this.targetFile = tempFileManager.createTempFile(); + try { + this.channel = Channels.newChannel(new FileOutputStream(targetFile)); + } catch (IOException e) { + throw new IllegalStateException(e); + } + } + + @Override + public void onData(StreamChunkId streamId, ByteBuffer buf) throws IOException { + while (buf.hasRemaining()) { + channel.write(buf); + } + } + + @Override + public void onComplete(StreamChunkId streamId) throws IOException { + channel.close(); + ManagedBuffer buffer = new FileSegmentManagedBuffer(transportConf, targetFile, 0, + targetFile.length()); + listener.onBlockFetchSuccess(blockIds[streamId.chunkIndex], buffer); + if (!tempFileManager.registerTempFileToClean(targetFile)) { + targetFile.delete(); + } + } + + @Override + public void onFailure(StreamChunkId streamId, Throwable cause) throws IOException { + channel.close(); + // On receipt of a failure, fail every block from chunkIndex onwards. + String[] remainingBlockIds = Arrays.copyOfRange(blockIds, streamId.chunkIndex, blockIds.length); + failRemainingBlocks(remainingBlockIds, cause); + targetFile.delete(); + } + } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java index 18b04fedcac5b..eeb5c85b967ba 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java @@ -37,24 +37,24 @@ public void init(String appId) { } * Note that this API takes a sequence so the implementation can batch requests, and does not * return a future so the underlying implementation can invoke onBlockFetchSuccess as soon as * the data of a block is fetched, rather than waiting for all blocks to be fetched. - * * @param host the host of the remote node. * @param port the port of the remote node. * @param execId the executor id. * @param blockIds block ids to fetch. * @param listener the listener to receive block fetching status. * @param tempFileManager TempFileManager to create and clean temp files. - * If it's not null, the remote blocks will be streamed - * into temp shuffle files to reduce the memory usage, otherwise, - * they will be kept in memory. +* If it's not null, the remote blocks will be streamed +* into temp shuffle files to reduce the memory usage, otherwise, + * @param useStreamRequestMessage flags whether to fetch to disk as the request is too large */ public abstract void fetchBlocks( - String host, - int port, - String execId, - String[] blockIds, - BlockFetchingListener listener, - TempFileManager tempFileManager); + String host, + int port, + String execId, + String[] blockIds, + BlockFetchingListener listener, + TempFileManager tempFileManager, + boolean useStreamRequestMessage); /** * Get the shuffle MetricsSet from ShuffleClient, this will be used in MetricsSystem to diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index 02e6eb3a4467e..416af7150b67f 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -23,7 +23,10 @@ import java.util.Arrays; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; +import org.apache.spark.network.client.*; +import org.apache.spark.network.protocol.StreamChunkId; import org.junit.After; import org.junit.AfterClass; import org.junit.BeforeClass; @@ -35,10 +38,6 @@ import org.apache.spark.network.TestUtils; import org.apache.spark.network.TransportContext; import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.client.ChunkReceivedCallback; -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.server.OneForOneStreamManager; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; @@ -232,6 +231,7 @@ public void onBlockFetchFailure(String blockId, Throwable t) { CountDownLatch chunkReceivedLatch = new CountDownLatch(1); ChunkReceivedCallback callback = new ChunkReceivedCallback() { + @Override public void onSuccess(int chunkIndex, ManagedBuffer buffer) { chunkReceivedLatch.countDown(); @@ -244,7 +244,8 @@ public void onFailure(int chunkIndex, Throwable t) { }; exception.set(null); - client2.fetchChunk(streamId, 0, callback); + Supplier> streamCallbackFactory = mock(Supplier.class); + client2.fetchChunk(streamId, 0, callback, streamCallbackFactory); chunkReceivedLatch.await(); checkSecurityException(exception.get()); } finally { diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java index 6d201b8fe8d7d..d2072a54fa415 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.network.shuffle; +import java.io.File; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; @@ -135,4 +136,23 @@ public void jsonSerializationOfExecutorRegistration() throws IOException { "\"subDirsPerLocalDir\": 7, \"shuffleManager\": " + "\"" + SORT_MANAGER + "\"}"; assertEquals(shuffleInfo, mapper.readValue(legacyShuffleJson, ExecutorShuffleInfo.class)); } + + @Test + public void testNormalizeAndInternPathname() { + assertPathsMatch("/foo", "bar", "baz", "/foo/bar/baz"); + assertPathsMatch("//foo/", "bar/", "//baz", "/foo/bar/baz"); + assertPathsMatch("foo", "bar", "baz///", "foo/bar/baz"); + assertPathsMatch("/foo/", "/bar//", "/baz", "/foo/bar/baz"); + assertPathsMatch("/", "", "", "/"); + assertPathsMatch("/", "/", "/", "/"); + } + + private void assertPathsMatch(String p1, String p2, String p3, String expectedPathname) { + String normPathname = + ExternalShuffleBlockResolver.createNormalizedInternedPathname(p1, p2, p3); + assertEquals(expectedPathname, normPathname); + File file = new File(normPathname); + String returnedPath = file.getPath(); + assertTrue(normPathname == returnedPath); + } } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index a6a1b8d0ac3f1..e3a85ea67741a 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -158,7 +158,7 @@ public void onBlockFetchFailure(String blockId, Throwable exception) { } } } - }, null); + }, null, false); if (!requestsRemaining.tryAcquire(blockIds.length, 5, TimeUnit.SECONDS)) { fail("Timeout getting response from the server"); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/NonShuffleFilesCleanupSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/NonShuffleFilesCleanupSuite.java new file mode 100644 index 0000000000000..d22f3ace4103b --- /dev/null +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/NonShuffleFilesCleanupSuite.java @@ -0,0 +1,221 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.File; +import java.io.FilenameFilter; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Random; +import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicBoolean; + +import com.google.common.util.concurrent.MoreExecutors; +import org.junit.Test; +import static org.junit.Assert.assertTrue; + +import org.apache.spark.network.util.MapConfigProvider; +import org.apache.spark.network.util.TransportConf; + +public class NonShuffleFilesCleanupSuite { + + // Same-thread Executor used to ensure cleanup happens synchronously in test thread. + private Executor sameThreadExecutor = MoreExecutors.sameThreadExecutor(); + private TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); + private static final String SORT_MANAGER = "org.apache.spark.shuffle.sort.SortShuffleManager"; + + @Test + public void cleanupOnRemovedExecutorWithShuffleFiles() throws IOException { + cleanupOnRemovedExecutor(true); + } + + @Test + public void cleanupOnRemovedExecutorWithoutShuffleFiles() throws IOException { + cleanupOnRemovedExecutor(false); + } + + private void cleanupOnRemovedExecutor(boolean withShuffleFiles) throws IOException { + TestShuffleDataContext dataContext = initDataContext(withShuffleFiles); + + ExternalShuffleBlockResolver resolver = + new ExternalShuffleBlockResolver(conf, null, sameThreadExecutor); + resolver.registerExecutor("app", "exec0", dataContext.createExecutorInfo(SORT_MANAGER)); + resolver.executorRemoved("exec0", "app"); + + assertCleanedUp(dataContext); + } + + @Test + public void cleanupUsesExecutorWithShuffleFiles() throws IOException { + cleanupUsesExecutor(true); + } + + @Test + public void cleanupUsesExecutorWithoutShuffleFiles() throws IOException { + cleanupUsesExecutor(false); + } + + private void cleanupUsesExecutor(boolean withShuffleFiles) throws IOException { + TestShuffleDataContext dataContext = initDataContext(withShuffleFiles); + + AtomicBoolean cleanupCalled = new AtomicBoolean(false); + + // Executor which does nothing to ensure we're actually using it. + Executor noThreadExecutor = runnable -> cleanupCalled.set(true); + + ExternalShuffleBlockResolver manager = + new ExternalShuffleBlockResolver(conf, null, noThreadExecutor); + + manager.registerExecutor("app", "exec0", dataContext.createExecutorInfo(SORT_MANAGER)); + manager.executorRemoved("exec0", "app"); + + assertTrue(cleanupCalled.get()); + assertStillThere(dataContext); + } + + @Test + public void cleanupOnlyRemovedExecutorWithShuffleFiles() throws IOException { + cleanupOnlyRemovedExecutor(true); + } + + @Test + public void cleanupOnlyRemovedExecutorWithoutShuffleFiles() throws IOException { + cleanupOnlyRemovedExecutor(false); + } + + private void cleanupOnlyRemovedExecutor(boolean withShuffleFiles) throws IOException { + TestShuffleDataContext dataContext0 = initDataContext(withShuffleFiles); + TestShuffleDataContext dataContext1 = initDataContext(withShuffleFiles); + + ExternalShuffleBlockResolver resolver = + new ExternalShuffleBlockResolver(conf, null, sameThreadExecutor); + resolver.registerExecutor("app", "exec0", dataContext0.createExecutorInfo(SORT_MANAGER)); + resolver.registerExecutor("app", "exec1", dataContext1.createExecutorInfo(SORT_MANAGER)); + + + resolver.executorRemoved("exec-nonexistent", "app"); + assertStillThere(dataContext0); + assertStillThere(dataContext1); + + resolver.executorRemoved("exec0", "app"); + assertCleanedUp(dataContext0); + assertStillThere(dataContext1); + + resolver.executorRemoved("exec1", "app"); + assertCleanedUp(dataContext0); + assertCleanedUp(dataContext1); + + // Make sure it's not an error to cleanup multiple times + resolver.executorRemoved("exec1", "app"); + assertCleanedUp(dataContext0); + assertCleanedUp(dataContext1); + } + + @Test + public void cleanupOnlyRegisteredExecutorWithShuffleFiles() throws IOException { + cleanupOnlyRegisteredExecutor(true); + } + + @Test + public void cleanupOnlyRegisteredExecutorWithoutShuffleFiles() throws IOException { + cleanupOnlyRegisteredExecutor(false); + } + + private void cleanupOnlyRegisteredExecutor(boolean withShuffleFiles) throws IOException { + TestShuffleDataContext dataContext = initDataContext(withShuffleFiles); + + ExternalShuffleBlockResolver resolver = + new ExternalShuffleBlockResolver(conf, null, sameThreadExecutor); + resolver.registerExecutor("app", "exec0", dataContext.createExecutorInfo(SORT_MANAGER)); + + resolver.executorRemoved("exec1", "app"); + assertStillThere(dataContext); + + resolver.executorRemoved("exec0", "app"); + assertCleanedUp(dataContext); + } + + private static void assertStillThere(TestShuffleDataContext dataContext) { + for (String localDir : dataContext.localDirs) { + assertTrue(localDir + " was cleaned up prematurely", new File(localDir).exists()); + } + } + + private static FilenameFilter filter = new FilenameFilter() { + @Override + public boolean accept(File dir, String name) { + // Don't delete shuffle data or shuffle index files. + return !name.endsWith(".index") && !name.endsWith(".data"); + } + }; + + private static boolean assertOnlyShuffleDataInDir(File[] dirs) { + for (File dir : dirs) { + assertTrue(dir.getName() + " wasn't cleaned up", !dir.exists() || + dir.listFiles(filter).length == 0 || assertOnlyShuffleDataInDir(dir.listFiles())); + } + return true; + } + + private static void assertCleanedUp(TestShuffleDataContext dataContext) { + for (String localDir : dataContext.localDirs) { + File[] dirs = new File[] {new File(localDir)}; + assertOnlyShuffleDataInDir(dirs); + } + } + + private static TestShuffleDataContext initDataContext(boolean withShuffleFiles) + throws IOException { + if (withShuffleFiles) { + return initDataContextWithShuffleFiles(); + } else { + return initDataContextWithoutShuffleFiles(); + } + } + + private static TestShuffleDataContext initDataContextWithShuffleFiles() throws IOException { + TestShuffleDataContext dataContext = createDataContext(); + createShuffleFiles(dataContext); + createNonShuffleFiles(dataContext); + return dataContext; + } + + private static TestShuffleDataContext initDataContextWithoutShuffleFiles() throws IOException { + TestShuffleDataContext dataContext = createDataContext(); + createNonShuffleFiles(dataContext); + return dataContext; + } + + private static TestShuffleDataContext createDataContext() { + TestShuffleDataContext dataContext = new TestShuffleDataContext(10, 5); + dataContext.create(); + return dataContext; + } + + private static void createShuffleFiles(TestShuffleDataContext dataContext) throws IOException { + Random rand = new Random(123); + dataContext.insertSortShuffleData(rand.nextInt(1000), rand.nextInt(1000), new byte[][] { + "ABC".getBytes(StandardCharsets.UTF_8), + "DEF".getBytes(StandardCharsets.UTF_8)}); + } + + private static void createNonShuffleFiles(TestShuffleDataContext dataContext) throws IOException { + // Create spill file(s) + dataContext.insertSpillData(); + } +} diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java index dc947a619bf02..ab8a37cdb3154 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java @@ -165,7 +165,7 @@ private static BlockFetchingListener fetchBlocks(LinkedHashMap T): T = { + try { + getValue + } catch { + case e: NumberFormatException => + // NumberFormatException doesn't have a constructor that takes a cause for some reason. + throw new NumberFormatException(s"Illegal value for config key $key: ${e.getMessage}") + .initCause(e) + case e: IllegalArgumentException => + throw new IllegalArgumentException(s"Illegal value for config key $key: ${e.getMessage}", e) + } + } + /** * Checks for illegal or deprecated config settings. Throws an exception for the former. Not * idempotent - may mutate this conf object to convert deprecated settings to supported ones. diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index f075a7e0eb0b4..41eac10d9b267 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -183,6 +183,13 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( dataOut.writeInt(context.partitionId()) dataOut.writeInt(context.attemptNumber()) dataOut.writeLong(context.taskAttemptId()) + val localProps = context.asInstanceOf[TaskContextImpl].getLocalProperties.asScala + dataOut.writeInt(localProps.size) + localProps.foreach { case (k, v) => + PythonRDD.writeUTF(k, dataOut) + PythonRDD.writeUTF(v, dataOut) + } + // sparkFilesDir PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut) // Python includes (*.zip and *.egg files) diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala index f975fa5cb4e23..b59a4fe66587c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala @@ -94,6 +94,11 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana blockHandler.applicationRemoved(appId, true /* cleanupLocalDirs */) } + /** Clean up all the non-shuffle files associated with an executor that has exited. */ + def executorRemoved(executorId: String, appId: String): Unit = { + blockHandler.executorRemoved(executorId, appId) + } + def stop() { if (server != null) { server.close() diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala index 1b7e031ee0678..ccb30e205ca40 100644 --- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala @@ -19,6 +19,7 @@ package org.apache.spark.deploy import java.io.File import java.net.{InetAddress, URI} +import java.nio.file.Files import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer @@ -48,7 +49,7 @@ object PythonRunner { // Format python file paths before adding them to the PYTHONPATH val formattedPythonFile = formatPath(pythonFile) - val formattedPyFiles = formatPaths(pyFiles) + val formattedPyFiles = resolvePyFiles(formatPaths(pyFiles)) // Launch a Py4J gateway server for the process to connect to; this will let it see our // Java system properties and such @@ -153,4 +154,30 @@ object PythonRunner { .map { p => formatPath(p, testWindows) } } + /** + * Resolves the ".py" files. ".py" file should not be added as is because PYTHONPATH does + * not expect a file. This method creates a temporary directory and puts the ".py" files + * if exist in the given paths. + */ + private def resolvePyFiles(pyFiles: Array[String]): Array[String] = { + lazy val dest = Utils.createTempDir(namePrefix = "localPyFiles") + pyFiles.flatMap { pyFile => + // In case of client with submit, the python paths should be set before context + // initialization because the context initialization can be done later. + // We will copy the local ".py" files because ".py" file shouldn't be added + // alone but its parent directory in PYTHONPATH. See SPARK-24384. + if (pyFile.endsWith(".py")) { + val source = new File(pyFile) + if (source.exists() && source.isFile && source.canRead) { + Files.copy(source.toPath, new File(dest, source.getName).toPath) + Some(dest.getAbsolutePath) + } else { + // Don't have to add it if it doesn't exist or isn't readable. + None + } + } else { + Some(pyFile) + } + }.distinct + } } diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala index 5151df00476f9..ab8d8d96a9b08 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala @@ -32,7 +32,7 @@ import org.apache.spark.internal.Logging * * Also, each HadoopDelegationTokenProvider is controlled by * spark.security.credentials.{service}.enabled, and will not be loaded if this config is set to - * false. For example, Hive's delegation token provider [[HiveDelegationTokenProvider]] can be + * false. For example, Hive's delegation token provider [[HiveDelegationTokenProvider]] can be * enabled/disabled by the configuration spark.security.credentials.hive.enabled. * * @param sparkConf Spark configuration @@ -52,7 +52,7 @@ private[spark] class HadoopDelegationTokenManager( // Maintain all the registered delegation token providers private val delegationTokenProviders = getDelegationTokenProviders - logDebug(s"Using the following delegation token providers: " + + logDebug("Using the following builtin delegation token providers: " + s"${delegationTokenProviders.keys.mkString(", ")}.") /** Construct a [[HadoopDelegationTokenManager]] for the default Hadoop filesystem */ diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 563b84934f264..ee1ca0bba5749 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -23,6 +23,7 @@ import java.text.SimpleDateFormat import java.util.{Date, Locale, UUID} import java.util.concurrent._ import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture} +import java.util.function.Supplier import scala.collection.mutable.{HashMap, HashSet, LinkedHashMap} import scala.concurrent.ExecutionContext @@ -49,7 +50,8 @@ private[deploy] class Worker( endpointName: String, workDirPath: String = null, val conf: SparkConf, - val securityMgr: SecurityManager) + val securityMgr: SecurityManager, + externalShuffleServiceSupplier: Supplier[ExternalShuffleService] = null) extends ThreadSafeRpcEndpoint with Logging { private val host = rpcEnv.address.host @@ -97,6 +99,10 @@ private[deploy] class Worker( private val APP_DATA_RETENTION_SECONDS = conf.getLong("spark.worker.cleanup.appDataTtl", 7 * 24 * 3600) + // Whether or not cleanup the non-shuffle files on executor exits. + private val CLEANUP_NON_SHUFFLE_FILES_ENABLED = + conf.getBoolean("spark.storage.cleanupFilesAfterExecutorExit", true) + private val testing: Boolean = sys.props.contains("spark.testing") private var master: Option[RpcEndpointRef] = None @@ -142,7 +148,11 @@ private[deploy] class Worker( WorkerWebUI.DEFAULT_RETAINED_DRIVERS) // The shuffle service is not actually started unless configured. - private val shuffleService = new ExternalShuffleService(conf, securityMgr) + private val shuffleService = if (externalShuffleServiceSupplier != null) { + externalShuffleServiceSupplier.get() + } else { + new ExternalShuffleService(conf, securityMgr) + } private val publicAddress = { val envVar = conf.getenv("SPARK_PUBLIC_DNS") @@ -732,6 +742,9 @@ private[deploy] class Worker( trimFinishedExecutorsIfNecessary() coresUsed -= executor.cores memoryUsed -= executor.memory + if (CLEANUP_NON_SHUFFLE_FILES_ENABLED) { + shuffleService.executorRemoved(executorStateChanged.execId.toString, appId) + } case None => logInfo("Unknown Executor " + fullId + " finished with state " + state + message.map(" message " + _).getOrElse("") + diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index a54b091a64d50..ad7a7d8af1a74 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -425,12 +425,18 @@ package object config { .doc("Remote block will be fetched to disk when size of the block is above this threshold " + "in bytes. This is to avoid a giant request takes too much memory. We can enable this " + "config by setting a specific value(e.g. 200m). Note this configuration will affect " + - "both shuffle fetch and block manager remote block fetch. For users who enabled " + - "external shuffle service, this feature can only be worked when external shuffle" + - "service is newer than Spark 2.2.") + "both shuffle fetch and block manager remote block fetch.") .bytesConf(ByteUnit.BYTE) .createWithDefault(Long.MaxValue) + private[spark] val STREAM_REQUEST_MESSAGE_ENABLED = + ConfigBuilder("spark.streamRequestMessageEnabled") + .doc("Remote block will be requested to be fetched to disk using stream request message. " + + "For users who enabled external shuffle service, this feature can only be worked when " + + "external shuffle service is newer than Spark 2.2.") + .booleanConf + .createWithDefault(false) + private[spark] val TASK_METRICS_TRACK_UPDATED_BLOCK_STATUSES = ConfigBuilder("spark.taskMetrics.trackUpdatedBlockStatuses") .doc("Enable tracking of updatedBlockStatuses in the TaskMetrics. Off by default since " + diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index 1d8a266d0079c..3800c80c56a81 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -68,7 +68,8 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo execId: String, blockIds: Array[String], listener: BlockFetchingListener, - tempFileManager: TempFileManager): Unit + tempFileManager: TempFileManager, + useStreamRequestMessage: Boolean): Unit /** * Upload a single block to a remote node, available only after [[init]] is invoked. @@ -92,7 +93,8 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo port: Int, execId: String, blockId: String, - tempFileManager: TempFileManager): ManagedBuffer = { + tempFileManager: TempFileManager, + useStreamRequestMessage: Boolean): ManagedBuffer = { // A monitor for the thread to wait on. val result = Promise[ManagedBuffer]() fetchBlocks(host, port, execId, Array(blockId), @@ -111,7 +113,9 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo result.success(new NioManagedBuffer(ret)) } } - }, tempFileManager) + }, + tempFileManager, + useStreamRequestMessage) ThreadUtils.awaitResult(result.future, Duration.Inf) } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index b7d8c35032763..5f0e475a12f90 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -100,19 +100,19 @@ private[spark] class NettyBlockTransferService( } override def fetchBlocks( - host: String, - port: Int, + host: String, port: Int, execId: String, blockIds: Array[String], listener: BlockFetchingListener, - tempFileManager: TempFileManager): Unit = { + tempFileManager: TempFileManager, + useStreamRequestMessage: Boolean): Unit = { logTrace(s"Fetch blocks from $host:$port (executor id $execId)") try { val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter { override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) { val client = clientFactory.createClient(host, port) new OneForOneBlockFetcher(client, appId, execId, blockIds, listener, - transportConf, tempFileManager).start() + transportConf, tempFileManager, useStreamRequestMessage).start() } } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index a2936d6ad539c..0d8cf2f135d78 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -407,7 +407,7 @@ private[netty] class NettyRpcEnv( private class FileDownloadCallback( sink: WritableByteChannel, source: FileDownloadChannel, - client: TransportClient) extends StreamCallback { + client: TransportClient) extends StreamCallback[String] { override def onData(streamId: String, buf: ByteBuffer): Unit = { while (buf.remaining() > 0) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 8e97b3da33820..598b62f85a1fa 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -42,7 +42,7 @@ import org.apache.spark.util.{AccumulatorV2, ThreadUtils, Utils} * up to launch speculative tasks, etc. * * Clients should first call initialize() and start(), then submit task sets through the - * runTasks method. + * submitTasks method. * * THREADING: [[SchedulerBackend]]s and task-submitting clients can call this class from multiple * threads, so it needs locks in public API methods to maintain its state. In addition, some @@ -62,7 +62,7 @@ private[spark] class TaskSchedulerImpl( this(sc, sc.conf.get(config.MAX_TASK_FAILURES)) } - // Lazily initializing blackListTrackOpt to avoid getting empty ExecutorAllocationClient, + // Lazily initializing blacklistTrackerOpt to avoid getting empty ExecutorAllocationClient, // because ExecutorAllocationClient is created after this TaskSchedulerImpl. private[scheduler] lazy val blacklistTrackerOpt = maybeCreateBlacklistTracker(sc) @@ -228,7 +228,7 @@ private[spark] class TaskSchedulerImpl( // 1. The task set manager has been created and some tasks have been scheduled. // In this case, send a kill signal to the executors to kill the task and then abort // the stage. - // 2. The task set manager has been created but no tasks has been scheduled. In this case, + // 2. The task set manager has been created but no tasks have been scheduled. In this case, // simply abort the stage. tsm.runningTasksSet.foreach { tid => taskIdToExecutorId.get(tid).foreach(execId => @@ -694,7 +694,7 @@ private[spark] class TaskSchedulerImpl( * * After stage failure and retry, there may be multiple TaskSetManagers for the stage. * If an earlier attempt of a stage completes a task, we should ensure that the later attempts - * do not also submit those same tasks. That also means that a task completion from an earlier + * do not also submit those same tasks. That also means that a task completion from an earlier * attempt can lead to the entire stage getting marked as successful. */ private[scheduler] def markPartitionCompletedInAllTaskSets(stageId: Int, partitionId: Int) = { diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 4103dfb10175e..44884395984a7 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -52,6 +52,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024, SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue), SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), + SparkEnv.get.conf.get(config.STREAM_REQUEST_MESSAGE_ENABLED), SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true)) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index e0276a4dc4224..197154814e169 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -212,6 +212,7 @@ private[spark] class BlockManager( private[storage] val remoteBlockTempFileManager = new BlockManager.RemoteBlockTempFileManager(this) private val maxRemoteBlockToMem = conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM) + private val streamRequestMessageEnabled = conf.get(config.STREAM_REQUEST_MESSAGE_ENABLED) /** * Initializes the BlockManager with the given appId. This is not performed in the constructor as @@ -671,16 +672,7 @@ private[spark] class BlockManager( b.status.diskSize.max(b.status.memSize) }.getOrElse(0L) val blockLocations = locationsAndStatus.map(_.locations).getOrElse(Seq.empty) - - // If the block size is above the threshold, we should pass our FileManger to - // BlockTransferService, which will leverage it to spill the block; if not, then passed-in - // null value means the block will be persisted in memory. - val tempFileManager = if (blockSize > maxRemoteBlockToMem) { - remoteBlockTempFileManager - } else { - null - } - + val useStreamRequestMessage = streamRequestMessageEnabled && blockSize > maxRemoteBlockToMem val locations = sortLocations(blockLocations) val maxFetchFailures = locations.size var locationIterator = locations.iterator @@ -689,7 +681,12 @@ private[spark] class BlockManager( logDebug(s"Getting remote block $blockId from $loc") val data = try { blockTransferService.fetchBlockSync( - loc.host, loc.port, loc.executorId, blockId.toString, tempFileManager).nioByteBuffer() + loc.host, + loc.port, + loc.executorId, + blockId.toString, + remoteBlockTempFileManager, + useStreamRequestMessage).nioByteBuffer() } catch { case NonFatal(e) => runningFailureCount += 1 diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index b31862323a895..b7203e01b013d 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -69,6 +69,7 @@ final class ShuffleBlockFetcherIterator( maxBytesInFlight: Long, maxReqsInFlight: Int, maxBlocksInFlightPerAddress: Int, + streamRequestMessageEnabled: Boolean, maxReqSizeShuffleToMem: Long, detectCorrupt: Boolean) extends Iterator[(BlockId, InputStream)] with TempFileManager with Logging { @@ -248,13 +249,14 @@ final class ShuffleBlockFetcherIterator( // Fetch remote shuffle blocks to disk when the request is too large. Since the shuffle data is // already encrypted and compressed over the wire(w.r.t. the related configs), we can just fetch // the data and write it to file directly. - if (req.size > maxReqSizeShuffleToMem) { - shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, - blockFetchingListener, this) - } else { - shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, - blockFetchingListener, null) - } + shuffleClient.fetchBlocks( + address.host, + address.port, + address.executorId, + blockIds.toArray, + blockFetchingListener, + this, + req.size > maxReqSizeShuffleToMem && streamRequestMessageEnabled) } private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index f651fe97c2cd5..178d2c8d1a10a 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -206,7 +206,9 @@ private[ui] class AllJobsPage(parent: JobsTab, store: AppStatusStore) extends We jobs: Seq[v1.JobData], killEnabled: Boolean): Seq[Node] = { // stripXSS is called to remove suspicious characters used in XSS attacks - val allParameters = request.getParameterMap.asScala.toMap.mapValues(_.map(UIUtils.stripXSS)) + val allParameters = request.getParameterMap.asScala.toMap.map { case (k, v) => + UIUtils.stripXSS(k) -> v.map(UIUtils.stripXSS).toSeq + } val parameterOtherTable = allParameters.filterNot(_._1.startsWith(jobTag)) .map(para => para._1 + "=" + para._2(0)) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 2575914121c39..d4e6a7bc3effa 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -117,8 +117,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val localitySummary = store.localitySummary(stageData.stageId, stageData.attemptId) - val totalTasks = stageData.numActiveTasks + stageData.numCompleteTasks + - stageData.numFailedTasks + stageData.numKilledTasks + val totalTasks = taskCount(stageData) if (totalTasks == 0) { val content =
@@ -133,7 +132,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val totalTasksNumStr = if (totalTasks == storedTasks) { s"$totalTasks" } else { - s"$totalTasks, showing ${storedTasks}" + s"$storedTasks, showing ${totalTasks}" } val summary = @@ -686,7 +685,7 @@ private[ui] class TaskDataSource( private var _tasksToShow: Seq[TaskData] = null - override def dataSize: Int = stage.numTasks + override def dataSize: Int = taskCount(stage) override def sliceData(from: Int, to: Int): Seq[TaskData] = { if (_tasksToShow == null) { @@ -1052,4 +1051,9 @@ private[ui] object ApiHelper { (stage.map(_.name).getOrElse(""), stage.flatMap(_.description).getOrElse(job.name)) } + def taskCount(stageData: StageData): Int = { + stageData.numActiveTasks + stageData.numCompleteTasks + stageData.numFailedTasks + + stageData.numKilledTasks + } + } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index b8b20db1fa407..56e4d6838a99a 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -43,7 +43,9 @@ private[ui] class StageTableBase( killEnabled: Boolean, isFailedStage: Boolean) { // stripXSS is called to remove suspicious characters used in XSS attacks - val allParameters = request.getParameterMap.asScala.toMap.mapValues(_.map(UIUtils.stripXSS)) + val allParameters = request.getParameterMap.asScala.toMap.map { case (k, v) => + UIUtils.stripXSS(k) -> v.map(UIUtils.stripXSS).toSeq + } val parameterOtherTable = allParameters.filterNot(_._1.startsWith(stageTag)) .map(para => para._1 + "=" + para._2(0)) diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 28ea0c6f0bdba..0af967f39612e 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -171,7 +171,7 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex val serializerManager = SparkEnv.get.serializerManager blockManager.master.getLocations(blockId).foreach { cmId => val bytes = blockTransfer.fetchBlockSync(cmId.host, cmId.port, cmId.executorId, - blockId.toString, null) + blockId.toString, null, false) val deserialized = serializerManager.dataDeserializeStream(blockId, new ChunkedByteBuffer(bytes.nioByteBuffer()).toInputStream())(data.elementClassTag).toList assert(deserialized === (1 to 100).toList) diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index bff808eb540ac..0d06b02e74e34 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -339,6 +339,38 @@ class SparkConfSuite extends SparkFunSuite with LocalSparkContext with ResetSyst } } + val defaultIllegalValue = "SomeIllegalValue" + val illegalValueTests : Map[String, (SparkConf, String) => Any] = Map( + "getTimeAsSeconds" -> (_.getTimeAsSeconds(_)), + "getTimeAsSeconds with default" -> (_.getTimeAsSeconds(_, defaultIllegalValue)), + "getTimeAsMs" -> (_.getTimeAsMs(_)), + "getTimeAsMs with default" -> (_.getTimeAsMs(_, defaultIllegalValue)), + "getSizeAsBytes" -> (_.getSizeAsBytes(_)), + "getSizeAsBytes with default string" -> (_.getSizeAsBytes(_, defaultIllegalValue)), + "getSizeAsBytes with default long" -> (_.getSizeAsBytes(_, 0L)), + "getSizeAsKb" -> (_.getSizeAsKb(_)), + "getSizeAsKb with default" -> (_.getSizeAsKb(_, defaultIllegalValue)), + "getSizeAsMb" -> (_.getSizeAsMb(_)), + "getSizeAsMb with default" -> (_.getSizeAsMb(_, defaultIllegalValue)), + "getSizeAsGb" -> (_.getSizeAsGb(_)), + "getSizeAsGb with default" -> (_.getSizeAsGb(_, defaultIllegalValue)), + "getInt" -> (_.getInt(_, 0)), + "getLong" -> (_.getLong(_, 0L)), + "getDouble" -> (_.getDouble(_, 0.0)), + "getBoolean" -> (_.getBoolean(_, false)) + ) + + illegalValueTests.foreach { case (name, getValue) => + test(s"SPARK-24337: $name throws an useful error message with key name") { + val key = "SomeKey" + val conf = new SparkConf() + conf.set(key, "SomeInvalidValue") + val thrown = intercept[IllegalArgumentException] { + getValue(conf, key) + } + assert(thrown.getMessage.contains(key)) + } + } } class Class1 {} diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala index ce212a7513310..e3fe2b696aa1f 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala @@ -17,10 +17,19 @@ package org.apache.spark.deploy.worker +import java.util.concurrent.atomic.AtomicBoolean +import java.util.function.Supplier + +import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Answers.RETURNS_SMART_NULLS +import org.mockito.Matchers._ +import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer import org.scalatest.{BeforeAndAfter, Matchers} import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} -import org.apache.spark.deploy.{Command, ExecutorState} +import org.apache.spark.deploy.{Command, ExecutorState, ExternalShuffleService} import org.apache.spark.deploy.DeployMessages.{DriverStateChanged, ExecutorStateChanged} import org.apache.spark.deploy.master.DriverState import org.apache.spark.rpc.{RpcAddress, RpcEnv} @@ -29,6 +38,8 @@ class WorkerSuite extends SparkFunSuite with Matchers with BeforeAndAfter { import org.apache.spark.deploy.DeployTestUtils._ + @Mock(answer = RETURNS_SMART_NULLS) private var shuffleService: ExternalShuffleService = _ + def cmd(javaOpts: String*): Command = { Command("", Seq.empty, Map.empty, Seq.empty, Seq.empty, Seq(javaOpts : _*)) } @@ -36,15 +47,21 @@ class WorkerSuite extends SparkFunSuite with Matchers with BeforeAndAfter { private var _worker: Worker = _ - private def makeWorker(conf: SparkConf): Worker = { + private def makeWorker( + conf: SparkConf, + shuffleServiceSupplier: Supplier[ExternalShuffleService] = null): Worker = { assert(_worker === null, "Some Worker's RpcEnv is leaked in tests") val securityMgr = new SecurityManager(conf) val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, securityMgr) _worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), - "Worker", "/tmp", conf, securityMgr) + "Worker", "/tmp", conf, securityMgr, shuffleServiceSupplier) _worker } + before { + MockitoAnnotations.initMocks(this) + } + after { if (_worker != null) { _worker.rpcEnv.shutdown() @@ -194,4 +211,36 @@ class WorkerSuite extends SparkFunSuite with Matchers with BeforeAndAfter { assert(worker.finishedDrivers.size === expectedValue) } } + + test("cleanup non-shuffle files after executor exits when config " + + "spark.storage.cleanupFilesAfterExecutorExit=true") { + testCleanupFilesWithConfig(true) + } + + test("don't cleanup non-shuffle files after executor exits when config " + + "spark.storage.cleanupFilesAfterExecutorExit=false") { + testCleanupFilesWithConfig(false) + } + + private def testCleanupFilesWithConfig(value: Boolean) = { + val conf = new SparkConf().set("spark.storage.cleanupFilesAfterExecutorExit", value.toString) + + val cleanupCalled = new AtomicBoolean(false) + when(shuffleService.executorRemoved(any[String], any[String])).thenAnswer(new Answer[Unit] { + override def answer(invocations: InvocationOnMock): Unit = { + cleanupCalled.set(true) + } + }) + val externalShuffleServiceSupplier = new Supplier[ExternalShuffleService] { + override def get: ExternalShuffleService = shuffleService + } + val worker = makeWorker(conf, externalShuffleServiceSupplier) + // initialize workers + for (i <- 0 until 10) { + worker.executors += s"app1/$i" -> createExecutorRunner(i) + } + worker.handleExecutorStateChanged( + ExecutorStateChanged("app1", 0, ExecutorState.EXITED, None, None)) + assert(cleanupCalled.get() == value) + } } diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala index 21138bd4a16ba..b574e07cc7b50 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -165,7 +165,9 @@ class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar wi override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { promise.success(data.retain()) } - }, null) + }, + null, + false) ThreadUtils.awaitReady(promise.future, FiniteDuration(10, TimeUnit.SECONDS)) promise.future.value.get diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index b19d8ebf72c61..ec6a5cbcb280d 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -1434,7 +1434,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE execId: String, blockIds: Array[String], listener: BlockFetchingListener, - tempFileManager: TempFileManager): Unit = { + tempFileManager: TempFileManager, + useStreamRequestMessage: Boolean): Unit = { listener.onBlockFetchSuccess("mockBlockId", new NioManagedBuffer(ByteBuffer.allocate(1))) } @@ -1461,13 +1462,14 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE port: Int, execId: String, blockId: String, - tempFileManager: TempFileManager): ManagedBuffer = { + tempFileManager: TempFileManager, + useStreamRequestMessage: Boolean): ManagedBuffer = { numCalls += 1 this.tempFileManager = tempFileManager if (numCalls <= maxFailures) { throw new RuntimeException("Failing block fetch in the mock block transfer service") } - super.fetchBlockSync(host, port, execId, blockId, tempFileManager) + super.fetchBlockSync(host, port, execId, blockId, tempFileManager, useStreamRequestMessage) } } } diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index a2997dbd1b1ac..1e41a35aa4bdd 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -46,7 +46,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT /** Creates a mock [[BlockTransferService]] that returns data from the given map. */ private def createMockTransfer(data: Map[BlockId, ManagedBuffer]): BlockTransferService = { val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val blocks = invocation.getArguments()(3).asInstanceOf[Array[String]] @@ -111,6 +111,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT 48 * 1024 * 1024, Int.MaxValue, Int.MaxValue, + false, Int.MaxValue, true) @@ -140,7 +141,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // 3 local blocks, and 2 remote blocks // (but from the same block manager so one call to fetchBlocks) verify(blockManager, times(3)).getBlockData(any()) - verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any()) + verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any(), any()) } test("release current unexhausted buffer in case the task completes early") { @@ -159,7 +160,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val sem = new Semaphore(0) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -189,6 +190,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT 48 * 1024 * 1024, Int.MaxValue, Int.MaxValue, + false, Int.MaxValue, true) @@ -227,7 +229,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val sem = new Semaphore(0) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -257,6 +259,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT 48 * 1024 * 1024, Int.MaxValue, Int.MaxValue, + false, Int.MaxValue, true) @@ -297,7 +300,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val corruptLocalBuffer = new FileSegmentManagedBuffer(null, new File("a"), 0, 100) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -327,6 +330,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT 48 * 1024 * 1024, Int.MaxValue, Int.MaxValue, + false, Int.MaxValue, true) @@ -337,7 +341,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val (id1, _) = iterator.next() assert(id1 === ShuffleBlockId(0, 0, 0)) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -391,6 +395,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT 2048, Int.MaxValue, Int.MaxValue, + false, Int.MaxValue, true) // Blocks should be returned without exceptions. @@ -415,7 +420,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val sem = new Semaphore(0) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] @@ -445,6 +450,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT 48 * 1024 * 1024, Int.MaxValue, Int.MaxValue, + false, Int.MaxValue, false) @@ -478,12 +484,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val remoteBlocks = Map[BlockId, ManagedBuffer]( ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer()) val transfer = mock(classOf[BlockTransferService]) - var tempFileManager: TempFileManager = null - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + var useStreamRequestMessage: Boolean = false + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] - tempFileManager = invocation.getArguments()(5).asInstanceOf[TempFileManager] + useStreamRequestMessage = invocation.getArguments()(6).asInstanceOf[Boolean] Future { listener.onBlockFetchSuccess( ShuffleBlockId(0, 0, 0).toString, remoteBlocks(ShuffleBlockId(0, 0, 0))) @@ -505,6 +511,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT maxBytesInFlight = Int.MaxValue, maxReqsInFlight = Int.MaxValue, maxBlocksInFlightPerAddress = Int.MaxValue, + streamRequestMessageEnabled = true, maxReqSizeShuffleToMem = 200, detectCorrupt = true) } @@ -514,14 +521,14 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT fetchShuffleBlock(blocksByAddress1) // `maxReqSizeShuffleToMem` is 200, which is greater than the block size 100, so don't fetch // shuffle block to disk. - assert(tempFileManager == null) + assert(!useStreamRequestMessage) val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq)).toIterator fetchShuffleBlock(blocksByAddress2) // `maxReqSizeShuffleToMem` is 200, which is smaller than the block size 300, so fetch // shuffle block to disk. - assert(tempFileManager != null) + assert(useStreamRequestMessage) } test("fail zero-size blocks") { @@ -551,6 +558,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT 48 * 1024 * 1024, Int.MaxValue, Int.MaxValue, + false, Int.MaxValue, true) diff --git a/docs/configuration.md b/docs/configuration.md index 64af0e98a82f5..5588c372d3e42 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -456,6 +456,33 @@ Apart from these, the following properties are also available, and may be useful from JVM to Python worker for every task. + + spark.sql.repl.eagerEval.enabled + false + + Enable eager evaluation or not. If true and the REPL you are using supports eager evaluation, + Dataset will be ran automatically. The HTML table which generated by _repl_html_ + called by notebooks like Jupyter will feedback the queries user have defined. For plain Python + REPL, the output will be shown like dataframe.show() + (see SPARK-24215 for more details). + + + + spark.sql.repl.eagerEval.maxNumRows + 20 + + Default number of rows in eager evaluation output HTML table generated by _repr_html_ or plain text, + this only take effect when spark.sql.repl.eagerEval.enabled is set to true. + + + + spark.sql.repl.eagerEval.truncate + 20 + + Default number of truncate in eager evaluation output HTML table generated by _repr_html_ or + plain text, this only take effect when spark.sql.repl.eagerEval.enabled set to true. + + spark.files diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index e9e1f3e280609..4eac9bd9032e4 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -140,6 +140,12 @@ namespace as that of the driver and executor pods. For example, to mount a secre --conf spark.kubernetes.executor.secrets.spark-secret=/etc/secrets ``` +To use a secret through an environment variable use the following options to the `spark-submit` command: +``` +--conf spark.kubernetes.driver.secretKeyRef.ENV_NAME=name:key +--conf spark.kubernetes.executor.secretKeyRef.ENV_NAME=name:key +``` + ## Introspection and Debugging These are the different ways in which you can investigate a running/completed Spark application, monitor progress, and @@ -321,6 +327,13 @@ specific to Spark on Kubernetes. Container image pull policy used when pulling images within Kubernetes. + + spark.kubernetes.container.image.pullSecrets + + + Comma separated list of Kubernetes secrets used to pull images from private image registries. + + spark.kubernetes.allocation.batch.size 5 @@ -602,4 +615,20 @@ specific to Spark on Kubernetes. spark.kubernetes.executor.secrets.spark-secret=/etc/secrets. + + spark.kubernetes.driver.secretKeyRef.[EnvName] + (none) + + Add as an environment variable to the driver container with name EnvName (case sensitive), the value referenced by key key in the data of the referenced Kubernetes Secret. For example, + spark.kubernetes.driver.secretKeyRef.ENV_VAR=spark-secret:key. + + + + spark.kubernetes.executor.secretKeyRef.[EnvName] + (none) + + Add as an environment variable to the executor container with name EnvName (case sensitive), the value referenced by key key in the data of the referenced Kubernetes Secret. For example, + spark.kubernetes.executor.secrets.ENV_VAR=spark-secret:key. + + diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 3c2a1501ca692..66ffb17949845 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -753,6 +753,18 @@ See the [configuration page](configuration.html) for information on Spark config spark.cores.max is reached + + spark.mesos.appJar.local.resolution.mode + host + + Provides support for the `local:///` scheme to reference the app jar resource in cluster mode. + If user uses a local resource (`local:///path/to/jar`) and the config option is not used it defaults to `host` eg. + the mesos fetcher tries to get the resource from the host's file system. + If the value is unknown it prints a warning msg in the dispatcher logs and defaults to `host`. + If the value is `container` then spark submit in the container will use the jar in the container's path: + `/path/to/jar`. + + # Troubleshooting and Debugging diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index f06e72a387df1..14d742de5655c 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -254,6 +254,18 @@ SPARK_WORKER_OPTS supports the following system properties: especially if you run jobs very frequently. + + spark.storage.cleanupFilesAfterExecutorExit + true + + Enable cleanup non-shuffle files(such as temp. shuffle blocks, cached RDD/broadcast blocks, + spill files, etc) of worker directories following executor exits. Note that this doesn't + overlap with `spark.worker.cleanup.enabled`, as this enables cleanup of non-shuffle files in + local directories of a dead executor, while `spark.worker.cleanup.enabled` enables cleanup of + all files/subdirectories of a stopped and timeout application. + This only affects Standalone mode, support of other cluster manangers can be added in the future. + + spark.worker.ui.compressedLogFileLengthCacheSize 100 diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 50600861912b1..4d8a738507bd1 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1752,6 +1752,15 @@ To use `groupBy().apply()`, the user needs to define the following: * A Python function that defines the computation for each group. * A `StructType` object or a string that defines the schema of the output `DataFrame`. +The output schema will be applied to the columns of the returned `pandas.DataFrame` in order by position, +not by name. This means that the columns in the `pandas.DataFrame` must be indexed so that their +position matches the corresponding field in the schema. + +Note that when creating a new `pandas.DataFrame` using a dictionary, the actual position of the column +can differ from the order that it was placed in the dictionary. It is recommended in this case to +explicitly define the column order using the `columns` keyword, e.g. +`pandas.DataFrame({'id': ids, 'a': data}, columns=['id', 'a'])`, or alternatively use an `OrderedDict`. + Note that all data for a group will be loaded into memory before the function is applied. This can lead to out of memory exceptons, especially if the group sizes are skewed. The configuration for [maxRecordsPerBatch](#setting-arrow-batch-size) is not applied on groups and it is up to the user diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index 1ad4e097246a3..9c9614509c64f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -276,8 +276,7 @@ class BisectingKMeans @Since("2.0.0") ( val summary = new BisectingKMeansSummary( model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) model.setSummary(Some(summary)) - // TODO: need to extend logNamedValue to support Array - instr.logNamedValue("clusterSizes", summary.clusterSizes.mkString("[", ",", "]")) + instr.logNamedValue("clusterSizes", summary.clusterSizes) instr.logSuccess(model) model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index 3091bb5a2e54c..64ecc1ebda589 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -426,8 +426,7 @@ class GaussianMixture @Since("2.0.0") ( $(predictionCol), $(probabilityCol), $(featuresCol), $(k), logLikelihood) model.setSummary(Some(summary)) instr.logNamedValue("logLikelihood", logLikelihood) - // TODO: need to extend logNamedValue to support Array - instr.logNamedValue("clusterSizes", summary.clusterSizes.mkString("[", ",", "]")) + instr.logNamedValue("clusterSizes", summary.clusterSizes) instr.logSuccess(model) model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index e72d7f9485e6a..1704412741d49 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -359,8 +359,7 @@ class KMeans @Since("1.5.0") ( model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) model.setSummary(Some(summary)) - // TODO: need to extend logNamedValue to support Array - instr.logNamedValue("clusterSizes", summary.clusterSizes.mkString("[", ",", "]")) + instr.logNamedValue("clusterSizes", summary.clusterSizes) instr.logSuccess(model) if (handlePersistence) { instances.unpersist() diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala index 2c30a1d9aa947..1b9a3499947d9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala @@ -18,21 +18,20 @@ package org.apache.spark.ml.clustering import org.apache.spark.annotation.{Experimental, Since} -import org.apache.spark.ml.Transformer import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.clustering.{PowerIterationClustering => MLlibPowerIterationClustering} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} -import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.{col, lit} import org.apache.spark.sql.types._ /** * Common params for PowerIterationClustering */ private[clustering] trait PowerIterationClusteringParams extends Params with HasMaxIter - with HasPredictionCol { + with HasWeightCol { /** * The number of clusters to create (k). Must be > 1. Default: 2. @@ -66,62 +65,33 @@ private[clustering] trait PowerIterationClusteringParams extends Params with Has def getInitMode: String = $(initMode) /** - * Param for the name of the input column for vertex IDs. - * Default: "id" + * Param for the name of the input column for source vertex IDs. + * Default: "src" * @group param */ @Since("2.4.0") - val idCol = new Param[String](this, "idCol", "Name of the input column for vertex IDs.", + val srcCol = new Param[String](this, "srcCol", "Name of the input column for source vertex IDs.", (value: String) => value.nonEmpty) - setDefault(idCol, "id") - - /** @group getParam */ - @Since("2.4.0") - def getIdCol: String = getOrDefault(idCol) - - /** - * Param for the name of the input column for neighbors in the adjacency list representation. - * Default: "neighbors" - * @group param - */ - @Since("2.4.0") - val neighborsCol = new Param[String](this, "neighborsCol", - "Name of the input column for neighbors in the adjacency list representation.", - (value: String) => value.nonEmpty) - - setDefault(neighborsCol, "neighbors") - /** @group getParam */ @Since("2.4.0") - def getNeighborsCol: String = $(neighborsCol) + def getSrcCol: String = getOrDefault(srcCol) /** - * Param for the name of the input column for neighbors in the adjacency list representation. - * Default: "similarities" + * Name of the input column for destination vertex IDs. + * Default: "dst" * @group param */ @Since("2.4.0") - val similaritiesCol = new Param[String](this, "similaritiesCol", - "Name of the input column for neighbors in the adjacency list representation.", + val dstCol = new Param[String](this, "dstCol", + "Name of the input column for destination vertex IDs.", (value: String) => value.nonEmpty) - setDefault(similaritiesCol, "similarities") - /** @group getParam */ @Since("2.4.0") - def getSimilaritiesCol: String = $(similaritiesCol) + def getDstCol: String = $(dstCol) - protected def validateAndTransformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnTypes(schema, $(idCol), Seq(IntegerType, LongType)) - SchemaUtils.checkColumnTypes(schema, $(neighborsCol), - Seq(ArrayType(IntegerType, containsNull = false), - ArrayType(LongType, containsNull = false))) - SchemaUtils.checkColumnTypes(schema, $(similaritiesCol), - Seq(ArrayType(FloatType, containsNull = false), - ArrayType(DoubleType, containsNull = false))) - SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) - } + setDefault(srcCol -> "src", dstCol -> "dst") } /** @@ -131,21 +101,8 @@ private[clustering] trait PowerIterationClusteringParams extends Params with Has * PIC finds a very low-dimensional embedding of a dataset using truncated power * iteration on a normalized pair-wise similarity matrix of the data. * - * PIC takes an affinity matrix between items (or vertices) as input. An affinity matrix - * is a symmetric matrix whose entries are non-negative similarities between items. - * PIC takes this matrix (or graph) as an adjacency matrix. Specifically, each input row includes: - * - `idCol`: vertex ID - * - `neighborsCol`: neighbors of vertex in `idCol` - * - `similaritiesCol`: non-negative weights (similarities) of edges between the vertex - * in `idCol` and each neighbor in `neighborsCol` - * PIC returns a cluster assignment for each input vertex. It appends a new column `predictionCol` - * containing the cluster assignment in `[0,k)` for each row (vertex). - * - * Notes: - * - [[PowerIterationClustering]] is a transformer with an expensive [[transform]] operation. - * Transform runs the iterative PIC algorithm to cluster the whole input dataset. - * - Input validation: This validates that similarities are non-negative but does NOT validate - * that the input matrix is symmetric. + * This class is not yet an Estimator/Transformer, use `assignClusters` method to run the + * PowerIterationClustering algorithm. * * @see * Spectral clustering (Wikipedia) @@ -154,7 +111,7 @@ private[clustering] trait PowerIterationClusteringParams extends Params with Has @Experimental class PowerIterationClustering private[clustering] ( @Since("2.4.0") override val uid: String) - extends Transformer with PowerIterationClusteringParams with DefaultParamsWritable { + extends PowerIterationClusteringParams with DefaultParamsWritable { setDefault( k -> 2, @@ -164,10 +121,6 @@ class PowerIterationClustering private[clustering] ( @Since("2.4.0") def this() = this(Identifiable.randomUID("PowerIterationClustering")) - /** @group setParam */ - @Since("2.4.0") - def setPredictionCol(value: String): this.type = set(predictionCol, value) - /** @group setParam */ @Since("2.4.0") def setK(value: Int): this.type = set(k, value) @@ -182,66 +135,56 @@ class PowerIterationClustering private[clustering] ( /** @group setParam */ @Since("2.4.0") - def setIdCol(value: String): this.type = set(idCol, value) + def setSrcCol(value: String): this.type = set(srcCol, value) /** @group setParam */ @Since("2.4.0") - def setNeighborsCol(value: String): this.type = set(neighborsCol, value) + def setDstCol(value: String): this.type = set(dstCol, value) /** @group setParam */ @Since("2.4.0") - def setSimilaritiesCol(value: String): this.type = set(similaritiesCol, value) + def setWeightCol(value: String): this.type = set(weightCol, value) + /** + * Run the PIC algorithm and returns a cluster assignment for each input vertex. + * + * @param dataset A dataset with columns src, dst, weight representing the affinity matrix, + * which is the matrix A in the PIC paper. Suppose the src column value is i, + * the dst column value is j, the weight column value is similarity s,,ij,, + * which must be nonnegative. This is a symmetric matrix and hence + * s,,ij,, = s,,ji,,. For any (i, j) with nonzero similarity, there should be + * either (i, j, s,,ij,,) or (j, i, s,,ji,,) in the input. Rows with i = j are + * ignored, because we assume s,,ij,, = 0.0. + * + * @return A dataset that contains columns of vertex id and the corresponding cluster for the id. + * The schema of it will be: + * - id: Long + * - cluster: Int + */ @Since("2.4.0") - override def transform(dataset: Dataset[_]): DataFrame = { - transformSchema(dataset.schema, logging = true) + def assignClusters(dataset: Dataset[_]): DataFrame = { + val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) { + lit(1.0) + } else { + col($(weightCol)).cast(DoubleType) + } - val sparkSession = dataset.sparkSession - val idColValue = $(idCol) - val rdd: RDD[(Long, Long, Double)] = - dataset.select( - col($(idCol)).cast(LongType), - col($(neighborsCol)).cast(ArrayType(LongType, containsNull = false)), - col($(similaritiesCol)).cast(ArrayType(DoubleType, containsNull = false)) - ).rdd.flatMap { - case Row(id: Long, nbrs: Seq[_], sims: Seq[_]) => - require(nbrs.size == sims.size, s"The length of the neighbor ID list must be " + - s"equal to the the length of the neighbor similarity list. Row for ID " + - s"$idColValue=$id has neighbor ID list of length ${nbrs.length} but similarity list " + - s"of length ${sims.length}.") - nbrs.asInstanceOf[Seq[Long]].zip(sims.asInstanceOf[Seq[Double]]).map { - case (nbr, similarity) => (id, nbr, similarity) - } - } + SchemaUtils.checkColumnTypes(dataset.schema, $(srcCol), Seq(IntegerType, LongType)) + SchemaUtils.checkColumnTypes(dataset.schema, $(dstCol), Seq(IntegerType, LongType)) + val rdd: RDD[(Long, Long, Double)] = dataset.select( + col($(srcCol)).cast(LongType), + col($(dstCol)).cast(LongType), + w).rdd.map { + case Row(src: Long, dst: Long, weight: Double) => (src, dst, weight) + } val algorithm = new MLlibPowerIterationClustering() .setK($(k)) .setInitializationMode($(initMode)) .setMaxIterations($(maxIter)) val model = algorithm.run(rdd) - val predictionsRDD: RDD[Row] = model.assignments.map { assignment => - Row(assignment.id, assignment.cluster) - } - - val predictionsSchema = StructType(Seq( - StructField($(idCol), LongType, nullable = false), - StructField($(predictionCol), IntegerType, nullable = false))) - val predictions = { - val uncastPredictions = sparkSession.createDataFrame(predictionsRDD, predictionsSchema) - dataset.schema($(idCol)).dataType match { - case _: LongType => - uncastPredictions - case otherType => - uncastPredictions.select(col($(idCol)).cast(otherType).alias($(idCol))) - } - } - - dataset.join(predictions, $(idCol)) - } - - @Since("2.4.0") - override def transformSchema(schema: StructType): StructType = { - validateAndTransformSchema(schema) + import dataset.sparkSession.implicits._ + model.assignments.toDF } @Since("2.4.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala index 41716c621ca98..bd1c1a8885201 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala @@ -53,7 +53,7 @@ final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params @Since("2.4.0") val minSupport = new DoubleParam(this, "minSupport", "The minimal support level of the " + "sequential pattern. Sequential pattern that appears more than " + - "(minSupport * size-of-the-dataset)." + + "(minSupport * size-of-the-dataset) " + "times will be output.", ParamValidators.gtEq(0.0)) /** @group getParam */ @@ -128,10 +128,10 @@ final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params * Finds the complete set of frequent sequential patterns in the input sequences of itemsets. * * @param dataset A dataset or a dataframe containing a sequence column which is - * {{{Seq[Seq[_]]}}} type + * {{{ArrayType(ArrayType(T))}}} type, T is the item type for the input dataset. * @return A `DataFrame` that contains columns of sequence and corresponding frequency. * The schema of it will be: - * - `sequence: Seq[Seq[T]]` (T is the item type) + * - `sequence: ArrayType(ArrayType(T))` (T is the item type) * - `freq: Long` */ @Since("2.4.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala index 467130b37c16e..3a1c166d46257 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala @@ -132,6 +132,19 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( log(compact(render(name -> value))) } + def logNamedValue(name: String, value: Array[String]): Unit = { + log(compact(render(name -> compact(render(value.toSeq))))) + } + + def logNamedValue(name: String, value: Array[Long]): Unit = { + log(compact(render(name -> compact(render(value.toSeq))))) + } + + def logNamedValue(name: String, value: Array[Double]): Unit = { + log(compact(render(name -> compact(render(value.toSeq))))) + } + + /** * Logs the successful completion of the training session. */ diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index 096b5416899e1..db92132d18b7b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -34,9 +34,8 @@ object LDASuite { vocabSize: Int): DataFrame = { val avgWC = 1 // average instances of each word in a doc val sc = spark.sparkContext - val rng = new java.util.Random() - rng.setSeed(1) val rdd = sc.parallelize(1 to rows).map { i => + val rng = new java.util.Random(i) Vectors.dense(Array.fill(vocabSize)(rng.nextInt(2 * avgWC).toDouble)) }.map(v => new TestRow(v)) spark.createDataFrame(rdd) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala index 65328df17baff..b7072728d48f0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala @@ -17,19 +17,19 @@ package org.apache.spark.ml.clustering -import scala.collection.mutable - import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} -import org.apache.spark.sql.functions.col +import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} +import org.apache.spark.sql.functions.{col, lit} import org.apache.spark.sql.types._ class PowerIterationClusteringSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import testImplicits._ + @transient var data: Dataset[_] = _ final val r1 = 1.0 final val n1 = 10 @@ -48,10 +48,9 @@ class PowerIterationClusteringSuite extends SparkFunSuite assert(pic.getK === 2) assert(pic.getMaxIter === 20) assert(pic.getInitMode === "random") - assert(pic.getPredictionCol === "prediction") - assert(pic.getIdCol === "id") - assert(pic.getNeighborsCol === "neighbors") - assert(pic.getSimilaritiesCol === "similarities") + assert(pic.getSrcCol === "src") + assert(pic.getDstCol === "dst") + assert(!pic.isDefined(pic.weightCol)) } test("parameter validation") { @@ -62,125 +61,102 @@ class PowerIterationClusteringSuite extends SparkFunSuite new PowerIterationClustering().setInitMode("no_such_a_mode") } intercept[IllegalArgumentException] { - new PowerIterationClustering().setIdCol("") + new PowerIterationClustering().setSrcCol("") } intercept[IllegalArgumentException] { - new PowerIterationClustering().setNeighborsCol("") - } - intercept[IllegalArgumentException] { - new PowerIterationClustering().setSimilaritiesCol("") + new PowerIterationClustering().setDstCol("") } } test("power iteration clustering") { val n = n1 + n2 - val model = new PowerIterationClustering() + val assignments = new PowerIterationClustering() .setK(2) .setMaxIter(40) - val result = model.transform(data) - - val predictions = Array.fill(2)(mutable.Set.empty[Long]) - result.select("id", "prediction").collect().foreach { - case Row(id: Long, cluster: Integer) => predictions(cluster) += id - } - assert(predictions.toSet == Set((1 until n1).toSet, (n1 until n).toSet)) - - val result2 = new PowerIterationClustering() + .setWeightCol("weight") + .assignClusters(data) + val localAssignments = assignments + .select('id, 'cluster) + .as[(Long, Int)].collect().toSet + val expectedResult = (0 until n1).map(x => (x, 1)).toSet ++ + (n1 until n).map(x => (x, 0)).toSet + assert(localAssignments === expectedResult) + + val assignments2 = new PowerIterationClustering() .setK(2) .setMaxIter(10) .setInitMode("degree") - .transform(data) - val predictions2 = Array.fill(2)(mutable.Set.empty[Long]) - result2.select("id", "prediction").collect().foreach { - case Row(id: Long, cluster: Integer) => predictions2(cluster) += id - } - assert(predictions2.toSet == Set((1 until n1).toSet, (n1 until n).toSet)) + .setWeightCol("weight") + .assignClusters(data) + val localAssignments2 = assignments2 + .select('id, 'cluster) + .as[(Long, Int)].collect().toSet + assert(localAssignments2 === expectedResult) } test("supported input types") { - val model = new PowerIterationClustering() + val pic = new PowerIterationClustering() .setK(2) .setMaxIter(1) + .setWeightCol("weight") - def runTest(idType: DataType, neighborType: DataType, similarityType: DataType): Unit = { + def runTest(srcType: DataType, dstType: DataType, weightType: DataType): Unit = { val typedData = data.select( - col("id").cast(idType).alias("id"), - col("neighbors").cast(ArrayType(neighborType, containsNull = false)).alias("neighbors"), - col("similarities").cast(ArrayType(similarityType, containsNull = false)) - .alias("similarities") + col("src").cast(srcType).alias("src"), + col("dst").cast(dstType).alias("dst"), + col("weight").cast(weightType).alias("weight") ) - model.transform(typedData).collect() - } - - for (idType <- Seq(IntegerType, LongType)) { - runTest(idType, LongType, DoubleType) - } - for (neighborType <- Seq(IntegerType, LongType)) { - runTest(LongType, neighborType, DoubleType) - } - for (similarityType <- Seq(FloatType, DoubleType)) { - runTest(LongType, LongType, similarityType) + pic.assignClusters(typedData).collect() } - } - test("invalid input: wrong types") { - val model = new PowerIterationClustering() - .setK(2) - .setMaxIter(1) - intercept[IllegalArgumentException] { - val typedData = data.select( - col("id").cast(DoubleType).alias("id"), - col("neighbors"), - col("similarities") - ) - model.transform(typedData) + for (srcType <- Seq(IntegerType, LongType)) { + runTest(srcType, LongType, DoubleType) } - intercept[IllegalArgumentException] { - val typedData = data.select( - col("id"), - col("neighbors").cast(ArrayType(DoubleType, containsNull = false)).alias("neighbors"), - col("similarities") - ) - model.transform(typedData) + for (dstType <- Seq(IntegerType, LongType)) { + runTest(LongType, dstType, DoubleType) } - intercept[IllegalArgumentException] { - val typedData = data.select( - col("id"), - col("neighbors"), - col("neighbors").alias("similarities") - ) - model.transform(typedData) + for (weightType <- Seq(FloatType, DoubleType)) { + runTest(LongType, LongType, weightType) } } test("invalid input: negative similarity") { - val model = new PowerIterationClustering() + val pic = new PowerIterationClustering() .setMaxIter(1) + .setWeightCol("weight") val badData = spark.createDataFrame(Seq( - (0, Array(1), Array(-1.0)), - (1, Array(0), Array(-1.0)) - )).toDF("id", "neighbors", "similarities") + (0, 1, -1.0), + (1, 0, -1.0) + )).toDF("src", "dst", "weight") val msg = intercept[SparkException] { - model.transform(badData) + pic.assignClusters(badData) }.getCause.getMessage assert(msg.contains("Similarity must be nonnegative")) } - test("invalid input: mismatched lengths for neighbor and similarity arrays") { - val model = new PowerIterationClustering() - .setMaxIter(1) - val badData = spark.createDataFrame(Seq( - (0, Array(1), Array(0.5)), - (1, Array(0, 2), Array(0.5)), - (2, Array(1), Array(0.5)) - )).toDF("id", "neighbors", "similarities") - val msg = intercept[SparkException] { - model.transform(badData) - }.getCause.getMessage - assert(msg.contains("The length of the neighbor ID list must be equal to the the length of " + - "the neighbor similarity list.")) - assert(msg.contains(s"Row for ID ${model.getIdCol}=1")) + test("test default weight") { + val dataWithoutWeight = data.sample(0.5, 1L).select('src, 'dst) + + val assignments = new PowerIterationClustering() + .setK(2) + .setMaxIter(40) + .assignClusters(dataWithoutWeight) + val localAssignments = assignments + .select('id, 'cluster) + .as[(Long, Int)].collect().toSet + + val dataWithWeightOne = dataWithoutWeight.withColumn("weight", lit(1.0)) + + val assignments2 = new PowerIterationClustering() + .setK(2) + .setMaxIter(40) + .assignClusters(dataWithWeightOne) + val localAssignments2 = assignments2 + .select('id, 'cluster) + .as[(Long, Int)].collect().toSet + + assert(localAssignments === localAssignments2) } test("read/write") { @@ -188,10 +164,9 @@ class PowerIterationClusteringSuite extends SparkFunSuite .setK(4) .setMaxIter(100) .setInitMode("degree") - .setIdCol("test_id") - .setNeighborsCol("myNeighborsCol") - .setSimilaritiesCol("mySimilaritiesCol") - .setPredictionCol("test_prediction") + .setSrcCol("src1") + .setDstCol("dst1") + .setWeightCol("weight") testDefaultReadWrite(t) } } @@ -222,17 +197,13 @@ object PowerIterationClusteringSuite { val n = n1 + n2 val points = genCircle(r1, n1) ++ genCircle(r2, n2) - val rows = for (i <- 1 until n) yield { - val neighbors = for (j <- 0 until i) yield { - j.toLong + val rows = (for (i <- 1 until n) yield { + for (j <- 0 until i) yield { + (i.toLong, j.toLong, sim(points(i), points(j))) } - val similarities = for (j <- 0 until i) yield { - sim(points(i), points(j)) - } - (i.toLong, neighbors.toArray, similarities.toArray) - } + }).flatMap(_.iterator) - spark.createDataFrame(rows).toDF("id", "neighbors", "similarities") + spark.createDataFrame(rows).toDF("src", "dst", "weight") } } diff --git a/project/build.properties b/project/build.properties index b19518fd7aa1c..d03985d980ec8 100644 --- a/project/build.properties +++ b/project/build.properties @@ -14,4 +14,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -sbt.version=0.13.16 +sbt.version=0.13.17 diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 424ecfd89b060..1754c48937a62 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -1131,6 +1131,13 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return RandomForestClassificationModel(java_model) + @since("2.4.0") + def setFeatureSubsetStrategy(self, value): + """ + Sets the value of :py:attr:`featureSubsetStrategy`. + """ + return self._set(featureSubsetStrategy=value) + class RandomForestClassificationModel(TreeEnsembleModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable): @@ -1193,6 +1200,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol >>> si_model = stringIndexer.fit(df) >>> td = si_model.transform(df) >>> gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol="indexed", seed=42) + >>> gbt.getFeatureSubsetStrategy() + 'all' >>> model = gbt.fit(td) >>> model.featureImportances SparseVector(1, {0: 1.0}) @@ -1226,6 +1235,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol ... ["indexed", "features"]) >>> model.evaluateEachIteration(validation) [0.25..., 0.23..., 0.21..., 0.19..., 0.18...] + >>> model.numClasses + 2 .. versionadded:: 1.4.0 """ @@ -1244,19 +1255,22 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic", - maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0): + maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, + featureSubsetStrategy="all"): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ - lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0) + lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, \ + featureSubsetStrategy="all") """ super(GBTClassifier, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.GBTClassifier", self.uid) self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, - lossType="logistic", maxIter=20, stepSize=0.1, subsamplingRate=1.0) + lossType="logistic", maxIter=20, stepSize=0.1, subsamplingRate=1.0, + featureSubsetStrategy="all") kwargs = self._input_kwargs self.setParams(**kwargs) @@ -1265,12 +1279,14 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, - lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0): + lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, + featureSubsetStrategy="all"): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ - lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0) + lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, \ + featureSubsetStrategy="all") Sets params for Gradient Boosted Tree Classification. """ kwargs = self._input_kwargs @@ -1293,8 +1309,15 @@ def getLossType(self): """ return self.getOrDefault(self.lossType) + @since("2.4.0") + def setFeatureSubsetStrategy(self, value): + """ + Sets the value of :py:attr:`featureSubsetStrategy`. + """ + return self._set(featureSubsetStrategy=value) + -class GBTClassificationModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable, +class GBTClassificationModel(TreeEnsembleModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable): """ Model fitted by GBTClassifier. diff --git a/python/pyspark/ml/fpm.py b/python/pyspark/ml/fpm.py index b8dafd49d354d..fd19fd96c4df6 100644 --- a/python/pyspark/ml/fpm.py +++ b/python/pyspark/ml/fpm.py @@ -16,8 +16,9 @@ # from pyspark import keyword_only, since +from pyspark.sql import DataFrame from pyspark.ml.util import * -from pyspark.ml.wrapper import JavaEstimator, JavaModel +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, _jvm from pyspark.ml.param.shared import * __all__ = ["FPGrowth", "FPGrowthModel"] @@ -243,3 +244,104 @@ def setParams(self, minSupport=0.3, minConfidence=0.8, itemsCol="items", def _create_model(self, java_model): return FPGrowthModel(java_model) + + +class PrefixSpan(JavaParams): + """ + .. note:: Experimental + + A parallel PrefixSpan algorithm to mine frequent sequential patterns. + The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns + Efficiently by Prefix-Projected Pattern Growth + (see here). + This class is not yet an Estimator/Transformer, use :py:func:`findFrequentSequentialPatterns` + method to run the PrefixSpan algorithm. + + @see Sequential Pattern Mining + (Wikipedia) + .. versionadded:: 2.4.0 + + """ + + minSupport = Param(Params._dummy(), "minSupport", "The minimal support level of the " + + "sequential pattern. Sequential pattern that appears more than " + + "(minSupport * size-of-the-dataset) times will be output. Must be >= 0.", + typeConverter=TypeConverters.toFloat) + + maxPatternLength = Param(Params._dummy(), "maxPatternLength", + "The maximal length of the sequential pattern. Must be > 0.", + typeConverter=TypeConverters.toInt) + + maxLocalProjDBSize = Param(Params._dummy(), "maxLocalProjDBSize", + "The maximum number of items (including delimiters used in the " + + "internal storage format) allowed in a projected database before " + + "local processing. If a projected database exceeds this size, " + + "another iteration of distributed prefix growth is run. " + + "Must be > 0.", + typeConverter=TypeConverters.toInt) + + sequenceCol = Param(Params._dummy(), "sequenceCol", "The name of the sequence column in " + + "dataset, rows with nulls in this column are ignored.", + typeConverter=TypeConverters.toString) + + @keyword_only + def __init__(self, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, + sequenceCol="sequence"): + """ + __init__(self, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, \ + sequenceCol="sequence") + """ + super(PrefixSpan, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.fpm.PrefixSpan", self.uid) + self._setDefault(minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, + sequenceCol="sequence") + kwargs = self._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("2.4.0") + def setParams(self, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, + sequenceCol="sequence"): + """ + setParams(self, minSupport=0.1, maxPatternLength=10, maxLocalProjDBSize=32000000, \ + sequenceCol="sequence") + """ + kwargs = self._input_kwargs + return self._set(**kwargs) + + @since("2.4.0") + def findFrequentSequentialPatterns(self, dataset): + """ + .. note:: Experimental + Finds the complete set of frequent sequential patterns in the input sequences of itemsets. + + :param dataset: A dataframe containing a sequence column which is + `ArrayType(ArrayType(T))` type, T is the item type for the input dataset. + :return: A `DataFrame` that contains columns of sequence and corresponding frequency. + The schema of it will be: + - `sequence: ArrayType(ArrayType(T))` (T is the item type) + - `freq: Long` + + >>> from pyspark.ml.fpm import PrefixSpan + >>> from pyspark.sql import Row + >>> df = sc.parallelize([Row(sequence=[[1, 2], [3]]), + ... Row(sequence=[[1], [3, 2], [1, 2]]), + ... Row(sequence=[[1, 2], [5]]), + ... Row(sequence=[[6]])]).toDF() + >>> prefixSpan = PrefixSpan(minSupport=0.5, maxPatternLength=5) + >>> prefixSpan.findFrequentSequentialPatterns(df).sort("sequence").show(truncate=False) + +----------+----+ + |sequence |freq| + +----------+----+ + |[[1]] |3 | + |[[1], [3]]|2 | + |[[1, 2]] |3 | + |[[2]] |3 | + |[[3]] |2 | + +----------+----+ + + .. versionadded:: 2.4.0 + """ + self._transfer_params_to_java() + jdf = self._java_obj.findFrequentSequentialPatterns(dataset._jdf) + return DataFrame(jdf, dataset.sql_ctx) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index dd0b62f184d26..dba0e57b01a0b 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -602,6 +602,14 @@ class TreeEnsembleParams(DecisionTreeParams): "used for learning each decision tree, in range (0, 1].", typeConverter=TypeConverters.toFloat) + supportedFeatureSubsetStrategies = ["auto", "all", "onethird", "sqrt", "log2"] + + featureSubsetStrategy = \ + Param(Params._dummy(), "featureSubsetStrategy", + "The number of features to consider for splits at each tree node. Supported " + + "options: " + ", ".join(supportedFeatureSubsetStrategies) + ", (0.0-1.0], [1-n].", + typeConverter=TypeConverters.toString) + def __init__(self): super(TreeEnsembleParams, self).__init__() @@ -619,6 +627,22 @@ def getSubsamplingRate(self): """ return self.getOrDefault(self.subsamplingRate) + @since("1.4.0") + def setFeatureSubsetStrategy(self, value): + """ + Sets the value of :py:attr:`featureSubsetStrategy`. + + .. note:: Deprecated in 2.4.0 and will be removed in 3.0.0. + """ + return self._set(featureSubsetStrategy=value) + + @since("1.4.0") + def getFeatureSubsetStrategy(self): + """ + Gets the value of featureSubsetStrategy or its default value. + """ + return self.getOrDefault(self.featureSubsetStrategy) + class TreeRegressorParams(Params): """ @@ -654,14 +678,8 @@ class RandomForestParams(TreeEnsembleParams): Private class to track supported random forest parameters. """ - supportedFeatureSubsetStrategies = ["auto", "all", "onethird", "sqrt", "log2"] numTrees = Param(Params._dummy(), "numTrees", "Number of trees to train (>= 1).", typeConverter=TypeConverters.toInt) - featureSubsetStrategy = \ - Param(Params._dummy(), "featureSubsetStrategy", - "The number of features to consider for splits at each tree node. Supported " + - "options: " + ", ".join(supportedFeatureSubsetStrategies) + ", (0.0-1.0], [1-n].", - typeConverter=TypeConverters.toString) def __init__(self): super(RandomForestParams, self).__init__() @@ -680,20 +698,6 @@ def getNumTrees(self): """ return self.getOrDefault(self.numTrees) - @since("1.4.0") - def setFeatureSubsetStrategy(self, value): - """ - Sets the value of :py:attr:`featureSubsetStrategy`. - """ - return self._set(featureSubsetStrategy=value) - - @since("1.4.0") - def getFeatureSubsetStrategy(self): - """ - Gets the value of featureSubsetStrategy or its default value. - """ - return self.getOrDefault(self.featureSubsetStrategy) - class GBTParams(TreeEnsembleParams): """ @@ -981,6 +985,13 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return RandomForestRegressionModel(java_model) + @since("2.4.0") + def setFeatureSubsetStrategy(self, value): + """ + Sets the value of :py:attr:`featureSubsetStrategy`. + """ + return self._set(featureSubsetStrategy=value) + class RandomForestRegressionModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable, JavaMLReadable): @@ -1029,6 +1040,8 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, >>> gbt = GBTRegressor(maxIter=5, maxDepth=2, seed=42) >>> print(gbt.getImpurity()) variance + >>> print(gbt.getFeatureSubsetStrategy()) + all >>> model = gbt.fit(df) >>> model.featureImportances SparseVector(1, {0: 1.0}) @@ -1079,20 +1092,20 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, - impurity="variance"): + impurity="variance", featureSubsetStrategy="all"): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \ checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, \ - impurity="variance") + impurity="variance", featureSubsetStrategy="all") """ super(GBTRegressor, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.regression.GBTRegressor", self.uid) self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, - impurity="variance") + impurity="variance", featureSubsetStrategy="all") kwargs = self._input_kwargs self.setParams(**kwargs) @@ -1102,13 +1115,13 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, - impuriy="variance"): + impuriy="variance", featureSubsetStrategy="all"): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \ checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, \ - impurity="variance") + impurity="variance", featureSubsetStrategy="all") Sets params for Gradient Boosted Tree Regression. """ kwargs = self._input_kwargs @@ -1131,6 +1144,13 @@ def getLossType(self): """ return self.getOrDefault(self.lossType) + @since("2.4.0") + def setFeatureSubsetStrategy(self, value): + """ + Sets the value of :py:attr:`featureSubsetStrategy`. + """ + return self._set(featureSubsetStrategy=value) + class GBTRegressionModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable, JavaMLReadable): """ diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index d5a237a5b2855..14d9128502ab0 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -53,6 +53,7 @@ from pyspark.shuffle import Aggregator, ExternalMerger, \ get_used_memory, ExternalSorter, ExternalGroupBy from pyspark.traceback_utils import SCCallSiteSync +from pyspark.util import fail_on_stopiteration __all__ = ["RDD"] @@ -339,7 +340,7 @@ def map(self, f, preservesPartitioning=False): [('a', 1), ('b', 1), ('c', 1)] """ def func(_, iterator): - return map(f, iterator) + return map(fail_on_stopiteration(f), iterator) return self.mapPartitionsWithIndex(func, preservesPartitioning) def flatMap(self, f, preservesPartitioning=False): @@ -354,7 +355,7 @@ def flatMap(self, f, preservesPartitioning=False): [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)] """ def func(s, iterator): - return chain.from_iterable(map(f, iterator)) + return chain.from_iterable(map(fail_on_stopiteration(f), iterator)) return self.mapPartitionsWithIndex(func, preservesPartitioning) def mapPartitions(self, f, preservesPartitioning=False): @@ -417,7 +418,7 @@ def filter(self, f): [2, 4] """ def func(iterator): - return filter(f, iterator) + return filter(fail_on_stopiteration(f), iterator) return self.mapPartitions(func, True) def distinct(self, numPartitions=None): @@ -798,6 +799,8 @@ def foreach(self, f): >>> def f(x): print(x) >>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f) """ + f = fail_on_stopiteration(f) + def processPartition(iterator): for x in iterator: f(x) @@ -847,6 +850,8 @@ def reduce(self, f): ... ValueError: Can not reduce() empty RDD """ + f = fail_on_stopiteration(f) + def func(iterator): iterator = iter(iterator) try: @@ -918,6 +923,8 @@ def fold(self, zeroValue, op): >>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add) 15 """ + op = fail_on_stopiteration(op) + def func(iterator): acc = zeroValue for obj in iterator: @@ -950,6 +957,9 @@ def aggregate(self, zeroValue, seqOp, combOp): >>> sc.parallelize([]).aggregate((0, 0), seqOp, combOp) (0, 0) """ + seqOp = fail_on_stopiteration(seqOp) + combOp = fail_on_stopiteration(combOp) + def func(iterator): acc = zeroValue for obj in iterator: @@ -1643,6 +1653,8 @@ def reduceByKeyLocally(self, func): >>> sorted(rdd.reduceByKeyLocally(add).items()) [('a', 2), ('b', 1)] """ + func = fail_on_stopiteration(func) + def reducePartition(iterator): m = {} for k, v in iterator: diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index b5fcf7092d93a..472c3cd4452f0 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -38,25 +38,13 @@ SparkContext._ensure_initialized() try: - # Try to access HiveConf, it will raise exception if Hive is not added - conf = SparkConf() - if conf.get('spark.sql.catalogImplementation', 'hive').lower() == 'hive': - SparkContext._jvm.org.apache.hadoop.hive.conf.HiveConf() - spark = SparkSession.builder\ - .enableHiveSupport()\ - .getOrCreate() - else: - spark = SparkSession.builder.getOrCreate() -except py4j.protocol.Py4JError: - if conf.get('spark.sql.catalogImplementation', '').lower() == 'hive': - warnings.warn("Fall back to non-hive support because failing to access HiveConf, " - "please make sure you build spark with hive") - spark = SparkSession.builder.getOrCreate() -except TypeError: - if conf.get('spark.sql.catalogImplementation', '').lower() == 'hive': - warnings.warn("Fall back to non-hive support because failing to access HiveConf, " - "please make sure you build spark with hive") - spark = SparkSession.builder.getOrCreate() + spark = SparkSession._create_shell_session() +except Exception: + import sys + import traceback + warnings.warn("Failed to initialize Spark session.") + traceback.print_exc(file=sys.stderr) + sys.exit(1) sc = spark.sparkContext sql = spark.sql diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 02c773302e9da..bd0ac0039ffe1 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -28,6 +28,7 @@ import pyspark.heapq3 as heapq from pyspark.serializers import BatchedSerializer, PickleSerializer, FlattenedValuesSerializer, \ CompressedSerializer, AutoBatchedSerializer +from pyspark.util import fail_on_stopiteration try: @@ -94,9 +95,9 @@ class Aggregator(object): """ def __init__(self, createCombiner, mergeValue, mergeCombiners): - self.createCombiner = createCombiner - self.mergeValue = mergeValue - self.mergeCombiners = mergeCombiners + self.createCombiner = fail_on_stopiteration(createCombiner) + self.mergeValue = fail_on_stopiteration(mergeValue) + self.mergeCombiners = fail_on_stopiteration(mergeCombiners) class SimpleAggregator(Aggregator): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 808235ab25440..1e6a1acebb5ca 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -78,6 +78,9 @@ def __init__(self, jdf, sql_ctx): self.is_cached = False self._schema = None # initialized lazily self._lazy_rdd = None + # Check whether _repr_html is supported or not, we use it to avoid calling _jdf twice + # by __repr__ and _repr_html_ while eager evaluation opened. + self._support_repr_html = False @property @since(1.3) @@ -351,8 +354,68 @@ def show(self, n=20, truncate=True, vertical=False): else: print(self._jdf.showString(n, int(truncate), vertical)) + @property + def _eager_eval(self): + """Returns true if the eager evaluation enabled. + """ + return self.sql_ctx.getConf( + "spark.sql.repl.eagerEval.enabled", "false").lower() == "true" + + @property + def _max_num_rows(self): + """Returns the max row number for eager evaluation. + """ + return int(self.sql_ctx.getConf( + "spark.sql.repl.eagerEval.maxNumRows", "20")) + + @property + def _truncate(self): + """Returns the truncate length for eager evaluation. + """ + return int(self.sql_ctx.getConf( + "spark.sql.repl.eagerEval.truncate", "20")) + def __repr__(self): - return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes)) + if not self._support_repr_html and self._eager_eval: + vertical = False + return self._jdf.showString( + self._max_num_rows, self._truncate, vertical) + else: + return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes)) + + def _repr_html_(self): + """Returns a dataframe with html code when you enabled eager evaluation + by 'spark.sql.repl.eagerEval.enabled', this only called by REPL you are + using support eager evaluation with HTML. + """ + import cgi + if not self._support_repr_html: + self._support_repr_html = True + if self._eager_eval: + max_num_rows = max(self._max_num_rows, 0) + vertical = False + sock_info = self._jdf.getRowsToPython( + max_num_rows, self._truncate, vertical) + rows = list(_load_from_socket(sock_info, BatchedSerializer(PickleSerializer()))) + head = rows[0] + row_data = rows[1:] + has_more_data = len(row_data) > max_num_rows + row_data = row_data[:max_num_rows] + + html = "\n" + # generate table head + html += "\n" % "\n" % "
%s
".join(map(lambda x: cgi.escape(x), head)) + # generate table rows + for row in row_data: + html += "
%s
".join( + map(lambda x: cgi.escape(x), row)) + html += "
\n" + if has_more_data: + html += "only showing top %d %s\n" % ( + max_num_rows, "row" if max_num_rows == 1 else "rows") + return html + else: + return None @since(2.1) def checkpoint(self, eager=True): diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index efcce25a08e04..1759195c6fcc0 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1964,6 +1964,22 @@ def element_at(col, extraction): return Column(sc._jvm.functions.element_at(_to_java_column(col), extraction)) +@since(2.4) +def array_remove(col, element): + """ + Collection function: Remove all elements that equal to element from the given array. + + :param col: name of column containing array + :param element: element to be removed from the array + + >>> df = spark.createDataFrame([([1, 2, 3, 1, 1],), ([],)], ['data']) + >>> df.select(array_remove(df.data, 1)).collect() + [Row(array_remove(data, 1)=[2, 3]), Row(array_remove(data, 1)=[])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_remove(_to_java_column(col), element)) + + @since(1.4) def explode(col): """Returns a new row for each element in the given array or map. @@ -2500,7 +2516,8 @@ def pandas_udf(f=None, returnType=None, functionType=None): A grouped map UDF defines transformation: A `pandas.DataFrame` -> A `pandas.DataFrame` The returnType should be a :class:`StructType` describing the schema of the returned `pandas.DataFrame`. - The length of the returned `pandas.DataFrame` can be arbitrary. + The length of the returned `pandas.DataFrame` can be arbitrary and the columns must be + indexed so that their position matches the corresponding field in the schema. Grouped map UDFs are used with :meth:`pyspark.sql.GroupedData.apply`. @@ -2548,6 +2565,12 @@ def pandas_udf(f=None, returnType=None, functionType=None): | 2|6.0| +---+---+ + .. note:: If returning a new `pandas.DataFrame` constructed with a dictionary, it is + recommended to explicitly index the columns by name to ensure the positions are correct, + or alternatively use an `OrderedDict`. + For example, `pd.DataFrame({'id': ids, 'a': data}, columns=['id', 'a'])` or + `pd.DataFrame(OrderedDict([('id', ids), ('a', data)]))`. + .. seealso:: :meth:`pyspark.sql.GroupedData.apply` 3. GROUPED_AGG diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 448a4732001b5..a0e20d39c20da 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -346,7 +346,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None, - samplingRatio=None): + samplingRatio=None, enforceSchema=None): """Loads a CSV file and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -373,6 +373,16 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non default value, ``false``. :param inferSchema: infers the input schema automatically from data. It requires one extra pass over the data. If None is set, it uses the default value, ``false``. + :param enforceSchema: If it is set to ``true``, the specified or inferred schema will be + forcibly applied to datasource files, and headers in CSV files will be + ignored. If the option is set to ``false``, the schema will be + validated against all headers in CSV files or the first header in RDD + if the ``header`` option is set to ``true``. Field names in the schema + and column names in CSV headers are checked by their positions + taking into account ``spark.sql.caseSensitive``. If None is set, + ``true`` is used by default. Though the default value is ``true``, + it is recommended to disable the ``enforceSchema`` option + to avoid incorrect results. :param ignoreLeadingWhiteSpace: A flag indicating whether or not leading whitespaces from values being read should be skipped. If None is set, it uses the default value, ``false``. @@ -449,7 +459,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non maxCharsPerColumn=maxCharsPerColumn, maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine, - charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, samplingRatio=samplingRatio) + charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, samplingRatio=samplingRatio, + enforceSchema=enforceSchema) if isinstance(path, basestring): path = [path] if type(path) == list: diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index d675a240172a7..e880dd1ca6d1a 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -547,6 +547,40 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone): df._schema = schema return df + @staticmethod + def _create_shell_session(): + """ + Initialize a SparkSession for a pyspark shell session. This is called from shell.py + to make error handling simpler without needing to declare local variables in that + script, which would expose those to users. + """ + import py4j + from pyspark.conf import SparkConf + from pyspark.context import SparkContext + try: + # Try to access HiveConf, it will raise exception if Hive is not added + conf = SparkConf() + if conf.get('spark.sql.catalogImplementation', 'hive').lower() == 'hive': + SparkContext._jvm.org.apache.hadoop.hive.conf.HiveConf() + return SparkSession.builder\ + .enableHiveSupport()\ + .getOrCreate() + else: + return SparkSession.builder.getOrCreate() + except py4j.protocol.Py4JError: + if conf.get('spark.sql.catalogImplementation', '').lower() == 'hive': + warnings.warn("Fall back to non-hive support because failing to access HiveConf, " + "please make sure you build spark with hive") + + try: + return SparkSession.builder.getOrCreate() + except TypeError: + if conf.get('spark.sql.catalogImplementation', '').lower() == 'hive': + warnings.warn("Fall back to non-hive support because failing to access HiveConf, " + "please make sure you build spark with hive") + + return SparkSession.builder.getOrCreate() + @since(2.0) @ignore_unicode_prefix def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True): diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 15f9407389864..fae50b3d5d532 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -564,7 +564,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None, - columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None): + columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None, + enforceSchema=None): """Loads a CSV file stream and returns the result as a :class:`DataFrame`. This function will go through the input once to determine the input schema if @@ -592,6 +593,16 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non default value, ``false``. :param inferSchema: infers the input schema automatically from data. It requires one extra pass over the data. If None is set, it uses the default value, ``false``. + :param enforceSchema: If it is set to ``true``, the specified or inferred schema will be + forcibly applied to datasource files, and headers in CSV files will be + ignored. If the option is set to ``false``, the schema will be + validated against all headers in CSV files or the first header in RDD + if the ``header`` option is set to ``true``. Field names in the schema + and column names in CSV headers are checked by their positions + taking into account ``spark.sql.caseSensitive``. If None is set, + ``true`` is used by default. Though the default value is ``true``, + it is recommended to disable the ``enforceSchema`` option + to avoid incorrect results. :param ignoreLeadingWhiteSpace: a flag indicating whether or not leading whitespaces from values being read should be skipped. If None is set, it uses the default value, ``false``. @@ -664,7 +675,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non maxCharsPerColumn=maxCharsPerColumn, maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine, - charToEscapeQuoteEscaping=charToEscapeQuoteEscaping) + charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, enforceSchema=enforceSchema) if isinstance(path, basestring): return self._df(self._jreader.csv(path)) else: diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index c7bd8f01b907f..487eb19c3b98a 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -900,6 +900,22 @@ def __call__(self, x): self.assertEqual(f, f_.func) self.assertEqual(return_type, f_.returnType) + def test_stopiteration_in_udf(self): + # test for SPARK-23754 + from pyspark.sql.functions import udf + from py4j.protocol import Py4JJavaError + + def foo(x): + raise StopIteration() + + with self.assertRaises(Py4JJavaError) as cm: + self.spark.range(0, 1000).withColumn('v', udf(foo)('id')).show() + + self.assertIn( + "Caught StopIteration thrown from user's code; failing the task", + cm.exception.java_exception.toString() + ) + def test_validate_column_types(self): from pyspark.sql.functions import udf, to_json from pyspark.sql.column import _to_java_column @@ -3040,6 +3056,54 @@ def test_csv_sampling_ratio(self): .csv(rdd, samplingRatio=0.5).schema self.assertEquals(schema, StructType([StructField("_c0", IntegerType(), True)])) + def test_checking_csv_header(self): + path = tempfile.mkdtemp() + shutil.rmtree(path) + try: + self.spark.createDataFrame([[1, 1000], [2000, 2]])\ + .toDF('f1', 'f2').write.option("header", "true").csv(path) + schema = StructType([ + StructField('f2', IntegerType(), nullable=True), + StructField('f1', IntegerType(), nullable=True)]) + df = self.spark.read.option('header', 'true').schema(schema)\ + .csv(path, enforceSchema=False) + self.assertRaisesRegexp( + Exception, + "CSV header does not conform to the schema", + lambda: df.collect()) + finally: + shutil.rmtree(path) + + def test_repr_html(self): + import re + pattern = re.compile(r'^ *\|', re.MULTILINE) + df = self.spark.createDataFrame([(1, "1"), (22222, "22222")], ("key", "value")) + self.assertEquals(None, df._repr_html_()) + with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}): + expected1 = """ + | + | + | + |
keyvalue
11
2222222222
+ |""" + self.assertEquals(re.sub(pattern, '', expected1), df._repr_html_()) + with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}): + expected2 = """ + | + | + | + |
keyvalue
11
222222
+ |""" + self.assertEquals(re.sub(pattern, '', expected2), df._repr_html_()) + with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}): + expected3 = """ + | + | + |
keyvalue
11
+ |only showing top 1 row + |""" + self.assertEquals(re.sub(pattern, '', expected3), df._repr_html_()) + class HiveSparkSubmitTests(SparkSubmitTests): diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 9dbe49b831cef..c8fb49d7c2b65 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -25,7 +25,7 @@ from pyspark.sql.column import Column, _to_java_column, _to_seq from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string,\ to_arrow_type, to_arrow_schema -from pyspark.util import _get_argspec +from pyspark.util import _get_argspec, fail_on_stopiteration __all__ = ["UDFRegistration"] @@ -157,7 +157,17 @@ def _create_judf(self): spark = SparkSession.builder.getOrCreate() sc = spark.sparkContext - wrapped_func = _wrap_function(sc, self.func, self.returnType) + func = fail_on_stopiteration(self.func) + + # for pandas UDFs the worker needs to know if the function takes + # one or two arguments, but the signature is lost when wrapping with + # fail_on_stopiteration, so we store it here + if self.evalType in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF): + func._argspec = _get_argspec(self.func) + + wrapped_func = _wrap_function(sc, func, self.returnType) jdt = spark._jsparkSession.parseDataType(self.returnType.json()) judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( self._name, wrapped_func, jdt, self.evalType, self.deterministic) diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py index e5218d9e75e78..63ae1f30e17ca 100644 --- a/python/pyspark/taskcontext.py +++ b/python/pyspark/taskcontext.py @@ -34,6 +34,7 @@ class TaskContext(object): _partitionId = None _stageId = None _taskAttemptId = None + _localProperties = None def __new__(cls): """Even if users construct TaskContext instead of using get, give them the singleton.""" @@ -88,3 +89,9 @@ def taskAttemptId(self): TaskAttemptID. """ return self._taskAttemptId + + def getLocalProperty(self, key): + """ + Get a local property set upstream in the driver, or None if it is missing. + """ + return self._localProperties.get(key, None) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 498d6b57e4353..30723b8e15b36 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -161,6 +161,37 @@ def gen_gs(N, step=1): self.assertEqual(k, len(vs)) self.assertEqual(list(range(k)), list(vs)) + def test_stopiteration_is_raised(self): + + def stopit(*args, **kwargs): + raise StopIteration() + + def legit_create_combiner(x): + return [x] + + def legit_merge_value(x, y): + return x.append(y) or x + + def legit_merge_combiners(x, y): + return x.extend(y) or x + + data = [(x % 2, x) for x in range(100)] + + # wrong create combiner + m = ExternalMerger(Aggregator(stopit, legit_merge_value, legit_merge_combiners), 20) + with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: + m.mergeValues(data) + + # wrong merge value + m = ExternalMerger(Aggregator(legit_create_combiner, stopit, legit_merge_combiners), 20) + with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: + m.mergeValues(data) + + # wrong merge combiners + m = ExternalMerger(Aggregator(legit_create_combiner, legit_merge_value, stopit), 20) + with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: + m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), data)) + class SorterTests(unittest.TestCase): def test_in_memory_sort(self): @@ -543,6 +574,20 @@ def test_tc_on_driver(self): tc = TaskContext.get() self.assertTrue(tc is None) + def test_get_local_property(self): + """Verify that local properties set on the driver are available in TaskContext.""" + key = "testkey" + value = "testvalue" + self.sc.setLocalProperty(key, value) + try: + rdd = self.sc.parallelize(range(1), 1) + prop1 = rdd.map(lambda x: TaskContext.get().getLocalProperty(key)).collect()[0] + self.assertEqual(prop1, value) + prop2 = rdd.map(lambda x: TaskContext.get().getLocalProperty("otherkey")).collect()[0] + self.assertTrue(prop2 is None) + finally: + self.sc.setLocalProperty(key, None) + class RDDTests(ReusedPySparkTestCase): @@ -1246,6 +1291,28 @@ def test_pipe_unicode(self): result = rdd.pipe('cat').collect() self.assertEqual(data, result) + def test_stopiteration_in_client_code(self): + + def stopit(*x): + raise StopIteration() + + seq_rdd = self.sc.parallelize(range(10)) + keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10)) + + self.assertRaises(Py4JJavaError, seq_rdd.map(stopit).collect) + self.assertRaises(Py4JJavaError, seq_rdd.filter(stopit).collect) + self.assertRaises(Py4JJavaError, seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect) + self.assertRaises(Py4JJavaError, seq_rdd.foreach, stopit) + self.assertRaises(Py4JJavaError, keyed_rdd.reduceByKeyLocally, stopit) + self.assertRaises(Py4JJavaError, seq_rdd.reduce, stopit) + self.assertRaises(Py4JJavaError, seq_rdd.fold, 0, stopit) + + # the exception raised is non-deterministic + self.assertRaises((Py4JJavaError, RuntimeError), + seq_rdd.aggregate, 0, stopit, lambda *x: 1) + self.assertRaises((Py4JJavaError, RuntimeError), + seq_rdd.aggregate, 0, lambda *x: 1, stopit) + class ProfilerTests(PySparkTestCase): diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 59cc2a6329350..e95a9b523393f 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -53,11 +53,16 @@ def _get_argspec(f): """ Get argspec of a function. Supports both Python 2 and Python 3. """ - # `getargspec` is deprecated since python3.0 (incompatible with function annotations). - # See SPARK-23569. - if sys.version_info[0] < 3: + + if hasattr(f, '_argspec'): + # only used for pandas UDF: they wrap the user function, losing its signature + # workers need this signature, so UDF saves it here + argspec = f._argspec + elif sys.version_info[0] < 3: argspec = inspect.getargspec(f) else: + # `getargspec` is deprecated since python3.0 (incompatible with function annotations). + # See SPARK-23569. argspec = inspect.getfullargspec(f) return argspec @@ -89,6 +94,23 @@ def majorMinorVersion(sparkVersion): " version numbers.") +def fail_on_stopiteration(f): + """ + Wraps the input function to fail on 'StopIteration' by raising a 'RuntimeError' + prevents silent loss of data when 'f' is used in a for loop + """ + def wrapper(*args, **kwargs): + try: + return f(*args, **kwargs) + except StopIteration as exc: + raise RuntimeError( + "Caught StopIteration thrown from user's code; failing the task", + exc + ) + + return wrapper + + if __name__ == "__main__": import doctest (failure_count, test_count) = doctest.testmod() diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 5d2e58bef6466..fbcb8af8bfb24 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -222,6 +222,12 @@ def main(infile, outfile): taskContext._partitionId = read_int(infile) taskContext._attemptNumber = read_int(infile) taskContext._taskAttemptId = read_long(infile) + taskContext._localProperties = dict() + for i in range(read_int(infile)): + k = utf8_deserializer.loads(infile) + v = utf8_deserializer.loads(infile) + taskContext._localProperties[k] = v + shuffle.MemoryBytesSpilled = 0 shuffle.DiskBytesSpilled = 0 _accumulatorRegistry.clear() diff --git a/repl/src/main/scala/org/apache/spark/repl/Main.scala b/repl/src/main/scala/org/apache/spark/repl/Main.scala index cc76a703bdf8f..e4ddcef9772e4 100644 --- a/repl/src/main/scala/org/apache/spark/repl/Main.scala +++ b/repl/src/main/scala/org/apache/spark/repl/Main.scala @@ -44,6 +44,7 @@ object Main extends Logging { var interp: SparkILoop = _ private var hasErrors = false + private var isShellSession = false private def scalaOptionError(msg: String): Unit = { hasErrors = true @@ -53,6 +54,7 @@ object Main extends Logging { } def main(args: Array[String]) { + isShellSession = true doMain(args, new SparkILoop) } @@ -79,44 +81,50 @@ object Main extends Logging { } def createSparkSession(): SparkSession = { - val execUri = System.getenv("SPARK_EXECUTOR_URI") - conf.setIfMissing("spark.app.name", "Spark shell") - // SparkContext will detect this configuration and register it with the RpcEnv's - // file server, setting spark.repl.class.uri to the actual URI for executors to - // use. This is sort of ugly but since executors are started as part of SparkContext - // initialization in certain cases, there's an initialization order issue that prevents - // this from being set after SparkContext is instantiated. - conf.set("spark.repl.class.outputDir", outputDir.getAbsolutePath()) - if (execUri != null) { - conf.set("spark.executor.uri", execUri) - } - if (System.getenv("SPARK_HOME") != null) { - conf.setSparkHome(System.getenv("SPARK_HOME")) - } + try { + val execUri = System.getenv("SPARK_EXECUTOR_URI") + conf.setIfMissing("spark.app.name", "Spark shell") + // SparkContext will detect this configuration and register it with the RpcEnv's + // file server, setting spark.repl.class.uri to the actual URI for executors to + // use. This is sort of ugly but since executors are started as part of SparkContext + // initialization in certain cases, there's an initialization order issue that prevents + // this from being set after SparkContext is instantiated. + conf.set("spark.repl.class.outputDir", outputDir.getAbsolutePath()) + if (execUri != null) { + conf.set("spark.executor.uri", execUri) + } + if (System.getenv("SPARK_HOME") != null) { + conf.setSparkHome(System.getenv("SPARK_HOME")) + } - val builder = SparkSession.builder.config(conf) - if (conf.get(CATALOG_IMPLEMENTATION.key, "hive").toLowerCase(Locale.ROOT) == "hive") { - if (SparkSession.hiveClassesArePresent) { - // In the case that the property is not set at all, builder's config - // does not have this value set to 'hive' yet. The original default - // behavior is that when there are hive classes, we use hive catalog. - sparkSession = builder.enableHiveSupport().getOrCreate() - logInfo("Created Spark session with Hive support") + val builder = SparkSession.builder.config(conf) + if (conf.get(CATALOG_IMPLEMENTATION.key, "hive").toLowerCase(Locale.ROOT) == "hive") { + if (SparkSession.hiveClassesArePresent) { + // In the case that the property is not set at all, builder's config + // does not have this value set to 'hive' yet. The original default + // behavior is that when there are hive classes, we use hive catalog. + sparkSession = builder.enableHiveSupport().getOrCreate() + logInfo("Created Spark session with Hive support") + } else { + // Need to change it back to 'in-memory' if no hive classes are found + // in the case that the property is set to hive in spark-defaults.conf + builder.config(CATALOG_IMPLEMENTATION.key, "in-memory") + sparkSession = builder.getOrCreate() + logInfo("Created Spark session") + } } else { - // Need to change it back to 'in-memory' if no hive classes are found - // in the case that the property is set to hive in spark-defaults.conf - builder.config(CATALOG_IMPLEMENTATION.key, "in-memory") + // In the case that the property is set but not to 'hive', the internal + // default is 'in-memory'. So the sparkSession will use in-memory catalog. sparkSession = builder.getOrCreate() logInfo("Created Spark session") } - } else { - // In the case that the property is set but not to 'hive', the internal - // default is 'in-memory'. So the sparkSession will use in-memory catalog. - sparkSession = builder.getOrCreate() - logInfo("Created Spark session") + sparkContext = sparkSession.sparkContext + sparkSession + } catch { + case e: Exception if isShellSession => + logError("Failed to initialize Spark session.", e) + sys.exit(1) } - sparkContext = sparkSession.sparkContext - sparkSession } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index 4086970ffb256..560dedf431b08 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -162,10 +162,12 @@ private[spark] object Config extends Logging { val KUBERNETES_DRIVER_LABEL_PREFIX = "spark.kubernetes.driver.label." val KUBERNETES_DRIVER_ANNOTATION_PREFIX = "spark.kubernetes.driver.annotation." val KUBERNETES_DRIVER_SECRETS_PREFIX = "spark.kubernetes.driver.secrets." + val KUBERNETES_DRIVER_SECRET_KEY_REF_PREFIX = "spark.kubernetes.driver.secretKeyRef." val KUBERNETES_EXECUTOR_LABEL_PREFIX = "spark.kubernetes.executor.label." val KUBERNETES_EXECUTOR_ANNOTATION_PREFIX = "spark.kubernetes.executor.annotation." val KUBERNETES_EXECUTOR_SECRETS_PREFIX = "spark.kubernetes.executor.secrets." + val KUBERNETES_EXECUTOR_SECRET_KEY_REF_PREFIX = "spark.kubernetes.executor.secretKeyRef." val KUBERNETES_DRIVER_ENV_PREFIX = "spark.kubernetes.driverEnv." } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala index 77b634ddfabcc..5a944187a7096 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala @@ -54,6 +54,7 @@ private[spark] case class KubernetesConf[T <: KubernetesRoleSpecificConf]( roleLabels: Map[String, String], roleAnnotations: Map[String, String], roleSecretNamesToMountPaths: Map[String, String], + roleSecretEnvNamesToKeyRefs: Map[String, String], roleEnvs: Map[String, String]) { def namespace(): String = sparkConf.get(KUBERNETES_NAMESPACE) @@ -129,6 +130,8 @@ private[spark] object KubernetesConf { sparkConf, KUBERNETES_DRIVER_ANNOTATION_PREFIX) val driverSecretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs( sparkConf, KUBERNETES_DRIVER_SECRETS_PREFIX) + val driverSecretEnvNamesToKeyRefs = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_DRIVER_SECRET_KEY_REF_PREFIX) val driverEnvs = KubernetesUtils.parsePrefixedKeyValuePairs( sparkConf, KUBERNETES_DRIVER_ENV_PREFIX) @@ -140,6 +143,7 @@ private[spark] object KubernetesConf { driverLabels, driverAnnotations, driverSecretNamesToMountPaths, + driverSecretEnvNamesToKeyRefs, driverEnvs) } @@ -167,8 +171,10 @@ private[spark] object KubernetesConf { executorCustomLabels val executorAnnotations = KubernetesUtils.parsePrefixedKeyValuePairs( sparkConf, KUBERNETES_EXECUTOR_ANNOTATION_PREFIX) - val executorSecrets = KubernetesUtils.parsePrefixedKeyValuePairs( + val executorMountSecrets = KubernetesUtils.parsePrefixedKeyValuePairs( sparkConf, KUBERNETES_EXECUTOR_SECRETS_PREFIX) + val executorEnvSecrets = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_EXECUTOR_SECRET_KEY_REF_PREFIX) val executorEnv = sparkConf.getExecutorEnv.toMap KubernetesConf( @@ -178,7 +184,8 @@ private[spark] object KubernetesConf { appId, executorLabels, executorAnnotations, - executorSecrets, + executorMountSecrets, + executorEnvSecrets, executorEnv) } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStep.scala new file mode 100644 index 0000000000000..03ff7d48420ff --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStep.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, EnvVarBuilder, HasMetadata} + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesRoleSpecificConf, SparkPod} + +private[spark] class EnvSecretsFeatureStep( + kubernetesConf: KubernetesConf[_ <: KubernetesRoleSpecificConf]) + extends KubernetesFeatureConfigStep { + override def configurePod(pod: SparkPod): SparkPod = { + val addedEnvSecrets = kubernetesConf + .roleSecretEnvNamesToKeyRefs + .map{ case (envName, keyRef) => + // Keyref parts + val keyRefParts = keyRef.split(":") + require(keyRefParts.size == 2, "SecretKeyRef must be in the form name:key.") + val name = keyRefParts(0) + val key = keyRefParts(1) + new EnvVarBuilder() + .withName(envName) + .withNewValueFrom() + .withNewSecretKeyRef() + .withKey(key) + .withName(name) + .endSecretKeyRef() + .endValueFrom() + .build() + } + + val containerWithEnvVars = new ContainerBuilder(pod.container) + .addAllToEnv(addedEnvSecrets.toSeq.asJava) + .build() + SparkPod(pod.pod, containerWithEnvVars) + } + + override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala index 10b0154466a3a..fdc5eb0d75832 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy.k8s.submit import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf, KubernetesRoleSpecificConf} -import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, LocalDirsFeatureStep, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s.features._ private[spark] class KubernetesDriverBuilder( provideBasicStep: (KubernetesConf[KubernetesDriverSpecificConf]) => BasicDriverFeatureStep = @@ -30,6 +30,9 @@ private[spark] class KubernetesDriverBuilder( provideSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] => MountSecretsFeatureStep) = new MountSecretsFeatureStep(_), + provideEnvSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] + => EnvSecretsFeatureStep) = + new EnvSecretsFeatureStep(_), provideLocalDirsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]) => LocalDirsFeatureStep = new LocalDirsFeatureStep(_)) { @@ -41,10 +44,14 @@ private[spark] class KubernetesDriverBuilder( provideCredentialsStep(kubernetesConf), provideServiceStep(kubernetesConf), provideLocalDirsStep(kubernetesConf)) - val allFeatures = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { + var allFeatures = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { baseFeatures ++ Seq(provideSecretsStep(kubernetesConf)) } else baseFeatures + allFeatures = if (kubernetesConf.roleSecretEnvNamesToKeyRefs.nonEmpty) { + allFeatures ++ Seq(provideEnvSecretsStep(kubernetesConf)) + } else allFeatures + var spec = KubernetesDriverSpec.initialSpec(kubernetesConf.sparkConf.getAll.toMap) for (feature <- allFeatures) { val configuredPod = feature.configurePod(spec.pod) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala index d8f63d57574fb..d5e1de36a58df 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler.cluster.k8s import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, KubernetesRoleSpecificConf, SparkPod} -import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, LocalDirsFeatureStep, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, EnvSecretsFeatureStep, LocalDirsFeatureStep, MountSecretsFeatureStep} private[spark] class KubernetesExecutorBuilder( provideBasicStep: (KubernetesConf[KubernetesExecutorSpecificConf]) => BasicExecutorFeatureStep = @@ -25,6 +25,9 @@ private[spark] class KubernetesExecutorBuilder( provideSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]) => MountSecretsFeatureStep = new MountSecretsFeatureStep(_), + provideEnvSecretsStep: + (KubernetesConf[_ <: KubernetesRoleSpecificConf] => EnvSecretsFeatureStep) = + new EnvSecretsFeatureStep(_), provideLocalDirsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf]) => LocalDirsFeatureStep = new LocalDirsFeatureStep(_)) { @@ -32,9 +35,14 @@ private[spark] class KubernetesExecutorBuilder( def buildFromFeatures( kubernetesConf: KubernetesConf[KubernetesExecutorSpecificConf]): SparkPod = { val baseFeatures = Seq(provideBasicStep(kubernetesConf), provideLocalDirsStep(kubernetesConf)) - val allFeatures = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { + var allFeatures = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { baseFeatures ++ Seq(provideSecretsStep(kubernetesConf)) } else baseFeatures + + allFeatures = if (kubernetesConf.roleSecretEnvNamesToKeyRefs.nonEmpty) { + allFeatures ++ Seq(provideEnvSecretsStep(kubernetesConf)) + } else allFeatures + var executorPod = SparkPod.initialPod() for (feature <- allFeatures) { executorPod = feature.configurePod(executorPod) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala index f10202f7a3546..3d23e1cb90fd2 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala @@ -40,6 +40,9 @@ class KubernetesConfSuite extends SparkFunSuite { private val SECRET_NAMES_TO_MOUNT_PATHS = Map( "secret1" -> "/mnt/secrets/secret1", "secret2" -> "/mnt/secrets/secret2") + private val SECRET_ENV_VARS = Map( + "envName1" -> "name1:key1", + "envName2" -> "name2:key2") private val CUSTOM_ENVS = Map( "customEnvKey1" -> "customEnvValue1", "customEnvKey2" -> "customEnvValue2") @@ -103,6 +106,9 @@ class KubernetesConfSuite extends SparkFunSuite { SECRET_NAMES_TO_MOUNT_PATHS.foreach { case (key, value) => sparkConf.set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$key", value) } + SECRET_ENV_VARS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_DRIVER_SECRET_KEY_REF_PREFIX$key", value) + } CUSTOM_ENVS.foreach { case (key, value) => sparkConf.set(s"$KUBERNETES_DRIVER_ENV_PREFIX$key", value) } @@ -121,6 +127,7 @@ class KubernetesConfSuite extends SparkFunSuite { CUSTOM_LABELS) assert(conf.roleAnnotations === CUSTOM_ANNOTATIONS) assert(conf.roleSecretNamesToMountPaths === SECRET_NAMES_TO_MOUNT_PATHS) + assert(conf.roleSecretEnvNamesToKeyRefs === SECRET_ENV_VARS) assert(conf.roleEnvs === CUSTOM_ENVS) } @@ -155,6 +162,9 @@ class KubernetesConfSuite extends SparkFunSuite { CUSTOM_ANNOTATIONS.foreach { case (key, value) => sparkConf.set(s"$KUBERNETES_EXECUTOR_ANNOTATION_PREFIX$key", value) } + SECRET_ENV_VARS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_EXECUTOR_SECRET_KEY_REF_PREFIX$key", value) + } SECRET_NAMES_TO_MOUNT_PATHS.foreach { case (key, value) => sparkConf.set(s"$KUBERNETES_EXECUTOR_SECRETS_PREFIX$key", value) } @@ -170,6 +180,6 @@ class KubernetesConfSuite extends SparkFunSuite { SPARK_ROLE_LABEL -> SPARK_POD_EXECUTOR_ROLE) ++ CUSTOM_LABELS) assert(conf.roleAnnotations === CUSTOM_ANNOTATIONS) assert(conf.roleSecretNamesToMountPaths === SECRET_NAMES_TO_MOUNT_PATHS) + assert(conf.roleSecretEnvNamesToKeyRefs === SECRET_ENV_VARS) } - } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala index eee85b8baa730..b2813d8b3265d 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala @@ -69,6 +69,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { DRIVER_LABELS, DRIVER_ANNOTATIONS, Map.empty, + Map.empty, DRIVER_ENVS) val featureStep = new BasicDriverFeatureStep(kubernetesConf) @@ -138,6 +139,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite { DRIVER_LABELS, DRIVER_ANNOTATIONS, Map.empty, + Map.empty, Map.empty) val step = new BasicDriverFeatureStep(kubernetesConf) val additionalProperties = step.getAdditionalPodSystemProperties() diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala index a764f7630b5c8..9182134b3337c 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala @@ -87,6 +87,7 @@ class BasicExecutorFeatureStepSuite LABELS, ANNOTATIONS, Map.empty, + Map.empty, Map.empty)) val executor = step.configurePod(SparkPod.initialPod()) @@ -124,6 +125,7 @@ class BasicExecutorFeatureStepSuite LABELS, ANNOTATIONS, Map.empty, + Map.empty, Map.empty)) assert(step.configurePod(SparkPod.initialPod()).pod.getSpec.getHostname.length === 63) } @@ -142,6 +144,7 @@ class BasicExecutorFeatureStepSuite LABELS, ANNOTATIONS, Map.empty, + Map.empty, Map("qux" -> "quux"))) val executor = step.configurePod(SparkPod.initialPod()) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala index 9f817d3bfc79a..f81894f8055f1 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala @@ -59,6 +59,7 @@ class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with Bef Map.empty, Map.empty, Map.empty, + Map.empty, Map.empty) val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) assert(kubernetesCredentialsStep.configurePod(BASE_DRIVER_POD) === BASE_DRIVER_POD) @@ -88,6 +89,7 @@ class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with Bef Map.empty, Map.empty, Map.empty, + Map.empty, Map.empty) val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) @@ -124,6 +126,7 @@ class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with Bef Map.empty, Map.empty, Map.empty, + Map.empty, Map.empty) val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) val resolvedProperties = kubernetesCredentialsStep.getAdditionalPodSystemProperties() diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala index c299d56865ec0..f265522a8823a 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala @@ -65,6 +65,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { DRIVER_LABELS, Map.empty, Map.empty, + Map.empty, Map.empty)) assert(configurationStep.configurePod(SparkPod.initialPod()) === SparkPod.initialPod()) assert(configurationStep.getAdditionalKubernetesResources().size === 1) @@ -94,6 +95,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { DRIVER_LABELS, Map.empty, Map.empty, + Map.empty, Map.empty)) val expectedServiceName = SHORT_RESOURCE_NAME_PREFIX + DriverServiceFeatureStep.DRIVER_SVC_POSTFIX @@ -113,6 +115,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { DRIVER_LABELS, Map.empty, Map.empty, + Map.empty, Map.empty)) val resolvedService = configurationStep .getAdditionalKubernetesResources() @@ -141,6 +144,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { DRIVER_LABELS, Map.empty, Map.empty, + Map.empty, Map.empty), clock) val driverService = configurationStep @@ -166,6 +170,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { DRIVER_LABELS, Map.empty, Map.empty, + Map.empty, Map.empty), clock) fail("The driver bind address should not be allowed.") @@ -189,6 +194,7 @@ class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { DRIVER_LABELS, Map.empty, Map.empty, + Map.empty, Map.empty), clock) fail("The driver host address should not be allowed.") diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala new file mode 100644 index 0000000000000..8b0b2d0739c76 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/EnvSecretsFeatureStepSuite.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import io.fabric8.kubernetes.api.model.PodBuilder + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s._ + +class EnvSecretsFeatureStepSuite extends SparkFunSuite{ + private val KEY_REF_NAME_FOO = "foo" + private val KEY_REF_NAME_BAR = "bar" + private val KEY_REF_KEY_FOO = "key_foo" + private val KEY_REF_KEY_BAR = "key_bar" + private val ENV_NAME_FOO = "MY_FOO" + private val ENV_NAME_BAR = "MY_bar" + + test("sets up all keyRefs") { + val baseDriverPod = SparkPod.initialPod() + val envVarsToKeys = Map( + ENV_NAME_BAR -> s"${KEY_REF_NAME_BAR}:${KEY_REF_KEY_BAR}", + ENV_NAME_FOO -> s"${KEY_REF_NAME_FOO}:${KEY_REF_KEY_FOO}") + val sparkConf = new SparkConf(false) + val kubernetesConf = KubernetesConf( + sparkConf, + KubernetesExecutorSpecificConf("1", new PodBuilder().build()), + "resource-name-prefix", + "app-id", + Map.empty, + Map.empty, + Map.empty, + envVarsToKeys, + Map.empty) + + val step = new EnvSecretsFeatureStep(kubernetesConf) + val driverContainerWithEnvSecrets = step.configurePod(baseDriverPod).container + + val expectedVars = + Seq(s"${ENV_NAME_BAR}", s"${ENV_NAME_FOO}") + + expectedVars.foreach { envName => + assert(KubernetesFeaturesTestUtils.containerHasEnvVar(driverContainerWithEnvSecrets, envName)) + } + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala index 27bff74ce38af..f90380e30e52a 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala @@ -16,7 +16,9 @@ */ package org.apache.spark.deploy.k8s.features -import io.fabric8.kubernetes.api.model.{HasMetadata, PodBuilder, SecretBuilder} +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{Container, HasMetadata, PodBuilder, SecretBuilder} import org.mockito.Matchers import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock @@ -58,4 +60,7 @@ object KubernetesFeaturesTestUtils { .build()) } + def containerHasEnvVar(container: Container, envVarName: String): Boolean = { + container.getEnv.asScala.exists(envVar => envVar.getName == envVarName) + } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala index 91e184b84b86e..2542a02d37766 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala @@ -43,6 +43,7 @@ class LocalDirsFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, + Map.empty, Map.empty) } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala index 9d02f56cc206d..9155793774123 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala @@ -41,6 +41,7 @@ class MountSecretsFeatureStepSuite extends SparkFunSuite { Map.empty, Map.empty, secretNamesToMountPaths, + Map.empty, Map.empty) val step = new MountSecretsFeatureStep(kubernetesConf) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala index c1b203e03a357..0775338098a13 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala @@ -142,6 +142,7 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { Map.empty, Map.empty, Map.empty, + Map.empty, Map.empty) when(driverBuilder.buildFromFeatures(kubernetesConf)).thenReturn(BUILT_KUBERNETES_SPEC) when(kubernetesClient.pods()).thenReturn(podOperations) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala index a511d254d2175..cb724068ea4f3 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.deploy.k8s.submit import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf} -import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, KubernetesFeaturesTestUtils, LocalDirsFeatureStep, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, EnvSecretsFeatureStep, KubernetesFeaturesTestUtils, LocalDirsFeatureStep, MountSecretsFeatureStep} class KubernetesDriverBuilderSuite extends SparkFunSuite { @@ -27,6 +27,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { private val SERVICE_STEP_TYPE = "service" private val LOCAL_DIRS_STEP_TYPE = "local-dirs" private val SECRETS_STEP_TYPE = "mount-secrets" + private val ENV_SECRETS_STEP_TYPE = "env-secrets" private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( BASIC_STEP_TYPE, classOf[BasicDriverFeatureStep]) @@ -43,12 +44,16 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { private val secretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( SECRETS_STEP_TYPE, classOf[MountSecretsFeatureStep]) + private val envSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + ENV_SECRETS_STEP_TYPE, classOf[EnvSecretsFeatureStep]) + private val builderUnderTest: KubernetesDriverBuilder = new KubernetesDriverBuilder( _ => basicFeatureStep, _ => credentialsStep, _ => serviceStep, _ => secretsStep, + _ => envSecretsStep, _ => localDirsStep) test("Apply fundamental steps all the time.") { @@ -64,6 +69,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Map.empty, + Map.empty, Map.empty) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), @@ -86,6 +92,7 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Map("secret" -> "secretMountPath"), + Map("EnvName" -> "SecretName:secretKey"), Map.empty) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), @@ -93,7 +100,9 @@ class KubernetesDriverBuilderSuite extends SparkFunSuite { CREDENTIALS_STEP_TYPE, SERVICE_STEP_TYPE, LOCAL_DIRS_STEP_TYPE, - SECRETS_STEP_TYPE) + SECRETS_STEP_TYPE, + ENV_SECRETS_STEP_TYPE + ) } private def validateStepTypesApplied(resolvedSpec: KubernetesDriverSpec, stepTypes: String*) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala index 9ee86b5a423a9..753cd30a237f3 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala @@ -20,23 +20,27 @@ import io.fabric8.kubernetes.api.model.PodBuilder import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} -import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, KubernetesFeaturesTestUtils, LocalDirsFeatureStep, MountSecretsFeatureStep} +import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, EnvSecretsFeatureStep, KubernetesFeaturesTestUtils, LocalDirsFeatureStep, MountSecretsFeatureStep} class KubernetesExecutorBuilderSuite extends SparkFunSuite { private val BASIC_STEP_TYPE = "basic" private val SECRETS_STEP_TYPE = "mount-secrets" + private val ENV_SECRETS_STEP_TYPE = "env-secrets" private val LOCAL_DIRS_STEP_TYPE = "local-dirs" private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( BASIC_STEP_TYPE, classOf[BasicExecutorFeatureStep]) private val mountSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( SECRETS_STEP_TYPE, classOf[MountSecretsFeatureStep]) + private val envSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + ENV_SECRETS_STEP_TYPE, classOf[EnvSecretsFeatureStep]) private val localDirsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( LOCAL_DIRS_STEP_TYPE, classOf[LocalDirsFeatureStep]) private val builderUnderTest = new KubernetesExecutorBuilder( _ => basicFeatureStep, _ => mountSecretsStep, + _ => envSecretsStep, _ => localDirsStep) test("Basic steps are consistently applied.") { @@ -49,6 +53,7 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Map.empty, + Map.empty, Map.empty) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, LOCAL_DIRS_STEP_TYPE) @@ -64,12 +69,14 @@ class KubernetesExecutorBuilderSuite extends SparkFunSuite { Map.empty, Map.empty, Map("secret" -> "secretMountPath"), + Map("secret-name" -> "secret-key"), Map.empty) validateStepTypesApplied( builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, LOCAL_DIRS_STEP_TYPE, - SECRETS_STEP_TYPE) + SECRETS_STEP_TYPE, + ENV_SECRETS_STEP_TYPE) } private def validateStepTypesApplied(resolvedPod: SparkPod, stepTypes: String*): Unit = { diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index b36f46456f9a5..7d80eedcc43ce 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -30,8 +30,7 @@ import org.apache.mesos.Protos.Environment.Variable import org.apache.mesos.Protos.TaskStatus.Reason import org.apache.spark.{SecurityManager, SparkConf, SparkException, TaskState} -import org.apache.spark.deploy.mesos.MesosDriverDescription -import org.apache.spark.deploy.mesos.config +import org.apache.spark.deploy.mesos.{config, MesosDriverDescription} import org.apache.spark.deploy.rest.{CreateSubmissionResponse, KillSubmissionResponse, SubmissionStatusResponse} import org.apache.spark.metrics.MetricsSystem import org.apache.spark.util.Utils @@ -418,6 +417,18 @@ private[spark] class MesosClusterScheduler( envBuilder.build() } + private def isContainerLocalAppJar(desc: MesosDriverDescription): Boolean = { + val isLocalJar = desc.jarUrl.startsWith("local://") + val isContainerLocal = desc.conf.getOption("spark.mesos.appJar.local.resolution.mode").exists { + case "container" => true + case "host" => false + case other => + logWarning(s"Unknown spark.mesos.appJar.local.resolution.mode $other, using host.") + false + } + isLocalJar && isContainerLocal + } + private def getDriverUris(desc: MesosDriverDescription): List[CommandInfo.URI] = { val confUris = List(conf.getOption("spark.mesos.uris"), desc.conf.getOption("spark.mesos.uris"), @@ -425,10 +436,14 @@ private[spark] class MesosClusterScheduler( _.map(_.split(",").map(_.trim)) ).flatten - val jarUrl = desc.jarUrl.stripPrefix("file:").stripPrefix("local:") - - ((jarUrl :: confUris) ++ getDriverExecutorURI(desc).toList).map(uri => - CommandInfo.URI.newBuilder().setValue(uri.trim()).setCache(useFetchCache).build()) + if (isContainerLocalAppJar(desc)) { + (confUris ++ getDriverExecutorURI(desc).toList).map(uri => + CommandInfo.URI.newBuilder().setValue(uri.trim()).setCache(useFetchCache).build()) + } else { + val jarUrl = desc.jarUrl.stripPrefix("file:").stripPrefix("local:") + ((jarUrl :: confUris) ++ getDriverExecutorURI(desc).toList).map(uri => + CommandInfo.URI.newBuilder().setValue(uri.trim()).setCache(useFetchCache).build()) + } } private def getContainerInfo(desc: MesosDriverDescription): ContainerInfo.Builder = { @@ -480,7 +495,14 @@ private[spark] class MesosClusterScheduler( (cmdExecutable, ".") } val cmdOptions = generateCmdOption(desc, sandboxPath).mkString(" ") - val primaryResource = new File(sandboxPath, desc.jarUrl.split("/").last).toString() + val primaryResource = { + if (isContainerLocalAppJar(desc)) { + new File(desc.jarUrl.stripPrefix("local://")).toString() + } else { + new File(sandboxPath, desc.jarUrl.split("/").last).toString() + } + } + val appArguments = desc.command.arguments.mkString(" ") s"$executable $cmdOptions $primaryResource $appArguments" diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala index d4eeb6bbcf886..26a2e5d730218 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala @@ -44,6 +44,10 @@ private[yarn] class YARNHadoopDelegationTokenManager( // public for testing val credentialProviders = getCredentialProviders + if (credentialProviders.nonEmpty) { + logDebug("Using the following YARN-specific credential providers: " + + s"${credentialProviders.keys.mkString(", ")}.") + } /** * Writes delegation tokens to creds. Delegation tokens are fetched from all registered diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 59b0f29e37d84..3b78b88de778d 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -271,16 +271,11 @@ class YarnClusterSuite extends BaseYarnClusterSuite { "PYSPARK_ARCHIVES_PATH" -> pythonPath.map("local:" + _).mkString(File.pathSeparator), "PYTHONPATH" -> pythonPath.mkString(File.pathSeparator)) ++ extraEnv - val moduleDir = - if (clientMode) { - // In client-mode, .py files added with --py-files are not visible in the driver. - // This is something that the launcher library would have to handle. - tempDir - } else { - val subdir = new File(tempDir, "pyModules") - subdir.mkdir() - subdir - } + val moduleDir = { + val subdir = new File(tempDir, "pyModules") + subdir.mkdir() + subdir + } val pyModule = new File(moduleDir, "mod1.py") Files.write(TEST_PYMODULE, pyModule, StandardCharsets.UTF_8) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 7c54851097af3..3fe00eefde7d8 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -592,6 +592,7 @@ primaryExpression | identifier #columnReference | base=primaryExpression '.' fieldName=identifier #dereference | '(' expression ')' #parenthesizedExpression + | EXTRACT '(' field=identifier FROM source=valueExpression ')' #extract ; constant @@ -739,6 +740,7 @@ nonReserved | VIEW | REPLACE | IF | POSITION + | EXTRACT | NO | DATA | START | TRANSACTION | COMMIT | ROLLBACK | IGNORE | SORT | CLUSTER | DISTRIBUTE | UNSET | TBLPROPERTIES | SKEWED | STORED | DIRECTORIES | LOCATION @@ -878,6 +880,7 @@ TRAILING: 'TRAILING'; IF: 'IF'; POSITION: 'POSITION'; +EXTRACT: 'EXTRACT'; EQ : '=' | '=='; NSEQ: '<=>'; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/MaskExpressionsUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/MaskExpressionsUtils.java new file mode 100644 index 0000000000000..05879902a4ed9 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/MaskExpressionsUtils.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +/** + * Contains all the Utils methods used in the masking expressions. + */ +public class MaskExpressionsUtils { + static final int UNMASKED_VAL = -1; + + /** + * Returns the masking character for {@param c} or {@param c} is it should not be masked. + * @param c the character to transform + * @param maskedUpperChar the character to use instead of a uppercase letter + * @param maskedLowerChar the character to use instead of a lowercase letter + * @param maskedDigitChar the character to use instead of a digit + * @param maskedOtherChar the character to use instead of a any other character + * @return masking character for {@param c} + */ + public static int transformChar( + final int c, + int maskedUpperChar, + int maskedLowerChar, + int maskedDigitChar, + int maskedOtherChar) { + switch(Character.getType(c)) { + case Character.UPPERCASE_LETTER: + if(maskedUpperChar != UNMASKED_VAL) { + return maskedUpperChar; + } + break; + + case Character.LOWERCASE_LETTER: + if(maskedLowerChar != UNMASKED_VAL) { + return maskedLowerChar; + } + break; + + case Character.DECIMAL_DIGIT_NUMBER: + if(maskedDigitChar != UNMASKED_VAL) { + return maskedDigitChar; + } + break; + + default: + if(maskedOtherChar != UNMASKED_VAL) { + return maskedOtherChar; + } + break; + } + + return c; + } + + /** + * Returns the replacement char to use according to the {@param rep} specified by the user and + * the {@param def} default. + */ + public static int getReplacementChar(String rep, int def) { + if (rep != null && rep.length() > 0) { + return rep.codePointAt(0); + } + return def; + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 3eaa9ecf5d075..f9947d1fa6c78 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1744,10 +1744,10 @@ class Analyzer( * it into the plan tree. */ object ExtractWindowExpressions extends Rule[LogicalPlan] { - private def hasWindowFunction(projectList: Seq[NamedExpression]): Boolean = - projectList.exists(hasWindowFunction) + private def hasWindowFunction(exprs: Seq[Expression]): Boolean = + exprs.exists(hasWindowFunction) - private def hasWindowFunction(expr: NamedExpression): Boolean = { + private def hasWindowFunction(expr: Expression): Boolean = { expr.find { case window: WindowExpression => true case _ => false @@ -1830,6 +1830,10 @@ class Analyzer( seenWindowAggregates += newAgg WindowExpression(newAgg, spec) + case AggregateExpression(aggFunc, _, _, _) if hasWindowFunction(aggFunc.children) => + failAnalysis("It is not allowed to use a window function inside an aggregate " + + "function. Please use the inner window function in a sub-query.") + // Extracts AggregateExpression. For example, for SUM(x) - Sum(y) OVER (...), // we need to extract SUM(x). case agg: AggregateExpression if !seenWindowAggregates.contains(agg) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 1134a8866dc13..49fb35b083580 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -430,8 +430,17 @@ object FunctionRegistry { expression[Concat]("concat"), expression[Flatten]("flatten"), expression[ArrayRepeat]("array_repeat"), + expression[ArrayRemove]("array_remove"), CreateStruct.registryEntry, + // mask functions + expression[Mask]("mask"), + expression[MaskFirstN]("mask_first_n"), + expression[MaskLastN]("mask_last_n"), + expression[MaskShowFirstN]("mask_show_first_n"), + expression[MaskShowLastN]("mask_show_last_n"), + expression[MaskHash]("mask_hash"), + // misc functions expression[AssertTrue]("assert_true"), expression[Crc32]("crc32"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala index d848ba18356d3..7541f527a52a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala @@ -30,6 +30,7 @@ package org.apache.spark.sql.catalyst.expressions * by `hashCode`. * - [[EqualTo]] and [[EqualNullSafe]] are reordered by `hashCode`. * - Other comparisons ([[GreaterThan]], [[LessThan]]) are reversed by `hashCode`. + * - Elements in [[In]] are reordered by `hashCode`. */ object Canonicalize { def execute(e: Expression): Expression = { @@ -85,6 +86,11 @@ object Canonicalize { case Not(GreaterThanOrEqual(l, r)) => LessThan(l, r) case Not(LessThanOrEqual(l, r)) => GreaterThan(l, r) + // order the list in the In operator + // In subqueries contain only one element of type ListQuery. So checking that the length > 1 + // we are not reordering In subqueries. + case In(value, list) if list.length > 1 => In(value, list.sortBy(_.hashCode())) + case _ => e } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 8a877b02c8191..176995affe701 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -2066,3 +2066,126 @@ case class ArrayRepeat(left: Expression, right: Expression) } } + +/** + * Remove all elements that equal to element from the given array + */ +@ExpressionDescription( + usage = "_FUNC_(array, element) - Remove all elements that equal to element from array.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3, null, 3), 3); + [1,2,null] + """, since = "2.4.0") +case class ArrayRemove(left: Expression, right: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = left.dataType + + override def inputTypes: Seq[AbstractDataType] = { + val elementType = left.dataType match { + case t: ArrayType => t.elementType + case _ => AnyDataType + } + Seq(ArrayType, elementType) + } + + lazy val elementType: DataType = left.dataType.asInstanceOf[ArrayType].elementType + + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(right.dataType) + + override def checkInputDataTypes(): TypeCheckResult = { + super.checkInputDataTypes() match { + case f: TypeCheckResult.TypeCheckFailure => f + case TypeCheckResult.TypeCheckSuccess => + TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName") + } + } + + override def nullSafeEval(arr: Any, value: Any): Any = { + val newArray = new Array[Any](arr.asInstanceOf[ArrayData].numElements()) + var pos = 0 + arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => + if (v == null || !ordering.equiv(v, value)) { + newArray(pos) = v + pos += 1 + } + ) + new GenericArrayData(newArray.slice(0, pos)) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (arr, value) => { + val numsToRemove = ctx.freshName("numsToRemove") + val newArraySize = ctx.freshName("newArraySize") + val i = ctx.freshName("i") + val getValue = CodeGenerator.getValue(arr, elementType, i) + val isEqual = ctx.genEqual(elementType, value, getValue) + s""" + |int $numsToRemove = 0; + |for (int $i = 0; $i < $arr.numElements(); $i ++) { + | if (!$arr.isNullAt($i) && $isEqual) { + | $numsToRemove = $numsToRemove + 1; + | } + |} + |int $newArraySize = $arr.numElements() - $numsToRemove; + |${genCodeForResult(ctx, ev, arr, value, newArraySize)} + """.stripMargin + }) + } + + def genCodeForResult( + ctx: CodegenContext, + ev: ExprCode, + inputArray: String, + value: String, + newArraySize: String): String = { + val values = ctx.freshName("values") + val i = ctx.freshName("i") + val pos = ctx.freshName("pos") + val getValue = CodeGenerator.getValue(inputArray, elementType, i) + val isEqual = ctx.genEqual(elementType, value, getValue) + if (!CodeGenerator.isPrimitiveType(elementType)) { + val arrayClass = classOf[GenericArrayData].getName + s""" + |int $pos = 0; + |Object[] $values = new Object[$newArraySize]; + |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { + | if ($inputArray.isNullAt($i)) { + | $values[$pos] = null; + | $pos = $pos + 1; + | } + | else { + | if (!($isEqual)) { + | $values[$pos] = $getValue; + | $pos = $pos + 1; + | } + | } + |} + |${ev.value} = new $arrayClass($values); + """.stripMargin + } else { + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + s""" + |${ctx.createUnsafeArray(values, newArraySize, elementType, s" $prettyName failed.")} + |int $pos = 0; + |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { + | if ($inputArray.isNullAt($i)) { + | $values.setNullAt($pos); + | $pos = $pos + 1; + | } + | else { + | if (!($isEqual)) { + | $values.set$primitiveValueTypeName($pos, $getValue); + | $pos = $pos + 1; + | } + | } + |} + |${ev.value} = $values; + """.stripMargin + } + } + + override def prettyName: String = "array_remove" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala new file mode 100644 index 0000000000000..276a57266a6e0 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala @@ -0,0 +1,569 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.commons.codec.digest.DigestUtils + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.MaskExpressionsUtils._ +import org.apache.spark.sql.catalyst.expressions.MaskLike._ +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + + +trait MaskLike { + def upper: String + def lower: String + def digit: String + + protected lazy val upperReplacement: Int = getReplacementChar(upper, defaultMaskedUppercase) + protected lazy val lowerReplacement: Int = getReplacementChar(lower, defaultMaskedLowercase) + protected lazy val digitReplacement: Int = getReplacementChar(digit, defaultMaskedDigit) + + protected val maskUtilsClassName: String = classOf[MaskExpressionsUtils].getName + + def inputStringLengthCode(inputString: String, length: String): String = { + s"${CodeGenerator.JAVA_INT} $length = $inputString.codePointCount(0, $inputString.length());" + } + + def appendMaskedToStringBuilderCode( + ctx: CodegenContext, + sb: String, + inputString: String, + offset: String, + numChars: String): String = { + val i = ctx.freshName("i") + val codePoint = ctx.freshName("codePoint") + s""" + |for (${CodeGenerator.JAVA_INT} $i = 0; $i < $numChars; $i++) { + | ${CodeGenerator.JAVA_INT} $codePoint = $inputString.codePointAt($offset); + | $sb.appendCodePoint($maskUtilsClassName.transformChar($codePoint, + | $upperReplacement, $lowerReplacement, + | $digitReplacement, $defaultMaskedOther)); + | $offset += Character.charCount($codePoint); + |} + """.stripMargin + } + + def appendUnchangedToStringBuilderCode( + ctx: CodegenContext, + sb: String, + inputString: String, + offset: String, + numChars: String): String = { + val i = ctx.freshName("i") + val codePoint = ctx.freshName("codePoint") + s""" + |for (${CodeGenerator.JAVA_INT} $i = 0; $i < $numChars; $i++) { + | ${CodeGenerator.JAVA_INT} $codePoint = $inputString.codePointAt($offset); + | $sb.appendCodePoint($codePoint); + | $offset += Character.charCount($codePoint); + |} + """.stripMargin + } + + def appendMaskedToStringBuilder( + sb: java.lang.StringBuilder, + inputString: String, + startOffset: Int, + numChars: Int): Int = { + var offset = startOffset + (1 to numChars) foreach { _ => + val codePoint = inputString.codePointAt(offset) + sb.appendCodePoint(transformChar( + codePoint, + upperReplacement, + lowerReplacement, + digitReplacement, + defaultMaskedOther)) + offset += Character.charCount(codePoint) + } + offset + } + + def appendUnchangedToStringBuilder( + sb: java.lang.StringBuilder, + inputString: String, + startOffset: Int, + numChars: Int): Int = { + var offset = startOffset + (1 to numChars) foreach { _ => + val codePoint = inputString.codePointAt(offset) + sb.appendCodePoint(codePoint) + offset += Character.charCount(codePoint) + } + offset + } +} + +trait MaskLikeWithN extends MaskLike { + def n: Int + protected lazy val charCount: Int = if (n < 0) 0 else n +} + +/** + * Utils for mask operations. + */ +object MaskLike { + val defaultCharCount = 4 + val defaultMaskedUppercase: Int = 'X' + val defaultMaskedLowercase: Int = 'x' + val defaultMaskedDigit: Int = 'n' + val defaultMaskedOther: Int = MaskExpressionsUtils.UNMASKED_VAL + + def extractCharCount(e: Expression): Int = e match { + case Literal(i, IntegerType | NullType) => + if (i == null) defaultCharCount else i.asInstanceOf[Int] + case Literal(_, dt) => throw new AnalysisException("Expected literal expression of type " + + s"${IntegerType.simpleString}, but got literal of ${dt.simpleString}") + case other => throw new AnalysisException(s"Expected literal expression, but got ${other.sql}") + } + + def extractReplacement(e: Expression): String = e match { + case Literal(s, StringType | NullType) => if (s == null) null else s.toString + case Literal(_, dt) => throw new AnalysisException("Expected literal expression of type " + + s"${StringType.simpleString}, but got literal of ${dt.simpleString}") + case other => throw new AnalysisException(s"Expected literal expression, but got ${other.sql}") + } +} + +/** + * Masks the input string. Additional parameters can be set to change the masking chars for + * uppercase letters, lowercase letters and digits. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(str[, upper[, lower[, digit]]]) - Masks str. By default, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.", + examples = """ + Examples: + > SELECT _FUNC_("abcd-EFGH-8765-4321", "U", "l", "#"); + llll-UUUU-####-#### + """) +// scalastyle:on line.size.limit +case class Mask(child: Expression, upper: String, lower: String, digit: String) + extends UnaryExpression with ExpectsInputTypes with MaskLike { + + def this(child: Expression) = this(child, null.asInstanceOf[String], null, null) + + def this(child: Expression, upper: Expression) = + this(child, extractReplacement(upper), null, null) + + def this(child: Expression, upper: Expression, lower: Expression) = + this(child, extractReplacement(upper), extractReplacement(lower), null) + + def this(child: Expression, upper: Expression, lower: Expression, digit: Expression) = + this(child, extractReplacement(upper), extractReplacement(lower), extractReplacement(digit)) + + override def nullSafeEval(input: Any): Any = { + val str = input.asInstanceOf[UTF8String].toString + val length = str.codePointCount(0, str.length()) + val sb = new java.lang.StringBuilder(length) + appendMaskedToStringBuilder(sb, str, 0, length) + UTF8String.fromString(sb.toString) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (input: String) => { + val sb = ctx.freshName("sb") + val length = ctx.freshName("length") + val offset = ctx.freshName("offset") + val inputString = ctx.freshName("inputString") + s""" + |String $inputString = $input.toString(); + |${inputStringLengthCode(inputString, length)} + |StringBuilder $sb = new StringBuilder($length); + |${CodeGenerator.JAVA_INT} $offset = 0; + |${appendMaskedToStringBuilderCode(ctx, sb, inputString, offset, length)} + |${ev.value} = UTF8String.fromString($sb.toString()); + """.stripMargin + }) + } + + override def dataType: DataType = StringType + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) +} + +/** + * Masks the first N chars of the input string. N defaults to 4. Additional parameters can be set + * to change the masking chars for uppercase letters, lowercase letters and digits. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(str[, n[, upper[, lower[, digit]]]]) - Masks the first n values of str. By default, n is 4, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.", + examples = """ + Examples: + > SELECT _FUNC_("1234-5678-8765-4321", 4); + nnnn-5678-8765-4321 + """) +// scalastyle:on line.size.limit +case class MaskFirstN( + child: Expression, + n: Int, + upper: String, + lower: String, + digit: String) + extends UnaryExpression with ExpectsInputTypes with MaskLikeWithN { + + def this(child: Expression) = + this(child, defaultCharCount, null, null, null) + + def this(child: Expression, n: Expression) = + this(child, extractCharCount(n), null, null, null) + + def this(child: Expression, n: Expression, upper: Expression) = + this(child, extractCharCount(n), extractReplacement(upper), null, null) + + def this(child: Expression, n: Expression, upper: Expression, lower: Expression) = + this(child, extractCharCount(n), extractReplacement(upper), extractReplacement(lower), null) + + def this( + child: Expression, + n: Expression, + upper: Expression, + lower: Expression, + digit: Expression) = + this(child, + extractCharCount(n), + extractReplacement(upper), + extractReplacement(lower), + extractReplacement(digit)) + + override def nullSafeEval(input: Any): Any = { + val str = input.asInstanceOf[UTF8String].toString + val length = str.codePointCount(0, str.length()) + val endOfMask = if (charCount > length) length else charCount + val sb = new java.lang.StringBuilder(length) + val offset = appendMaskedToStringBuilder(sb, str, 0, endOfMask) + appendUnchangedToStringBuilder(sb, str, offset, length - endOfMask) + UTF8String.fromString(sb.toString) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (input: String) => { + val sb = ctx.freshName("sb") + val length = ctx.freshName("length") + val offset = ctx.freshName("offset") + val inputString = ctx.freshName("inputString") + val endOfMask = ctx.freshName("endOfMask") + s""" + |String $inputString = $input.toString(); + |${inputStringLengthCode(inputString, length)} + |${CodeGenerator.JAVA_INT} $endOfMask = $charCount > $length ? $length : $charCount; + |${CodeGenerator.JAVA_INT} $offset = 0; + |StringBuilder $sb = new StringBuilder($length); + |${appendMaskedToStringBuilderCode(ctx, sb, inputString, offset, endOfMask)} + |${appendUnchangedToStringBuilderCode( + ctx, sb, inputString, offset, s"$length - $endOfMask")} + |${ev.value} = UTF8String.fromString($sb.toString()); + |""".stripMargin + }) + } + + override def dataType: DataType = StringType + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + + override def prettyName: String = "mask_first_n" +} + +/** + * Masks the last N chars of the input string. N defaults to 4. Additional parameters can be set + * to change the masking chars for uppercase letters, lowercase letters and digits. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(str[, n[, upper[, lower[, digit]]]]) - Masks the last n values of str. By default, n is 4, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.", + examples = """ + Examples: + > SELECT _FUNC_("1234-5678-8765-4321", 4); + 1234-5678-8765-nnnn + """, since = "2.4.0") +// scalastyle:on line.size.limit +case class MaskLastN( + child: Expression, + n: Int, + upper: String, + lower: String, + digit: String) + extends UnaryExpression with ExpectsInputTypes with MaskLikeWithN { + + def this(child: Expression) = + this(child, defaultCharCount, null, null, null) + + def this(child: Expression, n: Expression) = + this(child, extractCharCount(n), null, null, null) + + def this(child: Expression, n: Expression, upper: Expression) = + this(child, extractCharCount(n), extractReplacement(upper), null, null) + + def this(child: Expression, n: Expression, upper: Expression, lower: Expression) = + this(child, extractCharCount(n), extractReplacement(upper), extractReplacement(lower), null) + + def this( + child: Expression, + n: Expression, + upper: Expression, + lower: Expression, + digit: Expression) = + this(child, + extractCharCount(n), + extractReplacement(upper), + extractReplacement(lower), + extractReplacement(digit)) + + override def nullSafeEval(input: Any): Any = { + val str = input.asInstanceOf[UTF8String].toString + val length = str.codePointCount(0, str.length()) + val startOfMask = if (charCount >= length) 0 else length - charCount + val sb = new java.lang.StringBuilder(length) + val offset = appendUnchangedToStringBuilder(sb, str, 0, startOfMask) + appendMaskedToStringBuilder(sb, str, offset, length - startOfMask) + UTF8String.fromString(sb.toString) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (input: String) => { + val sb = ctx.freshName("sb") + val length = ctx.freshName("length") + val offset = ctx.freshName("offset") + val inputString = ctx.freshName("inputString") + val startOfMask = ctx.freshName("startOfMask") + s""" + |String $inputString = $input.toString(); + |${inputStringLengthCode(inputString, length)} + |${CodeGenerator.JAVA_INT} $startOfMask = $charCount >= $length ? + | 0 : $length - $charCount; + |${CodeGenerator.JAVA_INT} $offset = 0; + |StringBuilder $sb = new StringBuilder($length); + |${appendUnchangedToStringBuilderCode(ctx, sb, inputString, offset, startOfMask)} + |${appendMaskedToStringBuilderCode( + ctx, sb, inputString, offset, s"$length - $startOfMask")} + |${ev.value} = UTF8String.fromString($sb.toString()); + |""".stripMargin + }) + } + + override def dataType: DataType = StringType + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + + override def prettyName: String = "mask_last_n" +} + +/** + * Masks all but the first N chars of the input string. N defaults to 4. Additional parameters can + * be set to change the masking chars for uppercase letters, lowercase letters and digits. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(str[, n[, upper[, lower[, digit]]]]) - Masks all but the first n values of str. By default, n is 4, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.", + examples = """ + Examples: + > SELECT _FUNC_("1234-5678-8765-4321", 4); + 1234-nnnn-nnnn-nnnn + """, since = "2.4.0") +// scalastyle:on line.size.limit +case class MaskShowFirstN( + child: Expression, + n: Int, + upper: String, + lower: String, + digit: String) + extends UnaryExpression with ExpectsInputTypes with MaskLikeWithN { + + def this(child: Expression) = + this(child, defaultCharCount, null, null, null) + + def this(child: Expression, n: Expression) = + this(child, extractCharCount(n), null, null, null) + + def this(child: Expression, n: Expression, upper: Expression) = + this(child, extractCharCount(n), extractReplacement(upper), null, null) + + def this(child: Expression, n: Expression, upper: Expression, lower: Expression) = + this(child, extractCharCount(n), extractReplacement(upper), extractReplacement(lower), null) + + def this( + child: Expression, + n: Expression, + upper: Expression, + lower: Expression, + digit: Expression) = + this(child, + extractCharCount(n), + extractReplacement(upper), + extractReplacement(lower), + extractReplacement(digit)) + + override def nullSafeEval(input: Any): Any = { + val str = input.asInstanceOf[UTF8String].toString + val length = str.codePointCount(0, str.length()) + val startOfMask = if (charCount > length) length else charCount + val sb = new java.lang.StringBuilder(length) + val offset = appendUnchangedToStringBuilder(sb, str, 0, startOfMask) + appendMaskedToStringBuilder(sb, str, offset, length - startOfMask) + UTF8String.fromString(sb.toString) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (input: String) => { + val sb = ctx.freshName("sb") + val length = ctx.freshName("length") + val offset = ctx.freshName("offset") + val inputString = ctx.freshName("inputString") + val startOfMask = ctx.freshName("startOfMask") + s""" + |String $inputString = $input.toString(); + |${inputStringLengthCode(inputString, length)} + |${CodeGenerator.JAVA_INT} $startOfMask = $charCount > $length ? $length : $charCount; + |${CodeGenerator.JAVA_INT} $offset = 0; + |StringBuilder $sb = new StringBuilder($length); + |${appendUnchangedToStringBuilderCode(ctx, sb, inputString, offset, startOfMask)} + |${appendMaskedToStringBuilderCode( + ctx, sb, inputString, offset, s"$length - $startOfMask")} + |${ev.value} = UTF8String.fromString($sb.toString()); + |""".stripMargin + }) + } + + override def dataType: DataType = StringType + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + + override def prettyName: String = "mask_show_first_n" +} + +/** + * Masks all but the last N chars of the input string. N defaults to 4. Additional parameters can + * be set to change the masking chars for uppercase letters, lowercase letters and digits. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(str[, n[, upper[, lower[, digit]]]]) - Masks all but the last n values of str. By default, n is 4, upper case letters are converted to \"X\", lower case letters are converted to \"x\" and numbers are converted to \"n\". You can override the characters used in the mask by supplying additional arguments: the second argument controls the mask character for upper case letters, the third argument for lower case letters and the fourth argument for numbers.", + examples = """ + Examples: + > SELECT _FUNC_("1234-5678-8765-4321", 4); + nnnn-nnnn-nnnn-4321 + """, since = "2.4.0") +// scalastyle:on line.size.limit +case class MaskShowLastN( + child: Expression, + n: Int, + upper: String, + lower: String, + digit: String) + extends UnaryExpression with ExpectsInputTypes with MaskLikeWithN { + + def this(child: Expression) = + this(child, defaultCharCount, null, null, null) + + def this(child: Expression, n: Expression) = + this(child, extractCharCount(n), null, null, null) + + def this(child: Expression, n: Expression, upper: Expression) = + this(child, extractCharCount(n), extractReplacement(upper), null, null) + + def this(child: Expression, n: Expression, upper: Expression, lower: Expression) = + this(child, extractCharCount(n), extractReplacement(upper), extractReplacement(lower), null) + + def this( + child: Expression, + n: Expression, + upper: Expression, + lower: Expression, + digit: Expression) = + this(child, + extractCharCount(n), + extractReplacement(upper), + extractReplacement(lower), + extractReplacement(digit)) + + override def nullSafeEval(input: Any): Any = { + val str = input.asInstanceOf[UTF8String].toString + val length = str.codePointCount(0, str.length()) + val endOfMask = if (charCount >= length) 0 else length - charCount + val sb = new java.lang.StringBuilder(length) + val offset = appendMaskedToStringBuilder(sb, str, 0, endOfMask) + appendUnchangedToStringBuilder(sb, str, offset, length - endOfMask) + UTF8String.fromString(sb.toString) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (input: String) => { + val sb = ctx.freshName("sb") + val length = ctx.freshName("length") + val offset = ctx.freshName("offset") + val inputString = ctx.freshName("inputString") + val endOfMask = ctx.freshName("endOfMask") + s""" + |String $inputString = $input.toString(); + |${inputStringLengthCode(inputString, length)} + |${CodeGenerator.JAVA_INT} $endOfMask = $charCount >= $length ? 0 : $length - $charCount; + |${CodeGenerator.JAVA_INT} $offset = 0; + |StringBuilder $sb = new StringBuilder($length); + |${appendMaskedToStringBuilderCode(ctx, sb, inputString, offset, endOfMask)} + |${appendUnchangedToStringBuilderCode( + ctx, sb, inputString, offset, s"$length - $endOfMask")} + |${ev.value} = UTF8String.fromString($sb.toString()); + |""".stripMargin + }) + } + + override def dataType: DataType = StringType + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + + override def prettyName: String = "mask_show_last_n" +} + +/** + * Returns a hashed value based on str. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(str) - Returns a hashed value based on str. The hash is consistent and can be used to join masked values together across tables.", + examples = """ + Examples: + > SELECT _FUNC_("abcd-EFGH-8765-4321"); + 60c713f5ec6912229d2060df1c322776 + """) +// scalastyle:on line.size.limit +case class MaskHash(child: Expression) + extends UnaryExpression with ExpectsInputTypes { + + override def nullSafeEval(input: Any): Any = { + UTF8String.fromString(DigestUtils.md5Hex(input.asInstanceOf[UTF8String].toString)) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (input: String) => { + val digestUtilsClass = classOf[DigestUtils].getName.stripSuffix("$") + s""" + |${ev.value} = UTF8String.fromString($digestUtilsClass.md5Hex($input.toString())); + |""".stripMargin + }) + } + + override def dataType: DataType = StringType + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + + override def prettyName: String = "mask_hash" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 9823b2fc5ad97..bedad7da334ae 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1916,12 +1916,15 @@ case class Encode(value: Expression, charset: Expression) usage = """ _FUNC_(expr1, expr2) - Formats the number `expr1` like '#,###,###.##', rounded to `expr2` decimal places. If `expr2` is 0, the result has no decimal point or fractional part. + `expr2` also accept a user specified format. This is supposed to function like MySQL's FORMAT. """, examples = """ Examples: > SELECT _FUNC_(12332.123456, 4); 12,332.1235 + > SELECT _FUNC_(12332.123456, '##################.###'); + 12332.123 """) case class FormatNumber(x: Expression, d: Expression) extends BinaryExpression with ExpectsInputTypes { @@ -1930,14 +1933,20 @@ case class FormatNumber(x: Expression, d: Expression) override def right: Expression = d override def dataType: DataType = StringType override def nullable: Boolean = true - override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType) + override def inputTypes: Seq[AbstractDataType] = + Seq(NumericType, TypeCollection(IntegerType, StringType)) + + private val defaultFormat = "#,###,###,###,###,###,##0" // Associated with the pattern, for the last d value, and we will update the // pattern (DecimalFormat) once the new coming d value differ with the last one. // This is an Option to distinguish between 0 (numberFormat is valid) and uninitialized after // serialization (numberFormat has not been updated for dValue = 0). @transient - private var lastDValue: Option[Int] = None + private var lastDIntValue: Option[Int] = None + + @transient + private var lastDStringValue: Option[String] = None // A cached DecimalFormat, for performance concern, we will change it // only if the d value changed. @@ -1950,33 +1959,49 @@ case class FormatNumber(x: Expression, d: Expression) private lazy val numberFormat = new DecimalFormat("", new DecimalFormatSymbols(Locale.US)) override protected def nullSafeEval(xObject: Any, dObject: Any): Any = { - val dValue = dObject.asInstanceOf[Int] - if (dValue < 0) { - return null - } - - lastDValue match { - case Some(last) if last == dValue => - // use the current pattern - case _ => - // construct a new DecimalFormat only if a new dValue - pattern.delete(0, pattern.length) - pattern.append("#,###,###,###,###,###,##0") - - // decimal place - if (dValue > 0) { - pattern.append(".") - - var i = 0 - while (i < dValue) { - i += 1 - pattern.append("0") - } + right.dataType match { + case IntegerType => + val dValue = dObject.asInstanceOf[Int] + if (dValue < 0) { + return null } - lastDValue = Some(dValue) + lastDIntValue match { + case Some(last) if last == dValue => + // use the current pattern + case _ => + // construct a new DecimalFormat only if a new dValue + pattern.delete(0, pattern.length) + pattern.append(defaultFormat) + + // decimal place + if (dValue > 0) { + pattern.append(".") + + var i = 0 + while (i < dValue) { + i += 1 + pattern.append("0") + } + } + + lastDIntValue = Some(dValue) - numberFormat.applyLocalizedPattern(pattern.toString) + numberFormat.applyLocalizedPattern(pattern.toString) + } + case StringType => + val dValue = dObject.asInstanceOf[UTF8String].toString + lastDStringValue match { + case Some(last) if last == dValue => + case _ => + pattern.delete(0, pattern.length) + lastDStringValue = Some(dValue) + if (dValue.isEmpty) { + numberFormat.applyLocalizedPattern(defaultFormat) + } else { + numberFormat.applyLocalizedPattern(dValue) + } + } } x.dataType match { @@ -2008,35 +2033,52 @@ case class FormatNumber(x: Expression, d: Expression) // SPARK-13515: US Locale configures the DecimalFormat object to use a dot ('.') // as a decimal separator. val usLocale = "US" - val i = ctx.freshName("i") - val dFormat = ctx.freshName("dFormat") - val lastDValue = - ctx.addMutableState(CodeGenerator.JAVA_INT, "lastDValue", v => s"$v = -100;") - val pattern = ctx.addMutableState(sb, "pattern", v => s"$v = new $sb();") val numberFormat = ctx.addMutableState(df, "numberFormat", v => s"""$v = new $df("", new $dfs($l.$usLocale));""") - s""" - if ($d >= 0) { - $pattern.delete(0, $pattern.length()); - if ($d != $lastDValue) { - $pattern.append("#,###,###,###,###,###,##0"); - - if ($d > 0) { - $pattern.append("."); - for (int $i = 0; $i < $d; $i++) { - $pattern.append("0"); + right.dataType match { + case IntegerType => + val pattern = ctx.addMutableState(sb, "pattern", v => s"$v = new $sb();") + val i = ctx.freshName("i") + val lastDValue = + ctx.addMutableState(CodeGenerator.JAVA_INT, "lastDValue", v => s"$v = -100;") + s""" + if ($d >= 0) { + $pattern.delete(0, $pattern.length()); + if ($d != $lastDValue) { + $pattern.append("$defaultFormat"); + + if ($d > 0) { + $pattern.append("."); + for (int $i = 0; $i < $d; $i++) { + $pattern.append("0"); + } + } + $lastDValue = $d; + $numberFormat.applyLocalizedPattern($pattern.toString()); } + ${ev.value} = UTF8String.fromString($numberFormat.format(${typeHelper(num)})); + } else { + ${ev.value} = null; + ${ev.isNull} = true; } - $lastDValue = $d; - $numberFormat.applyLocalizedPattern($pattern.toString()); - } - ${ev.value} = UTF8String.fromString($numberFormat.format(${typeHelper(num)})); - } else { - ${ev.value} = null; - ${ev.isNull} = true; - } - """ + """ + case StringType => + val lastDValue = ctx.addMutableState("String", "lastDValue", v => s"""$v = null;""") + val dValue = ctx.freshName("dValue") + s""" + String $dValue = $d.toString(); + if (!$dValue.equals($lastDValue)) { + $lastDValue = $dValue; + if ($dValue.isEmpty()) { + $numberFormat.applyLocalizedPattern("$defaultFormat"); + } else { + $numberFormat.applyLocalizedPattern($dValue); + } + } + ${ev.value} = UTF8String.fromString($numberFormat.format(${typeHelper(num)})); + """ + } }) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 1c0b7bd806801..1d363b8146e3f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -21,7 +21,6 @@ import scala.collection.immutable.HashSet import scala.collection.mutable.{ArrayBuffer, Stack} import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.analysis.TypeCoercion.ImplicitTypeCasts import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.expressions.aggregate._ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index b9ece295c2510..383ebde3229d6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1206,6 +1206,34 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging new StringLocate(expression(ctx.substr), expression(ctx.str)) } + /** + * Create a Extract expression. + */ + override def visitExtract(ctx: ExtractContext): Expression = withOrigin(ctx) { + ctx.field.getText.toUpperCase(Locale.ROOT) match { + case "YEAR" => + Year(expression(ctx.source)) + case "QUARTER" => + Quarter(expression(ctx.source)) + case "MONTH" => + Month(expression(ctx.source)) + case "WEEK" => + WeekOfYear(expression(ctx.source)) + case "DAY" => + DayOfMonth(expression(ctx.source)) + case "DAYOFWEEK" => + DayOfWeek(expression(ctx.source)) + case "HOUR" => + Hour(expression(ctx.source)) + case "MINUTE" => + Minute(expression(ctx.source)) + case "SECOND" => + Second(expression(ctx.source)) + case other => + throw new ParseException(s"Literals of type '$other' are currently not supported.", ctx) + } + } + /** * Create a (windowed) Function expression. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index e487693927ab6..c486ad700f362 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -78,7 +78,7 @@ abstract class LogicalPlan schema.map { field => resolve(field.name :: Nil, resolver).map { case a: AttributeReference => a - case other => sys.error(s"can not handle nested schema yet... plan $this") + case _ => sys.error(s"can not handle nested schema yet... plan $this") }.getOrElse { throw new AnalysisException( s"Unable to resolve ${field.name} given [${output.map(_.name).mkString(", ")}]") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala new file mode 100644 index 0000000000000..28e6940f3cca3 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.logical.Range + +class CanonicalizeSuite extends SparkFunSuite { + + test("SPARK-24276: IN expression with different order are semantically equal") { + val range = Range(1, 1, 1, 1) + val idAttr = range.output.head + + val in1 = In(idAttr, Seq(Literal(1), Literal(2))) + val in2 = In(idAttr, Seq(Literal(2), Literal(1))) + val in3 = In(idAttr, Seq(Literal(1), Literal(2), Literal(3))) + + assert(in1.canonicalized.semanticHash() == in2.canonicalized.semanticHash()) + assert(in1.canonicalized.semanticHash() != in3.canonicalized.semanticHash()) + + assert(range.where(in1).sameResult(range.where(in2))) + assert(!range.where(in1).sameResult(range.where(in3))) + + val arrays1 = In(idAttr, Seq(CreateArray(Seq(Literal(1), Literal(2))), + CreateArray(Seq(Literal(2), Literal(1))))) + val arrays2 = In(idAttr, Seq(CreateArray(Seq(Literal(2), Literal(1))), + CreateArray(Seq(Literal(1), Literal(2))))) + val arrays3 = In(idAttr, Seq(CreateArray(Seq(Literal(1), Literal(2))), + CreateArray(Seq(Literal(3), Literal(1))))) + + assert(arrays1.canonicalized.semanticHash() == arrays2.canonicalized.semanticHash()) + assert(arrays1.canonicalized.semanticHash() != arrays3.canonicalized.semanticHash()) + + assert(range.where(arrays1).sameResult(range.where(arrays2))) + assert(!range.where(arrays1).sameResult(range.where(arrays3))) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 3fc0b08c56e02..f8ad624ce0e3d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -622,4 +622,62 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayRepeat(strArray, Literal(2)), Seq(Seq("hi", "hola"), Seq("hi", "hola"))) checkEvaluation(ArrayRepeat(Literal("hi"), Literal(null, IntegerType)), null) } + + test("Array remove") { + val a0 = Literal.create(Seq(1, 2, 3, 2, 2, 5), ArrayType(IntegerType)) + val a1 = Literal.create(Seq("b", "a", "a", "c", "b"), ArrayType(StringType)) + val a2 = Literal.create(Seq[String](null, "", null, ""), ArrayType(StringType)) + val a3 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType)) + val a4 = Literal.create(null, ArrayType(StringType)) + val a5 = Literal.create(Seq(1, null, 8, 9, null), ArrayType(IntegerType)) + val a6 = Literal.create(Seq(true, false, false, true), ArrayType(BooleanType)) + + checkEvaluation(ArrayRemove(a0, Literal(0)), Seq(1, 2, 3, 2, 2, 5)) + checkEvaluation(ArrayRemove(a0, Literal(1)), Seq(2, 3, 2, 2, 5)) + checkEvaluation(ArrayRemove(a0, Literal(2)), Seq(1, 3, 5)) + checkEvaluation(ArrayRemove(a0, Literal(3)), Seq(1, 2, 2, 2, 5)) + checkEvaluation(ArrayRemove(a0, Literal(5)), Seq(1, 2, 3, 2, 2)) + checkEvaluation(ArrayRemove(a0, Literal(null, IntegerType)), null) + + checkEvaluation(ArrayRemove(a1, Literal("")), Seq("b", "a", "a", "c", "b")) + checkEvaluation(ArrayRemove(a1, Literal("a")), Seq("b", "c", "b")) + checkEvaluation(ArrayRemove(a1, Literal("b")), Seq("a", "a", "c")) + checkEvaluation(ArrayRemove(a1, Literal("c")), Seq("b", "a", "a", "b")) + + checkEvaluation(ArrayRemove(a2, Literal("")), Seq(null, null)) + checkEvaluation(ArrayRemove(a2, Literal(null, StringType)), null) + + checkEvaluation(ArrayRemove(a3, Literal(1)), Seq.empty[Integer]) + + checkEvaluation(ArrayRemove(a4, Literal("a")), null) + + checkEvaluation(ArrayRemove(a5, Literal(9)), Seq(1, null, 8, null)) + checkEvaluation(ArrayRemove(a6, Literal(false)), Seq(true, true)) + + // complex data types + val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2), + Array[Byte](1, 2), Array[Byte](5, 6)), ArrayType(BinaryType)) + val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null), + ArrayType(BinaryType)) + val b2 = Literal.create(Seq[Array[Byte]](null, Array[Byte](1, 2)), + ArrayType(BinaryType)) + val nullBinary = Literal.create(null, BinaryType) + + val dataToRemove1 = Literal.create(Array[Byte](5, 6), BinaryType) + checkEvaluation(ArrayRemove(b0, dataToRemove1), + Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](1, 2))) + checkEvaluation(ArrayRemove(b0, nullBinary), null) + checkEvaluation(ArrayRemove(b1, dataToRemove1), Seq[Array[Byte]](Array[Byte](2, 1), null)) + checkEvaluation(ArrayRemove(b2, dataToRemove1), Seq[Array[Byte]](null, Array[Byte](1, 2))) + + val c0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)), + ArrayType(ArrayType(IntegerType))) + val c1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)), + ArrayType(ArrayType(IntegerType))) + val c2 = Literal.create(Seq[Seq[Int]](null, Seq[Int](2, 1)), ArrayType(ArrayType(IntegerType))) + val dataToRemove2 = Literal.create(Seq[Int](1, 2), ArrayType(IntegerType)) + checkEvaluation(ArrayRemove(c0, dataToRemove2), Seq[Seq[Int]](Seq[Int](3, 4))) + checkEvaluation(ArrayRemove(c1, dataToRemove2), Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1))) + checkEvaluation(ArrayRemove(c2, dataToRemove2), Seq[Seq[Int]](null, Seq[Int](2, 1))) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MaskExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MaskExpressionsSuite.scala new file mode 100644 index 0000000000000..4d69dc32ace82 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MaskExpressionsSuite.scala @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.types.{IntegerType, StringType} + +class MaskExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("mask") { + checkEvaluation(Mask(Literal("abcd-EFGH-8765-4321"), "U", "l", "#"), "llll-UUUU-####-####") + checkEvaluation( + new Mask(Literal("abcd-EFGH-8765-4321"), Literal("U"), Literal("l"), Literal("#")), + "llll-UUUU-####-####") + checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321"), Literal("U"), Literal("l")), + "llll-UUUU-nnnn-nnnn") + checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321"), Literal("U")), "xxxx-UUUU-nnnn-nnnn") + checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321")), "xxxx-XXXX-nnnn-nnnn") + checkEvaluation(new Mask(Literal(null, StringType)), null) + checkEvaluation(Mask(Literal("abcd-EFGH-8765-4321"), null, "l", "#"), "llll-XXXX-####-####") + checkEvaluation(new Mask( + Literal("abcd-EFGH-8765-4321"), + Literal(null, StringType), + Literal(null, StringType), + Literal(null, StringType)), "xxxx-XXXX-nnnn-nnnn") + checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321"), Literal("Upper")), + "xxxx-UUUU-nnnn-nnnn") + checkEvaluation(new Mask(Literal("")), "") + checkEvaluation(new Mask(Literal("abcd-EFGH-8765-4321"), Literal("")), "xxxx-XXXX-nnnn-nnnn") + checkEvaluation(Mask(Literal("abcd-EFGH-8765-4321"), "", "", ""), "xxxx-XXXX-nnnn-nnnn") + // scalastyle:off nonascii + checkEvaluation(Mask(Literal("Ul9U"), "\u2200", null, null), "\u2200xn\u2200") + checkEvaluation(new Mask(Literal("Hello World, こんにちは, 𠀋"), Literal("あ"), Literal("𡈽")), + "あ𡈽𡈽𡈽𡈽 あ𡈽𡈽𡈽𡈽, こんにちは, 𠀋") + // scalastyle:on nonascii + intercept[AnalysisException] { + checkEvaluation(new Mask(Literal(""), Literal(1)), "") + } + } + + test("mask_first_n") { + checkEvaluation(MaskFirstN(Literal("aB3d-EFGH-8765"), 6, "U", "l", "#"), + "lU#l-UFGH-8765") + checkEvaluation(new MaskFirstN( + Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l"), Literal("#")), + "llll-UFGH-8765-4321") + checkEvaluation( + new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l")), + "llll-UFGH-8765-4321") + checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U")), + "xxxx-UFGH-8765-4321") + checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6)), + "xxxx-XFGH-8765-4321") + intercept[AnalysisException] { + checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal("U")), "") + } + checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321")), "xxxx-EFGH-8765-4321") + checkEvaluation(new MaskFirstN(Literal(null, StringType)), null) + checkEvaluation(MaskFirstN(Literal("abcd-EFGH-8765-4321"), 4, "U", "l", null), + "llll-EFGH-8765-4321") + checkEvaluation(new MaskFirstN( + Literal("abcd-EFGH-8765-4321"), + Literal(null, IntegerType), + Literal(null, StringType), + Literal(null, StringType), + Literal(null, StringType)), "xxxx-EFGH-8765-4321") + checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("Upper")), + "xxxx-UFGH-8765-4321") + checkEvaluation(new MaskFirstN(Literal("")), "") + checkEvaluation(new MaskFirstN(Literal("abcd-EFGH-8765-4321"), Literal(4), Literal("")), + "xxxx-EFGH-8765-4321") + checkEvaluation(MaskFirstN(Literal("abcd-EFGH-8765-4321"), 1000, "", "", ""), + "xxxx-XXXX-nnnn-nnnn") + checkEvaluation(MaskFirstN(Literal("abcd-EFGH-8765-4321"), -1, "", "", ""), + "abcd-EFGH-8765-4321") + // scalastyle:off nonascii + checkEvaluation(MaskFirstN(Literal("Ul9U"), 2, "\u2200", null, null), "\u2200x9U") + checkEvaluation(new MaskFirstN(Literal("あ, 𠀋, Hello World"), Literal(10)), + "あ, 𠀋, Xxxxo World") + // scalastyle:on nonascii + } + + test("mask_last_n") { + checkEvaluation(MaskLastN(Literal("abcd-EFGH-aB3d"), 6, "U", "l", "#"), + "abcd-EFGU-lU#l") + checkEvaluation(new MaskLastN( + Literal("abcd-EFGH-8765"), Literal(6), Literal("U"), Literal("l"), Literal("#")), + "abcd-EFGU-####") + checkEvaluation( + new MaskLastN(Literal("abcd-EFGH-8765"), Literal(6), Literal("U"), Literal("l")), + "abcd-EFGU-nnnn") + checkEvaluation( + new MaskLastN(Literal("abcd-EFGH-8765"), Literal(6), Literal("U")), + "abcd-EFGU-nnnn") + checkEvaluation( + new MaskLastN(Literal("abcd-EFGH-8765"), Literal(6)), + "abcd-EFGX-nnnn") + intercept[AnalysisException] { + checkEvaluation(new MaskLastN(Literal("abcd-EFGH-8765"), Literal("U")), "") + } + checkEvaluation(new MaskLastN(Literal("abcd-EFGH-8765-4321")), "abcd-EFGH-8765-nnnn") + checkEvaluation(new MaskLastN(Literal(null, StringType)), null) + checkEvaluation(MaskLastN(Literal("abcd-EFGH-8765-4321"), 4, "U", "l", null), + "abcd-EFGH-8765-nnnn") + checkEvaluation(new MaskLastN( + Literal("abcd-EFGH-8765-4321"), + Literal(null, IntegerType), + Literal(null, StringType), + Literal(null, StringType), + Literal(null, StringType)), "abcd-EFGH-8765-nnnn") + checkEvaluation(new MaskLastN(Literal("abcd-EFGH-8765-4321"), Literal(12), Literal("Upper")), + "abcd-EFUU-nnnn-nnnn") + checkEvaluation(new MaskLastN(Literal("")), "") + checkEvaluation(new MaskLastN(Literal("abcd-EFGH-8765-4321"), Literal(16), Literal("")), + "abcx-XXXX-nnnn-nnnn") + checkEvaluation(MaskLastN(Literal("abcd-EFGH-8765-4321"), 1000, "", "", ""), + "xxxx-XXXX-nnnn-nnnn") + checkEvaluation(MaskLastN(Literal("abcd-EFGH-8765-4321"), -1, "", "", ""), + "abcd-EFGH-8765-4321") + // scalastyle:off nonascii + checkEvaluation(MaskLastN(Literal("Ul9U"), 2, "\u2200", null, null), "Uln\u2200") + checkEvaluation(new MaskLastN(Literal("あ, 𠀋, Hello World あ 𠀋"), Literal(10)), + "あ, 𠀋, Hello Xxxxx あ 𠀋") + // scalastyle:on nonascii + } + + test("mask_show_first_n") { + checkEvaluation(MaskShowFirstN(Literal("abcd-EFGH-8765-aB3d"), 6, "U", "l", "#"), + "abcd-EUUU-####-lU#l") + checkEvaluation(new MaskShowFirstN( + Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l"), Literal("#")), + "abcd-EUUU-####-####") + checkEvaluation( + new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l")), + "abcd-EUUU-nnnn-nnnn") + checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U")), + "abcd-EUUU-nnnn-nnnn") + checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6)), + "abcd-EXXX-nnnn-nnnn") + intercept[AnalysisException] { + checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal("U")), "") + } + checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321")), "abcd-XXXX-nnnn-nnnn") + checkEvaluation(new MaskShowFirstN(Literal(null, StringType)), null) + checkEvaluation(MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), 4, "U", "l", null), + "abcd-UUUU-nnnn-nnnn") + checkEvaluation(new MaskShowFirstN( + Literal("abcd-EFGH-8765-4321"), + Literal(null, IntegerType), + Literal(null, StringType), + Literal(null, StringType), + Literal(null, StringType)), "abcd-XXXX-nnnn-nnnn") + checkEvaluation( + new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("Upper")), + "abcd-EUUU-nnnn-nnnn") + checkEvaluation(new MaskShowFirstN(Literal("")), "") + checkEvaluation(new MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), Literal(4), Literal("")), + "abcd-XXXX-nnnn-nnnn") + checkEvaluation(MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), 1000, "", "", ""), + "abcd-EFGH-8765-4321") + checkEvaluation(MaskShowFirstN(Literal("abcd-EFGH-8765-4321"), -1, "", "", ""), + "xxxx-XXXX-nnnn-nnnn") + // scalastyle:off nonascii + checkEvaluation(MaskShowFirstN(Literal("Ul9U"), 2, "\u2200", null, null), "Uln\u2200") + checkEvaluation(new MaskShowFirstN(Literal("あ, 𠀋, Hello World"), Literal(10)), + "あ, 𠀋, Hellx Xxxxx") + // scalastyle:on nonascii + } + + test("mask_show_last_n") { + checkEvaluation(MaskShowLastN(Literal("aB3d-EFGH-8765"), 6, "U", "l", "#"), + "lU#l-UUUH-8765") + checkEvaluation(new MaskShowLastN( + Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l"), Literal("#")), + "llll-UUUU-###5-4321") + checkEvaluation( + new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U"), Literal("l")), + "llll-UUUU-nnn5-4321") + checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("U")), + "xxxx-UUUU-nnn5-4321") + checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(6)), + "xxxx-XXXX-nnn5-4321") + intercept[AnalysisException] { + checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal("U")), "") + } + checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321")), "xxxx-XXXX-nnnn-4321") + checkEvaluation(new MaskShowLastN(Literal(null, StringType)), null) + checkEvaluation(MaskShowLastN(Literal("abcd-EFGH-8765-4321"), 4, "U", "l", null), + "llll-UUUU-nnnn-4321") + checkEvaluation(new MaskShowLastN( + Literal("abcd-EFGH-8765-4321"), + Literal(null, IntegerType), + Literal(null, StringType), + Literal(null, StringType), + Literal(null, StringType)), "xxxx-XXXX-nnnn-4321") + checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(6), Literal("Upper")), + "xxxx-UUUU-nnn5-4321") + checkEvaluation(new MaskShowLastN(Literal("")), "") + checkEvaluation(new MaskShowLastN(Literal("abcd-EFGH-8765-4321"), Literal(4), Literal("")), + "xxxx-XXXX-nnnn-4321") + checkEvaluation(MaskShowLastN(Literal("abcd-EFGH-8765-4321"), 1000, "", "", ""), + "abcd-EFGH-8765-4321") + checkEvaluation(MaskShowLastN(Literal("abcd-EFGH-8765-4321"), -1, "", "", ""), + "xxxx-XXXX-nnnn-nnnn") + // scalastyle:off nonascii + checkEvaluation(MaskShowLastN(Literal("Ul9U"), 2, "\u2200", null, null), "\u2200x9U") + checkEvaluation(new MaskShowLastN(Literal("あ, 𠀋, Hello World"), Literal(10)), + "あ, 𠀋, Xello World") + // scalastyle:on nonascii + } + + test("mask_hash") { + checkEvaluation(MaskHash(Literal("abcd-EFGH-8765-4321")), "60c713f5ec6912229d2060df1c322776") + checkEvaluation(MaskHash(Literal("")), "d41d8cd98f00b204e9800998ecf8427e") + checkEvaluation(MaskHash(Literal(null, StringType)), null) + // scalastyle:off nonascii + checkEvaluation(MaskHash(Literal("\u2200x9U")), "f1243ef123d516b1f32a3a75309e5711") + // scalastyle:on nonascii + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index f1a6f9b8889fa..aa334e040d5fc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -706,6 +706,30 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { "15,159,339,180,002,773.2778") checkEvaluation(FormatNumber(Literal.create(null, IntegerType), Literal(3)), null) assert(FormatNumber(Literal.create(null, NullType), Literal(3)).resolved === false) + + checkEvaluation(FormatNumber(Literal(12332.123456), Literal("##############.###")), "12332.123") + checkEvaluation(FormatNumber(Literal(12332.123456), Literal("##.###")), "12332.123") + checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Byte]), Literal("##.####")), "4") + checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Short]), Literal("##.####")), "4") + checkEvaluation(FormatNumber(Literal(4.0f), Literal("##.###")), "4") + checkEvaluation(FormatNumber(Literal(4), Literal("##.###")), "4") + checkEvaluation(FormatNumber(Literal(12831273.23481d), + Literal("###,###,###,###,###.###")), "12,831,273.235") + checkEvaluation(FormatNumber(Literal(12831273.83421d), Literal("")), "12,831,274") + checkEvaluation(FormatNumber(Literal(123123324123L), Literal("###,###,###,###,###.###")), + "123,123,324,123") + checkEvaluation( + FormatNumber(Literal(Decimal(123123324123L) * Decimal(123123.21234d)), + Literal("###,###,###,###,###.####")), "15,159,339,180,002,773.2778") + checkEvaluation(FormatNumber(Literal.create(null, IntegerType), Literal("##.###")), null) + assert(FormatNumber(Literal.create(null, NullType), Literal("##.###")).resolved === false) + + checkEvaluation(FormatNumber(Literal(12332.123456), Literal("#,###,###,###,###,###,##0")), + "12,332") + checkEvaluation(FormatNumber( + Literal.create(null, IntegerType), Literal.create(null, StringType)), null) + checkEvaluation(FormatNumber( + Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null) } test("find in set") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala index 89903c2825125..ff0de0fb7c1f0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala @@ -51,7 +51,7 @@ class TableIdentifierParserSuite extends SparkFunSuite { "rollup", "row", "rows", "set", "smallint", "table", "timestamp", "to", "trigger", "true", "truncate", "update", "user", "values", "with", "regexp", "rlike", "bigint", "binary", "boolean", "current_date", "current_timestamp", "date", "double", "float", - "int", "smallint", "timestamp", "at", "position", "both", "leading", "trailing") + "int", "smallint", "timestamp", "at", "position", "both", "leading", "trailing", "extract") val hiveStrictNonReservedKeyword = Seq("anti", "full", "inner", "left", "semi", "right", "natural", "union", "intersect", "except", "database", "on", "join", "cross", "select", "from", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index ad0efbae89830..b3e59f53ee3de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import scala.collection.JavaConverters._ import scala.language.implicitConversions import org.apache.spark.annotation.InterfaceStability @@ -786,6 +787,24 @@ class Column(val expr: Expression) extends Logging { @scala.annotation.varargs def isin(list: Any*): Column = withExpr { In(expr, list.map(lit(_).expr)) } + /** + * A boolean expression that is evaluated to true if the value of this expression is contained + * by the provided collection. + * + * @group expr_ops + * @since 2.4.0 + */ + def isInCollection(values: scala.collection.Iterable[_]): Column = isin(values.toSeq: _*) + + /** + * A boolean expression that is evaluated to true if the value of this expression is contained + * by the provided collection. + * + * @group java_expr_ops + * @since 2.4.0 + */ + def isInCollection(values: java.lang.Iterable[_]): Column = isInCollection(values.asScala) + /** * SQL like expression. Returns a boolean column based on a SQL LIKE match. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index ac4580a0919ad..de6be5f76e15a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -22,6 +22,7 @@ import java.util.{Locale, Properties} import scala.collection.JavaConverters._ import com.fasterxml.jackson.databind.ObjectMapper +import com.univocity.parsers.csv.CsvParser import org.apache.spark.Partition import org.apache.spark.annotation.InterfaceStability @@ -474,6 +475,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * it determines the columns as string types and it reads only the first line to determine the * names and the number of fields. * + * If the enforceSchema is set to `false`, only the CSV header in the first line is checked + * to conform specified or inferred schema. + * * @param csvDataset input Dataset with one CSV row per record * @since 2.2.0 */ @@ -499,6 +503,13 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) val linesWithoutHeader: RDD[String] = maybeFirstLine.map { firstLine => + CSVDataSource.checkHeader( + firstLine, + new CsvParser(parsedOptions.asParserSettings), + actualSchema, + csvDataset.getClass.getCanonicalName, + parsedOptions.enforceSchema, + sparkSession.sessionState.conf.caseSensitiveAnalysis) filteredLines.rdd.mapPartitions(CSVUtils.filterHeaderLine(_, firstLine, parsedOptions)) }.getOrElse(filteredLines.rdd) @@ -539,6 +550,13 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `comment` (default empty string): sets a single character used for skipping lines * beginning with this character. By default, it is disabled.
  • *
  • `header` (default `false`): uses the first line as names of columns.
  • + *
  • `enforceSchema` (default `true`): If it is set to `true`, the specified or inferred schema + * will be forcibly applied to datasource files, and headers in CSV files will be ignored. + * If the option is set to `false`, the schema will be validated against all headers in CSV files + * in the case when the `header` option is set to `true`. Field names in the schema + * and column names in CSV headers are checked by their positions taking into account + * `spark.sql.caseSensitive`. Though the default value is true, it is recommended to disable + * the `enforceSchema` option to avoid incorrect results.
  • *
  • `inferSchema` (default `false`): infers the input schema automatically from data. It * requires one extra pass over the data.
  • *
  • `samplingRatio` (default is 1.0): defines fraction of rows used for schema inferring.
  • @@ -583,6 +601,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`. *
  • `multiLine` (default `false`): parse one record, which may span multiple lines.
  • * + * * @since 2.0.0 */ @scala.annotation.varargs diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index abb5ae53f4d73..f5526104690d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -231,16 +231,17 @@ class Dataset[T] private[sql]( } /** - * Compose the string representing rows for output + * Get rows represented in Sequence by specific truncate and vertical requirement. * - * @param _numRows Number of rows to show + * @param numRows Number of rows to return * @param truncate If set to more than 0, truncates strings to `truncate` characters and * all cells will be aligned right. - * @param vertical If set to true, prints output rows vertically (one line per column value). + * @param vertical If set to true, the rows to return do not need truncate. */ - private[sql] def showString( - _numRows: Int, truncate: Int = 20, vertical: Boolean = false): String = { - val numRows = _numRows.max(0).min(Int.MaxValue - 1) + private[sql] def getRows( + numRows: Int, + truncate: Int, + vertical: Boolean): Seq[Seq[String]] = { val newDf = toDF() val castCols = newDf.logicalPlan.output.map { col => // Since binary types in top-level schema fields have a specific format to print, @@ -251,14 +252,12 @@ class Dataset[T] private[sql]( Column(col).cast(StringType) } } - val takeResult = newDf.select(castCols: _*).take(numRows + 1) - val hasMoreData = takeResult.length > numRows - val data = takeResult.take(numRows) + val data = newDf.select(castCols: _*).take(numRows + 1) // For array values, replace Seq and Array with square brackets // For cells that are beyond `truncate` characters, replace it with the // first `truncate-3` and "..." - val rows: Seq[Seq[String]] = schema.fieldNames.toSeq +: data.map { row => + schema.fieldNames.toSeq +: data.map { row => row.toSeq.map { cell => val str = cell match { case null => "null" @@ -274,6 +273,26 @@ class Dataset[T] private[sql]( } }: Seq[String] } + } + + /** + * Compose the string representing rows for output + * + * @param _numRows Number of rows to show + * @param truncate If set to more than 0, truncates strings to `truncate` characters and + * all cells will be aligned right. + * @param vertical If set to true, prints output rows vertically (one line per column value). + */ + private[sql] def showString( + _numRows: Int, + truncate: Int = 20, + vertical: Boolean = false): String = { + val numRows = _numRows.max(0).min(Int.MaxValue - 1) + // Get rows represented by Seq[Seq[String]], we may get one more line if it has more data. + val tmpRows = getRows(numRows, truncate, vertical) + + val hasMoreData = tmpRows.length - 1 > numRows + val rows = tmpRows.take(numRows + 1) val sb = new StringBuilder val numCols = schema.fieldNames.length @@ -291,31 +310,25 @@ class Dataset[T] private[sql]( } } + val paddedRows = rows.map { row => + row.zipWithIndex.map { case (cell, i) => + if (truncate > 0) { + StringUtils.leftPad(cell, colWidths(i)) + } else { + StringUtils.rightPad(cell, colWidths(i)) + } + } + } + // Create SeparateLine val sep: String = colWidths.map("-" * _).addString(sb, "+", "+", "+\n").toString() // column names - rows.head.zipWithIndex.map { case (cell, i) => - if (truncate > 0) { - StringUtils.leftPad(cell, colWidths(i)) - } else { - StringUtils.rightPad(cell, colWidths(i)) - } - }.addString(sb, "|", "|", "|\n") - + paddedRows.head.addString(sb, "|", "|", "|\n") sb.append(sep) // data - rows.tail.foreach { - _.zipWithIndex.map { case (cell, i) => - if (truncate > 0) { - StringUtils.leftPad(cell.toString, colWidths(i)) - } else { - StringUtils.rightPad(cell.toString, colWidths(i)) - } - }.addString(sb, "|", "|", "|\n") - } - + paddedRows.tail.foreach(_.addString(sb, "|", "|", "|\n")) sb.append(sep) } else { // Extended display mode enabled @@ -346,7 +359,7 @@ class Dataset[T] private[sql]( } // Print a footer - if (vertical && data.isEmpty) { + if (vertical && rows.tail.isEmpty) { // In a vertical mode, print an empty row set explicitly sb.append("(0 rows)\n") } else if (hasMoreData) { @@ -3209,6 +3222,19 @@ class Dataset[T] private[sql]( } } + private[sql] def getRowsToPython( + _numRows: Int, + truncate: Int, + vertical: Boolean): Array[Any] = { + EvaluatePython.registerPicklers() + val numRows = _numRows.max(0).min(Int.MaxValue - 1) + val rows = getRows(numRows, truncate, vertical).map(_.toArray).toArray + val toJava: (Any) => Any = EvaluatePython.toJava(_, ArrayType(ArrayType(StringType))) + val iter: Iterator[Array[Byte]] = new SerDeUtil.AutoBatchedPickler( + rows.iterator.map(toJava)) + PythonRDD.serveIterator(iter, "serve-GetRows") + } + /** * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index b97a87a122406..be34387f6a874 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -384,9 +384,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { val (functionsWithDistinct, functionsWithoutDistinct) = aggregateExpressions.partition(_.isDistinct) - if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { + if (functionsWithDistinct.map(_.aggregateFunction.children.toSet).distinct.length > 1) { // This is a sanity check. We should not reach here when we have multiple distinct - // column sets. Our MultipleDistinctRewriter should take care this case. + // column sets. Our `RewriteDistinctAggregates` should take care this case. sys.error("You hit a query analyzer bug. Please report your query to " + "Spark user mailing list.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala index 69c03d862391e..ba7d2b7cbdb1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala @@ -31,7 +31,7 @@ import org.apache.spark.util.SerializableConfiguration /** - * Simple metrics collected during an instance of [[FileFormatWriter.ExecuteWriteTask]]. + * Simple metrics collected during an instance of [[FileFormatDataWriter]]. * These were first introduced in https://github.com/apache/spark/pull/18159 (SPARK-20703). */ case class BasicWriteTaskStats( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala new file mode 100644 index 0000000000000..6499328e89ce7 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala @@ -0,0 +1,313 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources + +import scala.collection.mutable + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce.TaskAttemptContext + +import org.apache.spark.internal.io.FileCommitProtocol +import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec +import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.StringType +import org.apache.spark.util.SerializableConfiguration + +/** + * Abstract class for writing out data in a single Spark task. + * Exceptions thrown by the implementation of this trait will automatically trigger task aborts. + */ +abstract class FileFormatDataWriter( + description: WriteJobDescription, + taskAttemptContext: TaskAttemptContext, + committer: FileCommitProtocol) { + /** + * Max number of files a single task writes out due to file size. In most cases the number of + * files written should be very small. This is just a safe guard to protect some really bad + * settings, e.g. maxRecordsPerFile = 1. + */ + protected val MAX_FILE_COUNTER: Int = 1000 * 1000 + protected val updatedPartitions: mutable.Set[String] = mutable.Set[String]() + protected var currentWriter: OutputWriter = _ + + /** Trackers for computing various statistics on the data as it's being written out. */ + protected val statsTrackers: Seq[WriteTaskStatsTracker] = + description.statsTrackers.map(_.newTaskInstance()) + + protected def releaseResources(): Unit = { + if (currentWriter != null) { + try { + currentWriter.close() + } finally { + currentWriter = null + } + } + } + + /** Writes a record */ + def write(record: InternalRow): Unit + + /** + * Returns the summary of relative information which + * includes the list of partition strings written out. The list of partitions is sent back + * to the driver and used to update the catalog. Other information will be sent back to the + * driver too and used to e.g. update the metrics in UI. + */ + def commit(): WriteTaskResult = { + releaseResources() + val summary = ExecutedWriteSummary( + updatedPartitions = updatedPartitions.toSet, + stats = statsTrackers.map(_.getFinalStats())) + WriteTaskResult(committer.commitTask(taskAttemptContext), summary) + } + + def abort(): Unit = { + try { + releaseResources() + } finally { + committer.abortTask(taskAttemptContext) + } + } +} + +/** FileFormatWriteTask for empty partitions */ +class EmptyDirectoryDataWriter( + description: WriteJobDescription, + taskAttemptContext: TaskAttemptContext, + committer: FileCommitProtocol +) extends FileFormatDataWriter(description, taskAttemptContext, committer) { + override def write(record: InternalRow): Unit = {} +} + +/** Writes data to a single directory (used for non-dynamic-partition writes). */ +class SingleDirectoryDataWriter( + description: WriteJobDescription, + taskAttemptContext: TaskAttemptContext, + committer: FileCommitProtocol) + extends FileFormatDataWriter(description, taskAttemptContext, committer) { + private var fileCounter: Int = _ + private var recordsInFile: Long = _ + // Initialize currentWriter and statsTrackers + newOutputWriter() + + private def newOutputWriter(): Unit = { + recordsInFile = 0 + releaseResources() + + val ext = description.outputWriterFactory.getFileExtension(taskAttemptContext) + val currentPath = committer.newTaskTempFile( + taskAttemptContext, + None, + f"-c$fileCounter%03d" + ext) + + currentWriter = description.outputWriterFactory.newInstance( + path = currentPath, + dataSchema = description.dataColumns.toStructType, + context = taskAttemptContext) + + statsTrackers.foreach(_.newFile(currentPath)) + } + + override def write(record: InternalRow): Unit = { + if (description.maxRecordsPerFile > 0 && recordsInFile >= description.maxRecordsPerFile) { + fileCounter += 1 + assert(fileCounter < MAX_FILE_COUNTER, + s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER") + + newOutputWriter() + } + + currentWriter.write(record) + statsTrackers.foreach(_.newRow(record)) + recordsInFile += 1 + } +} + +/** + * Writes data to using dynamic partition writes, meaning this single function can write to + * multiple directories (partitions) or files (bucketing). + */ +class DynamicPartitionDataWriter( + description: WriteJobDescription, + taskAttemptContext: TaskAttemptContext, + committer: FileCommitProtocol) + extends FileFormatDataWriter(description, taskAttemptContext, committer) { + + /** Flag saying whether or not the data to be written out is partitioned. */ + private val isPartitioned = description.partitionColumns.nonEmpty + + /** Flag saying whether or not the data to be written out is bucketed. */ + private val isBucketed = description.bucketIdExpression.isDefined + + assert(isPartitioned || isBucketed, + s"""DynamicPartitionWriteTask should be used for writing out data that's either + |partitioned or bucketed. In this case neither is true. + |WriteJobDescription: $description + """.stripMargin) + + private var fileCounter: Int = _ + private var recordsInFile: Long = _ + private var currentPartionValues: Option[UnsafeRow] = None + private var currentBucketId: Option[Int] = None + + /** Extracts the partition values out of an input row. */ + private lazy val getPartitionValues: InternalRow => UnsafeRow = { + val proj = UnsafeProjection.create(description.partitionColumns, description.allColumns) + row => proj(row) + } + + /** Expression that given partition columns builds a path string like: col1=val/col2=val/... */ + private lazy val partitionPathExpression: Expression = Concat( + description.partitionColumns.zipWithIndex.flatMap { case (c, i) => + val partitionName = ScalaUDF( + ExternalCatalogUtils.getPartitionPathString _, + StringType, + Seq(Literal(c.name), Cast(c, StringType, Option(description.timeZoneId)))) + if (i == 0) Seq(partitionName) else Seq(Literal(Path.SEPARATOR), partitionName) + }) + + /** Evaluates the `partitionPathExpression` above on a row of `partitionValues` and returns + * the partition string. */ + private lazy val getPartitionPath: InternalRow => String = { + val proj = UnsafeProjection.create(Seq(partitionPathExpression), description.partitionColumns) + row => proj(row).getString(0) + } + + /** Given an input row, returns the corresponding `bucketId` */ + private lazy val getBucketId: InternalRow => Int = { + val proj = + UnsafeProjection.create(description.bucketIdExpression.toSeq, description.allColumns) + row => proj(row).getInt(0) + } + + /** Returns the data columns to be written given an input row */ + private val getOutputRow = + UnsafeProjection.create(description.dataColumns, description.allColumns) + + /** + * Opens a new OutputWriter given a partition key and/or a bucket id. + * If bucket id is specified, we will append it to the end of the file name, but before the + * file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet + * + * @param partitionValues the partition which all tuples being written by this `OutputWriter` + * belong to + * @param bucketId the bucket which all tuples being written by this `OutputWriter` belong to + */ + private def newOutputWriter(partitionValues: Option[InternalRow], bucketId: Option[Int]): Unit = { + recordsInFile = 0 + releaseResources() + + val partDir = partitionValues.map(getPartitionPath(_)) + partDir.foreach(updatedPartitions.add) + + val bucketIdStr = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("") + + // This must be in a form that matches our bucketing format. See BucketingUtils. + val ext = f"$bucketIdStr.c$fileCounter%03d" + + description.outputWriterFactory.getFileExtension(taskAttemptContext) + + val customPath = partDir.flatMap { dir => + description.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir)) + } + val currentPath = if (customPath.isDefined) { + committer.newTaskTempFileAbsPath(taskAttemptContext, customPath.get, ext) + } else { + committer.newTaskTempFile(taskAttemptContext, partDir, ext) + } + + currentWriter = description.outputWriterFactory.newInstance( + path = currentPath, + dataSchema = description.dataColumns.toStructType, + context = taskAttemptContext) + + statsTrackers.foreach(_.newFile(currentPath)) + } + + override def write(record: InternalRow): Unit = { + val nextPartitionValues = if (isPartitioned) Some(getPartitionValues(record)) else None + val nextBucketId = if (isBucketed) Some(getBucketId(record)) else None + + if (currentPartionValues != nextPartitionValues || currentBucketId != nextBucketId) { + // See a new partition or bucket - write to a new partition dir (or a new bucket file). + if (isPartitioned && currentPartionValues != nextPartitionValues) { + currentPartionValues = Some(nextPartitionValues.get.copy()) + statsTrackers.foreach(_.newPartition(currentPartionValues.get)) + } + if (isBucketed) { + currentBucketId = nextBucketId + statsTrackers.foreach(_.newBucket(currentBucketId.get)) + } + + fileCounter = 0 + newOutputWriter(currentPartionValues, currentBucketId) + } else if (description.maxRecordsPerFile > 0 && + recordsInFile >= description.maxRecordsPerFile) { + // Exceeded the threshold in terms of the number of records per file. + // Create a new file by increasing the file counter. + fileCounter += 1 + assert(fileCounter < MAX_FILE_COUNTER, + s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER") + + newOutputWriter(currentPartionValues, currentBucketId) + } + val outputRow = getOutputRow(record) + currentWriter.write(outputRow) + statsTrackers.foreach(_.newRow(outputRow)) + recordsInFile += 1 + } +} + +/** A shared job description for all the write tasks. */ +class WriteJobDescription( + val uuid: String, // prevent collision between different (appending) write jobs + val serializableHadoopConf: SerializableConfiguration, + val outputWriterFactory: OutputWriterFactory, + val allColumns: Seq[Attribute], + val dataColumns: Seq[Attribute], + val partitionColumns: Seq[Attribute], + val bucketIdExpression: Option[Expression], + val path: String, + val customPartitionLocations: Map[TablePartitionSpec, String], + val maxRecordsPerFile: Long, + val timeZoneId: String, + val statsTrackers: Seq[WriteJobStatsTracker]) + extends Serializable { + + assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ dataColumns), + s""" + |All columns: ${allColumns.mkString(", ")} + |Partition columns: ${partitionColumns.mkString(", ")} + |Data columns: ${dataColumns.mkString(", ")} + """.stripMargin) +} + +/** The result of a successful write task. */ +case class WriteTaskResult(commitMsg: TaskCommitMessage, summary: ExecutedWriteSummary) + +/** + * Wrapper class for the metrics of writing data out. + * + * @param updatedPartitions the partitions updated during writing data out. Only valid + * for dynamic partition. + * @param stats one `WriteTaskStats` object for every `WriteJobStatsTracker` that the job had. + */ +case class ExecutedWriteSummary( + updatedPartitions: Set[String], + stats: Seq[WriteTaskStats]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 401597f967218..52da8356ab835 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.execution.datasources import java.util.{Date, UUID} -import scala.collection.mutable - import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce._ @@ -30,62 +28,25 @@ import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.internal.io.{FileCommitProtocol, SparkHadoopWriterUtils} -import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.catalog.{BucketSpec, ExternalCatalogUtils} +import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec -import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, _} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.execution.{SortExec, SparkPlan, SQLExecution} -import org.apache.spark.sql.types.StringType import org.apache.spark.util.{SerializableConfiguration, Utils} /** A helper object for writing FileFormat data out to a location. */ object FileFormatWriter extends Logging { - - /** - * Max number of files a single task writes out due to file size. In most cases the number of - * files written should be very small. This is just a safe guard to protect some really bad - * settings, e.g. maxRecordsPerFile = 1. - */ - private val MAX_FILE_COUNTER = 1000 * 1000 - /** Describes how output files should be placed in the filesystem. */ case class OutputSpec( - outputPath: String, - customPartitionLocations: Map[TablePartitionSpec, String], - outputColumns: Seq[Attribute]) - - /** A shared job description for all the write tasks. */ - private class WriteJobDescription( - val uuid: String, // prevent collision between different (appending) write jobs - val serializableHadoopConf: SerializableConfiguration, - val outputWriterFactory: OutputWriterFactory, - val allColumns: Seq[Attribute], - val dataColumns: Seq[Attribute], - val partitionColumns: Seq[Attribute], - val bucketIdExpression: Option[Expression], - val path: String, - val customPartitionLocations: Map[TablePartitionSpec, String], - val maxRecordsPerFile: Long, - val timeZoneId: String, - val statsTrackers: Seq[WriteJobStatsTracker]) - extends Serializable { - - assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ dataColumns), - s""" - |All columns: ${allColumns.mkString(", ")} - |Partition columns: ${partitionColumns.mkString(", ")} - |Data columns: ${dataColumns.mkString(", ")} - """.stripMargin) - } - - /** The result of a successful write task. */ - private case class WriteTaskResult(commitMsg: TaskCommitMessage, summary: ExecutedWriteSummary) + outputPath: String, + customPartitionLocations: Map[TablePartitionSpec, String], + outputColumns: Seq[Attribute]) /** * Basic work flow of this command is: @@ -262,30 +223,27 @@ object FileFormatWriter extends Logging { committer.setupTask(taskAttemptContext) - val writeTask = + val dataWriter = if (sparkPartitionId != 0 && !iterator.hasNext) { // In case of empty job, leave first partition to save meta for file format like parquet. - new EmptyDirectoryWriteTask(description) + new EmptyDirectoryDataWriter(description, taskAttemptContext, committer) } else if (description.partitionColumns.isEmpty && description.bucketIdExpression.isEmpty) { - new SingleDirectoryWriteTask(description, taskAttemptContext, committer) + new SingleDirectoryDataWriter(description, taskAttemptContext, committer) } else { - new DynamicPartitionWriteTask(description, taskAttemptContext, committer) + new DynamicPartitionDataWriter(description, taskAttemptContext, committer) } try { Utils.tryWithSafeFinallyAndFailureCallbacks(block = { // Execute the task to write rows out and commit the task. - val summary = writeTask.execute(iterator) - writeTask.releaseResources() - WriteTaskResult(committer.commitTask(taskAttemptContext), summary) - })(catchBlock = { - // If there is an error, release resource and then abort the task - try { - writeTask.releaseResources() - } finally { - committer.abortTask(taskAttemptContext) - logError(s"Job $jobId aborted.") + while (iterator.hasNext) { + dataWriter.write(iterator.next()) } + dataWriter.commit() + })(catchBlock = { + // If there is an error, abort the task + dataWriter.abort() + logError(s"Job $jobId aborted.") }) } catch { case e: FetchFailedException => @@ -302,7 +260,7 @@ object FileFormatWriter extends Logging { private def processStats( statsTrackers: Seq[WriteJobStatsTracker], statsPerTask: Seq[Seq[WriteTaskStats]]) - : Unit = { + : Unit = { val numStatsTrackers = statsTrackers.length assert(statsPerTask.forall(_.length == numStatsTrackers), @@ -321,281 +279,4 @@ object FileFormatWriter extends Logging { case (statsTracker, stats) => statsTracker.processStats(stats) } } - - /** - * A simple trait for writing out data in a single Spark task, without any concerns about how - * to commit or abort tasks. Exceptions thrown by the implementation of this trait will - * automatically trigger task aborts. - */ - private trait ExecuteWriteTask { - - /** - * Writes data out to files, and then returns the summary of relative information which - * includes the list of partition strings written out. The list of partitions is sent back - * to the driver and used to update the catalog. Other information will be sent back to the - * driver too and used to e.g. update the metrics in UI. - */ - def execute(iterator: Iterator[InternalRow]): ExecutedWriteSummary - def releaseResources(): Unit - } - - /** ExecuteWriteTask for empty partitions */ - private class EmptyDirectoryWriteTask(description: WriteJobDescription) - extends ExecuteWriteTask { - - val statsTrackers: Seq[WriteTaskStatsTracker] = - description.statsTrackers.map(_.newTaskInstance()) - - override def execute(iter: Iterator[InternalRow]): ExecutedWriteSummary = { - ExecutedWriteSummary( - updatedPartitions = Set.empty, - stats = statsTrackers.map(_.getFinalStats())) - } - - override def releaseResources(): Unit = {} - } - - /** Writes data to a single directory (used for non-dynamic-partition writes). */ - private class SingleDirectoryWriteTask( - description: WriteJobDescription, - taskAttemptContext: TaskAttemptContext, - committer: FileCommitProtocol) extends ExecuteWriteTask { - - private[this] var currentWriter: OutputWriter = _ - - val statsTrackers: Seq[WriteTaskStatsTracker] = - description.statsTrackers.map(_.newTaskInstance()) - - private def newOutputWriter(fileCounter: Int): Unit = { - val ext = description.outputWriterFactory.getFileExtension(taskAttemptContext) - val currentPath = committer.newTaskTempFile( - taskAttemptContext, - None, - f"-c$fileCounter%03d" + ext) - - currentWriter = description.outputWriterFactory.newInstance( - path = currentPath, - dataSchema = description.dataColumns.toStructType, - context = taskAttemptContext) - - statsTrackers.map(_.newFile(currentPath)) - } - - override def execute(iter: Iterator[InternalRow]): ExecutedWriteSummary = { - var fileCounter = 0 - var recordsInFile: Long = 0L - newOutputWriter(fileCounter) - - while (iter.hasNext) { - if (description.maxRecordsPerFile > 0 && recordsInFile >= description.maxRecordsPerFile) { - fileCounter += 1 - assert(fileCounter < MAX_FILE_COUNTER, - s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER") - - recordsInFile = 0 - releaseResources() - newOutputWriter(fileCounter) - } - - val internalRow = iter.next() - currentWriter.write(internalRow) - statsTrackers.foreach(_.newRow(internalRow)) - recordsInFile += 1 - } - releaseResources() - ExecutedWriteSummary( - updatedPartitions = Set.empty, - stats = statsTrackers.map(_.getFinalStats())) - } - - override def releaseResources(): Unit = { - if (currentWriter != null) { - try { - currentWriter.close() - } finally { - currentWriter = null - } - } - } - } - - /** - * Writes data to using dynamic partition writes, meaning this single function can write to - * multiple directories (partitions) or files (bucketing). - */ - private class DynamicPartitionWriteTask( - desc: WriteJobDescription, - taskAttemptContext: TaskAttemptContext, - committer: FileCommitProtocol) extends ExecuteWriteTask { - - /** Flag saying whether or not the data to be written out is partitioned. */ - val isPartitioned = desc.partitionColumns.nonEmpty - - /** Flag saying whether or not the data to be written out is bucketed. */ - val isBucketed = desc.bucketIdExpression.isDefined - - assert(isPartitioned || isBucketed, - s"""DynamicPartitionWriteTask should be used for writing out data that's either - |partitioned or bucketed. In this case neither is true. - |WriteJobDescription: ${desc} - """.stripMargin) - - // currentWriter is initialized whenever we see a new key (partitionValues + BucketId) - private var currentWriter: OutputWriter = _ - - /** Trackers for computing various statistics on the data as it's being written out. */ - private val statsTrackers: Seq[WriteTaskStatsTracker] = - desc.statsTrackers.map(_.newTaskInstance()) - - /** Extracts the partition values out of an input row. */ - private lazy val getPartitionValues: InternalRow => UnsafeRow = { - val proj = UnsafeProjection.create(desc.partitionColumns, desc.allColumns) - row => proj(row) - } - - /** Expression that given partition columns builds a path string like: col1=val/col2=val/... */ - private lazy val partitionPathExpression: Expression = Concat( - desc.partitionColumns.zipWithIndex.flatMap { case (c, i) => - val partitionName = ScalaUDF( - ExternalCatalogUtils.getPartitionPathString _, - StringType, - Seq(Literal(c.name), Cast(c, StringType, Option(desc.timeZoneId)))) - if (i == 0) Seq(partitionName) else Seq(Literal(Path.SEPARATOR), partitionName) - }) - - /** Evaluates the `partitionPathExpression` above on a row of `partitionValues` and returns - * the partition string. */ - private lazy val getPartitionPath: InternalRow => String = { - val proj = UnsafeProjection.create(Seq(partitionPathExpression), desc.partitionColumns) - row => proj(row).getString(0) - } - - /** Given an input row, returns the corresponding `bucketId` */ - private lazy val getBucketId: InternalRow => Int = { - val proj = UnsafeProjection.create(desc.bucketIdExpression.toSeq, desc.allColumns) - row => proj(row).getInt(0) - } - - /** Returns the data columns to be written given an input row */ - private val getOutputRow = UnsafeProjection.create(desc.dataColumns, desc.allColumns) - - /** - * Opens a new OutputWriter given a partition key and/or a bucket id. - * If bucket id is specified, we will append it to the end of the file name, but before the - * file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet - * - * @param partitionValues the partition which all tuples being written by this `OutputWriter` - * belong to - * @param bucketId the bucket which all tuples being written by this `OutputWriter` belong to - * @param fileCounter the number of files that have been written in the past for this specific - * partition. This is used to limit the max number of records written for a - * single file. The value should start from 0. - * @param updatedPartitions the set of updated partition paths, we should add the new partition - * path of this writer to it. - */ - private def newOutputWriter( - partitionValues: Option[InternalRow], - bucketId: Option[Int], - fileCounter: Int, - updatedPartitions: mutable.Set[String]): Unit = { - - val partDir = partitionValues.map(getPartitionPath(_)) - partDir.foreach(updatedPartitions.add) - - val bucketIdStr = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("") - - // This must be in a form that matches our bucketing format. See BucketingUtils. - val ext = f"$bucketIdStr.c$fileCounter%03d" + - desc.outputWriterFactory.getFileExtension(taskAttemptContext) - - val customPath = partDir.flatMap { dir => - desc.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir)) - } - val currentPath = if (customPath.isDefined) { - committer.newTaskTempFileAbsPath(taskAttemptContext, customPath.get, ext) - } else { - committer.newTaskTempFile(taskAttemptContext, partDir, ext) - } - - currentWriter = desc.outputWriterFactory.newInstance( - path = currentPath, - dataSchema = desc.dataColumns.toStructType, - context = taskAttemptContext) - - statsTrackers.foreach(_.newFile(currentPath)) - } - - override def execute(iter: Iterator[InternalRow]): ExecutedWriteSummary = { - // If anything below fails, we should abort the task. - var recordsInFile: Long = 0L - var fileCounter = 0 - val updatedPartitions = mutable.Set[String]() - var currentPartionValues: Option[UnsafeRow] = None - var currentBucketId: Option[Int] = None - - for (row <- iter) { - val nextPartitionValues = if (isPartitioned) Some(getPartitionValues(row)) else None - val nextBucketId = if (isBucketed) Some(getBucketId(row)) else None - - if (currentPartionValues != nextPartitionValues || currentBucketId != nextBucketId) { - // See a new partition or bucket - write to a new partition dir (or a new bucket file). - if (isPartitioned && currentPartionValues != nextPartitionValues) { - currentPartionValues = Some(nextPartitionValues.get.copy()) - statsTrackers.foreach(_.newPartition(currentPartionValues.get)) - } - if (isBucketed) { - currentBucketId = nextBucketId - statsTrackers.foreach(_.newBucket(currentBucketId.get)) - } - - recordsInFile = 0 - fileCounter = 0 - - releaseResources() - newOutputWriter(currentPartionValues, currentBucketId, fileCounter, updatedPartitions) - } else if (desc.maxRecordsPerFile > 0 && - recordsInFile >= desc.maxRecordsPerFile) { - // Exceeded the threshold in terms of the number of records per file. - // Create a new file by increasing the file counter. - recordsInFile = 0 - fileCounter += 1 - assert(fileCounter < MAX_FILE_COUNTER, - s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER") - - releaseResources() - newOutputWriter(currentPartionValues, currentBucketId, fileCounter, updatedPartitions) - } - val outputRow = getOutputRow(row) - currentWriter.write(outputRow) - statsTrackers.foreach(_.newRow(outputRow)) - recordsInFile += 1 - } - releaseResources() - - ExecutedWriteSummary( - updatedPartitions = updatedPartitions.toSet, - stats = statsTrackers.map(_.getFinalStats())) - } - - override def releaseResources(): Unit = { - if (currentWriter != null) { - try { - currentWriter.close() - } finally { - currentWriter = null - } - } - } - } } - -/** - * Wrapper class for the metrics of writing data out. - * - * @param updatedPartitions the partitions updated during writing data out. Only valid - * for dynamic partition. - * @param stats one `WriteTaskStats` object for every `WriteJobStatsTracker` that the job had. - */ -case class ExecutedWriteSummary( - updatedPartitions: Set[String], - stats: Seq[WriteTaskStats]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index dc54d182651b1..82322df407521 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -30,6 +30,7 @@ import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.spark.TaskContext import org.apache.spark.input.{PortableDataStream, StreamInputFormat} +import org.apache.spark.internal.Logging import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow @@ -50,7 +51,10 @@ abstract class CSVDataSource extends Serializable { conf: Configuration, file: PartitionedFile, parser: UnivocityParser, - schema: StructType): Iterator[InternalRow] + requiredSchema: StructType, + // Actual schema of data in the csv file + dataSchema: StructType, + caseSensitive: Boolean): Iterator[InternalRow] /** * Infers the schema from `inputPaths` files. @@ -110,7 +114,7 @@ abstract class CSVDataSource extends Serializable { } } -object CSVDataSource { +object CSVDataSource extends Logging { def apply(options: CSVOptions): CSVDataSource = { if (options.multiLine) { MultiLineCSVDataSource @@ -118,6 +122,84 @@ object CSVDataSource { TextInputCSVDataSource } } + + /** + * Checks that column names in a CSV header and field names in the schema are the same + * by taking into account case sensitivity. + * + * @param schema - provided (or inferred) schema to which CSV must conform. + * @param columnNames - names of CSV columns that must be checked against to the schema. + * @param fileName - name of CSV file that are currently checked. It is used in error messages. + * @param enforceSchema - if it is `true`, column names are ignored otherwise the CSV column + * names are checked for conformance to the schema. In the case if + * the column name don't conform to the schema, an exception is thrown. + * @param caseSensitive - if it is set to `false`, comparison of column names and schema field + * names is not case sensitive. + */ + def checkHeaderColumnNames( + schema: StructType, + columnNames: Array[String], + fileName: String, + enforceSchema: Boolean, + caseSensitive: Boolean): Unit = { + if (columnNames != null) { + val fieldNames = schema.map(_.name).toIndexedSeq + val (headerLen, schemaSize) = (columnNames.size, fieldNames.length) + var errorMessage: Option[String] = None + + if (headerLen == schemaSize) { + var i = 0 + while (errorMessage.isEmpty && i < headerLen) { + var (nameInSchema, nameInHeader) = (fieldNames(i), columnNames(i)) + if (!caseSensitive) { + nameInSchema = nameInSchema.toLowerCase + nameInHeader = nameInHeader.toLowerCase + } + if (nameInHeader != nameInSchema) { + errorMessage = Some( + s"""|CSV header does not conform to the schema. + | Header: ${columnNames.mkString(", ")} + | Schema: ${fieldNames.mkString(", ")} + |Expected: ${fieldNames(i)} but found: ${columnNames(i)} + |CSV file: $fileName""".stripMargin) + } + i += 1 + } + } else { + errorMessage = Some( + s"""|Number of column in CSV header is not equal to number of fields in the schema: + | Header length: $headerLen, schema size: $schemaSize + |CSV file: $fileName""".stripMargin) + } + + errorMessage.foreach { msg => + if (enforceSchema) { + logWarning(msg) + } else { + throw new IllegalArgumentException(msg) + } + } + } + } + + /** + * Checks that CSV header contains the same column names as fields names in the given schema + * by taking into account case sensitivity. + */ + def checkHeader( + header: String, + parser: CsvParser, + schema: StructType, + fileName: String, + enforceSchema: Boolean, + caseSensitive: Boolean): Unit = { + checkHeaderColumnNames( + schema, + parser.parseLine(header), + fileName, + enforceSchema, + caseSensitive) + } } object TextInputCSVDataSource extends CSVDataSource { @@ -127,7 +209,9 @@ object TextInputCSVDataSource extends CSVDataSource { conf: Configuration, file: PartitionedFile, parser: UnivocityParser, - schema: StructType): Iterator[InternalRow] = { + requiredSchema: StructType, + dataSchema: StructType, + caseSensitive: Boolean): Iterator[InternalRow] = { val lines = { val linesReader = new HadoopFileLinesReader(file, conf) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) @@ -136,8 +220,24 @@ object TextInputCSVDataSource extends CSVDataSource { } } - val shouldDropHeader = parser.options.headerFlag && file.start == 0 - UnivocityParser.parseIterator(lines, shouldDropHeader, parser, schema) + val hasHeader = parser.options.headerFlag && file.start == 0 + if (hasHeader) { + // Checking that column names in the header are matched to field names of the schema. + // The header will be removed from lines. + // Note: if there are only comments in the first block, the header would probably + // be not extracted. + CSVUtils.extractHeader(lines, parser.options).foreach { header => + CSVDataSource.checkHeader( + header, + parser.tokenizer, + dataSchema, + file.filePath, + parser.options.enforceSchema, + caseSensitive) + } + } + + UnivocityParser.parseIterator(lines, parser, requiredSchema) } override def infer( @@ -206,12 +306,24 @@ object MultiLineCSVDataSource extends CSVDataSource { conf: Configuration, file: PartitionedFile, parser: UnivocityParser, - schema: StructType): Iterator[InternalRow] = { + requiredSchema: StructType, + dataSchema: StructType, + caseSensitive: Boolean): Iterator[InternalRow] = { + def checkHeader(header: Array[String]): Unit = { + CSVDataSource.checkHeaderColumnNames( + dataSchema, + header, + file.filePath, + parser.options.enforceSchema, + caseSensitive) + } + UnivocityParser.parseStream( CodecStreams.createInputStreamWithCloseResource(conf, new Path(new URI(file.filePath))), parser.options.headerFlag, parser, - schema) + requiredSchema, + checkHeader) } override def infer( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index 21279d6daf7ad..b90275de9f40a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -130,6 +130,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { "df.filter($\"_corrupt_record\".isNotNull).count()." ) } + val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis (file: PartitionedFile) => { val conf = broadcastedHadoopConf.value.value @@ -137,7 +138,13 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { StructType(dataSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)), StructType(requiredSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)), parsedOptions) - CSVDataSource(parsedOptions).readFile(conf, file, parser, requiredSchema) + CSVDataSource(parsedOptions).readFile( + conf, + file, + parser, + requiredSchema, + dataSchema, + caseSensitive) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 7119189a4e131..fab8d62da0c1d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -156,6 +156,12 @@ class CSVOptions( val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) + /** + * Forcibly apply the specified or inferred schema to datasource files. + * If the option is enabled, headers of CSV files will be ignored. + */ + val enforceSchema = getBool("enforceSchema", default = true) + def asWriterSettings: CsvWriterSettings = { val writerSettings = new CsvWriterSettings() val format = writerSettings.getFormat diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala index 9dae41b63e810..1012e774118e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala @@ -68,12 +68,8 @@ object CSVUtils { } } - /** - * Drop header line so that only data can remain. - * This is similar with `filterHeaderLine` above and currently being used in CSV reading path. - */ - def dropHeaderLine(iter: Iterator[String], options: CSVOptions): Iterator[String] = { - val nonEmptyLines = if (options.isCommentSet) { + def skipComments(iter: Iterator[String], options: CSVOptions): Iterator[String] = { + if (options.isCommentSet) { val commentPrefix = options.comment.toString iter.dropWhile { line => line.trim.isEmpty || line.trim.startsWith(commentPrefix) @@ -81,11 +77,19 @@ object CSVUtils { } else { iter.dropWhile(_.trim.isEmpty) } - - if (nonEmptyLines.hasNext) nonEmptyLines.drop(1) - iter } + /** + * Extracts header and moves iterator forward so that only data remains in it + */ + def extractHeader(iter: Iterator[String], options: CSVOptions): Option[String] = { + val nonEmptyLines = skipComments(iter, options) + if (nonEmptyLines.hasNext) { + Some(nonEmptyLines.next()) + } else { + None + } + } /** * Helper method that converts string representation of a character to actual character. * It handles some Java escaped strings and throws exception if given string is longer than one diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index 4f00cc5eb3f39..5f7d5696b71a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -45,7 +45,7 @@ class UnivocityParser( // A `ValueConverter` is responsible for converting the given value to a desired type. private type ValueConverter = String => Any - private val tokenizer = { + val tokenizer = { val parserSetting = options.asParserSettings if (options.columnPruning && requiredSchema.length < dataSchema.length) { val tokenIndexArr = requiredSchema.map(f => java.lang.Integer.valueOf(dataSchema.indexOf(f))) @@ -250,14 +250,15 @@ private[csv] object UnivocityParser { inputStream: InputStream, shouldDropHeader: Boolean, parser: UnivocityParser, - schema: StructType): Iterator[InternalRow] = { + schema: StructType, + checkHeader: Array[String] => Unit): Iterator[InternalRow] = { val tokenizer = parser.tokenizer val safeParser = new FailureSafeParser[Array[String]]( input => Seq(parser.convert(input)), parser.options.parseMode, schema, parser.options.columnNameOfCorruptRecord) - convertStream(inputStream, shouldDropHeader, tokenizer) { tokens => + convertStream(inputStream, shouldDropHeader, tokenizer, checkHeader) { tokens => safeParser.parse(tokens) }.flatten } @@ -265,11 +266,14 @@ private[csv] object UnivocityParser { private def convertStream[T]( inputStream: InputStream, shouldDropHeader: Boolean, - tokenizer: CsvParser)(convert: Array[String] => T) = new Iterator[T] { + tokenizer: CsvParser, + checkHeader: Array[String] => Unit = _ => ())( + convert: Array[String] => T) = new Iterator[T] { tokenizer.beginParsing(inputStream) private var nextRecord = { if (shouldDropHeader) { - tokenizer.parseNext() + val firstRecord = tokenizer.parseNext() + checkHeader(firstRecord) } tokenizer.parseNext() } @@ -291,21 +295,11 @@ private[csv] object UnivocityParser { */ def parseIterator( lines: Iterator[String], - shouldDropHeader: Boolean, parser: UnivocityParser, schema: StructType): Iterator[InternalRow] = { val options = parser.options - val linesWithoutHeader = if (shouldDropHeader) { - // Note that if there are only comments in the first block, the header would probably - // be not dropped. - CSVUtils.dropHeaderLine(lines, options) - } else { - lines - } - - val filteredLines: Iterator[String] = - CSVUtils.filterCommentAndEmpty(linesWithoutHeader, options) + val filteredLines: Iterator[String] = CSVUtils.filterCommentAndEmpty(lines, options) val safeParser = new FailureSafeParser[String]( input => Seq(parser.parse(input)), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index ea283ed77efda..ea4bda327f36f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -116,7 +116,9 @@ object DataWritingSparkTask extends Logging { // write the data and commit this writer. Utils.tryWithSafeFinallyAndFailureCallbacks(block = { - iter.foreach(dataWriter.write) + while (iter.hasNext) { + dataWriter.write(iter.next()) + } val msg = if (useCommitCoordinator) { val coordinator = SparkEnv.get.outputCommitCoordinator diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 7817360810bde..17ffa2a517312 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -126,6 +126,12 @@ class MicroBatchExecution( _logicalPlan } + /** + * Signifies whether current batch (i.e. for the batch `currentBatchId`) has been constructed + * (i.e. written to the offsetLog) and is ready for execution. + */ + private var isCurrentBatchConstructed = false + /** * Signals to the thread executing micro-batches that it should stop running after the next * batch. This method blocks until the thread stops running. @@ -154,7 +160,6 @@ class MicroBatchExecution( triggerExecutor.execute(() => { if (isActive) { - var currentBatchIsRunnable = false // Whether the current batch is runnable / has been run var currentBatchHasNewData = false // Whether the current batch had new data startTrigger() @@ -175,7 +180,9 @@ class MicroBatchExecution( // new data to process as `constructNextBatch` may decide to run a batch for // state cleanup, etc. `isNewDataAvailable` will be updated to reflect whether new data // is available or not. - currentBatchIsRunnable = constructNextBatch(noDataBatchesEnabled) + if (!isCurrentBatchConstructed) { + isCurrentBatchConstructed = constructNextBatch(noDataBatchesEnabled) + } // Remember whether the current batch has data or not. This will be required later // for bookkeeping after running the batch, when `isNewDataAvailable` will have changed @@ -183,7 +190,7 @@ class MicroBatchExecution( currentBatchHasNewData = isNewDataAvailable currentStatus = currentStatus.copy(isDataAvailable = isNewDataAvailable) - if (currentBatchIsRunnable) { + if (isCurrentBatchConstructed) { if (currentBatchHasNewData) updateStatusMessage("Processing new data") else updateStatusMessage("No new data but cleaning up state") runBatch(sparkSessionForStream) @@ -194,9 +201,12 @@ class MicroBatchExecution( finishTrigger(currentBatchHasNewData) // Must be outside reportTimeTaken so it is recorded - // If the current batch has been executed, then increment the batch id, else there was - // no data to execute the batch - if (currentBatchIsRunnable) currentBatchId += 1 else Thread.sleep(pollingDelayMs) + // If the current batch has been executed, then increment the batch id and reset flag. + // Otherwise, there was no data to execute the batch and sleep for some time + if (isCurrentBatchConstructed) { + currentBatchId += 1 + isCurrentBatchConstructed = false + } else Thread.sleep(pollingDelayMs) } updateStatusMessage("Waiting for next trigger") isActive @@ -231,6 +241,7 @@ class MicroBatchExecution( /* First assume that we are re-executing the latest known batch * in the offset log */ currentBatchId = latestBatchId + isCurrentBatchConstructed = true availableOffsets = nextOffsets.toStreamProgress(sources) /* Initialize committed offsets to a committed batch, which at this * is the second latest batch id in the offset log. */ @@ -269,6 +280,7 @@ class MicroBatchExecution( // here, so we do nothing here. } currentBatchId = latestCommittedBatchId + 1 + isCurrentBatchConstructed = false committedOffsets ++= availableOffsets // Construct a new batch be recomputing availableOffsets } else if (latestCommittedBatchId < latestBatchId - 1) { @@ -313,11 +325,8 @@ class MicroBatchExecution( * - If either of the above is true, then construct the next batch by committing to the offset * log that range of offsets that the next batch will process. */ - private def constructNextBatch(noDataBatchesEnables: Boolean): Boolean = withProgressLocked { - // If new data is already available that means this method has already been called before - // and it must have already committed the offset range of next batch to the offset log. - // Hence do nothing, just return true. - if (isNewDataAvailable) return true + private def constructNextBatch(noDataBatchesEnabled: Boolean): Boolean = withProgressLocked { + if (isCurrentBatchConstructed) return true // Generate a map from each unique source to the next available offset. val latestOffsets: Map[BaseStreamingSource, Option[Offset]] = uniqueSources.map { @@ -348,9 +357,14 @@ class MicroBatchExecution( batchTimestampMs = triggerClock.getTimeMillis()) // Check whether next batch should be constructed - val lastExecutionRequiresAnotherBatch = noDataBatchesEnables && + val lastExecutionRequiresAnotherBatch = noDataBatchesEnabled && Option(lastExecution).exists(_.shouldRunAnotherBatch(offsetSeqMetadata)) val shouldConstructNextBatch = isNewDataAvailable || lastExecutionRequiresAnotherBatch + logTrace( + s"noDataBatchesEnabled = $noDataBatchesEnabled, " + + s"lastExecutionRequiresAnotherBatch = $lastExecutionRequiresAnotherBatch, " + + s"isNewDataAvailable = $isNewDataAvailable, " + + s"shouldConstructNextBatch = $shouldConstructNextBatch") if (shouldConstructNextBatch) { // Commit the next batch offset range to the offset log diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index d16b24c89ebef..e3d0cea608b2a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -318,9 +318,14 @@ class ContinuousExecution( } } - if (minLogEntriesToMaintain < currentBatchId) { - offsetLog.purge(currentBatchId - minLogEntriesToMaintain) - commitLog.purge(currentBatchId - minLogEntriesToMaintain) + // Since currentBatchId increases independently in cp mode, the current committed epoch may + // be far behind currentBatchId. It is not safe to discard the metadata with thresholdBatchId + // computed based on currentBatchId. As minLogEntriesToMaintain is used to keep the minimum + // number of batches that must be retained and made recoverable, so we should keep the + // specified number of metadata that have been committed. + if (minLogEntriesToMaintain <= epoch) { + offsetLog.purge(epoch + 1 - minLogEntriesToMaintain) + commitLog.purge(epoch + 1 - minLogEntriesToMaintain) } awaitProgressLock.lock() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 5ab9cb3fb86a5..a2aae9a708ff3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3169,6 +3169,15 @@ object functions { */ def array_sort(e: Column): Column = withExpr { ArraySort(e.expr) } + /** + * Remove all elements that equal to element from the given array. + * @group collection_funcs + * @since 2.4.0 + */ + def array_remove(column: Column, element: Any): Column = withExpr { + ArrayRemove(column.expr, Literal(element)) + } + /** * Creates a new row for each element in the given array or map column. * @@ -3499,6 +3508,125 @@ object functions { */ def map_entries(e: Column): Column = withExpr { MapEntries(e.expr) } + ////////////////////////////////////////////////////////////////////////////////////////////// + // Mask functions + ////////////////////////////////////////////////////////////////////////////////////////////// + /** + * Returns a string which is the masked representation of the input. + * @group mask_funcs + * @since 2.4.0 + */ + def mask(e: Column): Column = withExpr { new Mask(e.expr) } + + /** + * Returns a string which is the masked representation of the input, using `upper`, `lower` and + * `digit` as replacement characters. + * @group mask_funcs + * @since 2.4.0 + */ + def mask(e: Column, upper: String, lower: String, digit: String): Column = withExpr { + Mask(e.expr, upper, lower, digit) + } + + /** + * Returns a string with the first `n` characters masked. + * @group mask_funcs + * @since 2.4.0 + */ + def mask_first_n(e: Column, n: Int): Column = withExpr { new MaskFirstN(e.expr, Literal(n)) } + + /** + * Returns a string with the first `n` characters masked, using `upper`, `lower` and `digit` as + * replacement characters. + * @group mask_funcs + * @since 2.4.0 + */ + def mask_first_n( + e: Column, + n: Int, + upper: String, + lower: String, + digit: String): Column = withExpr { + MaskFirstN(e.expr, n, upper, lower, digit) + } + + /** + * Returns a string with the last `n` characters masked. + * @group mask_funcs + * @since 2.4.0 + */ + def mask_last_n(e: Column, n: Int): Column = withExpr { new MaskLastN(e.expr, Literal(n)) } + + /** + * Returns a string with the last `n` characters masked, using `upper`, `lower` and `digit` as + * replacement characters. + * @group mask_funcs + * @since 2.4.0 + */ + def mask_last_n( + e: Column, + n: Int, + upper: String, + lower: String, + digit: String): Column = withExpr { + MaskLastN(e.expr, n, upper, lower, digit) + } + + /** + * Returns a string with all but the first `n` characters masked. + * @group mask_funcs + * @since 2.4.0 + */ + def mask_show_first_n(e: Column, n: Int): Column = withExpr { + new MaskShowFirstN(e.expr, Literal(n)) + } + + /** + * Returns a string with all but the first `n` characters masked, using `upper`, `lower` and + * `digit` as replacement characters. + * @group mask_funcs + * @since 2.4.0 + */ + def mask_show_first_n( + e: Column, + n: Int, + upper: String, + lower: String, + digit: String): Column = withExpr { + MaskShowFirstN(e.expr, n, upper, lower, digit) + } + + /** + * Returns a string with all but the last `n` characters masked. + * @group mask_funcs + * @since 2.4.0 + */ + def mask_show_last_n(e: Column, n: Int): Column = withExpr { + new MaskShowLastN(e.expr, Literal(n)) + } + + /** + * Returns a string with all but the last `n` characters masked, using `upper`, `lower` and + * `digit` as replacement characters. + * @group mask_funcs + * @since 2.4.0 + */ + def mask_show_last_n( + e: Column, + n: Int, + upper: String, + lower: String, + digit: String): Column = withExpr { + MaskShowLastN(e.expr, n, upper, lower, digit) + } + + /** + * Returns a hashed value based on the input column. + * @group mask_funcs + * @since 2.4.0 + */ + def mask_hash(e: Column): Column = withExpr { MaskHash(e.expr) } + // scalastyle:off line.size.limit // scalastyle:off parameter.number diff --git a/sql/core/src/test/resources/sql-tests/inputs/extract.sql b/sql/core/src/test/resources/sql-tests/inputs/extract.sql new file mode 100644 index 0000000000000..9adf5d70056e2 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/extract.sql @@ -0,0 +1,21 @@ +CREATE TEMPORARY VIEW t AS select '2011-05-06 07:08:09.1234567' as c; + +select extract(year from c) from t; + +select extract(quarter from c) from t; + +select extract(month from c) from t; + +select extract(week from c) from t; + +select extract(day from c) from t; + +select extract(dayofweek from c) from t; + +select extract(hour from c) from t; + +select extract(minute from c) from t; + +select extract(second from c) from t; + +select extract(not_supported from c) from t; diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index c5070b734d521..2c18d6aaabdba 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -68,4 +68,8 @@ SELECT 1 from ( FROM (select 1 as x) a WHERE false ) b -where b.z != b.z +where b.z != b.z; + +-- SPARK-24369 multiple distinct aggregations having the same argument set +SELECT corr(DISTINCT x, y), corr(DISTINCT y, x), count(*) + FROM (VALUES (1, 1), (2, 2), (2, 2)) t(x, y); diff --git a/sql/core/src/test/resources/sql-tests/results/extract.sql.out b/sql/core/src/test/resources/sql-tests/results/extract.sql.out new file mode 100644 index 0000000000000..160e4c7d78455 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/extract.sql.out @@ -0,0 +1,96 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 11 + + +-- !query 0 +CREATE TEMPORARY VIEW t AS select '2011-05-06 07:08:09.1234567' as c +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +select extract(year from c) from t +-- !query 1 schema +struct +-- !query 1 output +2011 + + +-- !query 2 +select extract(quarter from c) from t +-- !query 2 schema +struct +-- !query 2 output +2 + + +-- !query 3 +select extract(month from c) from t +-- !query 3 schema +struct +-- !query 3 output +5 + + +-- !query 4 +select extract(week from c) from t +-- !query 4 schema +struct +-- !query 4 output +18 + + +-- !query 5 +select extract(day from c) from t +-- !query 5 schema +struct +-- !query 5 output +6 + + +-- !query 6 +select extract(dayofweek from c) from t +-- !query 6 schema +struct +-- !query 6 output +6 + + +-- !query 7 +select extract(hour from c) from t +-- !query 7 schema +struct +-- !query 7 output +7 + + +-- !query 8 +select extract(minute from c) from t +-- !query 8 schema +struct +-- !query 8 output +8 + + +-- !query 9 +select extract(second from c) from t +-- !query 9 schema +struct +-- !query 9 output +9 + + +-- !query 10 +select extract(not_supported from c) from t +-- !query 10 schema +struct<> +-- !query 10 output +org.apache.spark.sql.catalyst.parser.ParseException + +Literals of type 'NOT_SUPPORTED' are currently not supported.(line 1, pos 7) + +== SQL == +select extract(not_supported from c) from t +-------^^^ diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index c1abc6dff754b..581aa1754ce14 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 26 +-- Number of queries: 27 -- !query 0 @@ -241,3 +241,12 @@ where b.z != b.z struct<1:int> -- !query 25 output + + +-- !query 26 +SELECT corr(DISTINCT x, y), corr(DISTINCT y, x), count(*) + FROM (VALUES (1, 1), (2, 2), (2, 2)) t(x, y) +-- !query 26 schema +struct +-- !query 26 output +1.0 1.0 3 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 7c45be21961d3..2182bd7eadd63 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql +import java.util.Locale + +import scala.collection.JavaConverters._ + import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat} import org.scalatest.Matchers._ @@ -390,11 +394,67 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { checkAnswer(df.filter($"b".isin("z", "y")), df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "y")) + // Auto casting should work with mixture of different types in collections + checkAnswer(df.filter($"a".isin(1.toShort, "2")), + df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isin("3", 2.toLong)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isin(3, "1")), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) + val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") - intercept[AnalysisException] { + val e = intercept[AnalysisException] { df2.filter($"a".isin($"b")) } + Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were") + .foreach { s => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) + } + } + + test("isInCollection: Scala Collection") { + val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") + // Test with different types of collections + checkAnswer(df.filter($"a".isInCollection(Seq(3, 1))), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) + checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).toSet)), + df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isInCollection(Seq(3, 2).toArray)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).toList)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) + + val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") + + val e = intercept[AnalysisException] { + df2.filter($"a".isInCollection(Seq($"b"))) + } + Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were") + .foreach { s => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) + } + } + + test("isInCollection: Java Collection") { + val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") + // Test with different types of collections + checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).asJava)), + df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).toSet.asJava)), + df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).toList.asJava)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) + + val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b") + + val e = intercept[AnalysisException] { + df2.filter($"a".isInCollection(Seq($"b").asJava)) + } + Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were") + .foreach { s => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) + } } test("&&") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 96c28961e5aaf..f495a949ebc5a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql import scala.util.Random -import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} -import org.apache.spark.sql.catalyst.expressions.aggregate.Count +import org.scalatest.Matchers.the + import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec @@ -687,4 +687,34 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-21896: Window functions inside aggregate functions") { + def checkWindowError(df: => DataFrame): Unit = { + val thrownException = the [AnalysisException] thrownBy { + df.queryExecution.analyzed + } + assert(thrownException.message.contains("not allowed to use a window function")) + } + + checkWindowError(testData2.select(min(avg('b).over(Window.partitionBy('a))))) + checkWindowError(testData2.agg(sum('b), max(rank().over(Window.orderBy('a))))) + checkWindowError(testData2.groupBy('a).agg(sum('b), max(rank().over(Window.orderBy('b))))) + checkWindowError(testData2.groupBy('a).agg(max(sum(sum('b)).over(Window.orderBy('a))))) + checkWindowError( + testData2.groupBy('a).agg(sum('b).as("s"), max(count("*").over())).where('s === 3)) + checkAnswer( + testData2.groupBy('a).agg(max('b), sum('b).as("s"), count("*").over()).where('s === 3), + Row(1, 2, 3, 3) :: Row(2, 2, 3, 3) :: Row(3, 2, 3, 3) :: Nil) + + checkWindowError(sql("SELECT MIN(AVG(b) OVER(PARTITION BY a)) FROM testData2")) + checkWindowError(sql("SELECT SUM(b), MAX(RANK() OVER(ORDER BY a)) FROM testData2")) + checkWindowError(sql("SELECT SUM(b), MAX(RANK() OVER(ORDER BY b)) FROM testData2 GROUP BY a")) + checkWindowError(sql("SELECT MAX(SUM(SUM(b)) OVER(ORDER BY a)) FROM testData2 GROUP BY a")) + checkWindowError( + sql("SELECT MAX(RANK() OVER(ORDER BY b)) FROM testData2 GROUP BY a HAVING SUM(b) = 3")) + checkAnswer( + sql("SELECT a, MAX(b), RANK() OVER(ORDER BY a) FROM testData2 GROUP BY a HAVING SUM(b) = 3"), + Row(1, 2, 1) :: Row(2, 2, 2) :: Row(3, 2, 3) :: Nil) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 79e743d961af8..59119bbbd8a2c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -276,6 +276,113 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("mask functions") { + val df = Seq("TestString-123", "", null).toDF("a") + checkAnswer(df.select(mask($"a")), Seq(Row("XxxxXxxxxx-nnn"), Row(""), Row(null))) + checkAnswer(df.select(mask_first_n($"a", 4)), Seq(Row("XxxxString-123"), Row(""), Row(null))) + checkAnswer(df.select(mask_last_n($"a", 4)), Seq(Row("TestString-nnn"), Row(""), Row(null))) + checkAnswer(df.select(mask_show_first_n($"a", 4)), + Seq(Row("TestXxxxxx-nnn"), Row(""), Row(null))) + checkAnswer(df.select(mask_show_last_n($"a", 4)), + Seq(Row("XxxxXxxxxx-123"), Row(""), Row(null))) + checkAnswer(df.select(mask_hash($"a")), + Seq(Row("dd78d68ad1b23bde126812482dd70ac6"), + Row("d41d8cd98f00b204e9800998ecf8427e"), + Row(null))) + + checkAnswer(df.select(mask($"a", "U", "l", "#")), + Seq(Row("UlllUlllll-###"), Row(""), Row(null))) + checkAnswer(df.select(mask_first_n($"a", 4, "U", "l", "#")), + Seq(Row("UlllString-123"), Row(""), Row(null))) + checkAnswer(df.select(mask_last_n($"a", 4, "U", "l", "#")), + Seq(Row("TestString-###"), Row(""), Row(null))) + checkAnswer(df.select(mask_show_first_n($"a", 4, "U", "l", "#")), + Seq(Row("TestUlllll-###"), Row(""), Row(null))) + checkAnswer(df.select(mask_show_last_n($"a", 4, "U", "l", "#")), + Seq(Row("UlllUlllll-123"), Row(""), Row(null))) + + checkAnswer( + df.selectExpr("mask(a)", "mask(a, 'U')", "mask(a, 'U', 'l')", "mask(a, 'U', 'l', '#')"), + Seq(Row("XxxxXxxxxx-nnn", "UxxxUxxxxx-nnn", "UlllUlllll-nnn", "UlllUlllll-###"), + Row("", "", "", ""), + Row(null, null, null, null))) + checkAnswer(sql("select mask(null)"), Row(null)) + checkAnswer(sql("select mask('AAaa11', null, null, null)"), Row("XXxxnn")) + intercept[AnalysisException] { + checkAnswer(df.selectExpr("mask(a, a)"), Seq(Row("XxxxXxxxxx-nnn"), Row(""), Row(null))) + } + + checkAnswer( + df.selectExpr( + "mask_first_n(a)", + "mask_first_n(a, 6)", + "mask_first_n(a, 6, 'U')", + "mask_first_n(a, 6, 'U', 'l')", + "mask_first_n(a, 6, 'U', 'l', '#')"), + Seq(Row("XxxxString-123", "XxxxXxring-123", "UxxxUxring-123", "UlllUlring-123", + "UlllUlring-123"), + Row("", "", "", "", ""), + Row(null, null, null, null, null))) + checkAnswer(sql("select mask_first_n(null)"), Row(null)) + checkAnswer(sql("select mask_first_n('A1aA1a', null, null, null, null)"), Row("XnxX1a")) + intercept[AnalysisException] { + checkAnswer(spark.range(1).selectExpr("mask_first_n('A1aA1a', id)"), Row("XnxX1a")) + } + + checkAnswer( + df.selectExpr( + "mask_last_n(a)", + "mask_last_n(a, 6)", + "mask_last_n(a, 6, 'U')", + "mask_last_n(a, 6, 'U', 'l')", + "mask_last_n(a, 6, 'U', 'l', '#')"), + Seq(Row("TestString-nnn", "TestStrixx-nnn", "TestStrixx-nnn", "TestStrill-nnn", + "TestStrill-###"), + Row("", "", "", "", ""), + Row(null, null, null, null, null))) + checkAnswer(sql("select mask_last_n(null)"), Row(null)) + checkAnswer(sql("select mask_last_n('A1aA1a', null, null, null, null)"), Row("A1xXnx")) + intercept[AnalysisException] { + checkAnswer(spark.range(1).selectExpr("mask_last_n('A1aA1a', id)"), Row("A1xXnx")) + } + + checkAnswer( + df.selectExpr( + "mask_show_first_n(a)", + "mask_show_first_n(a, 6)", + "mask_show_first_n(a, 6, 'U')", + "mask_show_first_n(a, 6, 'U', 'l')", + "mask_show_first_n(a, 6, 'U', 'l', '#')"), + Seq(Row("TestXxxxxx-nnn", "TestStxxxx-nnn", "TestStxxxx-nnn", "TestStllll-nnn", + "TestStllll-###"), + Row("", "", "", "", ""), + Row(null, null, null, null, null))) + checkAnswer(sql("select mask_show_first_n(null)"), Row(null)) + checkAnswer(sql("select mask_show_first_n('A1aA1a', null, null, null, null)"), Row("A1aAnx")) + intercept[AnalysisException] { + checkAnswer(spark.range(1).selectExpr("mask_show_first_n('A1aA1a', id)"), Row("A1aAnx")) + } + + checkAnswer( + df.selectExpr( + "mask_show_last_n(a)", + "mask_show_last_n(a, 6)", + "mask_show_last_n(a, 6, 'U')", + "mask_show_last_n(a, 6, 'U', 'l')", + "mask_show_last_n(a, 6, 'U', 'l', '#')"), + Seq(Row("XxxxXxxxxx-123", "XxxxXxxxng-123", "UxxxUxxxng-123", "UlllUlllng-123", + "UlllUlllng-123"), + Row("", "", "", "", ""), + Row(null, null, null, null, null))) + checkAnswer(sql("select mask_show_last_n(null)"), Row(null)) + checkAnswer(sql("select mask_show_last_n('A1aA1a', null, null, null, null)"), Row("XnaA1a")) + intercept[AnalysisException] { + checkAnswer(spark.range(1).selectExpr("mask_show_last_n('A1aA1a', id)"), Row("XnaA1a")) + } + + checkAnswer(sql("select mask_hash(null)"), Row(null)) + } + test("sort_array/array_sort functions") { val df = Seq( (Array[Int](2, 1, 3), Array("b", "c", "a")), @@ -1003,6 +1110,35 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } + test("array remove") { + val df = Seq( + (Array[Int](2, 1, 2, 3), Array("a", "b", "c", "a"), Array("", "")), + (Array.empty[Int], Array.empty[String], Array.empty[String]), + (null, null, null) + ).toDF("a", "b", "c") + checkAnswer( + df.select(array_remove($"a", 2), array_remove($"b", "a"), array_remove($"c", "")), + Seq( + Row(Seq(1, 3), Seq("b", "c"), Seq.empty[String]), + Row(Seq.empty[Int], Seq.empty[String], Seq.empty[String]), + Row(null, null, null)) + ) + + checkAnswer( + df.selectExpr("array_remove(a, 2)", "array_remove(b, \"a\")", + "array_remove(c, \"\")"), + Seq( + Row(Seq(1, 3), Seq("b", "c"), Seq.empty[String]), + Row(Seq.empty[Int], Seq.empty[String], Seq.empty[String]), + Row(null, null, null)) + ) + + val e = intercept[AnalysisException] { + Seq(("a string element", "a")).toDF().selectExpr("array_remove(_1, _2)") + } + assert(e.message.contains("argument 1 requires array type, however, '`_1`' is of string type")) + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index b2aba8e72c5db..98a50fbd52b4d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -69,6 +69,27 @@ class PlannerSuite extends SharedSQLContext { testPartialAggregationPlan(query) } + test("mixed aggregates with same distinct columns") { + def assertNoExpand(plan: SparkPlan): Unit = { + assert(plan.collect { case e: ExpandExec => e }.isEmpty) + } + + withTempView("v") { + Seq((1, 1.0, 1.0), (1, 2.0, 2.0)).toDF("i", "j", "k").createTempView("v") + // one distinct column + val query1 = sql("SELECT sum(DISTINCT j), max(DISTINCT j) FROM v GROUP BY i") + assertNoExpand(query1.queryExecution.executedPlan) + + // 2 distinct columns + val query2 = sql("SELECT corr(DISTINCT j, k), count(DISTINCT j, k) FROM v GROUP BY i") + assertNoExpand(query2.queryExecution.executedPlan) + + // 2 distinct columns with different order + val query3 = sql("SELECT corr(DISTINCT j, k), count(DISTINCT k, j) FROM v GROUP BY i") + assertNoExpand(query3.queryExecution.executedPlan) + } + } + test("sizeInBytes estimation of limit operator for broadcast hash join optimization") { def checkPlan(fieldTypes: Seq[DataType]): Unit = { withTempView("testLimit") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceWriteBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceWriteBenchmark.scala new file mode 100644 index 0000000000000..2d2cdebd067c1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceWriteBenchmark.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.benchmark + +import org.apache.spark.SparkConf +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.Benchmark + +/** + * Benchmark to measure data source write performance. + * By default it measures 4 data source format: Parquet, ORC, JSON, CSV: + * spark-submit --class + * To measure specified formats, run it with arguments: + * spark-submit --class format1 [format2] [...] + */ +object DataSourceWriteBenchmark { + val conf = new SparkConf() + .setAppName("DataSourceWriteBenchmark") + .setIfMissing("spark.master", "local[1]") + .set("spark.sql.parquet.compression.codec", "snappy") + .set("spark.sql.orc.compression.codec", "snappy") + + val spark = SparkSession.builder.config(conf).getOrCreate() + + // Set default configs. Individual cases will change them if necessary. + spark.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") + + val tempTable = "temp" + val numRows = 1024 * 1024 * 15 + + def withTempTable(tableNames: String*)(f: => Unit): Unit = { + try f finally tableNames.foreach(spark.catalog.dropTempView) + } + + def withTable(tableNames: String*)(f: => Unit): Unit = { + try f finally { + tableNames.foreach { name => + spark.sql(s"DROP TABLE IF EXISTS $name") + } + } + } + + def writeNumeric(table: String, format: String, benchmark: Benchmark, dataType: String): Unit = { + spark.sql(s"create table $table(id $dataType) using $format") + benchmark.addCase(s"Output Single $dataType Column") { _ => + spark.sql(s"INSERT OVERWRITE TABLE $table SELECT CAST(id AS $dataType) AS c1 FROM $tempTable") + } + } + + def writeIntString(table: String, format: String, benchmark: Benchmark): Unit = { + spark.sql(s"CREATE TABLE $table(c1 INT, c2 STRING) USING $format") + benchmark.addCase("Output Int and String Column") { _ => + spark.sql(s"INSERT OVERWRITE TABLE $table SELECT CAST(id AS INT) AS " + + s"c1, CAST(id AS STRING) AS c2 FROM $tempTable") + } + } + + def writePartition(table: String, format: String, benchmark: Benchmark): Unit = { + spark.sql(s"CREATE TABLE $table(p INT, id INT) USING $format PARTITIONED BY (p)") + benchmark.addCase("Output Partitions") { _ => + spark.sql(s"INSERT OVERWRITE TABLE $table SELECT CAST(id AS INT) AS id," + + s" CAST(id % 2 AS INT) AS p FROM $tempTable") + } + } + + def writeBucket(table: String, format: String, benchmark: Benchmark): Unit = { + spark.sql(s"CREATE TABLE $table(c1 INT, c2 INT) USING $format CLUSTERED BY (c2) INTO 2 BUCKETS") + benchmark.addCase("Output Buckets") { _ => + spark.sql(s"INSERT OVERWRITE TABLE $table SELECT CAST(id AS INT) AS " + + s"c1, CAST(id AS INT) AS c2 FROM $tempTable") + } + } + + def main(args: Array[String]): Unit = { + val tableInt = "tableInt" + val tableDouble = "tableDouble" + val tableIntString = "tableIntString" + val tablePartition = "tablePartition" + val tableBucket = "tableBucket" + val formats: Seq[String] = if (args.isEmpty) { + Seq("Parquet", "ORC", "JSON", "CSV") + } else { + args + } + /* + Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz + Parquet writer benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Output Single Int Column 1815 / 1932 8.7 115.4 1.0X + Output Single Double Column 1877 / 1878 8.4 119.3 1.0X + Output Int and String Column 6265 / 6543 2.5 398.3 0.3X + Output Partitions 4067 / 4457 3.9 258.6 0.4X + Output Buckets 5608 / 5820 2.8 356.6 0.3X + + ORC writer benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Output Single Int Column 1201 / 1239 13.1 76.3 1.0X + Output Single Double Column 1542 / 1600 10.2 98.0 0.8X + Output Int and String Column 6495 / 6580 2.4 412.9 0.2X + Output Partitions 3648 / 3842 4.3 231.9 0.3X + Output Buckets 5022 / 5145 3.1 319.3 0.2X + + JSON writer benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Output Single Int Column 1988 / 2093 7.9 126.4 1.0X + Output Single Double Column 2854 / 2911 5.5 181.4 0.7X + Output Int and String Column 6467 / 6653 2.4 411.1 0.3X + Output Partitions 4548 / 5055 3.5 289.1 0.4X + Output Buckets 5664 / 5765 2.8 360.1 0.4X + + CSV writer benchmark: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Output Single Int Column 3025 / 3190 5.2 192.3 1.0X + Output Single Double Column 3575 / 3634 4.4 227.3 0.8X + Output Int and String Column 7313 / 7399 2.2 464.9 0.4X + Output Partitions 5105 / 5190 3.1 324.6 0.6X + Output Buckets 6986 / 6992 2.3 444.1 0.4X + */ + withTempTable(tempTable) { + spark.range(numRows).createOrReplaceTempView(tempTable) + formats.foreach { format => + withTable(tableInt, tableDouble, tableIntString, tablePartition, tableBucket) { + val benchmark = new Benchmark(s"$format writer benchmark", numRows) + writeNumeric(tableInt, format, benchmark, "Int") + writeNumeric(tableDouble, format, benchmark, "Double") + writeIntString(tableIntString, format, benchmark) + writePartition(tablePartition, format, benchmark) + writeBucket(tableBucket, format, benchmark) + benchmark.run() + } + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index afe10bdc4de26..d2f166c7d1877 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -23,9 +23,13 @@ import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat import java.util.Locale +import scala.collection.JavaConverters._ + import org.apache.commons.lang3.time.FastDateFormat import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.GzipCodec +import org.apache.log4j.{AppenderSkeleton, LogManager} +import org.apache.log4j.spi.LoggingEvent import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, UDT} @@ -1410,4 +1414,192 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te checkAnswer(idf, List(Row(15, 10, 5), Row(-15, -10, -5))) } } + + def checkHeader(multiLine: Boolean): Unit = { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + withTempPath { path => + val oschema = new StructType().add("f1", DoubleType).add("f2", DoubleType) + val odf = spark.createDataFrame(List(Row(1.0, 1234.5)).asJava, oschema) + odf.write.option("header", true).csv(path.getCanonicalPath) + val ischema = new StructType().add("f2", DoubleType).add("f1", DoubleType) + val exception = intercept[SparkException] { + spark.read + .schema(ischema) + .option("multiLine", multiLine) + .option("header", true) + .option("enforceSchema", false) + .csv(path.getCanonicalPath) + .collect() + } + assert(exception.getMessage.contains("CSV header does not conform to the schema")) + + val shortSchema = new StructType().add("f1", DoubleType) + val exceptionForShortSchema = intercept[SparkException] { + spark.read + .schema(shortSchema) + .option("multiLine", multiLine) + .option("header", true) + .option("enforceSchema", false) + .csv(path.getCanonicalPath) + .collect() + } + assert(exceptionForShortSchema.getMessage.contains( + "Number of column in CSV header is not equal to number of fields in the schema")) + + val longSchema = new StructType() + .add("f1", DoubleType) + .add("f2", DoubleType) + .add("f3", DoubleType) + + val exceptionForLongSchema = intercept[SparkException] { + spark.read + .schema(longSchema) + .option("multiLine", multiLine) + .option("header", true) + .option("enforceSchema", false) + .csv(path.getCanonicalPath) + .collect() + } + assert(exceptionForLongSchema.getMessage.contains("Header length: 2, schema size: 3")) + + val caseSensitiveSchema = new StructType().add("F1", DoubleType).add("f2", DoubleType) + val caseSensitiveException = intercept[SparkException] { + spark.read + .schema(caseSensitiveSchema) + .option("multiLine", multiLine) + .option("header", true) + .option("enforceSchema", false) + .csv(path.getCanonicalPath) + .collect() + } + assert(caseSensitiveException.getMessage.contains( + "CSV header does not conform to the schema")) + } + } + } + + test(s"SPARK-23786: Checking column names against schema in the multiline mode") { + checkHeader(multiLine = true) + } + + test(s"SPARK-23786: Checking column names against schema in the per-line mode") { + checkHeader(multiLine = false) + } + + test("SPARK-23786: CSV header must not be checked if it doesn't exist") { + withTempPath { path => + val oschema = new StructType().add("f1", DoubleType).add("f2", DoubleType) + val odf = spark.createDataFrame(List(Row(1.0, 1234.5)).asJava, oschema) + odf.write.option("header", false).csv(path.getCanonicalPath) + val ischema = new StructType().add("f2", DoubleType).add("f1", DoubleType) + val idf = spark.read + .schema(ischema) + .option("header", false) + .option("enforceSchema", false) + .csv(path.getCanonicalPath) + + checkAnswer(idf, odf) + } + } + + test("SPARK-23786: Ignore column name case if spark.sql.caseSensitive is false") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + withTempPath { path => + val oschema = new StructType().add("A", StringType) + val odf = spark.createDataFrame(List(Row("0")).asJava, oschema) + odf.write.option("header", true).csv(path.getCanonicalPath) + val ischema = new StructType().add("a", StringType) + val idf = spark.read.schema(ischema) + .option("header", true) + .option("enforceSchema", false) + .csv(path.getCanonicalPath) + checkAnswer(idf, odf) + } + } + } + + test("SPARK-23786: check header on parsing of dataset of strings") { + val ds = Seq("columnA,columnB", "1.0,1000.0").toDS() + val ischema = new StructType().add("columnB", DoubleType).add("columnA", DoubleType) + val exception = intercept[IllegalArgumentException] { + spark.read.schema(ischema).option("header", true).option("enforceSchema", false).csv(ds) + } + + assert(exception.getMessage.contains("CSV header does not conform to the schema")) + } + + test("SPARK-23786: enforce inferred schema") { + val expectedSchema = new StructType().add("_c0", DoubleType).add("_c1", StringType) + val withHeader = spark.read + .option("inferSchema", true) + .option("enforceSchema", false) + .option("header", true) + .csv(Seq("_c0,_c1", "1.0,a").toDS()) + assert(withHeader.schema == expectedSchema) + checkAnswer(withHeader, Seq(Row(1.0, "a"))) + + // Ignore the inferSchema flag if an user sets a schema + val schema = new StructType().add("colA", DoubleType).add("colB", StringType) + val ds = spark.read + .option("inferSchema", true) + .option("enforceSchema", false) + .option("header", true) + .schema(schema) + .csv(Seq("colA,colB", "1.0,a").toDS()) + assert(ds.schema == schema) + checkAnswer(ds, Seq(Row(1.0, "a"))) + + val exception = intercept[IllegalArgumentException] { + spark.read + .option("inferSchema", true) + .option("enforceSchema", false) + .option("header", true) + .schema(schema) + .csv(Seq("col1,col2", "1.0,a").toDS()) + } + assert(exception.getMessage.contains("CSV header does not conform to the schema")) + } + + test("SPARK-23786: warning should be printed if CSV header doesn't conform to schema") { + class TestAppender extends AppenderSkeleton { + var events = new java.util.ArrayList[LoggingEvent] + override def close(): Unit = {} + override def requiresLayout: Boolean = false + protected def append(event: LoggingEvent): Unit = events.add(event) + } + + val testAppender1 = new TestAppender + LogManager.getRootLogger.addAppender(testAppender1) + try { + val ds = Seq("columnA,columnB", "1.0,1000.0").toDS() + val ischema = new StructType().add("columnB", DoubleType).add("columnA", DoubleType) + + spark.read.schema(ischema).option("header", true).option("enforceSchema", true).csv(ds) + } finally { + LogManager.getRootLogger.removeAppender(testAppender1) + } + assert(testAppender1.events.asScala + .exists(msg => msg.getRenderedMessage.contains("CSV header does not conform to the schema"))) + + val testAppender2 = new TestAppender + LogManager.getRootLogger.addAppender(testAppender2) + try { + withTempPath { path => + val oschema = new StructType().add("f1", DoubleType).add("f2", DoubleType) + val odf = spark.createDataFrame(List(Row(1.0, 1234.5)).asJava, oschema) + odf.write.option("header", true).csv(path.getCanonicalPath) + val ischema = new StructType().add("f2", DoubleType).add("f1", DoubleType) + spark.read + .schema(ischema) + .option("header", true) + .option("enforceSchema", true) + .csv(path.getCanonicalPath) + .collect() + } + } finally { + LogManager.getRootLogger.removeAppender(testAppender2) + } + assert(testAppender2.events.asScala + .exists(msg => msg.getRenderedMessage.contains("CSV header does not conform to the schema"))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecutionSuite.scala new file mode 100644 index 0000000000000..c228740df07c8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecutionSuite.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql.functions.{count, window} +import org.apache.spark.sql.streaming.StreamTest + +class MicroBatchExecutionSuite extends StreamTest with BeforeAndAfter { + + import testImplicits._ + + after { + sqlContext.streams.active.foreach(_.stop()) + } + + test("SPARK-24156: do not plan a no-data batch again after it has already been planned") { + val inputData = MemoryStream[Int] + val df = inputData.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(window($"eventTime", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) + + testStream(df)( + AddData(inputData, 10, 11, 12, 13, 14, 15), // Set watermark to 5 + CheckAnswer(), + AddData(inputData, 25), // Set watermark to 15 to make MicroBatchExecution run no-data batch + CheckAnswer((10, 5)), // Last batch should be a no-data batch + StopStream, + Execute { q => + // Delete the last committed batch from the commit log to signify that the last batch + // (a no-data batch) never completed + val commit = q.commitLog.getLatest().map(_._1).getOrElse(-1L) + q.commitLog.purgeAfter(commit - 1) + }, + // Add data before start so that MicroBatchExecution can plan a batch. It should not, + // it should first re-run the incomplete no-data batch and then run a new batch to process + // new data. + AddData(inputData, 30), + StartStream(), + CheckNewAnswer((15, 1)), // This should not throw the error reported in SPARK-24156 + StopStream, + Execute { q => + // Delete the entire commit log + val commit = q.commitLog.getLatest().map(_._1).getOrElse(-1L) + q.commitLog.purge(commit + 1) + }, + AddData(inputData, 50), + StartStream(), + CheckNewAnswer((25, 1), (30, 1)) // This should not throw the error reported in SPARK-24156 + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index f348dac1319cb..4c3fd58cb2e45 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -292,7 +292,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be /** Execute arbitrary code */ object Execute { def apply(func: StreamExecution => Any): AssertOnQuery = - AssertOnQuery(query => { func(query); true }) + AssertOnQuery(query => { func(query); true }, "Execute") } object AwaitEpoch { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index cd1704ac2fdad..4980b0cd41f81 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -297,3 +297,49 @@ class ContinuousStressSuite extends ContinuousSuiteBase { CheckAnswerRowsContains(scala.Range(0, 25000).map(Row(_)))) } } + +class ContinuousMetaSuite extends ContinuousSuiteBase { + import testImplicits._ + + // We need to specify spark.sql.streaming.minBatchesToRetain to do the following test. + override protected def createSparkSession = new TestSparkSession( + new SparkContext( + "local[10]", + "continuous-stream-test-sql-context", + sparkConf.set("spark.sql.testkey", "true") + .set("spark.sql.streaming.minBatchesToRetain", "2"))) + + test("SPARK-24351: check offsetLog/commitLog retained in the checkpoint directory") { + withTempDir { checkpointDir => + val input = ContinuousMemoryStream[Int] + val df = input.toDF().mapPartitions(iter => { + // Sleep the task thread for 300 ms to make sure epoch processing time 3 times + // longer than epoch creating interval. So the gap between last committed + // epoch and currentBatchId grows over time. + Thread.sleep(300) + iter.map(row => row.getInt(0) * 2) + }) + + testStream(df)( + StartStream(trigger = Trigger.Continuous(100), + checkpointLocation = checkpointDir.getAbsolutePath), + AddData(input, 1), + CheckAnswer(2), + // Make sure epoch 2 has been committed before the following validation. + AwaitEpoch(2), + StopStream, + AssertOnQuery(q => { + q.commitLog.getLatest() match { + case Some((latestEpochId, _)) => + val commitLogValidateResult = q.commitLog.get(latestEpochId - 1).isDefined && + q.commitLog.get(latestEpochId - 2).isEmpty + val offsetLogValidateResult = q.offsetLog.get(latestEpochId - 1).isDefined && + q.offsetLog.get(latestEpochId - 2).isEmpty + commitLogValidateResult && offsetLogValidateResult + case None => false + } + }) + ) + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 948ba542b5733..130e258e78ca2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -24,7 +24,6 @@ import java.util.{ArrayList => JArrayList, List => JList, Locale, Map => JMap, S import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ -import scala.util.Try import scala.util.control.NonFatal import org.apache.hadoop.fs.Path @@ -657,17 +656,31 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { val useAdvanced = SQLConf.get.advancedPartitionPredicatePushdownEnabled + object ExtractAttribute { + def unapply(expr: Expression): Option[Attribute] = { + expr match { + case attr: Attribute => Some(attr) + case Cast(child, dt, _) if !Cast.mayTruncate(child.dataType, dt) => unapply(child) + case _ => None + } + } + } + def convert(expr: Expression): Option[String] = expr match { - case In(NonVarcharAttribute(name), ExtractableLiterals(values)) if useAdvanced => + case In(ExtractAttribute(NonVarcharAttribute(name)), ExtractableLiterals(values)) + if useAdvanced => Some(convertInToOr(name, values)) - case InSet(NonVarcharAttribute(name), ExtractableValues(values)) if useAdvanced => + case InSet(ExtractAttribute(NonVarcharAttribute(name)), ExtractableValues(values)) + if useAdvanced => Some(convertInToOr(name, values)) - case op @ SpecialBinaryComparison(NonVarcharAttribute(name), ExtractableLiteral(value)) => + case op @ SpecialBinaryComparison( + ExtractAttribute(NonVarcharAttribute(name)), ExtractableLiteral(value)) => Some(s"$name ${op.symbol} $value") - case op @ SpecialBinaryComparison(ExtractableLiteral(value), NonVarcharAttribute(name)) => + case op @ SpecialBinaryComparison( + ExtractableLiteral(value), ExtractAttribute(NonVarcharAttribute(name))) => Some(s"$value ${op.symbol} $name") case And(expr1, expr2) if useAdvanced => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala index f991352b207d4..55275f6b37945 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala @@ -22,13 +22,13 @@ import org.apache.hadoop.hive.conf.HiveConf import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.expressions.{EmptyRow, Expression, In, InSet} -import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.LongType // TODO: Refactor this to `HivePartitionFilteringSuite` class HiveClientSuite(version: String) extends HiveVersionSuite(version) with BeforeAndAfterAll { - import CatalystSqlParser._ private val tryDirectSqlKey = HiveConf.ConfVars.METASTORE_TRY_DIRECT_SQL.varname @@ -46,8 +46,7 @@ class HiveClientSuite(version: String) val hadoopConf = new Configuration() hadoopConf.setBoolean(tryDirectSqlKey, tryDirectSql) val client = buildClient(hadoopConf) - client - .runSqlHive("CREATE TABLE test (value INT) PARTITIONED BY (ds INT, h INT, chunk STRING)") + client.runSqlHive("CREATE TABLE test (value INT) PARTITIONED BY (ds INT, h INT, chunk STRING)") val partitions = for { @@ -66,6 +65,15 @@ class HiveClientSuite(version: String) client } + private def attr(name: String): Attribute = { + client.getTable("default", "test").partitionSchema.fields + .find(field => field.name.equals(name)) match { + case Some(field) => AttributeReference(field.name, field.dataType)() + case None => + fail(s"Illegal name of partition attribute: $name") + } + } + override def beforeAll() { super.beforeAll() client = init(true) @@ -74,7 +82,7 @@ class HiveClientSuite(version: String) test(s"getPartitionsByFilter returns all partitions when $tryDirectSqlKey=false") { val client = init(false) val filteredPartitions = client.getPartitionsByFilter(client.getTable("default", "test"), - Seq(parseExpression("ds=20170101"))) + Seq(attr("ds") === 20170101)) assert(filteredPartitions.size == testPartitionCount) } @@ -82,7 +90,7 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: ds<=>20170101") { // Should return all partitions where <=> is not supported testMetastorePartitionFiltering( - "ds<=>20170101", + attr("ds") <=> 20170101, 20170101 to 20170103, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil) @@ -90,7 +98,7 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: ds=20170101") { testMetastorePartitionFiltering( - "ds=20170101", + attr("ds") === 20170101, 20170101 to 20170101, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil) @@ -100,7 +108,7 @@ class HiveClientSuite(version: String) // Should return all partitions where h=0 because getPartitionsByFilter does not support // comparisons to non-literal values testMetastorePartitionFiltering( - "ds=(20170101 + 1) and h=0", + attr("ds") === (Literal(20170101) + 1) && attr("h") === 0, 20170101 to 20170103, 0 to 0, "aa" :: "ab" :: "ba" :: "bb" :: Nil) @@ -108,7 +116,7 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: chunk='aa'") { testMetastorePartitionFiltering( - "chunk='aa'", + attr("chunk") === "aa", 20170101 to 20170103, 0 to 23, "aa" :: Nil) @@ -116,7 +124,7 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: 20170101=ds") { testMetastorePartitionFiltering( - "20170101=ds", + Literal(20170101) === attr("ds"), 20170101 to 20170101, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil) @@ -124,7 +132,15 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: ds=20170101 and h=10") { testMetastorePartitionFiltering( - "ds=20170101 and h=10", + attr("ds") === 20170101 && attr("h") === 10, + 20170101 to 20170101, + 10 to 10, + "aa" :: "ab" :: "ba" :: "bb" :: Nil) + } + + test("getPartitionsByFilter: chunk in cast(ds as long)=20170101L") { + testMetastorePartitionFiltering( + attr("ds").cast(LongType) === 20170101L && attr("h") === 10, 20170101 to 20170101, 10 to 10, "aa" :: "ab" :: "ba" :: "bb" :: Nil) @@ -132,7 +148,7 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: ds=20170101 or ds=20170102") { testMetastorePartitionFiltering( - "ds=20170101 or ds=20170102", + attr("ds") === 20170101 || attr("ds") === 20170102, 20170101 to 20170102, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil) @@ -140,7 +156,15 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: ds in (20170102, 20170103) (using IN expression)") { testMetastorePartitionFiltering( - "ds in (20170102, 20170103)", + attr("ds").in(20170102, 20170103), + 20170102 to 20170103, + 0 to 23, + "aa" :: "ab" :: "ba" :: "bb" :: Nil) + } + + test("getPartitionsByFilter: cast(ds as long) in (20170102L, 20170103L) (using IN expression)") { + testMetastorePartitionFiltering( + attr("ds").cast(LongType).in(20170102L, 20170103L), 20170102 to 20170103, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil) @@ -148,7 +172,19 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: ds in (20170102, 20170103) (using INSET expression)") { testMetastorePartitionFiltering( - "ds in (20170102, 20170103)", + attr("ds").in(20170102, 20170103), + 20170102 to 20170103, + 0 to 23, + "aa" :: "ab" :: "ba" :: "bb" :: Nil, { + case expr @ In(v, list) if expr.inSetConvertible => + InSet(v, list.map(_.eval(EmptyRow)).toSet) + }) + } + + test("getPartitionsByFilter: cast(ds as long) in (20170102L, 20170103L) (using INSET expression)") + { + testMetastorePartitionFiltering( + attr("ds").cast(LongType).in(20170102L, 20170103L), 20170102 to 20170103, 0 to 23, "aa" :: "ab" :: "ba" :: "bb" :: Nil, { @@ -159,7 +195,7 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: chunk in ('ab', 'ba') (using IN expression)") { testMetastorePartitionFiltering( - "chunk in ('ab', 'ba')", + attr("chunk").in("ab", "ba"), 20170101 to 20170103, 0 to 23, "ab" :: "ba" :: Nil) @@ -167,7 +203,7 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: chunk in ('ab', 'ba') (using INSET expression)") { testMetastorePartitionFiltering( - "chunk in ('ab', 'ba')", + attr("chunk").in("ab", "ba"), 20170101 to 20170103, 0 to 23, "ab" :: "ba" :: Nil, { @@ -179,26 +215,24 @@ class HiveClientSuite(version: String) test("getPartitionsByFilter: (ds=20170101 and h>=8) or (ds=20170102 and h<8)") { val day1 = (20170101 to 20170101, 8 to 23, Seq("aa", "ab", "ba", "bb")) val day2 = (20170102 to 20170102, 0 to 7, Seq("aa", "ab", "ba", "bb")) - testMetastorePartitionFiltering( - "(ds=20170101 and h>=8) or (ds=20170102 and h<8)", - day1 :: day2 :: Nil) + testMetastorePartitionFiltering((attr("ds") === 20170101 && attr("h") >= 8) || + (attr("ds") === 20170102 && attr("h") < 8), day1 :: day2 :: Nil) } test("getPartitionsByFilter: (ds=20170101 and h>=8) or (ds=20170102 and h<(7+1))") { val day1 = (20170101 to 20170101, 8 to 23, Seq("aa", "ab", "ba", "bb")) // Day 2 should include all hours because we can't build a filter for h<(7+1) val day2 = (20170102 to 20170102, 0 to 23, Seq("aa", "ab", "ba", "bb")) - testMetastorePartitionFiltering( - "(ds=20170101 and h>=8) or (ds=20170102 and h<(7+1))", - day1 :: day2 :: Nil) + testMetastorePartitionFiltering((attr("ds") === 20170101 && attr("h") >= 8) || + (attr("ds") === 20170102 && attr("h") < (Literal(7) + 1)), day1 :: day2 :: Nil) } test("getPartitionsByFilter: " + "chunk in ('ab', 'ba') and ((ds=20170101 and h>=8) or (ds=20170102 and h<8))") { val day1 = (20170101 to 20170101, 8 to 23, Seq("ab", "ba")) val day2 = (20170102 to 20170102, 0 to 7, Seq("ab", "ba")) - testMetastorePartitionFiltering( - "chunk in ('ab', 'ba') and ((ds=20170101 and h>=8) or (ds=20170102 and h<8))", + testMetastorePartitionFiltering(attr("chunk").in("ab", "ba") && + ((attr("ds") === 20170101 && attr("h") >= 8) || (attr("ds") === 20170102 && attr("h") < 8)), day1 :: day2 :: Nil) } @@ -207,41 +241,41 @@ class HiveClientSuite(version: String) } private def testMetastorePartitionFiltering( - filterString: String, + filterExpr: Expression, expectedDs: Seq[Int], expectedH: Seq[Int], expectedChunks: Seq[String]): Unit = { testMetastorePartitionFiltering( - filterString, + filterExpr, (expectedDs, expectedH, expectedChunks) :: Nil, identity) } private def testMetastorePartitionFiltering( - filterString: String, + filterExpr: Expression, expectedDs: Seq[Int], expectedH: Seq[Int], expectedChunks: Seq[String], transform: Expression => Expression): Unit = { testMetastorePartitionFiltering( - filterString, + filterExpr, (expectedDs, expectedH, expectedChunks) :: Nil, - identity) + transform) } private def testMetastorePartitionFiltering( - filterString: String, + filterExpr: Expression, expectedPartitionCubes: Seq[(Seq[Int], Seq[Int], Seq[String])]): Unit = { - testMetastorePartitionFiltering(filterString, expectedPartitionCubes, identity) + testMetastorePartitionFiltering(filterExpr, expectedPartitionCubes, identity) } private def testMetastorePartitionFiltering( - filterString: String, + filterExpr: Expression, expectedPartitionCubes: Seq[(Seq[Int], Seq[Int], Seq[String])], transform: Expression => Expression): Unit = { val filteredPartitions = client.getPartitionsByFilter(client.getTable("default", "test"), Seq( - transform(parseExpression(filterString)) + transform(filterExpr) )) val expectedPartitionCount = expectedPartitionCubes.map {