diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index a975dc19cb78e..a0a028446d5fd 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -111,10 +111,9 @@ def run(self): java_import(gateway.jvm, "org.apache.spark.api.java.*") java_import(gateway.jvm, "org.apache.spark.api.python.*") java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*") - java_import(gateway.jvm, "org.apache.spark.sql.SQLContext") - java_import(gateway.jvm, "org.apache.spark.sql.hive.HiveContext") - java_import(gateway.jvm, "org.apache.spark.sql.hive.LocalHiveContext") - java_import(gateway.jvm, "org.apache.spark.sql.hive.TestHiveContext") + # TODO(davies): move into sql + java_import(gateway.jvm, "org.apache.spark.sql.*") + java_import(gateway.jvm, "org.apache.spark.sql.hive.*") java_import(gateway.jvm, "scala.Tuple2") return gateway diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 1990323249cf6..f16eb361d306f 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -20,15 +20,19 @@ - L{SQLContext} Main entry point for SQL functionality. - - L{SchemaRDD} + - L{DataFrame} A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In - addition to normal RDD operations, SchemaRDDs also support SQL. + addition to normal RDD operations, DataFrames also support SQL. + - L{GroupedDataFrame} + - L{Column} + Column is a DataFrame with a single column. - L{Row} A Row of data returned by a Spark SQL query. - L{HiveContext} Main entry point for accessing data stored in Apache Hive.. """ +import sys import itertools import decimal import datetime @@ -36,6 +40,9 @@ import warnings import json import re +import random +import os +from tempfile import NamedTemporaryFile from array import array from operator import itemgetter from itertools import imap @@ -43,6 +50,7 @@ from py4j.protocol import Py4JError from py4j.java_collections import ListConverter, MapConverter +from pyspark.context import SparkContext from pyspark.rdd import RDD from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer, \ CloudPickleSerializer, UTF8Deserializer @@ -54,7 +62,8 @@ "StringType", "BinaryType", "BooleanType", "DateType", "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType", "LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType", - "SQLContext", "HiveContext", "SchemaRDD", "Row"] + "SQLContext", "HiveContext", "DataFrame", "GroupedDataFrame", "Column", "Row", + "SchemaRDD"] class DataType(object): @@ -1171,7 +1180,7 @@ def Dict(d): class Row(tuple): - """ Row in SchemaRDD """ + """ Row in DataFrame """ __DATATYPE__ = dataType __FIELDS__ = tuple(f.name for f in dataType.fields) __slots__ = () @@ -1198,7 +1207,7 @@ class SQLContext(object): """Main entry point for Spark SQL functionality. - A SQLContext can be used create L{SchemaRDD}, register L{SchemaRDD} as + A SQLContext can be used create L{DataFrame}, register L{DataFrame} as tables, execute SQL over tables, cache tables, and read parquet files. """ @@ -1209,8 +1218,8 @@ def __init__(self, sparkContext, sqlContext=None): :param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new SQLContext in the JVM, instead we make all calls to this object. - >>> srdd = sqlCtx.inferSchema(rdd) - >>> sqlCtx.inferSchema(srdd) # doctest: +IGNORE_EXCEPTION_DETAIL + >>> df = sqlCtx.inferSchema(rdd) + >>> sqlCtx.inferSchema(df) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... TypeError:... @@ -1225,12 +1234,12 @@ def __init__(self, sparkContext, sqlContext=None): >>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1L, ... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1), ... time=datetime(2014, 8, 1, 14, 1, 5))]) - >>> srdd = sqlCtx.inferSchema(allTypes) - >>> srdd.registerTempTable("allTypes") + >>> df = sqlCtx.inferSchema(allTypes) + >>> df.registerTempTable("allTypes") >>> sqlCtx.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' ... 'from allTypes where b and i > 0').collect() [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0...8, 1, 14, 1, 5), a=1)] - >>> srdd.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, + >>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, ... x.row.a, x.list)).collect() [(1, u'string', 1.0, 1, True, ...(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])] """ @@ -1309,23 +1318,23 @@ def inferSchema(self, rdd, samplingRatio=None): ... [Row(field1=1, field2="row1"), ... Row(field1=2, field2="row2"), ... Row(field1=3, field2="row3")]) - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.collect()[0] + >>> df = sqlCtx.inferSchema(rdd) + >>> df.collect()[0] Row(field1=1, field2=u'row1') >>> NestedRow = Row("f1", "f2") >>> nestedRdd1 = sc.parallelize([ ... NestedRow(array('i', [1, 2]), {"row1": 1.0}), ... NestedRow(array('i', [2, 3]), {"row2": 2.0})]) - >>> srdd = sqlCtx.inferSchema(nestedRdd1) - >>> srdd.collect() + >>> df = sqlCtx.inferSchema(nestedRdd1) + >>> df.collect() [Row(f1=[1, 2], f2={u'row1': 1.0}), ..., f2={u'row2': 2.0})] >>> nestedRdd2 = sc.parallelize([ ... NestedRow([[1, 2], [2, 3]], [1, 2]), ... NestedRow([[2, 3], [3, 4]], [2, 3])]) - >>> srdd = sqlCtx.inferSchema(nestedRdd2) - >>> srdd.collect() + >>> df = sqlCtx.inferSchema(nestedRdd2) + >>> df.collect() [Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), ..., f2=[2, 3])] >>> from collections import namedtuple @@ -1334,13 +1343,13 @@ def inferSchema(self, rdd, samplingRatio=None): ... [CustomRow(field1=1, field2="row1"), ... CustomRow(field1=2, field2="row2"), ... CustomRow(field1=3, field2="row3")]) - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.collect()[0] + >>> df = sqlCtx.inferSchema(rdd) + >>> df.collect()[0] Row(field1=1, field2=u'row1') """ - if isinstance(rdd, SchemaRDD): - raise TypeError("Cannot apply schema to SchemaRDD") + if isinstance(rdd, DataFrame): + raise TypeError("Cannot apply schema to DataFrame") first = rdd.first() if not first: @@ -1384,10 +1393,10 @@ def applySchema(self, rdd, schema): >>> rdd2 = sc.parallelize([(1, "row1"), (2, "row2"), (3, "row3")]) >>> schema = StructType([StructField("field1", IntegerType(), False), ... StructField("field2", StringType(), False)]) - >>> srdd = sqlCtx.applySchema(rdd2, schema) - >>> sqlCtx.registerRDDAsTable(srdd, "table1") - >>> srdd2 = sqlCtx.sql("SELECT * from table1") - >>> srdd2.collect() + >>> df = sqlCtx.applySchema(rdd2, schema) + >>> sqlCtx.registerRDDAsTable(df, "table1") + >>> df2 = sqlCtx.sql("SELECT * from table1") + >>> df2.collect() [Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')] >>> from datetime import date, datetime @@ -1410,15 +1419,15 @@ def applySchema(self, rdd, schema): ... StructType([StructField("b", ShortType(), False)]), False), ... StructField("list", ArrayType(ByteType(), False), False), ... StructField("null", DoubleType(), True)]) - >>> srdd = sqlCtx.applySchema(rdd, schema) - >>> results = srdd.map( + >>> df = sqlCtx.applySchema(rdd, schema) + >>> results = df.map( ... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.date, ... x.time, x.map["a"], x.struct.b, x.list, x.null)) >>> results.collect()[0] # doctest: +NORMALIZE_WHITESPACE (127, -128, -32768, 32767, 2147483647, 1.0, datetime.date(2010, 1, 1), datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) - >>> srdd.registerTempTable("table2") + >>> df.registerTempTable("table2") >>> sqlCtx.sql( ... "SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " + ... "short1 + 1 AS short1, short2 - 1 AS short2, int - 1 AS int, " + @@ -1431,13 +1440,13 @@ def applySchema(self, rdd, schema): >>> abstract = "byte short float time map{} struct(b) list[]" >>> schema = _parse_schema_abstract(abstract) >>> typedSchema = _infer_schema_type(rdd.first(), schema) - >>> srdd = sqlCtx.applySchema(rdd, typedSchema) - >>> srdd.collect() + >>> df = sqlCtx.applySchema(rdd, typedSchema) + >>> df.collect() [Row(byte=127, short=-32768, float=1.0, time=..., list=[1, 2, 3])] """ - if isinstance(rdd, SchemaRDD): - raise TypeError("Cannot apply schema to SchemaRDD") + if isinstance(rdd, DataFrame): + raise TypeError("Cannot apply schema to DataFrame") if not isinstance(schema, StructType): raise TypeError("schema should be StructType") @@ -1457,8 +1466,8 @@ def applySchema(self, rdd, schema): rdd = rdd.map(converter) jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) - srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) - return SchemaRDD(srdd, self) + df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) + return DataFrame(df, self) def registerRDDAsTable(self, rdd, tableName): """Registers the given RDD as a temporary table in the catalog. @@ -1466,34 +1475,34 @@ def registerRDDAsTable(self, rdd, tableName): Temporary tables exist only during the lifetime of this instance of SQLContext. - >>> srdd = sqlCtx.inferSchema(rdd) - >>> sqlCtx.registerRDDAsTable(srdd, "table1") + >>> df = sqlCtx.inferSchema(rdd) + >>> sqlCtx.registerRDDAsTable(df, "table1") """ - if (rdd.__class__ is SchemaRDD): - srdd = rdd._jschema_rdd.baseSchemaRDD() - self._ssql_ctx.registerRDDAsTable(srdd, tableName) + if (rdd.__class__ is DataFrame): + df = rdd._jdf + self._ssql_ctx.registerRDDAsTable(df, tableName) else: - raise ValueError("Can only register SchemaRDD as table") + raise ValueError("Can only register DataFrame as table") def parquetFile(self, path): - """Loads a Parquet file, returning the result as a L{SchemaRDD}. + """Loads a Parquet file, returning the result as a L{DataFrame}. >>> import tempfile, shutil >>> parquetFile = tempfile.mkdtemp() >>> shutil.rmtree(parquetFile) - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.saveAsParquetFile(parquetFile) - >>> srdd2 = sqlCtx.parquetFile(parquetFile) - >>> sorted(srdd.collect()) == sorted(srdd2.collect()) + >>> df = sqlCtx.inferSchema(rdd) + >>> df.saveAsParquetFile(parquetFile) + >>> df2 = sqlCtx.parquetFile(parquetFile) + >>> sorted(df.collect()) == sorted(df2.collect()) True """ - jschema_rdd = self._ssql_ctx.parquetFile(path) - return SchemaRDD(jschema_rdd, self) + jdf = self._ssql_ctx.parquetFile(path) + return DataFrame(jdf, self) def jsonFile(self, path, schema=None, samplingRatio=1.0): """ Loads a text file storing one JSON object per line as a - L{SchemaRDD}. + L{DataFrame}. If the schema is provided, applies the given schema to this JSON dataset. @@ -1508,23 +1517,23 @@ def jsonFile(self, path, schema=None, samplingRatio=1.0): >>> for json in jsonStrings: ... print>>ofn, json >>> ofn.close() - >>> srdd1 = sqlCtx.jsonFile(jsonFile) - >>> sqlCtx.registerRDDAsTable(srdd1, "table1") - >>> srdd2 = sqlCtx.sql( + >>> df1 = sqlCtx.jsonFile(jsonFile) + >>> sqlCtx.registerRDDAsTable(df1, "table1") + >>> df2 = sqlCtx.sql( ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " ... "field6 as f4 from table1") - >>> for r in srdd2.collect(): + >>> for r in df2.collect(): ... print r Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')]) Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) - >>> srdd3 = sqlCtx.jsonFile(jsonFile, srdd1.schema()) - >>> sqlCtx.registerRDDAsTable(srdd3, "table2") - >>> srdd4 = sqlCtx.sql( + >>> df3 = sqlCtx.jsonFile(jsonFile, df1.schema()) + >>> sqlCtx.registerRDDAsTable(df3, "table2") + >>> df4 = sqlCtx.sql( ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " ... "field6 as f4 from table2") - >>> for r in srdd4.collect(): + >>> for r in df4.collect(): ... print r Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')]) @@ -1536,23 +1545,23 @@ def jsonFile(self, path, schema=None, samplingRatio=1.0): ... StructType([ ... StructField("field5", ... ArrayType(IntegerType(), False), True)]), False)]) - >>> srdd5 = sqlCtx.jsonFile(jsonFile, schema) - >>> sqlCtx.registerRDDAsTable(srdd5, "table3") - >>> srdd6 = sqlCtx.sql( + >>> df5 = sqlCtx.jsonFile(jsonFile, schema) + >>> sqlCtx.registerRDDAsTable(df5, "table3") + >>> df6 = sqlCtx.sql( ... "SELECT field2 AS f1, field3.field5 as f2, " ... "field3.field5[0] as f3 from table3") - >>> srdd6.collect() + >>> df6.collect() [Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)] """ if schema is None: - srdd = self._ssql_ctx.jsonFile(path, samplingRatio) + df = self._ssql_ctx.jsonFile(path, samplingRatio) else: scala_datatype = self._ssql_ctx.parseDataType(schema.json()) - srdd = self._ssql_ctx.jsonFile(path, scala_datatype) - return SchemaRDD(srdd, self) + df = self._ssql_ctx.jsonFile(path, scala_datatype) + return DataFrame(df, self) def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): - """Loads an RDD storing one JSON object per string as a L{SchemaRDD}. + """Loads an RDD storing one JSON object per string as a L{DataFrame}. If the schema is provided, applies the given schema to this JSON dataset. @@ -1560,23 +1569,23 @@ def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): Otherwise, it samples the dataset with ratio `samplingRatio` to determine the schema. - >>> srdd1 = sqlCtx.jsonRDD(json) - >>> sqlCtx.registerRDDAsTable(srdd1, "table1") - >>> srdd2 = sqlCtx.sql( + >>> df1 = sqlCtx.jsonRDD(json) + >>> sqlCtx.registerRDDAsTable(df1, "table1") + >>> df2 = sqlCtx.sql( ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " ... "field6 as f4 from table1") - >>> for r in srdd2.collect(): + >>> for r in df2.collect(): ... print r Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')]) Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) - >>> srdd3 = sqlCtx.jsonRDD(json, srdd1.schema()) - >>> sqlCtx.registerRDDAsTable(srdd3, "table2") - >>> srdd4 = sqlCtx.sql( + >>> df3 = sqlCtx.jsonRDD(json, df1.schema()) + >>> sqlCtx.registerRDDAsTable(df3, "table2") + >>> df4 = sqlCtx.sql( ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " ... "field6 as f4 from table2") - >>> for r in srdd4.collect(): + >>> for r in df4.collect(): ... print r Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')]) @@ -1588,12 +1597,12 @@ def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): ... StructType([ ... StructField("field5", ... ArrayType(IntegerType(), False), True)]), False)]) - >>> srdd5 = sqlCtx.jsonRDD(json, schema) - >>> sqlCtx.registerRDDAsTable(srdd5, "table3") - >>> srdd6 = sqlCtx.sql( + >>> df5 = sqlCtx.jsonRDD(json, schema) + >>> sqlCtx.registerRDDAsTable(df5, "table3") + >>> df6 = sqlCtx.sql( ... "SELECT field2 AS f1, field3.field5 as f2, " ... "field3.field5[0] as f3 from table3") - >>> srdd6.collect() + >>> df6.collect() [Row(f1=u'row1', f2=None,...Row(f1=u'row3', f2=[], f3=None)] >>> sqlCtx.jsonRDD(sc.parallelize(['{}', @@ -1615,33 +1624,33 @@ def func(iterator): keyed._bypass_serializer = True jrdd = keyed._jrdd.map(self._jvm.BytesToString()) if schema is None: - srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio) + df = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio) else: scala_datatype = self._ssql_ctx.parseDataType(schema.json()) - srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) - return SchemaRDD(srdd, self) + df = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) + return DataFrame(df, self) def sql(self, sqlQuery): - """Return a L{SchemaRDD} representing the result of the given query. + """Return a L{DataFrame} representing the result of the given query. - >>> srdd = sqlCtx.inferSchema(rdd) - >>> sqlCtx.registerRDDAsTable(srdd, "table1") - >>> srdd2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1") - >>> srdd2.collect() + >>> df = sqlCtx.inferSchema(rdd) + >>> sqlCtx.registerRDDAsTable(df, "table1") + >>> df2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1") + >>> df2.collect() [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')] """ - return SchemaRDD(self._ssql_ctx.sql(sqlQuery), self) + return DataFrame(self._ssql_ctx.sql(sqlQuery), self) def table(self, tableName): - """Returns the specified table as a L{SchemaRDD}. + """Returns the specified table as a L{DataFrame}. - >>> srdd = sqlCtx.inferSchema(rdd) - >>> sqlCtx.registerRDDAsTable(srdd, "table1") - >>> srdd2 = sqlCtx.table("table1") - >>> sorted(srdd.collect()) == sorted(srdd2.collect()) + >>> df = sqlCtx.inferSchema(rdd) + >>> sqlCtx.registerRDDAsTable(df, "table1") + >>> df2 = sqlCtx.table("table1") + >>> sorted(df.collect()) == sorted(df2.collect()) True """ - return SchemaRDD(self._ssql_ctx.table(tableName), self) + return DataFrame(self._ssql_ctx.table(tableName), self) def cacheTable(self, tableName): """Caches the specified table in-memory.""" @@ -1707,7 +1716,7 @@ def _create_row(fields, values): class Row(tuple): """ - A row in L{SchemaRDD}. The fields in it can be accessed like attributes. + A row in L{DataFrame}. The fields in it can be accessed like attributes. Row can be used to create a row object by using named arguments, the fields will be sorted by names. @@ -1799,16 +1808,15 @@ def inherit_doc(cls): return cls -@inherit_doc -class SchemaRDD(RDD): +class DataFrame(object): """An RDD of L{Row} objects that has an associated schema. - The underlying JVM object is a SchemaRDD, not a PythonRDD, so we can + The underlying JVM object is a DataFrame, so we can utilize the relational query api exposed by Spark SQL. For normal L{pyspark.rdd.RDD} operations (map, count, etc.) the - L{SchemaRDD} is not operated on directly, as it's underlying + L{DataFrame} is not operated on directly, as it's underlying implementation is an RDD composed of Java objects. Instead it is converted to a PythonRDD in the JVM, on which Python operations can be done. @@ -1818,92 +1826,89 @@ class SchemaRDD(RDD): etc) so that PySpark sees them as Row objects with named fields. """ - def __init__(self, jschema_rdd, sql_ctx): + def __init__(self, jdf, sql_ctx): + self._jdf = jdf self.sql_ctx = sql_ctx - self._sc = sql_ctx._sc - clsName = jschema_rdd.getClass().getName() - assert clsName.endswith("SchemaRDD"), "jschema_rdd must be SchemaRDD" - self._jschema_rdd = jschema_rdd - self._id = None + self._sc = sql_ctx and sql_ctx._sc self.is_cached = False - self.is_checkpointed = False - self.ctx = self.sql_ctx._sc - # the _jrdd is created by javaToPython(), serialized by pickle - self._jrdd_deserializer = AutoBatchedSerializer(PickleSerializer()) @property - def _jrdd(self): + def rdd(self): """Lazy evaluation of PythonRDD object. Only done when a user calls methods defined by the L{pyspark.rdd.RDD} super class (map, filter, etc.). """ - if not hasattr(self, '_lazy_jrdd'): - self._lazy_jrdd = self._jschema_rdd.baseSchemaRDD().javaToPython() - return self._lazy_jrdd + if not hasattr(self, '_lazy_rdd'): + jrdd = self._jdf.javaToPython() + rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer())) + schema = self.schema() - def id(self): - if self._id is None: - self._id = self._jrdd.id() - return self._id + def applySchema(it): + cls = _create_cls(schema) + return itertools.imap(cls, it) + + self._lazy_rdd = rdd.mapPartitions(applySchema) + + return self._lazy_rdd def limit(self, num): """Limit the result count to the number specified. - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.limit(2).collect() + >>> df = sqlCtx.inferSchema(rdd) + >>> df.limit(2).collect() [Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')] - >>> srdd.limit(0).collect() + >>> df.limit(0).collect() [] """ - rdd = self._jschema_rdd.baseSchemaRDD().limit(num) - return SchemaRDD(rdd, self.sql_ctx) + jdf = self._jdf.limit(num) + return DataFrame(jdf, self.sql_ctx) def toJSON(self, use_unicode=False): - """Convert a SchemaRDD into a MappedRDD of JSON documents; one document per row. + """Convert a DataFrame into a MappedRDD of JSON documents; one document per row. - >>> srdd1 = sqlCtx.jsonRDD(json) - >>> sqlCtx.registerRDDAsTable(srdd1, "table1") - >>> srdd2 = sqlCtx.sql( "SELECT * from table1") - >>> srdd2.toJSON().take(1)[0] == '{"field1":1,"field2":"row1","field3":{"field4":11}}' + >>> df1 = sqlCtx.jsonRDD(json) + >>> sqlCtx.registerRDDAsTable(df1, "table1") + >>> df2 = sqlCtx.sql( "SELECT * from table1") + >>> df2.toJSON().take(1)[0] == '{"field1":1,"field2":"row1","field3":{"field4":11}}' True - >>> srdd3 = sqlCtx.sql( "SELECT field3.field4 from table1") - >>> srdd3.toJSON().collect() == ['{"field4":11}', '{"field4":22}', '{"field4":33}'] + >>> df3 = sqlCtx.sql( "SELECT field3.field4 from table1") + >>> df3.toJSON().collect() == ['{"field4":11}', '{"field4":22}', '{"field4":33}'] True """ - rdd = self._jschema_rdd.baseSchemaRDD().toJSON() + rdd = self._jdf.toJSON() return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode)) def saveAsParquetFile(self, path): """Save the contents as a Parquet file, preserving the schema. Files that are written out using this method can be read back in as - a SchemaRDD using the L{SQLContext.parquetFile} method. + a DataFrame using the L{SQLContext.parquetFile} method. >>> import tempfile, shutil >>> parquetFile = tempfile.mkdtemp() >>> shutil.rmtree(parquetFile) - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.saveAsParquetFile(parquetFile) - >>> srdd2 = sqlCtx.parquetFile(parquetFile) - >>> sorted(srdd2.collect()) == sorted(srdd.collect()) + >>> df = sqlCtx.inferSchema(rdd) + >>> df.saveAsParquetFile(parquetFile) + >>> df2 = sqlCtx.parquetFile(parquetFile) + >>> sorted(df2.collect()) == sorted(df.collect()) True """ - self._jschema_rdd.saveAsParquetFile(path) + self._jdf.saveAsParquetFile(path) def registerTempTable(self, name): """Registers this RDD as a temporary table using the given name. The lifetime of this temporary table is tied to the L{SQLContext} - that was used to create this SchemaRDD. + that was used to create this DataFrame. - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.registerTempTable("test") - >>> srdd2 = sqlCtx.sql("select * from test") - >>> sorted(srdd.collect()) == sorted(srdd2.collect()) + >>> df = sqlCtx.inferSchema(rdd) + >>> df.registerTempTable("test") + >>> df2 = sqlCtx.sql("select * from test") + >>> sorted(df.collect()) == sorted(df2.collect()) True """ - self._jschema_rdd.registerTempTable(name) + self._jdf.registerTempTable(name) def registerAsTable(self, name): """DEPRECATED: use registerTempTable() instead""" @@ -1911,62 +1916,65 @@ def registerAsTable(self, name): self.registerTempTable(name) def insertInto(self, tableName, overwrite=False): - """Inserts the contents of this SchemaRDD into the specified table. + """Inserts the contents of this DataFrame into the specified table. Optionally overwriting any existing data. """ - self._jschema_rdd.insertInto(tableName, overwrite) + self._jdf.insertInto(tableName, overwrite) def saveAsTable(self, tableName): - """Creates a new table with the contents of this SchemaRDD.""" - self._jschema_rdd.saveAsTable(tableName) + """Creates a new table with the contents of this DataFrame.""" + self._jdf.saveAsTable(tableName) def schema(self): - """Returns the schema of this SchemaRDD (represented by + """Returns the schema of this DataFrame (represented by a L{StructType}).""" - return _parse_datatype_json_string(self._jschema_rdd.baseSchemaRDD().schema().json()) - - def schemaString(self): - """Returns the output schema in the tree format.""" - return self._jschema_rdd.schemaString() + return _parse_datatype_json_string(self._jdf.schema().json()) + # def schemaString(self): + # """Returns the output schema in the tree format.""" + # return self._jdf.schemaString() + # def printSchema(self): """Prints out the schema in the tree format.""" - print self.schemaString() + self.printSchema() def count(self): """Return the number of elements in this RDD. Unlike the base RDD implementation of count, this implementation - leverages the query optimizer to compute the count on the SchemaRDD, + leverages the query optimizer to compute the count on the DataFrame, which supports features such as filter pushdown. - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.count() + >>> df = sqlCtx.inferSchema(rdd) + >>> df.count() 3L - >>> srdd.count() == srdd.map(lambda x: x).count() + >>> df.count() == df.map(lambda x: x).count() True """ - return self._jschema_rdd.count() + return self._jdf.count() def collect(self): - """Return a list that contains all of the rows in this RDD. + """Return a list that contains all of the rows. Each object in the list is a Row, the fields can be accessed as attributes. - Unlike the base RDD implementation of collect, this implementation - leverages the query optimizer to perform a collect on the SchemaRDD, - which supports features such as filter pushdown. - - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.collect() + >>> df = sqlCtx.inferSchema(rdd) + >>> df.collect() [Row(field1=1, field2=u'row1'), ..., Row(field1=3, field2=u'row3')] """ - with SCCallSiteSync(self.context) as css: - bytesInJava = self._jschema_rdd.baseSchemaRDD().collectToPython().iterator() + with SCCallSiteSync(self._sc) as css: + bytesInJava = self._jdf.collectToPython().iterator() cls = _create_cls(self.schema()) - return map(cls, self._collect_iterator_through_file(bytesInJava)) + tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir) + tempFile.close() + self._sc._writeToFile(bytesInJava, tempFile.name) + # Read the data into Python and deserialize it: + with open(tempFile.name, 'rb') as tempFile: + rs = list(BatchedSerializer(PickleSerializer()).load_stream(tempFile)) + os.unlink(tempFile.name) + return [cls(r) for r in rs] def take(self, num): """Take the first num rows of the RDD. @@ -1974,130 +1982,396 @@ def take(self, num): Each object in the list is a Row, the fields can be accessed as attributes. - Unlike the base RDD implementation of take, this implementation - leverages the query optimizer to perform a collect on a SchemaRDD, - which supports features such as filter pushdown. - - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.take(2) + >>> df = sqlCtx.inferSchema(rdd) + >>> df.take(2) [Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')] """ return self.limit(num).collect() + def first(self): + return self.rdd.first() + + def map(self, f): + return self.rdd.map(f) + # Convert each object in the RDD to a Row with the right class - # for this SchemaRDD, so that fields can be accessed as attributes. - def mapPartitionsWithIndex(self, f, preservesPartitioning=False): + # for this DataFrame, so that fields can be accessed as attributes. + def mapPartitions(self, f, preservesPartitioning=False): """ Return a new RDD by applying a function to each partition of this RDD, while tracking the index of the original partition. >>> rdd = sc.parallelize([1, 2, 3, 4], 4) - >>> def f(splitIndex, iterator): yield splitIndex - >>> rdd.mapPartitionsWithIndex(f).sum() - 6 + >>> def f(iterator): yield 1 + >>> rdd.mapPartitions(f).sum() + 4 """ - rdd = RDD(self._jrdd, self._sc, self._jrdd_deserializer) - - schema = self.schema() - - def applySchema(_, it): - cls = _create_cls(schema) - return itertools.imap(cls, it) - - objrdd = rdd.mapPartitionsWithIndex(applySchema, preservesPartitioning) - return objrdd.mapPartitionsWithIndex(f, preservesPartitioning) + return self.rdd.mapPartitions(f, preservesPartitioning) # We override the default cache/persist/checkpoint behavior - # as we want to cache the underlying SchemaRDD object in the JVM, + # as we want to cache the underlying DataFrame object in the JVM, # not the PythonRDD checkpointed by the super class def cache(self): self.is_cached = True - self._jschema_rdd.cache() + self._jdf.cache() return self def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER): self.is_cached = True - javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel) - self._jschema_rdd.persist(javaStorageLevel) + javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel) + self._jdf.persist(javaStorageLevel) return self def unpersist(self, blocking=True): self.is_cached = False - self._jschema_rdd.unpersist(blocking) + self._jdf.unpersist(blocking) return self - def checkpoint(self): - self.is_checkpointed = True - self._jschema_rdd.checkpoint() + # def coalesce(self, numPartitions, shuffle=False): + # rdd = self._jdf.coalesce(numPartitions, shuffle, None) + # return DataFrame(rdd, self.sql_ctx) + + # def distinct(self, numPartitions=None): + # if numPartitions is None: + # rdd = self._jdf.distinct() + # else: + # rdd = self._jdf.distinct(numPartitions, None) + # return DataFrame(rdd, self.sql_ctx) + # + # def intersection(self, other): + # if (other.__class__ is DataFrame): + # rdd = self._jdf.intersection(other._jdf) + # return DataFrame(rdd, self.sql_ctx) + # else: + # raise ValueError("Can only intersect with another DataFrame") + + # def repartition(self, numPartitions): + # rdd = self._jdf.repartition(numPartitions, None) + # return DataFrame(rdd, self.sql_ctx) + # + # def subtract(self, other, numPartitions=None): + # if (other.__class__ is DataFrame): + # if numPartitions is None: + # rdd = self._jdf.subtract(other._jdf) + # else: + # rdd = self._jdf.subtract(other._jdf, + # numPartitions) + # return DataFrame(rdd, self.sql_ctx) + # else: + # raise ValueError("Can only subtract another DataFrame") + + def sample(self, withReplacement, fraction, seed=None): + """ + Return a sampled subset of this DataFrame. + + >>> df = sqlCtx.inferSchema(rdd) + >>> df.sample(False, 0.5, 97).count() + 2L + """ + assert fraction >= 0.0, "Negative fraction value: %s" % fraction + seed = seed if seed is not None else random.randint(0, sys.maxint) + rdd = self._jdf.sample(withReplacement, fraction, long(seed)) + return DataFrame(rdd, self.sql_ctx) + + # def takeSample(self, withReplacement, num, seed=None): + # """Return a fixed-size sampled subset of this DataFrame. + # + # >>> df = sqlCtx.inferSchema(rdd) + # >>> df.takeSample(False, 2, 97) + # [Row(field1=3, field2=u'row3'), Row(field1=1, field2=u'row1')] + # """ + # seed = seed if seed is not None else random.randint(0, sys.maxint) + # with SCCallSiteSync(self.context) as css: + # bytesInJava = self._jdf \ + # .takeSampleToPython(withReplacement, num, long(seed)) \ + # .iterator() + # cls = _create_cls(self.schema()) + # return map(cls, self._collect_iterator_through_file(bytesInJava)) - def isCheckpointed(self): - return self._jschema_rdd.isCheckpointed() + @property + def dtypes(self): + return [(f.name, str(f.dataType)) for f in self.schema().fields] - def getCheckpointFile(self): - checkpointFile = self._jschema_rdd.getCheckpointFile() - if checkpointFile.isDefined(): - return checkpointFile.get() + @property + def columns(self): + return [f.name for f in self.schema().fields] - def coalesce(self, numPartitions, shuffle=False): - rdd = self._jschema_rdd.coalesce(numPartitions, shuffle, None) - return SchemaRDD(rdd, self.sql_ctx) + def show(self): + raise NotImplemented - def distinct(self, numPartitions=None): - if numPartitions is None: - rdd = self._jschema_rdd.distinct() + def join(self, other, joinExprs=None, joinType=None): + if joinType is None: + if joinExprs is None: + jdf = self._jdf.join(other._jdf) + else: + jdf = self._jdf.join(other._jdf, joinExprs) else: - rdd = self._jschema_rdd.distinct(numPartitions, None) - return SchemaRDD(rdd, self.sql_ctx) + jdf = self._jdf.join(other._jdf, joinExprs, joinType) + return DataFrame(jdf, self.sql_ctx) + + def sort(self, *cols): + if not cols: + raise ValueError("should sort by at least one column") + for i, c in enumerate(cols): + if isinstance(c, basestring): + cols[i] = Column(c) + jcols = [c._jc for c in cols] + jdf = self._jdf.join(*jcols) + return DataFrame(jdf, self.sql_ctx) + + def tail(self): + raise NotImplemented + + def __getitem__(self, item): + if isinstance(item, basestring): + return Column(self._jdf.apply(item)) + # TODO projection + raise IndexError + + def __getattr__(self, name): + if isinstance(name, basestring): + return Column(self._jdf.apply(name)) + raise AttributeError + + def alias(self, name): + return DataFrame(getattr(self._jdf, "as")(name), self.sql_ctx) + + def select(self, *cols): + if not cols: + cols = ["*"] + if isinstance(cols[0], basestring): + cols = [_create_column_from_name(n) for n in cols] + else: + cols = [c._jc for c in cols] + jcols = ListConverter().convert(cols, self._sc._gateway._gateway_client) + jdf = self._jdf.select(jcols) + return DataFrame(jdf, self.sql_ctx) + + def where(self, cond): + return DataFrame(self._jdf.filter(cond._jc), self.sql_ctx) - def intersection(self, other): - if (other.__class__ is SchemaRDD): - rdd = self._jschema_rdd.intersection(other._jschema_rdd) - return SchemaRDD(rdd, self.sql_ctx) + def filter(self, col): + return DataFrame(self._jdf.filter(col._jc), self.sql_ctx) + + def groupby(self, *cols): + if cols and isinstance(cols[0], basestring): + cols = [_create_column_from_name(n) for n in cols] else: - raise ValueError("Can only intersect with another SchemaRDD") + cols = [c._jc for c in cols] + jcols = ListConverter().convert(cols, self._sc._gateway._gateway_client) + jdf = self._jdf.groupby(jcols) + return GroupedDataFrame(jdf, self.sql_ctx) - def repartition(self, numPartitions): - rdd = self._jschema_rdd.repartition(numPartitions, None) - return SchemaRDD(rdd, self.sql_ctx) - def subtract(self, other, numPartitions=None): - if (other.__class__ is SchemaRDD): - if numPartitions is None: - rdd = self._jschema_rdd.subtract(other._jschema_rdd) - else: - rdd = self._jschema_rdd.subtract(other._jschema_rdd, - numPartitions) - return SchemaRDD(rdd, self.sql_ctx) +# make SchemaRDD as an alias of DataFrame for backward compatibility +SchemaRDD = DataFrame + + +def dfapi(f): + def _api(self, *a): + ja = [v._jc if isinstance(v, Column) else v for v in a] + name = f.__name__ + jdf = getattr(self._jdf, name)(*ja) + return DataFrame(jdf, self.sql_ctx) + _api.__name__ = f.__name__ + _api.__doc__ = f.__doc__ + return _api + + +class GroupedDataFrame(object): + def __init__(self, jdf, sql_ctx): + self._jdf = jdf + self.sql_ctx = sql_ctx + + def agg(self, *exprs): + if len(exprs) == 1 and isinstance(exprs[0], dict): + jmap = MapConverter().convert(exprs[0], self.sql_ctx._sc._gateway._gateway_client) + jdf = self._jdf.agg(jmap) else: - raise ValueError("Can only subtract another SchemaRDD") + # Columns + assert all(isinstance(c, Column) for c in exprs), "all exprs should be Columns" + jdf = self._jdf.agg(*exprs) + return DataFrame(jdf, self.sql_ctx) - def sample(self, withReplacement, fraction, seed=None): - """ - Return a sampled subset of this SchemaRDD. + @dfapi + def count(self): + """ """ + + @dfapi + def mean(self): + """""" + + @dfapi + def max(self): + """""" + + @dfapi + def min(self): + """""" + + @dfapi + def sum(self): + """""" + + +SCALA_METHOD_MAPPINGS = { + '=': '$eq', + '>': '$greater', + '<': '$less', + '+': '$plus', + '-': '$minus', + '*': '$times', + '/': '$div', + '!': '$bang', + '@': '$at', + '#': '$hash', + '%': '$percent', + '^': '$up', + '&': '$amp', + '~': '$tilde', + '?': '$qmark', + '|': '$bar', + '\\': '$bslash', + ':': '$colon', +} - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.sample(False, 0.5, 97).count() - 2L - """ - assert fraction >= 0.0, "Negative fraction value: %s" % fraction - seed = seed if seed is not None else random.randint(0, sys.maxint) - rdd = self._jschema_rdd.sample(withReplacement, fraction, long(seed)) - return SchemaRDD(rdd, self.sql_ctx) - def takeSample(self, withReplacement, num, seed=None): - """Return a fixed-size sampled subset of this SchemaRDD. +def _create_column_from_literal(literal): + sc = SparkContext._active_spark_context + return sc._jvm.Literal.apply(literal) - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.takeSample(False, 2, 97) - [Row(field1=3, field2=u'row3'), Row(field1=1, field2=u'row1')] - """ - seed = seed if seed is not None else random.randint(0, sys.maxint) - with SCCallSiteSync(self.context) as css: - bytesInJava = self._jschema_rdd.baseSchemaRDD() \ - .takeSampleToPython(withReplacement, num, long(seed)) \ - .iterator() - cls = _create_cls(self.schema()) - return map(cls, self._collect_iterator_through_file(bytesInJava)) + +def _create_column_from_name(name): + sc = SparkContext._active_spark_context + return sc._jvm.Column(name) + + +def scalaMethod(name): + return ''.join(SCALA_METHOD_MAPPINGS.get(c, c) for c in name) + + +def _unary_op(name): + def _(self): + return Column(getattr(self._jc, scalaMethod(name))()) + return _ + + +def _bin_op(name): + def _(self, other): + if isinstance(other, Column): + jc = other._jc + else: + jc = _create_column_from_literal(other) + return Column(getattr(self._jc, scalaMethod(name))(jc)) + return _ + + +def _reverse_op(name): + def _(self, other): + return Column(getattr(_create_column_from_literal(other), scalaMethod(name))(self._jc)) + return _ + + +class Column(DataFrame): + def __init__(self, jc, jdf=None, sql_ctx=None): + self._jc = jc + super(Column, self).__init__(jdf, sql_ctx) + + def __nonzero__(self): + return True + + # arithmetic operators + __neg__ = _unary_op("unary_-") + __add__ = _bin_op("+") + __sub__ = _bin_op("-") + __mul__ = _bin_op("*") + __div__ = _bin_op("/") + __mod__ = _bin_op("%") + __radd__ = _bin_op("+") + __rsub__ = _reverse_op("-") + __rmul__ = _bin_op("*") + __rdiv__ = _reverse_op("/") + __rmod__ = _reverse_op("%") + __abs__ = _unary_op("abs") + abs = _unary_op("abs") + sqrt = _unary_op("sqrt") + + # logistic operators + __eq__ = _bin_op("===") + __ne__ = _bin_op("!==") + __lt__ = _bin_op("<") + __le__ = _bin_op("<=") + __ge__ = _bin_op(">=") + __gt__ = _bin_op(">") + # `and`, `or`, `not` cannot be overloaded in Python + And = _bin_op('&&') + Or = _bin_op('||') + Not = _unary_op('unary_!') + + # bitwise operators + __and__ = _bin_op("&") + __or__ = _bin_op("|") + __invert__ = _unary_op("unary_~") + __xor__ = _bin_op("^") + # __lshift__ = _bin_op("<<") + # __rshift__ = _bin_op(">>") + __rand__ = _bin_op("&") + __ror__ = _bin_op("|") + __rxor__ = _bin_op("^") + # __rlshift__ = _reverse_op("<<") + # __rrshift__ = _reverse_op(">>") + + # container operators + __contains__ = _bin_op("contains") + __getitem__ = _bin_op("getItem") + + def __getslice__(self, a, b): + jc = self._jsc.substr(a, b - a) + return Column(jc) + # __getattr__ = _bin_op("getField") + + # string methods + rlike = _bin_op("rlike") + like = _bin_op("like") + startswith = _bin_op("startsWith") + endswith = _bin_op("endsWith") + upper = _unary_op("upper") + lower = _unary_op("lower") + + # order + asc = _unary_op("asc") + desc = _unary_op("desc") + + # `as` is keyword + def As(self, alias): + return Column(getattr(self._jsc, "as")(alias)) + + def cast(self, dataType): + raise NotImplemented + + +def _help_func(name): + def _(col): + sc = SparkContext._active_spark_context + if isinstance(col, Column): + jcol = col._jc + else: + jcol = _create_column_from_name(col) + # FIXME: can not access dsl.min/max ... + jc = getattr(sc._jvm.org.apache.spark.sql.dsl(), name)(jcol) + return Column(jc) + return staticmethod(_) + + +class Aggregator(object): + # helper functions + max = _help_func("max") + min = _help_func("min") + avg = mean = _help_func("mean") + sum = _help_func("sum") + first = _help_func("first") + last = _help_func("last") + count = _help_func("count") def _test(): diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index b474fcf5bfb7e..142d7ad5f0f9b 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -806,6 +806,9 @@ def tearDownClass(cls): def setUp(self): self.sqlCtx = SQLContext(self.sc) + self.testData = [Row(key=i, value=str(i)) for i in range(100)] + rdd = self.sc.parallelize(self.testData) + self.df = self.sqlCtx.inferSchema(rdd) def test_udf(self): self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) @@ -821,7 +824,7 @@ def test_udf2(self): def test_udf_with_array_type(self): d = [Row(l=range(3), d={"key": range(5)})] rdd = self.sc.parallelize(d) - srdd = self.sqlCtx.inferSchema(rdd).registerTempTable("test") + self.sqlCtx.inferSchema(rdd).registerTempTable("test") self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType())) self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType()) [(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect() @@ -839,68 +842,51 @@ def test_broadcast_in_udf(self): def test_basic_functions(self): rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) - srdd = self.sqlCtx.jsonRDD(rdd) - srdd.count() - srdd.collect() - srdd.schemaString() - srdd.schema() + df = self.sqlCtx.jsonRDD(rdd) + df.count() + df.collect() + df.schema() # cache and checkpoint - self.assertFalse(srdd.is_cached) - srdd.persist() - srdd.unpersist() - srdd.cache() - self.assertTrue(srdd.is_cached) - self.assertFalse(srdd.isCheckpointed()) - self.assertEqual(None, srdd.getCheckpointFile()) - - srdd = srdd.coalesce(2, True) - srdd = srdd.repartition(3) - srdd = srdd.distinct() - srdd.intersection(srdd) - self.assertEqual(2, srdd.count()) - - srdd.registerTempTable("temp") - srdd = self.sqlCtx.sql("select foo from temp") - srdd.count() - srdd.collect() - - def test_distinct(self): - rdd = self.sc.parallelize(['{"a": 1}', '{"b": 2}', '{"c": 3}']*10, 10) - srdd = self.sqlCtx.jsonRDD(rdd) - self.assertEquals(srdd.getNumPartitions(), 10) - self.assertEquals(srdd.distinct().count(), 3) - result = srdd.distinct(5) - self.assertEquals(result.getNumPartitions(), 5) - self.assertEquals(result.count(), 3) + self.assertFalse(df.is_cached) + df.persist() + df.unpersist() + df.cache() + self.assertTrue(df.is_cached) + self.assertEqual(2, df.count()) + + df.registerTempTable("temp") + df = self.sqlCtx.sql("select foo from temp") + df.count() + df.collect() def test_apply_schema_to_row(self): - srdd = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""])) - srdd2 = self.sqlCtx.applySchema(srdd.map(lambda x: x), srdd.schema()) - self.assertEqual(srdd.collect(), srdd2.collect()) + df = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""])) + df2 = self.sqlCtx.applySchema(df.map(lambda x: x), df.schema()) + self.assertEqual(df.collect(), df2.collect()) rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x)) - srdd3 = self.sqlCtx.applySchema(rdd, srdd.schema()) - self.assertEqual(10, srdd3.count()) + df3 = self.sqlCtx.applySchema(rdd, df.schema()) + self.assertEqual(10, df3.count()) def test_serialize_nested_array_and_map(self): d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})] rdd = self.sc.parallelize(d) - srdd = self.sqlCtx.inferSchema(rdd) - row = srdd.first() + df = self.sqlCtx.inferSchema(rdd) + row = df.first() self.assertEqual(1, len(row.l)) self.assertEqual(1, row.l[0].a) self.assertEqual("2", row.d["key"].d) - l = srdd.map(lambda x: x.l).first() + l = df.map(lambda x: x.l).first() self.assertEqual(1, len(l)) self.assertEqual('s', l[0].b) - d = srdd.map(lambda x: x.d).first() + d = df.map(lambda x: x.d).first() self.assertEqual(1, len(d)) self.assertEqual(1.0, d["key"].c) - row = srdd.map(lambda x: x.d["key"]).first() + row = df.map(lambda x: x.d["key"]).first() self.assertEqual(1.0, row.c) self.assertEqual("2", row.d) @@ -908,26 +894,26 @@ def test_infer_schema(self): d = [Row(l=[], d={}), Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")] rdd = self.sc.parallelize(d) - srdd = self.sqlCtx.inferSchema(rdd) - self.assertEqual([], srdd.map(lambda r: r.l).first()) - self.assertEqual([None, ""], srdd.map(lambda r: r.s).collect()) - srdd.registerTempTable("test") + df = self.sqlCtx.inferSchema(rdd) + self.assertEqual([], df.map(lambda r: r.l).first()) + self.assertEqual([None, ""], df.map(lambda r: r.s).collect()) + df.registerTempTable("test") result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'") self.assertEqual(1, result.first()[0]) - srdd2 = self.sqlCtx.inferSchema(rdd, 1.0) - self.assertEqual(srdd.schema(), srdd2.schema()) - self.assertEqual({}, srdd2.map(lambda r: r.d).first()) - self.assertEqual([None, ""], srdd2.map(lambda r: r.s).collect()) - srdd2.registerTempTable("test2") + df2 = self.sqlCtx.inferSchema(rdd, 1.0) + self.assertEqual(df.schema(), df2.schema()) + self.assertEqual({}, df2.map(lambda r: r.d).first()) + self.assertEqual([None, ""], df2.map(lambda r: r.s).collect()) + df2.registerTempTable("test2") result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'") self.assertEqual(1, result.first()[0]) def test_struct_in_map(self): d = [Row(m={Row(i=1): Row(s="")})] rdd = self.sc.parallelize(d) - srdd = self.sqlCtx.inferSchema(rdd) - k, v = srdd.first().m.items()[0] + df = self.sqlCtx.inferSchema(rdd) + k, v = df.first().m.items()[0] self.assertEqual(1, k.i) self.assertEqual("", v.s) @@ -935,8 +921,8 @@ def test_convert_row_to_dict(self): row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}) self.assertEqual(1, row.asDict()['l'][0].a) rdd = self.sc.parallelize([row]) - srdd = self.sqlCtx.inferSchema(rdd) - srdd.registerTempTable("test") + df = self.sqlCtx.inferSchema(rdd) + df.registerTempTable("test") row = self.sqlCtx.sql("select l, d from test").first() self.assertEqual(1, row.asDict()["l"][0].a) self.assertEqual(1.0, row.asDict()['d']['key'].c) @@ -945,11 +931,11 @@ def test_infer_schema_with_udt(self): from pyspark.tests import ExamplePoint, ExamplePointUDT row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) rdd = self.sc.parallelize([row]) - srdd = self.sqlCtx.inferSchema(rdd) - schema = srdd.schema() + df = self.sqlCtx.inferSchema(rdd) + schema = df.schema() field = [f for f in schema.fields if f.name == "point"][0] self.assertEqual(type(field.dataType), ExamplePointUDT) - srdd.registerTempTable("labeled_point") + df.registerTempTable("labeled_point") point = self.sqlCtx.sql("SELECT point FROM labeled_point").first().point self.assertEqual(point, ExamplePoint(1.0, 2.0)) @@ -959,21 +945,50 @@ def test_apply_schema_with_udt(self): rdd = self.sc.parallelize([row]) schema = StructType([StructField("label", DoubleType(), False), StructField("point", ExamplePointUDT(), False)]) - srdd = self.sqlCtx.applySchema(rdd, schema) - point = srdd.first().point + df = self.sqlCtx.applySchema(rdd, schema) + point = df.first().point self.assertEquals(point, ExamplePoint(1.0, 2.0)) def test_parquet_with_udt(self): from pyspark.tests import ExamplePoint row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) rdd = self.sc.parallelize([row]) - srdd0 = self.sqlCtx.inferSchema(rdd) + df0 = self.sqlCtx.inferSchema(rdd) output_dir = os.path.join(self.tempdir.name, "labeled_point") - srdd0.saveAsParquetFile(output_dir) - srdd1 = self.sqlCtx.parquetFile(output_dir) - point = srdd1.first().point + df0.saveAsParquetFile(output_dir) + df1 = self.sqlCtx.parquetFile(output_dir) + point = df1.first().point self.assertEquals(point, ExamplePoint(1.0, 2.0)) + def test_column_operators(self): + from pyspark.sql import Column + ci = self.df.key + cs = self.df.value + c = ci == cs + self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column)) + rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci) + self.assertTrue(all(isinstance(c, Column) for c in rcc)) + cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7, ci and cs, ci or cs] + self.assertTrue(all(isinstance(c, Column) for c in cb)) + cbit = (ci & ci), (ci | ci), (ci ^ ci), (~ci) + self.assertTrue(all(isinstance(c, Column) for c in cbit)) + css = cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(), cs.startswith('a'), cs.endswith('a') + self.assertTrue(all(isinstance(c, Column) for c in css)) + + def test_column(self): + df = self.df + self.assertEqual(self.testData, df.select("*").collect()) + self.assertEqual(self.testData, df.select(df.key, df.value).collect()) + self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect()) + + def test_aggregator(self): + from pyspark.sql import Aggregator as Agg + df = self.df + g = df.groupby() + self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0])) + self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect()) + # self.assertEqual((0, '100'), tuple(g.agg(Agg.first(df.key), Agg.last(df.value)).first())) + class InputFormatTests(ReusedPySparkTestCase): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 19ccb6ca8f76b..f6178d775a38c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -19,11 +19,17 @@ package org.apache.spark.sql import scala.language.implicitConversions import scala.reflect.ClassTag +import scala.collection.JavaConversions._ + +import java.util.{ArrayList, List => JList} import com.fasterxml.jackson.core.JsonFactory +import net.razorvine.pickle.Pickler import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.api.python.SerDeUtil import org.apache.spark.storage.StorageLevel import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation @@ -31,7 +37,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.{Literal => LiteralExpr} import org.apache.spark.sql.catalyst.plans.{JoinType, Inner} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.LogicalRDD +import org.apache.spark.sql.execution.{LogicalRDD, EvaluatePython} import org.apache.spark.sql.json.JsonRDD import org.apache.spark.sql.types.{NumericType, StructType} import org.apache.spark.util.Utils @@ -560,4 +566,33 @@ class DataFrame protected[sql]( iter.map(JsonRDD.rowToJSON(rowSchema, jsonFactory)) } } + + //////////////////////////////////////////////////////////////////////////// + // for Python API + //////////////////////////////////////////////////////////////////////////// + private[sql] def select(cols: java.util.List[Column]): DataFrame = { + select(cols:_*) + } + private[sql] def groupby(cols: java.util.List[Column]): GroupedDataFrame = { + groupby(cols:_*) + } + + /** + * Converts a JavaRDD to a PythonRDD. It is used by pyspark. + */ + private[sql] def javaToPython: JavaRDD[Array[Byte]] = { + val fieldTypes = schema.fields.map(_.dataType) + val jrdd = this.rdd.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD() + SerDeUtil.javaToPython(jrdd) + } + /** + * Serializes the Array[Row] returned by collect(), using the same format as javaToPython. + */ + private[sql] def collectToPython: JList[Array[Byte]] = { + val fieldTypes = schema.fields.map(_.dataType) + val pickle = new Pickler + new ArrayList[Array[Byte]](collect().map { row => + EvaluatePython.rowToArray(row, fieldTypes) + }.grouped(100).map(batched => pickle.dumps(batched.toArray)).toIterable) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala index f0cde1a024751..ef5ef6b09dc6c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql import scala.language.implicitConversions +import scala.collection.JavaConverters._ +import scala.collection.JavaConversions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.{Literal => LiteralExpr} @@ -120,4 +122,9 @@ class GroupedDataFrame protected[sql](df: DataFrame, groupingExprs: Seq[Expressi * Compute the sum for each numeric columns for each group. */ override def sum(): DataFrame = aggregateNumericColumns(Sum) + + //// For Python API + private[sql] def agg(exprs: java.util.Map[String, String]): DataFrame = { + agg(exprs.toMap) + } }