diff --git a/metaflow/plugins/aws/batch/batch_cli.py b/metaflow/plugins/aws/batch/batch_cli.py index 6c8153da3a4..ec4f5304044 100644 --- a/metaflow/plugins/aws/batch/batch_cli.py +++ b/metaflow/plugins/aws/batch/batch_cli.py @@ -248,8 +248,15 @@ def echo(msg, stream="stderr", batch_id=None, **kwargs): } kwargs["input_paths"] = "".join("${%s}" % s for s in split_vars.keys()) - step_args = " ".join(util.dict_to_cli_options(kwargs)) + # For multinode, create modified kwargs for command construction only num_parallel = num_parallel or 0 + step_kwargs = kwargs.copy() + if num_parallel and num_parallel > 1: + # Pass task_id via an env var so shell can expand node index at runtime. + # Using a value starting with '$' prevents quoting in dict_to_cli_options. + step_kwargs["task_id"] = "$MF_TASK_ID_BASE[NODE-INDEX]" + + step_args = " ".join(util.dict_to_cli_options(step_kwargs)) if num_parallel and num_parallel > 1: # For multinode, we need to add a placeholder that can be mutated by the caller step_args += " [multinode-args]" @@ -270,15 +277,26 @@ def echo(msg, stream="stderr", batch_id=None, **kwargs): retry_deco[0].attributes.get("minutes_between_retries", 1) ) - # Set batch attributes + # Set batch attributes - use modified task_id for multinode to ensure MF_PATHSPEC has placeholder + task_spec_task_id = ( + step_kwargs["task_id"] if num_parallel > 1 else kwargs["task_id"] + ) task_spec = { + "flow_name": ctx.obj.flow.name, + "step_name": step_name, + "run_id": kwargs["run_id"], + "task_id": task_spec_task_id, + "retry_count": str(retry_count), + } + # Keep attrs clean with original task_id for metadata + main_task_spec = { "flow_name": ctx.obj.flow.name, "step_name": step_name, "run_id": kwargs["run_id"], "task_id": kwargs["task_id"], "retry_count": str(retry_count), } - attrs = {"metaflow.%s" % k: v for k, v in task_spec.items()} + attrs = {"metaflow.%s" % k: v for k, v in main_task_spec.items()} attrs["metaflow.user"] = util.get_username() attrs["metaflow.version"] = ctx.obj.environment.get_environment_info()[ "metaflow_version" @@ -302,6 +320,11 @@ def echo(msg, stream="stderr", batch_id=None, **kwargs): if split_vars: env.update(split_vars) + # For multinode, provide the base task id to be expanded in the container + if num_parallel and num_parallel > 1: + # Ensure we don't carry a possible 'control-' prefix into worker IDs + env["MF_TASK_ID_BASE"] = str(kwargs["task_id"]).replace("control-", "") + if retry_count: ctx.obj.echo_always( "Sleeping %d minutes before the next AWS Batch retry" diff --git a/metaflow/plugins/aws/batch/batch_client.py b/metaflow/plugins/aws/batch/batch_client.py index bf0f6a824e7..4c3eb9c750b 100644 --- a/metaflow/plugins/aws/batch/batch_client.py +++ b/metaflow/plugins/aws/batch/batch_client.py @@ -4,6 +4,7 @@ import random import time import hashlib +import os try: unicode @@ -19,7 +20,34 @@ class BatchClient(object): def __init__(self): from ..aws_client import get_aws_client - self._client = get_aws_client("batch") + # Prefer the task role by default when running inside AWS Batch containers + # by temporarily removing higher-precedence env credentials for this process. + # This avoids AMI-injected AWS_* env vars from overriding the task role. + # Outside of Batch, we leave env vars untouched unless explicitly opted-in. + if "AWS_BATCH_JOB_ID" in os.environ: + _aws_env_keys = [ + "AWS_ACCESS_KEY_ID", + "AWS_SECRET_ACCESS_KEY", + "AWS_SESSION_TOKEN", + "AWS_PROFILE", + "AWS_DEFAULT_PROFILE", + ] + _present = [k for k in _aws_env_keys if k in os.environ] + print( + "[Metaflow] AWS credential-related env vars present before Batch client init:", + _present, + ) + _saved_env = { + k: os.environ.pop(k) for k in _aws_env_keys if k in os.environ + } + try: + self._client = get_aws_client("batch") + finally: + # Restore prior env for the rest of the process + for k, v in _saved_env.items(): + os.environ[k] = v + else: + self._client = get_aws_client("batch") def active_job_queues(self): paginator = self._client.get_paginator("describe_job_queues") @@ -96,6 +124,8 @@ def execute(self): commands = self.payload["containerOverrides"]["command"][-1] # add split-index as this worker is also an ubf_task commands = commands.replace("[multinode-args]", "--split-index 0") + # For main node, remove the placeholder since it keeps the original task ID + commands = commands.replace("[NODE-INDEX]", "") main_task_override["command"][-1] = commands # secondary tasks @@ -103,18 +133,12 @@ def execute(self): self.payload["containerOverrides"] ) secondary_commands = self.payload["containerOverrides"]["command"][-1] - # other tasks do not have control- prefix, and have the split id appended to the task -id - secondary_commands = secondary_commands.replace( - self._task_id, - self._task_id.replace("control-", "") - + "-node-$AWS_BATCH_JOB_NODE_INDEX", - ) - secondary_commands = secondary_commands.replace( - "ubf_control", - "ubf_task", - ) - secondary_commands = secondary_commands.replace( - "[multinode-args]", "--split-index $AWS_BATCH_JOB_NODE_INDEX" + # For secondary nodes: remove "control-" prefix and replace placeholders + secondary_commands = ( + secondary_commands.replace("control-", "") + .replace("[NODE-INDEX]", "-node-$AWS_BATCH_JOB_NODE_INDEX") + .replace("ubf_control", "ubf_task") + .replace("[multinode-args]", "--split-index $AWS_BATCH_JOB_NODE_INDEX") ) secondary_task_container_override["command"][-1] = secondary_commands @@ -408,6 +432,14 @@ def _register_job_definition( self.num_parallel = num_parallel or 0 if self.num_parallel >= 1: + # Set the ulimit of number of open files to 65536. This is because we cannot set it easily once worker processes start on Batch. + # job_definition["containerProperties"]["linuxParameters"]["ulimits"] = [ + # { + # "name": "nofile", + # "softLimit": 65536, + # "hardLimit": 65536, + # } + # ] job_definition["type"] = "multinode" job_definition["nodeProperties"] = { "numNodes": self.num_parallel, diff --git a/metaflow/plugins/aws/batch/batch_decorator.py b/metaflow/plugins/aws/batch/batch_decorator.py index 6d64ed994aa..8d51db49d45 100644 --- a/metaflow/plugins/aws/batch/batch_decorator.py +++ b/metaflow/plugins/aws/batch/batch_decorator.py @@ -421,6 +421,89 @@ def _wait_for_mapper_tasks(self, flow, step_name): TIMEOUT = 600 last_completion_timeout = time.time() + TIMEOUT print("Waiting for batch secondary tasks to finish") + + # Prefer Batch API when metadata is local (nodes can't share local metadata files). + # If metadata isn't bound yet but we are on Batch, also prefer Batch API. + md = getattr(self, "metadata", None) + if md is not None and md.TYPE == "local": + return self._wait_for_mapper_tasks_batch_api( + flow, step_name, last_completion_timeout + ) + if md is None and "AWS_BATCH_JOB_ID" in os.environ: + return self._wait_for_mapper_tasks_batch_api( + flow, step_name, last_completion_timeout + ) + return self._wait_for_mapper_tasks_metadata( + flow, step_name, last_completion_timeout + ) + + def _wait_for_mapper_tasks_batch_api( + self, flow, step_name, last_completion_timeout + ): + """ + Poll the shared datastore (S3) for DONE markers for each mapper task. + This avoids relying on a metadata service or local metadata files. + """ + from metaflow.datastore.task_datastore import TaskDataStore + + pathspecs = getattr(flow, "_control_mapper_tasks", []) + total = len(pathspecs) + if total == 0: + print("No mapper tasks discovered for datastore wait; returning") + return True + + print("Waiting for mapper DONE markers in datastore for %d tasks" % total) + poll_sleep = 3.0 + while last_completion_timeout > time.time(): + time.sleep(poll_sleep) + completed = 0 + for ps in pathspecs: + try: + parts = ps.split("/") + if len(parts) == 3: + run_id, step, task_id = parts + else: + # Fallback in case of unexpected format + run_id, step, task_id = self.run_id, step_name, parts[-1] + tds = TaskDataStore( + self.flow_datastore, + run_id, + step, + task_id, + mode="r", + allow_not_done=True, + ) + if tds.has_metadata(TaskDataStore.METADATA_DONE_SUFFIX): + completed += 1 + except Exception as e: + if os.environ.get("METAFLOW_DEBUG_BATCH_POLL") in ( + "1", + "true", + "True", + ): + print("Datastore wait: error checking %s: %s" % (ps, e)) + continue + if completed == total: + print("All mapper tasks have written DONE markers to datastore") + return True + print( + "Waiting for mapper DONE markers. Finished: %d/%d" % (completed, total) + ) + poll_sleep = min(poll_sleep * 1.25, 10.0) + + raise Exception( + "Batch secondary workers did not finish in %s seconds (datastore wait)" + % (time.time() - (last_completion_timeout - 600)) + ) + + def _wait_for_mapper_tasks_metadata(self, flow, step_name, last_completion_timeout): + """ + Polls Metaflow metadata (Step client) for task completion. + Works with service-backed metadata providers but can fail with local metadata + in multi-node setups due to isolated per-node filesystems. + """ + from metaflow import Step + while last_completion_timeout > time.time(): time.sleep(2) try: @@ -441,7 +524,8 @@ def _wait_for_mapper_tasks(self, flow, step_name): except Exception: pass raise Exception( - "Batch secondary workers did not finish in %s seconds" % TIMEOUT + "Batch secondary workers did not finish in %s seconds" + % (time.time() - (last_completion_timeout - 600)) ) @classmethod