Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connector.read;

import java.util.Optional;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.catalog.SupportsRead;

/**
* A mix in interface for {@link Scan}. Data sources can implement this interface to indicate
* {@link Scan}s can be merged.
*
* @since 3.4.0
*/
@Evolving
public interface SupportsMerge extends Scan {

/**
* Returns the merged scan.
*/
Optional<SupportsMerge> mergeWith(SupportsMerge other, SupportsRead table);
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, CTERelationDef, CTERelationRef, Filter, Join, LogicalPlan, Project, Subquery, WithCTE}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.{SCALAR_SUBQUERY, SCALAR_SUBQUERY_REFERENCE, TreePattern}
import org.apache.spark.sql.connector.catalog.SupportsRead
import org.apache.spark.sql.connector.read.SupportsMerge
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.DataType

Expand Down Expand Up @@ -279,6 +282,42 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] {
}
}

case (
DataSourceV2ScanRelation(newRelation, newScan: SupportsMerge, newOutput,
newKeyGroupedPartitioning, newOrdering),
DataSourceV2ScanRelation(cachedRelation, cachedScan: SupportsMerge, cachedOutput,
cachedKeyGroupedPartitioning, cachedOrdering)) =>
checkIdenticalPlans(newRelation, cachedRelation).flatMap { outputMap =>
val mappedNewKeyGroupedPartitioning =
newKeyGroupedPartitioning.map(_.map(mapAttributes(_, outputMap)))
if (mappedNewKeyGroupedPartitioning.map(_.map(_.canonicalized)) ==
cachedKeyGroupedPartitioning.map(_.map(_.canonicalized))) {
val mappedNewOrdering = newOrdering.map(_.map(mapAttributes(_, outputMap)))
if (mappedNewOrdering.map(_.map(_.canonicalized)) ==
Comment on lines +293 to +296

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[minor] can we simplify the if else structure here ? something like

if (isKeyGroupPartitioningSame && isOrderingSame) { 
 // merge scans and update cachedRelation 
} else  { 
  None
}

cachedOrdering.map(_.map(_.canonicalized))) {
Option(cachedScan.mergeWith(newScan,
cachedRelation.table.asInstanceOf[SupportsRead]).orElse(null)).map { mergedScan =>
// Keep the original attributes of cached in merged
val mergedAttributes = mergedScan.readSchema().toAttributes
val cachedOutputNameMap = cachedOutput.map(a => a.name -> a).toMap
val mergedOutput = mergedAttributes.map {
case a => cachedOutputNameMap.getOrElse(a.name, a)
}
// Build the map from new to merged
val mergedOutputNameMap = mergedOutput.map(a => a.name -> a).toMap
val newOutputMap =
AttributeMap(newOutput.map(a => a -> mergedOutputNameMap(a.name).toAttribute))
DataSourceV2ScanRelation(cachedRelation, mergedScan, mergedOutput,
cachedKeyGroupedPartitioning, cachedOrdering) -> newOutputMap
}
} else {
None
}
} else {
None
}
}

// Otherwise merging is not possible.
case _ => None
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ trait FileScan extends Scan

protected def seqToString(seq: Seq[Any]): String = seq.mkString("[", ", ", "]")

