Skip to content
13 changes: 10 additions & 3 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,10 +574,13 @@ def _parse_datatype_json_value(json_value):


# Mapping Python types to Spark SQL DataType
# int -> LongType below so we can do not have to deal with
# the differences between Java int and Python ints when
# inferring data types. SPARK-5722
_type_mappings = {
type(None): NullType,
bool: BooleanType,
int: IntegerType,
int: LongType,
long: LongType,
float: DoubleType,
str: StringType,
Expand Down Expand Up @@ -681,6 +684,8 @@ def _need_python_to_sql_conversion(dataType):
_need_python_to_sql_conversion(dataType.valueType)
elif isinstance(dataType, UserDefinedType):
return True
elif isinstance(dataType, LongType):
return True
else:
return False

Expand Down Expand Up @@ -734,6 +739,8 @@ def converter(obj):
return lambda m: dict([(key_converter(k), value_converter(v)) for k, v in m.items()])
elif isinstance(dataType, UserDefinedType):
return lambda obj: dataType.serialize(obj)
elif isinstance(dataType, LongType):
return lambda x: long(x)
else:
raise ValueError("Unexpected type %r" % dataType)

Expand Down Expand Up @@ -926,11 +933,11 @@ def _infer_schema_type(obj, dataType):
>>> schema = _parse_schema_abstract("a b c d")
>>> row = (1, 1.0, "str", datetime.date(2014, 10, 10))
>>> _infer_schema_type(row, schema)
StructType...IntegerType...DoubleType...StringType...DateType...
StructType...LongType...DoubleType...StringType...DateType...
>>> row = [[1], {"key": (1, 2.0)}]
>>> schema = _parse_schema_abstract("a[] b{c d}")
>>> _infer_schema_type(row, schema)
StructType...a,ArrayType...b,MapType(StringType,...c,IntegerType...
StructType...a,ArrayType...b,MapType(StringType,...c,LongType...
"""
if dataType is None:
return _infer_type(obj)
Expand Down
24 changes: 23 additions & 1 deletion python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
CloudPickleSerializer, CompressedSerializer
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter
from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \
UserDefinedType, DoubleType
UserDefinedType, DoubleType, LongType, _infer_type
from pyspark import shuffle

_have_scipy = False
Expand Down Expand Up @@ -923,6 +923,28 @@ def test_infer_schema(self):
result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'")
self.assertEqual(1, result.first()[0])

# SPARK-5722
def test_infer_long_type(self):
longrow = [Row(f1='a', f2=100000000000000)]
lrdd = self.sc.parallelize(longrow)
slrdd = self.sqlCtx.inferSchema(lrdd)
self.assertEqual(slrdd.schema().fields[1].dataType, LongType())

# this saving as Parquet caused issues as well.
output_dir = os.path.join(self.tempdir.name, "infer_long_type")
slrdd.saveAsParquetFile(output_dir)
df1 = self.sqlCtx.parquetFile(output_dir)
self.assertEquals('a', df1.first().f1)
self.assertEquals(100000000000000, df1.first().f2)

self.assertEqual(_infer_type(1), LongType())
self.assertEqual(_infer_type(2**10), LongType())
self.assertEqual(_infer_type(2**20), LongType())
self.assertEqual(_infer_type(2**31 - 1), LongType())
self.assertEqual(_infer_type(2**31), LongType())
self.assertEqual(_infer_type(2**61), LongType())
self.assertEqual(_infer_type(2**71), LongType())

def test_convert_row_to_dict(self):
row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})
self.assertEqual(1, row.asDict()['l'][0].a)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ object EvaluatePython {

case (dec: BigDecimal, dt: DecimalType) => dec.underlying() // Pyrolite can handle BigDecimal

case (_, LongType) => obj.asInstanceOf[Long]

// Pyrolite can handle Timestamp
case (other, _) => other
}
Expand Down