Skip to content

Commit a12e29b

Browse files
Ngone51HyukjinKwon
authored andcommitted
[SPARK-34319][SQL] Resolve duplicate attributes for FlatMapCoGroupsInPandas/MapInPandas
### What changes were proposed in this pull request? Resolve duplicate attributes for `FlatMapCoGroupsInPandas`. ### Why are the changes needed? When performing self-join on top of `FlatMapCoGroupsInPandas`, analysis can fail because of conflicting attributes. For example, ```scala df = spark.createDataFrame([(1, 1)], ("column", "value")) row = df.groupby("ColUmn").cogroup( df.groupby("COLUMN") ).applyInPandas(lambda r, l: r + l, "column long, value long") row.join(row).show() ``` error: ```scala ... Conflicting attributes: column#163321L,value#163322L ;; ’Join Inner :- FlatMapCoGroupsInPandas [ColUmn#163312L], [COLUMN#163312L], <lambda>(column#163312L, value#163313L, column#163312L, value#163313L), [column#163321L, value#163322L] : :- Project [ColUmn#163312L, column#163312L, value#163313L] : : +- LogicalRDD [column#163312L, value#163313L], false : +- Project [COLUMN#163312L, column#163312L, value#163313L] : +- LogicalRDD [column#163312L, value#163313L], false +- FlatMapCoGroupsInPandas [ColUmn#163312L], [COLUMN#163312L], <lambda>(column#163312L, value#163313L, column#163312L, value#163313L), [column#163321L, value#163322L] :- Project [ColUmn#163312L, column#163312L, value#163313L] : +- LogicalRDD [column#163312L, value#163313L], false +- Project [COLUMN#163312L, column#163312L, value#163313L] +- LogicalRDD [column#163312L, value#163313L], false ... ``` ### Does this PR introduce _any_ user-facing change? yes, the query like the above example won't fail. ### How was this patch tested? Adde unit tests. Closes #31429 from Ngone51/fix-conflcting-attrs-of-FlatMapCoGroupsInPandas. Lead-authored-by: yi.wu <[email protected]> Co-authored-by: wuyi <[email protected]> Signed-off-by: HyukjinKwon <[email protected]> (cherry picked from commit e9362c2) Signed-off-by: HyukjinKwon <[email protected]>
1 parent 6831308 commit a12e29b

File tree

4 files changed

+70
-0
lines changed

4 files changed

+70
-0
lines changed

python/pyspark/sql/tests/test_pandas_cogrouped_map.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,18 @@ def test_case_insensitive_grouping_column(self):
203203
).applyInPandas(lambda r, l: r + l, "column long, value long").first()
204204
self.assertEqual(row.asDict(), Row(column=2, value=2).asDict())
205205

206+
def test_self_join(self):
207+
# SPARK-34319: self-join with FlatMapCoGroupsInPandas
208+
df = self.spark.createDataFrame([(1, 1)], ("column", "value"))
209+
210+
row = df.groupby("ColUmn").cogroup(
211+
df.groupby("COLUMN")
212+
).applyInPandas(lambda r, l: r + l, "column long, value long")
213+
214+
row = row.join(row).first()
215+
216+
self.assertEqual(row.asDict(), Row(column=2, value=2).asDict())
217+
206218
@staticmethod
207219
def _test_with_key(left, right, isLeft):
208220

python/pyspark/sql/tests/test_pandas_map.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,14 @@ def func(iterator):
112112
expected = df.collect()
113113
self.assertEqual(actual, expected)
114114

115+
def test_self_join(self):
116+
# SPARK-34319: self-join with MapInPandas
117+
df1 = self.spark.range(10)
118+
df2 = df1.mapInPandas(lambda iter: iter, 'id long')
119+
actual = df2.join(df2).collect()
120+
expected = df1.join(df1).collect()
121+
self.assertEqual(sorted(actual), sorted(expected))
122+
115123

116124
if __name__ == "__main__":
117125
from pyspark.sql.tests.test_pandas_map import * # noqa: F401

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1365,6 +1365,14 @@ class Analyzer(override val catalogManager: CatalogManager)
13651365
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
13661366
Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance()))))
13671367

1368+
case oldVersion @ FlatMapCoGroupsInPandas(_, _, _, output, _, _)
1369+
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
1370+
Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance()))))
1371+
1372+
case oldVersion @ MapInPandas(_, output, _)
1373+
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
1374+
Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance()))))
1375+
13681376
case oldVersion: Generate
13691377
if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty =>
13701378
val newOutput = oldVersion.generatorOutput.map(_.newInstance())

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,48 @@ class AnalysisSuite extends AnalysisTest with Matchers {
631631
Project(Seq(UnresolvedAttribute("temp0.a"), UnresolvedAttribute("temp1.a")), join))
632632
}
633633

634+
test("SPARK-34319: analysis fails on self-join with FlatMapCoGroupsInPandas") {
635+
val pythonUdf = PythonUDF("pyUDF", null,
636+
StructType(Seq(StructField("a", LongType))),
637+
Seq.empty,
638+
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
639+
true)
640+
val output = pythonUdf.dataType.asInstanceOf[StructType].toAttributes
641+
val project1 = Project(Seq(UnresolvedAttribute("a")), testRelation)
642+
val project2 = Project(Seq(UnresolvedAttribute("a")), testRelation2)
643+
val flatMapGroupsInPandas = FlatMapCoGroupsInPandas(
644+
Seq(UnresolvedAttribute("a")),
645+
Seq(UnresolvedAttribute("a")),
646+
pythonUdf,
647+
output,
648+
project1,
649+
project2)
650+
val left = SubqueryAlias("temp0", flatMapGroupsInPandas)
651+
val right = SubqueryAlias("temp1", flatMapGroupsInPandas)
652+
val join = Join(left, right, Inner, None, JoinHint.NONE)
653+
assertAnalysisSuccess(
654+
Project(Seq(UnresolvedAttribute("temp0.a"), UnresolvedAttribute("temp1.a")), join))
655+
}
656+
657+
test("SPARK-34319: analysis fails on self-join with MapInPandas") {
658+
val pythonUdf = PythonUDF("pyUDF", null,
659+
StructType(Seq(StructField("a", LongType))),
660+
Seq.empty,
661+
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
662+
true)
663+
val output = pythonUdf.dataType.asInstanceOf[StructType].toAttributes
664+
val project = Project(Seq(UnresolvedAttribute("a")), testRelation)
665+
val mapInPandas = MapInPandas(
666+
pythonUdf,
667+
output,
668+
project)
669+
val left = SubqueryAlias("temp0", mapInPandas)
670+
val right = SubqueryAlias("temp1", mapInPandas)
671+
val join = Join(left, right, Inner, None, JoinHint.NONE)
672+
assertAnalysisSuccess(
673+
Project(Seq(UnresolvedAttribute("temp0.a"), UnresolvedAttribute("temp1.a")), join))
674+
}
675+
634676
test("SPARK-24488 Generator with multiple aliases") {
635677
assertAnalysisSuccess(
636678
listRelation.select(Explode($"list").as("first_alias").as("second_alias")))

0 commit comments

Comments
 (0)