diff --git a/.gitignore b/.gitignore index b36301c7c44..2e62c56aced 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,8 @@ __pycache__/ *.py[cod] *$py.class *.metaflow +*.metaflow_spin +metaflow_card_cache/ build/ dist/ diff --git a/metaflow/__init__.py b/metaflow/__init__.py index 0eba0da3f33..9a0b005e286 100644 --- a/metaflow/__init__.py +++ b/metaflow/__init__.py @@ -146,6 +146,7 @@ class and related decorators. metadata, get_metadata, default_metadata, + inspect_spin, Metaflow, Flow, Run, diff --git a/metaflow/cli.py b/metaflow/cli.py index cb9a0bc1ac9..63fb39c925a 100644 --- a/metaflow/cli.py +++ b/metaflow/cli.py @@ -1,3 +1,4 @@ +import os import functools import inspect import os @@ -7,7 +8,6 @@ import metaflow.tracing as tracing from metaflow._vendor import click -from metaflow.system import _system_logger, _system_monitor from . import decorators, lint, metaflow_version, parameters, plugins from .cli_args import cli_args @@ -27,6 +27,8 @@ DEFAULT_PACKAGE_SUFFIXES, ) from .metaflow_current import current +from .metaflow_profile import from_start +from metaflow.system import _system_monitor, _system_logger from .metaflow_environment import MetaflowEnvironment from .packaging_sys import MetaflowCodeContent from .plugins import ( @@ -38,9 +40,9 @@ ) from .pylint_wrapper import PyLint from .R import metaflow_r_version, use_r +from .util import get_latest_run_id, resolve_identity, decompress_list from .user_configs.config_options import LocalFileInput, config_options from .user_configs.config_parameters import ConfigValue -from .util import get_latest_run_id, resolve_identity ERASE_TO_EOL = "\033[K" HIGHLIGHT = "red" @@ -125,6 +127,8 @@ def logger(body="", system_msg=False, head="", bad=False, timestamp=True, nl=Tru "step": "metaflow.cli_components.step_cmd.step", "run": "metaflow.cli_components.run_cmds.run", "resume": "metaflow.cli_components.run_cmds.resume", + "spin": "metaflow.cli_components.run_cmds.spin", + "spin-step": "metaflow.cli_components.step_cmd.spin_step", }, ) def cli(ctx): @@ -318,6 +322,13 @@ def version(obj): hidden=True, is_eager=True, ) +@click.option( + "--spin-mode", + is_flag=True, + default=False, + help="Enable spin mode for metaflow cli commands. Setting this flag will result " + "in using spin metadata and spin datastore for executions" +) @click.pass_context def start( ctx, @@ -335,6 +346,7 @@ def start( local_config_file=None, config=None, config_value=None, + spin_mode=False, **deco_options ): if quiet: @@ -347,6 +359,7 @@ def start( if use_r(): version = metaflow_r_version() + from_start("MetaflowCLI: Starting") echo("Metaflow %s" % version, fg="magenta", bold=True, nl=False) echo(" executing *%s*" % ctx.obj.flow.name, fg="magenta", nl=False) echo(" for *%s*" % resolve_identity(), fg="magenta") @@ -366,6 +379,7 @@ def start( ctx.obj.check = functools.partial(_check, echo) ctx.obj.top_cli = cli ctx.obj.package_suffixes = package_suffixes.split(",") + ctx.obj.spin_mode = spin_mode ctx.obj.datastore_impl = [d for d in DATASTORES if d.TYPE == datastore][0] @@ -472,19 +486,12 @@ def start( # set force rebuild flag for environments that support it. ctx.obj.environment._force_rebuild = force_rebuild_environments ctx.obj.environment.validate_environment(ctx.obj.logger, datastore) - ctx.obj.event_logger = LOGGING_SIDECARS[event_logger]( flow=ctx.obj.flow, env=ctx.obj.environment ) - ctx.obj.event_logger.start() - _system_logger.init_system_logger(ctx.obj.flow.name, ctx.obj.event_logger) - ctx.obj.monitor = MONITOR_SIDECARS[monitor]( flow=ctx.obj.flow, env=ctx.obj.environment ) - ctx.obj.monitor.start() - _system_monitor.init_system_monitor(ctx.obj.flow.name, ctx.obj.monitor) - ctx.obj.metadata = [m for m in METADATA_PROVIDERS if m.TYPE == metadata][0]( ctx.obj.environment, ctx.obj.flow, ctx.obj.event_logger, ctx.obj.monitor ) @@ -498,6 +505,52 @@ def start( ) ctx.obj.config_options = config_options + ctx.obj.is_spin = False + ctx.obj.skip_decorators = False + + # Override values for spin steps, or if we are in spin mode + if hasattr(ctx, "saved_args") and ctx.saved_args and "spin" in ctx.saved_args[0] or ctx.obj.spin_mode: + # To minimize side effects for spin, we will only use the following: + # - local metadata provider, + # - local datastore, + # - local environment, + # - null event logger, + # - null monitor + ctx.obj.is_spin = True + if "--skip-decorators" in ctx.saved_args: + ctx.obj.skip_decorators = True + + ctx.obj.event_logger = LOGGING_SIDECARS["nullSidecarLogger"]( + flow=ctx.obj.flow, env=ctx.obj.environment + ) + ctx.obj.monitor = MONITOR_SIDECARS["nullSidecarMonitor"]( + flow=ctx.obj.flow, env=ctx.obj.environment + ) + # Use spin metadata, spin datastore, and spin datastore root + ctx.obj.metadata = [m for m in METADATA_PROVIDERS if m.TYPE == "spin"][0]( + ctx.obj.environment, ctx.obj.flow, ctx.obj.event_logger, ctx.obj.monitor + ) + ctx.obj.datastore_impl = [d for d in DATASTORES if d.TYPE == "spin"][0] + datastore_root = ctx.obj.datastore_impl.get_datastore_root_from_config( + ctx.obj.echo, create_on_absent=True + ) + ctx.obj.datastore_impl.datastore_root = datastore_root + + ctx.obj.flow_datastore = FlowDataStore( + ctx.obj.flow.name, + ctx.obj.environment, # Same environment as run/resume + ctx.obj.metadata, # local metadata + ctx.obj.event_logger, # null event logger + ctx.obj.monitor, # null monitor + storage_impl=ctx.obj.datastore_impl, + ) + + # Start event logger and monitor + ctx.obj.event_logger.start() + _system_logger.init_system_logger(ctx.obj.flow.name, ctx.obj.event_logger) + + ctx.obj.monitor.start() + _system_monitor.init_system_monitor(ctx.obj.flow.name, ctx.obj.monitor) decorators._init(ctx.obj.flow) @@ -512,9 +565,11 @@ def start( ctx.obj.logger, echo, deco_options, + ctx.obj.is_spin, + ctx.obj.skip_decorators, ) - # In the case of run/resume, we will want to apply the TL decospecs + # In the case of run/resume/spin, we will want to apply the TL decospecs # *after* the run decospecs so that they don't take precedence. In other # words, for the same decorator, we want `myflow.py run --with foo` to # take precedence over any other `foo` decospec @@ -542,11 +597,10 @@ def start( if ( hasattr(ctx, "saved_args") and ctx.saved_args - and ctx.saved_args[0] not in ("run", "resume") + and ctx.saved_args[0] not in ("run", "resume", "spin") ): - # run/resume are special cases because they can add more decorators with --with, + # run/resume/spin are special cases because they can add more decorators with --with, # so they have to take care of themselves. - all_decospecs = ctx.obj.tl_decospecs + list( ctx.obj.environment.decospecs() or [] ) @@ -556,6 +610,9 @@ def start( # or a scheduler setting them up in their own way. if ctx.saved_args[0] not in ("step", "init"): all_decospecs += DEFAULT_DECOSPECS.split() + elif ctx.saved_args[0] == "spin-step": + # If we are in spin-args, we will not attach any decorators + all_decospecs = [] if all_decospecs: decorators._attach_decorators(ctx.obj.flow, all_decospecs) decorators._init(ctx.obj.flow) @@ -569,6 +626,9 @@ def start( ctx.obj.environment, ctx.obj.flow_datastore, ctx.obj.logger, + # The last two arguments are only used for spin steps + ctx.obj.is_spin, + ctx.obj.skip_decorators, ) # Check the graph again (mutators may have changed it) diff --git a/metaflow/cli_components/run_cmds.py b/metaflow/cli_components/run_cmds.py index 159e2764303..af4e1f2f234 100644 --- a/metaflow/cli_components/run_cmds.py +++ b/metaflow/cli_components/run_cmds.py @@ -9,20 +9,21 @@ from ..graph import FlowGraph from ..metaflow_current import current from ..metaflow_config import DEFAULT_DECOSPECS, FEAT_ALWAYS_UPLOAD_CODE_PACKAGE +from ..metaflow_profile import from_start from ..package import MetaflowPackage -from ..runtime import NativeRuntime +from ..runtime import NativeRuntime, SpinRuntime from ..system import _system_logger # from ..client.core import Run from ..tagging_util import validate_tags -from ..util import get_latest_run_id, write_latest_run_id +from ..util import get_latest_run_id, write_latest_run_id, parse_spin_pathspec -def before_run(obj, tags, decospecs): +def before_run(obj, tags, decospecs, skip_decorators=False): validate_tags(tags) - # There's a --with option both at the top-level and for the run + # There's a --with option both at the top-level and for the run/resume/spin # subcommand. Why? # # "run --with shoes" looks so much better than "--with shoes run". @@ -36,26 +37,36 @@ def before_run(obj, tags, decospecs): # - run level decospecs # - top level decospecs # - environment decospecs - all_decospecs = ( - list(decospecs or []) - + obj.tl_decospecs - + list(obj.environment.decospecs() or []) - ) - if all_decospecs: - # These decospecs are the ones from run/resume PLUS the ones from the - # environment (for example the @conda) - decorators._attach_decorators(obj.flow, all_decospecs) - decorators._init(obj.flow) - # Regenerate graph if we attached more decorators - obj.flow.__class__._init_graph() - obj.graph = obj.flow._graph - - obj.check(obj.graph, obj.flow, obj.environment, pylint=obj.pylint) - # obj.environment.init_environment(obj.logger) - - decorators._init_step_decorators( - obj.flow, obj.graph, obj.environment, obj.flow_datastore, obj.logger + from_start( + f"Inside before_run, skip_decorators={skip_decorators}, is_spin={obj.is_spin}" ) + if not skip_decorators: + all_decospecs = ( + list(decospecs or []) + + obj.tl_decospecs + + list(obj.environment.decospecs() or []) + ) + if all_decospecs: + # These decospecs are the ones from run/resume/spin PLUS the ones from the + # environment (for example the @conda) + decorators._attach_decorators(obj.flow, all_decospecs) + decorators._init(obj.flow) + # Regenerate graph if we attached more decorators + obj.flow.__class__._init_attrs() + obj.graph = obj.flow._graph + + obj.check(obj.graph, obj.flow, obj.environment, pylint=obj.pylint) + # obj.environment.init_environment(obj.logger) + + decorators._init_step_decorators( + obj.flow, + obj.graph, + obj.environment, + obj.flow_datastore, + obj.logger, + obj.is_spin, + skip_decorators, + ) # Re-read graph since it may have been modified by mutators obj.graph = obj.flow._graph @@ -73,6 +84,29 @@ def before_run(obj, tags, decospecs): ) +def common_runner_options(func): + @click.option( + "--run-id-file", + default=None, + show_default=True, + type=str, + help="Write the ID of this run to the file specified.", + ) + @click.option( + "--runner-attribute-file", + default=None, + show_default=True, + type=str, + help="Write the metadata and pathspec of this run to the file specified. Used internally " + "for Metaflow's Runner API.", + ) + @wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper + + def write_file(file_path, content): if file_path is not None: with open(file_path, "w", encoding="utf-8") as f: @@ -137,20 +171,6 @@ def common_run_options(func): "in steps.", callback=config_callback, ) - @click.option( - "--run-id-file", - default=None, - show_default=True, - type=str, - help="Write the ID of this run to the file specified.", - ) - @click.option( - "--runner-attribute-file", - default=None, - show_default=True, - type=str, - help="Write the metadata and pathspec of this run to the file specified. Used internally for Metaflow's Runner API.", - ) @wraps(func) def wrapper(*args, **kwargs): return func(*args, **kwargs) @@ -195,6 +215,7 @@ def wrapper(*args, **kwargs): @click.command(help="Resume execution of a previous run of this flow.") @tracing.cli("cli/resume") @common_run_options +@common_runner_options @click.pass_obj def resume( obj, @@ -326,6 +347,7 @@ def resume( @click.command(help="Run the workflow locally.") @tracing.cli("cli/run") @common_run_options +@common_runner_options @click.option( "--namespace", "user_namespace", @@ -348,7 +370,7 @@ def run( run_id_file=None, runner_attribute_file=None, user_namespace=None, - **kwargs + **kwargs, ): if user_namespace is not None: namespace(user_namespace or None) @@ -401,3 +423,109 @@ def run( ) with runtime.run_heartbeat(): runtime.execute() + + +@parameters.add_custom_parameters(deploy_mode=True) +@click.command(help="Spins up a task for a given step from a previous run locally.") +@tracing.cli("cli/spin") +@click.argument("pathspec") +@click.option( + "--skip-decorators/--no-skip-decorators", + is_flag=True, + default=False, + show_default=True, + help="Skip decorators attached to the step or flow.", +) +@click.option( + "--artifacts-module", + default=None, + show_default=True, + help="Path to a module that contains artifacts to be used in the spun step. " + "The artifacts should be defined as a dictionary called ARTIFACTS with keys as " + "the artifact names and values as the artifact values. The artifact values will " + "overwrite the default values of the artifacts used in the spun step.", +) +@click.option( + "--persist/--no-persist", + "persist", + default=True, + show_default=True, + help="Whether to persist the artifacts in the spun step. If set to False, " + "the artifacts will not be persisted and will not be available in the spun step's " + "datastore.", +) +@click.option( + "--max-log-size", + default=10, + show_default=True, + help="Maximum size of stdout and stderr captured in " + "megabytes. If a step outputs more than this to " + "stdout/stderr, its output will be truncated.", +) +@common_runner_options +@click.pass_obj +def spin( + obj, + pathspec, + persist=True, + artifacts_module=None, + skip_decorators=False, + max_log_size=None, + run_id_file=None, + runner_attribute_file=None, + **kwargs, +): + # Parse the pathspec argument to extract step name and full pathspec + step_name, parsed_pathspec = parse_spin_pathspec(pathspec, obj.flow.name) + + before_run(obj, [], [], skip_decorators) + obj.echo(f"Spinning up step *{step_name}* locally for flow *{obj.flow.name}*") + obj.flow._set_constants(obj.graph, kwargs, obj.config_options) + step_func = getattr(obj.flow, step_name, None) + if step_func is None: + raise CommandException( + f"Step '{step_name}' not found in flow '{obj.flow.name}'. " + "Please provide a valid step name." + ) + from_start("Spin: before spin runtime init") + spin_runtime = SpinRuntime( + obj.flow, + obj.graph, + obj.flow_datastore, + obj.metadata, + obj.environment, + obj.package, + obj.logger, + obj.entrypoint, + obj.event_logger, + obj.monitor, + step_func, + step_name, + parsed_pathspec, + skip_decorators, + artifacts_module, + persist, + max_log_size * 1024 * 1024, + ) + write_latest_run_id(obj, spin_runtime.run_id) + write_file(run_id_file, spin_runtime.run_id) + # We only need the root for the metadata, i.e. the portion before DATASTORE_LOCAL_DIR + datastore_root = spin_runtime._flow_datastore._storage_impl.datastore_root + orig_task_metadata_root = datastore_root.rsplit("/", 1)[0] + from_start("Spin: going to execute") + spin_runtime.execute() + from_start("Spin: after spin runtime execute") + + if runner_attribute_file: + with open(runner_attribute_file, "w") as f: + json.dump( + { + "task_id": spin_runtime.task.task_id, + "step_name": step_name, + "run_id": spin_runtime.run_id, + "flow_name": obj.flow.name, + # Store metadata in a format that can be used by the Runner API + "metadata": f"{obj.metadata.__class__.TYPE}@{orig_task_metadata_root}", + }, + f, + ) diff --git a/metaflow/cli_components/step_cmd.py b/metaflow/cli_components/step_cmd.py index f4bef099e42..79cc78ad584 100644 --- a/metaflow/cli_components/step_cmd.py +++ b/metaflow/cli_components/step_cmd.py @@ -1,12 +1,17 @@ from metaflow._vendor import click -from .. import decorators, namespace +from .. import namespace from ..cli import echo_always, echo_dev_null from ..cli_args import cli_args +from ..datastore.flow_datastore import FlowDataStore from ..exception import CommandException +from ..client.filecache import FileCache, FileBlobCache, TaskMetadataCache +from ..metaflow_config import SPIN_ALLOWED_DECORATORS +from ..metaflow_profile import from_start +from ..plugins import DATASTORES from ..task import MetaflowTask from ..unbounded_foreach import UBF_CONTROL, UBF_TASK -from ..util import decompress_list +from ..util import decompress_list, read_artifacts_module import metaflow.tracing as tracing @@ -109,7 +114,6 @@ def step( ubf_context="none", num_parallel=None, ): - if ctx.obj.is_quiet: echo = echo_dev_null else: @@ -118,7 +122,7 @@ def step( if ubf_context == "none": ubf_context = None if opt_namespace is not None: - namespace(opt_namespace or None) + namespace(opt_namespace) func = None try: @@ -176,3 +180,155 @@ def step( ) echo("Success", fg="green", bold=True, indent=True) + + +@click.command(help="Internal command to spin a single task.", hidden=True) +@click.argument("step-name") +@click.option( + "--run-id", + default=None, + required=True, + help="Run ID for the step that's about to be spun", +) +@click.option( + "--task-id", + default=None, + required=True, + help="Task ID for the step that's about to be spun", +) +@click.option( + "--orig-flow-datastore", + show_default=True, + help="Original datastore for the flow from which a task is being spun", +) +@click.option( + "--input-paths", + help="A comma-separated list of pathspecs specifying inputs for this step.", +) +@click.option( + "--split-index", + type=int, + default=None, + show_default=True, + help="Index of this foreach split.", +) +@click.option( + "--retry-count", + default=0, + help="How many times we have attempted to run this task.", +) +@click.option( + "--max-user-code-retries", + default=0, + help="How many times we should attempt running the user code.", +) +@click.option( + "--namespace", + "opt_namespace", + default=None, + help="Change namespace from the default (your username) to the specified tag.", +) +@click.option( + "--skip-decorators/--no-skip-decorators", + is_flag=True, + default=False, + show_default=True, + help="Skip decorators attached to the step or flow.", +) +@click.option( + "--persist/--no-persist", + "persist", + default=True, + show_default=True, + help="Whether to persist the artifacts in the spun step. If set to false, the artifacts will not" + " be persisted and will not be available in the spun step's datastore.", +) +@click.option( + "--artifacts-module", + default=None, + show_default=True, + help="Path to a module that contains artifacts to be used in the spun step. The artifacts should " + "be defined as a dictionary called ARTIFACTS with keys as the artifact names and values as the " + "artifact values. The artifact values will overwrite the default values of the artifacts used in " + "the spun step.", +) +@click.pass_context +def spin_step( + ctx, + step_name, + orig_flow_datastore, + run_id=None, + task_id=None, + input_paths=None, + split_index=None, + retry_count=None, + max_user_code_retries=None, + opt_namespace=None, + skip_decorators=False, + artifacts_module=None, + persist=True, +): + import time + + if ctx.obj.is_quiet: + echo = echo_dev_null + else: + echo = echo_always + + if opt_namespace is not None: + namespace(opt_namespace) + + input_paths = decompress_list(input_paths) if input_paths else [] + + skip_decorators = skip_decorators + whitelist_decorators = [] if skip_decorators else SPIN_ALLOWED_DECORATORS + from_start("SpinStep: initialized decorators") + spin_artifacts = read_artifacts_module(artifacts_module) if artifacts_module else {} + from_start("SpinStep: read artifacts module") + + ds_type, ds_root = orig_flow_datastore.split("@") + orig_datastore_impl = [d for d in DATASTORES if d.TYPE == ds_type][0] + orig_datastore_impl.datastore_root = ds_root + orig_flow_datastore = FlowDataStore( + ctx.obj.flow.name, + environment=None, + storage_impl=orig_datastore_impl, + ds_root=ds_root, + ) + + filecache = FileCache() + orig_flow_datastore.set_metadata_cache( + TaskMetadataCache(filecache, ds_type, ds_root, ctx.obj.flow.name) + ) + orig_flow_datastore.ca_store.set_blob_cache( + FileBlobCache( + filecache, FileCache.flow_ds_id(ds_type, ds_root, ctx.obj.flow.name) + ) + ) + + task = MetaflowTask( + ctx.obj.flow, + ctx.obj.flow_datastore, + ctx.obj.metadata, + ctx.obj.environment, + echo, + ctx.obj.event_logger, + ctx.obj.monitor, + None, # no unbounded foreach context + orig_flow_datastore=orig_flow_datastore, + spin_artifacts=spin_artifacts, + ) + from_start("SpinStep: initialized task") + task.run_step( + step_name, + run_id, + task_id, + None, + input_paths, + split_index, + retry_count, + max_user_code_retries, + whitelist_decorators, + persist, + ) + from_start("SpinStep: ran step") diff --git a/metaflow/client/__init__.py b/metaflow/client/__init__.py index a06fbd290ba..9acf7c44c88 100644 --- a/metaflow/client/__init__.py +++ b/metaflow/client/__init__.py @@ -6,6 +6,7 @@ metadata, get_metadata, default_metadata, + inspect_spin, Metaflow, Flow, Run, diff --git a/metaflow/client/core.py b/metaflow/client/core.py index 8f4d89d555f..59469be7737 100644 --- a/metaflow/client/core.py +++ b/metaflow/client/core.py @@ -207,6 +207,15 @@ def default_namespace() -> str: return get_namespace() +def inspect_spin(datastore_root): + """ + Set metadata provider to spin metadata so that users can inspect spin + steps, tasks, and artifacts. + """ + metadata_str = f"spin@{datastore_root}" + metadata(metadata_str) + + MetaflowArtifacts = NamedTuple @@ -277,6 +286,7 @@ def __init__( self._attempt = attempt self._current_namespace = _current_namespace or get_namespace() self._namespace_check = _namespace_check + # If the current namespace is False, we disable checking for namespace for this # and all children objects. Not setting namespace_check to False has the consequence # of preventing access to children objects after the namespace changes @@ -1181,149 +1191,191 @@ class Task(MetaflowObject): _PARENT_CLASS = "step" _CHILD_CLASS = "artifact" - def __init__(self, *args, **kwargs): - super(Task, self).__init__(*args, **kwargs) - def _iter_filter(self, x): # exclude private data artifacts return x.id[0] != "_" - def _iter_matching_tasks(self, steps, metadata_key, metadata_pattern): + def _get_matching_pathspecs(self, steps, metadata_key, metadata_pattern): """ - Yield tasks from specified steps matching a foreach path pattern. + Yield pathspecs of tasks from specified steps that match a given metadata pattern. Parameters ---------- steps : List[str] - List of step names to search for tasks - pattern : str - Regex pattern to match foreach-indices metadata + List of Step objects to search for tasks. + metadata_key : str + Metadata key to filter tasks on (e.g., 'foreach-execution-path'). + metadata_pattern : str + Regular expression pattern to match against the metadata value. - Returns - ------- - Iterator[Task] - Tasks matching the foreach path pattern + Yields + ------ + str + Pathspec of each task whose metadata value for the specified key matches the pattern. """ flow_id, run_id, _, _ = self.path_components - for step in steps: task_pathspecs = self._metaflow.metadata.filter_tasks_by_metadata( - flow_id, run_id, step.id, metadata_key, metadata_pattern + flow_id, run_id, step, metadata_key, metadata_pattern ) for task_pathspec in task_pathspecs: - yield Task(pathspec=task_pathspec, _namespace_check=False) + yield task_pathspec + + @staticmethod + def _get_previous_steps(graph_info, step_name): + # Get the parent steps + steps = [] + for node_name, attributes in graph_info["steps"].items(): + if step_name in attributes["next"]: + steps.append(node_name) + return steps @property - def parent_tasks(self) -> Iterator["Task"]: + def parent_task_pathspecs(self) -> Iterator[str]: """ - Yields all parent tasks of the current task if one exists. + Yields pathspecs of all parent tasks of the current task. Yields ------ - Task - Parent task of the current task - + str + Pathspec of the parent task of the current task """ - flow_id, run_id, _, _ = self.path_components - - steps = list(self.parent.parent_steps) - if not steps: - return [] + _, _, step_name, _ = self.path_components + metadata_dict = self.metadata_dict + graph_info = self["_graph_info"].data - current_path = self.metadata_dict.get("foreach-execution-path", "") + # Get the parent steps + steps = self._get_previous_steps(graph_info, step_name) + node_type = graph_info["steps"][step_name]["type"] + metadata_key = "foreach-execution-path" + current_path = metadata_dict.get(metadata_key) if len(steps) > 1: # Static join - use exact path matching pattern = current_path or ".*" - yield from self._iter_matching_tasks( - steps, "foreach-execution-path", pattern - ) - return - - # Handle single step case - target_task = Step( - f"{flow_id}/{run_id}/{steps[0].id}", _namespace_check=False - ).task - target_path = target_task.metadata_dict.get("foreach-execution-path") - - if not target_path or not current_path: - # (Current task, "A:10") and (Parent task, "") - # Pattern: ".*" - pattern = ".*" else: - current_depth = len(current_path.split(",")) - target_depth = len(target_path.split(",")) - - if current_depth < target_depth: - # Foreach join - # (Current task, "A:10,B:13") and (Parent task, "A:10,B:13,C:21") - # Pattern: "A:10,B:13,.*" - pattern = f"{current_path},.*" + if not steps: + return # No parent steps, yield nothing + + if not current_path: + # Current task is not part of a foreach + # Pattern: ".*" + pattern = ".*" else: - # Foreach split or linear step - # Option 1: - # (Current task, "A:10,B:13,C:21") and (Parent task, "A:10,B:13") - # Option 2: - # (Current task, "A:10,B:13") and (Parent task, "A:10,B:13") - # Pattern: "A:10,B:13" - pattern = ",".join(current_path.split(",")[:target_depth]) + current_depth = len(current_path.split(",")) + if node_type == "join": + # Foreach join + # (Current task, "A:10,B:13") and (Parent task, "A:10,B:13,C:21") + # Pattern: "A:10,B:13,.*" + pattern = f"{current_path},.*" + else: + # Foreach split or linear step + # Pattern: "A:10,B:13" + parent_step_type = graph_info["steps"][steps[0]]["type"] + target_depth = current_depth + if parent_step_type == "split-foreach" and current_depth == 1: + # (Current task, "A:10") and (Parent task, "") + pattern = ".*" + else: + # (Current task, "A:10,B:13,C:21") and (Parent task, "A:10,B:13") + # (Current task, "A:10,B:13") and (Parent task, "A:10,B:13") + if parent_step_type == "split-foreach": + target_depth = current_depth - 1 + pattern = ",".join(current_path.split(",")[:target_depth]) - yield from self._iter_matching_tasks(steps, "foreach-execution-path", pattern) + for pathspec in self._get_matching_pathspecs(steps, metadata_key, pattern): + yield pathspec @property - def child_tasks(self) -> Iterator["Task"]: + def child_task_pathspecs(self) -> Iterator[str]: """ - Yield all child tasks of the current task if one exists. + Yields pathspecs of all child tasks of the current task. Yields ------ - Task - Child task of the current task + str + Pathspec of the child task of the current task """ - flow_id, run_id, _, _ = self.path_components - steps = list(self.parent.child_steps) - if not steps: - return [] + flow_id, run_id, step_name, _ = self.path_components + metadata_dict = self.metadata_dict + graph_info = self["_graph_info"].data + + # Get the child steps + steps = graph_info["steps"][step_name]["next"] - current_path = self.metadata_dict.get("foreach-execution-path", "") + node_type = graph_info["steps"][step_name]["type"] + metadata_key = "foreach-execution-path" + current_path = metadata_dict.get(metadata_key) if len(steps) > 1: # Static split - use exact path matching pattern = current_path or ".*" - yield from self._iter_matching_tasks( - steps, "foreach-execution-path", pattern - ) - return - - # Handle single step case - target_task = Step( - f"{flow_id}/{run_id}/{steps[0].id}", _namespace_check=False - ).task - target_path = target_task.metadata_dict.get("foreach-execution-path") - - if not target_path or not current_path: - # (Current task, "A:10") and (Child task, "") - # Pattern: ".*" - pattern = ".*" else: - current_depth = len(current_path.split(",")) - target_depth = len(target_path.split(",")) - - if current_depth < target_depth: - # Foreach split - # (Current task, "A:10,B:13") and (Child task, "A:10,B:13,C:21") - # Pattern: "A:10,B:13,.*" - pattern = f"{current_path},.*" + if not steps: + return # No child steps, yield nothing + + if not current_path: + # Current task is not part of a foreach + # Pattern: ".*" + pattern = ".*" else: - # Foreach join or linear step - # Option 1: - # (Current task, "A:10,B:13,C:21") and (Child task, "A:10,B:13") - # Option 2: - # (Current task, "A:10,B:13") and (Child task, "A:10,B:13") - # Pattern: "A:10,B:13" - pattern = ",".join(current_path.split(",")[:target_depth]) - - yield from self._iter_matching_tasks(steps, "foreach-execution-path", pattern) + current_depth = len(current_path.split(",")) + if node_type == "split-foreach": + # Foreach split + # (Current task, "A:10,B:13") and (Child task, "A:10,B:13,C:21") + # Pattern: "A:10,B:13,.*" + pattern = f"{current_path},.*" + else: + # Foreach join or linear step + # Pattern: "A:10,B:13" + child_step_type = graph_info["steps"][steps[0]]["type"] + + # We need to know if the child step is a foreach join or a static join + child_step_prev_steps = self._get_previous_steps( + graph_info, steps[0] + ) + if len(child_step_prev_steps) > 1: + child_step_type = "static-join" + target_depth = current_depth + if child_step_type == "join" and current_depth == 1: + # (Current task, "A:10") and (Child task, "") + pattern = ".*" + else: + # (Current task, "A:10,B:13,C:21") and (Child task, "A:10,B:13") + # (Current task, "A:10,B:13") and (Child task, "A:10,B:13") + if child_step_type == "join": + target_depth = current_depth - 1 + pattern = ",".join(current_path.split(",")[:target_depth]) + + for pathspec in self._get_matching_pathspecs(steps, metadata_key, pattern): + yield pathspec + + @property + def parent_tasks(self) -> Iterator["Task"]: + """ + Yields all parent tasks of the current task if one exists. + + Yields + ------ + Task + Parent task of the current task + """ + parent_task_pathspecs = self.parent_task_pathspecs + for pathspec in parent_task_pathspecs: + yield Task(pathspec=pathspec, _namespace_check=False) + + @property + def child_tasks(self) -> Iterator["Task"]: + """ + Yields all child tasks of the current task if one exists. + + Yields + ------ + Task + Child task of the current task + """ + for pathspec in self.child_task_pathspecs: + yield Task(pathspec=pathspec, _namespace_check=False) @property def metadata(self) -> List[Metadata]: diff --git a/metaflow/client/filecache.py b/metaflow/client/filecache.py index 83a38811eff..980b5f34cf0 100644 --- a/metaflow/client/filecache.py +++ b/metaflow/client/filecache.py @@ -1,5 +1,6 @@ from __future__ import print_function from collections import OrderedDict +import json import os import sys import time @@ -10,13 +11,14 @@ from metaflow.datastore import FlowDataStore from metaflow.datastore.content_addressed_store import BlobCache +from metaflow.datastore.flow_datastore import MetadataCache from metaflow.exception import MetaflowException from metaflow.metaflow_config import ( CLIENT_CACHE_PATH, CLIENT_CACHE_MAX_SIZE, CLIENT_CACHE_MAX_FLOWDATASTORE_COUNT, - CLIENT_CACHE_MAX_TASKDATASTORE_COUNT, ) +from metaflow.metaflow_profile import from_start from metaflow.plugins import DATASTORES @@ -63,8 +65,8 @@ def __init__(self, cache_dir=None, max_size=None): # when querying for sizes of artifacts. Once we have queried for the size # of one artifact in a TaskDatastore, caching this means that any # queries on that same TaskDatastore will be quick (since we already - # have all the metadata) - self._task_metadata_caches = OrderedDict() + # have all the metadata). We keep track of this in a file so it persists + # across processes. @property def cache_dir(self): @@ -87,7 +89,7 @@ def get_log_legacy( ): ds_cls = self._get_datastore_storage_impl(ds_type) ds_root = ds_cls.path_join(*ds_cls.path_split(location)[:-5]) - cache_id = self._flow_ds_id(ds_type, ds_root, flow_name) + cache_id = self.flow_ds_id(ds_type, ds_root, flow_name) token = ( "%s.cached" @@ -311,13 +313,13 @@ def _index_objects(self): self._objects = sorted(objects, reverse=False) @staticmethod - def _flow_ds_id(ds_type, ds_root, flow_name): + def flow_ds_id(ds_type, ds_root, flow_name): p = urlparse(ds_root) sanitized_root = (p.netloc + p.path).replace("/", "_") return ".".join([ds_type, sanitized_root, flow_name]) @staticmethod - def _task_ds_id(ds_type, ds_root, flow_name, run_id, step_name, task_id, attempt): + def task_ds_id(ds_type, ds_root, flow_name, run_id, step_name, task_id, attempt): p = urlparse(ds_root) sanitized_root = (p.netloc + p.path).replace("/", "_") return ".".join( @@ -365,7 +367,7 @@ def _get_datastore_storage_impl(ds_type): return storage_impl[0] def _get_flow_datastore(self, ds_type, ds_root, flow_name): - cache_id = self._flow_ds_id(ds_type, ds_root, flow_name) + cache_id = self.flow_ds_id(ds_type, ds_root, flow_name) cached_flow_datastore = self._store_caches.get(cache_id) if cached_flow_datastore: @@ -380,9 +382,14 @@ def _get_flow_datastore(self, ds_type, ds_root, flow_name): ds_root=ds_root, ) blob_cache = self._blob_caches.setdefault( - cache_id, FileBlobCache(self, cache_id) + cache_id, + ( + FileBlobCache(self, cache_id), + TaskMetadataCache(self, ds_type, ds_root, flow_name), + ), ) - cached_flow_datastore.ca_store.set_blob_cache(blob_cache) + cached_flow_datastore.ca_store.set_blob_cache(blob_cache[0]) + cached_flow_datastore.set_metadata_cache(blob_cache[1]) self._store_caches[cache_id] = cached_flow_datastore if len(self._store_caches) > CLIENT_CACHE_MAX_FLOWDATASTORE_COUNT: cache_id_to_remove, _ = self._store_caches.popitem(last=False) @@ -393,32 +400,49 @@ def _get_task_datastore( self, ds_type, ds_root, flow_name, run_id, step_name, task_id, attempt ): flow_ds = self._get_flow_datastore(ds_type, ds_root, flow_name) - cached_metadata = None - if attempt is not None: - cache_id = self._task_ds_id( - ds_type, ds_root, flow_name, run_id, step_name, task_id, attempt - ) - cached_metadata = self._task_metadata_caches.get(cache_id) - if cached_metadata: - od_move_to_end(self._task_metadata_caches, cache_id) - return flow_ds.get_task_datastore( - run_id, - step_name, - task_id, - attempt=attempt, - data_metadata=cached_metadata, - ) - # If we are here, we either have attempt=None or nothing in the cache - task_ds = flow_ds.get_task_datastore( - run_id, step_name, task_id, attempt=attempt + + return flow_ds.get_task_datastore(run_id, step_name, task_id, attempt=attempt) + + +class TaskMetadataCache(MetadataCache): + def __init__(self, filecache, ds_type, ds_root, flow_name): + self._ds_type = ds_type + self._ds_root = ds_root + self._flow_name = flow_name + self._filecache = filecache + + def _path(self, run_id, step_name, task_id, attempt): + if attempt is None: + return None + cache_id = self._filecache.task_ds_id( + self._ds_type, + self._ds_root, + self._flow_name, + run_id, + step_name, + task_id, + attempt, + ) + token = ( + "%s.cached" + % sha1( + os.path.join( + run_id, step_name, task_id, str(attempt), "metadata" + ).encode("utf-8") + ).hexdigest() ) - cache_id = self._task_ds_id( - ds_type, ds_root, flow_name, run_id, step_name, task_id, task_ds.attempt + return os.path.join(self._filecache.cache_dir, cache_id, token[:2], token) + + def load_metadata(self, run_id, step_name, task_id, attempt): + d = self._filecache.read_file(self._path(run_id, step_name, task_id, attempt)) + if d: + return json.loads(d) + + def store_metadata(self, run_id, step_name, task_id, attempt, metadata_dict): + self._filecache.create_file( + self._path(run_id, step_name, task_id, attempt), + json.dumps(metadata_dict).encode("utf-8"), ) - self._task_metadata_caches[cache_id] = task_ds.ds_metadata - if len(self._task_metadata_caches) > CLIENT_CACHE_MAX_TASKDATASTORE_COUNT: - self._task_metadata_caches.popitem(last=False) - return task_ds class FileBlobCache(BlobCache): diff --git a/metaflow/datastore/__init__.py b/metaflow/datastore/__init__.py index 793251b0cff..65bb33b0eb9 100644 --- a/metaflow/datastore/__init__.py +++ b/metaflow/datastore/__init__.py @@ -2,3 +2,4 @@ from .flow_datastore import FlowDataStore from .datastore_set import TaskDataStoreSet from .task_datastore import TaskDataStore +from .spin_datastore import SpinTaskDatastore diff --git a/metaflow/datastore/content_addressed_store.py b/metaflow/datastore/content_addressed_store.py index e0533565ffa..a8f2e0e4805 100644 --- a/metaflow/datastore/content_addressed_store.py +++ b/metaflow/datastore/content_addressed_store.py @@ -38,7 +38,7 @@ def __init__(self, prefix, storage_impl): def set_blob_cache(self, blob_cache): self._blob_cache = blob_cache - def save_blobs(self, blob_iter, raw=False, len_hint=0): + def save_blobs(self, blob_iter, raw=False, len_hint=0, _is_transfer=False): """ Saves blobs of data to the datastore @@ -65,6 +65,9 @@ def save_blobs(self, blob_iter, raw=False, len_hint=0): Whether to save the bytes directly or process them, by default False len_hint : Hint of the number of blobs that will be produced by the iterator, by default 0 + _is_transfer : bool, default False + If True, this indicates we are saving blobs directly from the output of another + content addressed store's Returns ------- @@ -76,6 +79,20 @@ def save_blobs(self, blob_iter, raw=False, len_hint=0): def packing_iter(): for blob in blob_iter: + if _is_transfer: + key, blob_data, meta = blob + path = self._storage_impl.path_join(self._prefix, key[:2], key) + # Transfer data is always raw/decompressed, so mark it as such + meta_corrected = {"cas_raw": True, "cas_version": 1} + + results.append( + self.save_blobs_result( + uri=self._storage_impl.full_uri(path), + key=key, + ) + ) + yield path, (BytesIO(blob_data), meta_corrected) + continue sha = sha1(blob).hexdigest() path = self._storage_impl.path_join(self._prefix, sha[:2], sha) results.append( @@ -100,7 +117,7 @@ def packing_iter(): self._storage_impl.save_bytes(packing_iter(), overwrite=True, len_hint=len_hint) return results - def load_blobs(self, keys, force_raw=False): + def load_blobs(self, keys, force_raw=False, _is_transfer=False): """ Mirror function of save_blobs @@ -111,15 +128,20 @@ def load_blobs(self, keys, force_raw=False): ---------- keys : List of string Key describing the object to load - force_raw : bool, optional + force_raw : bool, default False Support for backward compatibility with previous datastores. If True, this will force the key to be loaded as is (raw). By default, False + _is_transfer : bool, default False + If True, this indicates we are loading blobs to transfer them directly + to another datastore. We will, in this case, also transfer the metdata + and do minimal processing. This is for internal use only. Returns ------- Returns an iterator of (string, bytes) tuples; the iterator may return keys - in a different order than were passed in. + in a different order than were passed in. If _is_transfer is True, the tuple + has three elements with the third one being the metadata. """ load_paths = [] for key in keys: @@ -127,7 +149,11 @@ def load_blobs(self, keys, force_raw=False): if self._blob_cache: blob = self._blob_cache.load_key(key) if blob is not None: - yield key, blob + if _is_transfer: + # Cached blobs are decompressed/processed bytes regardless of original format + yield key, blob, {"cas_raw": False, "cas_version": 1} + else: + yield key, blob else: path = self._storage_impl.path_join(self._prefix, key[:2], key) load_paths.append((key, path)) @@ -169,7 +195,10 @@ def load_blobs(self, keys, force_raw=False): if self._blob_cache: self._blob_cache.store_key(key, blob) - yield key, blob + if _is_transfer: + yield key, blob, meta # Preserve exact original metadata from storage + else: + yield key, blob def _unpack_backward_compatible(self, blob): # This is the backward compatible unpack diff --git a/metaflow/datastore/datastore_set.py b/metaflow/datastore/datastore_set.py index f60642de73f..80cc4c690a4 100644 --- a/metaflow/datastore/datastore_set.py +++ b/metaflow/datastore/datastore_set.py @@ -21,9 +21,18 @@ def __init__( pathspecs=None, prefetch_data_artifacts=None, allow_not_done=False, + join_type=None, + orig_flow_datastore=None, + spin_artifacts=None, ): self.task_datastores = flow_datastore.get_task_datastores( - run_id, steps=steps, pathspecs=pathspecs, allow_not_done=allow_not_done + run_id, + steps=steps, + pathspecs=pathspecs, + allow_not_done=allow_not_done, + join_type=join_type, + orig_flow_datastore=orig_flow_datastore, + spin_artifacts=spin_artifacts, ) if prefetch_data_artifacts: diff --git a/metaflow/datastore/flow_datastore.py b/metaflow/datastore/flow_datastore.py index 16318ed7693..4e1a73657c5 100644 --- a/metaflow/datastore/flow_datastore.py +++ b/metaflow/datastore/flow_datastore.py @@ -1,10 +1,13 @@ import itertools import json +from abc import ABC, abstractmethod from .. import metaflow_config from .content_addressed_store import ContentAddressedStore from .task_datastore import TaskDataStore +from .spin_datastore import SpinTaskDatastore +from ..metaflow_profile import from_start class FlowDataStore(object): @@ -63,10 +66,16 @@ def __init__( self._storage_impl.path_join(self.flow_name, "data"), self._storage_impl ) + # Private + self._metadata_cache = None + @property def datastore_root(self): return self._storage_impl.datastore_root + def set_metadata_cache(self, cache): + self._metadata_cache = cache + def get_task_datastores( self, run_id=None, @@ -76,6 +85,9 @@ def get_task_datastores( attempt=None, include_prior=False, mode="r", + join_type=None, + orig_flow_datastore=None, + spin_artifacts=None, ): """ Return a list of TaskDataStore for a subset of the tasks. @@ -95,7 +107,7 @@ def get_task_datastores( Steps to get the tasks from. If run_id is specified, this must also be specified, by default None pathspecs : List[str], optional - Full task specs (run_id/step_name/task_id). Can be used instead of + Full task specs (run_id/step_name/task_id[/attempt]). Can be used instead of specifying run_id and steps, by default None allow_not_done : bool, optional If True, returns the latest attempt of a task even if that attempt @@ -106,6 +118,16 @@ def get_task_datastores( If True, returns all attempts up to and including attempt. mode : str, default "r" Mode to initialize the returned TaskDataStores in. + join_type : str, optional + If specified, the join type for the task. This is used to determine + the user specified artifacts for the task in case of a spin task. + orig_flow_datastore : MetadataProvider, optional + The metadata provider in case of a spin task. If provided, the + returned TaskDataStore will be a SpinTaskDatastore instead of a + TaskDataStore. + spin_artifacts : Dict[str, Any], optional + Artifacts provided by user that can override the artifacts fetched via the + spin pathspec. Returns ------- @@ -145,7 +167,13 @@ def get_task_datastores( if attempt is not None and attempt <= metaflow_config.MAX_ATTEMPTS - 1: attempt_range = range(attempt + 1) if include_prior else [attempt] for task_url in task_urls: - for attempt in attempt_range: + task_splits = task_url.split("/") + # Usually it is flow, run, step, task (so 4 components) -- if we have a + # fifth one, there is a specific attempt number listed as well. + task_attempt_range = attempt_range + if len(task_splits) == 5: + task_attempt_range = [int(task_splits[4])] + for attempt in task_attempt_range: for suffix in [ TaskDataStore.METADATA_DATA_SUFFIX, TaskDataStore.METADATA_ATTEMPT_SUFFIX, @@ -198,7 +226,18 @@ def get_task_datastores( else (latest_started_attempts & done_attempts) ) latest_to_fetch = [ - (v[0], v[1], v[2], v[3], data_objs.get(v), mode, allow_not_done) + ( + v[0], + v[1], + v[2], + v[3], + data_objs.get(v), + mode, + allow_not_done, + join_type, + orig_flow_datastore, + spin_artifacts, + ) for v in latest_to_fetch ] return list(itertools.starmap(self.get_task_datastore, latest_to_fetch)) @@ -212,8 +251,64 @@ def get_task_datastore( data_metadata=None, mode="r", allow_not_done=False, + join_type=None, + orig_flow_datastore=None, + spin_artifacts=None, + persist=True, ): - return TaskDataStore( + if orig_flow_datastore is not None: + # In spin step subprocess, use SpinTaskDatastore for accessing artifacts + if join_type is not None: + # If join_type is specified, we need to use the artifacts corresponding + # to that particular join index, specified by the parent task pathspec. + spin_artifacts = spin_artifacts.get( + f"{run_id}/{step_name}/{task_id}", {} + ) + from_start( + "FlowDataStore: get_task_datastore for spin task for type %s %s metadata" + % (self.TYPE, "without" if data_metadata is None else "with") + ) + # Get the task datastore for the spun task. + orig_datastore = orig_flow_datastore.get_task_datastore( + run_id, + step_name, + task_id, + attempt=attempt, + data_metadata=data_metadata, + mode=mode, + allow_not_done=allow_not_done, + join_type=join_type, + persist=persist, + ) + + return SpinTaskDatastore( + self.flow_name, + run_id, + step_name, + task_id, + orig_datastore, + spin_artifacts, + ) + + cache_hit = False + if ( + self._metadata_cache is not None + and data_metadata is None + and attempt is not None + and allow_not_done is False + ): + # If we have a metadata cache, we can try to load the metadata + # from the cache if it is not provided. + data_metadata = self._metadata_cache.load_metadata( + run_id, step_name, task_id, attempt + ) + cache_hit = data_metadata is not None + + from_start( + "FlowDataStore: get_task_datastore for regular task for type %s %s metadata" + % (self.TYPE, "without" if data_metadata is None else "with") + ) + task_datastore = TaskDataStore( self, run_id, step_name, @@ -222,8 +317,23 @@ def get_task_datastore( data_metadata=data_metadata, mode=mode, allow_not_done=allow_not_done, + persist=persist, ) + # Only persist in cache if it is non-changing (so done only) and we have + # a non-None attempt + if ( + not cache_hit + and self._metadata_cache is not None + and allow_not_done is False + and attempt is not None + ): + self._metadata_cache.store_metadata( + run_id, step_name, task_id, attempt, task_datastore.ds_metadata + ) + + return task_datastore + def save_data(self, data_iter, len_hint=0): """Saves data to the underlying content-addressed store @@ -265,3 +375,13 @@ def load_data(self, keys, force_raw=False): """ for key, blob in self.ca_store.load_blobs(keys, force_raw=force_raw): yield key, blob + + +class MetadataCache(ABC): + @abstractmethod + def load_metadata(self, run_id, step_name, task_id, attempt): + raise NotImplementedError() + + @abstractmethod + def store_metadata(self, run_id, step_name, task_id, attempt, metadata_dict): + raise NotImplementedError() diff --git a/metaflow/datastore/spin_datastore.py b/metaflow/datastore/spin_datastore.py new file mode 100644 index 00000000000..f45856c4c51 --- /dev/null +++ b/metaflow/datastore/spin_datastore.py @@ -0,0 +1,91 @@ +from typing import Dict, Any +from .task_datastore import TaskDataStore, require_mode +from ..metaflow_profile import from_start + + +class SpinTaskDatastore(object): + def __init__( + self, + flow_name: str, + run_id: str, + step_name: str, + task_id: str, + orig_datastore: TaskDataStore, + spin_artifacts: Dict[str, Any], + ): + """ + SpinTaskDatastore is a datastore for a task that is used to retrieve + artifacts and attributes for a spin step. It uses the task pathspec + from a previous execution of the step to access the artifacts and attributes. + + Parameters: + ----------- + flow_name : str + Name of the flow + run_id : str + Run ID of the flow + step_name : str + Name of the step + task_id : str + Task ID of the step + orig_datastore : TaskDataStore + The datastore for the underlying task that is being spun. + spin_artifacts : Dict[str, Any] + User provided artifacts that are to be used in the spin task. This is a dictionary + where keys are artifact names and values are the actual data or metadata. + """ + self.flow_name = flow_name + self.run_id = run_id + self.step_name = step_name + self.task_id = task_id + self.orig_datastore = orig_datastore + self.spin_artifacts = spin_artifacts + self._task = None + + # Update _objects and _info in order to persist artifacts + # See `persist` method in `TaskDatastore` for more details + self._objects = self.orig_datastore._objects.copy() + self._info = self.orig_datastore._info.copy() + + # We strip out some of the control ones + for key in ("_transition",): + if key in self._objects: + del self._objects[key] + del self._info[key] + + from_start("SpinTaskDatastore: Initialized artifacts") + + @require_mode(None) + def __getitem__(self, name): + try: + # Check if it's an artifact in the spin_artifacts + return self.spin_artifacts[name] + except KeyError: + try: + # Check if it's an attribute of the task + # _foreach_stack, _foreach_index, ... + return self.orig_datastore[name] + except (KeyError, AttributeError) as e: + raise KeyError( + f"Attribute '{name}' not found in the previous execution " + f"of the tasks for `{self.step_name}`." + ) from e + + @require_mode(None) + def is_none(self, name): + val = self.__getitem__(name) + return val is None + + @require_mode(None) + def __contains__(self, name): + try: + _ = self.__getitem__(name) + return True + except KeyError: + return False + + @require_mode(None) + def items(self): + if self._objects: + return self._objects.items() + return {} diff --git a/metaflow/datastore/task_datastore.py b/metaflow/datastore/task_datastore.py index ebfed2d55d4..0d846e7c88c 100644 --- a/metaflow/datastore/task_datastore.py +++ b/metaflow/datastore/task_datastore.py @@ -6,6 +6,7 @@ from functools import wraps from io import BufferedIOBase, FileIO, RawIOBase +from typing import List from types import MethodType, FunctionType from .. import metaflow_config @@ -98,6 +99,7 @@ def __init__( data_metadata=None, mode="r", allow_not_done=False, + persist=True, ): self._storage_impl = flow_datastore._storage_impl self.TYPE = self._storage_impl.TYPE @@ -113,6 +115,7 @@ def __init__( self._attempt = attempt self._metadata = flow_datastore.metadata self._parent = flow_datastore + self._persist = persist # The GZIP encodings are for backward compatibility self._encodings = {"pickle-v2", "gzip+pickle-v2"} @@ -148,6 +151,8 @@ def __init__( ) if self.has_metadata(check_meta, add_attempt=False): max_attempt = i + elif max_attempt is not None: + break if self._attempt is None: self._attempt = max_attempt elif max_attempt is None or self._attempt > max_attempt: @@ -253,6 +258,70 @@ def init_task(self): """ self.save_metadata({self.METADATA_ATTEMPT_SUFFIX: {"time": time.time()}}) + @only_if_not_done + @require_mode("w") + def transfer_artifacts(self, other_datastore : "TaskDataStore", names : List[str] =None): + """ + Copies the blobs from other_datastore to this datastore if the datastore roots + are different. + + This is used specifically for spin so we can bring in artifacts from the original + datastore. + + Parameters + ---------- + other_datastore : TaskDataStore + Other datastore from which to copy artifacts from + names : List[str], optional, default None + If provided, only transfer the artifacts with these names. If None, + transfer all artifacts from the other datastore. + """ + if ( + other_datastore.TYPE == self.TYPE + and other_datastore._storage_impl.datastore_root + == self._storage_impl.datastore_root + ): + # Nothing to transfer -- artifacts are already saved properly + return + + # Determine which artifacts need to be transferred + if names is None: + # Transfer all artifacts from other datastore + artifacts_to_transfer = list(other_datastore._objects.keys()) + else: + # Transfer only specified artifacts + artifacts_to_transfer = [ + name for name in names if name in other_datastore._objects + ] + + if not artifacts_to_transfer: + return + + # Get SHA keys for artifacts to transfer + shas_to_transfer = [ + other_datastore._objects[name] for name in artifacts_to_transfer + ] + + # Check which blobs are missing locally + missing_shas = [] + for sha in shas_to_transfer: + local_path = self._ca_store._storage_impl.path_join( + self._ca_store._prefix, sha[:2], sha + ) + if not self._ca_store._storage_impl.is_file([local_path])[0]: + missing_shas.append(sha) + + if not missing_shas: + return # All blobs already exist locally + + # Load blobs from other datastore in transfer mode + transfer_blobs = other_datastore._ca_store.load_blobs( + missing_shas, _is_transfer=True + ) + + # Save blobs to local datastore in transfer mode + self._ca_store.save_blobs(transfer_blobs, _is_transfer=True) + @only_if_not_done @require_mode("w") def save_artifacts(self, artifacts_iter, len_hint=0): @@ -683,14 +752,16 @@ def persist(self, flow): flow : FlowSpec Flow to persist """ + if not self._persist: + return if flow._datastore: self._objects.update(flow._datastore._objects) self._info.update(flow._datastore._info) - # we create a list of valid_artifacts in advance, outside of - # artifacts_iter, so we can provide a len_hint below + # Scan flow object FIRST valid_artifacts = [] + current_artifact_names = set() for var in dir(flow): if var.startswith("__") or var in flow._EPHEMERAL: continue @@ -707,6 +778,16 @@ def persist(self, flow): or isinstance(val, Parameter) ): valid_artifacts.append((var, val)) + current_artifact_names.add(var) + + # Transfer ONLY artifacts that aren't being overridden + if hasattr(flow._datastore, "orig_datastore"): + parent_artifacts = set(flow._datastore._objects.keys()) + unchanged_artifacts = parent_artifacts - current_artifact_names + if unchanged_artifacts: + self.transfer_artifacts( + flow._datastore.orig_datastore, names=list(unchanged_artifacts) + ) def artifacts_iter(): # we consume the valid_artifacts list destructively to @@ -722,6 +803,7 @@ def artifacts_iter(): delattr(flow, var) yield var, val + # Save current artifacts self.save_artifacts(artifacts_iter(), len_hint=len(valid_artifacts)) @only_if_not_done diff --git a/metaflow/decorators.py b/metaflow/decorators.py index 583f8a4515e..d6f91cd3066 100644 --- a/metaflow/decorators.py +++ b/metaflow/decorators.py @@ -27,7 +27,7 @@ UserStepDecoratorBase, UserStepDecoratorMeta, ) - +from .metaflow_config import SPIN_ALLOWED_DECORATORS from metaflow._vendor import click @@ -658,6 +658,50 @@ def _attach_decorators_to_step(step, decospecs): step_deco.add_or_raise(step, False, 1, None) +def _should_skip_decorator_for_spin( + deco, is_spin, skip_decorators, logger, decorator_type="decorator" +): + """ + Determine if a decorator should be skipped for spin steps. + + Parameters: + ----------- + deco : Decorator + The decorator instance to check + is_spin : bool + Whether this is a spin step + skip_decorators : bool + Whether to skip all decorators + logger : callable + Logger function for warnings + decorator_type : str + Type of decorator ("Flow decorator" or "Step decorator") for logging + + Returns: + -------- + bool + True if the decorator should be skipped, False otherwise + """ + if not is_spin: + return False + + # Skip all decorator hooks if skip_decorators is True + if skip_decorators: + return True + + # Run decorator hooks for spin steps only if they are in the whitelist + if deco.name not in SPIN_ALLOWED_DECORATORS: + logger( + f"[Warning] Ignoring {decorator_type} '{deco.name}' as it is not supported in spin steps.", + system_msg=True, + timestamp=False, + bad=True, + ) + return True + + return False + + def _init(flow, only_non_static=False): for decorators in flow._flow_decorators.values(): for deco in decorators: @@ -673,7 +717,16 @@ def _init(flow, only_non_static=False): def _init_flow_decorators( - flow, graph, environment, flow_datastore, metadata, logger, echo, deco_options + flow, + graph, + environment, + flow_datastore, + metadata, + logger, + echo, + deco_options, + is_spin=False, + skip_decorators=False, ): # Since all flow decorators are stored as `{key:[deco]}` we iterate through each of them. for decorators in flow._flow_decorators.values(): @@ -702,6 +755,10 @@ def _init_flow_decorators( for option, option_info in deco.options.items() } for deco in decorators: + if _should_skip_decorator_for_spin( + deco, is_spin, skip_decorators, logger, "Flow decorator" + ): + continue deco.flow_init( flow, graph, @@ -714,8 +771,16 @@ def _init_flow_decorators( ) -def _init_step_decorators(flow, graph, environment, flow_datastore, logger): - # NOTE: We don't need graph but keeping it for backwards compatibility with +def _init_step_decorators( + flow, + graph, + environment, + flow_datastore, + logger, + is_spin=False, + skip_decorators=False, +): + # NOTE: We don't need the graph but keeping it for backwards compatibility with # extensions that use it directly. We will remove it at some point. # We call the mutate method for both the flow and step mutators. @@ -785,6 +850,10 @@ def _init_step_decorators(flow, graph, environment, flow_datastore, logger): for step in flow: for deco in step.decorators: + if _should_skip_decorator_for_spin( + deco, is_spin, skip_decorators, logger, "Step decorator" + ): + continue deco.step_init( flow, graph, diff --git a/metaflow/metaflow_config.py b/metaflow/metaflow_config.py index 0a54486357c..d8a2d46d335 100644 --- a/metaflow/metaflow_config.py +++ b/metaflow/metaflow_config.py @@ -21,6 +21,7 @@ # Path to the local directory to store artifacts for 'local' datastore. DATASTORE_LOCAL_DIR = ".metaflow" +DATASTORE_SPIN_LOCAL_DIR = ".metaflow_spin" # Local configuration file (in .metaflow) containing overrides per-project LOCAL_CONFIG_FILE = "config.json" @@ -47,6 +48,24 @@ "DEFAULT_FROM_DEPLOYMENT_IMPL", "argo-workflows" ) +### +# Spin configuration +### +SPIN_ALLOWED_DECORATORS = from_conf( + "SPIN_ALLOWED_DECORATORS", + [ + "conda", + "pypi", + "conda_base", + "pypi_base", + "environment", + "project", + "timeout", + "conda_env_internal", + "card", + ], +) + ### # User configuration ### @@ -57,6 +76,7 @@ # Datastore configuration ### DATASTORE_SYSROOT_LOCAL = from_conf("DATASTORE_SYSROOT_LOCAL") +DATASTORE_SYSROOT_SPIN = from_conf("DATASTORE_SYSROOT_SPIN") # S3 bucket and prefix to store artifacts for 's3' datastore. DATASTORE_SYSROOT_S3 = from_conf("DATASTORE_SYSROOT_S3") # Azure Blob Storage container and blob prefix @@ -461,6 +481,10 @@ ### FEAT_ALWAYS_UPLOAD_CODE_PACKAGE = from_conf("FEAT_ALWAYS_UPLOAD_CODE_PACKAGE", False) ### +# Profile +### +PROFILE_FROM_START = from_conf("PROFILE_FROM_START", False) +### # Debug configuration ### DEBUG_OPTIONS = [ diff --git a/metaflow/metaflow_profile.py b/metaflow/metaflow_profile.py index 39ecf42cdc3..1757aedf3fb 100644 --- a/metaflow/metaflow_profile.py +++ b/metaflow/metaflow_profile.py @@ -2,6 +2,24 @@ from contextlib import contextmanager +from .metaflow_config import PROFILE_FROM_START + +init_time = None + + +if PROFILE_FROM_START: + + def from_start(msg: str): + global init_time + if init_time is None: + init_time = time.time() + print("From start: %s took %dms" % (msg, int((time.time() - init_time) * 1000))) + +else: + + def from_start(_msg: str): + pass + @contextmanager def profile(label, stats_dict=None): diff --git a/metaflow/plugins/__init__.py b/metaflow/plugins/__init__.py index bbab35885e8..2a1565dae97 100644 --- a/metaflow/plugins/__init__.py +++ b/metaflow/plugins/__init__.py @@ -83,11 +83,13 @@ METADATA_PROVIDERS_DESC = [ ("service", ".metadata_providers.service.ServiceMetadataProvider"), ("local", ".metadata_providers.local.LocalMetadataProvider"), + ("spin", ".metadata_providers.spin.SpinMetadataProvider"), ] # Add datastore here DATASTORES_DESC = [ ("local", ".datastores.local_storage.LocalStorage"), + ("spin", ".datastores.spin_storage.SpinStorage"), ("s3", ".datastores.s3_storage.S3Storage"), ("azure", ".datastores.azure_storage.AzureStorage"), ("gs", ".datastores.gs_storage.GSStorage"), diff --git a/metaflow/plugins/cards/card_cli.py b/metaflow/plugins/cards/card_cli.py index 9cb8b4bbb9d..766f23a70d8 100644 --- a/metaflow/plugins/cards/card_cli.py +++ b/metaflow/plugins/cards/card_cli.py @@ -335,7 +335,6 @@ def list_many_cards( def cli(): pass - @cli.group(help="Commands related to @card decorator.") @click.pass_context def card(ctx): @@ -343,7 +342,6 @@ def card(ctx): # Can work with the Metaflow client. # If we don't set the metadata here than the metaflow client picks the defaults when calling the `Task`/`Run` objects. These defaults can come from the `config.json` file or based on the `METAFLOW_PROFILE` from metaflow import metadata - setting_metadata = "@".join( [ctx.obj.metadata.TYPE, ctx.obj.metadata.default_info()] ) diff --git a/metaflow/plugins/cards/card_datastore.py b/metaflow/plugins/cards/card_datastore.py index f70f608c372..beda97661ba 100644 --- a/metaflow/plugins/cards/card_datastore.py +++ b/metaflow/plugins/cards/card_datastore.py @@ -13,6 +13,7 @@ CARD_S3ROOT, CARD_LOCALROOT, DATASTORE_LOCAL_DIR, + DATASTORE_SPIN_LOCAL_DIR, CARD_SUFFIX, CARD_AZUREROOT, CARD_GSROOT, @@ -62,12 +63,17 @@ def get_storage_root(cls, storage_type): return CARD_AZUREROOT elif storage_type == "gs": return CARD_GSROOT - elif storage_type == "local": + elif storage_type == "local" or storage_type == "spin": # Borrowing some of the logic from LocalStorage.get_storage_root result = CARD_LOCALROOT + local_dir = ( + DATASTORE_SPIN_LOCAL_DIR + if storage_type == "spin" + else DATASTORE_LOCAL_DIR + ) if result is None: current_path = os.getcwd() - check_dir = os.path.join(current_path, DATASTORE_LOCAL_DIR, CARD_SUFFIX) + check_dir = os.path.join(current_path, local_dir, CARD_SUFFIX) check_dir = os.path.realpath(check_dir) orig_path = check_dir while not os.path.isdir(check_dir): @@ -75,9 +81,7 @@ def get_storage_root(cls, storage_type): if new_path == current_path: break # We are no longer making upward progress current_path = new_path - check_dir = os.path.join( - current_path, DATASTORE_LOCAL_DIR, CARD_SUFFIX - ) + check_dir = os.path.join(current_path, local_dir, CARD_SUFFIX) result = orig_path return result diff --git a/metaflow/plugins/cards/card_decorator.py b/metaflow/plugins/cards/card_decorator.py index 28c0c7f8f10..daa667fa2a7 100644 --- a/metaflow/plugins/cards/card_decorator.py +++ b/metaflow/plugins/cards/card_decorator.py @@ -171,6 +171,7 @@ def step_init( self._flow_datastore = flow_datastore self._environment = environment self._logger = logger + self.card_options = None # We check for configuration options. We do this here before they are diff --git a/metaflow/plugins/cards/card_modules/basic.py b/metaflow/plugins/cards/card_modules/basic.py index 19d8532ff9b..1e9c4ab7e4a 100644 --- a/metaflow/plugins/cards/card_modules/basic.py +++ b/metaflow/plugins/cards/card_modules/basic.py @@ -491,9 +491,13 @@ def render(self): ) # ignore the name as a parameter - param_ids = [ - p.id for p in self._task.parent.parent["_parameters"].task if p.id != "name" - ] + if "_parameters" not in self._task.parent.parent: + # In case of spin steps, there is no _parameters task + param_ids = [] + else: + param_ids = [ + p.id for p in self._task.parent.parent["_parameters"].task if p.id != "name" + ] if len(param_ids) > 0: # Extract parameter from the Parameter Task. That is less brittle. parameter_data = TaskToDict( diff --git a/metaflow/plugins/datastores/local_storage.py b/metaflow/plugins/datastores/local_storage.py index 4077a9404dd..bb4791df8d3 100644 --- a/metaflow/plugins/datastores/local_storage.py +++ b/metaflow/plugins/datastores/local_storage.py @@ -1,24 +1,29 @@ import json import os -from metaflow.metaflow_config import DATASTORE_LOCAL_DIR, DATASTORE_SYSROOT_LOCAL +from metaflow.metaflow_config import ( + DATASTORE_LOCAL_DIR, + DATASTORE_SYSROOT_LOCAL, +) from metaflow.datastore.datastore_storage import CloseAfterUse, DataStoreStorage class LocalStorage(DataStoreStorage): TYPE = "local" METADATA_DIR = "_meta" + DATASTORE_DIR = DATASTORE_LOCAL_DIR # ".metaflow" + SYSROOT_VAR = DATASTORE_SYSROOT_LOCAL @classmethod def get_datastore_root_from_config(cls, echo, create_on_absent=True): - result = DATASTORE_SYSROOT_LOCAL + result = cls.SYSROOT_VAR if result is None: try: # Python2 current_path = os.getcwdu() except: # noqa E722 current_path = os.getcwd() - check_dir = os.path.join(current_path, DATASTORE_LOCAL_DIR) + check_dir = os.path.join(current_path, cls.DATASTORE_DIR) check_dir = os.path.realpath(check_dir) orig_path = check_dir top_level_reached = False @@ -28,12 +33,13 @@ def get_datastore_root_from_config(cls, echo, create_on_absent=True): top_level_reached = True break # We are no longer making upward progress current_path = new_path - check_dir = os.path.join(current_path, DATASTORE_LOCAL_DIR) + check_dir = os.path.join(current_path, cls.DATASTORE_DIR) if top_level_reached: if create_on_absent: # Could not find any directory to use so create a new one echo( - "Creating local datastore in current directory (%s)" % orig_path + "Creating %s datastore in current directory (%s)" + % (cls.TYPE, orig_path) ) os.mkdir(orig_path) result = orig_path @@ -42,7 +48,7 @@ def get_datastore_root_from_config(cls, echo, create_on_absent=True): else: result = check_dir else: - result = os.path.join(result, DATASTORE_LOCAL_DIR) + result = os.path.join(result, cls.DATASTORE_DIR) return result @staticmethod diff --git a/metaflow/plugins/datastores/spin_storage.py b/metaflow/plugins/datastores/spin_storage.py new file mode 100644 index 00000000000..d0f39baf62b --- /dev/null +++ b/metaflow/plugins/datastores/spin_storage.py @@ -0,0 +1,12 @@ +from metaflow.metaflow_config import ( + DATASTORE_SPIN_LOCAL_DIR, + DATASTORE_SYSROOT_SPIN, +) +from metaflow.plugins.datastores.local_storage import LocalStorage + + +class SpinStorage(LocalStorage): + TYPE = "spin" + METADATA_DIR = "_meta" + DATASTORE_DIR = DATASTORE_SPIN_LOCAL_DIR # ".metaflow_spin" + SYSROOT_VAR = DATASTORE_SYSROOT_SPIN diff --git a/metaflow/plugins/metadata_providers/local.py b/metaflow/plugins/metadata_providers/local.py index 74de40a61e8..424812c810f 100644 --- a/metaflow/plugins/metadata_providers/local.py +++ b/metaflow/plugins/metadata_providers/local.py @@ -18,6 +18,14 @@ class LocalMetadataProvider(MetadataProvider): TYPE = "local" + DATASTORE_DIR = DATASTORE_LOCAL_DIR # ".metaflow" + + @classmethod + def _get_storage_class(cls): + # This method is meant to be overridden + from metaflow.plugins.datastores.local_storage import LocalStorage + + return LocalStorage def __init__(self, environment, flow, event_logger, monitor): super(LocalMetadataProvider, self).__init__( @@ -26,30 +34,28 @@ def __init__(self, environment, flow, event_logger, monitor): @classmethod def compute_info(cls, val): - from metaflow.plugins.datastores.local_storage import LocalStorage + storage_class = cls._get_storage_class() - v = os.path.realpath(os.path.join(val, DATASTORE_LOCAL_DIR)) + v = os.path.realpath(os.path.join(val, cls.DATASTORE_DIR)) if os.path.isdir(v): - LocalStorage.datastore_root = v + storage_class.datastore_root = v return val raise ValueError( - "Could not find directory %s in directory %s" % (DATASTORE_LOCAL_DIR, val) + "Could not find directory %s in directory %s" % (cls.DATASTORE_DIR, val) ) @classmethod def default_info(cls): - from metaflow.plugins.datastores.local_storage import LocalStorage + storage_class = cls._get_storage_class() def print_clean(line, **kwargs): print(line) - v = LocalStorage.get_datastore_root_from_config( + v = storage_class.get_datastore_root_from_config( print_clean, create_on_absent=False ) if v is None: - return ( - "" % DATASTORE_LOCAL_DIR - ) + return "" % cls.DATASTORE_DIR return os.path.dirname(v) def version(self): @@ -102,7 +108,7 @@ def register_task_id( def register_data_artifacts( self, run_id, step_name, task_id, attempt_id, artifacts ): - meta_dir = self._create_and_get_metadir( + meta_dir = self.__class__._create_and_get_metadir( self._flow_name, run_id, step_name, task_id ) artlist = self._artifacts_to_json( @@ -112,7 +118,7 @@ def register_data_artifacts( self._save_meta(meta_dir, artdict) def register_metadata(self, run_id, step_name, task_id, metadata): - meta_dir = self._create_and_get_metadir( + meta_dir = self.__class__._create_and_get_metadir( self._flow_name, run_id, step_name, task_id ) metalist = self._metadata_to_json(run_id, step_name, task_id, metadata) @@ -132,9 +138,7 @@ def _mutate_user_tags_for_run( def _optimistically_mutate(): # get existing tags - run = LocalMetadataProvider.get_object( - "run", "self", {}, None, flow_id, run_id - ) + run = cls.get_object("run", "self", {}, None, flow_id, run_id) if not run: raise MetaflowTaggingError( msg="Run not found (%s, %s)" % (flow_id, run_id) @@ -167,15 +171,13 @@ def _optimistically_mutate(): validate_tags(next_user_tags_set, existing_tags=existing_user_tag_set) # write new tag set to file system - LocalMetadataProvider._persist_tags_for_run( + cls._persist_tags_for_run( flow_id, run_id, next_user_tags_set, existing_system_tag_set ) # read tags back from file system to see if our optimism is misplaced # I.e. did a concurrent mutate overwrite our change - run = LocalMetadataProvider.get_object( - "run", "self", {}, None, flow_id, run_id - ) + run = cls.get_object("run", "self", {}, None, flow_id, run_id) if not run: raise MetaflowTaggingError( msg="Run not found for read-back check (%s, %s)" % (flow_id, run_id) @@ -279,8 +281,6 @@ def _get_object_internal( if obj_type not in ("root", "flow", "run", "step", "task", "artifact"): raise MetaflowInternalError(msg="Unexpected object type %s" % obj_type) - from metaflow.plugins.datastores.local_storage import LocalStorage - if obj_type == "artifact": # Artifacts are actually part of the tasks in the filesystem # E.g. we get here for (obj_type, sub_type) == (artifact, self) @@ -307,13 +307,13 @@ def _get_object_internal( # Special handling of self, artifact, and metadata if sub_type == "self": - meta_path = LocalMetadataProvider._get_metadir(*args[:obj_order]) + meta_path = cls._get_metadir(*args[:obj_order]) if meta_path is None: return None self_file = os.path.join(meta_path, "_self.json") if os.path.isfile(self_file): obj = MetadataProvider._apply_filter( - [LocalMetadataProvider._read_json_file(self_file)], filters + [cls._read_json_file(self_file)], filters )[0] # For non-descendants of a run, we are done @@ -324,7 +324,7 @@ def _get_object_internal( raise MetaflowInternalError( msg="Unexpected object type %s" % obj_type ) - run = LocalMetadataProvider.get_object( + run = cls.get_object( "run", "self", {}, None, *args[:RUN_ORDER] # *[flow_id, run_id] ) if not run: @@ -341,7 +341,7 @@ def _get_object_internal( if obj_type not in ("root", "flow", "run", "step", "task"): raise MetaflowInternalError(msg="Unexpected object type %s" % obj_type) - meta_path = LocalMetadataProvider._get_metadir(*args[:obj_order]) + meta_path = cls._get_metadir(*args[:obj_order]) result = [] if meta_path is None: return result @@ -352,9 +352,7 @@ def _get_object_internal( attempts_done = sorted(glob.iglob(attempt_done_files)) if attempts_done: successful_attempt = int( - LocalMetadataProvider._read_json_file(attempts_done[-1])[ - "value" - ] + cls._read_json_file(attempts_done[-1])["value"] ) if successful_attempt is not None: which_artifact = "*" @@ -365,10 +363,10 @@ def _get_object_internal( "%d_artifact_%s.json" % (successful_attempt, which_artifact), ) for obj in glob.iglob(artifact_files): - result.append(LocalMetadataProvider._read_json_file(obj)) + result.append(cls._read_json_file(obj)) # We are getting artifacts. We should overlay with ancestral run's tags - run = LocalMetadataProvider.get_object( + run = cls.get_object( "run", "self", {}, None, *args[:RUN_ORDER] # *[flow_id, run_id] ) if not run: @@ -388,12 +386,12 @@ def _get_object_internal( if obj_type not in ("root", "flow", "run", "step", "task"): raise MetaflowInternalError(msg="Unexpected object type %s" % obj_type) result = [] - meta_path = LocalMetadataProvider._get_metadir(*args[:obj_order]) + meta_path = cls._get_metadir(*args[:obj_order]) if meta_path is None: return result files = os.path.join(meta_path, "sysmeta_*") for obj in glob.iglob(files): - result.append(LocalMetadataProvider._read_json_file(obj)) + result.append(cls._read_json_file(obj)) return result # For the other types, we locate all the objects we need to find and return them @@ -401,14 +399,13 @@ def _get_object_internal( raise MetaflowInternalError(msg="Unexpected object type %s" % obj_type) if sub_type not in ("flow", "run", "step", "task"): raise MetaflowInternalError(msg="unexpected sub type %s" % sub_type) - obj_path = LocalMetadataProvider._make_path( - *args[:obj_order], create_on_absent=False - ) + obj_path = cls._make_path(*args[:obj_order], create_on_absent=False) result = [] if obj_path is None: return result skip_dirs = "*/" * (sub_order - obj_order) - all_meta = os.path.join(obj_path, skip_dirs, LocalStorage.METADATA_DIR) + storage_class = cls._get_storage_class() + all_meta = os.path.join(obj_path, skip_dirs, storage_class.METADATA_DIR) SelfInfo = collections.namedtuple("SelfInfo", ["filepath", "run_id"]) self_infos = [] for meta_path in glob.iglob(all_meta): @@ -418,9 +415,7 @@ def _get_object_internal( run_id = None # flow and run do not need info from ancestral run if sub_type in ("step", "task"): - run_id = LocalMetadataProvider._deduce_run_id_from_meta_dir( - meta_path, sub_type - ) + run_id = cls._deduce_run_id_from_meta_dir(meta_path, sub_type) # obj_type IS run, or more granular than run, let's do sanity check vs args if obj_order >= RUN_ORDER: if run_id != args[RUN_ORDER - 1]: @@ -430,10 +425,10 @@ def _get_object_internal( self_infos.append(SelfInfo(filepath=self_file, run_id=run_id)) for self_info in self_infos: - obj = LocalMetadataProvider._read_json_file(self_info.filepath) + obj = cls._read_json_file(self_info.filepath) if self_info.run_id: flow_id_from_args = args[0] - run = LocalMetadataProvider.get_object( + run = cls.get_object( "run", "self", {}, @@ -452,8 +447,8 @@ def _get_object_internal( return MetadataProvider._apply_filter(result, filters) - @staticmethod - def _deduce_run_id_from_meta_dir(meta_dir_path, sub_type): + @classmethod + def _deduce_run_id_from_meta_dir(cls, meta_dir_path, sub_type): curr_order = ObjectOrder.type_to_order(sub_type) levels_to_ascend = curr_order - ObjectOrder.type_to_order("run") if levels_to_ascend < 0: @@ -468,8 +463,8 @@ def _deduce_run_id_from_meta_dir(meta_dir_path, sub_type): ) return run_id - @staticmethod - def _makedirs(path): + @classmethod + def _makedirs(cls, path): # this is for python2 compatibility. # Python3 has os.makedirs(exist_ok=True). try: @@ -481,17 +476,15 @@ def _makedirs(path): else: raise - @staticmethod - def _persist_tags_for_run(flow_id, run_id, tags, system_tags): - subpath = LocalMetadataProvider._create_and_get_metadir( - flow_name=flow_id, run_id=run_id - ) + @classmethod + def _persist_tags_for_run(cls, flow_id, run_id, tags, system_tags): + subpath = cls._create_and_get_metadir(flow_name=flow_id, run_id=run_id) selfname = os.path.join(subpath, "_self.json") if not os.path.isfile(selfname): raise MetaflowInternalError( msg="Could not verify Run existence on disk - missing %s" % selfname ) - LocalMetadataProvider._save_meta( + cls._save_meta( subpath, { "_self": MetadataProvider._run_to_json_static( @@ -508,11 +501,11 @@ def _ensure_meta( tags = set() if sys_tags is None: sys_tags = set() - subpath = self._create_and_get_metadir( + subpath = self.__class__._create_and_get_metadir( self._flow_name, run_id, step_name, task_id ) selfname = os.path.join(subpath, "_self.json") - self._makedirs(subpath) + self.__class__._makedirs(subpath) if os.path.isfile(selfname): # There is a race here, but we are not aiming to make this as solid as # the metadata service. This is used primarily for concurrent resumes, @@ -549,26 +542,31 @@ def _new_task( self._register_system_metadata(run_id, step_name, task_id, attempt) return to_return - @staticmethod + @classmethod def _make_path( - flow_name=None, run_id=None, step_name=None, task_id=None, create_on_absent=True + cls, + flow_name=None, + run_id=None, + step_name=None, + task_id=None, + create_on_absent=True, ): - from metaflow.plugins.datastores.local_storage import LocalStorage + storage_class = cls._get_storage_class() - if LocalStorage.datastore_root is None: + if storage_class.datastore_root is None: def print_clean(line, **kwargs): print(line) - LocalStorage.datastore_root = LocalStorage.get_datastore_root_from_config( + storage_class.datastore_root = storage_class.get_datastore_root_from_config( print_clean, create_on_absent=create_on_absent ) - if LocalStorage.datastore_root is None: + if storage_class.datastore_root is None: return None if flow_name is None: - return LocalStorage.datastore_root + return storage_class.datastore_root components = [] if flow_name: components.append(flow_name) @@ -578,37 +576,35 @@ def print_clean(line, **kwargs): components.append(step_name) if task_id: components.append(task_id) - return LocalStorage().full_uri(LocalStorage.path_join(*components)) + return storage_class().full_uri(storage_class.path_join(*components)) - @staticmethod + @classmethod def _create_and_get_metadir( - flow_name=None, run_id=None, step_name=None, task_id=None + cls, flow_name=None, run_id=None, step_name=None, task_id=None ): - from metaflow.plugins.datastores.local_storage import LocalStorage + storage_class = cls._get_storage_class() - root_path = LocalMetadataProvider._make_path( - flow_name, run_id, step_name, task_id - ) - subpath = os.path.join(root_path, LocalStorage.METADATA_DIR) - LocalMetadataProvider._makedirs(subpath) + root_path = cls._make_path(flow_name, run_id, step_name, task_id) + subpath = os.path.join(root_path, storage_class.METADATA_DIR) + cls._makedirs(subpath) return subpath - @staticmethod - def _get_metadir(flow_name=None, run_id=None, step_name=None, task_id=None): - from metaflow.plugins.datastores.local_storage import LocalStorage + @classmethod + def _get_metadir(cls, flow_name=None, run_id=None, step_name=None, task_id=None): + storage_class = cls._get_storage_class() - root_path = LocalMetadataProvider._make_path( + root_path = cls._make_path( flow_name, run_id, step_name, task_id, create_on_absent=False ) if root_path is None: return None - subpath = os.path.join(root_path, LocalStorage.METADATA_DIR) + subpath = os.path.join(root_path, storage_class.METADATA_DIR) if os.path.isdir(subpath): return subpath return None - @staticmethod - def _dump_json_to_file(filepath, data, allow_overwrite=False): + @classmethod + def _dump_json_to_file(cls, filepath, data, allow_overwrite=False): if os.path.isfile(filepath) and not allow_overwrite: return try: @@ -622,15 +618,13 @@ def _dump_json_to_file(filepath, data, allow_overwrite=False): if f and os.path.isfile(f.name): os.remove(f.name) - @staticmethod - def _read_json_file(filepath): + @classmethod + def _read_json_file(cls, filepath): with open(filepath, "r") as f: return json.load(f) - @staticmethod - def _save_meta(root_dir, metadict, allow_overwrite=False): + @classmethod + def _save_meta(cls, root_dir, metadict, allow_overwrite=False): for name, datum in metadict.items(): filename = os.path.join(root_dir, "%s.json" % name) - LocalMetadataProvider._dump_json_to_file( - filename, datum, allow_overwrite=allow_overwrite - ) + cls._dump_json_to_file(filename, datum, allow_overwrite=allow_overwrite) diff --git a/metaflow/plugins/metadata_providers/spin.py b/metaflow/plugins/metadata_providers/spin.py new file mode 100644 index 00000000000..e32fdc8ffe6 --- /dev/null +++ b/metaflow/plugins/metadata_providers/spin.py @@ -0,0 +1,16 @@ +from metaflow.plugins.metadata_providers.local import LocalMetadataProvider +from metaflow.metaflow_config import DATASTORE_SPIN_LOCAL_DIR + + +class SpinMetadataProvider(LocalMetadataProvider): + TYPE = "spin" + DATASTORE_DIR = DATASTORE_SPIN_LOCAL_DIR # ".metaflow_spin" + + @classmethod + def _get_storage_class(cls): + from metaflow.plugins.datastores.spin_storage import SpinStorage + + return SpinStorage + + def version(self): + return "spin" diff --git a/metaflow/runner/metaflow_runner.py b/metaflow/runner/metaflow_runner.py index 84eb8ac4284..ba663edae68 100644 --- a/metaflow/runner/metaflow_runner.py +++ b/metaflow/runner/metaflow_runner.py @@ -6,7 +6,7 @@ from typing import Dict, Iterator, Optional, Tuple -from metaflow import Run +from metaflow import Run, Task from metaflow.metaflow_config import CLICK_API_PROCESS_CONFIG @@ -21,30 +21,36 @@ from .subprocess_manager import CommandManager, SubprocessManager -class ExecutingRun(object): +class ExecutingProcess(object): """ - This class contains a reference to a `metaflow.Run` object representing - the currently executing or finished run, as well as metadata related - to the process. + This is a base class for `ExecutingRun` and `ExecutingTask` classes. + The `ExecutingRun` and `ExecutingTask` classes are returned by methods + in `Runner` and `NBRunner`, and they are subclasses of this class. - `ExecutingRun` is returned by methods in `Runner` and `NBRunner`. It is not - meant to be instantiated directly. + The `ExecutingRun` class for instance contains a reference to a `metaflow.Run` + object representing the currently executing or finished run, as well as the metadata + related to the process. + + Similarly, the `ExecutingTask` class contains a reference to a `metaflow.Task` + object representing the currently executing or finished task, as well as the metadata + related to the process. + + This class or its subclasses are not meant to be instantiated directly. The class + works as a context manager, allowing you to use a pattern like: - This class works as a context manager, allowing you to use a pattern like ```python with Runner(...).run() as running: ... ``` - Note that you should use either this object as the context manager or - `Runner`, not both in a nested manner. + + Note that you should use either this object as the context manager or `Runner`, not both + in a nested manner. """ - def __init__( - self, runner: "Runner", command_obj: CommandManager, run_obj: Run - ) -> None: + def __init__(self, runner: "Runner", command_obj: CommandManager) -> None: """ Create a new ExecutingRun -- this should not be done by the user directly but - instead user Runner.run() + instead use Runner.run() Parameters ---------- @@ -57,9 +63,8 @@ def __init__( """ self.runner = runner self.command_obj = command_obj - self.run = run_obj - def __enter__(self) -> "ExecutingRun": + def __enter__(self) -> "ExecutingProcess": return self def __exit__(self, exc_type, exc_value, traceback): @@ -67,7 +72,7 @@ def __exit__(self, exc_type, exc_value, traceback): async def wait( self, timeout: Optional[float] = None, stream: Optional[str] = None - ) -> "ExecutingRun": + ) -> "ExecutingProcess": """ Wait for this run to finish, optionally with a timeout and optionally streaming its output. @@ -86,7 +91,7 @@ async def wait( Returns ------- - ExecutingRun + ExecutingProcess This object, allowing you to chain calls. """ await self.command_obj.wait(timeout, stream) @@ -193,6 +198,76 @@ async def stream_log( yield position, line +class ExecutingTask(ExecutingProcess): + """ + This class contains a reference to a `metaflow.Task` object representing + the currently executing or finished task, as well as metadata related + to the process. + `ExecutingTask` is returned by methods in `Runner` and `NBRunner`. It is not + meant to be instantiated directly. + This class works as a context manager, allowing you to use a pattern like + ```python + with Runner(...).spin() as running: + ... + ``` + Note that you should use either this object as the context manager or + `Runner`, not both in a nested manner. + """ + + def __init__( + self, runner: "Runner", command_obj: CommandManager, task_obj: Task + ) -> None: + """ + Create a new ExecutingTask -- this should not be done by the user directly but + instead use Runner.spin() + Parameters + ---------- + runner : Runner + Parent runner for this task. + command_obj : CommandManager + CommandManager containing the subprocess executing this task. + task_obj : Task + Task object corresponding to this task. + """ + super().__init__(runner, command_obj) + self.task = task_obj + + +class ExecutingRun(ExecutingProcess): + """ + This class contains a reference to a `metaflow.Run` object representing + the currently executing or finished run, as well as metadata related + to the process. + `ExecutingRun` is returned by methods in `Runner` and `NBRunner`. It is not + meant to be instantiated directly. + This class works as a context manager, allowing you to use a pattern like + ```python + with Runner(...).run() as running: + ... + ``` + Note that you should use either this object as the context manager or + `Runner`, not both in a nested manner. + """ + + def __init__( + self, runner: "Runner", command_obj: CommandManager, run_obj: Run + ) -> None: + """ + Create a new ExecutingRun -- this should not be done by the user directly but + instead use Runner.run() + Parameters + ---------- + runner : Runner + Parent runner for this run. + command_obj : CommandManager + CommandManager containing the subprocess executing this run. + run_obj : Run + Run object corresponding to this run. + """ + super().__init__(runner, command_obj) + self.run = run_obj + + class RunnerMeta(type): def __new__(mcs, name, bases, dct): cls = super().__new__(mcs, name, bases, dct) @@ -275,7 +350,7 @@ def __init__( env: Optional[Dict[str, str]] = None, cwd: Optional[str] = None, file_read_timeout: int = 3600, - **kwargs + **kwargs, ): # these imports are required here and not at the top # since they interfere with the user defined Parameters @@ -397,6 +472,78 @@ def run(self, **kwargs) -> ExecutingRun: return self.__get_executing_run(attribute_file_fd, command_obj) + def __get_executing_task(self, attribute_file_fd, command_obj): + content = handle_timeout(attribute_file_fd, command_obj, self.file_read_timeout) + + command_obj.sync_wait() + + content = json.loads(content) + pathspec = f"{content.get('flow_name')}/{content.get('run_id')}/{content.get('step_name')}/{content.get('task_id')}" + + # Set the correct metadata from the runner_attribute file corresponding to this run. + metadata_for_flow = content.get("metadata") + + task_object = Task( + pathspec, _namespace_check=False, _current_metadata=metadata_for_flow + ) + return ExecutingTask(self, command_obj, task_object) + + async def __async_get_executing_task(self, attribute_file_fd, command_obj): + content = await async_handle_timeout( + attribute_file_fd, command_obj, self.file_read_timeout + ) + content = json.loads(content) + pathspec = f"{content.get('flow_name')}/{content.get('run_id')}/{content.get('step_name')}/{content.get('task_id')}" + + # Set the correct metadata from the runner_attribute file corresponding to this run. + metadata_for_flow = content.get("metadata") + + task_object = Task( + pathspec, _namespace_check=False, _current_metadata=metadata_for_flow + ) + return ExecutingTask(self, command_obj, task_object) + + def spin(self, pathspec, **kwargs) -> ExecutingTask: + """ + Blocking spin execution of the run. + This method will wait until the spun run has completed execution. + Parameters + ---------- + pathspec : str + The pathspec of the step/task to spin. + **kwargs : Any + Additional arguments that you would pass to `python ./myflow.py` after + the `spin` command. + Returns + ------- + ExecutingTask + ExecutingTask containing the results of the spun task. + """ + with temporary_fifo() as (attribute_file_path, attribute_file_fd): + if CLICK_API_PROCESS_CONFIG: + with with_dir(self.cwd): + command = self.api(**self.top_level_kwargs).spin( + pathspec=pathspec, + runner_attribute_file=attribute_file_path, + **kwargs, + ) + else: + command = self.api(**self.top_level_kwargs).spin( + pathspec=pathspec, + runner_attribute_file=attribute_file_path, + **kwargs, + ) + + pid = self.spm.run_command( + [sys.executable, *command], + env=self.env_vars, + cwd=self.cwd, + show_output=self.show_output, + ) + command_obj = self.spm.get(pid) + + return self.__get_executing_task(attribute_file_fd, command_obj) + def resume(self, **kwargs) -> ExecutingRun: """ Blocking resume execution of the run. @@ -510,6 +657,50 @@ async def async_resume(self, **kwargs) -> ExecutingRun: return await self.__async_get_executing_run(attribute_file_fd, command_obj) + async def async_spin(self, pathspec, **kwargs) -> ExecutingTask: + """ + Non-blocking spin execution of the run. + This method will return as soon as the spun task has launched. + + Note that this method is asynchronous and needs to be `await`ed. + + Parameters + ---------- + pathspec : str + The pathspec of the step/task to spin. + **kwargs : Any + Additional arguments that you would pass to `python ./myflow.py` after + the `spin` command. + + Returns + ------- + ExecutingTask + ExecutingTask representing the spun task that was started. + """ + with temporary_fifo() as (attribute_file_path, attribute_file_fd): + if CLICK_API_PROCESS_CONFIG: + with with_dir(self.cwd): + command = self.api(**self.top_level_kwargs).spin( + pathspec=pathspec, + runner_attribute_file=attribute_file_path, + **kwargs, + ) + else: + command = self.api(**self.top_level_kwargs).spin( + pathspec=pathspec, + runner_attribute_file=attribute_file_path, + **kwargs, + ) + + pid = await self.spm.async_run_command( + [sys.executable, *command], + env=self.env_vars, + cwd=self.cwd, + ) + command_obj = self.spm.get(pid) + + return await self.__async_get_executing_task(attribute_file_fd, command_obj) + def __exit__(self, exc_type, exc_value, traceback): self.spm.cleanup() diff --git a/metaflow/runtime.py b/metaflow/runtime.py index 69c7ff4256d..e6388da3442 100644 --- a/metaflow/runtime.py +++ b/metaflow/runtime.py @@ -26,20 +26,28 @@ from contextlib import contextmanager from . import get_namespace +from .client.filecache import FileCache, FileBlobCache, TaskMetadataCache from .metadata_provider import MetaDatum -from .metaflow_config import FEAT_ALWAYS_UPLOAD_CODE_PACKAGE, MAX_ATTEMPTS, UI_URL +from .metaflow_config import ( + FEAT_ALWAYS_UPLOAD_CODE_PACKAGE, + MAX_ATTEMPTS, + UI_URL, + SPIN_ALLOWED_DECORATORS, +) +from .metaflow_profile import from_start +from .plugins import DATASTORES from .exception import ( MetaflowException, MetaflowInternalError, METAFLOW_EXIT_DISALLOW_RETRY, ) from . import procpoll -from .datastore import TaskDataStoreSet +from .datastore import FlowDataStore, TaskDataStoreSet from .debug import debug from .decorators import flow_decorators from .flowspec import _FlowState from .mflog import mflog, RUNTIME_LOG_SOURCE -from .util import to_unicode, compress_list, unicode_type +from .util import to_unicode, compress_list, unicode_type, get_latest_task_pathspec from .clone_util import clone_task_helper from .unbounded_foreach import ( CONTROL_TASK_TAG, @@ -85,6 +93,282 @@ class LoopBehavior(Enum): # TODO option: output dot graph periodically about execution +class SpinRuntime(object): + def __init__( + self, + flow, + graph, + flow_datastore, + metadata, + environment, + package, + logger, + entrypoint, + event_logger, + monitor, + step_func, + step_name, + spin_pathspec, + skip_decorators=False, + artifacts_module=None, + persist=True, + max_log_size=MAX_LOG_SIZE, + ): + from metaflow import Task + + self._flow = flow + self._graph = graph + self._flow_datastore = flow_datastore + self._metadata = metadata + self._environment = environment + self._package = package + self._logger = logger + self._entrypoint = entrypoint + self._event_logger = event_logger + self._monitor = monitor + + self._step_func = step_func + + # Determine if we have a complete pathspec or need to get the task + if spin_pathspec: + parts = spin_pathspec.split("/") + if len(parts) == 4: + # Complete pathspec: flow/run/step/task_id + try: + task = Task(spin_pathspec, _namespace_check=False) + except Exception: + raise MetaflowException( + f"Invalid pathspec: {spin_pathspec} for step: {step_name}" + ) + elif len(parts) == 3: + # Partial pathspec: flow/run/step - need to get the task + _, run_id, _ = parts + task = get_latest_task_pathspec(flow.name, step_name, run_id=run_id) + logger( + f"To make spin even faster, provide complete pathspec with task_id: {task.pathspec}", + system_msg=True, + ) + else: + raise MetaflowException( + f"Invalid pathspec format: {spin_pathspec}. Expected flow/run/step or flow/run/step/task_id" + ) + else: + # No pathspec provided, get latest task for this step + task = get_latest_task_pathspec(flow.name, step_name) + logger( + f"To make spin even faster, provide complete pathspec {task.pathspec}", + system_msg=True, + ) + from_start("SpinRuntime: after getting task") + + # Get the original FlowDatastore so we can use it to access artifacts from the + # spun task + meta_dict = task.metadata_dict + ds_type = meta_dict["ds-type"] + ds_root = meta_dict["ds-root"] + orig_datastore_impl = [d for d in DATASTORES if d.TYPE == ds_type][0] + orig_datastore_impl.datastore_root = ds_root + spin_pathspec = task.pathspec + orig_flow_datastore = FlowDataStore( + flow.name, + environment=None, + storage_impl=orig_datastore_impl, + ds_root=ds_root, + ) + + self._filecache = FileCache() + orig_flow_datastore.set_metadata_cache( + TaskMetadataCache(self._filecache, ds_type, ds_root, flow.name) + ) + orig_flow_datastore.ca_store.set_blob_cache( + FileBlobCache( + self._filecache, FileCache.flow_ds_id(ds_type, ds_root, flow.name) + ) + ) + + self._orig_flow_datastore = orig_flow_datastore + self._spin_pathspec = spin_pathspec + self._persist = persist + self._spin_task = task + self._input_paths = None + self._split_index = None + self._whitelist_decorators = None + self._config_file_name = None + self._skip_decorators = skip_decorators + self._artifacts_module = artifacts_module + self._max_log_size = max_log_size + self._encoding = sys.stdout.encoding or "UTF-8" + + # If no artifacts module is provided, create a temporary one with parameter values + if not self._artifacts_module and hasattr(flow, "_get_parameters"): + import tempfile + import os + + # Collect parameter values from the flow + param_artifacts = {} + for var, param in flow._get_parameters(): + if hasattr(flow, var): + value = getattr(flow, var) + # Only add if it's an actual value, not the Parameter object + if value is not None and not hasattr(value, "IS_PARAMETER"): + param_artifacts[var] = value + + # If we have parameter values, create a temp module + if param_artifacts: + with tempfile.NamedTemporaryFile( + mode="w", suffix=".py", delete=False + ) as f: + f.write( + "# Auto-generated artifacts module for spin step parameters\n" + ) + f.write("ARTIFACTS = {\n") + for key, value in param_artifacts.items(): + f.write(f" {repr(key)}: {repr(value)},\n") + f.write("}\n") + self._artifacts_module = f.name + self._temp_artifacts_file = f.name # Store for cleanup later + + # Create a new run_id for the spin task + self.run_id = self._metadata.new_run_id() + for deco in self.whitelist_decorators: + deco.runtime_init(flow, graph, package, self.run_id) + from_start("SpinRuntime: after init decorators") + + @property + def split_index(self): + """ + Returns the split index, caching the result after the first access. + """ + if self._split_index is None: + self._split_index = getattr(self._spin_task, "index", None) + + return self._split_index + + @property + def input_paths(self): + def _format_input_paths(task_pathspec, attempt): + _, run_id, step_name, task_id = task_pathspec.split("/") + return f"{run_id}/{step_name}/{task_id}/{attempt}" + + if self._input_paths: + return self._input_paths + + if self._step_func.name == "start": + from metaflow import Step + + flow_name, run_id, _, _ = self._spin_pathspec.split("/") + task = Step( + f"{flow_name}/{run_id}/_parameters", _namespace_check=False + ).task + self._input_paths = [ + _format_input_paths(task.pathspec, task.current_attempt) + ] + else: + parent_tasks = self._spin_task.parent_tasks + self._input_paths = [ + _format_input_paths(t.pathspec, t.current_attempt) for t in parent_tasks + ] + return self._input_paths + + @property + def whitelist_decorators(self): + if self._skip_decorators: + self._whitelist_decorators = [] + return self._whitelist_decorators + if self._whitelist_decorators: + return self._whitelist_decorators + self._whitelist_decorators = [ + deco + for deco in self._step_func.decorators + if any(deco.name.startswith(prefix) for prefix in SPIN_ALLOWED_DECORATORS) + ] + return self._whitelist_decorators + + def _new_task(self, step, input_paths=None, **kwargs): + return Task( + flow_datastore=self._flow_datastore, + flow=self._flow, + step=step, + run_id=self.run_id, + metadata=self._metadata, + environment=self._environment, + entrypoint=self._entrypoint, + event_logger=self._event_logger, + monitor=self._monitor, + input_paths=input_paths, + decos=self.whitelist_decorators, + logger=self._logger, + split_index=self.split_index, + **kwargs, + ) + + def execute(self): + exception = None + with tempfile.NamedTemporaryFile(mode="w", encoding="utf-8") as config_file: + config_value = dump_config_values(self._flow) + if config_value: + json.dump(config_value, config_file) + config_file.flush() + self._config_file_name = config_file.name + else: + self._config_file_name = None + from_start("SpinRuntime: config values processed") + self.task = self._new_task(self._step_func.name, self.input_paths) + try: + self._launch_and_monitor_task() + except Exception as ex: + self._logger("Task failed.", system_msg=True, bad=True) + exception = ex + raise + finally: + for deco in self.whitelist_decorators: + deco.runtime_finished(exception) + # Clean up temporary artifacts file if we created one + if hasattr(self, "_temp_artifacts_file"): + import os + + try: + os.unlink(self._temp_artifacts_file) + except: + pass + + def _launch_and_monitor_task(self): + worker = Worker( + self.task, + self._max_log_size, + self._config_file_name, + orig_flow_datastore=self._orig_flow_datastore, + spin_pathspec=self._spin_pathspec, + artifacts_module=self._artifacts_module, + persist=self._persist, + skip_decorators=self._skip_decorators, + ) + from_start("SpinRuntime: created worker") + + poll = procpoll.make_poll() + fds = worker.fds() + for fd in fds: + poll.add(fd) + + active_fds = set(fds) + + while active_fds: + events = poll.poll(POLL_TIMEOUT) + for event in events: + if event.can_read: + worker.read_logline(event.fd) + if event.is_terminated: + poll.remove(event.fd) + active_fds.remove(event.fd) + from_start("SpinRuntime: read loglines") + returncode = worker.terminate() + from_start("SpinRuntime: worker terminated") + if returncode != 0: + raise TaskFailed(self.task, f"Task failed with return code {returncode}") + else: + self._logger("Task finished successfully.", system_msg=True) + + class NativeRuntime(object): def __init__( self, @@ -1757,8 +2041,27 @@ class CLIArgs(object): for step execution in StepDecorator.runtime_step_cli(). """ - def __init__(self, task): + def __init__( + self, + task, + orig_flow_datastore=None, + spin_pathspec=None, + artifacts_module=None, + persist=True, + skip_decorators=False, + ): self.task = task + if orig_flow_datastore is not None: + self.orig_flow_datastore = "%s@%s" % ( + orig_flow_datastore.TYPE, + orig_flow_datastore.datastore_root, + ) + else: + self.orig_flow_datastore = None + self.spin_pathspec = spin_pathspec + self.artifacts_module = artifacts_module + self.persist = persist + self.skip_decorators = skip_decorators self.entrypoint = list(task.entrypoint) step_obj = getattr(self.task.flow, self.task.step) self.top_level_options = { @@ -1796,21 +2099,49 @@ def __init__(self, task): (k, ConfigInput.make_key_name(k)) for k in configs ] + if spin_pathspec: + self.spin_args() + else: + self.default_args() + + def default_args(self): self.commands = ["step"] self.command_args = [self.task.step] self.command_options = { - "run-id": task.run_id, - "task-id": task.task_id, - "input-paths": compress_list(task.input_paths), - "split-index": task.split_index, - "retry-count": task.retries, - "max-user-code-retries": task.user_code_retries, - "tag": task.tags, + "run-id": self.task.run_id, + "task-id": self.task.task_id, + "input-paths": compress_list(self.task.input_paths), + "split-index": self.task.split_index, + "retry-count": self.task.retries, + "max-user-code-retries": self.task.user_code_retries, + "tag": self.task.tags, "namespace": get_namespace() or "", - "ubf-context": task.ubf_context, + "ubf-context": self.task.ubf_context, } self.env = {} + def spin_args(self): + self.commands = ["spin-step"] + self.command_args = [self.task.step] + + self.command_options = { + "run-id": self.task.run_id, + "task-id": self.task.task_id, + "input-paths": compress_list(self.task.input_paths), + "split-index": self.task.split_index, + "retry-count": self.task.retries, + "max-user-code-retries": self.task.user_code_retries, + "namespace": get_namespace() or "", + "orig-flow-datastore": self.orig_flow_datastore, + "artifacts-module": self.artifacts_module, + "skip-decorators": self.skip_decorators, + } + if self.persist: + self.command_options["persist"] = True + else: + self.command_options["no-persist"] = True + self.env = {} + def get_args(self): # TODO: Make one with dict_to_cli_options; see cli_args.py for more detail def _options(mapping): @@ -1849,9 +2180,24 @@ def __str__(self): class Worker(object): - def __init__(self, task, max_logs_size, config_file_name): + def __init__( + self, + task, + max_logs_size, + config_file_name, + orig_flow_datastore=None, + spin_pathspec=None, + artifacts_module=None, + persist=True, + skip_decorators=False, + ): self.task = task self._config_file_name = config_file_name + self._orig_flow_datastore = orig_flow_datastore + self._spin_pathspec = spin_pathspec + self._artifacts_module = artifacts_module + self._skip_decorators = skip_decorators + self._persist = persist self._proc = self._launch() if task.retries > task.user_code_retries: @@ -1883,7 +2229,14 @@ def __init__(self, task, max_logs_size, config_file_name): # not it is properly shut down) def _launch(self): - args = CLIArgs(self.task) + args = CLIArgs( + self.task, + orig_flow_datastore=self._orig_flow_datastore, + spin_pathspec=self._spin_pathspec, + artifacts_module=self._artifacts_module, + persist=self._persist, + skip_decorators=self._skip_decorators, + ) env = dict(os.environ) if self.task.clone_run_id: @@ -1916,6 +2269,7 @@ def _launch(self): # by read_logline() below that relies on readline() not blocking # print('running', args) cmdline = args.get_args() + from_start(f"Command line: {' '.join(cmdline)}") debug.subcommand_exec(cmdline) return subprocess.Popen( cmdline, @@ -2038,13 +2392,14 @@ def terminate(self): else: self.emit_log(b"Task failed.", self._stderr, system_msg=True) else: - num = self.task.results["_foreach_num_splits"] - if num: - self.task.log( - "Foreach yields %d child steps." % num, - system_msg=True, - pid=self._proc.pid, - ) + if not self._spin_pathspec: + num = self.task.results["_foreach_num_splits"] + if num: + self.task.log( + "Foreach yields %d child steps." % num, + system_msg=True, + pid=self._proc.pid, + ) self.task.log( "Task finished successfully.", system_msg=True, pid=self._proc.pid ) diff --git a/metaflow/task.py b/metaflow/task.py index 548e241722a..c1c21e87710 100644 --- a/metaflow/task.py +++ b/metaflow/task.py @@ -6,14 +6,15 @@ import time import traceback - from types import MethodType, FunctionType from metaflow.sidecar import Message, MessageTypes from metaflow.datastore.exceptions import DataException +from metaflow.plugins import METADATA_PROVIDERS from .metaflow_config import MAX_ATTEMPTS from .metadata_provider import MetaDatum +from .metaflow_profile import from_start from .mflog import TASK_LOG_SOURCE from .datastore import Inputs, TaskDataStoreSet from .exception import ( @@ -49,6 +50,8 @@ def __init__( event_logger, monitor, ubf_context, + orig_flow_datastore=None, + spin_artifacts=None, ): self.flow = flow self.flow_datastore = flow_datastore @@ -58,6 +61,8 @@ def __init__( self.event_logger = event_logger self.monitor = monitor self.ubf_context = ubf_context + self.orig_flow_datastore = orig_flow_datastore + self.spin_artifacts = spin_artifacts def _exec_step_function(self, step_function, orig_step_func, input_obj=None): wrappers_stack = [] @@ -233,7 +238,6 @@ def property_setter( lambda _, parameter_ds=parameter_ds: parameter_ds["_graph_info"], ) all_vars.append("_graph_info") - if passdown: self.flow._datastore.passdown_partial(parameter_ds, all_vars) return param_only_vars @@ -261,6 +265,9 @@ def _init_data(self, run_id, join_type, input_paths): run_id, pathspecs=input_paths, prefetch_data_artifacts=prefetch_data_artifacts, + join_type=join_type, + orig_flow_datastore=self.orig_flow_datastore, + spin_artifacts=self.spin_artifacts, ) ds_list = [ds for ds in datastore_set] if len(ds_list) != len(input_paths): @@ -272,10 +279,27 @@ def _init_data(self, run_id, join_type, input_paths): # initialize directly in the single input case. ds_list = [] for input_path in input_paths: - run_id, step_name, task_id = input_path.split("/") + parts = input_path.split("/") + if len(parts) == 3: + run_id, step_name, task_id = parts + attempt = None + else: + run_id, step_name, task_id, attempt = parts + attempt = int(attempt) + ds_list.append( - self.flow_datastore.get_task_datastore(run_id, step_name, task_id) + self.flow_datastore.get_task_datastore( + run_id, + step_name, + task_id, + attempt=attempt, + join_type=join_type, + orig_flow_datastore=self.orig_flow_datastore, + spin_artifacts=self.spin_artifacts, + ) ) + from_start("MetaflowTask: got datastore for input path %s" % input_path) + if not ds_list: # this guards against errors in input paths raise MetaflowDataMissing( @@ -546,6 +570,8 @@ def run_step( split_index, retry_count, max_user_code_retries, + whitelist_decorators=None, + persist=True, ): if run_id and task_id: self.metadata.register_run_id(run_id) @@ -604,7 +630,12 @@ def run_step( step_func = getattr(self.flow, step_name) decorators = step_func.decorators - + if self.orig_flow_datastore: + # We filter only the whitelisted decorators in case of spin step. + decorators = [] if not whitelist_decorators else [ + deco for deco in decorators if deco.name in whitelist_decorators + ] + from_start("MetaflowTask: decorators initialized") node = self.flow._graph[step_name] join_type = None if node.type == "join": @@ -612,17 +643,20 @@ def run_step( # 1. initialize output datastore output = self.flow_datastore.get_task_datastore( - run_id, step_name, task_id, attempt=retry_count, mode="w" + run_id, step_name, task_id, attempt=retry_count, mode="w", persist=persist ) output.init_task() + from_start("MetaflowTask: output datastore initialized") if input_paths: # 2. initialize input datastores inputs = self._init_data(run_id, join_type, input_paths) + from_start("MetaflowTask: input datastores initialized") # 3. initialize foreach state self._init_foreach(step_name, join_type, inputs, split_index) + from_start("MetaflowTask: foreach state initialized") # 4. initialize the iteration state is_recursive_step = ( @@ -681,7 +715,7 @@ def run_step( ), ] ) - + from_start("MetaflowTask: finished input processing") self.metadata.register_metadata( run_id, step_name, @@ -735,8 +769,11 @@ def run_step( "project_flow_name": current.get("project_flow_name"), "trace_id": trace_id or None, } + + from_start("MetaflowTask: task metadata initialized") start = time.time() self.metadata.start_task_heartbeat(self.flow.name, run_id, step_name, task_id) + from_start("MetaflowTask: heartbeat started") with self.monitor.measure("metaflow.task.duration"): try: with self.monitor.count("metaflow.task.start"): @@ -756,7 +793,6 @@ def run_step( # should either be set prior to running the user code or listed in # FlowSpec._EPHEMERAL to allow for proper merging/importing of # user artifacts in the user's step code. - if join_type: # Join step: @@ -815,11 +851,19 @@ def run_step( "graph_info": self.flow._graph_info, } ) + from_start("MetaflowTask: before pre-step decorators") for deco in decorators: + if deco.name == "card" and self.orig_flow_datastore: + # if spin step and card decorator, pass spin metadata + metadata = [m for m in METADATA_PROVIDERS if m.TYPE == "spin"][ + 0 + ](self.environment, self.flow, self.event_logger, self.monitor) + else: + metadata = self.metadata deco.task_pre_step( step_name, output, - self.metadata, + metadata, run_id, task_id, self.flow, @@ -845,12 +889,12 @@ def run_step( max_user_code_retries, self.ubf_context, ) - + from_start("MetaflowTask: finished decorator processing") if join_type: self._exec_step_function(step_func, orig_step_func, input_obj) else: self._exec_step_function(step_func, orig_step_func) - + from_start("MetaflowTask: step function executed") for deco in decorators: deco.task_post_step( step_name, @@ -893,6 +937,7 @@ def run_step( raise finally: + from_start("MetaflowTask: decorators finalized") if self.ubf_context == UBF_CONTROL: self._finalize_control_task() @@ -932,7 +977,7 @@ def run_step( ) output.save_metadata({"task_end": {}}) - + from_start("MetaflowTask: output persisted") # this writes a success marker indicating that the # "transaction" is done output.done() @@ -961,3 +1006,4 @@ def run_step( name="duration", payload={**task_payload, "msg": str(duration)}, ) + from_start("MetaflowTask: task run completed") diff --git a/metaflow/util.py b/metaflow/util.py index c0383766b5d..82d282f1707 100644 --- a/metaflow/util.py +++ b/metaflow/util.py @@ -7,6 +7,7 @@ from functools import wraps from io import BytesIO from itertools import takewhile +from typing import Dict, Any, Tuple import re @@ -178,6 +179,117 @@ def resolve_identity(): return "%s:%s" % (identity_type, identity_value) +def parse_spin_pathspec(pathspec: str, flow_name: str) -> Tuple: + """ + Parse various pathspec formats for the spin command. + + Parameters + ---------- + pathspec : str + The pathspec string in one of the following formats: + - step_name (e.g., 'start') + - run_id/step_name (e.g., '221165/start') + - run_id/step_name/task_id (e.g., '221165/start/1350987') + - flow_name/run_id/step_name (e.g., 'ScalableFlow/221165/start') + - flow_name/run_id/step_name/task_id (e.g., 'ScalableFlow/221165/start/1350987') + flow_name : str + The name of the current flow. + + Returns + ------- + Tuple + A tuple of (step_name, full_pathspec_or_none) + + Raises + ------ + CommandException + If the pathspec format is invalid or flow name doesn't match. + """ + from .exception import CommandException + + parts = pathspec.split("/") + + if len(parts) == 1: + # Just step name: 'start' + step_name = parts[0] + parsed_pathspec = None + elif len(parts) == 2: + # run_id/step_name: '221165/start' + run_id, step_name = parts + parsed_pathspec = f"{flow_name}/{run_id}/{step_name}" + elif len(parts) == 3: + # Could be run_id/step_name/task_id or flow_name/run_id/step_name + if parts[0] == flow_name: + # flow_name/run_id/step_name + _, run_id, step_name = parts + parsed_pathspec = f"{flow_name}/{run_id}/{step_name}" + else: + # run_id/step_name/task_id + run_id, step_name, task_id = parts + parsed_pathspec = f"{flow_name}/{run_id}/{step_name}/{task_id}" + elif len(parts) == 4: + # flow_name/run_id/step_name/task_id + parsed_flow_name, run_id, step_name, task_id = parts + if parsed_flow_name != flow_name: + raise CommandException( + f"Flow name '{parsed_flow_name}' in pathspec does not match current flow '{flow_name}'." + ) + parsed_pathspec = pathspec + else: + raise CommandException( + f"Invalid pathspec format: '{pathspec}'. \n" + "Expected formats:\n" + " - step_name (e.g., 'start')\n" + " - run_id/step_name (e.g., '221165/start')\n" + " - run_id/step_name/task_id (e.g., '221165/start/1350987')\n" + " - flow_name/run_id/step_name (e.g., 'ScalableFlow/221165/start')\n" + " - flow_name/run_id/step_name/task_id (e.g., 'ScalableFlow/221165/start/1350987')" + ) + + return step_name, parsed_pathspec + + +def get_latest_task_pathspec(flow_name: str, step_name: str, run_id: str = None): + """ + Returns a task pathspec from the latest run (or specified run) of the flow for the queried step. + If the queried step has several tasks, the task pathspec of the first task is returned. + + Parameters + ---------- + flow_name : str + The name of the flow. + step_name : str + The name of the step. + run_id : str, optional + The run ID to use. If None, uses the latest run. + + Returns + ------- + Task + A Metaflow Task instance containing the latest task for the queried step. + + Raises + ------ + MetaflowNotFound + If no task or run is found for the queried step. + """ + from metaflow import Flow, Step + from metaflow.exception import MetaflowNotFound + + if not run_id: + flow = Flow(flow_name, _namespace_check=False) + run = flow.latest_run + if run is None: + raise MetaflowNotFound(f"No run found for flow {flow_name}") + run_id = run.id + + try: + task = Step(f"{flow_name}/{run_id}/{step_name}", _namespace_check=False).task + return task + except: + raise MetaflowNotFound(f"No task found for step {step_name} in run {run_id}") + + def get_latest_run_id(echo, flow_name): from metaflow.plugins.datastores.local_storage import LocalStorage @@ -471,3 +583,40 @@ def to_pod(value): from metaflow._vendor.packaging.version import parse as version_parse + + +def read_artifacts_module(file_path: str) -> Dict[str, Any]: + """ + Read a Python module from the given file path and return its ARTIFACTS variable. + + Parameters + ---------- + file_path : str + The path to the Python file containing the ARTIFACTS variable. + + Returns + ------- + Dict[str, Any] + A dictionary containing the ARTIFACTS variable from the module. + + Raises + ------- + MetaflowInternalError + If the file cannot be read or does not contain the ARTIFACTS variable. + """ + import importlib.util + import os + + try: + module_name = os.path.splitext(os.path.basename(file_path))[0] + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + variables = vars(module) + if "ARTIFACTS" not in variables: + raise MetaflowInternalError( + f"Module {file_path} does not contain ARTIFACTS variable" + ) + return variables.get("ARTIFACTS") + except Exception as e: + raise MetaflowInternalError(f"Error reading file {file_path}") from e diff --git a/test/core/tests/runtime_dag.py b/test/core/tests/runtime_dag.py index 8bc54985d20..cfad22ee17e 100644 --- a/test/core/tests/runtime_dag.py +++ b/test/core/tests/runtime_dag.py @@ -71,7 +71,9 @@ def _equals_task(task1, task2): if name not in [ "parent_tasks", + "parent_task_pathspecs", "child_tasks", + "child_task_pathspecs", "metadata", "data", "artifacts", diff --git a/test/unit/spin/artifacts/complex_dag_step_a.py b/test/unit/spin/artifacts/complex_dag_step_a.py new file mode 100644 index 00000000000..b7e81bf1b6f --- /dev/null +++ b/test/unit/spin/artifacts/complex_dag_step_a.py @@ -0,0 +1 @@ +ARTIFACTS = {"my_output": [10, 11, 12]} diff --git a/test/unit/spin/artifacts/complex_dag_step_d.py b/test/unit/spin/artifacts/complex_dag_step_d.py new file mode 100644 index 00000000000..20bb0376e8d --- /dev/null +++ b/test/unit/spin/artifacts/complex_dag_step_d.py @@ -0,0 +1,3 @@ +# This file is kept for backwards compatibility but should not be used directly +# The artifacts are now generated dynamically via pytest fixtures +ARTIFACTS = {} diff --git a/test/unit/spin/conftest.py b/test/unit/spin/conftest.py new file mode 100644 index 00000000000..a4f90937543 --- /dev/null +++ b/test/unit/spin/conftest.py @@ -0,0 +1,90 @@ +import pytest +from metaflow import Runner, Flow +import os + +# Get the directory containing the flows +FLOWS_DIR = os.path.join(os.path.dirname(__file__), "flows") + + +def pytest_addoption(parser): + """Add custom command line options.""" + parser.addoption( + "--use-latest", + action="store_true", + default=False, + help="Use latest run of each flow instead of running new ones", + ) + + +def create_flow_fixture(flow_name, flow_file, run_params=None, runner_params=None): + """ + Factory function to create flow fixtures with common logic. + + Parameters + ----------- + flow_name: str + Name of the flow class + flow_file: str + Python file containing the flow + run_params: dict, optional + Parameters to pass to .run() method + runner_params: dict, optional + Parameters to pass to Runner() + """ + + def flow_fixture(request): + if request.config.getoption("--use-latest"): + flow = Flow(flow_name, _namespace_check=False) + return flow.latest_run + else: + flow_path = os.path.join(FLOWS_DIR, flow_file) + runner_params_dict = runner_params or {} + runner_params_dict["cwd"] = FLOWS_DIR # Always set cwd to FLOWS_DIR + run_params_dict = run_params or {} + + with Runner(flow_path, **runner_params_dict).run( + **run_params_dict + ) as running: + return running.run + + return flow_fixture + + +# Create all the flow fixtures using the factory +complex_dag_run = pytest.fixture(scope="session")( + create_flow_fixture( + "ComplexDAGFlow", "complex_dag_flow.py", runner_params={"environment": "conda"} + ) +) + +merge_artifacts_run = pytest.fixture(scope="session")( + create_flow_fixture("MergeArtifactsFlow", "merge_artifacts_flow.py") +) + +simple_parameter_run = pytest.fixture(scope="session")( + create_flow_fixture( + "SimpleParameterFlow", "simple_parameter_flow.py", run_params={"alpha": 0.05} + ) +) + +simple_card_run = pytest.fixture(scope="session")( + create_flow_fixture( + "SimpleCardFlow", "simple_card_flow.py", + ) +) + +simple_config_run = pytest.fixture(scope="session")( + create_flow_fixture( + "TimeoutConfigFlow", + "simple_config_flow.py", + ) +) + + +@pytest.fixture +def complex_dag_step_d_artifacts(complex_dag_run): + """Generate dynamic artifacts for complex_dag step_d tests.""" + task = complex_dag_run["step_d"].task + task_pathspec = next(task.parent_task_pathspecs) + _, inp_path = task_pathspec.split("/", 1) + return {inp_path: {"my_output": [-1]}} diff --git a/test/unit/spin/flows/complex_dag_flow.py b/test/unit/spin/flows/complex_dag_flow.py new file mode 100644 index 00000000000..04b185fe40f --- /dev/null +++ b/test/unit/spin/flows/complex_dag_flow.py @@ -0,0 +1,116 @@ +from metaflow import FlowSpec, step, project, conda, Task, pypi + + +class ComplexDAGFlow(FlowSpec): + @step + def start(self): + self.split_start = [1, 2, 3] + self.my_output = [] + print("My output is: ", self.my_output) + self.next(self.step_a, foreach="split_start") + + @step + def step_a(self): + self.split_a = [4, 5] + self.my_output = self.my_output + [self.input] + print("My output is: ", self.my_output) + self.next(self.step_b, foreach="split_a") + + @step + def step_b(self): + self.split_b = [6, 7, 8] + self.my_output = self.my_output + [self.input] + print("My output is: ", self.my_output) + self.next(self.step_c, foreach="split_b") + + @conda(libraries={"numpy": "2.1.1"}) + @step + def step_c(self): + import numpy as np + + self.np_version = np.__version__ + print(f"numpy version: {self.np_version}") + self.my_output = self.my_output + [self.input] + [9, 10] + print("My output is: ", self.my_output) + self.next(self.step_d) + + @step + def step_d(self, inputs): + self.my_output = sorted([inp.my_output for inp in inputs])[0] + print("My output is: ", self.my_output) + self.next(self.step_e) + + @step + def step_e(self): + print(f"I am step E. Input is: {self.input}") + self.split_e = [9, 10] + print("My output is: ", self.my_output) + self.next(self.step_f, foreach="split_e") + + @step + def step_f(self): + self.my_output = self.my_output + [self.input] + print("My output is: ", self.my_output) + self.next(self.step_g) + + @step + def step_g(self): + print("My output is: ", self.my_output) + self.next(self.step_h) + + @step + def step_h(self, inputs): + self.my_output = sorted([inp.my_output for inp in inputs])[0] + print("My output is: ", self.my_output) + self.next(self.step_i) + + @step + def step_i(self, inputs): + self.my_output = sorted([inp.my_output for inp in inputs])[0] + print("My output is: ", self.my_output) + self.next(self.step_j) + + @step + def step_j(self): + print("My output is: ", self.my_output) + self.next(self.step_k, self.step_l) + + @step + def step_k(self): + self.my_output = self.my_output + [11] + print("My output is: ", self.my_output) + self.next(self.step_m) + + @step + def step_l(self): + print(f"I am step L. Input is: {self.input}") + self.my_output = self.my_output + [12] + print("My output is: ", self.my_output) + self.next(self.step_m) + + @conda(libraries={"scikit-learn": "1.3.0"}) + @step + def step_m(self, inputs): + import sklearn + + self.sklearn_version = sklearn.__version__ + self.my_output = sorted([inp.my_output for inp in inputs])[0] + print("Sklearn version: ", self.sklearn_version) + print("My output is: ", self.my_output) + self.next(self.step_n) + + @step + def step_n(self, inputs): + self.my_output = sorted([inp.my_output for inp in inputs])[0] + print("My output is: ", self.my_output) + self.next(self.end) + + @step + def end(self): + self.my_output = self.my_output + [13] + print("My output is: ", self.my_output) + print("Flow is complete!") + + +if __name__ == "__main__": + ComplexDAGFlow() diff --git a/test/unit/spin/flows/hello_spin_flow.py b/test/unit/spin/flows/hello_spin_flow.py new file mode 100644 index 00000000000..2df4a6aeee6 --- /dev/null +++ b/test/unit/spin/flows/hello_spin_flow.py @@ -0,0 +1,26 @@ +from metaflow import FlowSpec, step +import random + + +class HelloSpinFlow(FlowSpec): + + @step + def start(self): + chunk_size = 1024 * 1024 # 1 MB + total_size = 1024 * 1024 * 1000 # 1000 MB + + data = bytearray() + for _ in range(total_size // chunk_size): + data.extend(random.randbytes(chunk_size)) + + self.a = data + self.next(self.end) + + @step + def end(self): + print(f"Size of artifact a: {len(self.a)} bytes") + print("HelloSpinFlow completed.") + + +if __name__ == "__main__": + HelloSpinFlow() \ No newline at end of file diff --git a/test/unit/spin/flows/merge_artifacts_flow.py b/test/unit/spin/flows/merge_artifacts_flow.py new file mode 100644 index 00000000000..fe49f8c10be --- /dev/null +++ b/test/unit/spin/flows/merge_artifacts_flow.py @@ -0,0 +1,62 @@ +from metaflow import FlowSpec, step + + +class MergeArtifactsFlow(FlowSpec): + + @step + def start(self): + self.pass_down = "a" + self.next(self.a, self.b) + + @step + def a(self): + self.common = 5 + self.x = 1 + self.y = 3 + self.from_a = 6 + self.next(self.join) + + @step + def b(self): + self.common = 5 + self.x = 2 + self.y = 4 + self.next(self.join) + + @step + def join(self, inputs): + self.x = inputs.a.x + self.merge_artifacts(inputs, exclude=["y"]) + print("x is %s" % self.x) + print("pass_down is %s" % self.pass_down) + print("common is %d" % self.common) + print("from_a is %d" % self.from_a) + self.next(self.c) + + @step + def c(self): + self.next(self.d, self.e) + + @step + def d(self): + self.conflicting = 7 + self.next(self.join2) + + @step + def e(self): + self.conflicting = 8 + self.next(self.join2) + + @step + def join2(self, inputs): + self.merge_artifacts(inputs, include=["pass_down", "common"]) + print("Only pass_down and common exist here") + self.next(self.end) + + @step + def end(self): + pass + + +if __name__ == "__main__": + MergeArtifactsFlow() diff --git a/test/unit/spin/flows/myconfig.json b/test/unit/spin/flows/myconfig.json new file mode 100644 index 00000000000..c24b31c1e41 --- /dev/null +++ b/test/unit/spin/flows/myconfig.json @@ -0,0 +1 @@ +{"timeout": 60} \ No newline at end of file diff --git a/test/unit/spin/flows/simple_card_flow.py b/test/unit/spin/flows/simple_card_flow.py new file mode 100644 index 00000000000..83142d08ba8 --- /dev/null +++ b/test/unit/spin/flows/simple_card_flow.py @@ -0,0 +1,46 @@ +from metaflow import FlowSpec, step, card, Parameter, current +from metaflow.cards import Markdown + +import requests, pandas, string + +URL = "https://upload.wikimedia.org/wikipedia/commons/4/45/Blue_Marble_rotating.gif" + + +class SimpleCardFlow(FlowSpec): + number = Parameter("number", default=3) + image_url = Parameter("image_url", default=URL) + + @card(type="blank") + @step + def start(self): + current.card.append(Markdown("# Guess my number")) + if self.number > 5: + current.card.append(Markdown("My number is **smaller** ⬇️")) + elif self.number < 5: + current.card.append(Markdown("My number is **larger** ⬆️")) + else: + current.card.append(Markdown("## Correct! 🎉")) + + self.next(self.a) + + @step + def a(self): + print(f"image: {self.image_url}") + self.image = requests.get( + self.image_url, headers={"user-agent": "metaflow-example"} + ).content + self.dataframe = pandas.DataFrame( + { + "lowercase": list(string.ascii_lowercase), + "uppercase": list(string.ascii_uppercase), + } + ) + self.next(self.end) + + @step + def end(self): + pass + + +if __name__ == "__main__": + SimpleCardFlow() diff --git a/test/unit/spin/flows/simple_config_flow.py b/test/unit/spin/flows/simple_config_flow.py new file mode 100644 index 00000000000..d4e910e1e29 --- /dev/null +++ b/test/unit/spin/flows/simple_config_flow.py @@ -0,0 +1,22 @@ +import time +from metaflow import FlowSpec, step, Config, timeout + + +class TimeoutConfigFlow(FlowSpec): + config = Config("config", default="myconfig.json") + + @timeout(seconds=config.timeout) + @step + def start(self): + print(f"timing out after {self.config.timeout} seconds") + time.sleep(5) + print("success") + self.next(self.end) + + @step + def end(self): + print("full config", self.config) + + +if __name__ == "__main__": + TimeoutConfigFlow() diff --git a/test/unit/spin/flows/simple_parameter_flow.py b/test/unit/spin/flows/simple_parameter_flow.py new file mode 100644 index 00000000000..bf1969326a2 --- /dev/null +++ b/test/unit/spin/flows/simple_parameter_flow.py @@ -0,0 +1,33 @@ +from metaflow import FlowSpec, step, Parameter, current, project + + +@project(name="simple_parameter_flow") +class SimpleParameterFlow(FlowSpec): + alpha = Parameter("alpha", help="Learning rate", default=0.01) + + @step + def start(self): + print("SimpleParameterFlow is starting.") + print(f"Parameter alpha is set to: {self.alpha}") + self.a = 10 + self.b = 20 + self.next(self.end) + + @step + def end(self): + self.a = 50 + self.x = 100 + self.y = 200 + print("Parameter alpha in end step is: ", self.alpha) + print( + f"Pathspec: {current.pathspec}, flow_name: {current.flow_name}, run_id: {current.run_id}" + ) + print(f"step_name: {current.step_name}, task_id: {current.task_id}") + print(f"Project name: {current.project_name}, Namespace: {current.namespace}") + del self.a + del self.x + print("SimpleParameterFlow is all done.") + + +if __name__ == "__main__": + SimpleParameterFlow() diff --git a/test/unit/spin/spin_test_helpers.py b/test/unit/spin/spin_test_helpers.py new file mode 100644 index 00000000000..b2c61e37457 --- /dev/null +++ b/test/unit/spin/spin_test_helpers.py @@ -0,0 +1,32 @@ +import os +from metaflow import Runner + +FLOWS_DIR = os.path.join(os.path.dirname(__file__), "flows") +ARTIFACTS_DIR = os.path.join(os.path.dirname(__file__), "artifacts") + + +def assert_artifacts(task, spin_task): + """Assert that artifacts match between original task and spin task.""" + spin_task_artifacts = { + artifact.id: artifact.data for artifact in spin_task.artifacts + } + print(f"Spin task artifacts: {spin_task_artifacts}") + for artifact in task.artifacts: + assert ( + artifact.id in spin_task_artifacts + ), f"Artifact {artifact.id} not found in spin task" + assert ( + artifact.data == spin_task_artifacts[artifact.id] + ), f"Expected {artifact.data} but got {spin_task_artifacts[artifact.id]} for artifact {artifact.id}" + + +def run_step(flow_file, run, step_name, **tl_kwargs): + """Run a step and assert artifacts match.""" + task = run[step_name].task + flow_path = os.path.join(FLOWS_DIR, flow_file) + print(f"FLOWS_DIR: {FLOWS_DIR}") + + with Runner(flow_path, cwd=FLOWS_DIR, **tl_kwargs).spin(task.pathspec) as spin: + print("-" * 50) + print(f"Running test for step: {step_name} with task pathspec: {task.pathspec}") + assert_artifacts(task, spin.task) diff --git a/test/unit/spin/test_spin.py b/test/unit/spin/test_spin.py new file mode 100644 index 00000000000..04e9a852c1e --- /dev/null +++ b/test/unit/spin/test_spin.py @@ -0,0 +1,165 @@ +import pytest +from metaflow import Runner +import os +from spin_test_helpers import assert_artifacts, run_step, FLOWS_DIR, ARTIFACTS_DIR + + +@pytest.mark.parametrize( + "flow_file,fixture_name", + [ + ("merge_artifacts_flow.py", "merge_artifacts_run"), + ("simple_config_flow.py", "simple_config_run"), + ("simple_parameter_flow.py", "simple_parameter_run"), + ("complex_dag_flow.py", "complex_dag_run"), + ], + ids=["merge_artifacts", "simple_config", "simple_parameter", "complex_dag"], +) +def test_simple_flows(flow_file, fixture_name, request): + """Test simple flows that just need artifact validation.""" + run = request.getfixturevalue(fixture_name) + print(f"Running test for {flow_file}: {run}") + for step in run.steps(): + print("-" * 100) + if fixture_name == "complex_dag_run": + run_step(flow_file, run, step.id, environment="conda") + else: + run_step(flow_file, run, step.id) + + +def test_artifacts_module(complex_dag_run): + print(f"Running test for artifacts module in ComplexDAGFlow: {complex_dag_run}") + step_name = "step_a" + task = complex_dag_run[step_name].task + flow_path = os.path.join(FLOWS_DIR, "complex_dag_flow.py") + artifacts_path = os.path.join(ARTIFACTS_DIR, "complex_dag_step_a.py") + + with Runner(flow_path, environment="conda").spin( + task.pathspec, + artifacts_module=artifacts_path, + ) as spin: + print("-" * 50) + print(f"Running test for step: step_a with task pathspec: {task.pathspec}") + spin_task = spin.task + print(f"my_output: {spin_task['my_output']}") + assert spin_task["my_output"].data == [10, 11, 12, 3] + + +def test_artifacts_module_join_step( + complex_dag_run, complex_dag_step_d_artifacts, tmp_path +): + print(f"Running test for artifacts module in ComplexDAGFlow: {complex_dag_run}") + step_name = "step_d" + task = complex_dag_run[step_name].task + flow_path = os.path.join(FLOWS_DIR, "complex_dag_flow.py") + + # Create a temporary artifacts file with dynamic data + temp_artifacts_file = tmp_path / "temp_complex_dag_step_d.py" + temp_artifacts_file.write_text(f"ARTIFACTS = {repr(complex_dag_step_d_artifacts)}") + + with Runner(flow_path, environment="conda").spin( + task.pathspec, + artifacts_module=str(temp_artifacts_file), + ) as spin: + print("-" * 50) + print(f"Running test for step: step_d with task pathspec: {task.pathspec}") + spin_task = spin.task + assert spin_task["my_output"].data == [-1] + + +def test_timeout_decorator_enforcement(simple_config_run): + """Test that timeout decorator properly enforces timeout limits.""" + step_name = "start" + task = simple_config_run[step_name].task + flow_path = os.path.join(FLOWS_DIR, "simple_config_flow.py") + + # With decorator enabled (should timeout and raise exception) + with pytest.raises(Exception): + with Runner( + flow_path, cwd=FLOWS_DIR, config_value=[("config", {"timeout": 2})] + ).spin( + task.pathspec, + ): + pass + + +def test_skip_decorators_bypass(simple_config_run): + """Test that skip_decorators successfully bypasses timeout decorator.""" + step_name = "start" + task = simple_config_run[step_name].task + flow_path = os.path.join(FLOWS_DIR, "simple_config_flow.py") + + # With skip_decorators=True (should succeed despite timeout) + with Runner( + flow_path, cwd=FLOWS_DIR, config_value=[("config", {"timeout": 2})] + ).spin( + task.pathspec, + skip_decorators=True, + ) as spin: + print(f"Running test for step: {step_name} with skip_decorators=True") + # Should complete successfully even though sleep(5) > timeout(2) + spin_task = spin.task + assert spin_task.finished + + +def test_hidden_artifacts(simple_parameter_run): + """Test simple flows that just need artifact validation.""" + step_name = "start" + task = simple_parameter_run[step_name].task + flow_path = os.path.join(FLOWS_DIR, "simple_parameter_flow.py") + print(f"Running test for hidden artifacts in {flow_path}: {simple_parameter_run}") + + with Runner(flow_path, cwd=FLOWS_DIR).spin(task.pathspec) as spin: + spin_task = spin.task + assert "_graph_info" in spin_task + assert "_foreach_stack" in spin_task + + +def test_card_flow(simple_card_run): + """Test a simple flow that has @card decorator.""" + step_name = "start" + task = simple_card_run[step_name].task + flow_path = os.path.join(FLOWS_DIR, "simple_card_flow.py") + print(f"Running test for cards in {flow_path}: {simple_card_run}") + + with Runner(flow_path, cwd=FLOWS_DIR).spin(task.pathspec) as spin: + spin_task = spin.task + from metaflow.cards import get_cards + + res = get_cards(spin_task, follow_resumed=False) + print(res) + +def test_inspect_spin_client_access(simple_parameter_run): + """Test accessing spin artifacts using inspect_spin client directly.""" + from metaflow import inspect_spin, Task + import tempfile + + step_name = "start" + task = simple_parameter_run[step_name].task + flow_path = os.path.join(FLOWS_DIR, "simple_parameter_flow.py") + + with tempfile.TemporaryDirectory() as tmpdir: + # Run spin to generate artifacts + with Runner(flow_path, cwd=FLOWS_DIR).spin( + task.pathspec, + ) as spin: + spin_task = spin.task + spin_pathspec = spin_task.pathspec + assert spin_task['a'] is not None + assert spin_task['b'] is not None + + assert spin_pathspec is not None + + # Set metadata provider to spin + inspect_spin(FLOWS_DIR) + client_task = Task(spin_pathspec, _namespace_check=False) + + # Verify task is accessible + assert client_task is not None + + # Verify artifacts + assert hasattr(client_task, 'artifacts') + + # Verify artifact data + assert client_task.artifacts.a.data == 10 + assert client_task.artifacts.b.data == 20 + assert client_task.artifacts.alpha.data == 0.05 \ No newline at end of file