From 14419775202e6eef1f0e1f0c74c7be9030aca73d Mon Sep 17 00:00:00 2001 From: Doris Xin Date: Thu, 29 May 2014 15:22:14 -0700 Subject: [PATCH 01/16] SPARK-1939 Refactor takeSample method in RDD to use ScaSRS --- core/pom.xml | 4 ++ .../main/scala/org/apache/spark/rdd/RDD.scala | 32 +++++++++- .../spark/util/random/RandomSampler.scala | 2 +- .../scala/org/apache/spark/rdd/RDDSuite.scala | 63 ++++++++++++++----- pom.xml | 5 ++ project/SparkBuild.scala | 1 + python/pyspark/rdd.py | 15 ++++- 7 files changed, 100 insertions(+), 22 deletions(-) diff --git a/core/pom.xml b/core/pom.xml index bab50f5ce2888..6cb58dbd291c4 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -67,6 +67,10 @@ org.apache.commons commons-lang3 + + org.apache.commons + commons-math3 + com.google.code.findbugs jsr305 diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index aa03e9276fb34..2fdf45a0c8b8e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -379,8 +379,17 @@ abstract class RDD[T: ClassTag]( }.toArray } - def takeSample(withReplacement: Boolean, num: Int, seed: Long = Utils.random.nextLong): Array[T] = - { + /** + * Return a fixed-size sampled subset of this RDD in an array + * + * @param withReplacement whether sampling is done with replacement + * @param num size of the returned sample + * @param seed seed for the random number generator + * @return sample of specified size in an array + */ + def takeSample(withReplacement: Boolean, + num: Int, + seed: Long = Utils.random.nextLong): Array[T] = { var fraction = 0.0 var total = 0 val multiplier = 3.0 @@ -402,10 +411,11 @@ abstract class RDD[T: ClassTag]( } if (num > initialCount && !withReplacement) { + // special case not covered in computeFraction total = maxSelected fraction = multiplier * (maxSelected + 1) / initialCount } else { - fraction = multiplier * (num + 1) / initialCount + fraction = computeFraction(num, initialCount, withReplacement) total = num } @@ -421,6 +431,22 @@ abstract class RDD[T: ClassTag]( Utils.randomizeInPlace(samples, rand).take(total) } + private[spark] def computeFraction(num: Int, total: Long, withReplacement: Boolean) : Double = { + val fraction = num.toDouble / total + if (withReplacement) { + var numStDev = 5 + if (num < 12) { + // special case to guarantee sample size for small s + numStDev = 9 + } + fraction + numStDev * math.sqrt(fraction / total) + } else { + val delta = 0.00005 + val gamma = - math.log(delta)/total + math.min(1, fraction + gamma + math.sqrt(gamma * gamma + 2 * gamma * fraction)) + } + } + /** * Return the union of this RDD and another one. Any identical elements will appear multiple * times (use `.distinct()` to eliminate them). diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala index 4dc8ada00a3e8..e53103755b279 100644 --- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala +++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala @@ -70,7 +70,7 @@ class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false) } /** - * Return a sampler with is the complement of the range specified of the current sampler. + * Return a sampler which is the complement of the range specified of the current sampler. */ def cloneComplement(): BernoulliSampler[T] = new BernoulliSampler[T](lb, ub, !complement) diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index e686068f7a99a..5bdcb9bef6d62 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -22,6 +22,7 @@ import scala.reflect.ClassTag import org.scalatest.FunSuite +import org.apache.commons.math3.distribution.PoissonDistribution import org.apache.spark._ import org.apache.spark.SparkContext._ import org.apache.spark.rdd._ @@ -494,56 +495,84 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(sortedTopK === nums.sorted(ord).take(5)) } + test("computeFraction") { + // test that the computed fraction guarantees enough datapoints in the sample with a failure rate <= 0.0001 + val data = new EmptyRDD[Int](sc) + val n = 100000 + + for (s <- 1 to 15) { + val frac = data.computeFraction(s, n, true) + val qpois = new PoissonDistribution(frac * n) + assert(qpois.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low") + } + for (s <- 1 to 15) { + val frac = data.computeFraction(s, n, false) + val qpois = new PoissonDistribution(frac * n) + assert(qpois.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low") + } + for (s <- List(1, 10, 100, 1000)) { + val frac = data.computeFraction(s, n, true) + val qpois = new PoissonDistribution(frac * n) + assert(qpois.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low") + } + for (s <- List(1, 10, 100, 1000)) { + val frac = data.computeFraction(s, n, false) + val qpois = new PoissonDistribution(frac * n) + assert(qpois.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low") + } + } + test("takeSample") { - val data = sc.parallelize(1 to 100, 2) + val n = 1000000 + val data = sc.parallelize(1 to n, 2) for (num <- List(5, 20, 100)) { val sample = data.takeSample(withReplacement=false, num=num) assert(sample.size === num) // Got exactly num elements assert(sample.toSet.size === num) // Elements are distinct - assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]") } for (seed <- 1 to 5) { val sample = data.takeSample(withReplacement=false, 20, seed) assert(sample.size === 20) // Got exactly 20 elements assert(sample.toSet.size === 20) // Elements are distinct - assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]") } for (seed <- 1 to 5) { - val sample = data.takeSample(withReplacement=false, 200, seed) + val sample = data.takeSample(withReplacement=false, 100, seed) assert(sample.size === 100) // Got only 100 elements assert(sample.toSet.size === 100) // Elements are distinct - assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]") } for (seed <- 1 to 5) { val sample = data.takeSample(withReplacement=true, 20, seed) assert(sample.size === 20) // Got exactly 20 elements - assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]") } { val sample = data.takeSample(withReplacement=true, num=20) assert(sample.size === 20) // Got exactly 100 elements assert(sample.toSet.size <= 20, "sampling with replacement returned all distinct elements") - assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]") } { - val sample = data.takeSample(withReplacement=true, num=100) - assert(sample.size === 100) // Got exactly 100 elements + val sample = data.takeSample(withReplacement=true, num=n) + assert(sample.size === n) // Got exactly 100 elements // Chance of getting all distinct elements is astronomically low, so test we got < 100 - assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements") - assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements") + assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]") } for (seed <- 1 to 5) { - val sample = data.takeSample(withReplacement=true, 100, seed) - assert(sample.size === 100) // Got exactly 100 elements + val sample = data.takeSample(withReplacement=true, n, seed) + assert(sample.size === n) // Got exactly 100 elements // Chance of getting all distinct elements is astronomically low, so test we got < 100 - assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements") + assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements") } for (seed <- 1 to 5) { - val sample = data.takeSample(withReplacement=true, 200, seed) - assert(sample.size === 200) // Got exactly 200 elements + val sample = data.takeSample(withReplacement=true, 2*n, seed) + assert(sample.size === 2*n) // Got exactly 200 elements // Chance of getting all distinct elements is still quite low, so test we got < 100 - assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements") + assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements") } } diff --git a/pom.xml b/pom.xml index 7bf9f135fd340..01d6eef32be63 100644 --- a/pom.xml +++ b/pom.xml @@ -245,6 +245,11 @@ commons-codec 1.5 + + org.apache.commons + commons-math3 + 3.2 + com.google.code.findbugs jsr305 diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 8ef1e91f609fb..a6b6c26a49395 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -331,6 +331,7 @@ object SparkBuild extends Build { libraryDependencies ++= Seq( "com.google.guava" % "guava" % "14.0.1", "org.apache.commons" % "commons-lang3" % "3.3.2", + "org.apache.commons" % "commons-math3" % "3.2", "com.google.code.findbugs" % "jsr305" % "1.3.9", "log4j" % "log4j" % "1.2.17", "org.slf4j" % "slf4j-api" % slf4jVersion, diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 07578b8d937fc..b400404ad97c7 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -31,6 +31,7 @@ import warnings import heapq from random import Random +from math import sqrt, log, min from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ BatchedSerializer, CloudPickleSerializer, PairDeserializer, pack_long @@ -374,7 +375,7 @@ def takeSample(self, withReplacement, num, seed=None): total = maxSelected fraction = multiplier * (maxSelected + 1) / initialCount else: - fraction = multiplier * (num + 1) / initialCount + fraction = self._computeFraction(num, initialCount, withReplacement) total = num samples = self.sample(withReplacement, fraction, seed).collect() @@ -390,6 +391,18 @@ def takeSample(self, withReplacement, num, seed=None): sampler.shuffle(samples) return samples[0:total] + def _computeFraction(self, num, total, withReplacement): + fraction = float(num)/total + if withReplacement: + numStDev = 5 + if (num < 12): + numStDev = 9 + return fraction + numStDev * sqrt(fraction/total) + else: + delta = 0.00005 + gamma = - log(delta)/total + return min(1, fraction + gamma + sqrt(gamma * gamma + 2* gamma * fraction)) + def union(self, other): """ Return the union of this RDD and another one. From ffea61a67d228edb476d29ca13a84bb3f9a22887 Mon Sep 17 00:00:00 2001 From: Doris Xin Date: Thu, 29 May 2014 17:55:54 -0700 Subject: [PATCH 02/16] SPARK-1939: Refactor takeSample method in RDD Reviewer comments addressed: - commons-math3 is now a test-only dependency. bumped up to v3.3 - comments added to explain what computeFraction is doing - fixed the unit for computeFraction to use BinomialDitro for without replacement sampling - stylistic fixes --- core/pom.xml | 1 + .../main/scala/org/apache/spark/rdd/RDD.scala | 33 +++++++++++------ .../spark/util/random/RandomSampler.scala | 2 +- .../scala/org/apache/spark/rdd/RDDSuite.scala | 36 +++++++++---------- pom.xml | 2 +- project/SparkBuild.scala | 2 +- 6 files changed, 44 insertions(+), 32 deletions(-) diff --git a/core/pom.xml b/core/pom.xml index 6cb58dbd291c4..2b9f750e07d97 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -70,6 +70,7 @@ org.apache.commons commons-math3 + test com.google.code.findbugs diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 2fdf45a0c8b8e..9a5cf13e3ba52 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -388,8 +388,8 @@ abstract class RDD[T: ClassTag]( * @return sample of specified size in an array */ def takeSample(withReplacement: Boolean, - num: Int, - seed: Long = Utils.random.nextLong): Array[T] = { + num: Int, + seed: Long = Utils.random.nextLong): Array[T] = { var fraction = 0.0 var total = 0 val multiplier = 3.0 @@ -431,18 +431,31 @@ abstract class RDD[T: ClassTag]( Utils.randomizeInPlace(samples, rand).take(total) } - private[spark] def computeFraction(num: Int, total: Long, withReplacement: Boolean) : Double = { + /** + * Let p = num / total, where num is the sample size and total is the total number of + * datapoints in the RDD. We're trying to compute q > p such that + * - when sampling with replacement, we're drawing each datapoint with prob_i ~ Pois(q), + * where we want to guarantee Pr[s < num] < 0.0001 for s = sum(prob_i for i from 0 to total), + * i.e. the failure rate of not having a sufficiently large sample < 0.0001. + * Setting q = p + 5 * sqrt(p/total) is sufficient to guarantee 0.9999 success rate for + * num > 12, but we need a slightly larger q (9 empirically determined). + * - when sampling without replacement, we're drawing each datapoint with prob_i + * ~ Binomial(total, fraction) and our choice of q guarantees 1-delta, or 0.9999 success + * rate, where success rate is defined the same as in sampling with replacement. + * + * @param num sample size + * @param total size of RDD + * @param withReplacement whether sampling with replacement + * @return a sampling rate that guarantees sufficient sample size with 99.99% success rate + */ + private[rdd] def computeFraction(num: Int, total: Long, withReplacement: Boolean): Double = { val fraction = num.toDouble / total if (withReplacement) { - var numStDev = 5 - if (num < 12) { - // special case to guarantee sample size for small s - numStDev = 9 - } + val numStDev = if (num < 12) 9 else 5 fraction + numStDev * math.sqrt(fraction / total) } else { - val delta = 0.00005 - val gamma = - math.log(delta)/total + val delta = 1e-4 + val gamma = - math.log(delta) / total math.min(1, fraction + gamma + math.sqrt(gamma * gamma + 2 * gamma * fraction)) } } diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala index e53103755b279..247f10173f1e9 100644 --- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala +++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala @@ -70,7 +70,7 @@ class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false) } /** - * Return a sampler which is the complement of the range specified of the current sampler. + * Return a sampler that is the complement of the range specified of the current sampler. */ def cloneComplement(): BernoulliSampler[T] = new BernoulliSampler[T](lb, ub, !complement) diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 5bdcb9bef6d62..08b3b93d6a31f 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -22,7 +22,9 @@ import scala.reflect.ClassTag import org.scalatest.FunSuite +import org.apache.commons.math3.distribution.BinomialDistribution import org.apache.commons.math3.distribution.PoissonDistribution + import org.apache.spark._ import org.apache.spark.SparkContext._ import org.apache.spark.rdd._ @@ -496,29 +498,25 @@ class RDDSuite extends FunSuite with SharedSparkContext { } test("computeFraction") { - // test that the computed fraction guarantees enough datapoints in the sample with a failure rate <= 0.0001 + // test that the computed fraction guarantees enough datapoints + // in the sample with a failure rate <= 0.0001 val data = new EmptyRDD[Int](sc) val n = 100000 for (s <- 1 to 15) { val frac = data.computeFraction(s, n, true) - val qpois = new PoissonDistribution(frac * n) - assert(qpois.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low") + val poisson = new PoissonDistribution(frac * n) + assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low") } - for (s <- 1 to 15) { - val frac = data.computeFraction(s, n, false) - val qpois = new PoissonDistribution(frac * n) - assert(qpois.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low") - } - for (s <- List(1, 10, 100, 1000)) { + for (s <- List(20, 100, 1000)) { val frac = data.computeFraction(s, n, true) - val qpois = new PoissonDistribution(frac * n) - assert(qpois.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low") + val poisson = new PoissonDistribution(frac * n) + assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low") } for (s <- List(1, 10, 100, 1000)) { val frac = data.computeFraction(s, n, false) - val qpois = new PoissonDistribution(frac * n) - assert(qpois.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low") + val binomial = new BinomialDistribution(n, frac) + assert(binomial.inverseCumulativeProbability(0.0001)*n >= s, "Computed fraction is too low") } } @@ -530,37 +528,37 @@ class RDDSuite extends FunSuite with SharedSparkContext { val sample = data.takeSample(withReplacement=false, num=num) assert(sample.size === num) // Got exactly num elements assert(sample.toSet.size === num) // Elements are distinct - assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]") + assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } for (seed <- 1 to 5) { val sample = data.takeSample(withReplacement=false, 20, seed) assert(sample.size === 20) // Got exactly 20 elements assert(sample.toSet.size === 20) // Elements are distinct - assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]") + assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } for (seed <- 1 to 5) { val sample = data.takeSample(withReplacement=false, 100, seed) assert(sample.size === 100) // Got only 100 elements assert(sample.toSet.size === 100) // Elements are distinct - assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]") + assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } for (seed <- 1 to 5) { val sample = data.takeSample(withReplacement=true, 20, seed) assert(sample.size === 20) // Got exactly 20 elements - assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]") + assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } { val sample = data.takeSample(withReplacement=true, num=20) assert(sample.size === 20) // Got exactly 100 elements assert(sample.toSet.size <= 20, "sampling with replacement returned all distinct elements") - assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]") + assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } { val sample = data.takeSample(withReplacement=true, num=n) assert(sample.size === n) // Got exactly 100 elements // Chance of getting all distinct elements is astronomically low, so test we got < 100 assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements") - assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]") + assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } for (seed <- 1 to 5) { val sample = data.takeSample(withReplacement=true, n, seed) diff --git a/pom.xml b/pom.xml index 01d6eef32be63..64c8cd3c7810a 100644 --- a/pom.xml +++ b/pom.xml @@ -248,7 +248,7 @@ org.apache.commons commons-math3 - 3.2 + 3.3 com.google.code.findbugs diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index a6b6c26a49395..7314168fa8b6e 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -331,7 +331,7 @@ object SparkBuild extends Build { libraryDependencies ++= Seq( "com.google.guava" % "guava" % "14.0.1", "org.apache.commons" % "commons-lang3" % "3.3.2", - "org.apache.commons" % "commons-math3" % "3.2", + "org.apache.commons" % "commons-math3" % "3.3" % "test", "com.google.code.findbugs" % "jsr305" % "1.3.9", "log4j" % "log4j" % "1.2.17", "org.slf4j" % "slf4j-api" % slf4jVersion, From 7cab53a3926f4351432e5e3600b0796b9a4146e4 Mon Sep 17 00:00:00 2001 From: Doris Xin Date: Mon, 2 Jun 2014 12:00:38 -0700 Subject: [PATCH 03/16] fixed import bug in rdd.py --- python/pyspark/rdd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index b400404ad97c7..8f46e448e643e 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -31,7 +31,7 @@ import warnings import heapq from random import Random -from math import sqrt, log, min +from math import sqrt, log from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ BatchedSerializer, CloudPickleSerializer, PairDeserializer, pack_long From 9bdd36ede8e3c7e9e2892f327bc8d4b8898f2b7e Mon Sep 17 00:00:00 2001 From: Doris Xin Date: Mon, 9 Jun 2014 13:43:56 -0700 Subject: [PATCH 04/16] Check sample size and move computeFraction Check that the sample size is within supported range. Moved computeFraction int a private util class in util.random --- .../main/scala/org/apache/spark/rdd/RDD.scala | 44 +++++----------- .../spark/util/random/SamplingUtils.scala | 50 +++++++++++++++++++ .../scala/org/apache/spark/rdd/RDDSuite.scala | 23 --------- .../util/random/SamplingUtilsSuite.scala | 46 +++++++++++++++++ 4 files changed, 108 insertions(+), 55 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala create mode 100644 core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 9a5cf13e3ba52..a1b558fba55ac 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -43,7 +43,7 @@ import org.apache.spark.partial.PartialResult import org.apache.spark.storage.StorageLevel import org.apache.spark.util.{BoundedPriorityQueue, SerializableHyperLogLog, Utils} import org.apache.spark.util.collection.OpenHashMap -import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler} +import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, SamplingUtils} /** * A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable, @@ -400,12 +400,21 @@ abstract class RDD[T: ClassTag]( throw new IllegalArgumentException("Negative number of elements requested") } + if (!withReplacement && num > initialCount) { + throw new IllegalArgumentException("Cannot create sample larger than the original when " + + "sampling without replacement") + } + if (initialCount == 0) { return new Array[T](0) } if (initialCount > Integer.MAX_VALUE - 1) { - maxSelected = Integer.MAX_VALUE - 1 + maxSelected = Integer.MAX_VALUE - (5.0 * math.sqrt(Integer.MAX_VALUE)).toInt + if (num > maxSelected) { + throw new IllegalArgumentException("Cannot support a sample size > Integer.MAX_VALUE - " + + "5.0 * math.sqrt(Integer.MAX_VALUE)") + } } else { maxSelected = initialCount.toInt } @@ -415,7 +424,7 @@ abstract class RDD[T: ClassTag]( total = maxSelected fraction = multiplier * (maxSelected + 1) / initialCount } else { - fraction = computeFraction(num, initialCount, withReplacement) + fraction = SamplingUtils.computeFraction(num, initialCount, withReplacement) total = num } @@ -431,35 +440,6 @@ abstract class RDD[T: ClassTag]( Utils.randomizeInPlace(samples, rand).take(total) } - /** - * Let p = num / total, where num is the sample size and total is the total number of - * datapoints in the RDD. We're trying to compute q > p such that - * - when sampling with replacement, we're drawing each datapoint with prob_i ~ Pois(q), - * where we want to guarantee Pr[s < num] < 0.0001 for s = sum(prob_i for i from 0 to total), - * i.e. the failure rate of not having a sufficiently large sample < 0.0001. - * Setting q = p + 5 * sqrt(p/total) is sufficient to guarantee 0.9999 success rate for - * num > 12, but we need a slightly larger q (9 empirically determined). - * - when sampling without replacement, we're drawing each datapoint with prob_i - * ~ Binomial(total, fraction) and our choice of q guarantees 1-delta, or 0.9999 success - * rate, where success rate is defined the same as in sampling with replacement. - * - * @param num sample size - * @param total size of RDD - * @param withReplacement whether sampling with replacement - * @return a sampling rate that guarantees sufficient sample size with 99.99% success rate - */ - private[rdd] def computeFraction(num: Int, total: Long, withReplacement: Boolean): Double = { - val fraction = num.toDouble / total - if (withReplacement) { - val numStDev = if (num < 12) 9 else 5 - fraction + numStDev * math.sqrt(fraction / total) - } else { - val delta = 1e-4 - val gamma = - math.log(delta) / total - math.min(1, fraction + gamma + math.sqrt(gamma * gamma + 2 * gamma * fraction)) - } - } - /** * Return the union of this RDD and another one. Any identical elements will appear multiple * times (use `.distinct()` to eliminate them). diff --git a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala new file mode 100644 index 0000000000000..0905e3dd49dec --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.random + +private[spark] object SamplingUtils { + + /** + * Let p = num / total, where num is the sample size and total is the total number of + * datapoints in the RDD. We're trying to compute q > p such that + * - when sampling with replacement, we're drawing each datapoint with prob_i ~ Pois(q), + * where we want to guarantee Pr[s < num] < 0.0001 for s = sum(prob_i for i from 0 to total), + * i.e. the failure rate of not having a sufficiently large sample < 0.0001. + * Setting q = p + 5 * sqrt(p/total) is sufficient to guarantee 0.9999 success rate for + * num > 12, but we need a slightly larger q (9 empirically determined). + * - when sampling without replacement, we're drawing each datapoint with prob_i + * ~ Binomial(total, fraction) and our choice of q guarantees 1-delta, or 0.9999 success + * rate, where success rate is defined the same as in sampling with replacement. + * + * @param num sample size + * @param total size of RDD + * @param withReplacement whether sampling with replacement + * @return a sampling rate that guarantees sufficient sample size with 99.99% success rate + */ + def computeFraction(num: Int, total: Long, withReplacement: Boolean): Double = { + val fraction = num.toDouble / total + if (withReplacement) { + val numStDev = if (num < 12) 9 else 5 + fraction + numStDev * math.sqrt(fraction / total) + } else { + val delta = 1e-4 + val gamma = - math.log(delta) / total + math.min(1, fraction + gamma + math.sqrt(gamma * gamma + 2 * gamma * fraction)) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 08b3b93d6a31f..f93979a2a745a 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -497,29 +497,6 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(sortedTopK === nums.sorted(ord).take(5)) } - test("computeFraction") { - // test that the computed fraction guarantees enough datapoints - // in the sample with a failure rate <= 0.0001 - val data = new EmptyRDD[Int](sc) - val n = 100000 - - for (s <- 1 to 15) { - val frac = data.computeFraction(s, n, true) - val poisson = new PoissonDistribution(frac * n) - assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low") - } - for (s <- List(20, 100, 1000)) { - val frac = data.computeFraction(s, n, true) - val poisson = new PoissonDistribution(frac * n) - assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low") - } - for (s <- List(1, 10, 100, 1000)) { - val frac = data.computeFraction(s, n, false) - val binomial = new BinomialDistribution(n, frac) - assert(binomial.inverseCumulativeProbability(0.0001)*n >= s, "Computed fraction is too low") - } - } - test("takeSample") { val n = 1000000 val data = sc.parallelize(1 to n, 2) diff --git a/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala new file mode 100644 index 0000000000000..3ad46c41fe902 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.random + +import org.apache.commons.math3.distribution.{BinomialDistribution, PoissonDistribution} +import org.scalatest.FunSuite + +class SamplingUtilsSuite extends FunSuite{ + + test("computeFraction") { + // test that the computed fraction guarantees enough datapoints + // in the sample with a failure rate <= 0.0001 + val n = 100000 + + for (s <- 1 to 15) { + val frac = SamplingUtils.computeFraction(s, n, true) + val poisson = new PoissonDistribution(frac * n) + assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low") + } + for (s <- List(20, 100, 1000)) { + val frac = SamplingUtils.computeFraction(s, n, true) + val poisson = new PoissonDistribution(frac * n) + assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low") + } + for (s <- List(1, 10, 100, 1000)) { + val frac = SamplingUtils.computeFraction(s, n, false) + val binomial = new BinomialDistribution(n, frac) + assert(binomial.inverseCumulativeProbability(0.0001)*n >= s, "Computed fraction is too low") + } + } +} From ae3ad049161f470549510daf2eefdbe576fb01e8 Mon Sep 17 00:00:00 2001 From: Doris Xin Date: Tue, 10 Jun 2014 12:02:34 -0700 Subject: [PATCH 05/16] fixed edge cases to prevent overflow --- .../main/scala/org/apache/spark/rdd/RDD.scala | 24 ++++++------------- 1 file changed, 7 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 39933ec340356..fb4c3c6ebd6cc 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -391,41 +391,31 @@ abstract class RDD[T: ClassTag]( seed: Long = Utils.random.nextLong): Array[T] = { var fraction = 0.0 var total = 0 - val multiplier = 3.0 val initialCount = this.count() - var maxSelected = 0 if (num < 0) { throw new IllegalArgumentException("Negative number of elements requested") } + if (initialCount == 0) { + return new Array[T](0) + } + if (!withReplacement && num > initialCount) { throw new IllegalArgumentException("Cannot create sample larger than the original when " + "sampling without replacement") } - if (initialCount == 0) { - return new Array[T](0) - } - if (initialCount > Integer.MAX_VALUE - 1) { - maxSelected = Integer.MAX_VALUE - (5.0 * math.sqrt(Integer.MAX_VALUE)).toInt + val maxSelected = Integer.MAX_VALUE - (5.0 * math.sqrt(Integer.MAX_VALUE)).toInt if (num > maxSelected) { throw new IllegalArgumentException("Cannot support a sample size > Integer.MAX_VALUE - " + "5.0 * math.sqrt(Integer.MAX_VALUE)") } - } else { - maxSelected = initialCount.toInt } - if (num > initialCount && !withReplacement) { - // special case not covered in computeFraction - total = maxSelected - fraction = multiplier * (maxSelected + 1) / initialCount - } else { - fraction = SamplingUtils.computeFraction(num, initialCount, withReplacement) - total = num - } + fraction = SamplingUtils.computeFraction(num, initialCount, withReplacement) + total = num val rand = new Random(seed) var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() From 0a9b3e3f8a552c02040ad5e8afb8dd03e91a863e Mon Sep 17 00:00:00 2001 From: Doris Xin Date: Tue, 10 Jun 2014 13:24:44 -0700 Subject: [PATCH 06/16] "reviewer comment addressed" --- .../main/scala/org/apache/spark/rdd/RDD.scala | 11 +++++---- .../spark/util/random/SamplingUtils.scala | 8 +++---- .../util/random/SamplingUtilsSuite.scala | 6 ++--- pom.xml | 5 ---- python/pyspark/rdd.py | 24 +++++++++---------- 5 files changed, 24 insertions(+), 30 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index fb4c3c6ebd6cc..29614376aee37 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -391,6 +391,7 @@ abstract class RDD[T: ClassTag]( seed: Long = Utils.random.nextLong): Array[T] = { var fraction = 0.0 var total = 0 + val numStDev = 10.0 val initialCount = this.count() if (num < 0) { @@ -406,15 +407,15 @@ abstract class RDD[T: ClassTag]( "sampling without replacement") } - if (initialCount > Integer.MAX_VALUE - 1) { - val maxSelected = Integer.MAX_VALUE - (5.0 * math.sqrt(Integer.MAX_VALUE)).toInt + if (initialCount > Int.MaxValue - 1) { + val maxSelected = Int.MaxValue - (numStDev * math.sqrt(Int.MaxValue)).toInt if (num > maxSelected) { - throw new IllegalArgumentException("Cannot support a sample size > Integer.MAX_VALUE - " + - "5.0 * math.sqrt(Integer.MAX_VALUE)") + throw new IllegalArgumentException("Cannot support a sample size > Int.MaxValue - " + + s"$numStDev * math.sqrt(Int.MaxValue)") } } - fraction = SamplingUtils.computeFraction(num, initialCount, withReplacement) + fraction = SamplingUtils.computeFractionForSampleSize(num, initialCount, withReplacement) total = num val rand = new Random(seed) diff --git a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala index 0905e3dd49dec..66c93d59a4282 100644 --- a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala @@ -31,15 +31,15 @@ private[spark] object SamplingUtils { * ~ Binomial(total, fraction) and our choice of q guarantees 1-delta, or 0.9999 success * rate, where success rate is defined the same as in sampling with replacement. * - * @param num sample size + * @param sampleSizeLowerBound sample size * @param total size of RDD * @param withReplacement whether sampling with replacement * @return a sampling rate that guarantees sufficient sample size with 99.99% success rate */ - def computeFraction(num: Int, total: Long, withReplacement: Boolean): Double = { - val fraction = num.toDouble / total + def computeFractionForSampleSize(sampleSizeLowerBound: Int, total: Long, withReplacement: Boolean): Double = { + val fraction = sampleSizeLowerBound.toDouble / total if (withReplacement) { - val numStDev = if (num < 12) 9 else 5 + val numStDev = if (sampleSizeLowerBound < 12) 9 else 5 fraction + numStDev * math.sqrt(fraction / total) } else { val delta = 1e-4 diff --git a/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala index 3ad46c41fe902..eace1723a8c3d 100644 --- a/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala @@ -28,17 +28,17 @@ class SamplingUtilsSuite extends FunSuite{ val n = 100000 for (s <- 1 to 15) { - val frac = SamplingUtils.computeFraction(s, n, true) + val frac = SamplingUtils.computeFractionForSampleSize(s, n, true) val poisson = new PoissonDistribution(frac * n) assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low") } for (s <- List(20, 100, 1000)) { - val frac = SamplingUtils.computeFraction(s, n, true) + val frac = SamplingUtils.computeFractionForSampleSize(s, n, true) val poisson = new PoissonDistribution(frac * n) assert(poisson.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low") } for (s <- List(1, 10, 100, 1000)) { - val frac = SamplingUtils.computeFraction(s, n, false) + val frac = SamplingUtils.computeFractionForSampleSize(s, n, false) val binomial = new BinomialDistribution(n, frac) assert(binomial.inverseCumulativeProbability(0.0001)*n >= s, "Computed fraction is too low") } diff --git a/pom.xml b/pom.xml index ee0ddcc1d845c..0d46bb4114f73 100644 --- a/pom.xml +++ b/pom.xml @@ -256,11 +256,6 @@ commons-codec 1.5 - - org.apache.commons - commons-math3 - 3.3 - com.google.code.findbugs jsr305 diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 8154924f6eba8..898a2e42a7704 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -366,27 +366,25 @@ def takeSample(self, withReplacement, num, seed=None): fraction = 0.0 total = 0 - multiplier = 3.0 + numStDev = 10.0 initialCount = self.count() - maxSelected = 0 - if (num < 0): + if num < 0: raise ValueError - if (initialCount == 0): + if initialCount == 0: return list() + if (not withReplacement) and num > initialCount: + raise ValueError + if initialCount > sys.maxint - 1: - maxSelected = sys.maxint - 1 - else: - maxSelected = initialCount + maxSelected = sys.maxint - int(numStDev * sqrt(sys.maxint)) + if num > maxSelected: + raise ValueError - if num > initialCount and not withReplacement: - total = maxSelected - fraction = multiplier * (maxSelected + 1) / initialCount - else: - fraction = self._computeFraction(num, initialCount, withReplacement) - total = num + fraction = self._computeFraction(num, initialCount, withReplacement) + total = num samples = self.sample(withReplacement, fraction, seed).collect() From ecab5082f20c265cd967f0e2d946a175fc1fd48c Mon Sep 17 00:00:00 2001 From: Doris Xin Date: Tue, 10 Jun 2014 13:33:23 -0700 Subject: [PATCH 07/16] "fixed checkstyle violation --- .../scala/org/apache/spark/util/random/SamplingUtils.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala index 66c93d59a4282..7f1dd0977137f 100644 --- a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala @@ -36,7 +36,8 @@ private[spark] object SamplingUtils { * @param withReplacement whether sampling with replacement * @return a sampling rate that guarantees sufficient sample size with 99.99% success rate */ - def computeFractionForSampleSize(sampleSizeLowerBound: Int, total: Long, withReplacement: Boolean): Double = { + def computeFractionForSampleSize(sampleSizeLowerBound: Int, total: Long, + withReplacement: Boolean): Double = { val fraction = sampleSizeLowerBound.toDouble / total if (withReplacement) { val numStDev = if (sampleSizeLowerBound < 12) 9 else 5 From eff89e2c7ac824c7a9580b1388a92bb2144dc4dd Mon Sep 17 00:00:00 2001 From: Doris Xin Date: Thu, 12 Jun 2014 14:00:27 -0700 Subject: [PATCH 08/16] addressed reviewer comments. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Note that logging isn’t added to rdd.py because it seemed to be clobbering with the log4j logs --- .../main/scala/org/apache/spark/rdd/RDD.scala | 25 ++++----- .../spark/util/random/SamplingUtils.scala | 4 ++ .../scala/org/apache/spark/rdd/RDDSuite.scala | 4 +- .../util/random/SamplingUtilsSuite.scala | 4 +- python/pyspark/rdd.py | 53 ++++++++++++------- 5 files changed, 54 insertions(+), 36 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 29614376aee37..adf680f602ff9 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -389,8 +389,6 @@ abstract class RDD[T: ClassTag]( def takeSample(withReplacement: Boolean, num: Int, seed: Long = Utils.random.nextLong): Array[T] = { - var fraction = 0.0 - var total = 0 val numStDev = 10.0 val initialCount = this.count() @@ -407,27 +405,30 @@ abstract class RDD[T: ClassTag]( "sampling without replacement") } - if (initialCount > Int.MaxValue - 1) { - val maxSelected = Int.MaxValue - (numStDev * math.sqrt(Int.MaxValue)).toInt - if (num > maxSelected) { - throw new IllegalArgumentException("Cannot support a sample size > Int.MaxValue - " + - s"$numStDev * math.sqrt(Int.MaxValue)") - } + val maxSampleSize = Int.MaxValue - (numStDev * math.sqrt(Int.MaxValue)).toInt + if (num > maxSampleSize) { + throw new IllegalArgumentException("Cannot support a sample size > Int.MaxValue - " + + s"$numStDev * math.sqrt(Int.MaxValue)") } - fraction = SamplingUtils.computeFractionForSampleSize(num, initialCount, withReplacement) - total = num + val fraction = SamplingUtils.computeFractionForSampleSize(num, initialCount, + withReplacement) val rand = new Random(seed) var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() // If the first sample didn't turn out large enough, keep trying to take samples; // this shouldn't happen often because we use a big multiplier for the initial size - while (samples.length < total) { + var numIters = 0 + while (samples.length < num) { + if (numIters > 0) { + logWarning(s"Needed to re-sample due to insufficient sample size. Repeat #$numIters") + } samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() + numIters += 1 } - Utils.randomizeInPlace(samples, rand).take(total) + Utils.randomizeInPlace(samples, rand).take(num) } /** diff --git a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala index 7f1dd0977137f..a79e3ee756fc6 100644 --- a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala @@ -20,6 +20,10 @@ package org.apache.spark.util.random private[spark] object SamplingUtils { /** + * Returns a sampling rate that guarantees a sample of size >= sampleSizeLowerBound 99.99% of + * the time. + * + * How the sampling rate is determined: * Let p = num / total, where num is the sample size and total is the total number of * datapoints in the RDD. We're trying to compute q > p such that * - when sampling with replacement, we're drawing each datapoint with prob_i ~ Pois(q), diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 81bd00db91a14..a3d54c26e2cf2 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -547,8 +547,8 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements") } for (seed <- 1 to 5) { - val sample = data.takeSample(withReplacement=true, 2*n, seed) - assert(sample.size === 2*n) // Got exactly 200 elements + val sample = data.takeSample(withReplacement=true, 2 * n, seed) + assert(sample.size === 2 * n) // Got exactly 200 elements // Chance of getting all distinct elements is still quite low, so test we got < 100 assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements") } diff --git a/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala index eace1723a8c3d..accfe2e9b7f2a 100644 --- a/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala @@ -20,10 +20,10 @@ package org.apache.spark.util.random import org.apache.commons.math3.distribution.{BinomialDistribution, PoissonDistribution} import org.scalatest.FunSuite -class SamplingUtilsSuite extends FunSuite{ +class SamplingUtilsSuite extends FunSuite { test("computeFraction") { - // test that the computed fraction guarantees enough datapoints + // test that the computed fraction guarantees enough data points // in the sample with a failure rate <= 0.0001 val n = 100000 diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 898a2e42a7704..e27599af51b7c 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -15,8 +15,6 @@ # limitations under the License. # -from base64 import standard_b64encode as b64enc -import copy from collections import defaultdict from collections import namedtuple from itertools import chain, ifilter, imap @@ -364,8 +362,8 @@ def takeSample(self, withReplacement, num, seed=None): [4, 2, 1, 8, 2, 7, 0, 4, 1, 4] """ - fraction = 0.0 - total = 0 + #TODO remove + logging.basicConfig(level=logging.INFO) numStDev = 10.0 initialCount = self.count() @@ -378,38 +376,53 @@ def takeSample(self, withReplacement, num, seed=None): if (not withReplacement) and num > initialCount: raise ValueError - if initialCount > sys.maxint - 1: - maxSelected = sys.maxint - int(numStDev * sqrt(sys.maxint)) - if num > maxSelected: - raise ValueError - - fraction = self._computeFraction(num, initialCount, withReplacement) - total = num + maxSampleSize = sys.maxint - int(numStDev * sqrt(sys.maxint)) + if num > maxSampleSize: + raise ValueError + fraction = self._computeFractionForSampleSize(num, initialCount, withReplacement) + samples = self.sample(withReplacement, fraction, seed).collect() # If the first sample didn't turn out large enough, keep trying to take samples; # this shouldn't happen often because we use a big multiplier for their initial size. # See: scala/spark/RDD.scala rand = Random(seed) - while len(samples) < total: + while len(samples) < num: samples = self.sample(withReplacement, fraction, rand.randint(0, sys.maxint)).collect() sampler = RDDSampler(withReplacement, fraction, rand.randint(0, sys.maxint)) sampler.shuffle(samples) - return samples[0:total] - - def _computeFraction(self, num, total, withReplacement): - fraction = float(num)/total + return samples[0:num] + + @staticmethod + def _computeFractionForSampleSize(sampleSizeLowerBound, total, withReplacement): + """ + Returns a sampling rate that guarantees a sample of size >= sampleSizeLowerBound 99.99% of + the time. + + How the sampling rate is determined: + Let p = num / total, where num is the sample size and total is the total number of + datapoints in the RDD. We're trying to compute q > p such that + - when sampling with replacement, we're drawing each datapoint with prob_i ~ Pois(q), + where we want to guarantee Pr[s < num] < 0.0001 for s = sum(prob_i for i from 0 to + total), i.e. the failure rate of not having a sufficiently large sample < 0.0001. + Setting q = p + 5 * sqrt(p/total) is sufficient to guarantee 0.9999 success rate for + num > 12, but we need a slightly larger q (9 empirically determined). + - when sampling without replacement, we're drawing each datapoint with prob_i + ~ Binomial(total, fraction) and our choice of q guarantees 1-delta, or 0.9999 success + rate, where success rate is defined the same as in sampling with replacement. + """ + fraction = float(sampleSizeLowerBound) / total if withReplacement: numStDev = 5 - if (num < 12): + if (sampleSizeLowerBound < 12): numStDev = 9 - return fraction + numStDev * sqrt(fraction/total) + return fraction + numStDev * sqrt(fraction / total) else: delta = 0.00005 - gamma = - log(delta)/total - return min(1, fraction + gamma + sqrt(gamma * gamma + 2* gamma * fraction)) + gamma = - log(delta) / total + return min(1, fraction + gamma + sqrt(gamma * gamma + 2 * gamma * fraction)) def union(self, other): """ From 55518edb79cc3c47ec682099676df0e1ba235136 Mon Sep 17 00:00:00 2001 From: Doris Xin Date: Thu, 12 Jun 2014 14:38:37 -0700 Subject: [PATCH 09/16] added TODO for logging in rdd.py --- python/pyspark/rdd.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index e27599af51b7c..a064892c454e7 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -389,6 +389,7 @@ def takeSample(self, withReplacement, num, seed=None): # See: scala/spark/RDD.scala rand = Random(seed) while len(samples) < num: + #TODO add log warning for when more than one iteration was run samples = self.sample(withReplacement, fraction, rand.randint(0, sys.maxint)).collect() sampler = RDDSampler(withReplacement, fraction, rand.randint(0, sys.maxint)) From 64e445b943fe8cb0eaf0371017922ebb8a688353 Mon Sep 17 00:00:00 2001 From: Doris Xin Date: Thu, 12 Jun 2014 14:41:10 -0700 Subject: [PATCH 10/16] logwarnning as soon as it enters the while loop --- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index adf680f602ff9..d69f5136dda12 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -421,9 +421,7 @@ abstract class RDD[T: ClassTag]( // this shouldn't happen often because we use a big multiplier for the initial size var numIters = 0 while (samples.length < num) { - if (numIters > 0) { - logWarning(s"Needed to re-sample due to insufficient sample size. Repeat #$numIters") - } + logWarning(s"Needed to re-sample due to insufficient sample size. Repeat #$numIters") samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() numIters += 1 } From dc699f3586bc8cf6097f74965b76bbcff104963d Mon Sep 17 00:00:00 2001 From: Doris Xin Date: Thu, 12 Jun 2014 14:57:25 -0700 Subject: [PATCH 11/16] give back imports removed by accident in rdd.py --- python/pyspark/rdd.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index a064892c454e7..d1b2219cbf2cb 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -15,6 +15,8 @@ # limitations under the License. # +from base64 import standard_b64encode as b64enc +import copy from collections import defaultdict from collections import namedtuple from itertools import chain, ifilter, imap From 1481b01c0f4f26e4472e4e0e57683b9c501c59d4 Mon Sep 17 00:00:00 2001 From: Doris Xin Date: Thu, 12 Jun 2014 15:15:20 -0700 Subject: [PATCH 12/16] washing test tubes and making coffee --- python/pyspark/rdd.py | 100 +++++++++++++++++++++++------------------- 1 file changed, 55 insertions(+), 45 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index d1b2219cbf2cb..3c9575ca942be 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -203,9 +203,9 @@ def cache(self): def persist(self, storageLevel): """ - Set this RDD's storage level to persist its values across operations after the first time - it is computed. This can only be used to assign a new storage level if the RDD does not - have a storage level set yet. + Set this RDD's storage level to persist its values across operations + after the first time it is computed. This can only be used to assign + a new storage level if the RDD does not have a storage level set yet. """ self.is_cached = True javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel) @@ -214,7 +214,8 @@ def persist(self, storageLevel): def unpersist(self): """ - Mark the RDD as non-persistent, and remove all blocks for it from memory and disk. + Mark the RDD as non-persistent, and remove all blocks for it from + memory and disk. """ self.is_cached = False self._jrdd.unpersist() @@ -358,7 +359,8 @@ def sample(self, withReplacement, fraction, seed=None): # this is ported from scala/spark/RDD.scala def takeSample(self, withReplacement, num, seed=None): """ - Return a fixed-size sampled subset of this RDD (currently requires numpy). + Return a fixed-size sampled subset of this RDD (currently requires + numpy). >>> sc.parallelize(range(0, 10)).takeSample(True, 10, 1) #doctest: +SKIP [4, 2, 1, 8, 2, 7, 0, 4, 1, 4] @@ -401,20 +403,24 @@ def takeSample(self, withReplacement, num, seed=None): @staticmethod def _computeFractionForSampleSize(sampleSizeLowerBound, total, withReplacement): """ - Returns a sampling rate that guarantees a sample of size >= sampleSizeLowerBound 99.99% of - the time. + Returns a sampling rate that guarantees a sample of + size >= sampleSizeLowerBound 99.99% of the time. How the sampling rate is determined: - Let p = num / total, where num is the sample size and total is the total number of - datapoints in the RDD. We're trying to compute q > p such that - - when sampling with replacement, we're drawing each datapoint with prob_i ~ Pois(q), - where we want to guarantee Pr[s < num] < 0.0001 for s = sum(prob_i for i from 0 to - total), i.e. the failure rate of not having a sufficiently large sample < 0.0001. - Setting q = p + 5 * sqrt(p/total) is sufficient to guarantee 0.9999 success rate for - num > 12, but we need a slightly larger q (9 empirically determined). - - when sampling without replacement, we're drawing each datapoint with prob_i - ~ Binomial(total, fraction) and our choice of q guarantees 1-delta, or 0.9999 success - rate, where success rate is defined the same as in sampling with replacement. + Let p = num / total, where num is the sample size and total is the + total number of data points in the RDD. We're trying to compute + q > p such that + - when sampling with replacement, we're drawing each data point + with prob_i ~ Pois(q), where we want to guarantee + Pr[s < num] < 0.0001 for s = sum(prob_i for i from 0 to + total), i.e. the failure rate of not having a sufficiently large + sample < 0.0001. Setting q = p + 5 * sqrt(p/total) is sufficient + to guarantee 0.9999 success rate for num > 12, but we need a + slightly larger q (9 empirically determined). + - when sampling without replacement, we're drawing each data point + with prob_i ~ Binomial(total, fraction) and our choice of q + guarantees 1-delta, or 0.9999 success rate, where success rate is + defined the same as in sampling with replacement. """ fraction = float(sampleSizeLowerBound) / total if withReplacement: @@ -449,8 +455,8 @@ def union(self, other): def intersection(self, other): """ - Return the intersection of this RDD and another one. The output will not - contain any duplicate elements, even if the input RDDs did. + Return the intersection of this RDD and another one. The output will + not contain any duplicate elements, even if the input RDDs did. Note that this method performs a shuffle internally. @@ -692,8 +698,8 @@ def aggregate(self, zeroValue, seqOp, combOp): modify C{t2}. The first function (seqOp) can return a different result type, U, than - the type of this RDD. Thus, we need one operation for merging a T into an U - and one operation for merging two U + the type of this RDD. Thus, we need one operation for merging a T into + an U and one operation for merging two U >>> seqOp = (lambda x, y: (x[0] + y, x[1] + 1)) >>> combOp = (lambda x, y: (x[0] + y[0], x[1] + y[1])) @@ -786,8 +792,9 @@ def stdev(self): def sampleStdev(self): """ - Compute the sample standard deviation of this RDD's elements (which corrects for bias in - estimating the standard deviation by dividing by N-1 instead of N). + Compute the sample standard deviation of this RDD's elements (which + corrects for bias in estimating the standard deviation by dividing by + N-1 instead of N). >>> sc.parallelize([1, 2, 3]).sampleStdev() 1.0 @@ -796,8 +803,8 @@ def sampleStdev(self): def sampleVariance(self): """ - Compute the sample variance of this RDD's elements (which corrects for bias in - estimating the variance by dividing by N-1 instead of N). + Compute the sample variance of this RDD's elements (which corrects + for bias in estimating the variance by dividing by N-1 instead of N). >>> sc.parallelize([1, 2, 3]).sampleVariance() 1.0 @@ -849,8 +856,8 @@ def merge(a, b): def takeOrdered(self, num, key=None): """ - Get the N elements from a RDD ordered in ascending order or as specified - by the optional key function. + Get the N elements from a RDD ordered in ascending order or as + specified by the optional key function. >>> sc.parallelize([10, 1, 2, 9, 3, 4, 5, 6, 7]).takeOrdered(6) [1, 2, 3, 4, 5, 6] @@ -939,8 +946,9 @@ def first(self): def saveAsPickleFile(self, path, batchSize=10): """ - Save this RDD as a SequenceFile of serialized objects. The serializer used is - L{pyspark.serializers.PickleSerializer}, default batch size is 10. + Save this RDD as a SequenceFile of serialized objects. The serializer + used is L{pyspark.serializers.PickleSerializer}, default batch size + is 10. >>> tmpFile = NamedTemporaryFile(delete=True) >>> tmpFile.close() @@ -1208,9 +1216,10 @@ def _mergeCombiners(iterator): def foldByKey(self, zeroValue, func, numPartitions=None): """ - Merge the values for each key using an associative function "func" and a neutral "zeroValue" - which may be added to the result an arbitrary number of times, and must not change - the result (e.g., 0 for addition, or 1 for multiplication.). + Merge the values for each key using an associative function "func" + and a neutral "zeroValue" which may be added to the result an + arbitrary number of times, and must not change the result + (e.g., 0 for addition, or 1 for multiplication.). >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) >>> from operator import add @@ -1227,8 +1236,8 @@ def groupByKey(self, numPartitions=None): Hash-partitions the resulting RDD with into numPartitions partitions. Note: If you are grouping in order to perform an aggregation (such as a - sum or average) over each key, using reduceByKey will provide much better - performance. + sum or average) over each key, using reduceByKey will provide much + better performance. >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) >>> map((lambda (x,y): (x, list(y))), sorted(x.groupByKey().collect())) @@ -1288,8 +1297,8 @@ def groupWith(self, other): def cogroup(self, other, numPartitions=None): """ For each key k in C{self} or C{other}, return a resulting RDD that - contains a tuple with the list of values for that key in C{self} as well - as C{other}. + contains a tuple with the list of values for that key in C{self} as + well as C{other}. >>> x = sc.parallelize([("a", 1), ("b", 4)]) >>> y = sc.parallelize([("a", 2)]) @@ -1300,8 +1309,8 @@ def cogroup(self, other, numPartitions=None): def subtractByKey(self, other, numPartitions=None): """ - Return each (key, value) pair in C{self} that has no pair with matching key - in C{other}. + Return each (key, value) pair in C{self} that has no pair with matching + key in C{other}. >>> x = sc.parallelize([("a", 1), ("b", 4), ("b", 5), ("a", 2)]) >>> y = sc.parallelize([("a", 3), ("c", None)]) @@ -1339,10 +1348,10 @@ def repartition(self, numPartitions): """ Return a new RDD that has exactly numPartitions partitions. - Can increase or decrease the level of parallelism in this RDD. Internally, this uses - a shuffle to redistribute data. - If you are decreasing the number of partitions in this RDD, consider using `coalesce`, - which can avoid performing a shuffle. + Can increase or decrease the level of parallelism in this RDD. + Internally, this uses a shuffle to redistribute data. + If you are decreasing the number of partitions in this RDD, consider + using `coalesce`, which can avoid performing a shuffle. >>> rdd = sc.parallelize([1,2,3,4,5,6,7], 4) >>> sorted(rdd.glom().collect()) [[1], [2, 3], [4, 5], [6, 7]] @@ -1367,9 +1376,10 @@ def coalesce(self, numPartitions, shuffle=False): def zip(self, other): """ - Zips this RDD with another one, returning key-value pairs with the first element in each RDD - second element in each RDD, etc. Assumes that the two RDDs have the same number of - partitions and the same number of elements in each partition (e.g. one was made through + Zips this RDD with another one, returning key-value pairs with the + first element in each RDD second element in each RDD, etc. Assumes + that the two RDDs have the same number of partitions and the same + number of elements in each partition (e.g. one was made through a map on the other). >>> x = sc.parallelize(range(0,5)) From fb1452f387e2a80c10fbec074d38d424d87210e5 Mon Sep 17 00:00:00 2001 From: Doris Xin Date: Thu, 12 Jun 2014 16:22:39 -0700 Subject: [PATCH 13/16] allowing num to be greater than count in all cases --- .../main/scala/org/apache/spark/rdd/RDD.scala | 13 ++++--- python/pyspark/rdd.py | 36 ++++++++++--------- 2 files changed, 25 insertions(+), 24 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index d69f5136dda12..33fd3a17fd5d8 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -396,25 +396,24 @@ abstract class RDD[T: ClassTag]( throw new IllegalArgumentException("Negative number of elements requested") } - if (initialCount == 0) { + if (initialCount == 0 || num == 0) { return new Array[T](0) } - if (!withReplacement && num > initialCount) { - throw new IllegalArgumentException("Cannot create sample larger than the original when " + - "sampling without replacement") - } - val maxSampleSize = Int.MaxValue - (numStDev * math.sqrt(Int.MaxValue)).toInt if (num > maxSampleSize) { throw new IllegalArgumentException("Cannot support a sample size > Int.MaxValue - " + s"$numStDev * math.sqrt(Int.MaxValue)") } + val rand = new Random(seed) + if (!withReplacement && num > initialCount) { + return Utils.randomizeInPlace(this.collect(), rand) + } + val fraction = SamplingUtils.computeFractionForSampleSize(num, initialCount, withReplacement) - val rand = new Random(seed) var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() // If the first sample didn't turn out large enough, keep trying to take samples; diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 3c9575ca942be..824ce42e494e0 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -366,35 +366,37 @@ def takeSample(self, withReplacement, num, seed=None): [4, 2, 1, 8, 2, 7, 0, 4, 1, 4] """ - #TODO remove - logging.basicConfig(level=logging.INFO) numStDev = 10.0 initialCount = self.count() if num < 0: raise ValueError - if initialCount == 0: + if initialCount == 0 or num == 0: return list() + rand = Random(seed) if (not withReplacement) and num > initialCount: - raise ValueError + # shuffle current RDD and return + samples = self.collect() + fraction = float(num) / initialCount + num = initialCount + else: + maxSampleSize = sys.maxint - int(numStDev * sqrt(sys.maxint)) + if num > maxSampleSize: + raise ValueError - maxSampleSize = sys.maxint - int(numStDev * sqrt(sys.maxint)) - if num > maxSampleSize: - raise ValueError + fraction = self._computeFractionForSampleSize(num, initialCount, withReplacement) - fraction = self._computeFractionForSampleSize(num, initialCount, withReplacement) - - samples = self.sample(withReplacement, fraction, seed).collect() + samples = self.sample(withReplacement, fraction, seed).collect() - # If the first sample didn't turn out large enough, keep trying to take samples; - # this shouldn't happen often because we use a big multiplier for their initial size. - # See: scala/spark/RDD.scala - rand = Random(seed) - while len(samples) < num: - #TODO add log warning for when more than one iteration was run - samples = self.sample(withReplacement, fraction, rand.randint(0, sys.maxint)).collect() + # If the first sample didn't turn out large enough, keep trying to take samples; + # this shouldn't happen often because we use a big multiplier for their initial size. + # See: scala/spark/RDD.scala + while len(samples) < num: + #TODO add log warning for when more than one iteration was run + seed = rand.randint(0, sys.maxint) + samples = self.sample(withReplacement, fraction, seed).collect() sampler = RDDSampler(withReplacement, fraction, rand.randint(0, sys.maxint)) sampler.shuffle(samples) From 48d954dd3bf029aa16ba7a504d10d42abcd90d52 Mon Sep 17 00:00:00 2001 From: Doris Xin Date: Thu, 12 Jun 2014 16:44:57 -0700 Subject: [PATCH 14/16] remove unused imports from RDDSuite --- core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala | 4 ---- 1 file changed, 4 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index a3d54c26e2cf2..2e70a59c2f53e 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -22,12 +22,8 @@ import scala.reflect.ClassTag import org.scalatest.FunSuite -import org.apache.commons.math3.distribution.BinomialDistribution -import org.apache.commons.math3.distribution.PoissonDistribution - import org.apache.spark._ import org.apache.spark.SparkContext._ -import org.apache.spark.rdd._ class RDDSuite extends FunSuite with SharedSparkContext { From 82dde3138958158c4055818987cb99d001e77e39 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 12 Jun 2014 17:20:50 -0700 Subject: [PATCH 15/16] update pyspark's takeSample --- python/pyspark/rdd.py | 58 ++++++++++++++++++++++++------------------- 1 file changed, 32 insertions(+), 26 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 824ce42e494e0..8eb1604a941cd 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -362,44 +362,50 @@ def takeSample(self, withReplacement, num, seed=None): Return a fixed-size sampled subset of this RDD (currently requires numpy). - >>> sc.parallelize(range(0, 10)).takeSample(True, 10, 1) #doctest: +SKIP - [4, 2, 1, 8, 2, 7, 0, 4, 1, 4] + >>> rdd = sc.parallelize(range(0, 10)) + >>> len(rdd.takeSample(True, 20, 1)) + 20 + >>> len(rdd.takeSample(False, 5, 2)) + 5 + >>> len(rdd.takeSample(False, 15, 3)) + 10 """ - numStDev = 10.0 - initialCount = self.count() - if num < 0: - raise ValueError + raise ValueError("Sample size cannot be negative.") + elif num == 0: + return [] - if initialCount == 0 or num == 0: - return list() + initialCount = self.count() + if initialCount == 0: + return [] rand = Random(seed) - if (not withReplacement) and num > initialCount: + + if (not withReplacement) and num >= initialCount: # shuffle current RDD and return samples = self.collect() - fraction = float(num) / initialCount - num = initialCount - else: - maxSampleSize = sys.maxint - int(numStDev * sqrt(sys.maxint)) - if num > maxSampleSize: - raise ValueError - - fraction = self._computeFractionForSampleSize(num, initialCount, withReplacement) + rand.shuffle(samples) + return samples + numStDev = 10.0 + maxSampleSize = sys.maxint - int(numStDev * sqrt(sys.maxint)) + if num > maxSampleSize: + raise ValueError("Sample size cannot be greater than %d." % maxSampleSize) + + fraction = RDD._computeFractionForSampleSize(num, initialCount, withReplacement) + samples = self.sample(withReplacement, fraction, seed).collect() + + # If the first sample didn't turn out large enough, keep trying to take samples; + # this shouldn't happen often because we use a big multiplier for their initial size. + # See: scala/spark/RDD.scala + while len(samples) < num: + # TODO: add log warning for when more than one iteration was run + seed = rand.randint(0, sys.maxint) samples = self.sample(withReplacement, fraction, seed).collect() - # If the first sample didn't turn out large enough, keep trying to take samples; - # this shouldn't happen often because we use a big multiplier for their initial size. - # See: scala/spark/RDD.scala - while len(samples) < num: - #TODO add log warning for when more than one iteration was run - seed = rand.randint(0, sys.maxint) - samples = self.sample(withReplacement, fraction, seed).collect() + rand.shuffle(samples) - sampler = RDDSampler(withReplacement, fraction, rand.randint(0, sys.maxint)) - sampler.shuffle(samples) return samples[0:num] @staticmethod From 444e7500ca4be8ea7071703ab6d5043d69289a7a Mon Sep 17 00:00:00 2001 From: Doris Xin Date: Thu, 12 Jun 2014 17:39:20 -0700 Subject: [PATCH 16/16] edge cases --- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 8 +++++--- python/pyspark/rdd.py | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 33fd3a17fd5d8..fb12738f499f3 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -390,13 +390,15 @@ abstract class RDD[T: ClassTag]( num: Int, seed: Long = Utils.random.nextLong): Array[T] = { val numStDev = 10.0 - val initialCount = this.count() if (num < 0) { throw new IllegalArgumentException("Negative number of elements requested") + } else if (num == 0) { + return new Array[T](0) } - if (initialCount == 0 || num == 0) { + val initialCount = this.count() + if (initialCount == 0) { return new Array[T](0) } @@ -407,7 +409,7 @@ abstract class RDD[T: ClassTag]( } val rand = new Random(seed) - if (!withReplacement && num > initialCount) { + if (!withReplacement && num >= initialCount) { return Utils.randomizeInPlace(this.collect(), rand) } diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 8eb1604a941cd..da364e19874b1 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -370,6 +370,7 @@ def takeSample(self, withReplacement, num, seed=None): >>> len(rdd.takeSample(False, 15, 3)) 10 """ + numStDev = 10.0 if num < 0: raise ValueError("Sample size cannot be negative.") @@ -388,7 +389,6 @@ def takeSample(self, withReplacement, num, seed=None): rand.shuffle(samples) return samples - numStDev = 10.0 maxSampleSize = sys.maxint - int(numStDev * sqrt(sys.maxint)) if num > maxSampleSize: raise ValueError("Sample size cannot be greater than %d." % maxSampleSize)