Skip to content

Commit f92d276

Browse files
viiryaHyukjinKwon
authored andcommitted
[SPARK-25811][PYSPARK] Raise a proper error when unsafe cast is detected by PyArrow
## What changes were proposed in this pull request? Since 0.11.0, PyArrow supports to raise an error for unsafe cast ([PR](apache/arrow#2504)). We should use it to raise a proper error for pandas udf users when such cast is detected. Added a SQL config `spark.sql.execution.pandas.arrowSafeTypeConversion` to disable Arrow safe type check. ## How was this patch tested? Added test and manually test. Closes #22807 from viirya/SPARK-25811. Authored-by: Liang-Chi Hsieh <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent 098a2c4 commit f92d276

File tree

7 files changed

+141
-7
lines changed

7 files changed

+141
-7
lines changed

docs/sql-migration-guide-upgrade.md

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,54 @@ displayTitle: Spark SQL Upgrading Guide
4141

4242
- Since Spark 3.0, JSON datasource and JSON function `schema_of_json` infer TimestampType from string values if they match to the pattern defined by the JSON option `timestampFormat`. Set JSON option `inferTimestamp` to `false` to disable such type inferring.
4343

44+
- In PySpark, when Arrow optimization is enabled, if Arrow version is higher than 0.11.0, Arrow can perform safe type conversion when converting Pandas.Series to Arrow array during serialization. Arrow will raise errors when detecting unsafe type conversion like overflow. Setting `spark.sql.execution.pandas.arrowSafeTypeConversion` to true can enable it. The default setting is false. PySpark's behavior for Arrow versions is illustrated in the table below:
45+
<table class="table">
46+
<tr>
47+
<th>
48+
<b>PyArrow version</b>
49+
</th>
50+
<th>
51+
<b>Integer Overflow</b>
52+
</th>
53+
<th>
54+
<b>Floating Point Truncation</b>
55+
</th>
56+
</tr>
57+
<tr>
58+
<th>
59+
<b>version < 0.11.0</b>
60+
</th>
61+
<th>
62+
<b>Raise error</b>
63+
</th>
64+
<th>
65+
<b>Silently allows</b>
66+
</th>
67+
</tr>
68+
<tr>
69+
<th>
70+
<b>version > 0.11.0, arrowSafeTypeConversion=false</b>
71+
</th>
72+
<th>
73+
<b>Silent overflow</b>
74+
</th>
75+
<th>
76+
<b>Silently allows</b>
77+
</th>
78+
</tr>
79+
<tr>
80+
<th>
81+
<b>version > 0.11.0, arrowSafeTypeConversion=true</b>
82+
</th>
83+
<th>
84+
<b>Raise error</b>
85+
</th>
86+
<th>
87+
<b>Raise error</b>
88+
</th>
89+
</tr>
90+
</table>
91+
4492
- In Spark version 2.4 and earlier, if `org.apache.spark.sql.functions.udf(Any, DataType)` gets a Scala closure with primitive-type argument, the returned UDF will return null if the input values is null. Since Spark 3.0, the UDF will return the default value of the Java type if the input value is null. For example, `val f = udf((x: Int) => x, IntegerType)`, `f($"x")` will return null in Spark 2.4 and earlier if column `x` is null, and return 0 in Spark 3.0. This behavior change is introduced because Spark 3.0 is built with Scala 2.12 by default.
4593

4694
## Upgrading From Spark SQL 2.3 to 2.4

python/pyspark/serializers.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def __repr__(self):
245245
return "ArrowStreamSerializer"
246246

247247

