diff --git a/doc/source/conf.py b/doc/source/conf.py index e362f73309a3..2a2b1a37c207 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -42,6 +42,7 @@ "ray.core.generated.ClientTableData", "ray.core.generated.GcsTableEntry", "ray.core.generated.HeartbeatTableData", + "ray.core.generated.HeartbeatBatchTableData", "ray.core.generated.DriverTableData", "ray.core.generated.ErrorTableData", "ray.core.generated.ProfileTableData", diff --git a/python/ray/gcs_utils.py b/python/ray/gcs_utils.py index bbdbe04cf7fd..347f7ab9f806 100644 --- a/python/ray/gcs_utils.py +++ b/python/ray/gcs_utils.py @@ -11,6 +11,7 @@ from ray.core.generated.ErrorTableData import ErrorTableData from ray.core.generated.ProfileTableData import ProfileTableData from ray.core.generated.HeartbeatTableData import HeartbeatTableData +from ray.core.generated.HeartbeatBatchTableData import HeartbeatBatchTableData from ray.core.generated.DriverTableData import DriverTableData from ray.core.generated.ObjectTableData import ObjectTableData from ray.core.generated.ray.protocol.Task import Task @@ -20,14 +21,16 @@ __all__ = [ "GcsTableEntry", "ClientTableData", "ErrorTableData", "HeartbeatTableData", - "DriverTableData", "ProfileTableData", "ObjectTableData", "Task", - "TablePrefix", "TablePubsub", "construct_error_message" + "HeartbeatBatchTableData", "DriverTableData", "ProfileTableData", + "ObjectTableData", "Task", "TablePrefix", "TablePubsub", + "construct_error_message" ] FUNCTION_PREFIX = "RemoteFunction:" # xray heartbeats XRAY_HEARTBEAT_CHANNEL = str(TablePubsub.HEARTBEAT).encode("ascii") +XRAY_HEARTBEAT_BATCH_CHANNEL = str(TablePubsub.HEARTBEAT_BATCH).encode("ascii") # xray driver updates XRAY_DRIVER_CHANNEL = str(TablePubsub.DRIVER).encode("ascii") diff --git a/python/ray/monitor.py b/python/ray/monitor.py index 625641790de9..a37f75de7cf1 100644 --- a/python/ray/monitor.py +++ b/python/ray/monitor.py @@ -50,11 +50,6 @@ def __init__(self, # Setup subscriptions to the primary Redis server and the Redis shards. self.primary_subscribe_client = self.redis.pubsub( ignore_subscribe_messages=True) - self.shard_subscribe_clients = [] - for redis_client in self.state.redis_clients: - subscribe_client = redis_client.pubsub( - ignore_subscribe_messages=True) - self.shard_subscribe_clients.append(subscribe_client) # Keep a mapping from local scheduler client ID to IP address to use # for updating the load metrics. self.local_scheduler_id_to_ip_map = {} @@ -90,49 +85,50 @@ def __init__(self, str(e))) self.issue_gcs_flushes = False - def subscribe(self, channel, primary=True): - """Subscribe to the given channel. + def subscribe(self, channel): + """Subscribe to the given channel on the primary Redis shard. Args: channel (str): The channel to subscribe to. - primary: If True, then we only subscribe to the primary Redis - shard. Otherwise we subscribe to all of the other shards but - not the primary. Raises: Exception: An exception is raised if the subscription fails. """ - if primary: - self.primary_subscribe_client.subscribe(channel) - else: - for subscribe_client in self.shard_subscribe_clients: - subscribe_client.subscribe(channel) + self.primary_subscribe_client.subscribe(channel) - def xray_heartbeat_handler(self, unused_channel, data): - """Handle an xray heartbeat message from Redis.""" + def xray_heartbeat_batch_handler(self, unused_channel, data): + """Handle an xray heartbeat batch message from Redis.""" gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( data, 0) heartbeat_data = gcs_entries.Entries(0) - message = ray.gcs_utils.HeartbeatTableData.GetRootAsHeartbeatTableData( - heartbeat_data, 0) - num_resources = message.ResourcesAvailableLabelLength() - static_resources = {} - dynamic_resources = {} - for i in range(num_resources): - dyn = message.ResourcesAvailableLabel(i) - static = message.ResourcesTotalLabel(i) - dynamic_resources[dyn] = message.ResourcesAvailableCapacity(i) - static_resources[static] = message.ResourcesTotalCapacity(i) - - # Update the load metrics for this local scheduler. - client_id = ray.utils.binary_to_hex(message.ClientId()) - ip = self.local_scheduler_id_to_ip_map.get(client_id) - if ip: - self.load_metrics.update(ip, static_resources, dynamic_resources) - else: - print("Warning: could not find ip for client {} in {}.".format( - client_id, self.local_scheduler_id_to_ip_map)) + + message = (ray.gcs_utils.HeartbeatBatchTableData. + GetRootAsHeartbeatBatchTableData(heartbeat_data, 0)) + + for j in range(message.BatchLength()): + heartbeat_message = message.Batch(j) + + num_resources = heartbeat_message.ResourcesAvailableLabelLength() + static_resources = {} + dynamic_resources = {} + for i in range(num_resources): + dyn = heartbeat_message.ResourcesAvailableLabel(i) + static = heartbeat_message.ResourcesTotalLabel(i) + dynamic_resources[dyn] = ( + heartbeat_message.ResourcesAvailableCapacity(i)) + static_resources[static] = ( + heartbeat_message.ResourcesTotalCapacity(i)) + + # Update the load metrics for this local scheduler. + client_id = ray.utils.binary_to_hex(heartbeat_message.ClientId()) + ip = self.local_scheduler_id_to_ip_map.get(client_id) + if ip: + self.load_metrics.update(ip, static_resources, + dynamic_resources) + else: + print("Warning: could not find ip for client {} in {}.".format( + client_id, self.local_scheduler_id_to_ip_map)) def _xray_clean_up_entries_for_driver(self, driver_id): """Remove this driver's object/task entries from redis. @@ -222,8 +218,7 @@ def process_messages(self, max_messages=10000): max_messages: The maximum number of messages to process before returning. """ - subscribe_clients = ( - [self.primary_subscribe_client] + self.shard_subscribe_clients) + subscribe_clients = [self.primary_subscribe_client] for subscribe_client in subscribe_clients: for _ in range(max_messages): message = subscribe_client.get_message() @@ -237,9 +232,9 @@ def process_messages(self, max_messages=10000): # Determine the appropriate message handler. message_handler = None - if channel == ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL: + if channel == ray.gcs_utils.XRAY_HEARTBEAT_BATCH_CHANNEL: # Similar functionality as local scheduler info channel - message_handler = self.xray_heartbeat_handler + message_handler = self.xray_heartbeat_batch_handler elif channel == ray.gcs_utils.XRAY_DRIVER_CHANNEL: # Handles driver death. message_handler = self.xray_driver_removed_handler @@ -299,7 +294,7 @@ def run(self): clients and cleaning up state accordingly. """ # Initialize the subscription channel. - self.subscribe(ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL, primary=False) + self.subscribe(ray.gcs_utils.XRAY_HEARTBEAT_BATCH_CHANNEL) self.subscribe(ray.gcs_utils.XRAY_DRIVER_CHANNEL) # TODO(rkn): If there were any dead clients at startup, we should clean diff --git a/test/failure_test.py b/test/failure_test.py index 7895cbf1ac96..027ed38d6411 100644 --- a/test/failure_test.py +++ b/test/failure_test.py @@ -500,10 +500,10 @@ def test_warning_monitor_died(shutdown_only): # addition to the monitor. fake_id = 20 * b"\x00" malformed_message = "asdf" - redis_client = ray.worker.global_state.redis_clients[0] + redis_client = ray.worker.global_worker.redis_client redis_client.execute_command( - "RAY.TABLE_ADD", ray.gcs_utils.TablePrefix.HEARTBEAT, - ray.gcs_utils.TablePubsub.HEARTBEAT, fake_id, malformed_message) + "RAY.TABLE_ADD", ray.gcs_utils.TablePrefix.HEARTBEAT_BATCH, + ray.gcs_utils.TablePubsub.HEARTBEAT_BATCH, fake_id, malformed_message) wait_for_errors(ray_constants.MONITOR_DIED_ERROR, 1)