Skip to content

Commit 828b1f9

Browse files
bogao007HeartSaVioR
authored andcommitted
[SPARK-49463] Support ListState for TransformWithStateInPandas
### What changes were proposed in this pull request? Support ListState for TransformWithStateInPandas ### Why are the changes needed? Adding new functionality for TransformWithStateInPandas ### Does this PR introduce _any_ user-facing change? Yes ### How was this patch tested? Added new unit tests. ### Was this patch authored or co-authored using generative AI tooling? No Closes #47933 from bogao007/list-state. Authored-by: bogao007 <[email protected]> Signed-off-by: Jungtaek Lim <[email protected]>
1 parent 0c234bb commit 828b1f9

File tree

15 files changed

+5570
-645
lines changed

15 files changed

+5570
-645
lines changed

python/pyspark/sql/pandas/types.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,17 @@
5353
)
5454
from pyspark.errors import PySparkTypeError, UnsupportedOperationException, PySparkValueError
5555
from pyspark.loose_version import LooseVersion
56+
from pyspark.sql.utils import has_numpy
57+
58+
if has_numpy:
59+
import numpy as np
5660

5761
if TYPE_CHECKING:
5862
import pandas as pd
5963
import pyarrow as pa
6064

6165
from pyspark.sql.pandas._typing import SeriesLike as PandasSeriesLike
66+
from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike
6267

6368

6469
def to_arrow_type(
@@ -1344,3 +1349,34 @@ def _deduplicate_field_names(dt: DataType) -> DataType:
13441349
)
13451350
else:
13461351
return dt
1352+
1353+
1354+
def _to_numpy_type(type: DataType) -> Optional["np.dtype"]:
1355+
"""Convert Spark data type to NumPy type."""
1356+
import numpy as np
1357+
1358+
if type == ByteType():
1359+
return np.dtype("int8")
1360+
elif type == ShortType():
1361+
return np.dtype("int16")
1362+
elif type == IntegerType():
1363+
return np.dtype("int32")
1364+
elif type == LongType():
1365+
return np.dtype("int64")
1366+
elif type == FloatType():
1367+
return np.dtype("float32")
1368+
elif type == DoubleType():
1369+
return np.dtype("float64")
1370+
return None
1371+
1372+
1373+
def convert_pandas_using_numpy_type(
1374+
df: "PandasDataFrameLike", schema: StructType
1375+
) -> "PandasDataFrameLike":
1376+
for field in schema.fields:
1377+
if isinstance(
1378+
field.dataType, (ByteType, ShortType, LongType, FloatType, DoubleType, IntegerType)
1379+
):
1380+
np_type = _to_numpy_type(field.dataType)
1381+
df[field.name] = df[field.name].astype(np_type)
1382+
return df

python/pyspark/sql/streaming/StateMessage_pb2.py

Lines changed: 41 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -16,58 +16,69 @@
1616
#
1717
# -*- coding: utf-8 -*-
1818
# Generated by the protocol buffer compiler. DO NOT EDIT!
19+
# NO CHECKED-IN PROTOBUF GENCODE
1920
# source: StateMessage.proto
21+
# Protobuf Python Version: 5.27.3
2022
"""Generated protocol buffer code."""
21-
from google.protobuf.internal import builder as _builder
2223
from google.protobuf import descriptor as _descriptor
2324
from google.protobuf import descriptor_pool as _descriptor_pool
2425
from google.protobuf import symbol_database as _symbol_database
26+
from google.protobuf.internal import builder as _builder
2527

2628
# @@protoc_insertion_point(imports)
2729

2830
_sym_db = _symbol_database.Default()
2931

3032

