diff --git a/python/ray/autoscaler/_private/kuberay/node_provider.py b/python/ray/autoscaler/_private/kuberay/node_provider.py index 351764e1b018..823afe34e548 100644 --- a/python/ray/autoscaler/_private/kuberay/node_provider.py +++ b/python/ray/autoscaler/_private/kuberay/node_provider.py @@ -45,6 +45,9 @@ RAY_HEAD_POD_NAME = os.getenv("RAY_HEAD_POD_NAME") +# Key for GKE label that identifies which multi-host replica a pod belongs to +REPLICA_INDEX_KEY = "replicaIndex" + # Design: # Each modification the autoscaler wants to make is posted to the API server goal state @@ -79,7 +82,10 @@ def node_data_from_pod(pod: Dict[str, Any]) -> NodeData: kind, type = kind_and_type(pod) status = status_tag(pod) ip = pod_ip(pod) - return NodeData(kind=kind, type=type, status=status, ip=ip) + replica_index = _replica_index_label(pod) + return NodeData( + kind=kind, type=type, replica_index=replica_index, status=status, ip=ip + ) def kind_and_type(pod: Dict[str, Any]) -> Tuple[NodeKind, NodeType]: @@ -96,6 +102,16 @@ def kind_and_type(pod: Dict[str, Any]) -> Tuple[NodeKind, NodeType]: return kind, type +def _replica_index_label(pod: Dict[str, Any]) -> Optional[str]: + """Returns the replicaIndex label for a Pod in a multi-host TPU worker group. + The replicaIndex label is set by the GKE TPU Ray webhook and is of + the form {$WORKER_GROUP_NAME-$REPLICA_INDEX} where $REPLICA_INDEX + is an integer from 0 to Replicas-1. + """ + labels = pod["metadata"]["labels"] + return labels.get(REPLICA_INDEX_KEY, None) + + def pod_ip(pod: Dict[str, Any]) -> NodeIP: return pod["status"].get("podIP", "IP not yet assigned") diff --git a/python/ray/autoscaler/batching_node_provider.py b/python/ray/autoscaler/batching_node_provider.py index 7c7061c5cf50..f3a26d085952 100644 --- a/python/ray/autoscaler/batching_node_provider.py +++ b/python/ray/autoscaler/batching_node_provider.py @@ -14,6 +14,7 @@ NODE_KIND_HEAD, TAG_RAY_NODE_KIND, TAG_RAY_NODE_STATUS, + TAG_RAY_REPLICA_INDEX, TAG_RAY_USER_NODE_TYPE, ) @@ -43,6 +44,8 @@ class NodeData: Attributes: kind: Whether the node is the head or a worker. type: The user-defined type of the node. + replica_index: An identifier for nodes in a replica of a TPU worker group. + This value is set as a Pod label by a GKE webhook when TPUs are requested ip: Cluster-internal ip of the node. ip can be None if the ip has not yet been assigned. status: The status of the node. You must adhere to the following semantics @@ -58,6 +61,7 @@ class NodeData: type: NodeType ip: Optional[NodeIP] status: NodeStatus + replica_index: Optional[str] = None class BatchingNodeProvider(NodeProvider): @@ -116,6 +120,9 @@ def __init__( self.scale_request = ScaleRequest() + # Initialize map of replica indices to nodes in that replica + self.replica_index_to_nodes = defaultdict(list[str]) + def get_node_data(self) -> Dict[NodeID, NodeData]: """Queries cluster manager for node info. Returns a mapping from node id to NodeData. @@ -160,6 +167,12 @@ def non_terminated_nodes(self, tag_filters: Dict[str, str]) -> List[str]: workers_to_delete=set(), # No workers to delete yet ) all_nodes = list(self.node_data_dict.keys()) + self.replica_index_to_nodes.clear() + for node_id in all_nodes: + replica_index = self.node_data_dict[node_id].replica_index + # Only add node to map if it belongs to a multi-host podslice + if replica_index is not None: + self.replica_index_to_nodes[replica_index].append(node_id) # Support filtering by TAG_RAY_NODE_KIND, TAG_RAY_NODE_STATUS, and # TAG_RAY_USER_NODE_TYPE. # The autoscaler only uses tag_filters={}, @@ -187,11 +200,14 @@ def _cur_num_workers(self, node_data_dict: Dict[str, Any]): def node_tags(self, node_id: str) -> Dict[str, str]: node_data = self.node_data_dict[node_id] - return { + tags = { TAG_RAY_NODE_KIND: node_data.kind, TAG_RAY_NODE_STATUS: node_data.status, TAG_RAY_USER_NODE_TYPE: node_data.type, } + if node_data.replica_index is not None: + tags[TAG_RAY_REPLICA_INDEX] = node_data.replica_index + return tags def internal_ip(self, node_id: str) -> str: return self.node_data_dict[node_id].ip @@ -230,6 +246,20 @@ def terminate_node(self, node_id: str) -> Optional[Dict[str, Any]]: f"{node_type}. Skipping termination request." ) + # Terminate node self.scale_request.desired_num_workers[node_type] -= 1 self.scale_request.workers_to_delete.add(node_id) + + # Scale down all nodes in replica if node_id is part of a multi-host podslice + tags = self.node_tags(node_id) + if TAG_RAY_REPLICA_INDEX in tags: + node_replica_index = tags[TAG_RAY_REPLICA_INDEX] + for worker_id in self.replica_index_to_nodes[node_replica_index]: + # Check if worker has already been scheduled to delete + if worker_id not in self.scale_request.workers_to_delete: + self.scale_request.workers_to_delete.add(worker_id) + logger.info( + f"Autoscaler terminating node {worker_id} " + f"in multi-host replica {node_replica_index}." + ) self.scale_change_needed = True diff --git a/python/ray/autoscaler/tags.py b/python/ray/autoscaler/tags.py index 380b4450d6ec..38d03855040f 100644 --- a/python/ray/autoscaler/tags.py +++ b/python/ray/autoscaler/tags.py @@ -13,6 +13,8 @@ # Tag for user defined node types (e.g., m4xl_spot). This is used for multi # node type clusters. TAG_RAY_USER_NODE_TYPE = "ray-user-node-type" +# Tag for index of replica node belongs to. Used for multi-host worker groups. +TAG_RAY_REPLICA_INDEX = "ray-replica-index" # Tag for autofilled node types for legacy cluster yamls without multi # node type defined in the cluster configs. NODE_TYPE_LEGACY_HEAD = "ray-legacy-head-node-type" diff --git a/python/ray/tests/kuberay/test_files/podlist2.yaml b/python/ray/tests/kuberay/test_files/podlist2.yaml index 56371fb1b76a..92528b27b0fc 100644 --- a/python/ray/tests/kuberay/test_files/podlist2.yaml +++ b/python/ray/tests/kuberay/test_files/podlist2.yaml @@ -405,6 +405,293 @@ items: - ip: 10.4.0.6 qosClass: Burstable startTime: "2022-11-14T23:13:47Z" +- apiVersion: v1 + kind: Pod + metadata: + annotations: + ray.io/ft-enabled: "false" + creationTimestamp: "2024-06-28T10:11:15Z" + generateName: raycluster-autoscaler-worker-tpu-group- + labels: + app.kubernetes.io/created-by: kuberay-operator + app.kubernetes.io/name: kuberay + ray.io/cluster: raycluster-autoscaler + ray.io/group: tpu-group + ray.io/identifier: raycluster-autoscaler-worker + ray.io/is-ray-node: "yes" + ray.io/node-type: worker + replicaIndex: tpu-group-0 + name: raycluster-autoscaler-worker-fake-tpu-group-xtpcl + namespace: default + ownerReferences: + - apiVersion: ray.io/v1 + blockOwnerDeletion: true + controller: true + kind: RayCluster + name: raycluster-autoscaler + uid: eaac19a2-93e5-420e-98ce-9e47cf9f401f + resourceVersion: "13131412" + uid: a943c7f8-7e93-40c6-b676-9b4d7a0ac8c3 + spec: + affinity: + podAffinity: + requiredDuringSchedulingIgnoredDuringExecution: + - labelSelector: + matchExpressions: + - key: replicaIndex + operator: In + values: + - tpu-group-0 + topologyKey: cloud.google.com/gke-nodepool + containers: + - args: + - 'ulimit -n 65536; ray start --resources="{\"TPU\": 4}" --address=raycluster-autoscaler-head-svc.default.svc.cluster.local:6379 --metrics-export-port=8080 --block --dashboard-agent-listen-port=52365 --num-cpus=1 --memory=40000000000 ' + command: + - /bin/bash + - -lc + - -- + env: + - name: FQ_RAY_IP + value: raycluster-autoscaler-head-svc.default.svc.cluster.local + - name: RAY_IP + value: raycluster-autoscaler-head-svc + - name: RAY_CLUSTER_NAME + valueFrom: + fieldRef: + apiVersion: v1 + fieldPath: metadata.labels['ray.io/cluster'] + - name: RAY_CLOUD_INSTANCE_ID + valueFrom: + fieldRef: + apiVersion: v1 + fieldPath: metadata.name + - name: RAY_NODE_TYPE_NAME + valueFrom: + fieldRef: + apiVersion: v1 + fieldPath: metadata.labels['ray.io/group'] + - name: KUBERAY_GEN_RAY_START_CMD + value: 'ray start --resources="{\"TPU\": 4}" --address=raycluster-autoscaler-head-svc.default.svc.cluster.local:6379 --metrics-export-port=8080 --block --dashboard-agent-listen-port=52365 --num-cpus=1 --memory=40000000000 ' + - name: RAY_PORT + value: "6379" + - name: RAY_ADDRESS + value: raycluster-autoscaler-head-svc.default.svc.cluster.local:6379 + - name: RAY_USAGE_STATS_KUBERAY_IN_USE + value: "1" + - name: REDIS_PASSWORD + - name: RAY_DASHBOARD_ENABLE_K8S_DISK_USAGE + value: "1" + - name: TPU_WORKER_HOSTNAMES + value: tpu-group-0-0.raycluster-autoscaler-headless-worker-svc,tpu-group-0-1.raycluster-autoscaler-headless-worker-svc + - name: TPU_WORKER_ID + value: "0" + - name: TPU_NAME + value: tpu-group-0 + image: rayproject/ray:2.9.0 + imagePullPolicy: Always + lifecycle: + preStop: + exec: + command: + - /bin/sh + - -c + - ray stop + livenessProbe: + exec: + command: + - bash + - -c + - wget -T 2 -q -O- http://localhost:52365/api/local_raylet_healthz | grep + success + failureThreshold: 120 + initialDelaySeconds: 30 + periodSeconds: 5 + successThreshold: 1 + timeoutSeconds: 1 + name: ray-worker + ports: + - containerPort: 8080 + name: metrics + protocol: TCP + readinessProbe: + exec: + command: + - bash + - -c + - wget -T 2 -q -O- http://localhost:52365/api/local_raylet_healthz | grep + success + failureThreshold: 10 + initialDelaySeconds: 10 + periodSeconds: 5 + successThreshold: 1 + timeoutSeconds: 1 + resources: + limits: + cpu: "1" + ephemeral-storage: 10Gi + google.com/tpu: "4" + memory: 40G + requests: + cpu: "1" + ephemeral-storage: 10Gi + google.com/tpu: "4" + memory: 40G + terminationMessagePath: /dev/termination-log + terminationMessagePolicy: File + volumeMounts: + - mountPath: /dev/shm + name: shared-mem + - mountPath: /var/run/secrets/kubernetes.io/serviceaccount + name: kube-api-access-65x9l + readOnly: true + dnsPolicy: ClusterFirst + enableServiceLinks: true + hostname: tpu-group-0-0 + initContainers: + - args: + - "\n\t\t\t\t\tSECONDS=0\n\t\t\t\t\twhile true; do\n\t\t\t\t\t\tif (( SECONDS + <= 120 )); then\n\t\t\t\t\t\t\tif ray health-check --address raycluster-autoscaler-head-svc.default.svc.cluster.local:6379 + > /dev/null 2>&1; then\n\t\t\t\t\t\t\t\techo \"GCS is ready.\"\n\t\t\t\t\t\t\t\tbreak\n\t\t\t\t\t\t\tfi\n\t\t\t\t\t\t\techo + \"$SECONDS seconds elapsed: Waiting for GCS to be ready.\"\n\t\t\t\t\t\telse\n\t\t\t\t\t\t\tif + ray health-check --address raycluster-autoscaler-head-svc.default.svc.cluster.local:6379; + then\n\t\t\t\t\t\t\t\techo \"GCS is ready. Any error messages above can be safely + ignored.\"\n\t\t\t\t\t\t\t\tbreak\n\t\t\t\t\t\t\tfi\n\t\t\t\t\t\t\techo \"$SECONDS + seconds elapsed: Still waiting for GCS to be ready. For troubleshooting, refer + to the FAQ at https://github.com/ray-project/kuberay/blob/master/docs/guidance/FAQ.md.\"\n\t\t\t\t\t\tfi\n\t\t\t\t\t\tsleep + 5\t\t\n\t\t\t\t\tdone\n\t\t\t\t" + command: + - /bin/bash + - -lc + - -- + env: + - name: FQ_RAY_IP + value: raycluster-autoscaler-head-svc.default.svc.cluster.local + - name: RAY_IP + value: raycluster-autoscaler-head-svc + image: rayproject/ray:2.9.0 + imagePullPolicy: Always + name: wait-gcs-ready + resources: + limits: + cpu: 200m + memory: 256Mi + requests: + cpu: 200m + memory: 256Mi + terminationMessagePath: /dev/termination-log + terminationMessagePolicy: File + volumeMounts: + - mountPath: /var/run/secrets/kubernetes.io/serviceaccount + name: kube-api-access-65x9l + readOnly: true + nodeName: gke-tpu-0bf19815-10mj + nodeSelector: + cloud.google.com/gke-accelerator-count: "4" + cloud.google.com/gke-tpu-accelerator: tpu-v4-podslice + cloud.google.com/gke-tpu-topology: 2x2x2 + preemptionPolicy: PreemptLowerPriority + priority: 0 + restartPolicy: Always + schedulerName: default-scheduler + securityContext: {} + serviceAccount: default + serviceAccountName: default + subdomain: raycluster-autoscaler-headless-worker-svc + terminationGracePeriodSeconds: 30 + tolerations: + - effect: NoExecute + key: node.kubernetes.io/not-ready + operator: Exists + tolerationSeconds: 300 + - effect: NoExecute + key: node.kubernetes.io/unreachable + operator: Exists + tolerationSeconds: 300 + - effect: NoSchedule + key: google.com/tpu + operator: Exists + volumes: + - emptyDir: + medium: Memory + sizeLimit: 40G + name: shared-mem + - name: kube-api-access-65x9l + projected: + defaultMode: 420 + sources: + - serviceAccountToken: + expirationSeconds: 3607 + path: token + - configMap: + items: + - key: ca.crt + path: ca.crt + name: kube-root-ca.crt + - downwardAPI: + items: + - fieldRef: + apiVersion: v1 + fieldPath: metadata.namespace + path: namespace + status: + conditions: + - lastProbeTime: null + lastTransitionTime: "2024-06-28T10:11:48Z" + status: "True" + type: PodReadyToStartContainers + - lastProbeTime: null + lastTransitionTime: "2024-06-28T10:11:57Z" + status: "True" + type: Initialized + - lastProbeTime: null + lastTransitionTime: "2024-06-28T10:12:07Z" + status: "True" + type: Ready + - lastProbeTime: null + lastTransitionTime: "2024-06-28T10:12:07Z" + status: "True" + type: ContainersReady + - lastProbeTime: null + lastTransitionTime: "2024-06-28T10:11:46Z" + status: "True" + type: PodScheduled + containerStatuses: + - containerID: containerd://1e5d9cef5cb10636d44ef2ab6e557e71861f0960d05135df45d9af0c33a06d97 + image: docker.io/rayproject/ray:2.9.0 + imageID: docker.io/rayproject/ray@sha256:e64546fb5c3233bb0f33608e186e285c52cdd7440cae1af18f7fcde1c04e49f2 + lastState: {} + name: ray-worker + ready: true + restartCount: 0 + started: true + state: + running: + startedAt: "2024-06-28T10:11:57Z" + hostIP: 10.0.0.57 + hostIPs: + - ip: 10.0.0.57 + initContainerStatuses: + - containerID: containerd://40257ec805418def64c50b7ce7b59e5eca79bc91754893beb9bde4d4042f819b + image: docker.io/rayproject/ray:2.9.0 + imageID: docker.io/rayproject/ray@sha256:e64546fb5c3233bb0f33608e186e285c52cdd7440cae1af18f7fcde1c04e49f2 + lastState: {} + name: wait-gcs-ready + ready: true + restartCount: 0 + started: false + state: + terminated: + containerID: containerd://40257ec805418def64c50b7ce7b59e5eca79bc91754893beb9bde4d4042f819b + exitCode: 0 + finishedAt: "2024-06-28T10:11:56Z" + reason: Completed + startedAt: "2024-06-28T10:11:47Z" + phase: Running + podIP: 10.136.1.29 + podIPs: + - ip: 10.136.1.29 + qosClass: Guaranteed + startTime: "2024-06-28T10:11:46Z" - apiVersion: v1 kind: Pod metadata: diff --git a/python/ray/tests/kuberay/test_kuberay_node_provider.py b/python/ray/tests/kuberay/test_kuberay_node_provider.py index 46b539d724ef..4b15a1c84731 100644 --- a/python/ray/tests/kuberay/test_kuberay_node_provider.py +++ b/python/ray/tests/kuberay/test_kuberay_node_provider.py @@ -5,6 +5,7 @@ import jsonpatch import pytest +from collections import defaultdict from ray.autoscaler.batching_node_provider import NodeData from ray.autoscaler._private.kuberay.node_provider import ( _worker_group_index, @@ -22,7 +23,9 @@ def _get_basic_ray_cr_workers_to_delete( - cpu_workers_to_delete: List[NodeID], gpu_workers_to_delete: List[NodeID] + cpu_workers_to_delete: List[NodeID], + gpu_workers_to_delete: List[NodeID], + tpu_workers_to_delete: List[NodeID], ): """Generate a Ray cluster with non-empty workersToDelete field.""" raycluster = get_basic_ray_cr() @@ -32,6 +35,9 @@ def _get_basic_ray_cr_workers_to_delete( raycluster["spec"]["workerGroupSpecs"][1]["scaleStrategy"] = { "workersToDelete": gpu_workers_to_delete } + raycluster["spec"]["workerGroupSpecs"][2]["scaleStrategy"] = { + "workersToDelete": tpu_workers_to_delete + } return raycluster @@ -119,10 +125,18 @@ def test_create_node_cap_at_max( "podlist1.yaml", { "raycluster-autoscaler-head-8zsc8": NodeData( - kind="head", type="head-group", ip="10.4.2.6", status="up-to-date" + kind="head", + type="head-group", + replica_index=None, + ip="10.4.2.6", + status="up-to-date", ), # up-to-date status because the Ray container is in running status "raycluster-autoscaler-worker-small-group-dkz2r": NodeData( - kind="worker", type="small-group", ip="10.4.1.8", status="waiting" + kind="worker", + type="small-group", + replica_index=None, + ip="10.4.1.8", + status="waiting", ), # waiting status, because Ray container's state is "waiting". # The pod list includes a worker with non-null deletion timestamp. # It is excluded from the node data because it is considered @@ -134,23 +148,37 @@ def test_create_node_cap_at_max( "podlist2.yaml", { "raycluster-autoscaler-head-8zsc8": NodeData( - kind="head", type="head-group", ip="10.4.2.6", status="up-to-date" + kind="head", + type="head-group", + replica_index=None, + ip="10.4.2.6", + status="up-to-date", ), "raycluster-autoscaler-worker-fake-gpu-group-2qnhv": NodeData( kind="worker", type="fake-gpu-group", + replica_index=None, ip="10.4.0.6", status="up-to-date", ), + "raycluster-autoscaler-worker-fake-tpu-group-xtpcl": NodeData( + kind="worker", + type="tpu-group", + replica_index="tpu-group-0", + ip="10.136.1.29", + status="up-to-date", + ), "raycluster-autoscaler-worker-small-group-dkz2r": NodeData( kind="worker", type="small-group", + replica_index=None, ip="10.4.1.8", status="up-to-date", ), "raycluster-autoscaler-worker-small-group-lbfm4": NodeData( kind="worker", type="small-group", + replica_index=None, ip="10.4.0.5", status="up-to-date", ), @@ -175,6 +203,7 @@ def mock_get(node_provider, path): ), mock.patch.object(KubeRayNodeProvider, "_get", mock_get): kr_node_provider = KubeRayNodeProvider(provider_config={}, cluster_name="fake") kr_node_provider.cluster_name = "fake" + kr_node_provider.replica_index_to_nodes = defaultdict(list[str]) nodes = kr_node_provider.non_terminated_nodes({}) assert kr_node_provider.node_data_dict == expected_node_data assert set(nodes) == set(expected_node_data.keys()) @@ -187,23 +216,30 @@ def mock_get(node_provider, path): ( { "raycluster-autoscaler-head-8zsc8": NodeData( - kind="head", type="head-group", ip="10.4.2.6", status="up-to-date" + kind="head", + type="head-group", + replica_index=None, + ip="10.4.2.6", + status="up-to-date", ), "raycluster-autoscaler-worker-fake-gpu-group-2qnhv": NodeData( kind="worker", type="fake-gpu-group", + replica_index=None, ip="10.4.0.6", status="up-to-date", ), "raycluster-autoscaler-worker-small-group-dkz2r": NodeData( kind="worker", type="small-group", + replica_index=None, ip="10.4.1.8", status="up-to-date", ), "raycluster-autoscaler-worker-small-group-lbfm4": NodeData( kind="worker", type="small-group", + replica_index=None, ip="10.4.0.5", status="up-to-date", ), @@ -258,18 +294,20 @@ def test_submit_scale_request(node_data_dict, scale_request, expected_patch_payl @pytest.mark.parametrize("node_set", [{"A", "B", "C", "D", "E"}]) @pytest.mark.parametrize("cpu_workers_to_delete", ["A", "Z"]) @pytest.mark.parametrize("gpu_workers_to_delete", ["B", "Y"]) +@pytest.mark.parametrize("tpu_workers_to_delete", ["C", "X"]) @pytest.mark.skipif(sys.platform.startswith("win"), reason="Not relevant on Windows.") def test_safe_to_scale( node_set: Set[NodeID], cpu_workers_to_delete: List[NodeID], gpu_workers_to_delete: List[NodeID], + tpu_workers_to_delete: List[NodeID], ): # NodeData values unimportant for this test. - mock_node_data = NodeData("-", "-", "-", "-") + mock_node_data = NodeData("-", "-", "-", "-", "-") node_data_dict = {node_id: mock_node_data for node_id in node_set} raycluster = _get_basic_ray_cr_workers_to_delete( - cpu_workers_to_delete, gpu_workers_to_delete + cpu_workers_to_delete, gpu_workers_to_delete, tpu_workers_to_delete ) def mock_patch(kuberay_provider, path, patch_payload): @@ -286,12 +324,19 @@ def mock_patch(kuberay_provider, path, patch_payload): kr_node_provider.node_data_dict = node_data_dict actual_safe = kr_node_provider.safe_to_scale() - expected_safe = not any( - cpu_worker_to_delete in node_set - for cpu_worker_to_delete in cpu_workers_to_delete - ) and not any( - gpu_worker_to_delete in node_set - for gpu_worker_to_delete in gpu_workers_to_delete + expected_safe = ( + not any( + cpu_worker_to_delete in node_set + for cpu_worker_to_delete in cpu_workers_to_delete + ) + and not any( + gpu_worker_to_delete in node_set + for gpu_worker_to_delete in gpu_workers_to_delete + ) + and not any( + tpu_worker_to_delete in node_set + for tpu_worker_to_delete in tpu_workers_to_delete + ) ) assert expected_safe is actual_safe patched_cpu_workers_to_delete = kr_node_provider._patched_raycluster["spec"][ @@ -300,15 +345,20 @@ def mock_patch(kuberay_provider, path, patch_payload): patched_gpu_workers_to_delete = kr_node_provider._patched_raycluster["spec"][ "workerGroupSpecs" ][1]["scaleStrategy"]["workersToDelete"] + patched_tpu_workers_to_delete = kr_node_provider._patched_raycluster["spec"][ + "workerGroupSpecs" + ][2]["scaleStrategy"]["workersToDelete"] if expected_safe: # Cleaned up workers to delete assert patched_cpu_workers_to_delete == [] assert patched_gpu_workers_to_delete == [] + assert patched_tpu_workers_to_delete == [] else: # Did not clean up workers to delete assert patched_cpu_workers_to_delete == cpu_workers_to_delete assert patched_gpu_workers_to_delete == gpu_workers_to_delete + assert patched_tpu_workers_to_delete == tpu_workers_to_delete if __name__ == "__main__": diff --git a/python/ray/tests/test_batch_node_provider_unit.py b/python/ray/tests/test_batch_node_provider_unit.py index 55d075be53ad..86d2b5389e57 100644 --- a/python/ray/tests/test_batch_node_provider_unit.py +++ b/python/ray/tests/test_batch_node_provider_unit.py @@ -22,6 +22,7 @@ TAG_RAY_USER_NODE_TYPE, TAG_RAY_NODE_KIND, TAG_RAY_NODE_STATUS, + TAG_RAY_REPLICA_INDEX, NODE_KIND_HEAD, NODE_KIND_WORKER, ) @@ -58,6 +59,9 @@ def get_node_data(self) -> Dict[NodeID, NodeData]: self.num_non_terminated_nodes_calls += 1 return self._node_data_dict + def set_node_replica_index(self, node_id, replica_index): + self._node_data_dict[node_id].replica_index = replica_index + def submit_scale_request(self, scale_request: ScaleRequest) -> None: """Simulate modification of cluster state by an external cluster manager.""" self._scale_request_submitted_count += 1 @@ -415,6 +419,72 @@ def test_terminate_safeguards(): assert len(nodes) == 1 +@pytest.mark.skipif(sys.platform.startswith("win"), reason="Not relevant on Windows.") +def test_terminate_node_in_multihost_replica(): + """Test multi-host replica deletion logic for KubeRay. + + Tests manually deleting a node in a multi-host replica + and verifying that the entire replica is scaled down. + Nodes belonging to the same multi-host replica are identified + through a replicaIndex label set by a GKE webhook. + """ + # create 4 TPU nodes with MockBatchingNodeProvider + node_provider = MockBatchingNodeProvider( + provider_config={ + DISABLE_LAUNCH_CONFIG_CHECK_KEY: True, + DISABLE_NODE_UPDATERS_KEY: True, + FOREGROUND_NODE_LAUNCH_KEY: True, + }, + cluster_name="test-cluster", + _allow_multiple=True, + ) + + num_tpu_workers = 4 + for i in range(num_tpu_workers): + node_provider._add_node(node_type="TPU", node_kind=NODE_KIND_WORKER) + + # Set replica_index in node_data for all workers + workers = node_provider.non_terminated_nodes( + tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER} + ) + assert len(workers) == num_tpu_workers + for index, node_id in enumerate(workers): + if index < num_tpu_workers // 2: + node_provider.set_node_replica_index(node_id, "tpu-group-0") + else: + node_provider.set_node_replica_index(node_id, "tpu-group-1") + + # Verify RAY_REPLICA_INDEX tag has been set + replicaIndexFilter = {TAG_RAY_REPLICA_INDEX: "tpu-group-0"} + replicaWorkers1 = node_provider.non_terminated_nodes(tag_filters=replicaIndexFilter) + assert len(replicaWorkers1) == num_tpu_workers // 2 + + replicaIndexFilter[TAG_RAY_REPLICA_INDEX] = "tpu-group-1" + replicaWorkers2 = node_provider.non_terminated_nodes(tag_filters=replicaIndexFilter) + assert len(replicaWorkers2) == num_tpu_workers // 2 + + # Verify replica_to_nodes mapping has been populated + assert ( + len(node_provider.replica_index_to_nodes["tpu-group-0"]) == num_tpu_workers // 2 + ) + assert ( + len(node_provider.replica_index_to_nodes["tpu-group-1"]) == num_tpu_workers // 2 + ) + + worker_0 = replicaWorkers1[0] # tpu-group-0 + worker_2 = replicaWorkers2[0] # tpu-group-1 + # Manually delete one TPU worker in tpu-group-0 + # BatchingNodeProvider should scale down all nodes in the replica + assert worker_0 in node_provider.node_data_dict + node_provider.terminate_node(worker_0) + assert len(node_provider.scale_request.workers_to_delete) == num_tpu_workers // 2 + + # Scale down the tpu-group-1 replica + assert worker_2 in node_provider.node_data_dict + node_provider.terminate_node(worker_2) + assert len(node_provider.scale_request.workers_to_delete) == num_tpu_workers + + if __name__ == "__main__": if os.environ.get("PARALLEL_CI"):