@@ -20,9 +20,10 @@ package org.apache.spark.streaming.api.python
2020import java .util .{ArrayList => JArrayList }
2121import scala .collection .JavaConversions ._
2222
23- import org .apache .spark .rdd .RDD
2423import org .apache .spark .api .java ._
24+ import org .apache .spark .api .java .function .{Function2 => JFunction2 }
2525import org .apache .spark .api .python ._
26+ import org .apache .spark .rdd .RDD
2627import org .apache .spark .storage .StorageLevel
2728import org .apache .spark .streaming .{Interval , Duration , Time }
2829import org .apache .spark .streaming .dstream ._
@@ -35,19 +36,22 @@ trait PythonRDDFunction {
3536 def call (rdd : JavaRDD [_], rdd2 : JavaRDD [_], time : Long ): JavaRDD [Array [Byte ]]
3637}
3738
38- class RDDFunction (pfunc : PythonRDDFunction ) {
39- def apply (rdd : Option [RDD [_]], rdd2 : Option [RDD [_]], time : Time ): Option [RDD [Array [Byte ]]] = {
40- val jrdd = if (rdd.isDefined) {
39+ class RDDFunction (pfunc : PythonRDDFunction ) extends Serializable {
40+
41+ def apply (rdd : Option [RDD [_]], time : Time ): Option [RDD [Array [Byte ]]] = {
42+ apply(rdd, None , time)
43+ }
44+
45+ def wrapRDD (rdd : Option [RDD [_]]): JavaRDD [_] = {
46+ if (rdd.isDefined) {
4147 JavaRDD .fromRDD(rdd.get)
4248 } else {
4349 null
4450 }
45- val jrdd2 = if (rdd2.isDefined) {
46- JavaRDD .fromRDD(rdd2.get)
47- } else {
48- null
49- }
50- val r = pfunc.call(jrdd, jrdd2, time.milliseconds)
51+ }
52+
53+ def apply (rdd : Option [RDD [_]], rdd2 : Option [RDD [_]], time : Time ): Option [RDD [Array [Byte ]]] = {
54+ val r = pfunc.call(wrapRDD(rdd), wrapRDD(rdd2), time.milliseconds)
5155 if (r != null ) {
5256 Some (r.rdd)
5357 } else {
@@ -66,7 +70,13 @@ abstract class PythonDStream(parent: DStream[_]) extends DStream[Array[Byte]] (p
6670 val asJavaDStream = JavaDStream .fromDStream(this )
6771}
6872
69- object PythonDStream {
73+ private [spark] object PythonDStream {
74+
75+ // helper function for DStream.foreachRDD(),
76+ // cannot be `foreachRDD`, it will confusing py4j
77+ def callForeachRDD (jdstream : JavaDStream [Array [Byte ]], pyfunc : PythonRDDFunction ): Unit = {
78+ jdstream.dstream.foreachRDD((rdd, time) => pyfunc.call(rdd, null , time.milliseconds))
79+ }
7080
7181 // convert list of RDD into queue of RDDs, for ssc.queueStream()
7282 def toRDDQueue (rdds : JArrayList [JavaRDD [Array [Byte ]]]): java.util.Queue [JavaRDD [Array [Byte ]]] = {
@@ -97,7 +107,7 @@ private[spark] class PythonTransformedDStream (parent: DStream[_], pfunc: Python
97107 if (reuse && lastResult != null ) {
98108 Some (lastResult.copyTo(rdd1.get))
99109 } else {
100- val r = func(rdd1, None , validTime)
110+ val r = func(rdd1, validTime)
101111 if (reuse && r.isDefined && lastResult == null ) {
102112 r.get match {
103113 case rdd : PythonRDD =>
@@ -206,8 +216,9 @@ class PythonReducedWindowedDStream(parent: DStream[Array[Byte]],
206216 // Get the RDD of the reduced value of the previous window
207217 val previousWindowRDD = getOrCompute(previousWindow.endTime)
208218
219+ // for small window, reduce once will be better than twice
209220 if (windowDuration > slideDuration * 5 && previousWindowRDD.isDefined) {
210- // subtle the values from old RDDs
221+ // subtract the values from old RDDs
211222 val oldRDDs =
212223 parent.slice(previousWindow.beginTime, currentWindow.beginTime - parent.slideDuration)
213224 val subbed = if (oldRDDs.size > 0 ) {
@@ -236,22 +247,4 @@ class PythonReducedWindowedDStream(parent: DStream[Array[Byte]],
236247 }
237248 }
238249 }
239- }
240-
241- /**
242- * This is used for foreachRDD() in Python
243- */
244- class PythonForeachDStream (
245- prev : DStream [Array [Byte ]],
246- foreachFunction : PythonRDDFunction
247- ) extends ForEachDStream [Array [Byte ]](
248- prev,
249- (rdd : RDD [Array [Byte ]], time : Time ) => {
250- if (rdd != null ) {
251- foreachFunction.call(rdd, null , time.milliseconds)
252- }
253- }
254- ) {
255-
256- this .register()
257250}
0 commit comments