diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index c832b0b182147..9722e9e9cae22 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -528,76 +528,76 @@ def check_datatype(datatype): _verify_type(PythonOnlyPoint(1.0, 2.0), PythonOnlyUDT()) self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], PythonOnlyUDT())) - # def test_infer_schema_with_udt(self): - # from pyspark.sql.tests import ExamplePoint, ExamplePointUDT - # row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - # df = self.sqlCtx.createDataFrame([row]) - # schema = df.schema - # field = [f for f in schema.fields if f.name == "point"][0] - # self.assertEqual(type(field.dataType), ExamplePointUDT) - # df.registerTempTable("labeled_point") - # point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point - # self.assertEqual(point, ExamplePoint(1.0, 2.0)) - - # row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) - # df = self.sqlCtx.createDataFrame([row]) - # schema = df.schema - # field = [f for f in schema.fields if f.name == "point"][0] - # self.assertEqual(type(field.dataType), PythonOnlyUDT) - # df.registerTempTable("labeled_point") - # point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point - # self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) - - # def test_apply_schema_with_udt(self): - # from pyspark.sql.tests import ExamplePoint, ExamplePointUDT - # row = (1.0, ExamplePoint(1.0, 2.0)) - # schema = StructType([StructField("label", DoubleType(), False), - # StructField("point", ExamplePointUDT(), False)]) - # df = self.sqlCtx.createDataFrame([row], schema) - # point = df.head().point - # self.assertEqual(point, ExamplePoint(1.0, 2.0)) - - # row = (1.0, PythonOnlyPoint(1.0, 2.0)) - # schema = StructType([StructField("label", DoubleType(), False), - # StructField("point", PythonOnlyUDT(), False)]) - # df = self.sqlCtx.createDataFrame([row], schema) - # point = df.head().point - # self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) - - # def test_udf_with_udt(self): - # from pyspark.sql.tests import ExamplePoint, ExamplePointUDT - # row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - # df = self.sqlCtx.createDataFrame([row]) - # self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first()) - # udf = UserDefinedFunction(lambda p: p.y, DoubleType()) - # self.assertEqual(2.0, df.select(udf(df.point)).first()[0]) - # udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT()) - # self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0]) - - # row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) - # df = self.sqlCtx.createDataFrame([row]) - # self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first()) - # udf = UserDefinedFunction(lambda p: p.y, DoubleType()) - # self.assertEqual(2.0, df.select(udf(df.point)).first()[0]) - # udf2 = UserDefinedFunction(lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), PythonOnlyUDT()) - # self.assertEqual(PythonOnlyPoint(2.0, 3.0), df.select(udf2(df.point)).first()[0]) - - # def test_parquet_with_udt(self): - # from pyspark.sql.tests import ExamplePoint, ExamplePointUDT - # row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - # df0 = self.sqlCtx.createDataFrame([row]) - # output_dir = os.path.join(self.tempdir.name, "labeled_point") - # df0.write.parquet(output_dir) - # df1 = self.sqlCtx.read.parquet(output_dir) - # point = df1.head().point - # self.assertEqual(point, ExamplePoint(1.0, 2.0)) - - # row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) - # df0 = self.sqlCtx.createDataFrame([row]) - # df0.write.parquet(output_dir, mode='overwrite') - # df1 = self.sqlCtx.read.parquet(output_dir) - # point = df1.head().point - # self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) + def test_infer_schema_with_udt(self): + from pyspark.sql.tests import ExamplePoint, ExamplePointUDT + row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) + df = self.sqlCtx.createDataFrame([row]) + schema = df.schema + field = [f for f in schema.fields if f.name == "point"][0] + self.assertEqual(type(field.dataType), ExamplePointUDT) + df.registerTempTable("labeled_point") + point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point + self.assertEqual(point, ExamplePoint(1.0, 2.0)) + + row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) + df = self.sqlCtx.createDataFrame([row]) + schema = df.schema + field = [f for f in schema.fields if f.name == "point"][0] + self.assertEqual(type(field.dataType), PythonOnlyUDT) + df.registerTempTable("labeled_point") + point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point + self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) + + def test_apply_schema_with_udt(self): + from pyspark.sql.tests import ExamplePoint, ExamplePointUDT + row = (1.0, ExamplePoint(1.0, 2.0)) + schema = StructType([StructField("label", DoubleType(), False), + StructField("point", ExamplePointUDT(), False)]) + df = self.sqlCtx.createDataFrame([row], schema) + point = df.head().point + self.assertEqual(point, ExamplePoint(1.0, 2.0)) + + row = (1.0, PythonOnlyPoint(1.0, 2.0)) + schema = StructType([StructField("label", DoubleType(), False), + StructField("point", PythonOnlyUDT(), False)]) + df = self.sqlCtx.createDataFrame([row], schema) + point = df.head().point + self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) + + def test_udf_with_udt(self): + from pyspark.sql.tests import ExamplePoint, ExamplePointUDT + row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) + df = self.sqlCtx.createDataFrame([row]) + self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first()) + udf = UserDefinedFunction(lambda p: p.y, DoubleType()) + self.assertEqual(2.0, df.select(udf(df.point)).first()[0]) + udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT()) + self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0]) + + row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) + df = self.sqlCtx.createDataFrame([row]) + self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first()) + udf = UserDefinedFunction(lambda p: p.y, DoubleType()) + self.assertEqual(2.0, df.select(udf(df.point)).first()[0]) + udf2 = UserDefinedFunction(lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), PythonOnlyUDT()) + self.assertEqual(PythonOnlyPoint(2.0, 3.0), df.select(udf2(df.point)).first()[0]) + + def test_parquet_with_udt(self): + from pyspark.sql.tests import ExamplePoint, ExamplePointUDT + row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) + df0 = self.sqlCtx.createDataFrame([row]) + output_dir = os.path.join(self.tempdir.name, "labeled_point") + df0.write.parquet(output_dir) + df1 = self.sqlCtx.read.parquet(output_dir) + point = df1.head().point + self.assertEqual(point, ExamplePoint(1.0, 2.0)) + + row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) + df0 = self.sqlCtx.createDataFrame([row]) + df0.write.parquet(output_dir, mode='overwrite') + df1 = self.sqlCtx.read.parquet(output_dir) + point = df1.head().point + self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) def test_unionAll_with_udt(self): from pyspark.sql.tests import ExamplePoint, ExamplePointUDT diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 5c427845a775b..902644e735ea6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -52,6 +52,8 @@ object RowEncoder { case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | BinaryType | CalendarIntervalType => inputObject + case p: PythonUserDefinedType => extractorsFor(inputObject, p.sqlType) + case udt: UserDefinedType[_] => val obj = NewInstance( udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), @@ -151,10 +153,14 @@ object RowEncoder { private def constructorFor(schema: StructType): Expression = { val fields = schema.zipWithIndex.map { case (f, i) => - val field = BoundReference(i, f.dataType, f.nullable) + val dt = f.dataType match { + case p: PythonUserDefinedType => p.sqlType + case other => other + } + val field = BoundReference(i, dt, f.nullable) If( IsNull(field), - Literal.create(null, externalDataTypeFor(f.dataType)), + Literal.create(null, externalDataTypeFor(dt)), constructorFor(field) ) }