Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions metaflow/plugins/aws/batch/batch_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,15 @@ def echo(msg, stream="stderr", batch_id=None, **kwargs):
}
kwargs["input_paths"] = "".join("${%s}" % s for s in split_vars.keys())

step_args = " ".join(util.dict_to_cli_options(kwargs))
# For multinode, create modified kwargs for command construction only
num_parallel = num_parallel or 0
step_kwargs = kwargs.copy()
if num_parallel and num_parallel > 1:
# Pass task_id via an env var so shell can expand node index at runtime.
# Using a value starting with '$' prevents quoting in dict_to_cli_options.
step_kwargs["task_id"] = "$MF_TASK_ID_BASE[NODE-INDEX]"

step_args = " ".join(util.dict_to_cli_options(step_kwargs))
if num_parallel and num_parallel > 1:
# For multinode, we need to add a placeholder that can be mutated by the caller
step_args += " [multinode-args]"
Expand All @@ -270,15 +277,26 @@ def echo(msg, stream="stderr", batch_id=None, **kwargs):
retry_deco[0].attributes.get("minutes_between_retries", 1)
)

# Set batch attributes
# Set batch attributes - use modified task_id for multinode to ensure MF_PATHSPEC has placeholder
task_spec_task_id = (
step_kwargs["task_id"] if num_parallel > 1 else kwargs["task_id"]
)
task_spec = {
"flow_name": ctx.obj.flow.name,
"step_name": step_name,
"run_id": kwargs["run_id"],
"task_id": task_spec_task_id,
"retry_count": str(retry_count),
}
# Keep attrs clean with original task_id for metadata
main_task_spec = {
"flow_name": ctx.obj.flow.name,
"step_name": step_name,
"run_id": kwargs["run_id"],
"task_id": kwargs["task_id"],
"retry_count": str(retry_count),
}
attrs = {"metaflow.%s" % k: v for k, v in task_spec.items()}
attrs = {"metaflow.%s" % k: v for k, v in main_task_spec.items()}
attrs["metaflow.user"] = util.get_username()
attrs["metaflow.version"] = ctx.obj.environment.get_environment_info()[
"metaflow_version"
Expand All @@ -302,6 +320,11 @@ def echo(msg, stream="stderr", batch_id=None, **kwargs):
if split_vars:
env.update(split_vars)

# For multinode, provide the base task id to be expanded in the container
if num_parallel and num_parallel > 1:
# Ensure we don't carry a possible 'control-' prefix into worker IDs
env["MF_TASK_ID_BASE"] = str(kwargs["task_id"]).replace("control-", "")

if retry_count:
ctx.obj.echo_always(
"Sleeping %d minutes before the next AWS Batch retry"
Expand Down
58 changes: 45 additions & 13 deletions metaflow/plugins/aws/batch/batch_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import random
import time
import hashlib
import os

try:
unicode
Expand All @@ -19,7 +20,34 @@ class BatchClient(object):
def __init__(self):
from ..aws_client import get_aws_client

self._client = get_aws_client("batch")
# Prefer the task role by default when running inside AWS Batch containers
# by temporarily removing higher-precedence env credentials for this process.
# This avoids AMI-injected AWS_* env vars from overriding the task role.
# Outside of Batch, we leave env vars untouched unless explicitly opted-in.
if "AWS_BATCH_JOB_ID" in os.environ:
_aws_env_keys = [
"AWS_ACCESS_KEY_ID",
"AWS_SECRET_ACCESS_KEY",
"AWS_SESSION_TOKEN",
"AWS_PROFILE",
"AWS_DEFAULT_PROFILE",
]
_present = [k for k in _aws_env_keys if k in os.environ]
print(
"[Metaflow] AWS credential-related env vars present before Batch client init:",
_present,
)
_saved_env = {
k: os.environ.pop(k) for k in _aws_env_keys if k in os.environ
}
try:
self._client = get_aws_client("batch")
finally:
# Restore prior env for the rest of the process
for k, v in _saved_env.items():
os.environ[k] = v
else:
self._client = get_aws_client("batch")

def active_job_queues(self):
paginator = self._client.get_paginator("describe_job_queues")
Expand Down Expand Up @@ -96,25 +124,21 @@ def execute(self):
commands = self.payload["containerOverrides"]["command"][-1]
# add split-index as this worker is also an ubf_task
commands = commands.replace("[multinode-args]", "--split-index 0")
# For main node, remove the placeholder since it keeps the original task ID
commands = commands.replace("[NODE-INDEX]", "")
main_task_override["command"][-1] = commands

# secondary tasks
secondary_task_container_override = copy.deepcopy(
self.payload["containerOverrides"]
)
secondary_commands = self.payload["containerOverrides"]["command"][-1]
# other tasks do not have control- prefix, and have the split id appended to the task -id
secondary_commands = secondary_commands.replace(
self._task_id,
self._task_id.replace("control-", "")
+ "-node-$AWS_BATCH_JOB_NODE_INDEX",
)
secondary_commands = secondary_commands.replace(
"ubf_control",
"ubf_task",
)
secondary_commands = secondary_commands.replace(
"[multinode-args]", "--split-index $AWS_BATCH_JOB_NODE_INDEX"
# For secondary nodes: remove "control-" prefix and replace placeholders
secondary_commands = (
secondary_commands.replace("control-", "")
.replace("[NODE-INDEX]", "-node-$AWS_BATCH_JOB_NODE_INDEX")
.replace("ubf_control", "ubf_task")
.replace("[multinode-args]", "--split-index $AWS_BATCH_JOB_NODE_INDEX")
)

secondary_task_container_override["command"][-1] = secondary_commands
Expand Down Expand Up @@ -408,6 +432,14 @@ def _register_job_definition(

self.num_parallel = num_parallel or 0
if self.num_parallel >= 1:
# Set the ulimit of number of open files to 65536. This is because we cannot set it easily once worker processes start on Batch.
# job_definition["containerProperties"]["linuxParameters"]["ulimits"] = [
# {
# "name": "nofile",
# "softLimit": 65536,
# "hardLimit": 65536,
# }
# ]
job_definition["type"] = "multinode"
job_definition["nodeProperties"] = {
"numNodes": self.num_parallel,
Expand Down
86 changes: 85 additions & 1 deletion metaflow/plugins/aws/batch/batch_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,89 @@ def _wait_for_mapper_tasks(self, flow, step_name):
TIMEOUT = 600
last_completion_timeout = time.time() + TIMEOUT
print("Waiting for batch secondary tasks to finish")

# Prefer Batch API when metadata is local (nodes can't share local metadata files).
# If metadata isn't bound yet but we are on Batch, also prefer Batch API.
md = getattr(self, "metadata", None)
if md is not None and md.TYPE == "local":
return self._wait_for_mapper_tasks_batch_api(
flow, step_name, last_completion_timeout
)
if md is None and "AWS_BATCH_JOB_ID" in os.environ:
return self._wait_for_mapper_tasks_batch_api(
flow, step_name, last_completion_timeout
)
return self._wait_for_mapper_tasks_metadata(
flow, step_name, last_completion_timeout
)

def _wait_for_mapper_tasks_batch_api(
self, flow, step_name, last_completion_timeout
):
"""
Poll the shared datastore (S3) for DONE markers for each mapper task.
This avoids relying on a metadata service or local metadata files.
"""
from metaflow.datastore.task_datastore import TaskDataStore

pathspecs = getattr(flow, "_control_mapper_tasks", [])
total = len(pathspecs)
if total == 0:
print("No mapper tasks discovered for datastore wait; returning")
return True

print("Waiting for mapper DONE markers in datastore for %d tasks" % total)
poll_sleep = 3.0
while last_completion_timeout > time.time():
time.sleep(poll_sleep)
completed = 0
for ps in pathspecs:
try:
parts = ps.split("/")
if len(parts) == 3:
run_id, step, task_id = parts
else:
# Fallback in case of unexpected format
run_id, step, task_id = self.run_id, step_name, parts[-1]
tds = TaskDataStore(
self.flow_datastore,
run_id,
step,
task_id,
mode="r",
allow_not_done=True,
)
if tds.has_metadata(TaskDataStore.METADATA_DONE_SUFFIX):
completed += 1
except Exception as e:
if os.environ.get("METAFLOW_DEBUG_BATCH_POLL") in (
"1",
"true",
"True",
):
print("Datastore wait: error checking %s: %s" % (ps, e))
continue
if completed == total:
print("All mapper tasks have written DONE markers to datastore")
return True
print(
"Waiting for mapper DONE markers. Finished: %d/%d" % (completed, total)
)
poll_sleep = min(poll_sleep * 1.25, 10.0)

raise Exception(
"Batch secondary workers did not finish in %s seconds (datastore wait)"
% (time.time() - (last_completion_timeout - 600))
)

def _wait_for_mapper_tasks_metadata(self, flow, step_name, last_completion_timeout):
"""
Polls Metaflow metadata (Step client) for task completion.
Works with service-backed metadata providers but can fail with local metadata
in multi-node setups due to isolated per-node filesystems.
"""
from metaflow import Step

while last_completion_timeout > time.time():
time.sleep(2)
try:
Expand All @@ -441,7 +524,8 @@ def _wait_for_mapper_tasks(self, flow, step_name):
except Exception:
pass
raise Exception(
"Batch secondary workers did not finish in %s seconds" % TIMEOUT
"Batch secondary workers did not finish in %s seconds"
% (time.time() - (last_completion_timeout - 600))
)

@classmethod
Expand Down