Skip to content

Commit d7dbce8

Browse files
brkyvzrxin
authored andcommitted
[SPARK-7156][SQL] support RandomSplit in DataFrames
This is built on top of kaka1992 's PR #5711 using Logical plans. Author: Burak Yavuz <[email protected]> Closes #5761 from brkyvz/random-sample and squashes the following commits: a1fb0aa [Burak Yavuz] remove unrelated file 69669c3 [Burak Yavuz] fix broken test 1ddb3da [Burak Yavuz] copy base 6000328 [Burak Yavuz] added python api and fixed test 3c11d1b [Burak Yavuz] fixed broken test f400ade [Burak Yavuz] fix build errors 2384266 [Burak Yavuz] addressed comments v0.1 e98ebac [Burak Yavuz] [SPARK-7156][SQL] support RandomSplit in DataFrames
1 parent c9d530e commit d7dbce8

File tree

10 files changed

+130
-22
lines changed

10 files changed

+130
-22
lines changed

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -407,11 +407,26 @@ abstract class RDD[T: ClassTag](
407407
val sum = weights.sum
408408
val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
409409
normalizedCumWeights.sliding(2).map { x =>
410-
new PartitionwiseSampledRDD[T, T](
411-
this, new BernoulliCellSampler[T](x(0), x(1)), true, seed)
410+
randomSampleWithRange(x(0), x(1), seed)
412411
}.toArray
413412
}
414413

414+
/**
415+
* Internal method exposed for Random Splits in DataFrames. Samples an RDD given a probability
416+
* range.
417+
* @param lb lower bound to use for the Bernoulli sampler
418+
* @param ub upper bound to use for the Bernoulli sampler
419+
* @param seed the seed for the Random number generator
420+
* @return A random sub-sample of the RDD without replacement.
421+
*/
422+
private[spark] def randomSampleWithRange(lb: Double, ub: Double, seed: Long): RDD[T] = {
423+
this.mapPartitionsWithIndex { case (index, partition) =>
424+
val sampler = new BernoulliCellSampler[T](lb, ub)
425+
sampler.setSeed(seed + index)
426+
sampler.sample(partition)
427+
}
428+
}
429+
415430
/**
416431
* Return a fixed-size sampled subset of this RDD in an array
417432
*

core/src/test/java/org/apache/spark/JavaAPISuite.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -157,11 +157,11 @@ public void sample() {
157157
public void randomSplit() {
158158
List<Integer> ints = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
159159
JavaRDD<Integer> rdd = sc.parallelize(ints);
160-
JavaRDD<Integer>[] splits = rdd.randomSplit(new double[] { 0.4, 0.6, 1.0 }, 11);
160+
JavaRDD<Integer>[] splits = rdd.randomSplit(new double[] { 0.4, 0.6, 1.0 }, 31);
161161
Assert.assertEquals(3, splits.length);
162-
Assert.assertEquals(2, splits[0].count());
163-
Assert.assertEquals(3, splits[1].count());
164-
Assert.assertEquals(5, splits[2].count());
162+
Assert.assertEquals(1, splits[0].count());
163+
Assert.assertEquals(2, splits[1].count());
164+
Assert.assertEquals(7, splits[2].count());
165165
}
166166

167167
@Test

python/pyspark/sql/dataframe.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,14 +426,30 @@ def distinct(self):
426426
def sample(self, withReplacement, fraction, seed=None):
427427
"""Returns a sampled subset of this :class:`DataFrame`.
428428
429-
>>> df.sample(False, 0.5, 97).count()
429+
>>> df.sample(False, 0.5, 42).count()
430430
1
431431
"""
432432
assert fraction >= 0.0, "Negative fraction value: %s" % fraction
433433
seed = seed if seed is not None else random.randint(0, sys.maxsize)
434434
rdd = self._jdf.sample(withReplacement, fraction, long(seed))
435435
return DataFrame(rdd, self.sql_ctx)
436436

437+
def randomSplit(self, weights, seed=None):
438+
"""Randomly splits this :class:`DataFrame` with the provided weights.
439+
440+
>>> splits = df4.randomSplit([1.0, 2.0], 24)
441+
>>> splits[0].count()
442+
1
443+
444+
>>> splits[1].count()
445+
3
446+
"""
447+
for w in weights:
448+
assert w >= 0.0, "Negative weight value: %s" % w
449+
seed = seed if seed is not None else random.randint(0, sys.maxsize)
450+
rdd_array = self._jdf.randomSplit(_to_seq(self.sql_ctx._sc, weights), long(seed))
451+
return [DataFrame(rdd, self.sql_ctx) for rdd in rdd_array]
452+
437453
@property
438454
def dtypes(self):
439455
"""Returns all column names and their data types as a list.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -278,12 +278,6 @@ package object dsl {
278278
def sfilter[T1](arg1: Symbol)(udf: (T1) => Boolean): LogicalPlan =
279279
Filter(ScalaUdf(udf, BooleanType, Seq(UnresolvedAttribute(arg1.name))), logicalPlan)
280280

281-
def sample(
282-
fraction: Double,
283-
withReplacement: Boolean = true,
284-
seed: Int = (math.random * 1000).toInt): LogicalPlan =
285-
Sample(fraction, withReplacement, seed, logicalPlan)
286-
287281
// TODO specify the output column names
288282
def generate(
289283
generator: Generator,

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -300,8 +300,22 @@ case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode {
300300
override def output: Seq[Attribute] = child.output.map(_.withQualifiers(alias :: Nil))
301301
}
302302

303-
case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: LogicalPlan)
304-
extends UnaryNode {
303+
/**
304+
* Sample the dataset.
305+
*
306+
* @param lowerBound Lower-bound of the sampling probability (usually 0.0)
307+
* @param upperBound Upper-bound of the sampling probability. The expected fraction sampled
308+
* will be ub - lb.
309+
* @param withReplacement Whether to sample with replacement.
310+
* @param seed the random seed
311+
* @param child the LogicalPlan
312+
*/
313+
case class Sample(
314+
lowerBound: Double,
315+
upperBound: Double,
316+
withReplacement: Boolean,
317+
seed: Long,
318+
child: LogicalPlan) extends UnaryNode {
305319

306320
override def output: Seq[Attribute] = child.output
307321
}

sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -706,7 +706,7 @@ class DataFrame private[sql](
706706
* @group dfops
707707
*/
708708
def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = {
709-
Sample(fraction, withReplacement, seed, logicalPlan)
709+
Sample(0.0, fraction, withReplacement, seed, logicalPlan)
710710
}
711711

712712
/**
@@ -720,6 +720,42 @@ class DataFrame private[sql](
720720
sample(withReplacement, fraction, Utils.random.nextLong)
721721
}
722722

723+
/**
724+
* Randomly splits this [[DataFrame]] with the provided weights.
725+
*
726+
* @param weights weights for splits, will be normalized if they don't sum to 1.
727+
* @param seed Seed for sampling.
728+
* @group dfops
729+
*/
730+
def randomSplit(weights: Array[Double], seed: Long): Array[DataFrame] = {
731+
val sum = weights.sum
732+
val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
733+
normalizedCumWeights.sliding(2).map { x =>
734+
new DataFrame(sqlContext, Sample(x(0), x(1), false, seed, logicalPlan))
735+
}.toArray
736+
}
737+
738+
/**
739+
* Randomly splits this [[DataFrame]] with the provided weights.
740+
*
741+
* @param weights weights for splits, will be normalized if they don't sum to 1.
742+
* @group dfops
743+
*/
744+
def randomSplit(weights: Array[Double]): Array[DataFrame] = {
745+
randomSplit(weights, Utils.random.nextLong)
746+
}
747+
748+
/**
749+
* Randomly splits this [[DataFrame]] with the provided weights. Provided for the Python Api.
750+
*
751+
* @param weights weights for splits, will be normalized if they don't sum to 1.
752+
* @param seed Seed for sampling.
753+
* @group dfops
754+
*/
755+
def randomSplit(weights: List[Double], seed: Long): Array[DataFrame] = {
756+
randomSplit(weights.toArray, seed)
757+
}
758+
723759
/**
724760
* (Scala-specific) Returns a new [[DataFrame]] where each row has been expanded to zero or more
725761
* rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. The columns of

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -303,8 +303,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
303303
execution.Expand(projections, output, planLater(child)) :: Nil
304304
case logical.Aggregate(group, agg, child) =>
305305
execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil
306-
case logical.Sample(fraction, withReplacement, seed, child) =>
307-
execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil
306+
case logical.Sample(lb, ub, withReplacement, seed, child) =>
307+
execution.Sample(lb, ub, withReplacement, seed, planLater(child)) :: Nil
308308
case logical.LocalRelation(output, data) =>
309309
LocalTableScan(output, data) :: Nil
310310
case logical.Limit(IntegerLiteral(limit), child) =>

sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,16 +63,32 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {
6363

6464
/**
6565
* :: DeveloperApi ::
66+
* Sample the dataset.
67+
* @param lowerBound Lower-bound of the sampling probability (usually 0.0)
68+
* @param upperBound Upper-bound of the sampling probability. The expected fraction sampled
69+
* will be ub - lb.
70+
* @param withReplacement Whether to sample with replacement.
71+
* @param seed the random seed
72+
* @param child the QueryPlan
6673
*/
6774
@DeveloperApi
68-
case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: SparkPlan)
75+
case class Sample(
76+
lowerBound: Double,
77+
upperBound: Double,
78+
withReplacement: Boolean,
79+
seed: Long,
80+
child: SparkPlan)
6981
extends UnaryNode
7082
{
7183
override def output: Seq[Attribute] = child.output
7284

7385
// TODO: How to pick seed?
7486
override def execute(): RDD[Row] = {
75-
child.execute().map(_.copy()).sample(withReplacement, fraction, seed)
87+
if (withReplacement) {
88+
child.execute().map(_.copy()).sample(withReplacement, upperBound - lowerBound, seed)
89+
} else {
90+
child.execute().map(_.copy()).randomSampleWithRange(lowerBound, upperBound, seed)
91+
}
7692
}
7793
}
7894

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,23 @@ class DataFrameSuite extends QueryTest {
510510
assert(df.schema.map(_.name).toSeq === Seq("key", "valueRenamed", "newCol"))
511511
}
512512

