Skip to content

Commit a4592c6

Browse files
author
Pravin Gadakh
committed
Added accessor methods for Pipeline
1 parent 06b9d62 commit a4592c6

File tree

1 file changed

+38
-0
lines changed

1 file changed

+38
-0
lines changed

mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,18 @@ class Pipeline @Since("1.4.0") (
111111
@Since("1.2.0")
112112
def getStages: Array[PipelineStage] = $(stages).clone()
113113

114+
/** Returns stage at index i in Pipeline */
115+
@Since("2.0.0")
116+
def getStage[T <: PipelineStage](i: Int): T = getStages.apply(i).asInstanceOf[T]
117+
118+
/** Returns all stages of this type */
119+
@Since("2.0.0")
120+
def getStagesOfType[T <: PipelineStage]: Array[T] = {
121+
getStages.collect {
122+
case stage: T => stage
123+
}
124+
}
125+
114126
/**
115127
* Fits the pipeline to the input dataset with additional parameters. If a stage is an
116128
* [[Estimator]], its [[Estimator#fit]] method will be called on the input dataset to fit a model.
@@ -309,6 +321,32 @@ class PipelineModel private[ml] (
309321

310322
@Since("1.6.0")
311323
override def write: MLWriter = new PipelineModel.PipelineModelWriter(this)
324+
325+
/** Returns stage at index i in PipelineModel */
326+
@Since("2.0.0")
327+
def getStage[T <: Transformer](i: Int): T = {
328+
stages.apply(i).asInstanceOf[T]
329+
}
330+
331+
/**
332+
* Returns stage given its parent or generating instance in PipelineModel.
333+
* E.g., if this PipelineModel was created from a Pipeline containing a stage
334+
* {{myStage}}, then passing {{myStage}} to this method will return the
335+
* corresponding stage in this PipelineModel.
336+
*/
337+
@Since("2.0.0")
338+
def getStage[T <: Transformer, E <: PipelineStage](stage: E): T = {
339+
val idxInPipeline = this.parent.asInstanceOf[Pipeline].getStages.indexOf(stage)
340+
stages.apply(idxInPipeline).asInstanceOf[T]
341+
}
342+
343+
/** Returns all stages of this type */
344+
@Since("2.0.0")
345+
def getStagesOfType[T <: Transformer]: Array[T] = {
346+
stages.collect {
347+
case stage: T => stage
348+
}
349+
}
312350
}
313351

314352
@Since("1.6.0")

0 commit comments

Comments
 (0)