Skip to content

Commit d36cce1

Browse files
BryanCutlerHyukjinKwon
authored andcommitted
[SPARK-27276][PYTHON][SQL] Increase minimum version of pyarrow to 0.12.1 and remove prior workarounds
## What changes were proposed in this pull request? This increases the minimum support version of pyarrow to 0.12.1 and removes workarounds in pyspark to remain compatible with prior versions. This means that users will need to have at least pyarrow 0.12.1 installed and available in the cluster or an `ImportError` will be raised to indicate an upgrade is needed. ## How was this patch tested? Existing tests using: Python 2.7.15, pyarrow 0.12.1, pandas 0.24.2 Python 3.6.7, pyarrow 0.12.1, pandas 0.24.0 Closes #24298 from BryanCutler/arrow-bump-min-pyarrow-SPARK-27276. Authored-by: Bryan Cutler <[email protected]> Signed-off-by: HyukjinKwon <[email protected]>
1 parent 777b797 commit d36cce1

File tree

10 files changed

+67
-218
lines changed

10 files changed

+67
-218
lines changed

python/pyspark/serializers.py

Lines changed: 17 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -260,10 +260,14 @@ def __init__(self, timezone, safecheck, assign_cols_by_name):
260260
self._safecheck = safecheck
261261
self._assign_cols_by_name = assign_cols_by_name
262262

263-
def arrow_to_pandas(self, arrow_column, data_type):
264-
from pyspark.sql.types import _arrow_column_to_pandas, _check_series_localize_timestamps
263+
def arrow_to_pandas(self, arrow_column):
264+
from pyspark.sql.types import _check_series_localize_timestamps
265+
266+
# If the given column is a date type column, creates a series of datetime.date directly
267+
# instead of creating datetime64[ns] as intermediate data to avoid overflow caused by
268+
# datetime64[ns] type handling.
269+
s = arrow_column.to_pandas(date_as_object=True)
265270

266-
s = _arrow_column_to_pandas(arrow_column, data_type)
267271
s = _check_series_localize_timestamps(s, self._timezone)
268272
return s
269273

@@ -275,8 +279,6 @@ def _create_batch(self, series):
275279
:param series: A single pandas.Series, list of Series, or list of (series, arrow_type)
276280
:return: Arrow RecordBatch
277281
"""
278-
import decimal
279-
from distutils.version import LooseVersion
280282
import pandas as pd
281283
import pyarrow as pa
282284
from pyspark.sql.types import _check_series_convert_timestamps_internal
@@ -289,24 +291,10 @@ def _create_batch(self, series):
289291
def create_array(s, t):
290292
mask = s.isnull()
291293
# Ensure timestamp series are in expected form for Spark internal representation
292-
# TODO: maybe don't need None check anymore as of Arrow 0.9.1
293294
if t is not None and pa.types.is_timestamp(t):
294295
s = _check_series_convert_timestamps_internal(s.fillna(0), self._timezone)
295296
# TODO: need cast after Arrow conversion, ns values cause error with pandas 0.19.2
296297
return pa.Array.from_pandas(s, mask=mask).cast(t, safe=False)
297-
elif t is not None and pa.types.is_string(t) and sys.version < '3':
298-
# TODO: need decode before converting to Arrow in Python 2
299-
# TODO: don't need as of Arrow 0.9.1
300-
return pa.Array.from_pandas(s.apply(
301-
lambda v: v.decode("utf-8") if isinstance(v, str) else v), mask=mask, type=t)
302-
elif t is not None and pa.types.is_decimal(t) and \
303-
LooseVersion("0.9.0") <= LooseVersion(pa.__version__) < LooseVersion("0.10.0"):
304-
# TODO: see ARROW-2432. Remove when the minimum PyArrow version becomes 0.10.0.
305-
return pa.Array.from_pandas(s.apply(
306-
lambda v: decimal.Decimal('NaN') if v is None else v), mask=mask, type=t)
307-
elif LooseVersion(pa.__version__) < LooseVersion("0.11.0"):
308-
# TODO: see ARROW-1949. Remove when the minimum PyArrow version becomes 0.11.0.
309-
return pa.Array.from_pandas(s, mask=mask, type=t)
310298

311299
try:
312300
array = pa.Array.from_pandas(s, mask=mask, type=t, safe=self._safecheck)
@@ -340,12 +328,7 @@ def create_array(s, t):
340328
for i, field in enumerate(t)]
341329

342330
struct_arrs, struct_names = zip(*arrs_names)
343-
344-
# TODO: from_arrays args switched for v0.9.0, remove when bump min pyarrow version
345-
if LooseVersion(pa.__version__) < LooseVersion("0.9.0"):
346-
arrs.append(pa.StructArray.from_arrays(struct_names, struct_arrs))
347-
else:
348-
arrs.append(pa.StructArray.from_arrays(struct_arrs, struct_names))
331+
arrs.append(pa.StructArray.from_arrays(struct_arrs, struct_names))
349332
else:
350333
arrs.append(create_array(s, t))
351334

@@ -365,10 +348,8 @@ def load_stream(self, stream):
365348
"""
366349
batches = super(ArrowStreamPandasSerializer, self).load_stream(stream)
367350
import pyarrow as pa
368-
from pyspark.sql.types import from_arrow_type
369351
for batch in batches:
370-
yield [self.arrow_to_pandas(c, from_arrow_type(c.type))
371-
for c in pa.Table.from_batches([batch]).itercolumns()]
352+
yield [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns()]
372353