3133
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
32-
b'\n\x12StateMessage.proto\x12.org.apache.spark.sql.execution.streaming.state"\xe9\x02\n\x0cStateRequest\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12\x66\n\x15statefulProcessorCall\x18\x02 \x01(\x0b\x32\x45.org.apache.spark.sql.execution.streaming.state.StatefulProcessorCallH\x00\x12\x64\n\x14stateVariableRequest\x18\x03 \x01(\x0b\x32\x44.org.apache.spark.sql.execution.streaming.state.StateVariableRequestH\x00\x12p\n\x1aimplicitGroupingKeyRequest\x18\x04 \x01(\x0b\x32J.org.apache.spark.sql.execution.streaming.state.ImplicitGroupingKeyRequestH\x00\x42\x08\n\x06method"H\n\rStateResponse\x12\x12\n\nstatusCode\x18\x01 \x01(\x05\x12\x14\n\x0c\x65rrorMessage\x18\x02 \x01(\t\x12\r\n\x05value\x18\x03 \x01(\x0c"\x89\x03\n\x15StatefulProcessorCall\x12X\n\x0esetHandleState\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.SetHandleStateH\x00\x12Y\n\rgetValueState\x18\x02 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x12X\n\x0cgetListState\x18\x03 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x12W\n\x0bgetMapState\x18\x04 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x42\x08\n\x06method"z\n\x14StateVariableRequest\x12X\n\x0evalueStateCall\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.ValueStateCallH\x00\x42\x08\n\x06method"\xe0\x01\n\x1aImplicitGroupingKeyRequest\x12X\n\x0esetImplicitKey\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.SetImplicitKeyH\x00\x12^\n\x11removeImplicitKey\x18\x02 \x01(\x0b\x32\x41.org.apache.spark.sql.execution.streaming.state.RemoveImplicitKeyH\x00\x42\x08\n\x06method"}\n\x10StateCallCommand\x12\x11\n\tstateName\x18\x01 \x01(\t\x12\x0e\n\x06schema\x18\x02 \x01(\t\x12\x46\n\x03ttl\x18\x03 \x01(\x0b\x32\x39.org.apache.spark.sql.execution.streaming.state.TTLConfig"\xe1\x02\n\x0eValueStateCall\x12\x11\n\tstateName\x18\x01 \x01(\t\x12H\n\x06\x65xists\x18\x02 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ExistsH\x00\x12\x42\n\x03get\x18\x03 \x01(\x0b\x32\x33.org.apache.spark.sql.execution.streaming.state.GetH\x00\x12\\\n\x10valueStateUpdate\x18\x04 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.ValueStateUpdateH\x00\x12\x46\n\x05\x63lear\x18\x05 \x01(\x0b\x32\x35.org.apache.spark.sql.execution.streaming.state.ClearH\x00\x42\x08\n\x06method"\x1d\n\x0eSetImplicitKey\x12\x0b\n\x03key\x18\x01 \x01(\x0c"\x13\n\x11RemoveImplicitKey"\x08\n\x06\x45xists"\x05\n\x03Get"!\n\x10ValueStateUpdate\x12\r\n\x05value\x18\x01 \x01(\x0c"\x07\n\x05\x43lear"\\\n\x0eSetHandleState\x12J\n\x05state\x18\x01 \x01(\x0e\x32;.org.apache.spark.sql.execution.streaming.state.HandleState"\x1f\n\tTTLConfig\x12\x12\n\ndurationMs\x18\x01 \x01(\x05*K\n\x0bHandleState\x12\x0b\n\x07\x43REATED\x10\x00\x12\x0f\n\x0bINITIALIZED\x10\x01\x12\x12\n\x0e\x44\x41TA_PROCESSED\x10\x02\x12\n\n\x06\x43LOSED\x10\x03\x62\x06proto3' # noqa: E501
34+
b'\n\x12StateMessage.proto\x12.org.apache.spark.sql.execution.streaming.state"\xe9\x02\n\x0cStateRequest\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12\x66\n\x15statefulProcessorCall\x18\x02 \x01(\x0b\x32\x45.org.apache.spark.sql.execution.streaming.state.StatefulProcessorCallH\x00\x12\x64\n\x14stateVariableRequest\x18\x03 \x01(\x0b\x32\x44.org.apache.spark.sql.execution.streaming.state.StateVariableRequestH\x00\x12p\n\x1aimplicitGroupingKeyRequest\x18\x04 \x01(\x0b\x32J.org.apache.spark.sql.execution.streaming.state.ImplicitGroupingKeyRequestH\x00\x42\x08\n\x06method"H\n\rStateResponse\x12\x12\n\nstatusCode\x18\x01 \x01(\x05\x12\x14\n\x0c\x65rrorMessage\x18\x02 \x01(\t\x12\r\n\x05value\x18\x03 \x01(\x0c"\x89\x03\n\x15StatefulProcessorCall\x12X\n\x0esetHandleState\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.SetHandleStateH\x00\x12Y\n\rgetValueState\x18\x02 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x12X\n\x0cgetListState\x18\x03 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x12W\n\x0bgetMapState\x18\x04 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x42\x08\n\x06method"\xd2\x01\n\x14StateVariableRequest\x12X\n\x0evalueStateCall\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.ValueStateCallH\x00\x12V\n\rlistStateCall\x18\x02 \x01(\x0b\x32=.org.apache.spark.sql.execution.streaming.state.ListStateCallH\x00\x42\x08\n\x06method"\xe0\x01\n\x1aImplicitGroupingKeyRequest\x12X\n\x0esetImplicitKey\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.SetImplicitKeyH\x00\x12^\n\x11removeImplicitKey\x18\x02 \x01(\x0b\x32\x41.org.apache.spark.sql.execution.streaming.state.RemoveImplicitKeyH\x00\x42\x08\n\x06method"}\n\x10StateCallCommand\x12\x11\n\tstateName\x18\x01 \x01(\t\x12\x0e\n\x06schema\x18\x02 \x01(\t\x12\x46\n\x03ttl\x18\x03 \x01(\x0b\x32\x39.org.apache.spark.sql.execution.streaming.state.TTLConfig"\xe1\x02\n\x0eValueStateCall\x12\x11\n\tstateName\x18\x01 \x01(\t\x12H\n\x06\x65xists\x18\x02 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ExistsH\x00\x12\x42\n\x03get\x18\x03 \x01(\x0b\x32\x33.org.apache.spark.sql.execution.streaming.state.GetH\x00\x12\\\n\x10valueStateUpdate\x18\x04 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.ValueStateUpdateH\x00\x12\x46\n\x05\x63lear\x18\x05 \x01(\x0b\x32\x35.org.apache.spark.sql.execution.streaming.state.ClearH\x00\x42\x08\n\x06method"\x90\x04\n\rListStateCall\x12\x11\n\tstateName\x18\x01 \x01(\t\x12H\n\x06\x65xists\x18\x02 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ExistsH\x00\x12T\n\x0clistStateGet\x18\x03 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.ListStateGetH\x00\x12T\n\x0clistStatePut\x18\x04 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.ListStatePutH\x00\x12R\n\x0b\x61ppendValue\x18\x05 \x01(\x0b\x32;.org.apache.spark.sql.execution.streaming.state.AppendValueH\x00\x12P\n\nappendList\x18\x06 \x01(\x0b\x32:.org.apache.spark.sql.execution.streaming.state.AppendListH\x00\x12\x46\n\x05\x63lear\x18\x07 \x01(\x0b\x32\x35.org.apache.spark.sql.execution.streaming.state.ClearH\x00\x42\x08\n\x06method"\x1d\n\x0eSetImplicitKey\x12\x0b\n\x03key\x18\x01 \x01(\x0c"\x13\n\x11RemoveImplicitKey"\x08\n\x06\x45xists"\x05\n\x03Get"!\n\x10ValueStateUpdate\x12\r\n\x05value\x18\x01 \x01(\x0c"\x07\n\x05\x43lear""\n\x0cListStateGet\x12\x12\n\niteratorId\x18\x01 \x01(\t"\x0e\n\x0cListStatePut"\x1c\n\x0b\x41ppendValue\x12\r\n\x05value\x18\x01 \x01(\x0c"\x0c\n\nAppendList"\\\n\x0eSetHandleState\x12J\n\x05state\x18\x01 \x01(\x0e\x32;.org.apache.spark.sql.execution.streaming.state.HandleState"\x1f\n\tTTLConfig\x12\x12\n\ndurationMs\x18\x01 \x01(\x05*K\n\x0bHandleState\x12\x0b\n\x07\x43REATED\x10\x00\x12\x0f\n\x0bINITIALIZED\x10\x01\x12\x12\n\x0e\x44\x41TA_PROCESSED\x10\x02\x12\n\n\x06\x43LOSED\x10\x03\x62\x06proto3' # noqa: E501
3335
)
3436

