Skip to content

Commit 25892a7

Browse files
committed
[SPARK-52949][PYTHON] Avoid roundtrip between RecordBatch and Table in Arrow-optimized Python UDTF
### What changes were proposed in this pull request? Avoids roundtrip between `RecordBatch` and `Table` in Arrow-optimized Python UDTF. ### Why are the changes needed? In the Arrow-optimized Python UDTF code path, there are unnecessary roundtrip between `RecordBatch` and `Table`. We can defer converting to `RecordBatch` to after the result verification. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? The existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #51659 from ueshin/issues/SPARK-52949/arrow_udtf. Authored-by: Takuya Ueshin <[email protected]> Signed-off-by: Takuya Ueshin <[email protected]>
1 parent 8b6b28b commit 25892a7

File tree

1 file changed

+13
-15
lines changed

1 file changed

+13
-15
lines changed

python/pyspark/worker.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1638,7 +1638,7 @@ def wrap_arrow_udtf(f, return_type):
16381638
import pandas as pd
16391639

16401640
arrow_return_type = to_arrow_type(
1641-
return_type, prefers_large_types=use_large_var_types(runner_conf)
1641+
return_type, prefers_large_types=prefers_large_var_types
16421642
)
16431643
return_type_size = len(return_type)
16441644

@@ -1757,12 +1757,12 @@ def wrap_arrow_udtf(f, return_type):
17571757
import pyarrow as pa
17581758

17591759
arrow_return_type = to_arrow_type(
1760-
return_type, prefers_large_types=use_large_var_types(runner_conf)
1760+
return_type, prefers_large_types=prefers_large_var_types
17611761
)
17621762
return_type_size = len(return_type)
17631763

17641764
def verify_result(result):
1765-
if not isinstance(result, pa.RecordBatch):
1765+
if not isinstance(result, pa.Table):
17661766
raise PySparkTypeError(
17671767
errorClass="INVALID_ARROW_UDTF_RETURN_TYPE",
17681768
messageParameters={
@@ -1776,20 +1776,20 @@ def verify_result(result):
17761776
# rows or columns. Note that we avoid using `df.empty` here because the
17771777
# result dataframe may contain an empty row. For example, when a UDTF is
17781778
# defined as follows: def eval(self): yield tuple().
1779-
if len(result) > 0 or len(result.columns) > 0:
1780-
if len(result.columns) != return_type_size:
1779+
if result.num_rows > 0 or result.num_columns > 0:
1780+
if result.num_columns != return_type_size:
17811781
raise PySparkRuntimeError(
17821782
errorClass="UDTF_RETURN_SCHEMA_MISMATCH",
17831783
messageParameters={
17841784
"expected": str(return_type_size),
1785-
"actual": str(len(result.columns)),
1785+
"actual": str(result.num_columns),
17861786
"func": f.__name__,
17871787
},
17881788
)
17891789

17901790
# Verify the type and the schema of the result.
17911791
verify_arrow_result(
1792-
pa.Table.from_batches([result], schema=pa.schema(list(arrow_return_type))),
1792+
result,
17931793
assign_cols_by_name=False,
17941794
expected_cols_and_types=[
17951795
(field.name, field.type) for field in arrow_return_type
@@ -1832,9 +1832,7 @@ def check_return_value(res):
18321832
def convert_to_arrow(data: Iterable):
18331833
data = list(check_return_value(data))
18341834
if len(data) == 0:
1835-
return [
1836-
pa.RecordBatch.from_pylist(data, schema=pa.schema(list(arrow_return_type)))
1837-
]
1835+
return pa.Table.from_pylist(data, schema=pa.schema(list(arrow_return_type)))
18381836

18391837
def raise_conversion_error(original_exception):
18401838
raise PySparkRuntimeError(
@@ -1849,7 +1847,7 @@ def raise_conversion_error(original_exception):
18491847
try:
18501848
return LocalDataToArrowConversion.convert(
18511849
data, return_type, prefers_large_var_types
1852-
).to_batches()
1850+
)
18531851
except PySparkValueError as e:
18541852
if e.getErrorClass() == "AXIS_LENGTH_MISMATCH":
18551853
raise PySparkRuntimeError(
@@ -1871,8 +1869,8 @@ def raise_conversion_error(original_exception):
18711869

18721870
def evaluate(*args: pa.ChunkedArray):
18731871
if len(args) == 0:
1874-
for batch in convert_to_arrow(func()):
1875-
yield verify_result(batch), arrow_return_type
1872+
for batch in verify_result(convert_to_arrow(func())).to_batches():
1873+
yield batch, arrow_return_type
18761874

18771875
else:
18781876
list_args = list(args)
@@ -1883,8 +1881,8 @@ def evaluate(*args: pa.ChunkedArray):
18831881
t, schema=schema, return_as_tuples=True
18841882
)
18851883
for row in rows:
1886-
for batch in convert_to_arrow(func(*row)):
1887-
yield verify_result(batch), arrow_return_type
1884+
for batch in verify_result(convert_to_arrow(func(*row))).to_batches():
1885+
yield batch, arrow_return_type
18881886

18891887
return evaluate
18901888

0 commit comments

Comments
 (0)