diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala index 133e964719682..1dc7a3b7eecb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.execution.joins import java.util.concurrent.TimeUnit._ -import scala.collection.mutable - import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -31,7 +29,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{RowIterator, SparkPlan} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} -import org.apache.spark.util.collection.BitSet +import org.apache.spark.util.collection.{BitSet, OpenHashSet} /** * Performs a hash join of two child relations by first shuffling the data using the join keys. @@ -136,10 +134,10 @@ case class ShuffledHashJoinExec( * Full outer shuffled hash join with unique join keys: * 1. Process rows from stream side by looking up hash relation. * Mark the matched rows from build side be looked up. - * A `BitSet` is used to track matched rows with key index. + * A bit set is used to track matched rows with key index. * 2. Process rows from build side by iterating hash relation. * Filter out rows from build side being matched already, - * by checking key index from `BitSet`. + * by checking key index from bit set. */ private def fullOuterJoinWithUniqueKey( streamIter: Iterator[InternalRow], @@ -150,9 +148,8 @@ case class ShuffledHashJoinExec( streamNullJoinRowWithBuild: => InternalRow => JoinedRow, buildNullRow: GenericInternalRow, streamNullRow: GenericInternalRow): Iterator[InternalRow] = { - // TODO(SPARK-32629):record metrics of extra BitSet/HashSet - // in full outer shuffled hash join val matchedKeys = new BitSet(hashedRelation.maxNumKeysIndex) + longMetric("buildDataSize") += matchedKeys.capacity / 8 // Process stream side with looking up hash relation val streamResultIter = streamIter.map { srow => @@ -198,11 +195,11 @@ case class ShuffledHashJoinExec( * Full outer shuffled hash join with non-unique join keys: * 1. Process rows from stream side by looking up hash relation. * Mark the matched rows from build side be looked up. - * A `HashSet[Long]` is used to track matched rows with + * A [[OpenHashSet]] (Long) is used to track matched rows with * key index (Int) and value index (Int) together. * 2. Process rows from build side by iterating hash relation. * Filter out rows from build side being matched already, - * by checking key index and value index from `HashSet`. + * by checking key index and value index from [[OpenHashSet]]. * * The "value index" is defined as the index of the tuple in the chain * of tuples having the same key. For example, if certain key is found thrice, @@ -218,9 +215,15 @@ case class ShuffledHashJoinExec( streamNullJoinRowWithBuild: => InternalRow => JoinedRow, buildNullRow: GenericInternalRow, streamNullRow: GenericInternalRow): Iterator[InternalRow] = { - // TODO(SPARK-32629):record metrics of extra BitSet/HashSet - // in full outer shuffled hash join - val matchedRows = new mutable.HashSet[Long] + val matchedRows = new OpenHashSet[Long] + TaskContext.get().addTaskCompletionListener[Unit](_ => { + // At the end of the task, update the task's memory usage for this + // [[OpenHashSet]] to track matched rows, which has two parts: + // [[OpenHashSet._bitset]] and [[OpenHashSet._data]]. + val bitSetEstimatedSize = matchedRows.getBitSet.capacity / 8 + val dataEstimatedSize = matchedRows.capacity * 8 + longMetric("buildDataSize") += bitSetEstimatedSize + dataEstimatedSize + }) def markRowMatched(keyIndex: Int, valueIndex: Int): Unit = { val rowIndex: Long = (keyIndex.toLong << 32) | valueIndex diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 078a3ba029e4b..4e10c27edb0e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.{FilterExec, RangeExec, SparkPlan, WholeSt import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -363,6 +364,41 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils } } + test("SPARK-32629: ShuffledHashJoin(full outer) metrics") { + val uniqueLeftDf = Seq(("1", "1"), ("11", "11")).toDF("key", "value") + val nonUniqueLeftDf = Seq(("1", "1"), ("1", "2"), ("11", "11")).toDF("key", "value") + val rightDf = (1 to 10).map(i => (i.toString, i.toString)).toDF("key2", "value") + Seq( + // Test unique key on build side + (uniqueLeftDf, rightDf, 11, 134228048, 10, 134221824), + // Test non-unique key on build side + (nonUniqueLeftDf, rightDf, 12, 134228552, 11, 134221824) + ).foreach { case (leftDf, rightDf, fojRows, fojBuildSize, rojRows, rojBuildSize) => + val fojDf = leftDf.hint("shuffle_hash").join( + rightDf, $"key" === $"key2", "full_outer") + fojDf.collect() + val fojPlan = fojDf.queryExecution.executedPlan.collectFirst { + case s: ShuffledHashJoinExec => s + } + assert(fojPlan.isDefined, "The query plan should have shuffled hash join") + testMetricsInSparkPlanOperator(fojPlan.get, + Map("numOutputRows" -> fojRows, "buildDataSize" -> fojBuildSize)) + + // Test right outer join as well to verify build data size to be different + // from full outer join. This makes sure we take extra BitSet/OpenHashSet + // for full outer join into account. + val rojDf = leftDf.hint("shuffle_hash").join( + rightDf, $"key" === $"key2", "right_outer") + rojDf.collect() + val rojPlan = rojDf.queryExecution.executedPlan.collectFirst { + case s: ShuffledHashJoinExec => s + } + assert(rojPlan.isDefined, "The query plan should have shuffled hash join") + testMetricsInSparkPlanOperator(rojPlan.get, + Map("numOutputRows" -> rojRows, "buildDataSize" -> rojBuildSize)) + } + } + test("BroadcastHashJoin(outer) metrics") { val df1 = Seq((1, "a"), (1, "b"), (4, "c")).toDF("key", "value") val df2 = Seq((1, "a"), (1, "b"), (2, "c"), (3, "d")).toDF("key2", "value") @@ -686,16 +722,6 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils } test("SPARK-28332: SQLMetric merge should handle -1 properly") { - def checkSparkPlanMetrics(plan: SparkPlan, expected: Map[String, Long]): Unit = { - expected.foreach { case (metricName: String, metricValue: Long) => - assert(plan.metrics.contains(metricName), s"The query plan should have metric $metricName") - val actualMetric = plan.metrics.get(metricName).get - assert(actualMetric.value == metricValue, - s"The query plan metric $metricName did not match, " + - s"expected:$metricValue, actual:${actualMetric.value}") - } - } - val df = testData.join(testData2.filter('b === 0), $"key" === $"a", "left_outer") df.collect() val plan = df.queryExecution.executedPlan @@ -706,7 +732,8 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils assert(exchanges.size == 2, "The query plan should have two shuffle exchanges") - checkSparkPlanMetrics(exchanges(0), Map("dataSize" -> 3200, "shuffleRecordsWritten" -> 100)) - checkSparkPlanMetrics(exchanges(1), Map("dataSize" -> 0, "shuffleRecordsWritten" -> 0)) + testMetricsInSparkPlanOperator(exchanges.head, + Map("dataSize" -> 3200, "shuffleRecordsWritten" -> 100)) + testMetricsInSparkPlanOperator(exchanges(1), Map("dataSize" -> 0, "shuffleRecordsWritten" -> 0)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala index ce726046c3215..81667d52e16ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala @@ -25,7 +25,7 @@ import org.apache.spark.TestUtils import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.execution.SparkPlanInfo +import org.apache.spark.sql.execution.{SparkPlan, SparkPlanInfo} import org.apache.spark.sql.execution.ui.{SparkPlanGraph, SQLAppStatusStore} import org.apache.spark.sql.internal.SQLConf.WHOLESTAGE_CODEGEN_ENABLED import org.apache.spark.sql.test.SQLTestUtils @@ -254,6 +254,24 @@ trait SQLMetricsTestUtils extends SQLTestUtils { } } } + + /** + * Verify if the metrics in `SparkPlan` operator are same as expected metrics. + * + * @param plan `SparkPlan` operator to check metrics + * @param expectedMetrics the expected metrics. The format is `metric name -> metric value`. + */ + protected def testMetricsInSparkPlanOperator( + plan: SparkPlan, + expectedMetrics: Map[String, Long]): Unit = { + expectedMetrics.foreach { case (metricName: String, metricValue: Long) => + assert(plan.metrics.contains(metricName), s"The query plan should have metric $metricName") + val actualMetric = plan.metrics(metricName) + assert(actualMetric.value == metricValue, + s"The query plan metric $metricName did not match, " + + s"expected:$metricValue, actual:${actualMetric.value}") + } + } }