From ef633f5e4857400c8711ee800b01016b6bd406b2 Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Mon, 13 Oct 2014 18:11:11 +0530 Subject: [PATCH 1/6] SPARK-3874, Provide stable TaskContext API --- .../java/org/apache/spark/TaskContext.java | 74 ++----------------- .../org/apache/spark/rdd/HadoopRDD.scala | 2 +- .../apache/spark/rdd/PairRDDFunctions.scala | 8 +- .../apache/spark/scheduler/DAGScheduler.scala | 6 +- .../org/apache/spark/scheduler/Task.scala | 8 +- .../java/org/apache/spark/JavaAPISuite.java | 2 +- .../util/JavaTaskCompletionListenerImpl.java | 4 +- .../org/apache/spark/CacheManagerSuite.scala | 8 +- .../org/apache/spark/rdd/PipedRDDSuite.scala | 2 +- .../spark/scheduler/TaskContextSuite.scala | 2 +- .../ShuffleBlockFetcherIteratorSuite.scala | 8 +- .../sql/parquet/ParquetTableOperations.scala | 4 +- 12 files changed, 32 insertions(+), 96 deletions(-) diff --git a/core/src/main/java/org/apache/spark/TaskContext.java b/core/src/main/java/org/apache/spark/TaskContext.java index 4e6d708af0ea7..a18ddeb49db52 100644 --- a/core/src/main/java/org/apache/spark/TaskContext.java +++ b/core/src/main/java/org/apache/spark/TaskContext.java @@ -37,7 +37,7 @@ * Contextual information about a task which can be read or mutated during execution. */ @DeveloperApi -public class TaskContext implements Serializable { +public abstract class TaskContext implements Serializable { private int stageId; private int partitionId; @@ -45,19 +45,8 @@ public class TaskContext implements Serializable { private boolean runningLocally; private TaskMetrics taskMetrics; - /** - * :: DeveloperApi :: - * Contextual information about a task which can be read or mutated during execution. - * - * @param stageId stage id - * @param partitionId index of the partition - * @param attemptId the number of attempts to execute this task - * @param runningLocally whether the task is running locally in the driver JVM - * @param taskMetrics performance metrics of the task - */ - @DeveloperApi - public TaskContext(int stageId, int partitionId, long attemptId, boolean runningLocally, - TaskMetrics taskMetrics) { + TaskContext(int stageId, int partitionId, long attemptId, boolean runningLocally, + TaskMetrics taskMetrics) { this.attemptId = attemptId; this.partitionId = partitionId; this.runningLocally = runningLocally; @@ -65,41 +54,6 @@ public TaskContext(int stageId, int partitionId, long attemptId, boolean running this.taskMetrics = taskMetrics; } - /** - * :: DeveloperApi :: - * Contextual information about a task which can be read or mutated during execution. - * - * @param stageId stage id - * @param partitionId index of the partition - * @param attemptId the number of attempts to execute this task - * @param runningLocally whether the task is running locally in the driver JVM - */ - @DeveloperApi - public TaskContext(int stageId, int partitionId, long attemptId, boolean runningLocally) { - this.attemptId = attemptId; - this.partitionId = partitionId; - this.runningLocally = runningLocally; - this.stageId = stageId; - this.taskMetrics = TaskMetrics.empty(); - } - - /** - * :: DeveloperApi :: - * Contextual information about a task which can be read or mutated during execution. - * - * @param stageId stage id - * @param partitionId index of the partition - * @param attemptId the number of attempts to execute this task - */ - @DeveloperApi - public TaskContext(int stageId, int partitionId, long attemptId) { - this.attemptId = attemptId; - this.partitionId = partitionId; - this.runningLocally = false; - this.stageId = stageId; - this.taskMetrics = TaskMetrics.empty(); - } - private static ThreadLocal taskContext = new ThreadLocal(); @@ -107,7 +61,7 @@ public TaskContext(int stageId, int partitionId, long attemptId) { * :: Internal API :: * This is spark internal API, not intended to be called from user programs. */ - public static void setTaskContext(TaskContext tc) { + static void setTaskContext(TaskContext tc) { taskContext.set(tc); } @@ -116,7 +70,7 @@ public static TaskContext get() { } /** :: Internal API :: */ - public static void unset() { + static void unset() { taskContext.remove(); } @@ -222,20 +176,14 @@ public void markInterrupted() { interrupted = true; } - @Deprecated - /** Deprecated: use getStageId() */ public int stageId() { return stageId; } - @Deprecated - /** Deprecated: use getPartitionId() */ public int partitionId() { return partitionId; } - @Deprecated - /** Deprecated: use getAttemptId() */ public long attemptId() { return attemptId; } @@ -250,18 +198,6 @@ public boolean isRunningLocally() { return runningLocally; } - public int getStageId() { - return stageId; - } - - public int getPartitionId() { - return partitionId; - } - - public long getAttemptId() { - return attemptId; - } - /** ::Internal API:: */ public TaskMetrics taskMetrics() { return taskMetrics; diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 6b63eb23e9ee1..8010dd90082f8 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -196,7 +196,7 @@ class HadoopRDD[K, V]( val jobConf = getJobConf() val inputFormat = getInputFormat(jobConf) HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmm").format(createTime), - context.getStageId, theSplit.index, context.getAttemptId.toInt, jobConf) + context.stageId, theSplit.index, context.attemptId.toInt, jobConf) reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL) // Register an on-task-completion callback to close the input stream. diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 0d97506450a7f..929ded58a3bd5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -956,9 +956,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val writeShard = (context: TaskContext, iter: Iterator[(K,V)]) => { // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it // around by taking a mod. We expect that no task will be attempted 2 billion times. - val attemptNumber = (context.getAttemptId % Int.MaxValue).toInt + val attemptNumber = (context.attemptId % Int.MaxValue).toInt /* "reduce task" */ - val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.getPartitionId, + val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId, attemptNumber) val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId) val format = outfmt.newInstance @@ -1027,9 +1027,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val writeToFile = (context: TaskContext, iter: Iterator[(K, V)]) => { // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it // around by taking a mod. We expect that no task will be attempted 2 billion times. - val attemptNumber = (context.getAttemptId % Int.MaxValue).toInt + val attemptNumber = (context.attemptId % Int.MaxValue).toInt - writer.setup(context.getStageId, context.getPartitionId, attemptNumber) + writer.setup(context.stageId, context.partitionId, attemptNumber) writer.open() try { var count = 0 diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 788eb1ff4e455..f81fa6d8089fc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -633,14 +633,14 @@ class DAGScheduler( val rdd = job.finalStage.rdd val split = rdd.partitions(job.partitions(0)) val taskContext = - new TaskContext(job.finalStage.id, job.partitions(0), 0, true) - TaskContext.setTaskContext(taskContext) + new TaskContextImpl(job.finalStage.id, job.partitions(0), 0, true) + TaskContextHelper.setTaskContext(taskContext) try { val result = job.func(taskContext, rdd.iterator(split, taskContext)) job.listener.taskSucceeded(0, result) } finally { taskContext.markTaskCompleted() - TaskContext.unset() + TaskContextHelper.unset() } } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index c6e47c84a0cb2..78213686d6813 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -22,7 +22,7 @@ import java.nio.ByteBuffer import scala.collection.mutable.HashMap -import org.apache.spark.TaskContext +import org.apache.spark.{TaskContextHelper, TaskContextImpl, TaskContext} import org.apache.spark.executor.TaskMetrics import org.apache.spark.serializer.SerializerInstance import org.apache.spark.util.ByteBufferInputStream @@ -45,8 +45,8 @@ import org.apache.spark.util.Utils private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable { final def run(attemptId: Long): T = { - context = new TaskContext(stageId, partitionId, attemptId, false) - TaskContext.setTaskContext(context) + context = new TaskContextImpl(stageId, partitionId, attemptId, false) + TaskContextHelper.setTaskContext(context) context.taskMetrics.hostname = Utils.localHostName() taskThread = Thread.currentThread() if (_killed) { @@ -56,7 +56,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex runTask(context) } finally { context.markTaskCompleted() - TaskContext.unset() + TaskContextHelper.unset() } } diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 4a078435447e5..b8fa822ae4bd8 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -776,7 +776,7 @@ public void persist() { @Test public void iterator() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2); - TaskContext context = new TaskContext(0, 0, 0L, false, new TaskMetrics()); + TaskContext context = new TaskContextImpl(0, 0, 0L, false, new TaskMetrics()); Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue()); } diff --git a/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java b/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java index 0944bf8cd5c71..e9ec700e32e15 100644 --- a/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java +++ b/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java @@ -30,8 +30,8 @@ public class JavaTaskCompletionListenerImpl implements TaskCompletionListener { public void onTaskCompletion(TaskContext context) { context.isCompleted(); context.isInterrupted(); - context.getStageId(); - context.getPartitionId(); + context.stageId(); + context.partitionId(); context.isRunningLocally(); context.addTaskCompletionListener(this); } diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala index d735010d7c9d5..c0735f448d193 100644 --- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala @@ -66,7 +66,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar // in blockManager.put is a losing battle. You have been warned. blockManager = sc.env.blockManager cacheManager = sc.env.cacheManager - val context = new TaskContext(0, 0, 0) + val context = new TaskContextImpl(0, 0, 0) val computeValue = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) val getValue = blockManager.get(RDDBlockId(rdd.id, split.index)) assert(computeValue.toList === List(1, 2, 3, 4)) @@ -81,7 +81,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar } whenExecuting(blockManager) { - val context = new TaskContext(0, 0, 0) + val context = new TaskContextImpl(0, 0, 0) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(5, 6, 7)) } @@ -94,7 +94,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar } whenExecuting(blockManager) { - val context = new TaskContext(0, 0, 0, true) + val context = new TaskContextImpl(0, 0, 0, true) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(1, 2, 3, 4)) } @@ -102,7 +102,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar test("verify task metrics updated correctly") { cacheManager = sc.env.cacheManager - val context = new TaskContext(0, 0, 0) + val context = new TaskContextImpl(0, 0, 0) cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY) assert(context.taskMetrics.updatedBlocks.getOrElse(Seq()).size === 2) } diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala index be972c5e97a7e..271a90c6646bb 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala @@ -174,7 +174,7 @@ class PipedRDDSuite extends FunSuite with SharedSparkContext { } val hadoopPart1 = generateFakeHadoopPartition() val pipedRdd = new PipedRDD(nums, "printenv " + varName) - val tContext = new TaskContext(0, 0, 0) + val tContext = new TaskContextImpl(0, 0, 0) val rddIter = pipedRdd.compute(hadoopPart1, tContext) val arr = rddIter.toArray assert(arr(0) == "/some/path") diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index faba5508c906c..561a5e9cd90c4 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -51,7 +51,7 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte } test("all TaskCompletionListeners should be called even if some fail") { - val context = new TaskContext(0, 0, 0) + val context = new TaskContextImpl(0, 0, 0) val listener = mock(classOf[TaskCompletionListener]) context.addTaskCompletionListener(_ => throw new Exception("blah")) context.addTaskCompletionListener(listener) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 809bd70929656..a8c049d749015 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.storage -import org.apache.spark.TaskContext +import org.apache.spark.{TaskContextImpl, TaskContext} import org.apache.spark.network.{BlockFetchingListener, BlockTransferService} import org.mockito.Mockito._ @@ -62,7 +62,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { ) val iterator = new ShuffleBlockFetcherIterator( - new TaskContext(0, 0, 0), + new TaskContextImpl(0, 0, 0), transfer, blockManager, blocksByAddress, @@ -120,7 +120,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { ) val iterator = new ShuffleBlockFetcherIterator( - new TaskContext(0, 0, 0), + new TaskContextImpl(0, 0, 0), transfer, blockManager, blocksByAddress, @@ -169,7 +169,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { (bmId, Seq((blId1, 1L), (blId2, 1L)))) val iterator = new ShuffleBlockFetcherIterator( - new TaskContext(0, 0, 0), + new TaskContextImpl(0, 0, 0), transfer, blockManager, blocksByAddress, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index ffb732347d30a..d39e31a7fa195 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -289,9 +289,9 @@ case class InsertIntoParquetTable( def writeShard(context: TaskContext, iter: Iterator[Row]): Int = { // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it // around by taking a mod. We expect that no task will be attempted 2 billion times. - val attemptNumber = (context.getAttemptId % Int.MaxValue).toInt + val attemptNumber = (context.attemptId % Int.MaxValue).toInt /* "reduce task" */ - val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.getPartitionId, + val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId, attemptNumber) val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId) val format = new AppendingParquetOutputFormat(taskIdOffset) From bbd9e057a24cd25336a806dce41b2cbd1ebc3233 Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Mon, 13 Oct 2014 18:43:33 +0530 Subject: [PATCH 2/6] adding missed out files to git. --- .../org/apache/spark/TaskContextHelper.scala | 26 +++++++++++++++++ .../org/apache/spark/TaskContextImpl.scala | 28 +++++++++++++++++++ 2 files changed, 54 insertions(+) create mode 100644 core/src/main/scala/org/apache/spark/TaskContextHelper.scala create mode 100644 core/src/main/scala/org/apache/spark/TaskContextImpl.scala diff --git a/core/src/main/scala/org/apache/spark/TaskContextHelper.scala b/core/src/main/scala/org/apache/spark/TaskContextHelper.scala new file mode 100644 index 0000000000000..4bc3ffeb4e87e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/TaskContextHelper.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +private [spark] object TaskContextHelper { + + def setTaskContext(tc: TaskContext): Unit = TaskContext.setTaskContext(tc) + + def unset(): Unit = TaskContext.unset() + +} diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala new file mode 100644 index 0000000000000..37135a2d85a99 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.apache.spark.executor.TaskMetrics + +private[spark] class TaskContextImpl(stageId: Int, + partitionId: Int, + attemptId: Long, + runningLocally: Boolean = false, + taskMetrics: TaskMetrics = TaskMetrics.empty) + extends TaskContext(stageId, partitionId, attemptId, runningLocally, taskMetrics); + From 7ecc2fe712a40d73192ed9e28d8eca6d48c941f3 Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Tue, 14 Oct 2014 14:00:38 +0530 Subject: [PATCH 3/6] CR, Moved implementations to TaskContextImpl --- .../java/org/apache/spark/TaskContext.java | 115 ++---------------- .../org/apache/spark/TaskContextHelper.scala | 2 +- .../org/apache/spark/TaskContextImpl.scala | 86 ++++++++++++- 3 files changed, 93 insertions(+), 110 deletions(-) diff --git a/core/src/main/java/org/apache/spark/TaskContext.java b/core/src/main/java/org/apache/spark/TaskContext.java index a18ddeb49db52..6f5a66a4f31d9 100644 --- a/core/src/main/java/org/apache/spark/TaskContext.java +++ b/core/src/main/java/org/apache/spark/TaskContext.java @@ -39,21 +39,6 @@ @DeveloperApi public abstract class TaskContext implements Serializable { - private int stageId; - private int partitionId; - private long attemptId; - private boolean runningLocally; - private TaskMetrics taskMetrics; - - TaskContext(int stageId, int partitionId, long attemptId, boolean runningLocally, - TaskMetrics taskMetrics) { - this.attemptId = attemptId; - this.partitionId = partitionId; - this.runningLocally = runningLocally; - this.stageId = stageId; - this.taskMetrics = taskMetrics; - } - private static ThreadLocal taskContext = new ThreadLocal(); @@ -74,29 +59,15 @@ static void unset() { taskContext.remove(); } - // List of callback functions to execute when the task completes. - private transient List onCompleteCallbacks = - new ArrayList(); - - // Whether the corresponding task has been killed. - private volatile boolean interrupted = false; - - // Whether the task has completed. - private volatile boolean completed = false; - /** * Checks whether the task has completed. */ - public boolean isCompleted() { - return completed; - } + public abstract boolean isCompleted(); /** * Checks whether the task has been killed. */ - public boolean isInterrupted() { - return interrupted; - } + public abstract boolean isInterrupted(); /** * Add a (Java friendly) listener to be executed on task completion. @@ -104,10 +75,7 @@ public boolean isInterrupted() { *

* An example use is for HadoopRDD to register a callback to close the input stream. */ - public TaskContext addTaskCompletionListener(TaskCompletionListener listener) { - onCompleteCallbacks.add(listener); - return this; - } + public abstract TaskContext addTaskCompletionListener(TaskCompletionListener listener); /** * Add a listener in the form of a Scala closure to be executed on task completion. @@ -115,15 +83,7 @@ public TaskContext addTaskCompletionListener(TaskCompletionListener listener) { *

* An example use is for HadoopRDD to register a callback to close the input stream. */ - public TaskContext addTaskCompletionListener(final Function1 f) { - onCompleteCallbacks.add(new TaskCompletionListener() { - @Override - public void onTaskCompletion(TaskContext context) { - f.apply(context); - } - }); - return this; - } + public abstract TaskContext addTaskCompletionListener(final Function1 f); /** * Add a callback function to be executed on task completion. An example use @@ -131,75 +91,24 @@ public void onTaskCompletion(TaskContext context) { * Will be called in any situation - success, failure, or cancellation. * * Deprecated: use addTaskCompletionListener - * + * * @param f Callback function. */ @Deprecated - public void addOnCompleteCallback(final Function0 f) { - onCompleteCallbacks.add(new TaskCompletionListener() { - @Override - public void onTaskCompletion(TaskContext context) { - f.apply(); - } - }); - } + public abstract void addOnCompleteCallback(final Function0 f); - /** - * ::Internal API:: - * Marks the task as completed and triggers the listeners. - */ - public void markTaskCompleted() throws TaskCompletionListenerException { - completed = true; - List errorMsgs = new ArrayList(2); - // Process complete callbacks in the reverse order of registration - List revlist = - new ArrayList(onCompleteCallbacks); - Collections.reverse(revlist); - for (TaskCompletionListener tcl: revlist) { - try { - tcl.onTaskCompletion(this); - } catch (Throwable e) { - errorMsgs.add(e.getMessage()); - } - } - - if (!errorMsgs.isEmpty()) { - throw new TaskCompletionListenerException(JavaConversions.asScalaBuffer(errorMsgs)); - } - } - - /** - * ::Internal API:: - * Marks the task for interruption, i.e. cancellation. - */ - public void markInterrupted() { - interrupted = true; - } - - public int stageId() { - return stageId; - } + public abstract int stageId(); - public int partitionId() { - return partitionId; - } + public abstract int partitionId(); - public long attemptId() { - return attemptId; - } + public abstract long attemptId(); @Deprecated /** Deprecated: use isRunningLocally() */ - public boolean runningLocally() { - return runningLocally; - } + public abstract boolean runningLocally(); - public boolean isRunningLocally() { - return runningLocally; - } + public abstract boolean isRunningLocally(); /** ::Internal API:: */ - public TaskMetrics taskMetrics() { - return taskMetrics; - } + public abstract TaskMetrics taskMetrics(); } diff --git a/core/src/main/scala/org/apache/spark/TaskContextHelper.scala b/core/src/main/scala/org/apache/spark/TaskContextHelper.scala index 4bc3ffeb4e87e..a18e0586fca12 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextHelper.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextHelper.scala @@ -22,5 +22,5 @@ private [spark] object TaskContextHelper { def setTaskContext(tc: TaskContext): Unit = TaskContext.setTaskContext(tc) def unset(): Unit = TaskContext.unset() - + } diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 37135a2d85a99..244a02a298cf9 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -18,11 +18,85 @@ package org.apache.spark import org.apache.spark.executor.TaskMetrics +import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException} -private[spark] class TaskContextImpl(stageId: Int, - partitionId: Int, - attemptId: Long, - runningLocally: Boolean = false, - taskMetrics: TaskMetrics = TaskMetrics.empty) - extends TaskContext(stageId, partitionId, attemptId, runningLocally, taskMetrics); +import scala.collection.mutable.ArrayBuffer + +private[spark] class TaskContextImpl(_stageId: Int, + _partitionId: Int, + _attemptId: Long, + _runningLocally: Boolean = false, + _taskMetrics: TaskMetrics = TaskMetrics.empty) + extends TaskContext(_stageId, _partitionId, _attemptId, _runningLocally, _taskMetrics) + with Logging { + + // List of callback functions to execute when the task completes. + @transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener] + + // Whether the corresponding task has been killed. + @volatile private var interrupted: Boolean = false + + // Whether the task has completed. + @volatile private var completed: Boolean = false + + override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = { + onCompleteCallbacks += listener + this + } + + override def addTaskCompletionListener(f: TaskContext => Unit): this.type = { + onCompleteCallbacks += new TaskCompletionListener { + override def onTaskCompletion(context: TaskContext): Unit = f(context) + } + this + } + + @deprecated("use addTaskCompletionListener", "1.1.0") + override def addOnCompleteCallback(f: () => Unit) { + onCompleteCallbacks += new TaskCompletionListener { + override def onTaskCompletion(context: TaskContext): Unit = f() + } + } + + /** Marks the task as completed and triggers the listeners. */ + private[spark] def markTaskCompleted(): Unit = { + completed = true + val errorMsgs = new ArrayBuffer[String](2) + // Process complete callbacks in the reverse order of registration + onCompleteCallbacks.reverse.foreach { listener => + try { + listener.onTaskCompletion(this) + } catch { + case e: Throwable => + errorMsgs += e.getMessage + logError("Error in TaskCompletionListener", e) + } + } + if (errorMsgs.nonEmpty) { + throw new TaskCompletionListenerException(errorMsgs) + } + } + + /** Marks the task for interruption, i.e. cancellation. */ + private[spark] def markInterrupted(): Unit = { + interrupted = true + } + + override def isCompleted: Boolean = completed + + override def taskMetrics(): TaskMetrics = _taskMetrics + + override def isRunningLocally: Boolean = _runningLocally + + override def runningLocally(): Boolean = _runningLocally + + override def isInterrupted: Boolean = interrupted + + override def partitionId(): Int = _partitionId + + override def attemptId(): Long = _attemptId + + override def stageId(): Int = _stageId + +} From facf3b1987055a60e8e22f0a7a3bdc64ce250e3f Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Tue, 14 Oct 2014 14:06:14 +0530 Subject: [PATCH 4/6] Fixed the mima issue. --- project/MimaBuild.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index 39f8ba4745737..88a3bab6adf80 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -39,7 +39,8 @@ object MimaBuild { ProblemFilters.exclude[MissingFieldProblem](fullName), ProblemFilters.exclude[IncompatibleResultTypeProblem](fullName), ProblemFilters.exclude[IncompatibleMethTypeProblem](fullName), - ProblemFilters.exclude[IncompatibleFieldTypeProblem](fullName) + ProblemFilters.exclude[IncompatibleFieldTypeProblem](fullName), + ProblemFilters.exclude[AbstractClassProblem](fullName) ) // Exclude a single class and its corresponding object From df261d0f41c41d0fe05c418aab46c82ff2183b7b Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Tue, 14 Oct 2014 14:12:52 +0530 Subject: [PATCH 5/6] Josh's suggestion --- .../org/apache/spark/TaskContextImpl.scala | 25 ++++++------------- 1 file changed, 7 insertions(+), 18 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 244a02a298cf9..01508f2ed32d2 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -22,12 +22,12 @@ import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerExce import scala.collection.mutable.ArrayBuffer -private[spark] class TaskContextImpl(_stageId: Int, - _partitionId: Int, - _attemptId: Long, - _runningLocally: Boolean = false, - _taskMetrics: TaskMetrics = TaskMetrics.empty) - extends TaskContext(_stageId, _partitionId, _attemptId, _runningLocally, _taskMetrics) +private[spark] class TaskContextImpl(val stageId: Int, + val partitionId: Int, + val attemptId: Long, + val runningLocally: Boolean = false, + val taskMetrics: TaskMetrics = TaskMetrics.empty) + extends TaskContext(stageId, partitionId, attemptId, runningLocally, taskMetrics) with Logging { // List of callback functions to execute when the task completes. @@ -84,19 +84,8 @@ private[spark] class TaskContextImpl(_stageId: Int, override def isCompleted: Boolean = completed - override def taskMetrics(): TaskMetrics = _taskMetrics - - override def isRunningLocally: Boolean = _runningLocally - - override def runningLocally(): Boolean = _runningLocally + override def isRunningLocally: Boolean = runningLocally override def isInterrupted: Boolean = interrupted - - override def partitionId(): Int = _partitionId - - override def attemptId(): Long = _attemptId - - override def stageId(): Int = _stageId - } From ed551cee5a50c92d8b4793383b1403e1e64868ce Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Tue, 14 Oct 2014 14:24:24 +0530 Subject: [PATCH 6/6] Fixed a typo --- project/MimaBuild.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index 88a3bab6adf80..81542a8c11839 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -32,7 +32,7 @@ object MimaBuild { ProblemFilters.exclude[MissingMethodProblem](fullName), // Sometimes excluded methods have default arguments and // they are translated into public methods/fields($default$) in generated - // bytecode. It is not possible to exhustively list everything. + // bytecode. It is not possible to exhaustively list everything. // But this should be okay. ProblemFilters.exclude[MissingMethodProblem](fullName+"$default$2"), ProblemFilters.exclude[MissingMethodProblem](fullName+"$default$1"),