3537
_globals = globals()
36-
3738
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
3839
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "StateMessage_pb2", _globals)
3940
if not _descriptor._USE_C_DESCRIPTORS:
40-
DESCRIPTOR._options = None
41-
_globals["_HANDLESTATE"]._serialized_start = 1978
42-
_globals["_HANDLESTATE"]._serialized_end = 2053
41+
DESCRIPTOR._loaded_options = None
42+
_globals["_HANDLESTATE"]._serialized_start = 2694
43+
_globals["_HANDLESTATE"]._serialized_end = 2769
4344
_globals["_STATEREQUEST"]._serialized_start = 71
4445
_globals["_STATEREQUEST"]._serialized_end = 432
4546
_globals["_STATERESPONSE"]._serialized_start = 434
4647
_globals["_STATERESPONSE"]._serialized_end = 506
4748
_globals["_STATEFULPROCESSORCALL"]._serialized_start = 509
4849
_globals["_STATEFULPROCESSORCALL"]._serialized_end = 902
49-
_globals["_STATEVARIABLEREQUEST"]._serialized_start = 904
50-
_globals["_STATEVARIABLEREQUEST"]._serialized_end = 1026
51-
_globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_start = 1029
52-
_globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_end = 1253
53-
_globals["_STATECALLCOMMAND"]._serialized_start = 1255
54-
_globals["_STATECALLCOMMAND"]._serialized_end = 1380
55-
_globals["_VALUESTATECALL"]._serialized_start = 1383
56-
_globals["_VALUESTATECALL"]._serialized_end = 1736
57-
_globals["_SETIMPLICITKEY"]._serialized_start = 1738
58-
_globals["_SETIMPLICITKEY"]._serialized_end = 1767
59-
_globals["_REMOVEIMPLICITKEY"]._serialized_start = 1769
60-
_globals["_REMOVEIMPLICITKEY"]._serialized_end = 1788
61-
_globals["_EXISTS"]._serialized_start = 1790
62-
_globals["_EXISTS"]._serialized_end = 1798
63-
_globals["_GET"]._serialized_start = 1800
64-
_globals["_GET"]._serialized_end = 1805
65-
_globals["_VALUESTATEUPDATE"]._serialized_start = 1807
66-
_globals["_VALUESTATEUPDATE"]._serialized_end = 1840
67-
_globals["_CLEAR"]._serialized_start = 1842
68-
_globals["_CLEAR"]._serialized_end = 1849
69-
_globals["_SETHANDLESTATE"]._serialized_start = 1851
70-
_globals["_SETHANDLESTATE"]._serialized_end = 1943
71-
_globals["_TTLCONFIG"]._serialized_start = 1945
72-
_globals["_TTLCONFIG"]._serialized_end = 1976
50+
_globals["_STATEVARIABLEREQUEST"]._serialized_start = 905
51+
_globals["_STATEVARIABLEREQUEST"]._serialized_end = 1115
52+
_globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_start = 1118
53+
_globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_end = 1342
54+
_globals["_STATECALLCOMMAND"]._serialized_start = 1344
55+
_globals["_STATECALLCOMMAND"]._serialized_end = 1469
56+
_globals["_VALUESTATECALL"]._serialized_start = 1472
57+
_globals["_VALUESTATECALL"]._serialized_end = 1825
58+
_globals["_LISTSTATECALL"]._serialized_start = 1828
59+
_globals["_LISTSTATECALL"]._serialized_end = 2356
60+
_globals["_SETIMPLICITKEY"]._serialized_start = 2358
61+
_globals["_SETIMPLICITKEY"]._serialized_end = 2387
62+
_globals["_REMOVEIMPLICITKEY"]._serialized_start = 2389
63+
_globals["_REMOVEIMPLICITKEY"]._serialized_end = 2408
64+
_globals["_EXISTS"]._serialized_start = 2410
65+
_globals["_EXISTS"]._serialized_end = 2418
66+
_globals["_GET"]._serialized_start = 2420
67+
_globals["_GET"]._serialized_end = 2425
68+
_globals["_VALUESTATEUPDATE"]._serialized_start = 2427
69+
_globals["_VALUESTATEUPDATE"]._serialized_end = 2460
70+
_globals["_CLEAR"]._serialized_start = 2462
71+
_globals["_CLEAR"]._serialized_end = 2469
72+
_globals["_LISTSTATEGET"]._serialized_start = 2471
73+
_globals["_LISTSTATEGET"]._serialized_end = 2505
74+
_globals["_LISTSTATEPUT"]._serialized_start = 2507
75+
_globals["_LISTSTATEPUT"]._serialized_end = 2521
76+
_globals["_APPENDVALUE"]._serialized_start = 2523
77+
_globals["_APPENDVALUE"]._serialized_end = 2551
78+
_globals["_APPENDLIST"]._serialized_start = 2553
79+
_globals["_APPENDLIST"]._serialized_end = 2565
80+
_globals["_SETHANDLESTATE"]._serialized_start = 2567
81+
_globals["_SETHANDLESTATE"]._serialized_end = 2659
82+
_globals["_TTLCONFIG"]._serialized_start = 2661
83+
_globals["_TTLCONFIG"]._serialized_end = 2692
7384
# @@protoc_insertion_point(module_scope)

0 commit comments

Comments
 (0)