From a4592c625bd0cb60b5c5c83e9391f75d38c2eaa6 Mon Sep 17 00:00:00 2001 From: Pravin Gadakh Date: Fri, 15 Apr 2016 20:24:51 +0530 Subject: [PATCH] Added accessor methods for Pipeline --- .../scala/org/apache/spark/ml/Pipeline.scala | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 82066726a0694..007f371b6f4f1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -111,6 +111,18 @@ class Pipeline @Since("1.4.0") ( @Since("1.2.0") def getStages: Array[PipelineStage] = $(stages).clone() + /** Returns stage at index i in Pipeline */ + @Since("2.0.0") + def getStage[T <: PipelineStage](i: Int): T = getStages.apply(i).asInstanceOf[T] + + /** Returns all stages of this type */ + @Since("2.0.0") + def getStagesOfType[T <: PipelineStage]: Array[T] = { + getStages.collect { + case stage: T => stage + } + } + /** * Fits the pipeline to the input dataset with additional parameters. If a stage is an * [[Estimator]], its [[Estimator#fit]] method will be called on the input dataset to fit a model. @@ -309,6 +321,32 @@ class PipelineModel private[ml] ( @Since("1.6.0") override def write: MLWriter = new PipelineModel.PipelineModelWriter(this) + + /** Returns stage at index i in PipelineModel */ + @Since("2.0.0") + def getStage[T <: Transformer](i: Int): T = { + stages.apply(i).asInstanceOf[T] + } + + /** + * Returns stage given its parent or generating instance in PipelineModel. + * E.g., if this PipelineModel was created from a Pipeline containing a stage + * {{myStage}}, then passing {{myStage}} to this method will return the + * corresponding stage in this PipelineModel. + */ + @Since("2.0.0") + def getStage[T <: Transformer, E <: PipelineStage](stage: E): T = { + val idxInPipeline = this.parent.asInstanceOf[Pipeline].getStages.indexOf(stage) + stages.apply(idxInPipeline).asInstanceOf[T] + } + + /** Returns all stages of this type */ + @Since("2.0.0") + def getStagesOfType[T <: Transformer]: Array[T] = { + stages.collect { + case stage: T => stage + } + } } @Since("1.6.0")