Skip to content

Commit fc9a156

Browse files
committed
added a flag for determining if this Sample node is from the parser.
1 parent 12be2c3 commit fc9a156

File tree

7 files changed

+16
-14
lines changed

7 files changed

+16
-14
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -499,12 +499,14 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
499499
s"Sampling fraction ($fraction) must be on interval [0, 100]")
500500
Sample(0.0, fraction.toDouble / 100, withReplacement = false,
501501
(math.random * 1000).toInt,
502-
relation)
502+
relation)(
503+
isTableSample = true)
503504
case Token("TOK_TABLEBUCKETSAMPLE",
504505
Token(numerator, Nil) ::
505506
Token(denominator, Nil) :: Nil) =>
506507
val fraction = numerator.toDouble / denominator.toDouble
507-
Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, relation)
508+
Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, relation)(
509+
isTableSample = true)
508510
case a =>
509511
noParseRule("Sampling", a)
510512
}.getOrElse(relation)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ object SamplePushDown extends Rule[LogicalPlan] {
109109
// Push down projection into sample
110110
case Project(projectList, s @ Sample(lb, up, replace, seed, child)) =>
111111
Sample(lb, up, replace, seed,
112-
Project(projectList, child))
112+
Project(projectList, child))()
113113
}
114114
}
115115

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -561,15 +561,18 @@ case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode {
561561
* @param withReplacement Whether to sample with replacement.
562562
* @param seed the random seed
563563
* @param child the LogicalPlan
564+
* @param isTableSample Is created from TABLESAMPLE in the parser.
564565
*/
565566
case class Sample(
566567
lowerBound: Double,
567568
upperBound: Double,
568569
withReplacement: Boolean,
569570
seed: Long,
570-
child: LogicalPlan) extends UnaryNode {
571+
child: LogicalPlan)(
572+
val isTableSample: java.lang.Boolean = false) extends UnaryNode {
571573

572574
override def output: Seq[Attribute] = child.output
575+
override protected def otherCopyArgs: Seq[AnyRef] = isTableSample :: Nil
573576
}
574577

575578
/**

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -640,14 +640,14 @@ class FilterPushdownSuite extends PlanTest {
640640
test("push project and filter down into sample") {
641641
val x = testRelation.subquery('x)
642642
val originalQuery =
643-
Sample(0.0, 0.6, false, 11L, x).select('a)
643+
Sample(0.0, 0.6, false, 11L, x)().select('a)
644644

645645
val originalQueryAnalyzed = EliminateSubQueries(analysis.SimpleAnalyzer.execute(originalQuery))
646646

647647
val optimized = Optimize.execute(originalQueryAnalyzed)
648648

649649
val correctAnswer =
650-
Sample(0.0, 0.6, false, 11L, x.select('a))
650+
Sample(0.0, 0.6, false, 11L, x.select('a))()
651651

652652
comparePlans(optimized, correctAnswer.analyze)
653653
}

sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,7 +1039,7 @@ class DataFrame private[sql](
10391039
* @since 1.3.0
10401040
*/
10411041
def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = withPlan {
1042-
Sample(0.0, fraction, withReplacement, seed, logicalPlan)
1042+
Sample(0.0, fraction, withReplacement, seed, logicalPlan)()
10431043
}
10441044

10451045
/**
@@ -1071,7 +1071,7 @@ class DataFrame private[sql](
10711071
val sum = weights.sum
10721072
val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
10731073
normalizedCumWeights.sliding(2).map { x =>
1074-
new DataFrame(sqlContext, Sample(x(0), x(1), withReplacement = false, seed, sorted))
1074+
new DataFrame(sqlContext, Sample(x(0), x(1), withReplacement = false, seed, sorted)())
10751075
}.toArray
10761076
}
10771077

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -564,7 +564,7 @@ class Dataset[T] private[sql](
564564
* @since 1.6.0
565565
*/
566566
def sample(withReplacement: Boolean, fraction: Double, seed: Long) : Dataset[T] =
567-
withPlan(Sample(0.0, fraction, withReplacement, seed, _))
567+
withPlan(Sample(0.0, fraction, withReplacement, seed, _)())
568568

569569
/**
570570
* Returns a new [[Dataset]] by sampling a fraction of records, using a random seed.

sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,8 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
8181
case p: Limit =>
8282
s"${toSQL(p.child)} LIMIT ${p.limitExpr.sql}"
8383

84-
// TABLESAMPLE is part of tableSource clause in the parser,
85-
// and thus we must handle it with subquery.
86-
case p @ Sample(lb, ub, withReplacement, _, _)
87-
if !withReplacement && lb <= (ub + RandomSampler.roundingEpsilon) =>
88-
val fraction = math.min(100, math.max(0, (ub - lb) * 100))
84+
case p: Sample if p.isTableSample =>
85+
val fraction = math.min(100, math.max(0, (p.upperBound - p.lowerBound) * 100))
8986
p.child match {
9087
case m: MetastoreRelation =>
9188
val aliasName = m.alias.getOrElse("")

0 commit comments

Comments
 (0)