3030 Processor ,
3131)
3232from sagemaker .transformer import Transformer , _TransformJob
33+ from sagemaker .tuner import HyperparameterTuner , _TuningJob
3334from sagemaker .workflow .entities import (
3435 DefaultEnumMeta ,
3536 Entity ,
3940 PropertyFile ,
4041 Properties ,
4142)
43+ from sagemaker .workflow .functions import Join
4244
4345
4446class StepTypeEnum (Enum , metaclass = DefaultEnumMeta ):
@@ -51,6 +53,7 @@ class StepTypeEnum(Enum, metaclass=DefaultEnumMeta):
5153 TRAINING = "Training"
5254 TRANSFORM = "Transform"
5355 CALLBACK = "Callback"
56+ TUNING = "Tuning"
5457
5558
5659@attr .s
@@ -92,6 +95,7 @@ def add_depends_on(self, step_names: List[str]):
9295 """Add step names to the current step depends on list"""
9396 if not step_names :
9497 return
98+
9599 if not self .depends_on :
96100 self .depends_on = []
97101 self .depends_on .extend (step_names )
@@ -429,3 +433,132 @@ def to_request(self) -> RequestType:
429433 property_file .expr for property_file in self .property_files
430434 ]
431435 return request_dict
436+
437+
438+ class TuningStep (Step ):
439+ """Tuning step for workflow."""
440+
441+ def __init__ (
442+ self ,
443+ name : str ,
444+ tuner : HyperparameterTuner ,
445+ inputs = None ,
446+ job_arguments : List [str ] = None ,
447+ cache_config : CacheConfig = None ,
448+ depends_on : List [str ] = None ,
449+ ):
450+ """Construct a TuningStep, given a `HyperparameterTuner` instance.
451+
452+ In addition to the tuner instance, the other arguments are those that are supplied to
453+ the `fit` method of the `sagemaker.tuner.HyperparameterTuner`.
454+
455+ Args:
456+ name (str): The name of the tuning step.
457+ tuner (HyperparameterTuner): A `sagemaker.tuner.HyperparameterTuner` instance.
458+ inputs: Information about the training data. Please refer to the
459+ ``fit()`` method of the associated estimator, as this can take
460+ any of the following forms:
461+
462+ * (str) - The S3 location where training data is saved.
463+ * (dict[str, str] or dict[str, sagemaker.inputs.TrainingInput]) -
464+ If using multiple channels for training data, you can specify
465+ a dict mapping channel names to strings or
466+ :func:`~sagemaker.inputs.TrainingInput` objects.
467+ * (sagemaker.inputs.TrainingInput) - Channel configuration for S3 data sources
468+ that can provide additional information about the training dataset.
469+ See :func:`sagemaker.inputs.TrainingInput` for full details.
470+ * (sagemaker.session.FileSystemInput) - channel configuration for
471+ a file system data source that can provide additional information as well as
472+ the path to the training dataset.
473+ * (sagemaker.amazon.amazon_estimator.RecordSet) - A collection of
474+ Amazon :class:~`Record` objects serialized and stored in S3.
475+ For use with an estimator for an Amazon algorithm.
476+ * (sagemaker.amazon.amazon_estimator.FileSystemRecordSet) -
477+ Amazon SageMaker channel configuration for a file system data source for
478+ Amazon algorithms.
479+ * (list[sagemaker.amazon.amazon_estimator.RecordSet]) - A list of
480+ :class:~`sagemaker.amazon.amazon_estimator.RecordSet` objects,
481+ where each instance is a different channel of training data.
482+ * (list[sagemaker.amazon.amazon_estimator.FileSystemRecordSet]) - A list of
483+ :class:~`sagemaker.amazon.amazon_estimator.FileSystemRecordSet` objects,
484+ where each instance is a different channel of training data.
485+ job_arguments (List[str]): A list of strings to be passed into the processing job.
486+ Defaults to `None`.
487+ cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
488+ depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.ProcessingStep`
489+ depends on
490+ """
491+ super (TuningStep , self ).__init__ (name , StepTypeEnum .TUNING , depends_on )
492+ self .tuner = tuner
493+ self .inputs = inputs
494+ self .job_arguments = job_arguments
495+ self ._properties = Properties (
496+ path = f"Steps.{ name } " ,
497+ shape_names = [
498+ "DescribeHyperParameterTuningJobResponse" ,
499+ "ListTrainingJobsForHyperParameterTuningJobResponse" ,
500+ ],
501+ )
502+ self .cache_config = cache_config
503+
504+ @property
505+ def arguments (self ) -> RequestType :
506+ """The arguments dict that is used to call `create_hyper_parameter_tuning_job`.
507+
508+ NOTE: The CreateHyperParameterTuningJob request is not quite the
509+ args list that workflow needs.
510+ The HyperParameterTuningJobName attribute cannot be included.
511+ """
512+ if self .tuner .estimator is not None :
513+ self .tuner .estimator ._prepare_for_training ()
514+ else :
515+ for _ , estimator in self .tuner .estimator_dict .items ():
516+ estimator ._prepare_for_training ()
517+
518+ self .tuner ._prepare_for_tuning ()
519+ tuner_args = _TuningJob ._get_tuner_args (self .tuner , self .inputs )
520+ request_dict = self .tuner .sagemaker_session ._get_tuning_request (** tuner_args )
521+ request_dict .pop ("HyperParameterTuningJobName" )
522+
523+ return request_dict
524+
525+ @property
526+ def properties (self ):
527+ """A Properties object representing
528+
529+ `DescribeHyperParameterTuningJobResponse` and
530+ `ListTrainingJobsForHyperParameterTuningJobResponse` data model.
531+ """
532+ return self ._properties
533+
534+ def to_request (self ) -> RequestType :
535+ """Updates the dictionary with cache configuration."""
536+ request_dict = super ().to_request ()
537+ if self .cache_config :
538+ request_dict .update (self .cache_config .config )
539+
540+ return request_dict
541+
542+ def get_top_model_s3_uri (self , top_k : int , s3_bucket : str , prefix : str = "" ):
543+ """Get the model artifact s3 uri from the top performing training jobs.
544+
545+ Args:
546+ top_k (int): the index of the top performing training job
547+ tuning step stores up to 50 top performing training jobs, hence
548+ a valid top_k value is from 0 to 49. The best training job
549+ model is at index 0
550+ s3_bucket (str): the s3 bucket to store the training job output artifact
551+ prefix (str): the s3 key prefix to store the training job output artifact
552+ """
553+ values = ["s3:/" , s3_bucket ]
554+ if prefix != "" and prefix is not None :
555+ values .append (prefix )
556+
557+ return Join (
558+ on = "/" ,
559+ values = values
560+ + [
561+ self .properties .TrainingJobSummaries [top_k ].TrainingJobName ,
562+ "output/model.tar.gz" ,
563+ ],
564+ )
0 commit comments