| 
 | 1 | +/*  | 
 | 2 | + * Licensed to the Apache Software Foundation (ASF) under one or more  | 
 | 3 | + * contributor license agreements.  See the NOTICE file distributed with  | 
 | 4 | + * this work for additional information regarding copyright ownership.  | 
 | 5 | + * The ASF licenses this file to You under the Apache License, Version 2.0  | 
 | 6 | + * (the "License"); you may not use this file except in compliance with  | 
 | 7 | + * the License.  You may obtain a copy of the License at  | 
 | 8 | + *  | 
 | 9 | + *    http://www.apache.org/licenses/LICENSE-2.0  | 
 | 10 | + *  | 
 | 11 | + * Unless required by applicable law or agreed to in writing, software  | 
 | 12 | + * distributed under the License is distributed on an "AS IS" BASIS,  | 
 | 13 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  | 
 | 14 | + * See the License for the specific language governing permissions and  | 
 | 15 | + * limitations under the License.  | 
 | 16 | + */  | 
 | 17 | + | 
 | 18 | +package org.apache.spark;  | 
 | 19 | + | 
 | 20 | +import java.io.Serializable;  | 
 | 21 | +import java.util.ArrayList;  | 
 | 22 | +import java.util.Collections;  | 
 | 23 | +import java.util.List;  | 
 | 24 | + | 
 | 25 | +import scala.Function0;  | 
 | 26 | +import scala.Function1;  | 
 | 27 | +import scala.Unit;  | 
 | 28 | +import scala.collection.JavaConversions;  | 
 | 29 | + | 
 | 30 | +import org.apache.spark.annotation.DeveloperApi;  | 
 | 31 | +import org.apache.spark.executor.TaskMetrics;  | 
 | 32 | +import org.apache.spark.util.TaskCompletionListener;  | 
 | 33 | +import org.apache.spark.util.TaskCompletionListenerException;  | 
 | 34 | + | 
 | 35 | +/**  | 
 | 36 | +* :: DeveloperApi ::  | 
 | 37 | +* Contextual information about a task which can be read or mutated during execution.  | 
 | 38 | +*/  | 
 | 39 | +@DeveloperApi  | 
 | 40 | +public class TaskContext implements Serializable {  | 
 | 41 | + | 
 | 42 | +  private int stageId;  | 
 | 43 | +  private int partitionId;  | 
 | 44 | +  private long attemptId;  | 
 | 45 | +  private boolean runningLocally;  | 
 | 46 | +  private TaskMetrics taskMetrics;  | 
 | 47 | + | 
 | 48 | +  /**  | 
 | 49 | +   * :: DeveloperApi ::  | 
 | 50 | +   * Contextual information about a task which can be read or mutated during execution.  | 
 | 51 | +   *  | 
 | 52 | +   * @param stageId stage id  | 
 | 53 | +   * @param partitionId index of the partition  | 
 | 54 | +   * @param attemptId the number of attempts to execute this task  | 
 | 55 | +   * @param runningLocally whether the task is running locally in the driver JVM  | 
 | 56 | +   * @param taskMetrics performance metrics of the task  | 
 | 57 | +   */  | 
 | 58 | +  @DeveloperApi  | 
 | 59 | +  public TaskContext(int stageId, int partitionId, long attemptId, boolean runningLocally,  | 
 | 60 | +                     TaskMetrics taskMetrics) {  | 
 | 61 | +    this.attemptId = attemptId;  | 
 | 62 | +    this.partitionId = partitionId;  | 
 | 63 | +    this.runningLocally = runningLocally;  | 
 | 64 | +    this.stageId = stageId;  | 
 | 65 | +    this.taskMetrics = taskMetrics;  | 
 | 66 | +  }  | 
 | 67 | + | 
 | 68 | +  /**  | 
 | 69 | +   * :: DeveloperApi ::  | 
 | 70 | +   * Contextual information about a task which can be read or mutated during execution.  | 
 | 71 | +   *  | 
 | 72 | +   * @param stageId stage id  | 
 | 73 | +   * @param partitionId index of the partition  | 
 | 74 | +   * @param attemptId the number of attempts to execute this task  | 
 | 75 | +   * @param runningLocally whether the task is running locally in the driver JVM  | 
 | 76 | +   */  | 
 | 77 | +  @DeveloperApi  | 
 | 78 | +  public TaskContext(int stageId, int partitionId, long attemptId, boolean runningLocally) {  | 
 | 79 | +    this.attemptId = attemptId;  | 
 | 80 | +    this.partitionId = partitionId;  | 
 | 81 | +    this.runningLocally = runningLocally;  | 
 | 82 | +    this.stageId = stageId;  | 
 | 83 | +    this.taskMetrics = TaskMetrics.empty();  | 
 | 84 | +  }  | 
 | 85 | + | 
 | 86 | +  /**  | 
 | 87 | +   * :: DeveloperApi ::  | 
 | 88 | +   * Contextual information about a task which can be read or mutated during execution.  | 
 | 89 | +   *  | 
 | 90 | +   * @param stageId stage id  | 
 | 91 | +   * @param partitionId index of the partition  | 
 | 92 | +   * @param attemptId the number of attempts to execute this task  | 
 | 93 | +   */  | 
 | 94 | +  @DeveloperApi  | 
 | 95 | +  public TaskContext(int stageId, int partitionId, long attemptId) {  | 
 | 96 | +    this.attemptId = attemptId;  | 
 | 97 | +    this.partitionId = partitionId;  | 
 | 98 | +    this.runningLocally = false;  | 
 | 99 | +    this.stageId = stageId;  | 
 | 100 | +    this.taskMetrics = TaskMetrics.empty();  | 
 | 101 | +  }  | 
 | 102 | + | 
 | 103 | +  private static ThreadLocal<TaskContext> taskContext =  | 
 | 104 | +    new ThreadLocal<TaskContext>();  | 
 | 105 | + | 
 | 106 | +  /**  | 
 | 107 | +   * :: Internal API ::  | 
 | 108 | +   * This is spark internal API, not intended to be called from user programs.  | 
 | 109 | +   */  | 
 | 110 | +  public static void setTaskContext(TaskContext tc) {  | 
 | 111 | +    taskContext.set(tc);  | 
 | 112 | +  }  | 
 | 113 | + | 
 | 114 | +  public static TaskContext get() {  | 
 | 115 | +    return taskContext.get();  | 
 | 116 | +  }  | 
 | 117 | + | 
 | 118 | +  /** :: Internal API ::  */  | 
 | 119 | +  public static void unset() {  | 
 | 120 | +    taskContext.remove();  | 
 | 121 | +  }  | 
 | 122 | + | 
 | 123 | +  // List of callback functions to execute when the task completes.  | 
 | 124 | +  private transient List<TaskCompletionListener> onCompleteCallbacks =  | 
 | 125 | +    new ArrayList<TaskCompletionListener>();  | 
 | 126 | + | 
 | 127 | +  // Whether the corresponding task has been killed.  | 
 | 128 | +  private volatile boolean interrupted = false;  | 
 | 129 | + | 
 | 130 | +  // Whether the task has completed.  | 
 | 131 | +  private volatile boolean completed = false;  | 
 | 132 | + | 
 | 133 | +  /**  | 
 | 134 | +   * Checks whether the task has completed.  | 
 | 135 | +   */  | 
 | 136 | +  public boolean isCompleted() {  | 
 | 137 | +    return completed;  | 
 | 138 | +  }  | 
 | 139 | + | 
 | 140 | +  /**  | 
 | 141 | +   * Checks whether the task has been killed.  | 
 | 142 | +   */  | 
 | 143 | +  public boolean isInterrupted() {  | 
 | 144 | +    return interrupted;  | 
 | 145 | +  }  | 
 | 146 | + | 
 | 147 | +  /**  | 
 | 148 | +   * Add a (Java friendly) listener to be executed on task completion.  | 
 | 149 | +   * This will be called in all situation - success, failure, or cancellation.  | 
 | 150 | +   * <p/>  | 
 | 151 | +   * An example use is for HadoopRDD to register a callback to close the input stream.  | 
 | 152 | +   */  | 
 | 153 | +  public TaskContext addTaskCompletionListener(TaskCompletionListener listener) {  | 
 | 154 | +    onCompleteCallbacks.add(listener);  | 
 | 155 | +    return this;  | 
 | 156 | +  }  | 
 | 157 | + | 
 | 158 | +  /**  | 
 | 159 | +   * Add a listener in the form of a Scala closure to be executed on task completion.  | 
 | 160 | +   * This will be called in all situations - success, failure, or cancellation.  | 
 | 161 | +   * <p/>  | 
 | 162 | +   * An example use is for HadoopRDD to register a callback to close the input stream.  | 
 | 163 | +   */  | 
 | 164 | +  public TaskContext addTaskCompletionListener(final Function1<TaskContext, Unit> f) {  | 
 | 165 | +    onCompleteCallbacks.add(new TaskCompletionListener() {  | 
 | 166 | +      @Override  | 
 | 167 | +      public void onTaskCompletion(TaskContext context) {  | 
 | 168 | +        f.apply(context);  | 
 | 169 | +      }  | 
 | 170 | +    });  | 
 | 171 | +    return this;  | 
 | 172 | +  }  | 
 | 173 | + | 
 | 174 | +  /**  | 
 | 175 | +   * Add a callback function to be executed on task completion. An example use  | 
 | 176 | +   * is for HadoopRDD to register a callback to close the input stream.  | 
 | 177 | +   * Will be called in any situation - success, failure, or cancellation.  | 
 | 178 | +   *  | 
 | 179 | +   * Deprecated: use addTaskCompletionListener  | 
 | 180 | +   *   | 
 | 181 | +   * @param f Callback function.  | 
 | 182 | +   */  | 
 | 183 | +  @Deprecated  | 
 | 184 | +  public void addOnCompleteCallback(final Function0<Unit> f) {  | 
 | 185 | +    onCompleteCallbacks.add(new TaskCompletionListener() {  | 
 | 186 | +      @Override  | 
 | 187 | +      public void onTaskCompletion(TaskContext context) {  | 
 | 188 | +        f.apply();  | 
 | 189 | +      }  | 
 | 190 | +    });  | 
 | 191 | +  }  | 
 | 192 | + | 
 | 193 | +  /**  | 
 | 194 | +   * ::Internal API::  | 
 | 195 | +   * Marks the task as completed and triggers the listeners.  | 
 | 196 | +   */  | 
 | 197 | +  public void markTaskCompleted() throws TaskCompletionListenerException {  | 
 | 198 | +    completed = true;  | 
 | 199 | +    List<String> errorMsgs = new ArrayList<String>(2);  | 
 | 200 | +    // Process complete callbacks in the reverse order of registration  | 
 | 201 | +    List<TaskCompletionListener> revlist =  | 
 | 202 | +      new ArrayList<TaskCompletionListener>(onCompleteCallbacks);  | 
 | 203 | +    Collections.reverse(revlist);  | 
 | 204 | +    for (TaskCompletionListener tcl: revlist) {  | 
 | 205 | +      try {  | 
 | 206 | +        tcl.onTaskCompletion(this);  | 
 | 207 | +      } catch (Throwable e) {  | 
 | 208 | +        errorMsgs.add(e.getMessage());  | 
 | 209 | +      }  | 
 | 210 | +    }  | 
 | 211 | + | 
 | 212 | +    if (!errorMsgs.isEmpty()) {  | 
 | 213 | +      throw new TaskCompletionListenerException(JavaConversions.asScalaBuffer(errorMsgs));  | 
 | 214 | +    }  | 
 | 215 | +  }  | 
 | 216 | + | 
 | 217 | +  /**  | 
 | 218 | +   * ::Internal API::  | 
 | 219 | +   * Marks the task for interruption, i.e. cancellation.  | 
 | 220 | +   */  | 
 | 221 | +  public void markInterrupted() {  | 
 | 222 | +    interrupted = true;  | 
 | 223 | +  }  | 
 | 224 | + | 
 | 225 | +  @Deprecated  | 
 | 226 | +  /** Deprecated: use getStageId() */  | 
 | 227 | +  public int stageId() {  | 
 | 228 | +    return stageId;  | 
 | 229 | +  }  | 
 | 230 | + | 
 | 231 | +  @Deprecated  | 
 | 232 | +  /** Deprecated: use getPartitionId() */  | 
 | 233 | +  public int partitionId() {  | 
 | 234 | +    return partitionId;  | 
 | 235 | +  }  | 
 | 236 | + | 
 | 237 | +  @Deprecated  | 
 | 238 | +  /** Deprecated: use getAttemptId() */  | 
 | 239 | +  public long attemptId() {  | 
 | 240 | +    return attemptId;  | 
 | 241 | +  }  | 
 | 242 | + | 
 | 243 | +  @Deprecated  | 
 | 244 | +  /** Deprecated: use isRunningLocally() */  | 
 | 245 | +  public boolean runningLocally() {  | 
 | 246 | +    return runningLocally;  | 
 | 247 | +  }  | 
 | 248 | + | 
 | 249 | +  public boolean isRunningLocally() {  | 
 | 250 | +    return runningLocally;  | 
 | 251 | +  }  | 
 | 252 | + | 
 | 253 | +  public int getStageId() {  | 
 | 254 | +    return stageId;  | 
 | 255 | +  }  | 
 | 256 | + | 
 | 257 | +  public int getPartitionId() {  | 
 | 258 | +    return partitionId;  | 
 | 259 | +  }  | 
 | 260 | + | 
 | 261 | +  public long getAttemptId() {  | 
 | 262 | +    return attemptId;  | 
 | 263 | +  }    | 
 | 264 | + | 
 | 265 | +  /** ::Internal API:: */  | 
 | 266 | +  public TaskMetrics taskMetrics() {  | 
 | 267 | +    return taskMetrics;  | 
 | 268 | +  }  | 
 | 269 | +}  | 
0 commit comments