@@ -362,44 +362,50 @@ def takeSample(self, withReplacement, num, seed=None):
362362 Return a fixed-size sampled subset of this RDD (currently requires
363363 numpy).
364364
365- >>> sc.parallelize(range(0, 10)).takeSample(True, 10, 1) #doctest: +SKIP
366- [4, 2, 1, 8, 2, 7, 0, 4, 1, 4]
365+ >>> rdd = sc.parallelize(range(0, 10))
366+ >>> len(rdd.takeSample(True, 20, 1))
367+ 20
368+ >>> len(rdd.takeSample(False, 5, 2))
369+ 5
370+ >>> len(rdd.takeSample(False, 15, 3))
371+ 10
367372 """
368373
369- numStDev = 10.0
370- initialCount = self .count ()
371-
372374 if num < 0 :
373- raise ValueError
375+ raise ValueError ("Sample size cannot be negative." )
376+ elif num == 0 :
377+ return []
374378
375- if initialCount == 0 or num == 0 :
376- return list ()
379+ initialCount = self .count ()
380+ if initialCount == 0 :
381+ return []
377382
378383 rand = Random (seed )
379- if (not withReplacement ) and num > initialCount :
384+
385+ if (not withReplacement ) and num >= initialCount :
380386 # shuffle current RDD and return
381387 samples = self .collect ()
382- fraction = float (num ) / initialCount
383- num = initialCount
384- else :
385- maxSampleSize = sys .maxint - int (numStDev * sqrt (sys .maxint ))
386- if num > maxSampleSize :
387- raise ValueError
388-
389- fraction = self ._computeFractionForSampleSize (num , initialCount , withReplacement )
388+ rand .shuffle (samples )
389+ return samples
390390
391+ numStDev = 10.0
392+ maxSampleSize = sys .maxint - int (numStDev * sqrt (sys .maxint ))
393+ if num > maxSampleSize :
394+ raise ValueError ("Sample size cannot be greater than %d." % maxSampleSize )
395+
396+ fraction = RDD ._computeFractionForSampleSize (num , initialCount , withReplacement )
397+ samples = self .sample (withReplacement , fraction , seed ).collect ()
398+
399+ # If the first sample didn't turn out large enough, keep trying to take samples;
400+ # this shouldn't happen often because we use a big multiplier for their initial size.
401+ # See: scala/spark/RDD.scala
402+ while len (samples ) < num :
403+ # TODO: add log warning for when more than one iteration was run
404+ seed = rand .randint (0 , sys .maxint )
391405 samples = self .sample (withReplacement , fraction , seed ).collect ()
392406
393- # If the first sample didn't turn out large enough, keep trying to take samples;
394- # this shouldn't happen often because we use a big multiplier for their initial size.
395- # See: scala/spark/RDD.scala
396- while len (samples ) < num :
397- #TODO add log warning for when more than one iteration was run
398- seed = rand .randint (0 , sys .maxint )
399- samples = self .sample (withReplacement , fraction , seed ).collect ()
407+ rand .shuffle (samples )
400408
401- sampler = RDDSampler (withReplacement , fraction , rand .randint (0 , sys .maxint ))
402- sampler .shuffle (samples )
403409 return samples [0 :num ]
404410
405411 @staticmethod
0 commit comments