Skip to content

Commit 5fc35a3

Browse files
committed
Add SQL config for Arrow safe type check.
1 parent 409569b commit 5fc35a3

File tree

5 files changed

+37
-12
lines changed

5 files changed

+37
-12
lines changed

python/pyspark/serializers.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def __repr__(self):
245245
return "ArrowStreamSerializer"
246246

247247

248-
def _create_batch(series, timezone):
248+
def _create_batch(series, timezone, runner_conf):
249249
"""
250250
Create an Arrow record batch from the given pandas.Series or list of Series, with optional type.
251251
@@ -284,12 +284,17 @@ def create_array(s, t):
284284
elif LooseVersion(pa.__version__) < LooseVersion("0.11.0"):
285285
# TODO: see ARROW-1949. Remove when the minimum PyArrow version becomes 0.11.0.
286286
return pa.Array.from_pandas(s, mask=mask, type=t)
287+
288+
enabledArrowSafeTypeCheck = \
289+
runner_conf.get("spark.sql.execution.pandas.arrowSafeTypeConversion", "true") == 'true'
287290
try:
288-
array = pa.Array.from_pandas(s, mask=mask, type=t, safe=True)
291+
array = pa.Array.from_pandas(s, mask=mask, type=t, safe=enabledArrowSafeTypeCheck)
289292
except pa.ArrowException as e:
290293
error_msg = "Exception thrown when converting pandas.Series (%s) to Arrow " + \
291294
"Array (%s). It can be caused by overflows or other unsafe " + \
292-
"conversions warned by Arrow."
295+
"conversions warned by Arrow. Arrow safe type check can be " + \
296+
"disabled by using SQL config " + \
297+
"`spark.sql.execution.pandas.arrowSafeTypeConversion`."
293298
raise RuntimeError(error_msg % (s.dtype, t), e)
294299
return array
295300

@@ -302,9 +307,10 @@ class ArrowStreamPandasSerializer(Serializer):
302307
Serializes Pandas.Series as Arrow data with Arrow streaming format.
303308
"""
304309

305-
def __init__(self, timezone):
310+
def __init__(self, timezone, runner_conf):
306311
super(ArrowStreamPandasSerializer, self).__init__()
307312
self._timezone = timezone
313+
self._runner_conf = runner_conf
308314

309315
def arrow_to_pandas(self, arrow_column):
310316
from pyspark.sql.types import from_arrow_type, \
@@ -324,7 +330,7 @@ def dump_stream(self, iterator, stream):
324330
writer = None
325331
try:
326332
for series in iterator:
327-
batch = _create_batch(series, self._timezone)
333+
batch = _create_batch(series, self._timezone, self._runner_conf)
328334
if writer is None:
329335
write_int(SpecialLengths.START_ARROW_STREAM, stream)
330336
writer = pa.RecordBatchStreamWriter(stream, batch.schema)

python/pyspark/sql/tests/test_pandas_udf.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -212,13 +212,18 @@ def test_pandas_udf_detect_unsafe_type_conversion(self):
212212
def udf(column):
213213
return pd.Series(np.linspace(0, 1, 3))
214214

215-
udf_df = df.select(['A']).withColumn('udf', udf('A'))
216-
217215
# Since 0.11.0, PyArrow supports the feature to raise an error for unsafe cast.
218216
if LooseVersion(pa.__version__) >= LooseVersion("0.11.0"):
219-
with self.assertRaisesRegexp(Exception,
220-
"Exception thrown when converting pandas.Series"):
221-
udf_df.collect()
217+
with self.sql_conf({
218+
"spark.sql.execution.pandas.arrowSafeTypeConversion": True}):
219+
with self.assertRaisesRegexp(Exception,
220+
"Exception thrown when converting pandas.Series"):
221+
df.select(['A']).withColumn('udf', udf('A')).collect()
222+
223+
# Disabling Arrow safe type check.
224+
with self.sql_conf({
225+
"spark.sql.execution.pandas.arrowSafeTypeConversion": False}):
226+
df.select(['A']).withColumn('udf', udf('A')).collect()
222227

223228

224229
if __name__ == "__main__":

python/pyspark/worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def read_udfs(pickleSer, infile, eval_type):
253253

254254
# NOTE: if timezone is set here, that implies respectSessionTimeZone is True
255255
timezone = runner_conf.get("spark.sql.session.timeZone", None)
256-
ser = ArrowStreamPandasSerializer(timezone)
256+
ser = ArrowStreamPandasSerializer(timezone, runner_conf)
257257
else:
258258
ser = BatchedSerializer(PickleSerializer(), 100)
259259

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1331,6 +1331,16 @@ object SQLConf {
13311331
.booleanConf
13321332
.createWithDefault(true)
13331333

1334+
val PANDAS_ARROW_SAFE_TYPE_CONVERSION =
1335+
buildConf("spark.sql.execution.pandas.arrowSafeTypeConversion")
1336+
.internal()
1337+
.doc("When true, enabling Arrow do safe type conversion check when converting" +
1338+
"Pandas.Series to Arrow Array during serialization. Arrow will raise errors " +
1339+
"when detecting unsafe type conversion. When false, disabling Arrow's type " +
1340+
"check and do type conversions anyway.")
1341+
.booleanConf
1342+
.createWithDefault(true)
1343+
13341344
val REPLACE_EXCEPT_WITH_FILTER = buildConf("spark.sql.optimizer.replaceExceptWithFilter")
13351345
.internal()
13361346
.doc("When true, the apply function of the rule verifies whether the right node of the" +
@@ -2005,6 +2015,8 @@ class SQLConf extends Serializable with Logging {
20052015
def pandasGroupedMapAssignColumnsByName: Boolean =
20062016
getConf(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME)
20072017

2018+
def arrowSafeTypeConversion: Boolean = getConf(SQLConf.PANDAS_ARROW_SAFE_TYPE_CONVERSION)
2019+
20082020
def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER)
20092021

20102022
def decimalOperationsAllowPrecisionLoss: Boolean = getConf(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS)

sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,8 @@ object ArrowUtils {
133133
}
134134
val pandasColsByName = Seq(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME.key ->
135135
conf.pandasGroupedMapAssignColumnsByName.toString)
136-
Map(timeZoneConf ++ pandasColsByName: _*)
136+
val arrowSafeTypeCheck = Seq(SQLConf.PANDAS_ARROW_SAFE_TYPE_CONVERSION.key ->
137+
conf.arrowSafeTypeConversion.toString)
138+
Map(timeZoneConf ++ pandasColsByName ++ arrowSafeTypeCheck: _*)
137139
}
138140
}

0 commit comments

Comments
 (0)