diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 6eaf6794764c7..3b8786cf67a6f 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1044,7 +1044,7 @@ class SparkContext(config: SparkConf) extends Logging { private[spark] def getCallSite(): CallSite = { Option(getLocalProperty("externalCallSite")) match { case Some(callSite) => CallSite(callSite, longForm = "") - case None => Utils.getCallSite + case None => Utils.getCallSite() } } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 1cf55e86f6c81..521317727e788 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.rdd -import java.util.Random +import java.util.{Properties, Random} import scala.collection.{mutable, Map} import scala.collection.mutable.ArrayBuffer @@ -41,7 +41,7 @@ import org.apache.spark.partial.CountEvaluator import org.apache.spark.partial.GroupedCountEvaluator import org.apache.spark.partial.PartialResult import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{BoundedPriorityQueue, Utils} +import org.apache.spark.util.{BoundedPriorityQueue, Utils, CallSite} import org.apache.spark.util.collection.OpenHashMap import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, SamplingUtils} @@ -1220,7 +1220,15 @@ abstract class RDD[T: ClassTag]( private var storageLevel: StorageLevel = StorageLevel.NONE /** User code that created this RDD (e.g. `textFile`, `parallelize`). */ - @transient private[spark] val creationSite = Utils.getCallSite + @transient private[spark] val creationSite = { + val short: String = sc.getLocalProperty(Utils.CALL_SITE_SHORT) + if (short != null) { + CallSite(short, sc.getLocalProperty(Utils.CALL_SITE_LONG)) + } else { + Utils.getCallSite() + } + } + private[spark] def getCreationSite: String = Option(creationSite).map(_.shortForm).getOrElse("") private[spark] def elementClassTag: ClassTag[T] = classTag[T] diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 0ae28f911e302..46828bfec0593 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -53,6 +53,9 @@ private[spark] case class CallSite(shortForm: String, longForm: String) private[spark] object Utils extends Logging { val random = new Random() + private[spark] val CALL_SITE_SHORT: String = "callSite.short" + private[spark] val CALL_SITE_LONG: String = "callSite.long" + /** Serialize an object using Java serialization */ def serialize[T](o: T): Array[Byte] = { val bos = new ByteArrayOutputStream() @@ -856,13 +859,19 @@ private[spark] object Utils extends Logging { * finding the call site of a method. */ private val SPARK_CLASS_REGEX = """^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?\.[A-Z]""".r + val SCALA_CLASS_REGEX = """^scala""".r + + private def defaultRegexFunc(className: String): Boolean = { + SPARK_CLASS_REGEX.findFirstIn(className).isDefined || + SCALA_CLASS_REGEX.findFirstIn(className).isDefined + } /** * When called inside a class in the spark package, returns the name of the user code class * (outside the spark package) that called into Spark, as well as which Spark method they called. * This is used, for example, to tell users where in their code each RDD got created. */ - def getCallSite: CallSite = { + def getCallSite(regexFunc: String => Boolean = defaultRegexFunc(_)): CallSite = { val trace = Thread.currentThread.getStackTrace() .filterNot { ste:StackTraceElement => // When running under some profilers, the current stack trace might contain some bogus @@ -883,8 +892,8 @@ private[spark] object Utils extends Logging { for (el <- trace) { if (insideSpark) { - if (SPARK_CLASS_REGEX.findFirstIn(el.getClassName).isDefined) { - lastSparkMethod = if (el.getMethodName == "") { + if (regexFunc(el.getClassName)) { + lastSparkMethod = if (el.getMethodName == "") { // Spark method is a constructor; get its class name el.getClassName.substring(el.getClassName.lastIndexOf('.') + 1) } else { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 101cec1c7a7c2..b1fbbe354cc96 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -38,7 +38,7 @@ import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.receiver.{ActorSupervisorStrategy, ActorReceiver, Receiver} import org.apache.spark.streaming.scheduler._ import org.apache.spark.streaming.ui.StreamingTab -import org.apache.spark.util.MetadataCleaner +import org.apache.spark.util.Utils /** * Main entry point for Spark Streaming functionality. It provides methods used to create @@ -441,6 +441,9 @@ class StreamingContext private[streaming] ( throw new SparkException("StreamingContext has already been stopped") } validate() + sc.setCallSite( + Utils.getCallSite(org.apache.spark.streaming.util.Utils.streamingRegexFunc).shortForm + ) scheduler.start() state = Started } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index e05db236addca..cdaaa55978cf0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -30,7 +30,7 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ import org.apache.spark.streaming.StreamingContext._ import org.apache.spark.streaming.scheduler.Job -import org.apache.spark.util.MetadataCleaner +import org.apache.spark.util.{CallSite, Utils, MetadataCleaner} /** * A Discretized Stream (DStream), the basic abstraction in Spark Streaming, is a continuous @@ -106,6 +106,21 @@ abstract class DStream[T: ClassTag] ( /** Return the StreamingContext associated with this DStream */ def context = ssc + /* Find the creation callSite */ + val creationSite = Utils.getCallSite(org.apache.spark.streaming.util.Utils.streamingRegexFunc) + + /* Store the RDD creation callSite in threadlocal */ + private def setRDDCreationCallSite(callSite: CallSite = creationSite) = { + ssc.sparkContext.setLocalProperty(Utils.CALL_SITE_SHORT, callSite.shortForm) + ssc.sparkContext.setLocalProperty(Utils.CALL_SITE_LONG, callSite.longForm) + } + + /* Return the current callSite */ + private def getRDDCreationCallSite(): CallSite = { + CallSite(ssc.sparkContext.getLocalProperty(Utils.CALL_SITE_SHORT), + ssc.sparkContext.getLocalProperty(Utils.CALL_SITE_LONG)) + } + /** Persist the RDDs of this DStream with the given storage level */ def persist(level: StorageLevel): DStream[T] = { if (this.isInitialized) { @@ -288,7 +303,9 @@ abstract class DStream[T: ClassTag] ( // (based on sliding time of this DStream), then generate the RDD case None => { if (isTimeValid(time)) { - compute(time) match { + val prevCallSite = getRDDCreationCallSite + setRDDCreationCallSite() + val rddOption = compute(time) match { case Some(newRDD) => if (storageLevel != StorageLevel.NONE) { newRDD.persist(storageLevel) @@ -304,10 +321,12 @@ abstract class DStream[T: ClassTag] ( generatedRDDs.put(time, newRDD) Some(newRDD) case None => - None + return None } + setRDDCreationCallSite(prevCallSite) + return rddOption } else { - None + return None } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/Utils.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/Utils.scala new file mode 100644 index 0000000000000..e0458ea41b9c7 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/Utils.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.util + +import org.apache.spark.util.Utils.SCALA_CLASS_REGEX + +/** + * Utility method used by Spark Streaming. + */ +private[streaming] object Utils { + private val SPARK_STREAMING_CLASS_REGEX = """^org\.apache\.spark""".r + private val SPARK_EXAMPLES_CLASS_REGEX = """^org\.apache\.spark\.examples""".r + + def streamingRegexFunc(className: String): Boolean = { + (SPARK_STREAMING_CLASS_REGEX.findFirstIn(className).isDefined && + !SPARK_EXAMPLES_CLASS_REGEX.findFirstIn(className).isDefined) || + SCALA_CLASS_REGEX.findFirstIn(className).isDefined + } +}