-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-14585][ML][WIP] Provide accessor methods for Pipeline stages #12420
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -111,6 +111,18 @@ class Pipeline @Since("1.4.0") ( | |
| @Since("1.2.0") | ||
| def getStages: Array[PipelineStage] = $(stages).clone() | ||
|
|
||
| /** Returns stage at index i in Pipeline */ | ||
| @Since("2.0.0") | ||
| def getStage[T <: PipelineStage](i: Int): T = getStages.apply(i).asInstanceOf[T] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here and elsewhere, I would just write array(i) instead of array.apply(i) Also, can you please document the |
||
|
|
||
| /** Returns all stages of this type */ | ||
| @Since("2.0.0") | ||
| def getStagesOfType[T <: PipelineStage]: Array[T] = { | ||
| getStages.collect { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: This is probably more natural as a one-liner |
||
| case stage: T => stage | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Fits the pipeline to the input dataset with additional parameters. If a stage is an | ||
| * [[Estimator]], its [[Estimator#fit]] method will be called on the input dataset to fit a model. | ||
|
|
@@ -309,6 +321,32 @@ class PipelineModel private[ml] ( | |
|
|
||
| @Since("1.6.0") | ||
| override def write: MLWriter = new PipelineModel.PipelineModelWriter(this) | ||
|
|
||
| /** Returns stage at index i in PipelineModel */ | ||
| @Since("2.0.0") | ||
| def getStage[T <: Transformer](i: Int): T = { | ||
| stages.apply(i).asInstanceOf[T] | ||
| } | ||
|
|
||
| /** | ||
| * Returns stage given its parent or generating instance in PipelineModel. | ||
| * E.g., if this PipelineModel was created from a Pipeline containing a stage | ||
| * {{myStage}}, then passing {{myStage}} to this method will return the | ||
| * corresponding stage in this PipelineModel. | ||
| */ | ||
| @Since("2.0.0") | ||
| def getStage[T <: Transformer, E <: PipelineStage](stage: E): T = { | ||
| val idxInPipeline = this.parent.asInstanceOf[Pipeline].getStages.indexOf(stage) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should test against the uid, not the stage reference. |
||
| stages.apply(idxInPipeline).asInstanceOf[T] | ||
| } | ||
|
|
||
| /** Returns all stages of this type */ | ||
| @Since("2.0.0") | ||
| def getStagesOfType[T <: Transformer]: Array[T] = { | ||
| stages.collect { | ||
| case stage: T => stage | ||
| } | ||
| } | ||
| } | ||
|
|
||
| @Since("1.6.0") | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't be since 2.0.0 at this point. Also use a
@returntag in the docs.