248-
def _create_batch(series, timezone):
248+
def _create_batch(series, timezone, safecheck):
249249
"""
250250
Create an Arrow record batch from the given pandas.Series or list of Series, with optional type.
251251
@@ -284,7 +284,17 @@ def create_array(s, t):
284284
elif LooseVersion(pa.__version__) < LooseVersion("0.11.0"):
285285
# TODO: see ARROW-1949. Remove when the minimum PyArrow version becomes 0.11.0.
286286
return pa.Array.from_pandas(s, mask=mask, type=t)
287-
return pa.Array.from_pandas(s, mask=mask, type=t, safe=False)
287+
288+
try:
289+
array = pa.Array.from_pandas(s, mask=mask, type=t, safe=safecheck)
290+
except pa.ArrowException as e:
291+
error_msg = "Exception thrown when converting pandas.Series (%s) to Arrow " + \
292+
"Array (%s). It can be caused by overflows or other unsafe " + \
293+
"conversions warned by Arrow. Arrow safe type check can be " + \
294+
"disabled by using SQL config " + \
295+
"`spark.sql.execution.pandas.arrowSafeTypeConversion`."
296+
raise RuntimeError(error_msg % (s.dtype, t), e)
297+
return array
288298

289299
arrs = [create_array(s, t) for s, t in series]
290300
return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))])
@@ -295,9 +305,10 @@ class ArrowStreamPandasSerializer(Serializer):
295305
Serializes Pandas.Series as Arrow data with Arrow streaming format.
296306
"""
297307

298-
def __init__(self, timezone):
308+
def __init__(self, timezone, safecheck):
299309
super(ArrowStreamPandasSerializer, self).__init__()
300310
self._timezone = timezone
311+
self._safecheck = safecheck
301312

302313
def arrow_to_pandas(self, arrow_column):
303314
from pyspark.sql.types import from_arrow_type, \
@@ -317,7 +328,7 @@ def dump_stream(self, iterator, stream):
317328
writer = None
318329
try:
319330
for series in iterator:
320-
batch = _create_batch(series, self._timezone)
331+
batch = _create_batch(series, self._timezone, self._safecheck)
321332
if writer is None:
322333
write_int(SpecialLengths.START_ARROW_STREAM, stream)
323334
writer = pa.RecordBatchStreamWriter(stream, batch.schema)

python/pyspark/sql/session.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -556,8 +556,9 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone):
556556
pdf_slices = (pdf[start:start + step] for start in xrange(0, len(pdf), step))
557557

558558
# Create Arrow record batches
559+
safecheck = self._wrapped._conf.arrowSafeTypeConversion()
559560
batches = [_create_batch([(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)],
560-
timezone)
561+
timezone, safecheck)
561562
for pdf_slice in pdf_slices]
562563

563564
# Create the Spark schema from the first Arrow batch (always at least 1 batch after slicing)

python/pyspark/sql/tests/test_pandas_udf.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,64 @@ def foofoo(x, y):
197197
).collect
198198
)
199199

200+
def test_pandas_udf_detect_unsafe_type_conversion(self):
201+
from distutils.version import LooseVersion
202+
import pandas as pd
203+
import numpy as np
204+
import pyarrow as pa
205+
206+
values = [1.0] * 3
207+
pdf = pd.DataFrame({'A': values})
208+
df = self.spark.createDataFrame(pdf).repartition(1)
209+
210+
@pandas_udf(returnType="int")
211+
def udf(column):
212+
return pd.Series(np.linspace(0, 1, 3))
213+
214+
# Since 0.11.0, PyArrow supports the feature to raise an error for unsafe cast.
215+
if LooseVersion(pa.__version__) >= LooseVersion("0.11.0"):
216+
with self.sql_conf({
217+
"spark.sql.execution.pandas.arrowSafeTypeConversion": True}):
218+
with self.assertRaisesRegexp(Exception,
219+
"Exception thrown when converting pandas.Series"):
220+
df.select(['A']).withColumn('udf', udf('A')).collect()
221+
222+
# Disabling Arrow safe type check.
223+
with self.sql_conf({
224+
"spark.sql.execution.pandas.arrowSafeTypeConversion": False}):
225+
df.select(['A']).withColumn('udf', udf('A')).collect()
226+
227+
def test_pandas_udf_arrow_overflow(self):
228+
from distutils.version import LooseVersion
229+
import pandas as pd
230+
import pyarrow as pa
231+
232+
df = self.spark.range(0, 1)
233+
234+
@pandas_udf(returnType="byte")
235+
def udf(column):
236+
return pd.Series([128])
237+
238+
# Arrow 0.11.0+ allows enabling or disabling safe type check.
239+
if LooseVersion(pa.__version__) >= LooseVersion("0.11.0"):
240+
# When enabling safe type check, Arrow 0.11.0+ disallows overflow cast.
241+
with self.sql_conf({
242+
"spark.sql.execution.pandas.arrowSafeTypeConversion": True}):
243+
with self.assertRaisesRegexp(Exception,
244+
"Exception thrown when converting pandas.Series"):
245+
df.withColumn('udf', udf('id')).collect()
246+
247+
# Disabling safe type check, let Arrow do the cast anyway.
248+
with self.sql_conf({"spark.sql.execution.pandas.arrowSafeTypeConversion": False}):
249+
df.withColumn('udf', udf('id')).collect()
250+
else:
251+
# SQL config `arrowSafeTypeConversion` no matters for older Arrow.
252+
# Overflow cast causes an error.
253+
with self.sql_conf({"spark.sql.execution.pandas.arrowSafeTypeConversion": False}):
254+
with self.assertRaisesRegexp(Exception,
255+
"Integer value out of bounds"):
256+
df.withColumn('udf', udf('id')).collect()
257+
200258

