From 020cbdf8814d25121a909b634f908dc8c41bd739 Mon Sep 17 00:00:00 2001 From: jbencook Date: Mon, 22 Dec 2014 10:40:49 -0600 Subject: [PATCH 1/5] [SPARK-4860][pyspark][sql] using Scala implementations of `sample()` and `takeSample()` --- python/pyspark/sql.py | 29 +++++++++++++++++++ .../org/apache/spark/sql/SchemaRDD.scala | 14 +++++++++ 2 files changed, 43 insertions(+) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 469f82473af97..94051990f8df1 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -2085,6 +2085,35 @@ def subtract(self, other, numPartitions=None): else: raise ValueError("Can only subtract another SchemaRDD") + def sample(self, withReplacement, fraction, seed=None): + """ + Return a sampled subset of this SchemaRDD. + + >>> srdd = sqlCtx.inferSchema(rdd) + >>> srdd.sample(False, 0.5, 97).count() + 2L + """ + assert fraction >= 0.0, "Negative fraction value: %s" % fraction + seed = seed if seed is not None else random.randint(0, sys.maxint) + rdd = self._jschema_rdd.baseSchemaRDD().sample( + withReplacement, fraction, long(seed)) + return SchemaRDD(rdd.toJavaSchemaRDD(), self.sql_ctx) + + def takeSample(self, withReplacement, num, seed=None): + """Return a fixed-size sampled subset of this SchemaRDD. + + >>> srdd = sqlCtx.inferSchema(rdd) + >>> srdd.takeSample(False, 2, 97) + [Row(field1=3, field2=u'row3'), Row(field1=1, field2=u'row1')] + """ + seed = seed if seed is not None else random.randint(0, sys.maxint) + with SCCallSiteSync(self.context) as css: + bytesInJava = self._jschema_rdd.baseSchemaRDD() \ + .takeSampleToPython(withReplacement, num, long(seed)) \ + .iterator() + cls = _create_cls(self.schema()) + return map(cls, self._collect_iterator_through_file(bytesInJava)) + def _test(): import doctest diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 7baf8ffcef787..677fb56ba6b52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -437,6 +437,20 @@ class SchemaRDD( }.grouped(100).map(batched => pickle.dumps(batched.toArray)).toIterable) } + /** + * Serializes the Array[Row] returned by SchemaRDD's takeSample(), using the same + * format as javaToPython and collectToPython. It is used by pyspark. + */ + private[sql] def takeSampleToPython(withReplacement: Boolean, + num: Int, + seed: Long): JList[Array[Byte]] = { + val fieldTypes = schema.fields.map(_.dataType) + val pickle = new Pickler + new java.util.ArrayList(this.takeSample(withReplacement, num, seed).map { row => + EvaluatePython.rowToArray(row, fieldTypes) + }.grouped(100).map(batched => pickle.dumps(batched.toArray)).toIterable) + } + /** * Creates SchemaRDD by applying own schema to derived RDD. Typically used to wrap return value * of base RDD functions that do not change schema. From b916442fa40dadfee6c58b5ed7600d16d5f39930 Mon Sep 17 00:00:00 2001 From: jbencook Date: Mon, 22 Dec 2014 20:15:44 -0600 Subject: [PATCH 2/5] [SPARK-4860][pyspark][sql] adding sample() to JavaSchemaRDD --- .../scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala index ac4844f9b9290..cc0f22b01a4cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala @@ -218,4 +218,10 @@ class JavaSchemaRDD( */ def subtract(other: JavaSchemaRDD, p: Partitioner): JavaSchemaRDD = this.baseSchemaRDD.subtract(other.baseSchemaRDD, p).toJavaSchemaRDD + + /** + * Return an RDD with a sampled version of the underlying dataset. + */ + def sample(withReplacement: Boolean, fraction: Double, seed: Long): JavaSchemaRDD = + this.baseSchemaRDD.sample(withReplacement, fraction, seed).toJavaSchemaRDD } From de22f706d8bbe6d80a6ea2e9a5343b77e0695471 Mon Sep 17 00:00:00 2001 From: jbencook Date: Mon, 22 Dec 2014 20:16:42 -0600 Subject: [PATCH 3/5] [SPARK-4860][pyspark][sql] using sample() method from JavaSchemaRDD --- python/pyspark/sql.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 94051990f8df1..9807a84a66f11 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -2095,9 +2095,8 @@ def sample(self, withReplacement, fraction, seed=None): """ assert fraction >= 0.0, "Negative fraction value: %s" % fraction seed = seed if seed is not None else random.randint(0, sys.maxint) - rdd = self._jschema_rdd.baseSchemaRDD().sample( - withReplacement, fraction, long(seed)) - return SchemaRDD(rdd.toJavaSchemaRDD(), self.sql_ctx) + rdd = self._jschema_rdd.sample(withReplacement, fraction, long(seed)) + return SchemaRDD(rdd, self.sql_ctx) def takeSample(self, withReplacement, num, seed=None): """Return a fixed-size sampled subset of this SchemaRDD. From 5170da23afa93f756d92048201769da6c0194753 Mon Sep 17 00:00:00 2001 From: "J. Benjamin Cook" Date: Tue, 23 Dec 2014 03:42:26 -0600 Subject: [PATCH 4/5] [SPARK-4860][pyspark][sql] fixing typo: from RDD to SchemaRDD --- .../scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala index cc0f22b01a4cc..5b9c612487ace 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala @@ -220,7 +220,7 @@ class JavaSchemaRDD( this.baseSchemaRDD.subtract(other.baseSchemaRDD, p).toJavaSchemaRDD /** - * Return an RDD with a sampled version of the underlying dataset. + * Return a SchemaRDD with a sampled version of the underlying dataset. */ def sample(withReplacement: Boolean, fraction: Double, seed: Long): JavaSchemaRDD = this.baseSchemaRDD.sample(withReplacement, fraction, seed).toJavaSchemaRDD From 6fbc76993a3ab8c00afde49234d00cf0128f5db5 Mon Sep 17 00:00:00 2001 From: "J. Benjamin Cook" Date: Tue, 23 Dec 2014 03:45:41 -0600 Subject: [PATCH 5/5] [SPARK-4860][pyspark][sql] fixing sloppy indentation for takeSampleToPython() arguments --- .../src/main/scala/org/apache/spark/sql/SchemaRDD.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 677fb56ba6b52..856b10f1a8fd8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -441,9 +441,10 @@ class SchemaRDD( * Serializes the Array[Row] returned by SchemaRDD's takeSample(), using the same * format as javaToPython and collectToPython. It is used by pyspark. */ - private[sql] def takeSampleToPython(withReplacement: Boolean, - num: Int, - seed: Long): JList[Array[Byte]] = { + private[sql] def takeSampleToPython( + withReplacement: Boolean, + num: Int, + seed: Long): JList[Array[Byte]] = { val fieldTypes = schema.fields.map(_.dataType) val pickle = new Pickler new java.util.ArrayList(this.takeSample(withReplacement, num, seed).map { row =>