Skip to content

Commit 250e0b8

Browse files
committed
Address comments.
1 parent 68b0a3a commit 250e0b8

File tree

6 files changed

+49
-16
lines changed

6 files changed

+49
-16
lines changed

python/pyspark/serializers.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def __repr__(self):
245245
return "ArrowStreamSerializer"
246246

247247

248-
def _create_batch(series, timezone, runner_conf):
248+
def _create_batch(series, timezone, safecheck):
249249
"""
250250
Create an Arrow record batch from the given pandas.Series or list of Series, with optional type.
251251
@@ -285,10 +285,8 @@ def create_array(s, t):
285285
# TODO: see ARROW-1949. Remove when the minimum PyArrow version becomes 0.11.0.
286286
return pa.Array.from_pandas(s, mask=mask, type=t)
287287

288-
enabledArrowSafeTypeCheck = \
289-
runner_conf.get("spark.sql.execution.pandas.arrowSafeTypeConversion", "true") == 'true'
290288
try:
291-
array = pa.Array.from_pandas(s, mask=mask, type=t, safe=enabledArrowSafeTypeCheck)
289+
array = pa.Array.from_pandas(s, mask=mask, type=t, safe=safecheck)
292290
except pa.ArrowException as e:
293291
error_msg = "Exception thrown when converting pandas.Series (%s) to Arrow " + \
294292
"Array (%s). It can be caused by overflows or other unsafe " + \
@@ -307,10 +305,10 @@ class ArrowStreamPandasSerializer(Serializer):
307305
Serializes Pandas.Series as Arrow data with Arrow streaming format.
308306
"""
309307

310-
def __init__(self, timezone, runner_conf):
308+
def __init__(self, timezone, safecheck):
311309
super(ArrowStreamPandasSerializer, self).__init__()
312310
self._timezone = timezone
313-
self._runner_conf = runner_conf
311+
self._safecheck = safecheck
314312

315313
def arrow_to_pandas(self, arrow_column):
316314
from pyspark.sql.types import from_arrow_type, \
@@ -330,7 +328,7 @@ def dump_stream(self, iterator, stream):
330328
writer = None
331329
try:
332330
for series in iterator:
333-
batch = _create_batch(series, self._timezone, self._runner_conf)
331+
batch = _create_batch(series, self._timezone, self._safecheck)
334332
if writer is None:
335333
write_int(SpecialLengths.START_ARROW_STREAM, stream)
336334
writer = pa.RecordBatchStreamWriter(stream, batch.schema)

python/pyspark/sql/session.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -556,10 +556,9 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone):
556556
pdf_slices = (pdf[start:start + step] for start in xrange(0, len(pdf), step))
557557

558558
# Create Arrow record batches
559-
runner_conf = {"spark.sql.execution.pandas.arrowSafeTypeConversion":
560-
self._wrapped._conf.arrowSafeTypeConversion()}
559+
safecheck = self._wrapped._conf.arrowSafeTypeConversion()
561560
batches = [_create_batch([(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)],
562-
timezone, runner_conf)
561+
timezone, safecheck)
563562
for pdf_slice in pdf_slices]
564563

565564
# Create the Spark schema from the first Arrow batch (always at least 1 batch after slicing)

python/pyspark/sql/tests/test_arrow.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,8 +288,10 @@ def test_createDataFrame_does_not_modify_input(self):
288288
# Integers with nulls will get NaNs filled with 0 and will be casted
289289
pdf.ix[1, '2_int_t'] = None
290290
pdf_copy = pdf.copy(deep=True)
291-
self.spark.createDataFrame(pdf, schema=self.schema)
292-
self.assertTrue(pdf.equals(pdf_copy))
291+
with self.sql_conf({
292+
"spark.sql.execution.pandas.arrowSafeTypeConversion": False}):
293+
self.spark.createDataFrame(pdf, schema=self.schema)
294+
self.assertTrue(pdf.equals(pdf_copy))
293295

294296
def test_schema_conversion_roundtrip(self):
295297
from pyspark.sql.types import from_arrow_schema, to_arrow_schema

python/pyspark/sql/tests/test_pandas_udf.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,38 @@ def udf(column):
225225
"spark.sql.execution.pandas.arrowSafeTypeConversion": False}):
226226
df.select(['A']).withColumn('udf', udf('A')).collect()
227227

228+
def test_pandas_udf_arrow_overflow(self):
229+
from distutils.version import LooseVersion
230+
from pyspark.sql.functions import pandas_udf
231+
import pandas as pd
232+
import pyarrow as pa
233+
234+
df = self.spark.range(0, 1)
235+
236+
@pandas_udf(returnType="byte")
237+
def udf(column):
238+
return pd.Series([128])
239+
240+
# Arrow 0.11.0+ allows enabling or disabling safe type check.
241+
if LooseVersion(pa.__version__) >= LooseVersion("0.11.0"):
242+
# When enabling safe type check, Arrow 0.11.0+ disallows overflow cast.
243+
with self.sql_conf({
244+
"spark.sql.execution.pandas.arrowSafeTypeConversion": True}):
245+
with self.assertRaisesRegexp(Exception,
246+
"Exception thrown when converting pandas.Series"):
247+
df.withColumn('udf', udf('id')).collect()
248+
249+
# Disabling safe type check, let Arrow do the cast anyway.
250+
with self.sql_conf({"spark.sql.execution.pandas.arrowSafeTypeConversion": False}):
251+
df.withColumn('udf', udf('id')).collect()
252+
else:
253+
# SQL config `arrowSafeTypeConversion` no matters for older Arrow.
254+
# Overflow cast causes an error.
255+
with self.sql_conf({"spark.sql.execution.pandas.arrowSafeTypeConversion": False}):
256+
with self.assertRaisesRegexp(Exception,
257+
"Integer value out of bounds"):
258+
df.withColumn('udf', udf('id')).collect()
259+
228260

229261
if __name__ == "__main__":
230262
from pyspark.sql.tests.test_pandas_udf import *

python/pyspark/worker.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,9 @@ def read_udfs(pickleSer, infile, eval_type):
253253

254254
# NOTE: if timezone is set here, that implies respectSessionTimeZone is True
255255
timezone = runner_conf.get("spark.sql.session.timeZone", None)
256-
ser = ArrowStreamPandasSerializer(timezone, runner_conf)
256+
safecheck = runner_conf.get("spark.sql.execution.pandas.arrowSafeTypeConversion",
257+
"true").lower() == 'true'
258+
ser = ArrowStreamPandasSerializer(timezone, safecheck)
257259
else:
258260
ser = BatchedSerializer(PickleSerializer(), 100)
259261

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1334,9 +1334,9 @@ object SQLConf {
13341334
val PANDAS_ARROW_SAFE_TYPE_CONVERSION =
13351335
buildConf("spark.sql.execution.pandas.arrowSafeTypeConversion")
13361336
.internal()
1337-
.doc("When true, enabling Arrow do safe type conversion check when converting" +
1338-
"Pandas.Series to Arrow Array during serialization. Arrow will raise errors " +
1339-
"when detecting unsafe type conversion. When false, disabling Arrow's type " +
1337+
.doc("When true, Arrow will perform safe type conversion when converting " +
1338+
"Pandas.Series to Arrow array during serialization. Arrow will raise errors " +
1339+
"when detecting unsafe type conversion like overflow. When false, disabling Arrow's type " +
13401340
"check and do type conversions anyway.")
13411341
.booleanConf
13421342
.createWithDefault(true)

0 commit comments

Comments
 (0)