1515# limitations under the License.
1616#
1717
18- from base64 import standard_b64encode as b64enc
19- import copy
2018from collections import defaultdict
2119from collections import namedtuple
2220from itertools import chain , ifilter , imap
@@ -364,8 +362,8 @@ def takeSample(self, withReplacement, num, seed=None):
364362 [4, 2, 1, 8, 2, 7, 0, 4, 1, 4]
365363 """
366364
367- fraction = 0.0
368- total = 0
365+ #TODO remove
366+ logging . basicConfig ( level = logging . INFO )
369367 numStDev = 10.0
370368 initialCount = self .count ()
371369
@@ -378,38 +376,53 @@ def takeSample(self, withReplacement, num, seed=None):
378376 if (not withReplacement ) and num > initialCount :
379377 raise ValueError
380378
381- if initialCount > sys .maxint - 1 :
382- maxSelected = sys .maxint - int (numStDev * sqrt (sys .maxint ))
383- if num > maxSelected :
384- raise ValueError
385-
386- fraction = self ._computeFraction (num , initialCount , withReplacement )
387- total = num
379+ maxSampleSize = sys .maxint - int (numStDev * sqrt (sys .maxint ))
380+ if num > maxSampleSize :
381+ raise ValueError
388382
383+ fraction = self ._computeFractionForSampleSize (num , initialCount , withReplacement )
384+
389385 samples = self .sample (withReplacement , fraction , seed ).collect ()
390386
391387 # If the first sample didn't turn out large enough, keep trying to take samples;
392388 # this shouldn't happen often because we use a big multiplier for their initial size.
393389 # See: scala/spark/RDD.scala
394390 rand = Random (seed )
395- while len (samples ) < total :
391+ while len (samples ) < num :
396392 samples = self .sample (withReplacement , fraction , rand .randint (0 , sys .maxint )).collect ()
397393
398394 sampler = RDDSampler (withReplacement , fraction , rand .randint (0 , sys .maxint ))
399395 sampler .shuffle (samples )
400- return samples [0 :total ]
401-
402- def _computeFraction (self , num , total , withReplacement ):
403- fraction = float (num )/ total
396+ return samples [0 :num ]
397+
398+ @staticmethod
399+ def _computeFractionForSampleSize (sampleSizeLowerBound , total , withReplacement ):
400+ """
401+ Returns a sampling rate that guarantees a sample of size >= sampleSizeLowerBound 99.99% of
402+ the time.
403+
404+ How the sampling rate is determined:
405+ Let p = num / total, where num is the sample size and total is the total number of
406+ datapoints in the RDD. We're trying to compute q > p such that
407+ - when sampling with replacement, we're drawing each datapoint with prob_i ~ Pois(q),
408+ where we want to guarantee Pr[s < num] < 0.0001 for s = sum(prob_i for i from 0 to
409+ total), i.e. the failure rate of not having a sufficiently large sample < 0.0001.
410+ Setting q = p + 5 * sqrt(p/total) is sufficient to guarantee 0.9999 success rate for
411+ num > 12, but we need a slightly larger q (9 empirically determined).
412+ - when sampling without replacement, we're drawing each datapoint with prob_i
413+ ~ Binomial(total, fraction) and our choice of q guarantees 1-delta, or 0.9999 success
414+ rate, where success rate is defined the same as in sampling with replacement.
415+ """
416+ fraction = float (sampleSizeLowerBound ) / total
404417 if withReplacement :
405418 numStDev = 5
406- if (num < 12 ):
419+ if (sampleSizeLowerBound < 12 ):
407420 numStDev = 9
408- return fraction + numStDev * sqrt (fraction / total )
421+ return fraction + numStDev * sqrt (fraction / total )
409422 else :
410423 delta = 0.00005
411- gamma = - log (delta )/ total
412- return min (1 , fraction + gamma + sqrt (gamma * gamma + 2 * gamma * fraction ))
424+ gamma = - log (delta ) / total
425+ return min (1 , fraction + gamma + sqrt (gamma * gamma + 2 * gamma * fraction ))
413426
414427 def union (self , other ):
415428 """
0 commit comments