373354
def __repr__(self):
374355
return "ArrowStreamPandasSerializer"
@@ -384,17 +365,17 @@ def __init__(self, timezone, safecheck, assign_cols_by_name, df_for_struct=False
384365
.__init__(timezone, safecheck, assign_cols_by_name)
385366
self._df_for_struct = df_for_struct
386367

387-
def arrow_to_pandas(self, arrow_column, data_type):
388-
from pyspark.sql.types import StructType, \
389-
_arrow_column_to_pandas, _check_dataframe_localize_timestamps
368+
def arrow_to_pandas(self, arrow_column):
369+
import pyarrow.types as types
390370

391-
if self._df_for_struct and type(data_type) == StructType:
371+
if self._df_for_struct and types.is_struct(arrow_column.type):
392372
import pandas as pd
393-
series = [_arrow_column_to_pandas(column, field.dataType).rename(field.name)
394-
for column, field in zip(arrow_column.flatten(), data_type)]
395-
s = _check_dataframe_localize_timestamps(pd.concat(series, axis=1), self._timezone)
373+
series = [super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(column)
374+
.rename(field.name)
375+
for column, field in zip(arrow_column.flatten(), arrow_column.type)]
376+
s = pd.concat(series, axis=1)
396377
else:
397-
s = super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(arrow_column, data_type)
378+
s = super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(arrow_column)
398379
return s
399380

400381
def dump_stream(self, iterator, stream):

python/pyspark/sql/dataframe.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2138,13 +2138,15 @@ def toPandas(self):
21382138
# of PyArrow is found, if 'spark.sql.execution.arrow.enabled' is enabled.
21392139
if use_arrow:
21402140
try:
2141-
from pyspark.sql.types import _arrow_table_to_pandas, \
2142-
_check_dataframe_localize_timestamps
2141+
from pyspark.sql.types import _check_dataframe_localize_timestamps
21432142
import pyarrow
21442143
batches = self._collectAsArrow()
21452144
if len(batches) > 0:
21462145
table = pyarrow.Table.from_batches(batches)
2147-
pdf = _arrow_table_to_pandas(table, self.schema)
2146+
# Pandas DataFrame created from PyArrow uses datetime64[ns] for date type
2147+
# values, but we should use datetime.date to match the behavior with when
2148+
# Arrow optimization is disabled.
2149+
pdf = table.to_pandas(date_as_object=True)
21482150
return _check_dataframe_localize_timestamps(pdf, timezone)
21492151
else:
21502152
return pd.DataFrame.from_records([], columns=self.columns)

python/pyspark/sql/session.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,6 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone):
530530
to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the
531531
data types will be used to coerce the data in Pandas to Arrow conversion.
532532
"""
533-
from distutils.version import LooseVersion
534533
from pyspark.serializers import ArrowStreamPandasSerializer
535534
from pyspark.sql.types import from_arrow_type, to_arrow_type, TimestampType
536535
from pyspark.sql.utils import require_minimum_pandas_version, \
@@ -544,11 +543,7 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone):
544543

545544
# Create the Spark schema from list of names passed in with Arrow types
546545
if isinstance(schema, (list, tuple)):
547-
if LooseVersion(pa.__version__) < LooseVersion("0.12.0"):
548-
temp_batch = pa.RecordBatch.from_pandas(pdf[0:100], preserve_index=False)
549-
arrow_schema = temp_batch.schema
550-
else:
551-
arrow_schema = pa.Schema.from_pandas(pdf, preserve_index=False)
546+
arrow_schema = pa.Schema.from_pandas(pdf, preserve_index=False)
552547
struct = StructType()
553548
for name, field in zip(schema, arrow_schema):
554549
struct.add(name, from_arrow_type(field.type), nullable=field.nullable)

python/pyspark/sql/tests/test_arrow.py

Lines changed: 6 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ class ArrowTests(ReusedSQLTestCase):
4646
def setUpClass(cls):
4747
from datetime import date, datetime
4848
from decimal import Decimal
49-
from distutils.version import LooseVersion
5049
super(ArrowTests, cls).setUpClass()
5150
cls.warnings_lock = threading.Lock()
5251

@@ -68,23 +67,16 @@ def setUpClass(cls):
6867
StructField("5_double_t", DoubleType(), True),
6968
StructField("6_decimal_t", DecimalType(38, 18), True),
7069
StructField("7_date_t", DateType(), True),
71-
StructField("8_timestamp_t", TimestampType(), True)])
70+
StructField("8_timestamp_t", TimestampType(), True),
71+
StructField("9_binary_t", BinaryType(), True)])
7272
cls.data = [(u"a", 1, 10, 0.2, 2.0, Decimal("2.0"),
73-
date(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)),
73+
date(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1), bytearray(b"a")),
7474
(u"b", 2, 20, 0.4, 4.0, Decimal("4.0"),
75-
date(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)),
75+
date(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2), bytearray(b"bb")),
7676
(u"c", 3, 30, 0.8, 6.0, Decimal("6.0"),
77-
date(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3)),
77+
date(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3), bytearray(b"ccc")),
7878
(u"d", 4, 40, 1.0, 8.0, Decimal("8.0"),
79-
date(2262, 4, 12), datetime(2262, 3, 3, 3, 3, 3))]
80-
81-
# TODO: remove version check once minimum pyarrow version is 0.10.0
82-
if LooseVersion("0.10.0") <= LooseVersion(pa.__version__):
83-
cls.schema.add(StructField("9_binary_t", BinaryType(), True))
84-
cls.data[0] = cls.data[0] + (bytearray(b"a"),)
85-
cls.data[1] = cls.data[1] + (bytearray(b"bb"),)
86-
cls.data[2] = cls.data[2] + (bytearray(b"ccc"),)
87-
cls.data[3] = cls.data[3] + (bytearray(b"dddd"),)
79+
date(2262, 4, 12), datetime(2262, 3, 3, 3, 3, 3), bytearray(b"dddd"))]
8880

8981
@classmethod
9082
def tearDownClass(cls):
@@ -123,23 +115,13 @@ def test_toPandas_fallback_enabled(self):
123115
assert_frame_equal(pdf, pd.DataFrame({u'map': [{u'a': 1}]}))
124116

125117
def test_toPandas_fallback_disabled(self):
126-
from distutils.version import LooseVersion
127-
128118
schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
129119
df = self.spark.createDataFrame([(None,)], schema=schema)
130120
with QuietTest(self.sc):
131121
with self.warnings_lock:
132122
with self.assertRaisesRegexp(Exception, 'Unsupported type'):
133123
df.toPandas()
134124

135-
# TODO: remove BinaryType check once minimum pyarrow version is 0.10.0
136-
if LooseVersion(pa.__version__) < LooseVersion("0.10.0"):
137-
schema = StructType([StructField("binary", BinaryType(), True)])
138-
df = self.spark.createDataFrame([(None,)], schema=schema)
139-
with QuietTest(self.sc):
140-
with self.assertRaisesRegexp(Exception, 'Unsupported type.*BinaryType'):
141-
df.toPandas()
142-
143125
def test_null_conversion(self):
144126
df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] +
145127
self.data)
@@ -348,20 +330,11 @@ def test_createDataFrame_fallback_enabled(self):
348330
self.assertEqual(df.collect(), [Row(a={u'a': 1})])
349331

350332
def test_createDataFrame_fallback_disabled(self):
351-
from distutils.version import LooseVersion
352-
353333
with QuietTest(self.sc):
354334
with self.assertRaisesRegexp(TypeError, 'Unsupported type'):
355335
self.spark.createDataFrame(
356336
pd.DataFrame([[{u'a': 1}]]), "a: map<string, int>")
357337

358-
# TODO: remove BinaryType check once minimum pyarrow version is 0.10.0
359-
if LooseVersion(pa.__version__) < LooseVersion("0.10.0"):
360-
with QuietTest(self.sc):
361-
with self.assertRaisesRegexp(TypeError, 'Unsupported type.*BinaryType'):
362-
self.spark.createDataFrame(
363-
pd.DataFrame([[{'a': b'aaa'}]]), "a: binary")
364-
365338
# Regression test for SPARK-23314
366339
def test_timestamp_dst(self):
367340
# Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am

python/pyspark/sql/tests/test_pandas_udf.py

Lines changed: 17 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -198,62 +198,48 @@ def foofoo(x, y):
198198
)
199199

200200
def test_pandas_udf_detect_unsafe_type_conversion(self):
201-
from distutils.version import LooseVersion
202201
import pandas as pd
203202
import numpy as np
204-
import pyarrow as pa
205203

206204
values = [1.0] * 3
207205
pdf = pd.DataFrame({'A': values})
208206
df = self.spark.createDataFrame(pdf).repartition(1)
209207

210208
@pandas_udf(returnType="int")
211209
def udf(column):
212-
return pd.Series(np.linspace(0, 1, 3))
210+
return pd.Series(np.linspace(0, 1, len(column)))
213211

214212
# Since 0.11.0, PyArrow supports the feature to raise an error for unsafe cast.
215-
if LooseVersion(pa.__version__) >= LooseVersion("0.11.0"):
216-
with self.sql_conf({
217-
"spark.sql.execution.pandas.arrowSafeTypeConversion": True}):
218-
with self.assertRaisesRegexp(Exception,
219-
"Exception thrown when converting pandas.Series"):
220-
df.select(['A']).withColumn('udf', udf('A')).collect()
213+
with self.sql_conf({
214+
"spark.sql.execution.pandas.arrowSafeTypeConversion": True}):
215+
with self.assertRaisesRegexp(Exception,
216+
"Exception thrown when converting pandas.Series"):
217+
df.select(['A']).withColumn('udf', udf('A')).collect()
221218

222219
# Disabling Arrow safe type check.
223220
with self.sql_conf({
224221
"spark.sql.execution.pandas.arrowSafeTypeConversion": False}):
225222
df.select(['A']).withColumn('udf', udf('A')).collect()
226223

227224
def test_pandas_udf_arrow_overflow(self):
228-
from distutils.version import LooseVersion
229225
import pandas as pd
230-
import pyarrow as pa
231226

232227
df = self.spark.range(0, 1)
233228

234229
@pandas_udf(returnType="byte")
235230
def udf(column):
236-
return pd.Series([128])
237-
238-
# Arrow 0.11.0+ allows enabling or disabling safe type check.
239-
if LooseVersion(pa.__version__) >= LooseVersion("0.11.0"):
240-
# When enabling safe type check, Arrow 0.11.0+ disallows overflow cast.
241-
with self.sql_conf({
242-
"spark.sql.execution.pandas.arrowSafeTypeConversion": True}):
243-
with self.assertRaisesRegexp(Exception,
244-
"Exception thrown when converting pandas.Series"):
245-
df.withColumn('udf', udf('id')).collect()
246-
247-
# Disabling safe type check, let Arrow do the cast anyway.
248-
with self.sql_conf({"spark.sql.execution.pandas.arrowSafeTypeConversion": False}):
231+
return pd.Series([128] * len(column))
232+
233+
# When enabling safe type check, Arrow 0.11.0+ disallows overflow cast.
234+
with self.sql_conf({
235+
"spark.sql.execution.pandas.arrowSafeTypeConversion": True}):
236+
with self.assertRaisesRegexp(Exception,
237+
"Exception thrown when converting pandas.Series"):
249238
df.withColumn('udf', udf('id')).collect()
250-
else:
251-
# SQL config `arrowSafeTypeConversion` no matters for older Arrow.
252-
# Overflow cast causes an error.
253-
with self.sql_conf({"spark.sql.execution.pandas.arrowSafeTypeConversion": False}):
254-
with self.assertRaisesRegexp(Exception,
255-
"Integer value out of bounds"):
256-
df.withColumn('udf', udf('id')).collect()
239+
240+
# Disabling safe type check, let Arrow do the cast anyway.
241+
with self.sql_conf({"spark.sql.execution.pandas.arrowSafeTypeConversion": False}):
242+
df.withColumn('udf', udf('id')).collect()
257243

258244

259245
if __name__ == "__main__":

python/pyspark/sql/tests/test_pandas_udf_grouped_map.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222
from collections import OrderedDict
2323
from decimal import Decimal
24-
from distutils.version import LooseVersion
2524

2625
from pyspark.sql import Row
2726
from pyspark.sql.functions import array, explode, col, lit, udf, sum, pandas_udf, PandasUDFType
@@ -65,20 +64,17 @@ def test_supported_types(self):
6564
1, 2, 3,
6665
4, 5, 1.1,
6766
2.2, Decimal(1.123),
68-
[1, 2, 2], True, 'hello'
67+
[1, 2, 2], True, 'hello',
68+
bytearray([0x01, 0x02])
6969
]
7070
output_fields = [
7171
('id', IntegerType()), ('byte', ByteType()), ('short', ShortType()),
7272
('int', IntegerType()), ('long', LongType()), ('float', FloatType()),
7373
('double', DoubleType()), ('decim', DecimalType(10, 3)),
74-
('array', ArrayType(IntegerType())), ('bool', BooleanType()), ('str', StringType())
74+
('array', ArrayType(IntegerType())), ('bool', BooleanType()), ('str', StringType()),
75+
('bin', BinaryType())
7576
]
7677

77-
# TODO: Add BinaryType to variables above once minimum pyarrow version is 0.10.0
78-
if LooseVersion(pa.__version__) >= LooseVersion("0.10.0"):
79-
values.append(bytearray([0x01, 0x02]))
80-
output_fields.append(('bin', BinaryType()))
81-
8278
output_schema = StructType([StructField(*x) for x in output_fields])
8379
df = self.spark.createDataFrame([values], schema=output_schema)
8480

@@ -95,6 +91,7 @@ def test_supported_types(self):
9591
bool=False if pdf.bool else True,
9692
str=pdf.str + 'there',
9793
array=pdf.array,
94+
bin=pdf.bin
9895
),
9996
output_schema,
10097
PandasUDFType.GROUPED_MAP
@@ -112,6 +109,7 @@ def test_supported_types(self):
112109
bool=False if pdf.bool else True,
113110
str=pdf.str + 'there',
114111
array=pdf.array,
112+
bin=pdf.bin
115113
),
116114
output_schema,
117115
PandasUDFType.GROUPED_MAP
@@ -130,6 +128,7 @@ def test_supported_types(self):
130128
bool=False if pdf.bool else True,
131129
str=pdf.str + 'there',
132130
array=pdf.array,
131+
bin=pdf.bin
133132
),
134133
output_schema,
135134
PandasUDFType.GROUPED_MAP
@@ -291,10 +290,6 @@ def test_unsupported_types(self):
291290
StructField('struct', StructType([StructField('l', LongType())])),
292291
]
293292

294-
# TODO: Remove this if-statement once minimum pyarrow version is 0.10.0
295-
if LooseVersion(pa.__version__) < LooseVersion("0.10.0"):
296-
unsupported_types.append(StructField('bin', BinaryType()))
297-
298293
for unsupported_type in unsupported_types:
299294
schema = StructType([StructField('id', LongType(), True), unsupported_type])
300295
with QuietTest(self.sc):
@@ -466,13 +461,8 @@ def invalid_positional_types(pdf):
466461
with QuietTest(self.sc):
467462
with self.assertRaisesRegexp(Exception, "KeyError: 'id'"):
468463
grouped_df.apply(column_name_typo).collect()
469-
if LooseVersion(pa.__version__) < LooseVersion("0.11.0"):
470-
# TODO: see ARROW-1949. Remove when the minimum PyArrow version becomes 0.11.0.
471-
with self.assertRaisesRegexp(Exception, "No cast implemented"):
472-
grouped_df.apply(invalid_positional_types).collect()
473-
else:
474-
with self.assertRaisesRegexp(Exception, "an integer is required"):
475-
grouped_df.apply(invalid_positional_types).collect()
464+
with self.assertRaisesRegexp(Exception, "an integer is required"):
465+
grouped_df.apply(invalid_positional_types).collect()
476466

477467
def test_positional_assignment_conf(self):
478468
with self.sql_conf({

0 commit comments

Comments
 (0)