Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,18 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] with PredicateHelper {
}
}

type Header = Project
case class Header(elements: Seq[(Literal, Attribute)], child: LogicalPlan) {
def merged: Boolean = elements.size > 1
}

private def extractCommonScalarSubqueries(plan: LogicalPlan) = {
// Plan of subqueries and a flag is the plan is merged
val cache = ListBuffer.empty[(Header, Boolean)]
val cache = ListBuffer.empty[Header]
val newPlan = removeReferences(insertReferences(plan, cache), cache)
if (cache.nonEmpty) {
val scalarSubqueries = cache.map { case (header, _) => ScalarSubquery(header) }.toSeq
val scalarSubqueries = cache.map { case Header(elements, child) =>
ScalarSubquery(createProject(elements, child))
}.toSeq
CommonScalarSubqueries(scalarSubqueries, newPlan)
} else {
newPlan
Expand All @@ -116,7 +120,7 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] with PredicateHelper {
// First traversal builds up the cache and inserts `ScalarSubqueryReference`s to the plan.
private def insertReferences(
plan: LogicalPlan,
cache: ListBuffer[(Header, Boolean)]): LogicalPlan = {
cache: ListBuffer[Header]): LogicalPlan = {
plan.transformAllExpressionsWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY)) {
case s: ScalarSubquery if s.children.isEmpty =>
val (subqueryIndex, headerIndex) = cacheSubquery(s.plan, cache)
Expand All @@ -129,28 +133,28 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] with PredicateHelper {
// [[Project]] node.
private def cacheSubquery(
plan: LogicalPlan,
cache: ListBuffer[(Header, Boolean)]): (Int, Int) = {
cache: ListBuffer[Header]): (Int, Int) = {
val firstOutput = plan.output.head
cache.zipWithIndex.collectFirst(Function.unlift { case ((header, merged), subqueryIndex) =>
cache.zipWithIndex.collectFirst(Function.unlift { case (header, subqueryIndex) =>
checkIdenticalPlans(plan, header.child)
.map((subqueryIndex, header, header.child, _, merged))
.map((subqueryIndex, header, header.child, _))
.orElse(tryMergePlans(plan, header.child).map {
case (mergedPlan, outputMap) => (subqueryIndex, header, mergedPlan, outputMap, true)
case (mergedPlan, outputMap) => (subqueryIndex, header, mergedPlan, outputMap)
})
}).map { case (subqueryIndex, header, mergedPlan, outputMap, merged) =>
}).map { case (subqueryIndex, header, mergedPlan, outputMap) =>
val mappedFirstOutput = mapAttributes(firstOutput, outputMap)
val headerElements = getHeaderElements(header)
var headerIndex = headerElements.indexWhere {
var headerIndex = header.elements.indexWhere {
case (_, attribute) => attribute.exprId == mappedFirstOutput.exprId
}
if (headerIndex == -1) {
val newHeaderElements = headerElements :+ (Literal(firstOutput.name) -> mappedFirstOutput)
cache(subqueryIndex) = createHeader(newHeaderElements, mergedPlan) -> merged
headerIndex = headerElements.size
val newHeaderElements =
header.elements :+ (Literal(firstOutput.name) -> mappedFirstOutput)
cache(subqueryIndex) = Header(newHeaderElements, mergedPlan)
headerIndex = header.elements.size
}
subqueryIndex -> headerIndex
}.getOrElse {
cache += createHeader(Seq(Literal(firstOutput.name) -> firstOutput), plan) -> false
cache += Header(Seq(Literal(firstOutput.name) -> firstOutput), plan)
cache.length - 1 -> 0
}
}
Expand Down Expand Up @@ -281,24 +285,16 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] with PredicateHelper {
})
}

private def createHeader(headerElements: Seq[(Literal, Attribute)], plan: LogicalPlan): Header = {
private def createProject(elements: Seq[(Literal, Attribute)], plan: LogicalPlan): Project = {
Project(
Seq(Alias(
CreateNamedStruct(headerElements.flatMap {
CreateNamedStruct(elements.flatMap {
case (name, attribute) => Seq(name, attribute)
}),
"mergedValue")()),
plan)
}

private def getHeaderElements(header: Header) = {
val mergedValue =
header.projectList.head.asInstanceOf[Alias].child.asInstanceOf[CreateNamedStruct]
mergedValue.children.grouped(2).map {
case Seq(name: Literal, attribute: Attribute) => name -> attribute
}.toSeq
}

private def mapAttributes[T <: Expression](expr: T, outputMap: AttributeMap[Attribute]) = {
expr.transform {
case a: Attribute => outputMap.getOrElse(a, a)
Expand Down Expand Up @@ -373,16 +369,16 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] with PredicateHelper {
// merged from multiple subqueries or `ScalarSubquery(original plan)` it it isn't.
private def removeReferences(
plan: LogicalPlan,
cache: ListBuffer[(Header, Boolean)]): LogicalPlan = {
cache: ListBuffer[Header]): LogicalPlan = {
val nonMergedSubqueriesBefore = cache.scanLeft(0) {
case (nonMergedSubqueriesBefore, (_, merged)) =>
nonMergedSubqueriesBefore + (if (merged) 0 else 1)
case (nonMergedSubqueriesBefore, header) =>
nonMergedSubqueriesBefore + (if (header.merged) 0 else 1)
}.toArray
val newPlan =
plan.transformAllExpressionsWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY_REFERENCE)) {
case ssr: ScalarSubqueryReference =>
val (header, merged) = cache(ssr.subqueryIndex)
if (merged) {
val header = cache(ssr.subqueryIndex)
if (header.merged) {
if (nonMergedSubqueriesBefore(ssr.subqueryIndex) > 0) {
ssr.copy(subqueryIndex =
ssr.subqueryIndex - nonMergedSubqueriesBefore(ssr.subqueryIndex))
Expand All @@ -394,7 +390,7 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] with PredicateHelper {
}
}
cache.zipWithIndex.collect {
case ((_, merged), i) if !merged => i
case (header, i) if !header.merged => i
}.reverse.foreach(cache.remove)
newPlan
}
Expand Down