Skip to content

Commit c31d519

Browse files
committed
Disable Arrow safe type check for some tests.
1 parent 5fc35a3 commit c31d519

File tree

1 file changed

+32
-24
lines changed

1 file changed

+32
-24
lines changed

python/pyspark/sql/tests/test_pandas_udf_scalar.py

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -138,36 +138,44 @@ def test_vectorized_udf_null_boolean(self):
138138
self.assertEquals(df.collect(), res.collect())
139139

140140
def test_vectorized_udf_null_byte(self):
141-
data = [(None,), (2,), (3,), (4,)]
142-
schema = StructType().add("byte", ByteType())
143-
df = self.spark.createDataFrame(data, schema)
144-
byte_f = pandas_udf(lambda x: x, ByteType())
145-
res = df.select(byte_f(col('byte')))
146-
self.assertEquals(df.collect(), res.collect())
141+
with self.sql_conf({
142+
"spark.sql.execution.pandas.arrowSafeTypeConversion": False}):
143+
data = [(None,), (2,), (3,), (4,)]
144+
schema = StructType().add("byte", ByteType())
145+
df = self.spark.createDataFrame(data, schema)
146+
byte_f = pandas_udf(lambda x: x, ByteType())
147+
res = df.select(byte_f(col('byte')))
148+
self.assertEquals(df.collect(), res.collect())
147149

148150
def test_vectorized_udf_null_short(self):
149-
data = [(None,), (2,), (3,), (4,)]
150-
schema = StructType().add("short", ShortType())
151-
df = self.spark.createDataFrame(data, schema)
152-
short_f = pandas_udf(lambda x: x, ShortType())
153-
res = df.select(short_f(col('short')))
154-
self.assertEquals(df.collect(), res.collect())
151+
with self.sql_conf({
152+
"spark.sql.execution.pandas.arrowSafeTypeConversion": False}):
153+
data = [(None,), (2,), (3,), (4,)]
154+
schema = StructType().add("short", ShortType())
155+
df = self.spark.createDataFrame(data, schema)
156+
short_f = pandas_udf(lambda x: x, ShortType())
157+
res = df.select(short_f(col('short')))
158+
self.assertEquals(df.collect(), res.collect())
155159

156160
def test_vectorized_udf_null_int(self):
157-
data = [(None,), (2,), (3,), (4,)]
158-
schema = StructType().add("int", IntegerType())
159-
df = self.spark.createDataFrame(data, schema)
160-
int_f = pandas_udf(lambda x: x, IntegerType())
161-
res = df.select(int_f(col('int')))
162-
self.assertEquals(df.collect(), res.collect())
161+
with self.sql_conf({
162+
"spark.sql.execution.pandas.arrowSafeTypeConversion": False}):
163+
data = [(None,), (2,), (3,), (4,)]
164+
schema = StructType().add("int", IntegerType())
165+
df = self.spark.createDataFrame(data, schema)
166+
int_f = pandas_udf(lambda x: x, IntegerType())
167+
res = df.select(int_f(col('int')))
168+
self.assertEquals(df.collect(), res.collect())
163169

164170
def test_vectorized_udf_null_long(self):
165-
data = [(None,), (2,), (3,), (4,)]
166-
schema = StructType().add("long", LongType())
167-
df = self.spark.createDataFrame(data, schema)
168-
long_f = pandas_udf(lambda x: x, LongType())
169-
res = df.select(long_f(col('long')))
170-
self.assertEquals(df.collect(), res.collect())
171+
with self.sql_conf({
172+
"spark.sql.execution.pandas.arrowSafeTypeConversion": False}):
173+
data = [(None,), (2,), (3,), (4,)]
174+
schema = StructType().add("long", LongType())
175+
df = self.spark.createDataFrame(data, schema)
176+
long_f = pandas_udf(lambda x: x, LongType())
177+
res = df.select(long_f(col('long')))
178+
self.assertEquals(df.collect(), res.collect())
171179

172180
def test_vectorized_udf_null_float(self):
173181
data = [(3.0,), (5.0,), (-1.0,), (None,)]

0 commit comments

Comments
 (0)