201259
if __name__ == "__main__":
202260
from pyspark.sql.tests.test_pandas_udf import *

python/pyspark/worker.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,9 @@ def read_udfs(pickleSer, infile, eval_type):
252252

253253
# NOTE: if timezone is set here, that implies respectSessionTimeZone is True
254254
timezone = runner_conf.get("spark.sql.session.timeZone", None)
255-
ser = ArrowStreamPandasSerializer(timezone)
255+
safecheck = runner_conf.get("spark.sql.execution.pandas.arrowSafeTypeConversion",
256+
"false").lower() == 'true'
257+
ser = ArrowStreamPandasSerializer(timezone, safecheck)
256258
else:
257259
ser = BatchedSerializer(PickleSerializer(), 100)
258260

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1324,6 +1324,16 @@ object SQLConf {
13241324
.booleanConf
13251325
.createWithDefault(true)
13261326

1327+
val PANDAS_ARROW_SAFE_TYPE_CONVERSION =
1328+
buildConf("spark.sql.execution.pandas.arrowSafeTypeConversion")
1329+
.internal()
1330+
.doc("When true, Arrow will perform safe type conversion when converting " +
1331+
"Pandas.Series to Arrow array during serialization. Arrow will raise errors " +
1332+
"when detecting unsafe type conversion like overflow. When false, disabling Arrow's type " +
1333+
"check and do type conversions anyway. This config only works for Arrow 0.11.0+.")
1334+
.booleanConf
1335+
.createWithDefault(false)
1336+
13271337
val REPLACE_EXCEPT_WITH_FILTER = buildConf("spark.sql.optimizer.replaceExceptWithFilter")
13281338
.internal()
13291339
.doc("When true, the apply function of the rule verifies whether the right node of the" +
@@ -1998,6 +2008,8 @@ class SQLConf extends Serializable with Logging {
19982008
def pandasGroupedMapAssignColumnsByName: Boolean =
19992009
getConf(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME)
20002010

2011+
def arrowSafeTypeConversion: Boolean = getConf(SQLConf.PANDAS_ARROW_SAFE_TYPE_CONVERSION)
2012+
20012013
def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER)
20022014

20032015
def decimalOperationsAllowPrecisionLoss: Boolean = getConf(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS)

sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,8 @@ object ArrowUtils {
133133
}
134134
val pandasColsByName = Seq(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME.key ->
135135
conf.pandasGroupedMapAssignColumnsByName.toString)
136-
Map(timeZoneConf ++ pandasColsByName: _*)
136+
val arrowSafeTypeCheck = Seq(SQLConf.PANDAS_ARROW_SAFE_TYPE_CONVERSION.key ->
137+
conf.arrowSafeTypeConversion.toString)
138+
Map(timeZoneConf ++ pandasColsByName ++ arrowSafeTypeCheck: _*)
137139
}
138140
}

0 commit comments

Comments
 (0)