@@ -43,7 +43,7 @@ import org.apache.spark.partial.PartialResult
4343import org .apache .spark .storage .StorageLevel
4444import org .apache .spark .util .{BoundedPriorityQueue , SerializableHyperLogLog , Utils }
4545import org .apache .spark .util .collection .OpenHashMap
46- import org .apache .spark .util .random .{BernoulliSampler , PoissonSampler }
46+ import org .apache .spark .util .random .{BernoulliSampler , PoissonSampler , SamplingUtils }
4747
4848/**
4949 * A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable,
@@ -400,12 +400,21 @@ abstract class RDD[T: ClassTag](
400400 throw new IllegalArgumentException (" Negative number of elements requested" )
401401 }
402402
403+ if (! withReplacement && num > initialCount) {
404+ throw new IllegalArgumentException (" Cannot create sample larger than the original when " +
405+ " sampling without replacement" )
406+ }
407+
403408 if (initialCount == 0 ) {
404409 return new Array [T ](0 )
405410 }
406411
407412 if (initialCount > Integer .MAX_VALUE - 1 ) {
408- maxSelected = Integer .MAX_VALUE - 1
413+ maxSelected = Integer .MAX_VALUE - (5.0 * math.sqrt(Integer .MAX_VALUE )).toInt
414+ if (num > maxSelected) {
415+ throw new IllegalArgumentException (" Cannot support a sample size > Integer.MAX_VALUE - " +
416+ " 5.0 * math.sqrt(Integer.MAX_VALUE)" )
417+ }
409418 } else {
410419 maxSelected = initialCount.toInt
411420 }
@@ -415,7 +424,7 @@ abstract class RDD[T: ClassTag](
415424 total = maxSelected
416425 fraction = multiplier * (maxSelected + 1 ) / initialCount
417426 } else {
418- fraction = computeFraction(num, initialCount, withReplacement)
427+ fraction = SamplingUtils . computeFraction(num, initialCount, withReplacement)
419428 total = num
420429 }
421430
@@ -431,35 +440,6 @@ abstract class RDD[T: ClassTag](
431440 Utils .randomizeInPlace(samples, rand).take(total)
432441 }
433442
434- /**
435- * Let p = num / total, where num is the sample size and total is the total number of
436- * datapoints in the RDD. We're trying to compute q > p such that
437- * - when sampling with replacement, we're drawing each datapoint with prob_i ~ Pois(q),
438- * where we want to guarantee Pr[s < num] < 0.0001 for s = sum(prob_i for i from 0 to total),
439- * i.e. the failure rate of not having a sufficiently large sample < 0.0001.
440- * Setting q = p + 5 * sqrt(p/total) is sufficient to guarantee 0.9999 success rate for
441- * num > 12, but we need a slightly larger q (9 empirically determined).
442- * - when sampling without replacement, we're drawing each datapoint with prob_i
443- * ~ Binomial(total, fraction) and our choice of q guarantees 1-delta, or 0.9999 success
444- * rate, where success rate is defined the same as in sampling with replacement.
445- *
446- * @param num sample size
447- * @param total size of RDD
448- * @param withReplacement whether sampling with replacement
449- * @return a sampling rate that guarantees sufficient sample size with 99.99% success rate
450- */
451- private [rdd] def computeFraction (num : Int , total : Long , withReplacement : Boolean ): Double = {
452- val fraction = num.toDouble / total
453- if (withReplacement) {
454- val numStDev = if (num < 12 ) 9 else 5
455- fraction + numStDev * math.sqrt(fraction / total)
456- } else {
457- val delta = 1e-4
458- val gamma = - math.log(delta) / total
459- math.min(1 , fraction + gamma + math.sqrt(gamma * gamma + 2 * gamma * fraction))
460- }
461- }
462-
463443 /**
464444 * Return the union of this RDD and another one. Any identical elements will appear multiple
465445 * times (use `.distinct()` to eliminate them).
0 commit comments