diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 9bd8a0f98d..c66759f59f 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -1,13 +1,10 @@ package spark -import java.net.URL -import java.util.{Date, Random} -import java.util.{HashMap => JHashMap} +import java.util.Random import scala.collection.Map import scala.collection.JavaConversions.mapAsScalaMap import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap import org.apache.hadoop.io.BytesWritable import org.apache.hadoop.io.NullWritable @@ -28,6 +25,7 @@ import spark.rdd.FlatMappedRDD import spark.rdd.GlommedRDD import spark.rdd.MappedRDD import spark.rdd.MapPartitionsRDD +import spark.rdd.MapPartitionsWithSetupAndCleanup import spark.rdd.MapPartitionsWithIndexRDD import spark.rdd.PipedRDD import spark.rdd.SampledRDD @@ -36,7 +34,8 @@ import spark.rdd.UnionRDD import spark.rdd.ZippedRDD import spark.storage.StorageLevel -import SparkContext._ +import spark.SparkContext._ +import spark.RDD.PartitionMapper /** * A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable, @@ -364,6 +363,18 @@ abstract class RDD[T: ClassManifest]( preservesPartitioning: Boolean = false): RDD[U] = new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning) + /** + * Return a new RDD by applying a function to every element in this RDD, with extra setup & cleanup + * at the beginning & end of processing every partition. + * + * This might be useful if you need to setup some resources per task & cleanup them up at the end, eg. + * a db connection + */ + def mapWithSetup[U: ClassManifest]( + m: PartitionMapper[T,U], + preservesPartitioning: Boolean = false): RDD[U] = + new MapPartitionsWithSetupAndCleanup(this, m, preservesPartitioning) + /** * Zips this RDD with another one, returning key-value pairs with the first element in each RDD, * second element in each RDD, etc. Assumes that the two RDDs have the *same number of @@ -746,3 +757,28 @@ abstract class RDD[T: ClassManifest]( origin) } + +object RDD { + + /** + * Defines a map function over elements of an RDD, but with extra setup and cleanup + * that happens + */ + trait PartitionMapper[T,U] extends Serializable { + /** + * called at the start of processing of each partition + */ + def setup(partiton:Int) + + /** + * transform one element of the partition + */ + @throws(classOf[Exception]) //for the java api + def map(t: T) : U + + /** + * called at the end of each partition. This will get called even if the map failed (eg., an exception was thrown) + */ + def cleanup + } +} diff --git a/core/src/main/scala/spark/api/java/JavaDoublePartitionMapper.java b/core/src/main/scala/spark/api/java/JavaDoublePartitionMapper.java new file mode 100644 index 0000000000..5d07c10572 --- /dev/null +++ b/core/src/main/scala/spark/api/java/JavaDoublePartitionMapper.java @@ -0,0 +1,4 @@ +package spark.api.java; + +public abstract class JavaDoublePartitionMapper implements spark.RDD.PartitionMapper { +} diff --git a/core/src/main/scala/spark/api/java/JavaPairPartitionMapper.java b/core/src/main/scala/spark/api/java/JavaPairPartitionMapper.java new file mode 100644 index 0000000000..a09bf66cf1 --- /dev/null +++ b/core/src/main/scala/spark/api/java/JavaPairPartitionMapper.java @@ -0,0 +1,6 @@ +package spark.api.java; + +import scala.Tuple2; + +public abstract class JavaPairPartitionMapper implements spark.RDD.PartitionMapper> { +} diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala index d884529d7a..4dcf705e85 100644 --- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala @@ -10,6 +10,8 @@ import spark.api.java.function.{Function2 => JFunction2, Function => JFunction, import spark.partial.{PartialResult, BoundedDouble} import spark.storage.StorageLevel import com.google.common.base.Optional +import spark.RDD.PartitionMapper +import spark.api.java.ManifestHelper.fakeManifest trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { @@ -117,6 +119,31 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { JavaPairRDD.fromRDD(rdd.mapPartitions(fn))(f.keyType(), f.valueType()) } + /** + * Return a new RDD by applying a function to each element of the RDD, with an additional + * setup & cleanup that happens before & after computing each partition + */ + def mapWithSetup[U](m: PartitionMapper[T,U]): JavaRDD[U] = { + JavaRDD.fromRDD(rdd.mapWithSetup(m)(fakeManifest[U]))(fakeManifest[U]) + } + + /** + * Return a new RDD by applying a function to each element of the RDD, with an additional + * setup & cleanup that happens before & after computing each partition + */ + def mapWithSetup[K,V](m: JavaPairPartitionMapper[T,K,V]): JavaPairRDD[K,V] = { + JavaPairRDD.fromRDD(rdd.mapWithSetup(m)(fakeManifest[(K,V)]))( + fakeManifest[K], fakeManifest[V]) + } + + /** + * Return a new RDD by applying a function to each element of the RDD, with an additional + * setup & cleanup that happens before & after computing each partition + */ + def mapWithSetup(m: JavaDoublePartitionMapper[T]): JavaDoubleRDD = { + JavaDoubleRDD.fromRDD(rdd.mapWithSetup(m)(manifest[java.lang.Double]).asInstanceOf[RDD[Double]]) + } + /** * Return an RDD created by coalescing all elements within each partition into an array. */ diff --git a/core/src/main/scala/spark/api/java/ManifestHelper.java b/core/src/main/scala/spark/api/java/ManifestHelper.java new file mode 100644 index 0000000000..f2ca6da022 --- /dev/null +++ b/core/src/main/scala/spark/api/java/ManifestHelper.java @@ -0,0 +1,11 @@ +package spark.api.java; + +import scala.reflect.ClassManifest; +import scala.reflect.ClassManifest$; + +class ManifestHelper { + + public static ClassManifest fakeManifest() { + return (ClassManifest) ClassManifest$.MODULE$.fromClass(Object.class); + } +} diff --git a/core/src/main/scala/spark/rdd/MapPartitionsWithSetupAndCleanup.scala b/core/src/main/scala/spark/rdd/MapPartitionsWithSetupAndCleanup.scala new file mode 100644 index 0000000000..e818b5e9b3 --- /dev/null +++ b/core/src/main/scala/spark/rdd/MapPartitionsWithSetupAndCleanup.scala @@ -0,0 +1,26 @@ +package spark.rdd + +import spark.{TaskContext, Partition, RDD} +import spark.RDD.PartitionMapper + +/** + * + */ + +class MapPartitionsWithSetupAndCleanup[U: ClassManifest, T: ClassManifest]( + prev: RDD[T], + m: PartitionMapper[T,U], + preservesPartitioning: Boolean +) extends RDD[U](prev){ + + override def getPartitions = firstParent[T].partitions + + override val partitioner = if (preservesPartitioning) prev.partitioner else None + + override def compute(split: Partition, context: TaskContext) = { + context.addOnCompleteCallback(m.cleanup _) + m.setup(split.index) + firstParent[T].iterator(split, context).map(m.map _) + } + +} diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index 26e3ab72c0..8520ae911c 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -19,10 +19,7 @@ import org.junit.Before; import org.junit.Test; -import spark.api.java.JavaDoubleRDD; -import spark.api.java.JavaPairRDD; -import spark.api.java.JavaRDD; -import spark.api.java.JavaSparkContext; +import spark.api.java.*; import spark.api.java.function.*; import spark.partial.BoundedDouble; import spark.partial.PartialResult; @@ -400,6 +397,54 @@ public Iterable call(Iterator iter) { Assert.assertEquals("[3, 7]", partitionSums.collect().toString()); } + @Test + public void mapPartitionsWithSetupAndCleanup() { + //the real test of the behavior is in the scala test, just make sure the java api wrappers are OK + JavaRDD rdd = sc.parallelize(Arrays.asList(1,2,3,4,5,6,7,8,9,10), 4); + + JavaPairRDD pairRdd = rdd.mapWithSetup(new JavaPairPartitionMapper() { + @Override + public void setup(int partition) { + System.out.println("setup " + partition); + } + + @Override + public Tuple2 map(Integer integer) { + return new Tuple2(integer, integer + "_"); + } + + @Override + public void cleanup() { + System.out.println("cleanup"); + } + }); + Assert.assertEquals( + "[(1,1_), (2,2_), (3,3_), (4,4_), (5,5_), (6,6_), (7,7_), (8,8_), (9,9_), (10,10_)]", + pairRdd.collect().toString()); + + + JavaDoubleRDD doubleRdd = rdd.mapWithSetup(new JavaDoublePartitionMapper() { + @Override + public void setup(int partition) { + System.out.println("setup" + partition); + } + + @Override + public Double map(Integer integer) throws Exception { + return integer.doubleValue(); + } + + @Override + public void cleanup() { + System.out.println("cleanup"); + } + }); + Assert.assertEquals( + "[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]", + doubleRdd.collect().toString()); + + } + @Test public void persist() { JavaDoubleRDD doubleRDD = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0)); diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 9739ba869b..db30fbbdbd 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -1,9 +1,11 @@ package spark import scala.collection.mutable.HashMap +import scala.collection.Set import org.scalatest.FunSuite import spark.SparkContext._ import spark.rdd.{CoalescedRDD, PartitionPruningRDD} +import spark.RDD.PartitionMapper class RDDSuite extends FunSuite with LocalSparkContext { @@ -178,4 +180,48 @@ class RDDSuite extends FunSuite with LocalSparkContext { assert(prunedData.size === 1) assert(prunedData(0) === 10) } + + test("mapPartitionWithSetupAndCleanup") { + sc = new SparkContext("local[4]", "test") + val data = sc.parallelize(1 to 100, 4) + val acc = sc.accumulableCollection(new HashMap[Int,Set[Int]]()) + val mapped = data.mapWithSetup(new PartitionMapper[Int,Int](){ + var partition = -1 + var values = Set[Int]() + def setup(partition:Int) {this.partition = partition} + def map(i:Int) = {values += i; i * 2} + def cleanup = { + //the purpose of this strange code is just to make sure this method is called + // after the data has been iterated through completely. + acc.localValue += (partition -> values) + } + }).collect + + assert(mapped.toSet === (1 to 100).map{_ * 2}.toSet) + assert(acc.value.keySet == (0 to 3).toSet) + acc.value.foreach { case(partition, values) => + assert(values.size === 25) + } + + + //the naive alternative doesn't work + val acc2 = sc.accumulableCollection(new HashMap[Int,Set[Int]]()) + val m2 = data.mapPartitionsWithSplit{ + case (partition, itr) => + var values = Set[Int]() + val mItr = itr.map{i => values += i; i * 2} + //you haven't actually put anything into values yet, b/c itr.map defines another + // iterator, which is lazily computed. so the Set is empty + acc2.localValue += (partition -> values) + mItr + }.collect + + assert(m2.toSet === (1 to 100).map{_ * 2}.toSet) + assert(acc2.value.keySet === (0 to 3).toSet) + //this condition will fail +// acc2.value.foreach { case(partition, values) => +// assert(values.size === 25) +// } + + } }