diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 3d46b852c52e1..6dbdb3adb295d 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -30,6 +30,8 @@ import functools import time import datetime +import array +import math import py4j try: @@ -1735,6 +1737,60 @@ def test_BinaryType_serialization(self): df = self.spark.createDataFrame(data, schema=schema) df.collect() + # test for SPARK-16542 + def test_array_types(self): + int_types = set(['b', 'h', 'i', 'l']) + float_types = set(['f', 'd']) + unsupported_types = set(array.typecodes) - int_types - float_types + + def collected(a): + row = Row(myarray=a) + rdd = self.sc.parallelize([row]) + df = self.spark.createDataFrame(rdd) + return df.collect()[0]["myarray"][0] + # test whether pyspark can correctly handle int types + for t in int_types: + # test positive numbers + a = array.array(t, [1]) + while True: + try: + self.assertEqual(collected(a), a[0]) + a[0] *= 2 + except OverflowError: + break + # test negative numbers + a = array.array(t, [-1]) + while True: + try: + self.assertEqual(collected(a), a[0]) + a[0] *= 2 + except OverflowError: + break + # test whether pyspark can correctly handle float types + for t in float_types: + # test upper bound and precision + a = array.array(t, [1.0]) + while not math.isinf(a[0]): + self.assertEqual(collected(a), a[0]) + a[0] *= 2 + a[0] += 1 + # test lower bound + a = array.array(t, [1.0]) + while a[0] != 0: + self.assertEqual(collected(a), a[0]) + a[0] /= 2 + # test whether pyspark can correctly handle unsupported types + for t in unsupported_types: + try: + a = array.array(t) + c = collected(a) + self.assertTrue(False) # if no exception thrown, fail the test + except TypeError: + pass # catch the expected exception and do nothing + except: + # if incorrect exception thrown, fail the test + self.assertTrue(False) + class HiveSparkSubmitTests(SparkSubmitTests): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 4a023123b6eca..e8aeef1a51b7a 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -929,6 +929,16 @@ def _parse_datatype_json_value(json_value): datetime.time: TimestampType, } +# Mapping Python array types to Spark SQL DataType +_array_type_mappings = { + 'b': ByteType, + 'h': ShortType, + 'i': IntegerType, + 'l': LongType, + 'f': FloatType, + 'd': DoubleType +} + if sys.version < "3": _type_mappings.update({ unicode: StringType, @@ -958,12 +968,17 @@ def _infer_type(obj): return MapType(_infer_type(key), _infer_type(value), True) else: return MapType(NullType(), NullType(), True) - elif isinstance(obj, (list, array)): + elif isinstance(obj, list): for v in obj: if v is not None: return ArrayType(_infer_type(obj[0]), True) else: return ArrayType(NullType(), True) + elif isinstance(obj, array): + if obj.typecode in _array_type_mappings: + return ArrayType(_array_type_mappings[obj.typecode](), True) + else: + raise TypeError("not supported type: array(%s)" % obj.typecode) else: try: return _infer_schema(obj) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index 724025b4647f4..90726a2e01365 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -91,20 +91,30 @@ object EvaluatePython { case (c: Boolean, BooleanType) => c + case (c: Byte, ByteType) => c + case (c: Short, ByteType) => c.toByte case (c: Int, ByteType) => c.toByte case (c: Long, ByteType) => c.toByte + case (c: Byte, ShortType) => c.toShort + case (c: Short, ShortType) => c case (c: Int, ShortType) => c.toShort case (c: Long, ShortType) => c.toShort + case (c: Byte, IntegerType) => c.toInt + case (c: Short, IntegerType) => c.toInt case (c: Int, IntegerType) => c case (c: Long, IntegerType) => c.toInt + case (c: Byte, LongType) => c.toLong + case (c: Short, LongType) => c.toLong case (c: Int, LongType) => c.toLong case (c: Long, LongType) => c + case (c: Float, FloatType) => c case (c: Double, FloatType) => c.toFloat + case (c: Float, DoubleType) => c.toDouble case (c: Double, DoubleType) => c case (c: java.math.BigDecimal, dt: DecimalType) => Decimal(c, dt.precision, dt.scale)