From 532d39fe7c98aafdc7d0b8806698027a221196a7 Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Fri, 7 Jun 2024 20:37:01 -0700 Subject: [PATCH 1/2] done --- .../sql/connect/streaming/readwriter.py | 9 +++++++- python/pyspark/sql/streaming/query.py | 6 +++--- .../sql/tests/streaming/test_streaming.py | 21 +++++++++++++++++++ 3 files changed, 32 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/connect/streaming/readwriter.py b/python/pyspark/sql/connect/streaming/readwriter.py index 4973bb5b6cf73..b5bb7f2a09128 100644 --- a/python/pyspark/sql/connect/streaming/readwriter.py +++ b/python/pyspark/sql/connect/streaming/readwriter.py @@ -446,6 +446,11 @@ def partitionBy(self, *cols: str) -> "DataStreamWriter": # type: ignore[misc] partitionBy.__doc__ = PySparkDataStreamWriter.partitionBy.__doc__ def queryName(self, queryName: str) -> "DataStreamWriter": + if not queryName or type(queryName) != str or len(queryName.strip()) == 0: + raise PySparkValueError( + error_class="VALUE_NOT_NON_EMPTY_STR", + message_parameters={"arg_name": "queryName", "arg_value": str(queryName)}, + ) self._write_proto.query_name = queryName return self @@ -605,7 +610,9 @@ def _start_internal( session=self._session, queryId=start_result.query_id.id, runId=start_result.query_id.run_id, - name=start_result.name, + # A Streaming Query cannot have empty string as name + # Spark throws error in that case, so this cast is safe + name=start_result.name if start_result.name != "" else None, ) if start_result.HasField("query_started_event_json"): diff --git a/python/pyspark/sql/streaming/query.py b/python/pyspark/sql/streaming/query.py index bcab8a104f1d9..d3d58da3562b6 100644 --- a/python/pyspark/sql/streaming/query.py +++ b/python/pyspark/sql/streaming/query.py @@ -114,7 +114,7 @@ def runId(self) -> str: @property def name(self) -> str: """ - Returns the user-specified name of the query, or null if not specified. + Returns the user-specified name of the query, or None if not specified. This name can be specified in the `org.apache.spark.sql.streaming.DataStreamWriter` as `dataframe.writeStream.queryName("query").start()`. This name, if set, must be unique across all active queries. @@ -127,14 +127,14 @@ def name(self) -> str: Returns ------- str - The user-specified name of the query, or null if not specified. + The user-specified name of the query, or None if not specified. Examples -------- >>> sdf = spark.readStream.format("rate").load() >>> sq = sdf.writeStream.format('memory').queryName('this_query').start() - Get the user-specified name of the query, or null if not specified. + Get the user-specified name of the query, or None if not specified. >>> sq.name 'this_query' diff --git a/python/pyspark/sql/tests/streaming/test_streaming.py b/python/pyspark/sql/tests/streaming/test_streaming.py index 1799f0d1336e5..dc1cfacc3f784 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming.py +++ b/python/pyspark/sql/tests/streaming/test_streaming.py @@ -24,6 +24,7 @@ from pyspark.sql.functions import lit from pyspark.sql.types import StructType, StructField, IntegerType, StringType from pyspark.testing.sqlutils import ReusedSQLTestCase +from pyspark.errors import PySparkValueError class StreamingTestsMixin: @@ -58,6 +59,26 @@ def test_streaming_query_functions_basic(self): finally: query.stop() + def test_streaming_query_name_edge_case(self): + # Query name should be None when not specified + q1 = self.spark.readStream.format("rate").load().writeStream.format("noop").start() + self.assertEqual(q1.name, None) + + # Cannot set query name to be an empty string + error_thrown = False + try: + ( + self.spark.readStream.format("rate") + .load() + .writeStream.format("noop") + .queryName("") + .start() + ) + except PySparkValueError as e: + error_thrown = True + + self.assertTrue(error_thrown) + def test_stream_trigger(self): df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") From cc5aa0aceda2635696ce6151ced057d98737c68e Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Sun, 9 Jun 2024 11:18:17 -0700 Subject: [PATCH 2/2] lint --- python/pyspark/sql/tests/streaming/test_streaming.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests/streaming/test_streaming.py b/python/pyspark/sql/tests/streaming/test_streaming.py index dc1cfacc3f784..ea5ccb3630882 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming.py +++ b/python/pyspark/sql/tests/streaming/test_streaming.py @@ -74,7 +74,7 @@ def test_streaming_query_name_edge_case(self): .queryName("") .start() ) - except PySparkValueError as e: + except PySparkValueError: error_thrown = True self.assertTrue(error_thrown)