1818import sys
1919
2020from py4j .java_collections import ListConverter
21- from py4j .java_gateway import java_import
21+ from py4j .java_gateway import java_import , JavaObject
2222
2323from pyspark import RDD , SparkConf
2424from pyspark .serializers import UTF8Deserializer , CloudPickleSerializer
@@ -38,6 +38,8 @@ def _daemonize_callback_server():
3838 from exiting if it's not shutdown. The following code replace `start()`
3939 of CallbackServer with a new version, which set daemon=True for this
4040 thread.
41+
42+ Also, it will update the port number (0) with real port
4143 """
4244 # TODO: create a patch for Py4J
4345 import socket
@@ -54,8 +56,11 @@ def start(self):
5456 1 )
5557 try :
5658 self .server_socket .bind ((self .address , self .port ))
57- except Exception :
58- msg = 'An error occurred while trying to start the callback server'
59+ if not self .port :
60+ # update port with real port
61+ self .port = self .server_socket .getsockname ()[1 ]
62+ except Exception as e :
63+ msg = 'An error occurred while trying to start the callback server: %s' % e
5964 logger .exception (msg )
6065 raise Py4JNetworkError (msg )
6166
@@ -105,15 +110,24 @@ def _jduration(self, seconds):
105110 def _ensure_initialized (cls ):
106111 SparkContext ._ensure_initialized ()
107112 gw = SparkContext ._gateway
108- # start callback server
109- # getattr will fallback to JVM
110- if "_callback_server" not in gw .__dict__ :
111- _daemonize_callback_server ()
112- gw ._start_callback_server (gw ._python_proxy_port )
113113
114114 java_import (gw .jvm , "org.apache.spark.streaming.*" )
115115 java_import (gw .jvm , "org.apache.spark.streaming.api.java.*" )
116116 java_import (gw .jvm , "org.apache.spark.streaming.api.python.*" )
117+
118+ # start callback server
119+ # getattr will fallback to JVM, so we cannot test by hasattr()
120+ if "_callback_server" not in gw .__dict__ :
121+ _daemonize_callback_server ()
122+ # use random port
123+ gw ._start_callback_server (0 )
124+ # gateway with real port
125+ gw ._python_proxy_port = gw ._callback_server .port
126+ # get the GatewayServer object in JVM by ID
127+ jgws = JavaObject ("GATEWAY_SERVER" , gw ._gateway_client )
128+ # update the port of CallbackClient with real port
129+ gw .jvm .PythonDStream .updatePythonGatewayPort (jgws , gw ._python_proxy_port )
130+
117131 # register serializer for TransformFunction
118132 # it happens before creating SparkContext when loading from checkpointing
119133 cls ._transformerSerializer = TransformFunctionSerializer (
0 commit comments