1515# limitations under the License.
1616#
1717
18- from pyspark .serializers import UTF8Deserializer
18+ from pyspark import RDD
19+ from pyspark .serializers import UTF8Deserializer , BatchedSerializer
1920from pyspark .context import SparkContext
21+ from pyspark .storagelevel import StorageLevel
2022from pyspark .streaming .dstream import DStream
21- from pyspark .streaming .duration import Duration , Seconds
23+ from pyspark .streaming .duration import Seconds
2224
2325from py4j .java_collections import ListConverter
2426
2527__all__ = ["StreamingContext" ]
2628
2729
30+ def _daemonize_callback_server ():
31+ """
32+ Hack Py4J to daemonize callback server
33+ """
34+ # TODO: create a patch for Py4J
35+ import socket
36+ import py4j .java_gateway
37+ logger = py4j .java_gateway .logger
38+ from py4j .java_gateway import Py4JNetworkError
39+ from threading import Thread
40+
41+ def start (self ):
42+ """Starts the CallbackServer. This method should be called by the
43+ client instead of run()."""
44+ self .server_socket = socket .socket (socket .AF_INET , socket .SOCK_STREAM )
45+ self .server_socket .setsockopt (socket .SOL_SOCKET , socket .SO_REUSEADDR ,
46+ 1 )
47+ try :
48+ self .server_socket .bind ((self .address , self .port ))
49+ # self.port = self.server_socket.getsockname()[1]
50+ except Exception :
51+ msg = 'An error occurred while trying to start the callback server'
52+ logger .exception (msg )
53+ raise Py4JNetworkError (msg )
54+
55+ # Maybe thread needs to be cleanup up?
56+ self .thread = Thread (target = self .run )
57+ self .thread .daemon = True
58+ self .thread .start ()
59+
60+ py4j .java_gateway .CallbackServer .start = start
61+
62+
2863class StreamingContext (object ):
2964 """
3065 Main entry point for Spark Streaming functionality. A StreamingContext represents the
@@ -53,7 +88,9 @@ def _start_callback_server(self):
5388 gw = self ._sc ._gateway
5489 # getattr will fallback to JVM
5590 if "_callback_server" not in gw .__dict__ :
91+ _daemonize_callback_server ()
5692 gw ._start_callback_server (gw ._python_proxy_port )
93+ gw ._python_proxy_port = gw ._callback_server .port # update port with real port
5794
5895 def _initialize_context (self , sc , duration ):
5996 return self ._jvm .JavaStreamingContext (sc ._jsc , duration ._jduration )
@@ -92,26 +129,44 @@ def stop(self, stopSparkContext=True, stopGraceFully=False):
92129
93130 def remember (self , duration ):
94131 """
95- Set each DStreams in this context to remember RDDs it generated in the last given duration.
96- DStreams remember RDDs only for a limited duration of time and releases them for garbage
97- collection. This method allows the developer to specify how to long to remember the RDDs (
98- if the developer wishes to query old data outside the DStream computation).
99- @param duration pyspark.streaming.duration.Duration object or seconds.
100- Minimum duration that each DStream should remember its RDDs
132+ Set each DStreams in this context to remember RDDs it generated
133+ in the last given duration. DStreams remember RDDs only for a
134+ limited duration of time and releases them for garbage collection.
135+ This method allows the developer to specify how to long to remember
136+ the RDDs ( if the developer wishes to query old data outside the
137+ DStream computation).
138+
139+ @param duration Minimum duration (in seconds) that each DStream
140+ should remember its RDDs
101141 """
102142 if isinstance (duration , (int , long , float )):
103143 duration = Seconds (duration )
104144
105145 self ._jssc .remember (duration ._jduration )
106146
107- # TODO: add storageLevel
108- def socketTextStream (self , hostname , port ):
147+ def checkpoint (self , directory ):
148+ """
149+ Sets the context to periodically checkpoint the DStream operations for master
150+ fault-tolerance. The graph will be checkpointed every batch interval.
151+
152+ @param directory HDFS-compatible directory where the checkpoint data
153+ will be reliably stored
154+ """
155+ self ._jssc .checkpoint (directory )
156+
157+ def socketTextStream (self , hostname , port , storageLevel = StorageLevel .MEMORY_AND_DISK_SER_2 ):
109158 """
110159 Create an input from TCP source hostname:port. Data is received using
111160 a TCP socket and receive byte is interpreted as UTF8 encoded '\n ' delimited
112161 lines.
162+
163+ @param hostname Hostname to connect to for receiving data
164+ @param port Port to connect to for receiving data
165+ @param storageLevel Storage level to use for storing the received objects
113166 """
114- return DStream (self ._jssc .socketTextStream (hostname , port ), self , UTF8Deserializer ())
167+ jlevel = self ._sc ._getJavaStorageLevel (storageLevel )
168+ return DStream (self ._jssc .socketTextStream (hostname , port , jlevel ), self ,
169+ UTF8Deserializer ())
115170
116171 def textFileStream (self , directory ):
117172 """
@@ -122,14 +177,52 @@ def textFileStream(self, directory):
122177 """
123178 return DStream (self ._jssc .textFileStream (directory ), self , UTF8Deserializer ())
124179
125- def _makeStream (self , inputs , numSlices = None ):
180+ def _check_serialzers (self , rdds ):
181+ # make sure they have same serializer
182+ if len (set (rdd ._jrdd_deserializer for rdd in rdds )):
183+ for i in range (len (rdds )):
184+ # reset them to sc.serializer
185+ rdds [i ] = rdds [i ].map (lambda x : x , preservesPartitioning = True )
186+
187+ def queueStream (self , queue , oneAtATime = False , default = None ):
126188 """
127- This function is only for unittest.
128- It requires a list as input, and returns the i_th element at the i_th batch
129- under manual clock.
189+ Create an input stream from an queue of RDDs or list. In each batch,
190+ it will process either one or all of the RDDs returned by the queue.
191+
192+ NOTE: changes to the queue after the stream is created will not be recognized.
193+ @param queue Queue of RDDs
194+ @tparam T Type of objects in the RDD
130195 """
131- rdds = [self ._sc .parallelize (input , numSlices ) for input in inputs ]
196+ if queue and not isinstance (queue [0 ], RDD ):
197+ rdds = [self ._sc .parallelize (input ) for input in queue ]
198+ else :
199+ rdds = queue
200+ self ._check_serialzers (rdds )
132201 jrdds = ListConverter ().convert ([r ._jrdd for r in rdds ],
133202 SparkContext ._gateway ._gateway_client )
134- jdstream = self ._jvm .PythonDataInputStream (self ._jssc , jrdds ).asJavaDStream ()
135- return DStream (jdstream , self , rdds [0 ]._jrdd_deserializer )
203+ jdstream = self ._jvm .PythonDataInputStream (self ._jssc , jrdds , oneAtATime ,
204+ default and default ._jrdd )
205+ return DStream (jdstream .asJavaDStream (), self , rdds [0 ]._jrdd_deserializer )
206+
207+ def transform (self , dstreams , transformFunc ):
208+ """
209+ Create a new DStream in which each RDD is generated by applying a function on RDDs of
210+ the DStreams. The order of the JavaRDDs in the transform function parameter will be the
211+ same as the order of corresponding DStreams in the list.
212+ """
213+ # TODO
214+
215+ def union (self , * dstreams ):
216+ """
217+ Create a unified DStream from multiple DStreams of the same
218+ type and same slide duration.
219+ """
220+ if not dstreams :
221+ raise ValueError ("should have at least one DStream to union" )
222+ if len (dstreams ) == 1 :
223+ return dstreams [0 ]
224+ self ._check_serialzers (dstreams )
225+ first = dstreams [0 ]
226+ jrest = ListConverter ().convert ([d ._jdstream for d in dstreams [1 :]],
227+ SparkContext ._gateway ._gateway_client )
228+ return DStream (self ._jssc .union (first ._jdstream , jrest ), self , first ._jrdd_deserializer )
0 commit comments