Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,8 @@ class MultilayerPerceptronClassificationModel private[ml] (

@Since("1.5.0")
override def copy(extra: ParamMap): MultilayerPerceptronClassificationModel = {
copyValues(new MultilayerPerceptronClassificationModel(uid, layers, weights), extra)
val copied = new MultilayerPerceptronClassificationModel(uid, layers, weights).setParent(parent)
copyValues(copied, extra)
}

@Since("2.0.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,10 @@ class BucketedRandomProjectionLSHModel private[ml](
}

@Since("2.1.0")
override def copy(extra: ParamMap): this.type = defaultCopy(extra)
override def copy(extra: ParamMap): BucketedRandomProjectionLSHModel = {
val copied = new BucketedRandomProjectionLSHModel(uid, randUnitVectors).setParent(parent)
copyValues(copied, extra)
}

@Since("2.1.0")
override def write: MLWriter = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,10 @@ class MinHashLSHModel private[ml](
}

@Since("2.1.0")
override def copy(extra: ParamMap): this.type = defaultCopy(extra)
override def copy(extra: ParamMap): MinHashLSHModel = {
val copied = new MinHashLSHModel(uid, randCoefficients).setParent(parent)
copyValues(copied, extra)
}

@Since("2.1.0")
override def write: MLWriter = new MinHashLSHModel.MinHashLSHModelWriter(this)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,10 @@ class RFormulaModel private[feature](
}

@Since("1.5.0")
override def copy(extra: ParamMap): RFormulaModel = copyValues(
new RFormulaModel(uid, resolvedFormula, pipelineModel))
override def copy(extra: ParamMap): RFormulaModel = {
val copied = new RFormulaModel(uid, resolvedFormula, pipelineModel).setParent(parent)
copyValues(copied, extra)
}

@Since("2.0.0")
override def toString: String = s"RFormulaModel($resolvedFormula) (uid=$uid)"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class MultilayerPerceptronClassifierSuite
.setMaxIter(100)
.setSolver("l-bfgs")
val model = trainer.fit(dataset)
MLTestingUtils.checkCopy(model)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just something for consideration. Not sure if this is the best way to add unit test for this. I don't think we need to add check only for the classes included in this PR. And even if we do, maybe it's possible to do it in a more uniform way, like in testEstimatorAndModelReadWrite or add a pipelineCompatibilityTest as necessary.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @hhbyyh , I agree that this is maybe not the best general way to test it, but this is how it is done in every other Model test. Maybe this is ok for now and we could follow up with a discussion on the best way to check "basic" ML functionality?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do it this way for now - would be good to log a JIRA to investigate a better and more generic way of testing so that it doesn't have to be done manually for each newly added Estimator/Model (or when a Model changes such that defaultCopy is no longer sufficient, etc)

val result = model.transform(dataset)
val predictionAndLabels = result.select("prediction", "label").collect()
predictionAndLabels.foreach { case Row(p: Double, l: Double) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import breeze.numerics.constants.Pi
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.Dataset
Expand Down Expand Up @@ -89,10 +89,12 @@ class BucketedRandomProjectionLSHSuite
.setOutputCol("values")
.setBucketLength(1.0)
.setSeed(12345)
val unitVectors = brp.fit(dataset).randUnitVectors
val brpModel = brp.fit(dataset)
val unitVectors = brpModel.randUnitVectors
unitVectors.foreach { v: Vector =>
assert(Vectors.norm(v, 2.0) ~== 1.0 absTol 1e-14)
}
MLTestingUtils.checkCopy(brpModel)
}

test("BucketedRandomProjectionLSH: test of LSH property") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.Dataset

Expand Down Expand Up @@ -57,6 +57,15 @@ class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
testEstimatorAndModelReadWrite(mh, dataset, settings, settings, checkModelData)
}

test("Model copy and uid checks") {
val mh = new MinHashLSH()
.setInputCol("keys")
.setOutputCol("values")
val model = mh.fit(dataset)
assert(mh.uid === model.uid)
MLTestingUtils.checkCopy(model)
}

test("hashFunction") {
val model = new MinHashLSHModel("mh", randCoefficients = Array((0, 1), (1, 2), (3, 0)))
val res = model.hashFunction(Vectors.sparse(10, Seq((2, 1.0), (3, 1.0), (5, 1.0), (7, 1.0))))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
val formula = new RFormula().setFormula("id ~ v1 + v2")
val original = Seq((0, 1.0, 3.0), (2, 2.0, 5.0)).toDF("id", "v1", "v2")
val model = formula.fit(original)
MLTestingUtils.checkCopy(model)
val result = model.transform(original)
val resultSchema = model.transformSchema(original.schema)
val expected = Seq(
Expand Down