diff --git a/src/sagemaker/hyperpod/__init__.py b/src/sagemaker/hyperpod/__init__.py index 353e8a0a..dd60ddfc 100644 --- a/src/sagemaker/hyperpod/__init__.py +++ b/src/sagemaker/hyperpod/__init__.py @@ -1,2 +1,70 @@ -from .common.utils import * -from .observability.MonitoringConfig import MonitoringConfig \ No newline at end of file +# Lazy loading implementation to avoid importing heavy dependencies until needed +import sys +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + # Type hints for IDE support without runtime imports + from .observability.MonitoringConfig import MonitoringConfig + +from .common.lazy_loading import setup_lazy_module + +HYPERPOD_CONFIG = { + 'exports': [ + # Common utilities (lazy loaded) + 'get_default_namespace', + 'handle_exception', + 'get_eks_name_from_arn', + 'get_region_from_eks_arn', + 'get_jumpstart_model_instance_types', + 'get_cluster_instance_types', + 'setup_logging', + 'is_eks_orchestrator', + 'update_kube_config', + 'set_eks_context', + 'set_cluster_context', + 'get_cluster_context', + 'list_clusters', + 'get_current_cluster', + 'get_current_region', + 'parse_client_kubernetes_version', + 'is_kubernetes_version_compatible', + 'display_formatted_logs', + 'verify_kubernetes_version_compatibility', + # Observability + 'MonitoringConfig', + # Constants + 'EKS_ARN_PATTERN', + 'CLIENT_VERSION_PATTERN', + 'KUBE_CONFIG_PATH' + ], + 'lazy_imports': { + # Common utilities + 'get_default_namespace': 'sagemaker.hyperpod.common.utils:get_default_namespace', + 'handle_exception': 'sagemaker.hyperpod.common.utils:handle_exception', + 'get_eks_name_from_arn': 'sagemaker.hyperpod.common.utils:get_eks_name_from_arn', + 'get_region_from_eks_arn': 'sagemaker.hyperpod.common.utils:get_region_from_eks_arn', + 'get_jumpstart_model_instance_types': 'sagemaker.hyperpod.common.utils:get_jumpstart_model_instance_types', + 'get_cluster_instance_types': 'sagemaker.hyperpod.common.utils:get_cluster_instance_types', + 'setup_logging': 'sagemaker.hyperpod.common.utils:setup_logging', + 'is_eks_orchestrator': 'sagemaker.hyperpod.common.utils:is_eks_orchestrator', + 'update_kube_config': 'sagemaker.hyperpod.common.utils:update_kube_config', + 'set_eks_context': 'sagemaker.hyperpod.common.utils:set_eks_context', + 'set_cluster_context': 'sagemaker.hyperpod.common.utils:set_cluster_context', + 'get_cluster_context': 'sagemaker.hyperpod.common.utils:get_cluster_context', + 'list_clusters': 'sagemaker.hyperpod.common.utils:list_clusters', + 'get_current_cluster': 'sagemaker.hyperpod.common.utils:get_current_cluster', + 'get_current_region': 'sagemaker.hyperpod.common.utils:get_current_region', + 'parse_client_kubernetes_version': 'sagemaker.hyperpod.common.utils:parse_client_kubernetes_version', + 'is_kubernetes_version_compatible': 'sagemaker.hyperpod.common.utils:is_kubernetes_version_compatible', + 'display_formatted_logs': 'sagemaker.hyperpod.common.utils:display_formatted_logs', + 'verify_kubernetes_version_compatibility': 'sagemaker.hyperpod.common.utils:verify_kubernetes_version_compatibility', + # Observability + 'MonitoringConfig': 'sagemaker.hyperpod.observability.MonitoringConfig:MonitoringConfig', + # Constants + 'EKS_ARN_PATTERN': 'sagemaker.hyperpod.common.utils:EKS_ARN_PATTERN', + 'CLIENT_VERSION_PATTERN': 'sagemaker.hyperpod.common.utils:CLIENT_VERSION_PATTERN', + 'KUBE_CONFIG_PATH': 'sagemaker.hyperpod.common.utils:KUBE_CONFIG_PATH' + } +} + +setup_lazy_module(__name__, HYPERPOD_CONFIG) diff --git a/src/sagemaker/hyperpod/cli/command_registry.py b/src/sagemaker/hyperpod/cli/command_registry.py new file mode 100644 index 00000000..137cbaaf --- /dev/null +++ b/src/sagemaker/hyperpod/cli/command_registry.py @@ -0,0 +1,190 @@ +""" +Command Registry System for SageMaker HyperPod CLI + +This module provides a centralized way to register and discover CLI commands, +eliminating hardcoded command mappings throughout the codebase. +""" + +from typing import Dict, List, Optional, Tuple, Callable, Any +import importlib +from dataclasses import dataclass, field + + +@dataclass +class CommandMetadata: + """Metadata for a CLI command""" + name: str + help_text: str + module_name: str + import_path: str + parent_group: Optional[str] = None + + +@dataclass +class CommandGroup: + """Represents a CLI command group""" + name: str + help_text: str + commands: List[CommandMetadata] = field(default_factory=list) + + +class CommandRegistry: + """ + Central registry for CLI commands that eliminates hardcoded mappings. + + Commands register themselves with metadata, and the CLI dynamically + discovers and loads them as needed. + """ + + def __init__(self): + self._commands: Dict[str, CommandMetadata] = {} + self._groups: Dict[str, CommandGroup] = {} + self._module_to_commands: Dict[str, List[str]] = {} + self._initialized = False + + def register_command( + self, + name: str, + help_text: str, + module_name: str, + import_path: str, + parent_group: Optional[str] = None + ): + """Register a command with the registry""" + cmd = CommandMetadata( + name=name, + help_text=help_text, + module_name=module_name, + import_path=import_path, + parent_group=parent_group + ) + + self._commands[name] = cmd + + # Track commands by module + if module_name not in self._module_to_commands: + self._module_to_commands[module_name] = [] + self._module_to_commands[module_name].append(name) + + # Add to group if specified + if parent_group: + if parent_group not in self._groups: + self._groups[parent_group] = CommandGroup(parent_group, f"{parent_group.title()} operations.") + self._groups[parent_group].commands.append(cmd) + + def register_group(self, name: str, help_text: str): + """Register a command group""" + if name not in self._groups: + self._groups[name] = CommandGroup(name, help_text) + + def get_command_metadata(self, name: str) -> Optional[CommandMetadata]: + """Get metadata for a specific command""" + return self._commands.get(name) + + def get_commands_by_module(self, module_name: str) -> List[CommandMetadata]: + """Get all commands for a specific module""" + command_names = self._module_to_commands.get(module_name, []) + return [self._commands[name] for name in command_names] + + def get_top_level_commands(self) -> List[str]: + """Get all top-level commands (no parent group)""" + return [name for name, cmd in self._commands.items() if cmd.parent_group is None] + + def get_subcommands(self, group_name: str) -> List[str]: + """Get all subcommands for a group""" + group = self._groups.get(group_name) + return [cmd.name for cmd in group.commands] if group else [] + + def get_all_groups(self) -> List[str]: + """Get all registered group names""" + return list(self._groups.keys()) + + def get_module_for_command(self, name: str) -> Optional[str]: + """Get the module name that provides a command""" + cmd = self._commands.get(name) + return cmd.module_name if cmd else None + + def initialize_registry(self): + """Initialize the registry - commands will self-register via decorators""" + if self._initialized: + return + + # Register command groups only - commands will auto-register themselves + self.register_group('create', 'Create endpoints or pytorch jobs.') + self.register_group('list', 'List endpoints or pytorch jobs.') + self.register_group('describe', 'Describe endpoints or pytorch jobs.') + self.register_group('delete', 'Delete endpoints or pytorch jobs.') + self.register_group('list-pods', 'List pods for endpoints or pytorch jobs.') + self.register_group('get-logs', 'Get pod logs for endpoints or pytorch jobs.') + self.register_group('invoke', 'Invoke model endpoints.') + self.register_group('get-operator-logs', 'Get operator logs for endpoints.') + + self._initialized = True + + def ensure_commands_loaded(self): + """Ensure command modules are imported so they can self-register""" + try: + # Import modules to trigger self-registration + import sagemaker.hyperpod.cli.commands.cluster + import sagemaker.hyperpod.cli.commands.training + import sagemaker.hyperpod.cli.commands.inference + except ImportError: + pass # Modules will be loaded when needed + + +# Command Registration Decorators +def register_command(name: str, module_name: str, parent_group: str = None): + """ + Decorator that auto-registers commands with the registry. + Extracts help text from the Click command's docstring. + + Usage: + @register_command("pytorch-job", "training", "create") + def pytorch_create(): + '''Create a new PyTorch training job.''' + pass + """ + def decorator(func): + # Extract help text from function docstring + help_text = func.__doc__.strip() if func.__doc__ else f"{name.replace('-', ' ').title()} operations." + + # Auto-register with registry (done at import time) + registry = get_registry() + registry.register_command( + name=name, + help_text=help_text, + module_name=module_name, + import_path=f"{func.__module__}:{func.__name__}", + parent_group=parent_group + ) + + # Import click here to avoid import issues during lazy loading + import click + + # Create Click command + click_cmd = click.command(name)(func) + + return click_cmd + + return decorator + +def register_cluster_command(name: str): + """Register a top-level cluster command.""" + return register_command(name, 'cluster', parent_group=None) + +def register_training_command(name: str, group: str): + """Register a training command in specified group.""" + return register_command(name, 'training', parent_group=group) + +def register_inference_command(name: str, group: str): + """Register an inference command in specified group.""" + return register_command(name, 'inference', parent_group=group) + + +# Global registry instance +_registry = CommandRegistry() + +def get_registry() -> CommandRegistry: + """Get the global command registry instance""" + _registry.initialize_registry() + return _registry diff --git a/src/sagemaker/hyperpod/cli/commands/cluster.py b/src/sagemaker/hyperpod/cli/commands/cluster.py index bd641867..4e236a61 100644 --- a/src/sagemaker/hyperpod/cli/commands/cluster.py +++ b/src/sagemaker/hyperpod/cli/commands/cluster.py @@ -14,20 +14,15 @@ import subprocess import json import sys -import botocore.config from collections import defaultdict from typing import Any, Dict, List, Optional, Tuple -import boto3 import click -from botocore.client import BaseClient -from kubernetes import client from ratelimit import limits, sleep_and_retry from tabulate import tabulate +from sagemaker.hyperpod.common.lazy_loading import LazyDecorator, setup_lazy_module +from sagemaker.hyperpod.cli.command_registry import register_cluster_command -from sagemaker.hyperpod.cli.clients.kubernetes_client import ( - KubernetesClient, -) from sagemaker.hyperpod.cli.constants.command_constants import ( AVAILABLE_ACCELERATOR_DEVICES_KEY, DEEP_HEALTH_CHECK_STATUS_LABEL, @@ -42,36 +37,46 @@ TEMP_KUBE_CONFIG_FILE, OutputFormat, ) -from sagemaker.hyperpod.common.telemetry.user_agent import ( - get_user_agent_extra_suffix, -) -from sagemaker.hyperpod.cli.service.list_pods import ( - ListPods, -) -from sagemaker.hyperpod.cli.utils import ( - get_name_from_arn, - get_sagemaker_client, - setup_logger, - set_logging_level, - store_current_hyperpod_context, -) -from sagemaker.hyperpod.cli.validators.cluster_validator import ( - ClusterValidator, -) -from sagemaker.hyperpod.cli.utils import ( - get_eks_cluster_name, -) -from sagemaker.hyperpod.common.utils import ( - get_cluster_context as get_cluster_context_util, -) -from sagemaker.hyperpod.observability.utils import ( - get_monitoring_config, - is_observability_addon_enabled, -) -from sagemaker.hyperpod.common.telemetry.telemetry_logging import ( - _hyperpod_telemetry_emitter, -) -from sagemaker.hyperpod.common.telemetry.constants import Feature + +from sagemaker.hyperpod.cli.utils import setup_logger +from sagemaker.hyperpod.cli.validators.cluster_validator import ClusterValidator + +CLUSTER_CONFIG = { + 'exports': [ + 'boto3', 'botocore', 'BaseClient', 'client', 'KubernetesClient', + 'get_user_agent_extra_suffix', 'ListPods', 'get_name_from_arn', + 'get_sagemaker_client', 'set_logging_level', 'store_current_hyperpod_context', + 'ClusterValidator', 'get_eks_cluster_name', 'get_cluster_context_util', + 'get_monitoring_config', 'is_observability_addon_enabled', + '_hyperpod_telemetry_emitter', 'Feature' + ], + 'critical_deps': ['telemetry_emitter', 'telemetry_feature'], + 'lazy_imports': { + 'boto3': 'boto3', + 'botocore': 'botocore.config', + 'BaseClient': 'botocore.client:BaseClient', + 'client': 'kubernetes:client', + 'KubernetesClient': 'sagemaker.hyperpod.cli.clients.kubernetes_client:KubernetesClient', + 'get_user_agent_extra_suffix': 'sagemaker.hyperpod.common.telemetry.user_agent:get_user_agent_extra_suffix', + 'ListPods': 'sagemaker.hyperpod.cli.service.list_pods:ListPods', + 'get_name_from_arn': 'sagemaker.hyperpod.cli.utils:get_name_from_arn', + 'get_sagemaker_client': 'sagemaker.hyperpod.cli.utils:get_sagemaker_client', + 'set_logging_level': 'sagemaker.hyperpod.cli.utils:set_logging_level', + 'store_current_hyperpod_context': 'sagemaker.hyperpod.cli.utils:store_current_hyperpod_context', + 'get_eks_cluster_name': 'sagemaker.hyperpod.cli.utils:get_eks_cluster_name', + 'get_cluster_context_util': 'sagemaker.hyperpod.common.utils:get_cluster_context', + 'get_monitoring_config': 'sagemaker.hyperpod.observability.utils:get_monitoring_config', + 'is_observability_addon_enabled': 'sagemaker.hyperpod.observability.utils:is_observability_addon_enabled', + '_hyperpod_telemetry_emitter': 'sagemaker.hyperpod.common.telemetry.telemetry_logging:_hyperpod_telemetry_emitter', + 'Feature': 'sagemaker.hyperpod.common.telemetry.constants:Feature' + } +} + +setup_lazy_module(__name__, CLUSTER_CONFIG) + +# Helper functions for decorators +def _get_telemetry_emitter(): + return getattr(sys.modules[__name__], '_hyperpod_telemetry_emitter') RATE_LIMIT = 4 RATE_LIMIT_PERIOD = 1 # 1 second @@ -79,7 +84,7 @@ logger = setup_logger(__name__) -@click.command() +@register_cluster_command("list-cluster") @click.option( "--region", type=click.STRING, @@ -112,7 +117,7 @@ multiple=True, help="Optional. The namespace that you want to check the capacity for. Only SageMaker managed namespaces are supported.", ) -@_hyperpod_telemetry_emitter(Feature.HYPERPOD, "list_cluster") +@LazyDecorator(_get_telemetry_emitter, lambda: sys.modules[__name__].Feature.HYPERPOD, "list_cluster") def list_cluster( region: Optional[str], output: Optional[str], @@ -151,7 +156,7 @@ def list_cluster( hyperpod-eks-cluster-a | ml.g5.2xlarge | 2 | 1| 2 | N/A | 1 | 1 """ if debug: - set_logging_level(logger, logging.DEBUG) + sys.modules[__name__].set_logging_level(logger, logging.DEBUG) validator = ClusterValidator() # Make use of user_agent_extra field of the botocore_config object @@ -160,18 +165,18 @@ def list_cluster( # This config will also make sure that user_agent never fails to log the User-Agent string # even if boto User-Agent header format is updated in the future # Ref: https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html - botocore_config = botocore.config.Config( - user_agent_extra=get_user_agent_extra_suffix() + botocore_config = sys.modules[__name__].botocore.config.Config( + user_agent_extra=sys.modules[__name__].get_user_agent_extra_suffix() ) - session = boto3.Session(region_name=region) if region else boto3.Session() + session = sys.modules[__name__].boto3.Session(region_name=region) if region else sys.modules[__name__].boto3.Session() if not validator.validate_aws_credential(session): logger.error("Failed to list clusters capacity due to invalid AWS credentials.") sys.exit(1) try: - sm_client = get_sagemaker_client(session, botocore_config) - except botocore.exceptions.NoRegionError: + sm_client = sys.modules[__name__].get_sagemaker_client(session, botocore_config) + except sys.modules[__name__].botocore.exceptions.NoRegionError: logger.error( f"Please ensure you have configured the AWS default region or use the '--region' argument to specify the region." ) @@ -242,7 +247,7 @@ def list_cluster( def rate_limited_operation( cluster_name: str, validator: ClusterValidator, - sm_client: BaseClient, + sm_client: Any, region: Optional[str], temp_config_file: str, cluster_capacities: List[List[str]], @@ -257,9 +262,9 @@ def rate_limited_operation( f"Cannot find EKS cluster behind {cluster_name}, continue..." ) return - eks_cluster_name = get_name_from_arn(eks_cluster_arn) + eks_cluster_name = sys.modules[__name__].get_name_from_arn(eks_cluster_arn) _update_kube_config(eks_cluster_name, region, temp_config_file) - k8s_client = KubernetesClient(is_get_capacity=True) + k8s_client = sys.modules[__name__].KubernetesClient(is_get_capacity=True) nodes = k8s_client.list_node_with_temp_config( temp_config_file, SAGEMAKER_HYPERPOD_NAME_LABEL ) @@ -367,7 +372,7 @@ def _get_available_quota(nominal, usage, flavor, resource_name): return "N/A" -def _get_hyperpod_clusters(sm_client: boto3.client) -> List[str]: +def _get_hyperpod_clusters(sm_client: Any) -> List[str]: cluster_names: List[str] = [] response = sm_client.list_clusters() if "ClusterSummaries" in response: @@ -410,9 +415,9 @@ def _restructure_output(summary_list, namespaces): def _aggregate_nodes_info( - nodes: List[client.V1Node], + nodes: List[Any], ) -> Dict[str, Dict[str, Any]]: - list_pods_service = ListPods() + list_pods_service = sys.modules[__name__].ListPods() nodes_resource_allocated_dict = ( list_pods_service.list_pods_and_get_requested_resources_group_by_node_name() ) @@ -473,7 +478,7 @@ def _aggregate_nodes_info( return nodes_summary -@click.command() +@register_cluster_command("set-cluster-context") @click.option( "--cluster-name", type=click.STRING, @@ -518,31 +523,31 @@ def set_cluster_context( None """ if debug: - set_logging_level(logger, logging.DEBUG) + sys.modules[__name__].set_logging_level(logger, logging.DEBUG) validator = ClusterValidator() - botocore_config = botocore.config.Config( - user_agent_extra=get_user_agent_extra_suffix() + botocore_config = sys.modules[__name__].botocore.config.Config( + user_agent_extra=sys.modules[__name__].get_user_agent_extra_suffix() ) - session = boto3.Session(region_name=region) if region else boto3.Session() + session = sys.modules[__name__].boto3.Session(region_name=region) if region else sys.modules[__name__].boto3.Session() if not validator.validate_aws_credential(session): logger.error("Cannot connect to HyperPod cluster due to aws credentials error") sys.exit(1) try: - sm_client = get_sagemaker_client(session, botocore_config) + sm_client = sys.modules[__name__].get_sagemaker_client(session, botocore_config) hp_cluster_details = sm_client.describe_cluster(ClusterName=cluster_name) logger.debug("Fetched hyperpod cluster details") - store_current_hyperpod_context(hp_cluster_details) + sys.modules[__name__].store_current_hyperpod_context(hp_cluster_details) eks_cluster_arn = hp_cluster_details["Orchestrator"]["Eks"]["ClusterArn"] logger.debug( f"hyperpod cluster's EKS orchestrator cluster arn: {eks_cluster_arn}" ) - eks_name = get_name_from_arn(eks_cluster_arn) + eks_name = sys.modules[__name__].get_name_from_arn(eks_cluster_arn) _update_kube_config(eks_name, region, None) - k8s_client = KubernetesClient() + k8s_client = sys.modules[__name__].KubernetesClient() k8s_client.set_context(eks_cluster_arn, namespace) - except botocore.exceptions.NoRegionError: + except sys.modules[__name__].botocore.exceptions.NoRegionError: logger.error( f"Please ensure you configured AWS default region or use '--region' argument to specify the region" ) @@ -554,7 +559,7 @@ def set_cluster_context( sys.exit(1) -@click.command() +@register_cluster_command("get-cluster-context") @click.option( "--debug", is_flag=True, @@ -573,12 +578,12 @@ def get_cluster_context( None """ if debug: - set_logging_level(logger, logging.DEBUG) + sys.modules[__name__].set_logging_level(logger, logging.DEBUG) try: - current_context = get_cluster_context_util() + current_context = sys.modules[__name__].get_cluster_context_util() print(f"Cluster context:{current_context}") - except botocore.exceptions.NoRegionError: + except sys.modules[__name__].botocore.exceptions.NoRegionError: logger.error( f"Please ensure you configured AWS default region or use '--region' argument to specify the region" ) @@ -590,7 +595,7 @@ def get_cluster_context( sys.exit(1) -@click.command() +@register_cluster_command("get-monitoring") @click.option("--grafana", is_flag=True, help="Returns Grafana Dashboard URL") @click.option("--prometheus", is_flag=True, help="Returns Prometheus Workspace URL") @click.option("--list", is_flag=True, help="Returns list of available metrics") @@ -601,10 +606,10 @@ def get_monitoring(grafana: bool, prometheus: bool, list: bool) -> None: print("Error: Please select at least one option") print("Usage : hyp get-monitoring --grafana/--prometheus/--list/--help") return - if not is_observability_addon_enabled(get_eks_cluster_name()): + if not sys.modules[__name__].is_observability_addon_enabled(sys.modules[__name__].get_eks_cluster_name()): print("Observability addon is not enabled for this cluster") sys.exit(1) - monitor_config = get_monitoring_config() + monitor_config = sys.modules[__name__].get_monitoring_config() if prometheus: print(f"Prometheus workspace URL: {monitor_config.prometheusURL}") if grafana: diff --git a/src/sagemaker/hyperpod/cli/commands/inference.py b/src/sagemaker/hyperpod/cli/commands/inference.py index cba3e60c..df71f2a0 100644 --- a/src/sagemaker/hyperpod/cli/commands/inference.py +++ b/src/sagemaker/hyperpod/cli/commands/inference.py @@ -1,25 +1,68 @@ import click import json -import boto3 -from typing import Optional +import sys +from typing import Optional, Any from tabulate import tabulate - -from sagemaker.hyperpod.cli.inference_utils import generate_click_command -from hyperpod_jumpstart_inference_template.registry import SCHEMA_REGISTRY as JS_REG -from hyperpod_custom_inference_template.registry import SCHEMA_REGISTRY as C_REG -from sagemaker.hyperpod.inference.hp_jumpstart_endpoint import HPJumpStartEndpoint -from sagemaker.hyperpod.inference.hp_endpoint import HPEndpoint -from sagemaker_core.resources import Endpoint -from sagemaker.hyperpod.common.telemetry.telemetry_logging import ( - _hyperpod_telemetry_emitter, -) -from sagemaker.hyperpod.common.telemetry.constants import Feature from sagemaker.hyperpod.common.cli_decorators import handle_cli_exceptions -from sagemaker.hyperpod.common.utils import display_formatted_logs +from sagemaker.hyperpod.common.lazy_loading import ( + LazyDecorator, setup_lazy_module +) +from sagemaker.hyperpod.cli.command_registry import register_inference_command + +INFERENCE_CONFIG = { + 'exports': [ + 'boto3', 'generate_click_command', 'JS_REG', 'C_REG', 'HPJumpStartEndpoint', + 'HPEndpoint', 'Endpoint', '_hyperpod_telemetry_emitter', 'Feature', 'display_formatted_logs' + ], + 'template_packages': { + 'jumpstart_template_package': 'hyperpod_jumpstart_inference_template', + 'custom_template_package': 'hyperpod_custom_inference_template', + 'supported_versions': ['1.0'], + }, + 'critical_deps': ['telemetry_emitter', 'telemetry_feature', 'inference_utils'], + 'lazy_imports': { + 'boto3': 'boto3', + 'generate_click_command': 'sagemaker.hyperpod.cli.inference_utils:generate_click_command', + 'JS_REG': 'hyperpod_jumpstart_inference_template.registry:SCHEMA_REGISTRY', + 'C_REG': 'hyperpod_custom_inference_template.registry:SCHEMA_REGISTRY', + 'HPJumpStartEndpoint': 'sagemaker.hyperpod.inference.hp_jumpstart_endpoint:HPJumpStartEndpoint', + 'HPEndpoint': 'sagemaker.hyperpod.inference.hp_endpoint:HPEndpoint', + 'Endpoint': 'sagemaker_core.resources:Endpoint', + '_hyperpod_telemetry_emitter': 'sagemaker.hyperpod.common.telemetry.telemetry_logging:_hyperpod_telemetry_emitter', + 'Feature': 'sagemaker.hyperpod.common.telemetry.constants:Feature', + 'display_formatted_logs': 'sagemaker.hyperpod.common.utils:display_formatted_logs' + } +} + +def _setup_inference_registries(deps): + """Setup inference-specific registries.""" + from sagemaker.hyperpod.common.lazy_loading import LazyRegistry + js_registry = LazyRegistry( + versions=['1.0'], + registry_import_path='hyperpod_jumpstart_inference_template.registry:SCHEMA_REGISTRY' + ) + custom_registry = LazyRegistry( + versions=['1.0'], + registry_import_path='hyperpod_custom_inference_template.registry:SCHEMA_REGISTRY' + ) + deps['JS_REG'] = js_registry + deps['C_REG'] = custom_registry + setattr(sys.modules[__name__], 'JS_REG', js_registry) + setattr(sys.modules[__name__], 'C_REG', custom_registry) + +INFERENCE_CONFIG['extra_setup'] = _setup_inference_registries +setup_lazy_module(__name__, INFERENCE_CONFIG) + +# Helper functions for decorators +def _get_telemetry_emitter(): + return getattr(sys.modules[__name__], '_hyperpod_telemetry_emitter') + +def _get_generate_click_command(): + return getattr(sys.modules[__name__], 'generate_click_command') # CREATE -@click.command("hyp-jumpstart-endpoint") +@register_inference_command("hyp-jumpstart-endpoint", "create") @click.option( "--namespace", type=click.STRING, @@ -28,11 +71,11 @@ help="Optional. The namespace of the jumpstart model endpoint to create. Default set to 'default'", ) @click.option("--version", default="1.0", help="Schema version to use") -@generate_click_command( +@LazyDecorator(_get_generate_click_command, schema_pkg="hyperpod_jumpstart_inference_template", - registry=JS_REG, + registry=lambda: sys.modules[__name__].JS_REG, ) -@_hyperpod_telemetry_emitter(Feature.HYPERPOD_CLI, "create_js_endpoint_cli") +@LazyDecorator(_get_telemetry_emitter, lambda: sys.modules[__name__].Feature.HYPERPOD_CLI, "create_js_endpoint_cli") @handle_cli_exceptions() def js_create(name, namespace, version, js_endpoint): """ @@ -42,7 +85,7 @@ def js_create(name, namespace, version, js_endpoint): js_endpoint.create(name=name, namespace=namespace) -@click.command("hyp-custom-endpoint") +@register_inference_command("hyp-custom-endpoint", "create") @click.option( "--namespace", type=click.STRING, @@ -51,11 +94,11 @@ def js_create(name, namespace, version, js_endpoint): help="Optional. The namespace of the jumpstart model endpoint to create. Default set to 'default'", ) @click.option("--version", default="1.0", help="Schema version to use") -@generate_click_command( +@LazyDecorator(_get_generate_click_command, schema_pkg="hyperpod_custom_inference_template", - registry=C_REG, + registry=lambda: sys.modules[__name__].C_REG, ) -@_hyperpod_telemetry_emitter(Feature.HYPERPOD_CLI, "create_custom_endpoint_cli") +@LazyDecorator(_get_telemetry_emitter, lambda: sys.modules[__name__].Feature.HYPERPOD_CLI, "create_custom_endpoint_cli") @handle_cli_exceptions() def custom_create(name, namespace, version, custom_endpoint): """ @@ -86,7 +129,7 @@ def custom_create(name, namespace, version, custom_endpoint): default="application/json", help="Optional. The content type of the request to invoke. Default set to 'application/json'", ) -@_hyperpod_telemetry_emitter(Feature.HYPERPOD_CLI, "invoke_custom_endpoint_cli") +@LazyDecorator(_get_telemetry_emitter, lambda: sys.modules[__name__].Feature.HYPERPOD_CLI, "invoke_custom_endpoint_cli") @handle_cli_exceptions() def custom_invoke( endpoint_name: str, @@ -101,10 +144,10 @@ def custom_invoke( except json.JSONDecodeError: raise click.ClickException("--body must be valid JSON") - rt = boto3.client("sagemaker-runtime") + rt = sys.modules[__name__].boto3.client("sagemaker-runtime") try: - endpoint = Endpoint.get(endpoint_name) + endpoint = sys.modules[__name__].Endpoint.get(endpoint_name) except Exception as e: endpoint = None @@ -113,7 +156,7 @@ def custom_invoke( f"Endpoint {endpoint_name} creation has been initated but is currently not in service") elif not endpoint: try: - hp_endpoint = HPEndpoint.get(endpoint_name) + hp_endpoint = sys.modules[__name__].HPEndpoint.get(endpoint_name) except Exception as e: hp_endpoint = None @@ -140,7 +183,7 @@ def custom_invoke( default="default", help="Optional. The namespace of the jumpstart model endpoint to list. Default set to 'default'", ) -@_hyperpod_telemetry_emitter(Feature.HYPERPOD_CLI, "list_js_endpoints_cli") +@LazyDecorator(_get_telemetry_emitter, lambda: sys.modules[__name__].Feature.HYPERPOD_CLI, "list_js_endpoints_cli") @handle_cli_exceptions() def js_list( namespace: Optional[str], @@ -148,7 +191,7 @@ def js_list( """ List all Hyperpod Jumpstart model endpoints. """ - endpoints = HPJumpStartEndpoint.model_construct().list(namespace) + endpoints = sys.modules[__name__].HPJumpStartEndpoint.model_construct().list(namespace) data = [ep.model_dump() for ep in endpoints] if not data: @@ -183,7 +226,7 @@ def js_list( default="default", help="Optional. The namespace of the custom model endpoint to list. Default set to 'default'", ) -@_hyperpod_telemetry_emitter(Feature.HYPERPOD_CLI, "list_custom_endpoints_cli") +@LazyDecorator(_get_telemetry_emitter, lambda: sys.modules[__name__].Feature.HYPERPOD_CLI, "list_custom_endpoints_cli") @handle_cli_exceptions() def custom_list( namespace: Optional[str], @@ -191,7 +234,7 @@ def custom_list( """ List all Hyperpod custom model endpoints. """ - endpoints = HPEndpoint.model_construct().list(namespace) + endpoints = sys.modules[__name__].HPEndpoint.model_construct().list(namespace) data = [ep.model_dump() for ep in endpoints] if not data: @@ -240,7 +283,7 @@ def custom_list( required=False, help="Optional. If set to `True`, the full json will be displayed", ) -@_hyperpod_telemetry_emitter(Feature.HYPERPOD_CLI, "get_js_endpoint_cli") +@LazyDecorator(_get_telemetry_emitter, lambda: sys.modules[__name__].Feature.HYPERPOD_CLI, "get_js_endpoint_cli") @handle_cli_exceptions() def js_describe( name: str, @@ -250,7 +293,7 @@ def js_describe( """ Describe a Hyperpod Jumpstart model endpoint. """ - my_endpoint = HPJumpStartEndpoint.model_construct().get(name, namespace) + my_endpoint = sys.modules[__name__].HPJumpStartEndpoint.model_construct().get(name, namespace) data = my_endpoint.model_dump() if full: @@ -262,7 +305,7 @@ def js_describe( click.echo("Invalid data received: expected a dictionary.") return - click.echo("\nDeployment (should be completed in 1-5 min):") + click.echo("\nDeployment (should be completed in 1-5 min):") status = data.get("status") or {} metadata = data.get("metadata") or {} @@ -321,7 +364,7 @@ def js_describe( click.echo() click.echo(click.style("─" * 60, fg="white")) - click.echo("\nSageMaker Endpoint (takes ~10 min to create):") + click.echo("\nSageMaker Endpoint (takes ~10 min to create):") status = data.get("status") or {} endpoints = status.get("endpoints") or {} sagemaker_info = endpoints.get("sagemaker") @@ -389,7 +432,7 @@ def js_describe( required=False, help="Optional. If set to `True`, the full json will be displayed", ) -@_hyperpod_telemetry_emitter(Feature.HYPERPOD_CLI, "get_custom_endpoint_cli") +@LazyDecorator(_get_telemetry_emitter, lambda: sys.modules[__name__].Feature.HYPERPOD_CLI, "get_custom_endpoint_cli") @handle_cli_exceptions() def custom_describe( name: str, @@ -399,7 +442,7 @@ def custom_describe( """ Describe a Hyperpod custom model endpoint. """ - my_endpoint = HPEndpoint.model_construct().get(name, namespace) + my_endpoint = sys.modules[__name__].HPEndpoint.model_construct().get(name, namespace) data = my_endpoint.model_dump() if full: @@ -411,7 +454,7 @@ def custom_describe( click.echo("Invalid data received: expected a dictionary.") return - click.echo("\nDeployment (should be completed in 1-5 min):") + click.echo("\nDeployment (should be completed in 1-5 min):") status = data.get("status") or {} metadata = data.get("metadata") or {} @@ -504,7 +547,7 @@ def custom_describe( click.echo() click.echo(click.style("─" * 60, fg="white")) - click.echo("\nSageMaker Endpoint (takes ~10 min to create):") + click.echo("\nSageMaker Endpoint (takes ~10 min to create):") status = data.get("status") or {} endpoints = status.get("endpoints") or {} sagemaker_info = endpoints.get("sagemaker") @@ -564,7 +607,7 @@ def custom_describe( default="default", help="Optional. The namespace of the jumpstart model endpoint to delete. Default set to 'default'.", ) -@_hyperpod_telemetry_emitter(Feature.HYPERPOD_CLI, "delete_js_endpoint_cli") +@LazyDecorator(_get_telemetry_emitter, lambda: sys.modules[__name__].Feature.HYPERPOD_CLI, "delete_js_endpoint_cli") @handle_cli_exceptions() def js_delete( name: str, @@ -575,7 +618,7 @@ def js_delete( """ # Auto-detects the endpoint type and operation # 0Provides 404 message: "❓ JumpStart endpoint 'missing-name' not found..." - my_endpoint = HPJumpStartEndpoint.model_construct().get(name, namespace) + my_endpoint = sys.modules[__name__].HPJumpStartEndpoint.model_construct().get(name, namespace) my_endpoint.delete() @@ -593,7 +636,7 @@ def js_delete( default="default", help="Optional. The namespace of the custom model endpoint to delete. Default set to 'default'.", ) -@_hyperpod_telemetry_emitter(Feature.HYPERPOD_CLI, "delete_custom_endpoint_cli") +@LazyDecorator(_get_telemetry_emitter, lambda: sys.modules[__name__].Feature.HYPERPOD_CLI, "delete_custom_endpoint_cli") @handle_cli_exceptions() def custom_delete( name: str, @@ -602,7 +645,7 @@ def custom_delete( """ Delete a Hyperpod custom model endpoint. """ - my_endpoint = HPEndpoint.model_construct().get(name, namespace) + my_endpoint = sys.modules[__name__].HPEndpoint.model_construct().get(name, namespace) my_endpoint.delete() @@ -614,7 +657,7 @@ def custom_delete( default="default", help="Optional. The namespace of the jumpstart model to list pods for. Default set to 'default'.", ) -@_hyperpod_telemetry_emitter(Feature.HYPERPOD_CLI, "list_pods_js_endpoint_cli") +@LazyDecorator(_get_telemetry_emitter, lambda: sys.modules[__name__].Feature.HYPERPOD_CLI, "list_pods_js_endpoint_cli") @handle_cli_exceptions() def js_list_pods( namespace: Optional[str], @@ -622,7 +665,7 @@ def js_list_pods( """ List all pods related to jumpstart model endpoint. """ - my_endpoint = HPJumpStartEndpoint.model_construct() + my_endpoint = sys.modules[__name__].HPJumpStartEndpoint.model_construct() pods = my_endpoint.list_pods(namespace=namespace) click.echo(pods) @@ -635,7 +678,7 @@ def js_list_pods( default="default", help="Optional. The namespace of the custom model to list pods for. Default set to 'default'.", ) -@_hyperpod_telemetry_emitter(Feature.HYPERPOD_CLI, "list_pods_custom_endpoint_cli") +@LazyDecorator(_get_telemetry_emitter, lambda: sys.modules[__name__].Feature.HYPERPOD_CLI, "list_pods_custom_endpoint_cli") @handle_cli_exceptions() def custom_list_pods( namespace: Optional[str], @@ -643,12 +686,12 @@ def custom_list_pods( """ List all pods related to custom model endpoint. """ - my_endpoint = HPEndpoint.model_construct() + my_endpoint = sys.modules[__name__].HPEndpoint.model_construct() pods = my_endpoint.list_pods(namespace=namespace) click.echo(pods) -@click.command("hyp-jumpstart-endpoint") +@register_inference_command("hyp-jumpstart-endpoint", "get-logs") @click.option( "--pod-name", type=click.STRING, @@ -668,7 +711,7 @@ def custom_list_pods( default="default", help="Optional. The namespace of the jumpstart model to get logs for. Default set to 'default'.", ) -@_hyperpod_telemetry_emitter(Feature.HYPERPOD_CLI, "get_logs_js_endpoint") +@LazyDecorator(_get_telemetry_emitter, lambda: sys.modules[__name__].Feature.HYPERPOD_CLI, "get_logs_js_endpoint") @handle_cli_exceptions() def js_get_logs( pod_name: str, @@ -678,15 +721,15 @@ def js_get_logs( """ Get specific pod log for jumpstart model endpoint. """ - my_endpoint = HPJumpStartEndpoint.model_construct() + my_endpoint = sys.modules[__name__].HPJumpStartEndpoint.model_construct() logs = my_endpoint.get_logs(pod=pod_name, container=container, namespace=namespace) # Use common log display utility for consistent formatting across all job types container_info = f" (container: {container})" if container else "" - display_formatted_logs(logs, title=f"JumpStart Endpoint Logs for {pod_name}{container_info}") + sys.modules[__name__].display_formatted_logs(logs, title=f"JumpStart Endpoint Logs for {pod_name}{container_info}") -@click.command("hyp-custom-endpoint") +@register_inference_command("hyp-custom-endpoint", "get-logs") @click.option( "--pod-name", type=click.STRING, @@ -706,7 +749,7 @@ def js_get_logs( default="default", help="Optional. The namespace of the custom model to get logs for. Default set to 'default'.", ) -@_hyperpod_telemetry_emitter(Feature.HYPERPOD_CLI, "get_logs_custom_endpoint") +@LazyDecorator(_get_telemetry_emitter, lambda: sys.modules[__name__].Feature.HYPERPOD_CLI, "get_logs_custom_endpoint") @handle_cli_exceptions() def custom_get_logs( pod_name: str, @@ -716,12 +759,12 @@ def custom_get_logs( """ Get specific pod log for custom model endpoint. """ - my_endpoint = HPEndpoint.model_construct() + my_endpoint = sys.modules[__name__].HPEndpoint.model_construct() logs = my_endpoint.get_logs(pod=pod_name, container=container, namespace=namespace) # Use common log display utility for consistent formatting across all job types container_info = f" (container: {container})" if container else "" - display_formatted_logs(logs, title=f"Custom Endpoint Logs for {pod_name}{container_info}") + sys.modules[__name__].display_formatted_logs(logs, title=f"Custom Endpoint Logs for {pod_name}{container_info}") @click.command("hyp-jumpstart-endpoint") @@ -731,7 +774,7 @@ def custom_get_logs( required=True, help="Required. The time frame to get logs for.", ) -@_hyperpod_telemetry_emitter(Feature.HYPERPOD_CLI, "get_js_operator_logs") +@LazyDecorator(_get_telemetry_emitter, lambda: sys.modules[__name__].Feature.HYPERPOD_CLI, "get_js_operator_logs") @handle_cli_exceptions() def js_get_operator_logs( since_hours: float, @@ -739,7 +782,7 @@ def js_get_operator_logs( """ Get operator logs for jumpstart model endpoint. """ - my_endpoint = HPJumpStartEndpoint.model_construct() + my_endpoint = sys.modules[__name__].HPJumpStartEndpoint.model_construct() logs = my_endpoint.get_operator_logs(since_hours=since_hours) click.echo(logs) @@ -751,7 +794,7 @@ def js_get_operator_logs( required=True, help="Required. The time frame get logs for.", ) -@_hyperpod_telemetry_emitter(Feature.HYPERPOD_CLI, "get_custom_operator_logs") +@LazyDecorator(_get_telemetry_emitter, lambda: sys.modules[__name__].Feature.HYPERPOD_CLI, "get_custom_operator_logs") @handle_cli_exceptions() def custom_get_operator_logs( since_hours: float, @@ -759,6 +802,6 @@ def custom_get_operator_logs( """ Get operator logs for custom model endpoint. """ - my_endpoint = HPEndpoint.model_construct() + my_endpoint = sys.modules[__name__].HPEndpoint.model_construct() logs = my_endpoint.get_operator_logs(since_hours=since_hours) click.echo(logs) diff --git a/src/sagemaker/hyperpod/cli/commands/training.py b/src/sagemaker/hyperpod/cli/commands/training.py index bef71203..73ead0d5 100644 --- a/src/sagemaker/hyperpod/cli/commands/training.py +++ b/src/sagemaker/hyperpod/cli/commands/training.py @@ -1,24 +1,59 @@ import click -from sagemaker.hyperpod.training.hyperpod_pytorch_job import HyperPodPytorchJob -from sagemaker.hyperpod.common.config import Metadata -from sagemaker.hyperpod.cli.training_utils import generate_click_command -from hyperpod_pytorch_job_template.registry import SCHEMA_REGISTRY -from sagemaker.hyperpod.common.telemetry.telemetry_logging import ( - _hyperpod_telemetry_emitter, -) -from sagemaker.hyperpod.common.telemetry.constants import Feature +import sys +from typing import Any from sagemaker.hyperpod.common.cli_decorators import handle_cli_exceptions -from sagemaker.hyperpod.common.utils import display_formatted_logs +from sagemaker.hyperpod.common.lazy_loading import LazyDecorator, setup_lazy_module +from sagemaker.hyperpod.cli.command_registry import register_training_command + +TRAINING_CONFIG = { + 'exports': [ + 'HyperPodPytorchJob', 'Metadata', 'generate_click_command', 'SCHEMA_REGISTRY', + '_hyperpod_telemetry_emitter', 'Feature', 'display_formatted_logs' + ], + 'template_packages': { + 'template_package': 'hyperpod_pytorch_job_template', + 'supported_versions': ['1.0', '1.1'], + }, + 'critical_deps': ['telemetry_emitter', 'telemetry_feature', 'training_utils'], + 'lazy_imports': { + 'HyperPodPytorchJob': 'sagemaker.hyperpod.training.hyperpod_pytorch_job:HyperPodPytorchJob', + 'Metadata': 'sagemaker.hyperpod.common.config:Metadata', + 'generate_click_command': 'sagemaker.hyperpod.cli.training_utils:generate_click_command', + 'SCHEMA_REGISTRY': 'hyperpod_pytorch_job_template.registry:SCHEMA_REGISTRY', + '_hyperpod_telemetry_emitter': 'sagemaker.hyperpod.common.telemetry.telemetry_logging:_hyperpod_telemetry_emitter', + 'Feature': 'sagemaker.hyperpod.common.telemetry.constants:Feature', + 'display_formatted_logs': 'sagemaker.hyperpod.common.utils:display_formatted_logs' + } +} + +def _setup_training_registries(deps): + """Setup training-specific registries.""" + from sagemaker.hyperpod.common.lazy_loading import LazyRegistry + registry = LazyRegistry( + versions=['1.0', '1.1'], + registry_import_path='hyperpod_pytorch_job_template.registry:SCHEMA_REGISTRY' + ) + deps['SCHEMA_REGISTRY'] = registry + setattr(sys.modules[__name__], 'SCHEMA_REGISTRY', registry) + +TRAINING_CONFIG['extra_setup'] = _setup_training_registries +setup_lazy_module(__name__, TRAINING_CONFIG) + +def _get_telemetry_emitter(): + return getattr(sys.modules[__name__], '_hyperpod_telemetry_emitter') + +def _get_generate_click_command(): + return getattr(sys.modules[__name__], 'generate_click_command') -@click.command("hyp-pytorch-job") -@click.option("--version", default="1.0", help="Schema version to use") +@register_training_command("hyp-pytorch-job", "create") +@click.option("--version", default="1.1", help="Schema version to use") @click.option("--debug", default=False, help="Enable debug mode") -@generate_click_command( - schema_pkg="hyperpod_pytorch_job_template", - registry=SCHEMA_REGISTRY, +@LazyDecorator(_get_generate_click_command, + schema_pkg=lambda: sys.modules[__name__]._MODULE_CONFIG["template_package"], + registry=lambda: sys.modules[__name__].SCHEMA_REGISTRY, ) -@_hyperpod_telemetry_emitter(Feature.HYPERPOD_CLI, "create_pytorchjob_cli") +@LazyDecorator(_get_telemetry_emitter, lambda: sys.modules[__name__].Feature.HYPERPOD_CLI, "create_pytorchjob_cli") @handle_cli_exceptions() def pytorch_create(version, debug, config): """Create a PyTorch job.""" @@ -40,7 +75,7 @@ def pytorch_create(version, debug, config): # Prepare job kwargs job_kwargs = { - "metadata": Metadata(**metadata_kwargs), + "metadata": sys.modules[__name__].Metadata(**metadata_kwargs), "replica_specs": spec.get("replica_specs"), } @@ -53,22 +88,22 @@ def pytorch_create(version, debug, config): job_kwargs["run_policy"] = spec.get("run_policy") # Create job - job = HyperPodPytorchJob(**job_kwargs) + job = sys.modules[__name__].HyperPodPytorchJob(**job_kwargs) job.create(debug=debug) -@click.command("hyp-pytorch-job") +@register_training_command("hyp-pytorch-job", "list") @click.option( "--namespace", "-n", default="default", help="Optional. The namespace to list jobs from. Defaults to 'default' namespace.", ) -@_hyperpod_telemetry_emitter(Feature.HYPERPOD_CLI, "list_pytorchjobs_cli") +@LazyDecorator(_get_telemetry_emitter, lambda: sys.modules[__name__].Feature.HYPERPOD_CLI, "list_pytorchjobs_cli") @handle_cli_exceptions() def list_jobs(namespace: str): """List all HyperPod PyTorch jobs.""" - jobs = HyperPodPytorchJob.list(namespace=namespace) + jobs = sys.modules[__name__].HyperPodPytorchJob.list(namespace=namespace) if not jobs: click.echo("No jobs found.") @@ -132,7 +167,7 @@ def list_jobs(namespace: str): click.echo() # Add empty line at the end -@click.command("hyp-pytorch-job") +@register_training_command("hyp-pytorch-job", "describe") @click.option( "--job-name", required=True, help="Required. The name of the job to describe" ) @@ -142,11 +177,11 @@ def list_jobs(namespace: str): default="default", help="Optional. The namespace of the job. Defaults to 'default' namespace.", ) -@_hyperpod_telemetry_emitter(Feature.HYPERPOD_CLI, "get_pytorchjob_cli") +@LazyDecorator(_get_telemetry_emitter, lambda: sys.modules[__name__].Feature.HYPERPOD_CLI, "get_pytorchjob_cli") @handle_cli_exceptions() def pytorch_describe(job_name: str, namespace: str): """Describe a HyperPod PyTorch job.""" - job = HyperPodPytorchJob.get(name=job_name, namespace=namespace) + job = sys.modules[__name__].HyperPodPytorchJob.get(name=job_name, namespace=namespace) if job is None: raise Exception(f"Job {job_name} not found in namespace {namespace}") @@ -233,7 +268,7 @@ def pytorch_describe(job_name: str, namespace: str): click.echo("No status information available") -@click.command("hyp-pytorch-job") +@register_training_command("hyp-pytorch-job", "delete") @click.option( "--job-name", required=True, help="Required. The name of the job to delete" ) @@ -243,15 +278,15 @@ def pytorch_describe(job_name: str, namespace: str): default="default", help="Optional. The namespace of the job. Defaults to 'default' namespace.", ) -@_hyperpod_telemetry_emitter(Feature.HYPERPOD_CLI, "delete_pytorchjob_cli") +@LazyDecorator(_get_telemetry_emitter, lambda: sys.modules[__name__].Feature.HYPERPOD_CLI, "delete_pytorchjob_cli") @handle_cli_exceptions() def pytorch_delete(job_name: str, namespace: str): """Delete a HyperPod PyTorch job.""" - job = HyperPodPytorchJob.get(name=job_name, namespace=namespace) + job = sys.modules[__name__].HyperPodPytorchJob.get(name=job_name, namespace=namespace) job.delete() -@click.command("hyp-pytorch-job") +@register_training_command("hyp-pytorch-job", "list-pods") @click.option( "--job-name", required=True, @@ -263,11 +298,11 @@ def pytorch_delete(job_name: str, namespace: str): default="default", help="Optional. The namespace of the job. Defaults to 'default' namespace.", ) -@_hyperpod_telemetry_emitter(Feature.HYPERPOD_CLI, "list_pods_pytorchjob_cli") +@LazyDecorator(_get_telemetry_emitter, lambda: sys.modules[__name__].Feature.HYPERPOD_CLI, "list_pods_pytorchjob_cli") @handle_cli_exceptions() def pytorch_list_pods(job_name: str, namespace: str): """List all HyperPod PyTorch pods related to the job.""" - job = HyperPodPytorchJob.get(name=job_name, namespace=namespace) + job = sys.modules[__name__].HyperPodPytorchJob.get(name=job_name, namespace=namespace) pods = job.list_pods() if not pods: @@ -292,7 +327,7 @@ def pytorch_list_pods(job_name: str, namespace: str): click.echo() -@click.command("hyp-pytorch-job") +@register_training_command("hyp-pytorch-job", "get-logs") @click.option( "--job-name", required=True, @@ -307,30 +342,30 @@ def pytorch_list_pods(job_name: str, namespace: str): default="default", help="Optional. The namespace of the job. Defaults to 'default' namespace.", ) -@_hyperpod_telemetry_emitter(Feature.HYPERPOD_CLI, "get_pytorchjob_logs_from_pod_cli") +@LazyDecorator(_get_telemetry_emitter, lambda: sys.modules[__name__].Feature.HYPERPOD_CLI, "get_pytorchjob_logs_from_pod_cli") @handle_cli_exceptions() def pytorch_get_logs(job_name: str, pod_name: str, namespace: str): """Get specific pod log for Hyperpod Pytorch job.""" click.echo("Listing logs for pod: " + pod_name) - job = HyperPodPytorchJob.get(name=job_name, namespace=namespace) + job = sys.modules[__name__].HyperPodPytorchJob.get(name=job_name, namespace=namespace) logs = job.get_logs_from_pod(pod_name=pod_name) # Use common log display utility for consistent formatting across all job types - display_formatted_logs(logs, title=f"Pod Logs for {pod_name}") + sys.modules[__name__].display_formatted_logs(logs, title=f"Pod Logs for {pod_name}") -@click.command("hyp-pytorch-job") +@register_training_command("hyp-pytorch-job", "get-operator-logs") @click.option( "--since-hours", type=click.FLOAT, required=True, help="Required. The time frame to get logs for.", ) -@_hyperpod_telemetry_emitter(Feature.HYPERPOD_CLI, "get_pytorch_operator_logs") +@LazyDecorator(_get_telemetry_emitter, lambda: sys.modules[__name__].Feature.HYPERPOD_CLI, "get_pytorch_operator_logs") @handle_cli_exceptions() def pytorch_get_operator_logs(since_hours: float): """Get operator logs for pytorch training jobs.""" - logs = HyperPodPytorchJob.get_operator_logs(since_hours=since_hours) + logs = sys.modules[__name__].HyperPodPytorchJob.get_operator_logs(since_hours=since_hours) # Use common log display utility for consistent formatting across all job types - display_formatted_logs(logs, title="PyTorch Operator Logs") + sys.modules[__name__].display_formatted_logs(logs, title="PyTorch Operator Logs") diff --git a/src/sagemaker/hyperpod/cli/hyp_cli.py b/src/sagemaker/hyperpod/cli/hyp_cli.py index c395845d..24a1ede6 100644 --- a/src/sagemaker/hyperpod/cli/hyp_cli.py +++ b/src/sagemaker/hyperpod/cli/hyp_cli.py @@ -1,156 +1,211 @@ import click -import yaml -import json -import os -import subprocess -from pydantic import BaseModel, ValidationError, Field from typing import Optional -from importlib.metadata import version, PackageNotFoundError - -from sagemaker.hyperpod.cli.commands.cluster import list_cluster, set_cluster_context, get_cluster_context, \ - get_monitoring -from sagemaker.hyperpod.cli.commands.training import ( - pytorch_create, - list_jobs, - pytorch_describe, - pytorch_delete, - pytorch_list_pods, - pytorch_get_logs, - pytorch_get_operator_logs, -) -from sagemaker.hyperpod.cli.commands.inference import ( - js_create, - custom_create, - custom_invoke, - js_list, - custom_list, - js_describe, - custom_describe, - js_delete, - custom_delete, - js_list_pods, - custom_list_pods, - js_get_logs, - custom_get_logs, - js_get_operator_logs, - custom_get_operator_logs, -) - - -def get_package_version(package_name): - try: - return version(package_name) - except PackageNotFoundError: - return "Not installed" - -def print_version(ctx, param, value): - if not value or ctx.resilient_parsing: - return - - hyp_version = get_package_version("sagemaker-hyperpod") - pytorch_template_version = get_package_version("hyperpod-pytorch-job-template") - custom_inference_version = get_package_version("hyperpod-custom-inference-template") - jumpstart_inference_version = get_package_version("hyperpod-jumpstart-inference-template") - - click.echo(f"hyp version: {hyp_version}") - click.echo(f"hyperpod-pytorch-job-template version: {pytorch_template_version}") - click.echo(f"hyperpod-custom-inference-template version: {custom_inference_version}") - click.echo(f"hyperpod-jumpstart-inference-template version: {jumpstart_inference_version}") - ctx.exit() - -@click.group() -@click.option('--version', is_flag=True, callback=print_version, expose_value=False, is_eager=True, help='Show version information') -def cli(): - pass - +from .command_registry import get_registry + +# Single source of truth for CLI help text +CLI_HELP_TEXT = { + 'create': 'Create endpoints or pytorch jobs.', + 'list': 'List endpoints or pytorch jobs.', + 'describe': 'Describe endpoints or pytorch jobs.', + 'delete': 'Delete endpoints or pytorch jobs.', + 'list-pods': 'List pods for endpoints or pytorch jobs.', + 'get-logs': 'Get pod logs for endpoints or pytorch jobs.', + 'invoke': 'Invoke model endpoints.', + 'get-operator-logs': 'Get operator logs for endpoints.', + 'list-cluster': 'List SageMaker Hyperpod Clusters with metadata.', + 'set-cluster-context': 'Connect to a HyperPod EKS cluster.', + 'get-cluster-context': 'Get context related to the current set cluster.', + 'get-monitoring': 'Get monitoring configurations for Hyperpod cluster.' +} + +# Custom CLI group that delays command registration until needed +class LazyGroup(click.Group): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.registry = get_registry() + + # ensure that duplicate modules aren't loaded + self.modules_registered = set() + + def list_commands(self, ctx): + """Return list of commands by querying the registry""" + # Ensure registry is initialized and commands are discovered + self.registry.ensure_commands_loaded() + + subgroup_commands = self.registry.get_all_groups() + + top_level_commands = self.registry.get_top_level_commands() + + # Combine and return sorted list + all_commands = subgroup_commands + top_level_commands + return sorted(set(all_commands)) + + def get_help(self, ctx): + """Override get_help to avoid loading modules for help text""" + # Generate help without loading heavy modules + formatter = ctx.make_formatter() + self.format_help(ctx, formatter) + return formatter.getvalue() + + def format_usage(self, ctx, formatter): + """Format usage without loading modules""" + pieces = self.collect_usage_pieces(ctx) + prog_name = ctx.find_root().info_name + formatter.write_usage(prog_name, ' '.join(pieces)) + + def format_help(self, ctx, formatter): + """Format help without loading modules""" + self.format_usage(ctx, formatter) + self.format_help_text(ctx, formatter) + self.format_options(ctx, formatter) + self.format_commands(ctx, formatter) + + def format_commands(self, ctx, formatter): + """Format commands section without loading modules""" + commands = [] + for name in self.list_commands(ctx): + help_text = CLI_HELP_TEXT.get(name, f'{name.replace("-", " ").title()} operations.') + commands.append((name, help_text)) + + if commands: + with formatter.section('Commands'): + formatter.write_dl(commands) + + def get_command(self, ctx, name): + self.registry.ensure_commands_loaded() + + # Register modules when actually needed + self._register_module_for_command(name) + return super().get_command(ctx, name) + + + def _register_module_for_command(self, name): + """Register only the module needed for a specific command""" + module_name = self.registry.get_module_for_command(name) + + if module_name and module_name not in self.modules_registered: + self._register_module(module_name) + elif name in self.registry.get_all_groups(): + # These are subgroup commands - register all modules so subcommands show in help + self._ensure_all_modules_registered() + + def _ensure_all_modules_registered(self): + """Register all modules - used when actually accessing subcommands""" + # Trigger self-registration by ensuring commands are loaded + self.registry.ensure_commands_loaded() + + for module in ['training', 'inference', 'cluster']: + if module not in self.modules_registered: + self._register_module(module) + + def _register_module(self, module_name): + """Register commands from a specific module""" + if module_name in self.modules_registered: + return + + if module_name == 'training': + from sagemaker.hyperpod.cli.commands.training import ( + pytorch_create, list_jobs, pytorch_describe, pytorch_delete, + pytorch_list_pods, pytorch_get_logs, pytorch_get_operator_logs + ) + + self.commands['create'].add_command(pytorch_create) + self.commands['list'].add_command(list_jobs) + self.commands['describe'].add_command(pytorch_describe) + self.commands['delete'].add_command(pytorch_delete) + self.commands['list-pods'].add_command(pytorch_list_pods) + self.commands['get-logs'].add_command(pytorch_get_logs) + self.commands['get-operator-logs'].add_command(pytorch_get_operator_logs) + + elif module_name == 'inference': + from sagemaker.hyperpod.cli.commands.inference import ( + js_create, custom_create, custom_invoke, js_list, custom_list, + js_describe, custom_describe, js_delete, custom_delete, + js_list_pods, custom_list_pods, js_get_logs, custom_get_logs, + js_get_operator_logs, custom_get_operator_logs, + ) + + self.commands['create'].add_command(js_create) + self.commands['create'].add_command(custom_create) + self.commands['list'].add_command(js_list) + self.commands['list'].add_command(custom_list) + self.commands['describe'].add_command(js_describe) + self.commands['describe'].add_command(custom_describe) + self.commands['delete'].add_command(js_delete) + self.commands['delete'].add_command(custom_delete) + self.commands['list-pods'].add_command(js_list_pods) + self.commands['list-pods'].add_command(custom_list_pods) + self.commands['get-logs'].add_command(js_get_logs) + self.commands['get-logs'].add_command(custom_get_logs) + self.commands['get-operator-logs'].add_command(js_get_operator_logs) + self.commands['get-operator-logs'].add_command(custom_get_operator_logs) + self.commands['invoke'].add_command(custom_invoke) + + elif module_name == 'cluster': + from sagemaker.hyperpod.cli.commands.cluster import list_cluster, set_cluster_context, get_cluster_context, get_monitoring + self.add_command(list_cluster) + self.add_command(set_cluster_context) + self.add_command(get_cluster_context) + self.add_command(get_monitoring) + + self.modules_registered.add(module_name) + class CLICommand(click.Group): pass +# Create CLI with lazy loading +cli = LazyGroup() + +# Create subgroups, lightweight and don't trigger imports @cli.group(cls=CLICommand) def create(): - """Create endpoints or pytorch jobs.""" pass +create.__doc__ = CLI_HELP_TEXT['create'] @cli.group(cls=CLICommand) def list(): - """List endpoints or pytorch jobs.""" pass +list.__doc__ = CLI_HELP_TEXT['list'] @cli.group(cls=CLICommand) def describe(): - """Describe endpoints or pytorch jobs.""" pass +describe.__doc__ = CLI_HELP_TEXT['describe'] @cli.group(cls=CLICommand) def delete(): - """Delete endpoints or pytorch jobs.""" pass +delete.__doc__ = CLI_HELP_TEXT['delete'] @cli.group(cls=CLICommand) def list_pods(): - """List pods for endpoints or pytorch jobs.""" pass +list_pods.__doc__ = CLI_HELP_TEXT['list-pods'] @cli.group(cls=CLICommand) def get_logs(): - """Get pod logs for endpoints or pytorch jobs.""" pass +get_logs.__doc__ = CLI_HELP_TEXT['get-logs'] @cli.group(cls=CLICommand) def invoke(): - """Invoke model endpoints.""" pass +invoke.__doc__ = CLI_HELP_TEXT['invoke'] @cli.group(cls=CLICommand) def get_operator_logs(): - """Get operator logs for endpoints.""" pass - - -create.add_command(pytorch_create) -create.add_command(js_create) -create.add_command(custom_create) - -list.add_command(list_jobs) -list.add_command(js_list) -list.add_command(custom_list) - -describe.add_command(pytorch_describe) -describe.add_command(js_describe) -describe.add_command(custom_describe) - -delete.add_command(pytorch_delete) -delete.add_command(js_delete) -delete.add_command(custom_delete) - -list_pods.add_command(pytorch_list_pods) -list_pods.add_command(js_list_pods) -list_pods.add_command(custom_list_pods) - -get_logs.add_command(pytorch_get_logs) -get_logs.add_command(js_get_logs) -get_logs.add_command(custom_get_logs) - -get_operator_logs.add_command(pytorch_get_operator_logs) -get_operator_logs.add_command(js_get_operator_logs) -get_operator_logs.add_command(custom_get_operator_logs) - -invoke.add_command(custom_invoke) -invoke.add_command(custom_invoke, name="hyp-jumpstart-endpoint") - -cli.add_command(list_cluster) -cli.add_command(set_cluster_context) -cli.add_command(get_cluster_context) -cli.add_command(get_monitoring) +get_operator_logs.__doc__ = CLI_HELP_TEXT['get-operator-logs'] if __name__ == "__main__": diff --git a/src/sagemaker/hyperpod/common/lazy_loading.py b/src/sagemaker/hyperpod/common/lazy_loading.py new file mode 100644 index 00000000..a26eda70 --- /dev/null +++ b/src/sagemaker/hyperpod/common/lazy_loading.py @@ -0,0 +1,329 @@ +""" +Common lazy loading infrastructure for deferred imports and CLI performance optimization. + +This module provides reusable components for implementing lazy loading patterns +that improve startup performance while maintaining full functionality. +""" + +import sys +from typing import Any, Dict, List, Callable, Optional, Union + + +class LazyRegistry: + """ + A registry that provides version info for CLI generation but lazy-loads model classes for execution. + + This class implements a two-tier approach: + - CLI Generation Time: Provides version info without importing heavy dependencies + - Execution Time: Lazy-loads the real registry with model classes when needed + """ + + def __init__( + self, + versions: List[str], + real_registry_loader: Optional[Callable] = None, + registry_import_path: Optional[str] = None + ): + """ + Initialize LazyRegistry. + + Args: + versions: List of supported versions for CLI generation + real_registry_loader: Function to load the real registry (optional) + registry_import_path: Import path for the real registry (alternative to loader) + """ + self.versions = versions + self.real_registry_loader = real_registry_loader + self.registry_import_path = registry_import_path + self._real_registry = None + + def _load_real_registry(self): + """Load the real registry using either the loader function or import path.""" + if self._real_registry is None: + if self.real_registry_loader: + self._real_registry = self.real_registry_loader() + elif self.registry_import_path: + module_path, attr_name = self.registry_import_path.split(':', 1) + module = __import__(module_path, fromlist=[attr_name]) + self._real_registry = getattr(module, attr_name) + else: + raise ValueError("Either real_registry_loader or registry_import_path must be provided") + + def keys(self): + """Provide version keys for CLI generation.""" + return self.versions + + def get(self, version): + """Lazy-load real model class when needed for execution.""" + self._load_real_registry() + return self._real_registry.get(version) + + def __contains__(self, version): + """Support version checking for CLI generation.""" + return version in self.versions + + def items(self): + """Support iteration for CLI generation - only provide keys, not values.""" + return [(version, None) for version in self.versions] + + +class LazyDecorator: + """ + A decorator that applies decorators based on their type. + + CLI generation decorators (like generate_click_command) are applied immediately + to ensure proper help text generation. Execution decorators (like telemetry) + are deferred until command execution. + """ + + def __init__(self, decorator_getter: Callable, *args, **kwargs): + """ + Initialize LazyDecorator. + + Args: + decorator_getter: Function that returns the decorator + *args, **kwargs: Arguments to pass to the decorator + """ + self.decorator_getter = decorator_getter + self.args = args + self.kwargs = kwargs + self._cached_decorator = None + + def __call__(self, func): + """Apply the decorator based on its type.""" + # Check if this is a CLI generation decorator that needs immediate application for help + if (hasattr(self.decorator_getter, '__name__') and + self.decorator_getter.__name__ in ['_get_generate_click_command']): + return self._apply_immediately(func) + else: + # Defer execution decorators like telemetry + return self._apply_deferred(func) + + def _apply_immediately(self, func): + """Apply decorator immediately for CLI generation.""" + decorator = self.decorator_getter() + + # Resolve any callable arguments (like lambda functions) + resolved_args = [arg() if callable(arg) else arg for arg in self.args] + resolved_kwargs = {k: (v() if callable(v) else v) for k, v in self.kwargs.items()} + + return decorator(*resolved_args, **resolved_kwargs)(func) + + def _apply_deferred(self, func): + """Apply decorator at execution time for runtime decorators.""" + def wrapper(*wrapper_args, **wrapper_kwargs): + if self._cached_decorator is None: + decorator = self.decorator_getter() + + # Resolve any callable arguments at runtime + resolved_args = [arg() if callable(arg) else arg for arg in self.args] + resolved_kwargs = {k: (v() if callable(v) else v) for k, v in self.kwargs.items()} + + self._cached_decorator = decorator(*resolved_args, **resolved_kwargs) + self._decorated_func = self._cached_decorator(func) + + return self._decorated_func(*wrapper_args, **wrapper_kwargs) + + # Preserve function metadata + wrapper.__name__ = getattr(func, '__name__', 'wrapped_function') + wrapper.__doc__ = getattr(func, '__doc__', None) + return wrapper + + +class LazyImportManager: + """ + Manages lazy imports using a mapping-based approach. + + This provides a clean, declarative way to define lazy imports without + coupling to specific module implementations. + """ + + def __init__(self, import_mapping: Dict[str, str]): + """ + Initialize LazyImportManager. + + Args: + import_mapping: Dict mapping attribute names to import paths + Format: "module.path:attribute" or "module_name" + """ + self.import_mapping = import_mapping + self._cached_imports = {} + + def get_lazy_import(self, name: str) -> Any: + """ + Get a lazy import by name. + + Args: + name: Name of the import to retrieve + + Returns: + The imported object + + Raises: + AttributeError: If the import name is not found + """ + if name in self._cached_imports: + return self._cached_imports[name] + + if name not in self.import_mapping: + raise AttributeError(f"No lazy import defined for '{name}'") + + import_path = self.import_mapping[name] + + if ':' in import_path: + # Format: "module.path:attribute" + module_path, attr_name = import_path.split(':', 1) + module = __import__(module_path, fromlist=[attr_name]) + obj = getattr(module, attr_name) + else: + # Format: "module_name" (import entire module) + obj = __import__(import_path) + + # Cache for future access + self._cached_imports[name] = obj + return obj + + def create_getattr_function(self, module_name: str) -> Callable[[str], Any]: + """ + Create a __getattr__ function for a module. + + Args: + module_name: Name of the module (for error messages) + + Returns: + A __getattr__ function that can be used in a module + """ + def __getattr__(name: str) -> Any: + try: + obj = self.get_lazy_import(name) + # Cache it in the module namespace for direct access + setattr(sys.modules[module_name], name, obj) + return obj + except AttributeError: + raise AttributeError(f"module '{module_name}' has no attribute '{name}'") + + return __getattr__ + + +def create_critical_deps_loader( + dependencies: Dict[str, str], + module_name: str, + extra_setup: Optional[Callable] = None +) -> Callable: + """ + Create a function to load critical dependencies for decorators. + + Args: + dependencies: Dict mapping dependency names to import paths + module_name: Name of the module (for setting in sys.modules) + extra_setup: Optional function for additional setup logic + + Returns: + Function that loads critical dependencies + """ + def _ensure_critical_deps(): + """Load critical dependencies needed for decorators.""" + deps = {} + + for name, import_path in dependencies.items(): + try: + if ':' in import_path: + module_path, attr_name = import_path.split(':', 1) + module = __import__(module_path, fromlist=[attr_name]) + obj = getattr(module, attr_name) + else: + obj = __import__(import_path) + + deps[name] = obj + # Set in module namespace for immediate access + setattr(sys.modules[module_name], name, obj) + except ImportError: + # Ignore import errors during module loading + pass + + # Call any extra setup function + if extra_setup: + extra_setup(deps) + + return deps + + return _ensure_critical_deps + + +def setup_lazy_module(module_name: str, config: Dict[str, Any]) -> None: + """ + Setup lazy loading infrastructure for a command module. + + This eliminates boilerplate by providing a single setup function that handles + all the lazy loading infrastructure based on configuration. + + Args: + module_name: The __name__ of the module + config: Configuration dict with keys: + - exports: List of what should be available (__all__) + - template_packages: Dict of template package configs + - critical_deps: List of dependencies needed for decorators + - lazy_imports: Dict mapping names to import paths + - extra_setup: Optional function for additional setup + """ + module = sys.modules[module_name] + + # Set __all__ for the module + if 'exports' in config: + setattr(module, '__all__', config['exports']) + + # Create module config if template packages specified + if 'template_packages' in config: + setattr(module, '_MODULE_CONFIG', config['template_packages']) + + # Setup critical dependencies + if 'critical_deps' in config: + critical_deps = {} + base_paths = { + 'telemetry_emitter': 'sagemaker.hyperpod.common.telemetry.telemetry_logging:_hyperpod_telemetry_emitter', + 'telemetry_feature': 'sagemaker.hyperpod.common.telemetry.constants:Feature', + 'training_utils': 'sagemaker.hyperpod.cli.training_utils:generate_click_command', + 'inference_utils': 'sagemaker.hyperpod.cli.inference_utils:generate_click_command' + } + + for dep in config['critical_deps']: + if dep in base_paths: + # Map to final attribute names + if dep == 'telemetry_emitter': + critical_deps['_hyperpod_telemetry_emitter'] = base_paths[dep] + elif dep == 'telemetry_feature': + critical_deps['Feature'] = base_paths[dep] + elif dep == 'training_utils': + critical_deps['generate_click_command'] = base_paths[dep] + elif dep == 'inference_utils': + critical_deps['generate_click_command'] = base_paths[dep] + else: + critical_deps[dep] = config['lazy_imports'].get(dep, dep) + + # Create and call critical deps loader + extra_setup = config.get('extra_setup') + ensure_critical_deps = create_critical_deps_loader( + dependencies=critical_deps, + module_name=module_name, + extra_setup=extra_setup + ) + ensure_critical_deps() + setattr(module, '_ensure_critical_deps', ensure_critical_deps) + + # Setup lazy imports + if 'lazy_imports' in config: + import_manager = LazyImportManager(config['lazy_imports']) + getattr_func = import_manager.create_getattr_function(module_name) + setattr(module, '__getattr__', getattr_func) + setattr(module, '_import_manager', import_manager) + + # Helper functions for decorators + if 'telemetry_emitter' in config.get('critical_deps', []): + def _get_telemetry_emitter(): + return getattr(module, '_hyperpod_telemetry_emitter') + setattr(module, '_get_telemetry_emitter', _get_telemetry_emitter) + + if any('utils' in dep for dep in config.get('critical_deps', [])): + def _get_generate_click_command(): + return getattr(module, 'generate_click_command') + setattr(module, '_get_generate_click_command', _get_generate_click_command)