513+
test("randomSplit") {
514+
val n = 600
515+
val data = TestSQLContext.sparkContext.parallelize(1 to n, 2).toDF("id")
516+
for (seed <- 1 to 5) {
517+
val splits = data.randomSplit(Array[Double](1, 2, 3), seed)
518+
assert(splits.length == 3, "wrong number of splits")
519+
520+
assert(splits.reduce((a, b) => a.unionAll(b)).sort("id").collect().toList ==
521+
data.collect().toList, "incomplete or wrong split")
522+
523+
val s = splits.map(_.count())
524+
assert(math.abs(s(0) - 100) < 50) // std = 9.13
525+
assert(math.abs(s(1) - 200) < 50) // std = 11.55
526+
assert(math.abs(s(2) - 300) < 50) // std = 12.25
527+
}
528+
}
529+
513530
test("describe") {
514531
val describeTestData = Seq(
515532
("Bob", 16, 176),

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -887,13 +887,13 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
887887
fraction.toDouble >= (0.0 - RandomSampler.roundingEpsilon)
888888
&& fraction.toDouble <= (100.0 + RandomSampler.roundingEpsilon),
889889
s"Sampling fraction ($fraction) must be on interval [0, 100]")
890-
Sample(fraction.toDouble / 100, withReplacement = false, (math.random * 1000).toInt,
890+
Sample(0.0, fraction.toDouble / 100, withReplacement = false, (math.random * 1000).toInt,
891891
relation)
892892
case Token("TOK_TABLEBUCKETSAMPLE",
893893
Token(numerator, Nil) ::
894894
Token(denominator, Nil) :: Nil) =>
895895
val fraction = numerator.toDouble / denominator.toDouble
896-
Sample(fraction, withReplacement = false, (math.random * 1000).toInt, relation)
896+
Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, relation)
897897
case a: ASTNode =>
898898
throw new NotImplementedError(
899899
s"""No parse rules for sampling clause: ${a.getType}, text: ${a.getText} :

0 commit comments

Comments
 (0)