@@ -275,7 +275,7 @@ def test_func(dstream):
275275 self .assertEqual (expected_output , output )
276276
277277 def test_mapPartitions_batch (self ):
278- """Basic operation test for DStream.mapPartitions with batch deserializer"""
278+ """Basic operation test for DStream.mapPartitions with batch deserializer. """
279279 test_input = [range (1 , 5 ), range (5 , 9 ), range (9 , 13 )]
280280 numSlices = 2
281281
@@ -288,7 +288,7 @@ def f(iterator):
288288 self .assertEqual (expected_output , output )
289289
290290 def test_mapPartitions_unbatch (self ):
291- """Basic operation test for DStream.mapPartitions with unbatch deserializer"""
291+ """Basic operation test for DStream.mapPartitions with unbatch deserializer. """
292292 test_input = [range (1 , 4 ), range (4 , 7 ), range (7 , 10 )]
293293 numSlices = 2
294294
@@ -301,8 +301,8 @@ def f(iterator):
301301 self .assertEqual (expected_output , output )
302302
303303 def test_countByValue_batch (self ):
304- """Basic operation test for DStream.countByValue with batch deserializer"""
305- test_input = [range (1 , 5 ) + range (1 ,5 ), range (5 , 7 ) + range (5 , 9 ), ["a" ] * 2 + [ "b" ] + [ "" ] ]
304+ """Basic operation test for DStream.countByValue with batch deserializer. """
305+ test_input = [range (1 , 5 ) + range (1 ,5 ), range (5 , 7 ) + range (5 , 9 ), ["a" , "a" , "b" , "" ]]
306306
307307 def test_func (dstream ):
308308 return dstream .countByValue ()
@@ -315,7 +315,7 @@ def test_func(dstream):
315315 self .assertEqual (expected_output , output )
316316
317317 def test_countByValue_unbatch (self ):
318- """Basic operation test for DStream.countByValue with unbatch deserializer"""
318+ """Basic operation test for DStream.countByValue with unbatch deserializer. """
319319 test_input = [range (1 , 4 ), [1 , 1 , "" ], ["a" , "a" , "b" ]]
320320
321321 def test_func (dstream ):
@@ -328,30 +328,72 @@ def test_func(dstream):
328328 self ._sort_result_based_on_key (result )
329329 self .assertEqual (expected_output , output )
330330
331+ def test_groupByKey_batch (self ):
332+ """Basic operation test for DStream.groupByKey with batch deserializer."""
333+ test_input = [range (1 , 5 ), [1 , 1 , 1 , 2 , 2 , 3 ], ["a" , "a" , "b" , "" , "" , "" ]]
334+ def test_func (dstream ):
335+ return dstream .map (lambda x : (x ,1 )).groupByKey ()
336+ expected_output = [[(1 , [1 ]), (2 , [1 ]), (3 , [1 ]), (4 , [1 ])],
337+ [(1 , [1 , 1 , 1 ]), (2 , [1 , 1 ]), (3 , [1 ])],
338+ [("a" , [1 , 1 ]), ("b" , [1 ]), ("" , [1 , 1 , 1 ])]]
339+ scattered_output = self ._run_stream (test_input , test_func , expected_output )
340+ output = self ._convert_iter_value_to_list (scattered_output )
341+ for result in (output , expected_output ):
342+ self ._sort_result_based_on_key (result )
343+ self .assertEqual (expected_output , output )
344+
345+ def test_groupByKey_unbatch (self ):
346+ """Basic operation test for DStream.groupByKey with unbatch deserializer."""
347+ test_input = [range (1 , 4 ), [1 , 1 , "" ], ["a" , "a" , "b" ]]
348+ def test_func (dstream ):
349+ return dstream .map (lambda x : (x ,1 )).groupByKey ()
350+ expected_output = [[(1 , [1 ]), (2 , [1 ]), (3 , [1 ])],
351+ [(1 , [1 , 1 ]), ("" , [1 ])],
352+ [("a" , [1 , 1 ]), ("b" , [1 ])]]
353+ scattered_output = self ._run_stream (test_input , test_func , expected_output )
354+ output = self ._convert_iter_value_to_list (scattered_output )
355+ for result in (output , expected_output ):
356+ self ._sort_result_based_on_key (result )
357+ self .assertEqual (expected_output , output )
358+
359+ def _convert_iter_value_to_list (self , outputs ):
360+ """Return key value pair list. Value is converted to iterator to list."""
361+ result = list ()
362+ for output in outputs :
363+ result .append (map (lambda (x , y ): (x , list (y )), output ))
364+ return result
365+
331366 def _sort_result_based_on_key (self , outputs ):
367+ """Sort the list base onf first value."""
332368 for output in outputs :
333369 output .sort (key = lambda x : x [0 ])
334370
335371 def _run_stream (self , test_input , test_func , expected_output , numSlices = None ):
336- """Start stream and return the output"""
337- # Generate input stream with user-defined input
372+ """
373+ Start stream and return the output.
374+ @param test_input: dataset for the test. This should be list of lists.
375+ @param test_func: wrapped test_function. This function should return PythonDstream object.
376+ @param expexted_output: expected output for this testcase.
377+ @param numSlices: the number of slices in the rdd in the dstream.
378+ """
379+ # Generate input stream with user-defined input.
338380 numSlices = numSlices or self .numInputPartitions
339381 test_input_stream = self .ssc ._testInputStream (test_input , numSlices )
340- # Apply test function to stream
382+ # Apply test function to stream.
341383 test_stream = test_func (test_input_stream )
342- # Add job to get output from stream
384+ # Add job to get output from stream.
343385 test_stream ._test_output (self .result )
344386 self .ssc .start ()
345387
346388 start_time = time .time ()
347- # loop until get the result from stream
389+ # Loop until get the expected the number of the result from the stream.
348390 while True :
349391 current_time = time .time ()
350- # check time out
392+ # Check time out.
351393 if (current_time - start_time ) > self .timeout :
352394 break
353395 self .ssc .awaitTermination (50 )
354- # check if the output is the same length of expexted output
396+ # Check if the output is the same length of expexted output.
355397 if len (expected_output ) == len (self .result ):
356398 break
357399
@@ -372,9 +414,5 @@ def tearDownClass(cls):
372414 PySparkStreamingTestCase .tearDownClass ()
373415
374416
375-
376-
377-
378-
379417if __name__ == "__main__" :
380418 unittest .main ()
0 commit comments