From a127486d59528eae452dcbcc2ccfb68fdd7769b7 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Fri, 8 Jul 2016 20:58:14 -0400 Subject: [PATCH 1/6] use array.typecode to infer type Python's array has more type than python it self, for example python only has float while array support 'f' (float) and 'd' (double) Switching to array.typecode helps spark make a better inference For example, for the code: from pyspark.sql.types import _infer_type from array import array a = array('f',[1,2,3,4,5,6]) _infer_type(a) We will get ArrayType(DoubleType,true) before change, but ArrayType(FloatType,true) after change --- python/pyspark/sql/types.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index a3679873e1d8d..f048bf6d0b954 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -926,6 +926,23 @@ def _parse_datatype_json_value(json_value): datetime.time: TimestampType, } +# Mapping Python array types to Spark SQL DataType +_array_type_mappings = { + 'b': ByteType, + 'B': ShortType, + 'u': StringType, + 'h': ShortType, + 'H': IntegerType, + 'i': IntegerType, + 'I': LongType, + 'l': LongType, + 'L': LongType, + 'q': LongType, + 'Q': LongType, + 'f': FloatType, + 'd': DoubleType +} + if sys.version < "3": _type_mappings.update({ unicode: StringType, @@ -955,12 +972,14 @@ def _infer_type(obj): return MapType(_infer_type(key), _infer_type(value), True) else: return MapType(NullType(), NullType(), True) - elif isinstance(obj, (list, array)): + elif isinstance(obj, list): for v in obj: if v is not None: return ArrayType(_infer_type(obj[0]), True) else: return ArrayType(NullType(), True) + elif isinstance(obj, array): + return ArrayType(_array_type_mappings[obj.typecode](), True) else: try: return _infer_schema(obj) From 05979ca6eabf723cf3849ec2bf6f6e9de26cb138 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Thu, 14 Jul 2016 16:07:12 +0800 Subject: [PATCH 2/6] add case (c: Float, FloatType) to fromJava --- .../apache/spark/sql/execution/python/EvaluatePython.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index cf68ed4ec36a8..7a86695a8ca90 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -118,6 +118,10 @@ object EvaluatePython { case (c: Double, DoubleType) => c + case (c: Float, FloatType) => c + + case (c: Float, DoubleType) => c.toDouble + case (c: java.math.BigDecimal, dt: DecimalType) => Decimal(c, dt.precision, dt.scale) case (c: Int, DateType) => c From cd2ec6bc707fb6e7255b3a6a6822c3667866c63c Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Sun, 16 Oct 2016 22:44:48 -0400 Subject: [PATCH 3/6] add test for array in dataframe --- python/pyspark/sql/tests.py | 55 +++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index a8ca386e1ce31..227f74155e69f 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -30,6 +30,8 @@ import functools import time import datetime +import array +import math import py4j try: @@ -1630,6 +1632,59 @@ def test_cache(self): "does_not_exist", lambda: spark.catalog.uncacheTable("does_not_exist")) + # test for SPARK-16542 + def test_array_types(): + int_types = set([ 'b', 'B', 'h', 'H', 'i', 'I', 'l', 'L', 'q', 'Q' ]) + float_types = set([ 'f', 'd' ]) + unsupported_types = set(array.typecodes) - int_types - float_types + def collected(a): + row = Row(myarray=a) + rdd = self.sc.parallelize([ row ]) + df = self.spark.createDataFrame(rdd) + return df.collect()[0]["myarray"][0] + # test whether pyspark can correctly handle int types + for t in int_types: + is_unsigned = t.isupper() + # test positive numbers + a = array.array(t,[1]) + while True: + try: + self.assertEqual(collected(a),a[0]) + a[0] *= 2 + except OverflowError: + break + # test negative numbers + if not is_unsigned: + a = array.array(t,[-1]) + while True: + try: + self.assertEqual(collected(a),a[0]) + a[0] *= 2 + except OverflowError: + break + # test whether pyspark can correctly handle float types + for t in float_types: + # test upper bound and precision + a = array.array(t,[1.0]) + while not math.isinf(a[0]): + self.assertEqual(collected(a),a[0]) + a[0] *= 2 + a[0] += 1 + # test lower bound + a = array.array(t,[1.0]) + while a[0]!=0: + self.assertEqual(collected(a),a[0]) + a[0] /= 2 + # test whether pyspark can correctly handle unsupported types + for t in unsupported_types: + try: + c = collected(a) + self.assertTrue(False) # if no exception thrown, fail the test + except TypeError: + pass # catch the expected exception and do nothing + except: + self.assertTrue(False) # if incorrect exception thrown, fail the test + class HiveSparkSubmitTests(SparkSubmitTests): From 82223c02082793b899c7eeca70f7bbfcea516c28 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Sun, 16 Oct 2016 23:35:47 -0400 Subject: [PATCH 4/6] set unsigned types and Py_UNICODE as unsupported --- python/pyspark/sql/tests.py | 4 ++-- python/pyspark/sql/types.py | 12 ++++-------- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index d2adb66d12cfb..108f1755b7981 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1738,8 +1738,8 @@ def test_BinaryType_serialization(self): df.collect() # test for SPARK-16542 - def test_array_types(): - int_types = set([ 'b', 'B', 'h', 'H', 'i', 'I', 'l', 'L', 'q', 'Q' ]) + def test_array_types(self): + int_types = set([ 'b', 'h', 'i', 'l' ]) float_types = set([ 'f', 'd' ]) unsupported_types = set(array.typecodes) - int_types - float_types def collected(a): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index fa67a6ff257b8..e8aeef1a51b7a 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -932,16 +932,9 @@ def _parse_datatype_json_value(json_value): # Mapping Python array types to Spark SQL DataType _array_type_mappings = { 'b': ByteType, - 'B': ShortType, - 'u': StringType, 'h': ShortType, - 'H': IntegerType, 'i': IntegerType, - 'I': LongType, 'l': LongType, - 'L': LongType, - 'q': LongType, - 'Q': LongType, 'f': FloatType, 'd': DoubleType } @@ -982,7 +975,10 @@ def _infer_type(obj): else: return ArrayType(NullType(), True) elif isinstance(obj, array): - return ArrayType(_array_type_mappings[obj.typecode](), True) + if obj.typecode in _array_type_mappings: + return ArrayType(_array_type_mappings[obj.typecode](), True) + else: + raise TypeError("not supported type: array(%s)" % obj.typecode) else: try: return _infer_schema(obj) From 0a967e280b3250bf7217e61905ad28f010c4ed40 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Mon, 17 Oct 2016 13:46:35 -0400 Subject: [PATCH 5/6] fix code style --- python/pyspark/sql/tests.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 108f1755b7981..6aa293fc93137 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1739,56 +1739,58 @@ def test_BinaryType_serialization(self): # test for SPARK-16542 def test_array_types(self): - int_types = set([ 'b', 'h', 'i', 'l' ]) - float_types = set([ 'f', 'd' ]) + int_types = set(['b', 'h', 'i', 'l']) + float_types = set(['f', 'd']) unsupported_types = set(array.typecodes) - int_types - float_types + def collected(a): row = Row(myarray=a) - rdd = self.sc.parallelize([ row ]) + rdd = self.sc.parallelize([row]) df = self.spark.createDataFrame(rdd) return df.collect()[0]["myarray"][0] # test whether pyspark can correctly handle int types for t in int_types: is_unsigned = t.isupper() # test positive numbers - a = array.array(t,[1]) + a = array.array(t, [1]) while True: try: - self.assertEqual(collected(a),a[0]) + self.assertEqual(collected(a), a[0]) a[0] *= 2 except OverflowError: break # test negative numbers if not is_unsigned: - a = array.array(t,[-1]) + a = array.array(t, [-1]) while True: try: - self.assertEqual(collected(a),a[0]) + self.assertEqual(collected(a), a[0]) a[0] *= 2 except OverflowError: break # test whether pyspark can correctly handle float types for t in float_types: # test upper bound and precision - a = array.array(t,[1.0]) + a = array.array(t, [1.0]) while not math.isinf(a[0]): - self.assertEqual(collected(a),a[0]) + self.assertEqual(collected(a), a[0]) a[0] *= 2 a[0] += 1 # test lower bound - a = array.array(t,[1.0]) - while a[0]!=0: - self.assertEqual(collected(a),a[0]) + a = array.array(t, [1.0]) + while a[0] != 0: + self.assertEqual(collected(a), a[0]) a[0] /= 2 # test whether pyspark can correctly handle unsupported types for t in unsupported_types: try: c = collected(a) - self.assertTrue(False) # if no exception thrown, fail the test + self.assertTrue(False) # if no exception thrown, fail the test except TypeError: - pass # catch the expected exception and do nothing + pass # catch the expected exception and do nothing except: - self.assertTrue(False) # if incorrect exception thrown, fail the test + # if incorrect exception thrown, fail the test + self.assertTrue(False) class HiveSparkSubmitTests(SparkSubmitTests): From 2059435b45ed1f6337a4f935adcd029084cfec91 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Mon, 17 Oct 2016 20:11:05 -0400 Subject: [PATCH 6/6] fix the same problem for byte and short --- python/pyspark/sql/tests.py | 17 ++++++++--------- .../sql/execution/python/EvaluatePython.scala | 14 ++++++++++---- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 6aa293fc93137..6dbdb3adb295d 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1750,7 +1750,6 @@ def collected(a): return df.collect()[0]["myarray"][0] # test whether pyspark can correctly handle int types for t in int_types: - is_unsigned = t.isupper() # test positive numbers a = array.array(t, [1]) while True: @@ -1760,14 +1759,13 @@ def collected(a): except OverflowError: break # test negative numbers - if not is_unsigned: - a = array.array(t, [-1]) - while True: - try: - self.assertEqual(collected(a), a[0]) - a[0] *= 2 - except OverflowError: - break + a = array.array(t, [-1]) + while True: + try: + self.assertEqual(collected(a), a[0]) + a[0] *= 2 + except OverflowError: + break # test whether pyspark can correctly handle float types for t in float_types: # test upper bound and precision @@ -1784,6 +1782,7 @@ def collected(a): # test whether pyspark can correctly handle unsupported types for t in unsupported_types: try: + a = array.array(t) c = collected(a) self.assertTrue(False) # if no exception thrown, fail the test except TypeError: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index 25e8cf1b29dc7..90726a2e01365 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -91,25 +91,31 @@ object EvaluatePython { case (c: Boolean, BooleanType) => c + case (c: Byte, ByteType) => c + case (c: Short, ByteType) => c.toByte case (c: Int, ByteType) => c.toByte case (c: Long, ByteType) => c.toByte + case (c: Byte, ShortType) => c.toShort + case (c: Short, ShortType) => c case (c: Int, ShortType) => c.toShort case (c: Long, ShortType) => c.toShort + case (c: Byte, IntegerType) => c.toInt + case (c: Short, IntegerType) => c.toInt case (c: Int, IntegerType) => c case (c: Long, IntegerType) => c.toInt + case (c: Byte, LongType) => c.toLong + case (c: Short, LongType) => c.toLong case (c: Int, LongType) => c.toLong case (c: Long, LongType) => c - case (c: Double, FloatType) => c.toFloat - - case (c: Double, DoubleType) => c - case (c: Float, FloatType) => c + case (c: Double, FloatType) => c.toFloat case (c: Float, DoubleType) => c.toDouble + case (c: Double, DoubleType) => c case (c: java.math.BigDecimal, dt: DecimalType) => Decimal(c, dt.precision, dt.scale)