1414 NODE_KIND_HEAD ,
1515 TAG_RAY_NODE_KIND ,
1616 TAG_RAY_NODE_STATUS ,
17+ TAG_RAY_REPLICA_INDEX ,
1718 TAG_RAY_USER_NODE_TYPE ,
1819)
1920
@@ -43,6 +44,8 @@ class NodeData:
4344 Attributes:
4445 kind: Whether the node is the head or a worker.
4546 type: The user-defined type of the node.
47+ replica_index: An identifier for nodes in a replica of a TPU worker group.
48+ This value is set as a Pod label by a GKE webhook when TPUs are requested
4649 ip: Cluster-internal ip of the node. ip can be None if the ip
4750 has not yet been assigned.
4851 status: The status of the node. You must adhere to the following semantics
@@ -58,6 +61,7 @@ class NodeData:
5861 type : NodeType
5962 ip : Optional [NodeIP ]
6063 status : NodeStatus
64+ replica_index : Optional [str ] = None
6165
6266
6367class BatchingNodeProvider (NodeProvider ):
@@ -116,6 +120,9 @@ def __init__(
116120
117121 self .scale_request = ScaleRequest ()
118122
123+ # Initialize map of replica indices to nodes in that replica
124+ self .replica_index_to_nodes = defaultdict (list [str ])
125+
119126 def get_node_data (self ) -> Dict [NodeID , NodeData ]:
120127 """Queries cluster manager for node info. Returns a mapping from node id to
121128 NodeData.
@@ -160,6 +167,12 @@ def non_terminated_nodes(self, tag_filters: Dict[str, str]) -> List[str]:
160167 workers_to_delete = set (), # No workers to delete yet
161168 )
162169 all_nodes = list (self .node_data_dict .keys ())
170+ self .replica_index_to_nodes .clear ()
171+ for node_id in all_nodes :
172+ replica_index = self .node_data_dict [node_id ].replica_index
173+ # Only add node to map if it belongs to a multi-host podslice
174+ if replica_index is not None :
175+ self .replica_index_to_nodes [replica_index ].append (node_id )
163176 # Support filtering by TAG_RAY_NODE_KIND, TAG_RAY_NODE_STATUS, and
164177 # TAG_RAY_USER_NODE_TYPE.
165178 # The autoscaler only uses tag_filters={},
@@ -187,11 +200,14 @@ def _cur_num_workers(self, node_data_dict: Dict[str, Any]):
187200
188201 def node_tags (self , node_id : str ) -> Dict [str , str ]:
189202 node_data = self .node_data_dict [node_id ]
190- return {
203+ tags = {
191204 TAG_RAY_NODE_KIND : node_data .kind ,
192205 TAG_RAY_NODE_STATUS : node_data .status ,
193206 TAG_RAY_USER_NODE_TYPE : node_data .type ,
194207 }
208+ if node_data .replica_index is not None :
209+ tags [TAG_RAY_REPLICA_INDEX ] = node_data .replica_index
210+ return tags
195211
196212 def internal_ip (self , node_id : str ) -> str :
197213 return self .node_data_dict [node_id ].ip
@@ -230,6 +246,20 @@ def terminate_node(self, node_id: str) -> Optional[Dict[str, Any]]:
230246 f"{ node_type } . Skipping termination request."
231247 )
232248
249+ # Terminate node
233250 self .scale_request .desired_num_workers [node_type ] -= 1
234251 self .scale_request .workers_to_delete .add (node_id )
252+
253+ # Scale down all nodes in replica if node_id is part of a multi-host podslice
254+ tags = self .node_tags (node_id )
255+ if TAG_RAY_REPLICA_INDEX in tags :
256+ node_replica_index = tags [TAG_RAY_REPLICA_INDEX ]
257+ for worker_id in self .replica_index_to_nodes [node_replica_index ]:
258+ # Check if worker has already been scheduled to delete
259+ if worker_id not in self .scale_request .workers_to_delete :
260+ self .scale_request .workers_to_delete .add (worker_id )
261+ logger .info (
262+ f"Autoscaler terminating node { worker_id } "
263+ f"in multi-host replica { node_replica_index } ."
264+ )
235265 self .scale_change_needed = True
0 commit comments