Skip to content

Commit 6b0e391

Browse files
viiryadongjoon-hyun
authored andcommitted
[SPARK-29427][SQL] Add API to convert RelationalGroupedDataset to KeyValueGroupedDataset
### What changes were proposed in this pull request? This PR proposes to add `as` API to RelationalGroupedDataset. It creates KeyValueGroupedDataset instance using given grouping expressions, instead of a typed function in groupByKey API. Because it can leverage existing columns, it can use existing data partition, if any, when doing operations like cogroup. ### Why are the changes needed? Currently if users want to do cogroup on DataFrames, there is no good way to do except for KeyValueGroupedDataset. 1. KeyValueGroupedDataset ignores existing data partition if any. That is a problem. 2. groupByKey calls typed function to create additional keys. You can not reuse existing columns, if you just need grouping by them. ```scala // df1 and df2 are certainly partitioned and sorted. val df1 = Seq((1, 2, 3), (2, 3, 4)).toDF("a", "b", "c") .repartition($"a").sortWithinPartitions("a") val df2 = Seq((1, 2, 4), (2, 3, 5)).toDF("a", "b", "c") .repartition($"a").sortWithinPartitions("a") ``` ```scala // This groupBy.as.cogroup won't unnecessarily repartition the data val df3 = df1.groupBy("a").as[Int] .cogroup(df2.groupBy("a").as[Int]) { case (key, data1, data2) => data1.zip(data2).map { p => p._1.getInt(2) + p._2.getInt(2) } } ``` ``` == Physical Plan == *(5) SerializeFromObject [input[0, int, false] AS value#11247] +- CoGroup org.apache.spark.sql.DataFrameSuite$$Lambda$4922/12067092816eec1b6f, a#11209: int, createexternalrow(a#11209, b#11210, c#11211, StructField(a,IntegerType,false), StructField(b,IntegerType,false), StructField(c,IntegerType,false)), createexternalrow(a#11225, b#11226, c#11227, StructField(a,IntegerType,false), StructField(b,IntegerType,false), StructField(c,IntegerType,false)), [a#11209], [a#11225], [a#11209, b#11210, c#11211], [a#11225, b#11226, c#11227], obj#11246: int :- *(2) Sort [a#11209 ASC NULLS FIRST], false, 0 : +- Exchange hashpartitioning(a#11209, 5), false, [id=#10218] : +- *(1) Project [_1#11202 AS a#11209, _2#11203 AS b#11210, _3#11204 AS c#11211] : +- *(1) LocalTableScan [_1#11202, _2#11203, _3#11204] +- *(4) Sort [a#11225 ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(a#11225, 5), false, [id=#10223] +- *(3) Project [_1#11218 AS a#11225, _2#11219 AS b#11226, _3#11220 AS c#11227] +- *(3) LocalTableScan [_1#11218, _2#11219, _3#11220] ``` ```scala // Current approach creates additional AppendColumns and repartition data again val df4 = df1.groupByKey(r => r.getInt(0)).cogroup(df2.groupByKey(r => r.getInt(0))) { case (key, data1, data2) => data1.zip(data2).map { p => p._1.getInt(2) + p._2.getInt(2) } } ``` ``` == Physical Plan == *(7) SerializeFromObject [input[0, int, false] AS value#11257] +- CoGroup org.apache.spark.sql.DataFrameSuite$$Lambda$4933/138102700737171997, value#11252: int, createexternalrow(a#11209, b#11210, c#11211, StructField(a,IntegerType,false), StructField(b,IntegerType,false), StructField(c,IntegerType,false)), createexternalrow(a#11225, b#11226, c#11227, StructField(a,IntegerType,false), StructField(b,IntegerType,false), StructField(c,IntegerType,false)), [value#11252], [value#11254], [a#11209, b#11210, c#11211], [a#11225, b#11226, c#11227], obj#11256: int :- *(3) Sort [value#11252 ASC NULLS FIRST], false, 0 : +- Exchange hashpartitioning(value#11252, 5), true, [id=#10302] : +- AppendColumns org.apache.spark.sql.DataFrameSuite$$Lambda$4930/19529195347ce07f47, createexternalrow(a#11209, b#11210, c#11211, StructField(a,IntegerType,false), StructField(b,IntegerType,false), StructField(c,IntegerType,false)), [input[0, int, false] AS value#11252] : +- *(2) Sort [a#11209 ASC NULLS FIRST], false, 0 : +- Exchange hashpartitioning(a#11209, 5), false, [id=#10297] : +- *(1) Project [_1#11202 AS a#11209, _2#11203 AS b#11210, _3#11204 AS c#11211] : +- *(1) LocalTableScan [_1#11202, _2#11203, _3#11204] +- *(6) Sort [value#11254 ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(value#11254, 5), true, [id=#10312] +- AppendColumns org.apache.spark.sql.DataFrameSuite$$Lambda$4932/15265288491f0e0c1f, createexternalrow(a#11225, b#11226, c#11227, StructField(a,IntegerType,false), StructField(b,IntegerType,false), StructField(c,IntegerType,false)), [input[0, int, false] AS value#11254] +- *(5) Sort [a#11225 ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(a#11225, 5), false, [id=#10307] +- *(4) Project [_1#11218 AS a#11225, _2#11219 AS b#11226, _3#11220 AS c#11227] +- *(4) LocalTableScan [_1#11218, _2#11219, _3#11220] ``` ### Does this PR introduce any user-facing change? Yes, this adds a new `as` API to RelationalGroupedDataset. Users can use it to create KeyValueGroupedDataset and do cogroup. ### How was this patch tested? Unit tests. Closes #26509 from viirya/SPARK-29427-2. Lead-authored-by: Liang-Chi Hsieh <[email protected]> Co-authored-by: Liang-Chi Hsieh <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 6e581cf commit 6b0e391