private lazy val (normalizedPartitionFilters, normalizedDataFilters) = {
protected lazy val (normalizedPartitionFilters, normalizedDataFilters) = {
val partitionFilterAttributes = AttributeSet(partitionFilters).map(a => a.name -> a).toMap
val normalizedPartitionFilters = ExpressionSet(partitionFilters.map(
QueryPlan.normalizeExpressions(_, fileIndex.partitionSchema.toAttributes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
*/
package org.apache.spark.sql.execution.datasources.v2.parquet

import java.util.Optional

import scala.collection.JavaConverters._

import org.apache.hadoop.conf.Configuration
Expand All @@ -24,8 +26,9 @@ import org.apache.parquet.hadoop.ParquetInputFormat

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.connector.catalog.SupportsRead
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
import org.apache.spark.sql.connector.read.PartitionReaderFactory
import org.apache.spark.sql.connector.read.{PartitionReaderFactory, SupportsMerge}
import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, PartitioningAwareFileIndex, RowIndexUtil}
import org.apache.spark.sql.execution.datasources.parquet.{ParquetOptions, ParquetReadSupport, ParquetWriteSupport}
import org.apache.spark.sql.execution.datasources.v2.FileScan
Expand All @@ -46,7 +49,7 @@ case class ParquetScan(
options: CaseInsensitiveStringMap,
pushedAggregate: Option[Aggregation] = None,
partitionFilters: Seq[Expression] = Seq.empty,
dataFilters: Seq[Expression] = Seq.empty) extends FileScan {
dataFilters: Seq[Expression] = Seq.empty) extends FileScan with SupportsMerge {
override def isSplitable(path: Path): Boolean = {
// If aggregate is pushed down, only the file footer will be read once,
// so file should not be split across multiple tasks.
Expand Down Expand Up @@ -106,15 +109,18 @@ case class ParquetScan(
new ParquetOptions(options.asCaseSensitiveMap.asScala.toMap, sqlConf))
}

private def pushedDownAggEqual(p: ParquetScan) = {
if (pushedAggregate.nonEmpty && p.pushedAggregate.nonEmpty) {
AggregatePushDownUtils.equivalentAggregations(pushedAggregate.get, p.pushedAggregate.get)
} else {
pushedAggregate.isEmpty && p.pushedAggregate.isEmpty
}
}
Comment on lines +112 to +118

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we move this to FileScan itself ? OrcScan also has some duplicate code


override def equals(obj: Any): Boolean = obj match {
case p: ParquetScan =>
val pushedDownAggEqual = if (pushedAggregate.nonEmpty && p.pushedAggregate.nonEmpty) {
AggregatePushDownUtils.equivalentAggregations(pushedAggregate.get, p.pushedAggregate.get)
} else {
pushedAggregate.isEmpty && p.pushedAggregate.isEmpty
}
super.equals(p) && dataSchema == p.dataSchema && options == p.options &&
equivalentFilters(pushedFilters, p.pushedFilters) && pushedDownAggEqual
equivalentFilters(pushedFilters, p.pushedFilters) && pushedDownAggEqual(p)
case _ => false
}

Expand All @@ -138,4 +144,29 @@ case class ParquetScan(
Map("PushedAggregation" -> pushedAggregationsStr) ++
Map("PushedGroupBy" -> pushedGroupByStr)
}

override def mergeWith(other: SupportsMerge, table: SupportsRead): Optional[SupportsMerge] = {
if (other.isInstanceOf[ParquetScan]) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can replace this with case match

val o = other.asInstanceOf[ParquetScan]
if (fileIndex == o.fileIndex &&
options == o.options &&
dataSchema == o.dataSchema &&
equivalentFilters(pushedFilters, o.pushedFilters) &&
pushedDownAggEqual(o) &&
normalizedPartitionFilters == o.normalizedPartitionFilters &&
normalizedDataFilters == o.normalizedDataFilters) {
Comment on lines +156 to +157

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[question] should we just disjunct these diff filters from scans and run a boolean simplification on top of it ? to handle the cases with diff partition and data filter on the scans ?

Are we expecting some heuristic here ? as if when combining the filters will be useful ?

val builder = table.newScanBuilder(options).asInstanceOf[ParquetScanBuilder]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[question] should we add assertion for table.newScanBuilder should be a instance of ParquetScanBuilder ?

pushedAggregate.map(builder.pushAggregation)
builder.pushFilters(dataFilters ++ partitionFilters)
builder.pruneColumns(readSchema().merge(o.readSchema()))
val scan = builder.build().asInstanceOf[ParquetScan]

Optional.of(scan)
} else {
Optional.empty()
}
} else {
Optional.empty()
}
}
}
40 changes: 40 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2177,6 +2177,46 @@ class SubquerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
}
}

test("SPARK-40259: Merge non-correlated scalar subqueries with Parquet DSv2 sources") {
withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") {
withTempPath { path =>
testData
.withColumn("partition", $"key" % 10)
.write
.mode(SaveMode.Overwrite)
.partitionBy("partition")
.parquet(path.getCanonicalPath)
withTempView("td") {
spark.read.parquet(path.getCanonicalPath).createOrReplaceTempView("td")
Seq(false, true).foreach { enableAQE =>
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString) {
val df = sql(
"""
|SELECT
| (SELECT sum(key) FROM td WHERE partition < 5),
| (SELECT sum(key) FROM td WHERE partition >= 5),
| (SELECT sum(value) FROM td WHERE partition < 5),
| (SELECT sum(value) FROM td WHERE partition >= 5)
""".stripMargin)

checkAnswer(df, Row(2450, 2600, 2450.0, 2600.0) :: Nil)

val plan = df.queryExecution.executedPlan
val subqueryIds = collectWithSubqueries(plan) { case s: SubqueryExec => s.id }
val reusedSubqueryIds = collectWithSubqueries(plan) {
case rs: ReusedSubqueryExec => rs.child.id
}

assert(subqueryIds.size == 2, "Missing or unexpected SubqueryExec in the plan")
assert(reusedSubqueryIds.size == 2,
"Missing or unexpected reused ReusedSubqueryExec in the plan")
}
}
}
}
}
}

test("SPARK-39355: Single column uses quoted to construct UnresolvedAttribute") {
checkAnswer(
sql("""
Expand Down