diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 618c2799d9c9..29a4142eb122 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -1303,12 +1303,25 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends OrderPr copy(child = newChild) } +object OffsetAndLimit { + def unapply(p: GlobalLimit): Option[(Int, Int, LogicalPlan)] = { + p match { + // Optimizer pushes local limit through offset, so we need to match the plan this way. + case GlobalLimit(IntegerLiteral(globalLimit), + Offset(IntegerLiteral(offset), + LocalLimit(IntegerLiteral(localLimit), child))) + if globalLimit + offset == localLimit => + Some((offset, globalLimit, child)) + case _ => None + } + } +} + object LimitAndOffset { - def unapply(p: GlobalLimit): Option[(Expression, Expression, LogicalPlan)] = { + def unapply(p: Offset): Option[(Int, Int, LogicalPlan)] = { p match { - case GlobalLimit(le1, Offset(le2, LocalLimit(le3, child))) if le1.eval().asInstanceOf[Int] - + le2.eval().asInstanceOf[Int] == le3.eval().asInstanceOf[Int] => - Some((le1, le2, child)) + case Offset(IntegerLiteral(offset), Limit(IntegerLiteral(limit), child)) => + Some((limit, offset, child)) case _ => None } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index f939ea1882b1..34891b3d1ab7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -81,55 +81,56 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { */ object SpecialLimits extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case ReturnAnswer(rootPlan) => rootPlan match { - case Limit(IntegerLiteral(limit), Sort(order, true, child)) - if limit < conf.topKSortFallbackThreshold => - TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil - case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child))) - if limit < conf.topKSortFallbackThreshold => - TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil + // Call `planTakeOrdered` first which matches a larger plan. + case ReturnAnswer(rootPlan) => planTakeOrdered(rootPlan).getOrElse(rootPlan match { + // We should match the combination of limit and offset first, to get the optimal physical + // plan, instead of planning limit and offset separately. + case LimitAndOffset(limit, offset, child) => + CollectLimitExec(limit = limit, child = planLater(child), offset = offset) + case OffsetAndLimit(offset, limit, child) => + // 'Offset a' then 'Limit b' is the same as 'Limit a + b' then 'Offset a'. + CollectLimitExec(limit = offset + limit, child = planLater(child), offset = offset) case Limit(IntegerLiteral(limit), child) => - CollectLimitExec(limit, planLater(child)) :: Nil - case LimitAndOffset(IntegerLiteral(limit), IntegerLiteral(offset), - Sort(order, true, child)) if limit + offset < conf.topKSortFallbackThreshold => - TakeOrderedAndProjectExec( - limit, order, child.output, planLater(child), offset) :: Nil - case LimitAndOffset(IntegerLiteral(limit), IntegerLiteral(offset), - Project(projectList, Sort(order, true, child))) - if limit + offset < conf.topKSortFallbackThreshold => - TakeOrderedAndProjectExec( - limit, order, projectList, planLater(child), offset) :: Nil - case LimitAndOffset(IntegerLiteral(limit), IntegerLiteral(offset), child) => - CollectLimitExec(limit, planLater(child), offset) :: Nil + CollectLimitExec(limit = limit, child = planLater(child)) case logical.Offset(IntegerLiteral(offset), child) => - CollectLimitExec(child = planLater(child), offset = offset) :: Nil + CollectLimitExec(child = planLater(child), offset = offset) case Tail(IntegerLiteral(limit), child) => - CollectTailExec(limit, planLater(child)) :: Nil - case other => planLater(other) :: Nil - } + CollectTailExec(limit, planLater(child)) + case other => planLater(other) + }) :: Nil + + case other => planTakeOrdered(other).toSeq + } + + private def planTakeOrdered(plan: LogicalPlan): Option[SparkPlan] = plan match { + // We should match the combination of limit and offset first, to get the optimal physical + // plan, instead of planning limit and offset separately. + case LimitAndOffset(limit, offset, Sort(order, true, child)) + if limit < conf.topKSortFallbackThreshold => + Some(TakeOrderedAndProjectExec( + limit, order, child.output, planLater(child), offset)) + case LimitAndOffset(limit, offset, Project(projectList, Sort(order, true, child))) + if limit < conf.topKSortFallbackThreshold => + Some(TakeOrderedAndProjectExec( + limit, order, projectList, planLater(child), offset)) + // 'Offset a' then 'Limit b' is the same as 'Limit a + b' then 'Offset a'. + case OffsetAndLimit(offset, limit, Sort(order, true, child)) + if offset + limit < conf.topKSortFallbackThreshold => + Some(TakeOrderedAndProjectExec( + offset + limit, order, child.output, planLater(child), offset)) + case OffsetAndLimit(offset, limit, Project(projectList, Sort(order, true, child))) + if offset + limit < conf.topKSortFallbackThreshold => + Some(TakeOrderedAndProjectExec( + offset + limit, order, projectList, planLater(child), offset)) case Limit(IntegerLiteral(limit), Sort(order, true, child)) if limit < conf.topKSortFallbackThreshold => - TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil + Some(TakeOrderedAndProjectExec( + limit, order, child.output, planLater(child))) case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child))) if limit < conf.topKSortFallbackThreshold => - TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil - // This is a global LIMIT and OFFSET over a logical sorting operator, - // where the sum of specified limit and specified offset is less than a heuristic threshold. - // In this case we generate a physical top-K sorting operator, passing down - // the limit and offset values to be evaluated inline during the physical - // sorting operation for greater efficiency. - case LimitAndOffset(IntegerLiteral(limit), IntegerLiteral(offset), - Sort(order, true, child)) if limit + offset < conf.topKSortFallbackThreshold => - TakeOrderedAndProjectExec( - limit, order, child.output, planLater(child), offset) :: Nil - case LimitAndOffset(IntegerLiteral(limit), IntegerLiteral(offset), - Project(projectList, Sort(order, true, child))) - if limit + offset < conf.topKSortFallbackThreshold => - TakeOrderedAndProjectExec(limit, order, projectList, planLater(child), offset) :: Nil - case LimitAndOffset(IntegerLiteral(limit), IntegerLiteral(offset), child) => - GlobalLimitAndOffsetExec(limit, offset, planLater(child)) :: Nil - case _ => - Nil + Some(TakeOrderedAndProjectExec( + limit, order, projectList, planLater(child))) + case _ => None } } @@ -814,12 +815,19 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.LocalRelation(output, data, _) => LocalTableScanExec(output, data) :: Nil case CommandResult(output, _, plan, data) => CommandResultExec(output, plan, data) :: Nil + // We should match the combination of limit and offset first, to get the optimal physical + // plan, instead of planning limit and offset separately. + case LimitAndOffset(limit, offset, child) => + GlobalLimitExec(limit, planLater(child), offset) :: Nil + case OffsetAndLimit(offset, limit, child) => + // 'Offset a' then 'Limit b' is the same as 'Limit a + b' then 'Offset a'. + GlobalLimitExec(limit = offset + limit, child = planLater(child), offset = offset) :: Nil case logical.LocalLimit(IntegerLiteral(limit), child) => execution.LocalLimitExec(limit, planLater(child)) :: Nil case logical.GlobalLimit(IntegerLiteral(limit), child) => execution.GlobalLimitExec(limit, planLater(child)) :: Nil case logical.Offset(IntegerLiteral(offset), child) => - GlobalLimitAndOffsetExec(offset = offset, child = planLater(child)) :: Nil + GlobalLimitExec(child = planLater(child), offset = offset) :: Nil case union: logical.Union => execution.UnionExec(union.children.map(planLater)) :: Nil case g @ logical.Generate(generator, _, outer, _, _, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index caffe3ff8555..dbba19002c56 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -37,11 +37,11 @@ trait LimitExec extends UnaryExecNode { } /** - * Take the first `limit` + `offset` elements and collect them to a single partition and then to - * drop the first `offset` elements. + * Take the first `limit` elements, collect them to a single partition and then to drop the + * first `offset` elements. * - * This operator will be used when a logical `Limit` operation is the final operator in an - * logical plan, which happens when the user is collecting results back to the driver. + * This operator will be used when a logical `Limit` and/or `Offset` operation is the final operator + * in an logical plan, which happens when the user is collecting results back to the driver. */ case class CollectLimitExec(limit: Int = -1, child: SparkPlan, offset: Int = 0) extends LimitExec { assert(limit >= 0 || (limit == -1 && offset > 0)) @@ -56,7 +56,7 @@ case class CollectLimitExec(limit: Int = -1, child: SparkPlan, offset: Int = 0) // Then [1, 2, 3] will be taken and output [3]. if (limit >= 0) { if (offset > 0) { - child.executeTake(limit + offset).drop(offset) + child.executeTake(limit).drop(offset) } else { child.executeTake(limit) } @@ -79,11 +79,7 @@ case class CollectLimitExec(limit: Int = -1, child: SparkPlan, offset: Int = 0) childRDD } else { val locallyLimited = if (limit >= 0) { - if (offset > 0) { - childRDD.mapPartitionsInternal(_.take(limit + offset)) - } else { - childRDD.mapPartitionsInternal(_.take(limit)) - } + childRDD.mapPartitionsInternal(_.take(limit)) } else { childRDD } @@ -98,7 +94,7 @@ case class CollectLimitExec(limit: Int = -1, child: SparkPlan, offset: Int = 0) } if (limit >= 0) { if (offset > 0) { - singlePartitionRDD.mapPartitionsInternal(_.slice(offset, offset + limit)) + singlePartitionRDD.mapPartitionsInternal(_.slice(offset, limit)) } else { singlePartitionRDD.mapPartitionsInternal(_.take(limit)) } @@ -164,8 +160,8 @@ trait BaseLimitExec extends LimitExec with CodegenSupport { override def outputOrdering: Seq[SortOrder] = child.outputOrdering - protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => - iter.take(limit) + protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitionsInternal { + iter => iter.take(limit) } override def inputRDDs(): Seq[RDD[InternalRow]] = { @@ -215,61 +211,52 @@ case class LocalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec { } /** - * Take the first `limit` elements of the child's single output partition. + * Take the first `limit` elements and then drop the first `offset` elements in the child's single + * output partition. */ -case class GlobalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec { - - override def requiredChildDistribution: List[Distribution] = AllTuples :: Nil - - override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = - copy(child = newChild) -} - -/** - * Skip the first `offset` elements then take the first `limit` of the following elements in - * the child's single output partition. - */ -case class GlobalLimitAndOffsetExec( - limit: Int = -1, - offset: Int, - child: SparkPlan) extends BaseLimitExec { - assert(offset > 0) +case class GlobalLimitExec(limit: Int = -1, child: SparkPlan, offset: Int = 0) + extends BaseLimitExec { + assert(limit >= 0 || (limit == -1 && offset > 0)) override def requiredChildDistribution: List[Distribution] = AllTuples :: Nil - override def doExecute(): RDD[InternalRow] = if (limit >= 0) { - child.execute().mapPartitionsInternal(iter => iter.slice(offset, limit + offset)) - } else { - child.execute().mapPartitionsInternal(iter => iter.drop(offset)) + override def doExecute(): RDD[InternalRow] = { + if (offset > 0) { + if (limit >= 0) { + child.execute().mapPartitionsInternal(iter => iter.slice(offset, limit)) + } else { + child.execute().mapPartitionsInternal(iter => iter.drop(offset)) + } + } else { + super.doExecute() + } } - private lazy val skipTerm = BaseLimitExec.newLimitCountTerm() - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { - ctx.addMutableState( - CodeGenerator.JAVA_INT, skipTerm, forceInline = true, useFreshName = false) - if (limit >= 0) { - // The counter name is already obtained by the upstream operators via `limitNotReachedChecks`. - // Here we have to inline it to not change its name. This is fine as we won't have many limit - // operators in one query. - ctx.addMutableState( - CodeGenerator.JAVA_INT, countTerm, forceInline = true, useFreshName = false) - s""" - | if ($skipTerm < $offset) { - | $skipTerm += 1; - | } else if ($countTerm < $limit) { - | $countTerm += 1; - | ${consume(ctx, input)} - | } + if (offset > 0) { + val skipTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "rowsSkipped", forceInline = true) + if (limit > 0) { + // In codegen, we skip the first `offset` rows, then take the first `limit - offset` rows. + val finalLimit = limit - offset + s""" + | if ($skipTerm < $offset) { + | $skipTerm += 1; + | } else if ($countTerm < $finalLimit) { + | $countTerm += 1; + | ${consume(ctx, input)} + | } """.stripMargin + } else { + s""" + | if ($skipTerm < $offset) { + | $skipTerm += 1; + | } else { + | ${consume(ctx, input)} + | } + """.stripMargin + } } else { - s""" - | if ($skipTerm < $offset) { - | $skipTerm += 1; - | } else { - | ${consume(ctx, input)} - | } - """.stripMargin + super.doConsume(ctx, input, row) } } @@ -278,9 +265,9 @@ case class GlobalLimitAndOffsetExec( } /** - * Take the first limit elements as defined by the sortOrder, and do projection if needed. - * This is logically equivalent to having a Limit operator after a [[SortExec]] operator, - * or having a [[ProjectExec]] operator between them. + * Take the first `limit` elements as defined by the sortOrder, then drop the first `offset` + * elements, and do projection if needed. This is logically equivalent to having a Limit and/or + * Offset operator after a [[SortExec]] operator, or having a [[ProjectExec]] operator between them. * This could have been named TopK, but Spark's top operator does the opposite in ordering * so we name it TakeOrdered to avoid confusion. */ @@ -297,12 +284,8 @@ case class TakeOrderedAndProjectExec( override def executeCollect(): Array[InternalRow] = { val ord = new LazilyGeneratedOrdering(sortOrder, child.output) - val data = if (offset > 0) { - child.execute().mapPartitionsInternal(_.map(_.copy())) - .takeOrdered(limit + offset)(ord).drop(offset) - } else { - child.execute().mapPartitionsInternal(_.map(_.copy())).takeOrdered(limit)(ord) - } + val limited = child.execute().mapPartitionsInternal(_.map(_.copy())).takeOrdered(limit)(ord) + val data = if (offset > 0) limited.drop(offset) else limited if (projectList != child.output) { val proj = UnsafeProjection.create(projectList, child.output) data.map(r => proj(r).copy()) @@ -328,15 +311,10 @@ case class TakeOrderedAndProjectExec( val singlePartitionRDD = if (childRDD.getNumPartitions == 1) { childRDD } else { - val localTopK = if (offset > 0) { - childRDD.mapPartitionsInternal { iter => - Utils.takeOrdered(iter.map(_.copy()), limit + offset)(ord) - } - } else { - childRDD.mapPartitionsInternal { iter => - Utils.takeOrdered(iter.map(_.copy()), limit)(ord) - } + val localTopK = childRDD.mapPartitionsInternal { iter => + Utils.takeOrdered(iter.map(_.copy()), limit)(ord) } + new ShuffledRowRDD( ShuffleExchangeExec.prepareShuffleDependency( localTopK, @@ -347,11 +325,8 @@ case class TakeOrderedAndProjectExec( readMetrics) } singlePartitionRDD.mapPartitionsInternal { iter => - val topK = if (offset > 0) { - Utils.takeOrdered(iter.map(_.copy()), limit + offset)(ord).drop(offset) - } else { - Utils.takeOrdered(iter.map(_.copy()), limit)(ord) - } + val limited = Utils.takeOrdered(iter.map(_.copy()), limit)(ord) + val topK = if (offset > 0) limited.drop(offset) else limited if (projectList != child.output) { val proj = UnsafeProjection.create(projectList, child.output) topK.map(r => proj(r))