File tree

3 files changed

+112
-0
lines changed

3 files changed

+112
-0
lines changed

sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import org.apache.spark.annotation.Stable
2626
import org.apache.spark.api.python.PythonEvalType
2727
import org.apache.spark.broadcast.Broadcast
2828
import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction}
29+
import org.apache.spark.sql.catalyst.encoders.encoderFor
2930
import org.apache.spark.sql.catalyst.expressions._
3031
import org.apache.spark.sql.catalyst.expressions.aggregate._
3132
import org.apache.spark.sql.catalyst.plans.logical._
@@ -129,6 +130,37 @@ class RelationalGroupedDataset protected[sql](
129130
(inputExpr: Expression) => exprToFunc(inputExpr)
130131
}
131132

133+
/**
134+
* Returns a `KeyValueGroupedDataset` where the data is grouped by the grouping expressions
135+
* of current `RelationalGroupedDataset`.
136+
*
137+
* @since 3.0.0
138+
*/
139+
def as[K: Encoder, T: Encoder]: KeyValueGroupedDataset[K, T] = {
140+
val keyEncoder = encoderFor[K]
141+
val valueEncoder = encoderFor[T]
142+
143+
// Resolves grouping expressions.
144+
val dummyPlan = Project(groupingExprs.map(alias), LocalRelation(df.logicalPlan.output))
145+
val analyzedPlan = df.sparkSession.sessionState.analyzer.execute(dummyPlan)
146+
.asInstanceOf[Project]
147+
df.sparkSession.sessionState.analyzer.checkAnalysis(analyzedPlan)
148+
val aliasedGroupings = analyzedPlan.projectList
149+
150+
// Adds the grouping expressions that are not in base DataFrame into outputs.
151+
val addedCols = aliasedGroupings.filter(g => !df.logicalPlan.outputSet.contains(g.toAttribute))
152+
val qe = Dataset.ofRows(
153+
df.sparkSession,
154+
Project(df.logicalPlan.output ++ addedCols, df.logicalPlan)).queryExecution
155+
156+
new KeyValueGroupedDataset(
157+
keyEncoder,
158+
valueEncoder,
159+
qe,
160+
df.logicalPlan.output,
161+
aliasedGroupings.map(_.toAttribute))
162+
}
163+
132164
/**
133165
* (Scala-specific) Compute aggregates by specifying the column names and
134166
* aggregate methods. The resulting `DataFrame` will also contain the grouping columns.

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import org.scalatest.Matchers._
3030
import org.apache.spark.SparkException
3131
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd}
3232
import org.apache.spark.sql.catalyst.TableIdentifier
33+
import org.apache.spark.sql.catalyst.encoders.RowEncoder
3334
import org.apache.spark.sql.catalyst.expressions.Uuid
3435
import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
3536
import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Union}
@@ -2221,4 +2222,62 @@ class DataFrameSuite extends QueryTest with SharedSparkSession {
22212222
val idTuples = sampled.collect().map(row => row.getLong(0) -> row.getLong(1))
22222223
assert(idTuples.length == idTuples.toSet.size)
22232224
}
2225+
2226+
test("groupBy.as") {
2227+
val df1 = Seq((1, 2, 3), (2, 3, 4)).toDF("a", "b", "c")
2228+
.repartition($"a", $"b").sortWithinPartitions("a", "b")
2229+
val df2 = Seq((1, 2, 4), (2, 3, 5)).toDF("a", "b", "c")
2230+
.repartition($"a", $"b").sortWithinPartitions("a", "b")
2231+
2232+
implicit val valueEncoder = RowEncoder(df1.schema)
2233+
2234+
val df3 = df1.groupBy("a", "b").as[GroupByKey, Row]
2235+
.cogroup(df2.groupBy("a", "b").as[GroupByKey, Row]) { case (_, data1, data2) =>
2236+
data1.zip(data2).map { p =>
2237+
p._1.getInt(2) + p._2.getInt(2)
2238+
}
2239+
}.toDF
2240+
2241+
checkAnswer(df3.sort("value"), Row(7) :: Row(9) :: Nil)
2242+
2243+
// Assert that no extra shuffle introduced by cogroup.
2244+
val exchanges = df3.queryExecution.executedPlan.collect {
2245+
case h: ShuffleExchangeExec => h
2246+
}
2247+
assert(exchanges.size == 2)
2248+
}
2249+
2250+
test("groupBy.as: custom grouping expressions") {
2251+
val df1 = Seq((1, 2, 3), (2, 3, 4)).toDF("a1", "b", "c")
2252+
.repartition($"a1", $"b").sortWithinPartitions("a1", "b")
2253+
val df2 = Seq((1, 2, 4), (2, 3, 5)).toDF("a1", "b", "c")
2254+
.repartition($"a1", $"b").sortWithinPartitions("a1", "b")
2255+
2256+
implicit val valueEncoder = RowEncoder(df1.schema)
2257+
2258+
val groupedDataset1 = df1.groupBy(($"a1" + 1).as("a"), $"b").as[GroupByKey, Row]
2259+
val groupedDataset2 = df2.groupBy(($"a1" + 1).as("a"), $"b").as[GroupByKey, Row]
2260+
2261+
val df3 = groupedDataset1
2262+
.cogroup(groupedDataset2) { case (_, data1, data2) =>
2263+
data1.zip(data2).map { p =>
2264+
p._1.getInt(2) + p._2.getInt(2)
2265+
}
2266+
}.toDF
2267+
2268+
checkAnswer(df3.sort("value"), Row(7) :: Row(9) :: Nil)
2269+
}
2270+
2271+
test("groupBy.as: throw AnalysisException for unresolved grouping expr") {
2272+
val df = Seq((1, 2, 3), (2, 3, 4)).toDF("a", "b", "c")
2273+
2274+
implicit val valueEncoder = RowEncoder(df.schema)
2275+
2276+
val err = intercept[AnalysisException] {
2277+
df.groupBy($"d", $"b").as[GroupByKey, Row]
2278+
}
2279+
assert(err.getMessage.contains("cannot resolve '`d`'"))
2280+
}
22242281
}
2282+
2283+
case class GroupByKey(a: Int, b: Int)

sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1861,6 +1861,27 @@ class DatasetSuite extends QueryTest with SharedSparkSession {
18611861
}
18621862
}
18631863
}
1864+
1865+
test("groupBy.as") {
1866+
val df1 = Seq(DoubleData(1, "one"), DoubleData(2, "two"), DoubleData( 3, "three")).toDS()
1867+
.repartition($"id").sortWithinPartitions("id")
1868+
val df2 = Seq(DoubleData(5, "one"), DoubleData(1, "two"), DoubleData( 3, "three")).toDS()
1869+
.repartition($"id").sortWithinPartitions("id")
1870+
1871+
val df3 = df1.groupBy("id").as[Int, DoubleData]
1872+
.cogroup(df2.groupBy("id").as[Int, DoubleData]) { case (key, data1, data2) =>
1873+
if (key == 1) {
1874+
Iterator(DoubleData(key, (data1 ++ data2).foldLeft("")((cur, next) => cur + next.val1)))
1875+
} else Iterator.empty
1876+
}
1877+
checkDataset(df3, DoubleData(1, "onetwo"))
1878+
1879+
// Assert that no extra shuffle introduced by cogroup.
1880+
val exchanges = df3.queryExecution.executedPlan.collect {
1881+
case h: ShuffleExchangeExec => h
1882+
}
1883+
assert(exchanges.size == 2)
1884+
}
18641885
}
18651886

18661887
object AssertExecutionId {

0 commit comments

Comments
 (0)