diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index f491a679b2422..ff48b3029a00b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -51,7 +51,8 @@ import org.apache.spark.util.VersionUtils */ private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams with HasRegParam with HasElasticNetParam with HasMaxIter with HasFitIntercept with HasTol - with HasStandardization with HasWeightCol with HasThreshold with HasAggregationDepth { + with HasStandardization with HasWeightCol with HasThreshold with HasAggregationDepth + with HasHandlePersistence { import org.apache.spark.ml.classification.LogisticRegression.supportedFamilyNames @@ -431,6 +432,10 @@ class LogisticRegression @Since("1.2.0") ( @Since("2.2.0") def setUpperBoundsOnIntercepts(value: Vector): this.type = set(upperBoundsOnIntercepts, value) + /** @group setParam */ + @Since("2.3.0") + def setHandlePersistence(value: Boolean): this.type = set(handlePersistence, value) + private def assertBoundConstrainedOptimizationParamsValid( numCoefficientSets: Int, numFeatures: Int): Unit = { @@ -483,14 +488,7 @@ class LogisticRegression @Since("1.2.0") ( this } - override protected[spark] def train(dataset: Dataset[_]): LogisticRegressionModel = { - val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE - train(dataset, handlePersistence) - } - - protected[spark] def train( - dataset: Dataset[_], - handlePersistence: Boolean): LogisticRegressionModel = { + protected[spark] def train(dataset: Dataset[_]): LogisticRegressionModel = { val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map { @@ -498,9 +496,9 @@ class LogisticRegression @Since("1.2.0") ( Instance(label, weight, features) } - if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) + if ($(handlePersistence)) instances.persist(StorageLevel.MEMORY_AND_DISK) - val instr = Instrumentation.create(this, instances) + val instr = Instrumentation.create(this, dataset) instr.logParams(regParam, elasticNetParam, standardization, threshold, maxIter, tol, fitIntercept) @@ -878,7 +876,7 @@ class LogisticRegression @Since("1.2.0") ( } } - if (handlePersistence) instances.unpersist() + if ($(handlePersistence)) instances.unpersist() val model = copyValues(new LogisticRegressionModel(uid, coefficientMatrix, interceptVector, numClasses, isMultinomial)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 05b8c3ab5456e..0513932756cfc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -17,10 +17,8 @@ package org.apache.spark.ml.classification -import java.util.{List => JList} import java.util.UUID -import scala.collection.JavaConverters._ import scala.language.existentials import org.apache.hadoop.fs.Path @@ -34,7 +32,7 @@ import org.apache.spark.ml._ import org.apache.spark.ml.attribute._ import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params} -import org.apache.spark.ml.param.shared.HasWeightCol +import org.apache.spark.ml.param.shared.{HasHandlePersistence, HasWeightCol} import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ @@ -55,7 +53,7 @@ private[ml] trait ClassifierTypeTrait { * Params for [[OneVsRest]]. */ private[ml] trait OneVsRestParams extends PredictorParams - with ClassifierTypeTrait with HasWeightCol { + with ClassifierTypeTrait with HasWeightCol with HasHandlePersistence { /** * param for the base binary classifier that we reduce multiclass classification into. @@ -67,6 +65,10 @@ private[ml] trait OneVsRestParams extends PredictorParams /** @group getParam */ def getClassifier: ClassifierType = $(classifier) + + /** @group setParam */ + @Since("2.3.0") + def setHandlePersistence(value: Boolean): this.type = set(handlePersistence, value) } private[ml] object OneVsRestParams extends ClassifierTypeTrait { @@ -163,9 +165,7 @@ final class OneVsRestModel private[ml] ( val initUDF = udf { () => Map[Int, Double]() } val newDataset = dataset.withColumn(accColName, initUDF()) - // persist if underlying dataset is not persistent. - val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE - if (handlePersistence) { + if ($(handlePersistence)) { newDataset.persist(StorageLevel.MEMORY_AND_DISK) } @@ -190,7 +190,7 @@ final class OneVsRestModel private[ml] ( updatedDataset.select(newColumns: _*).withColumnRenamed(tmpColName, accColName) } - if (handlePersistence) { + if ($(handlePersistence)) { newDataset.unpersist() } @@ -346,9 +346,7 @@ final class OneVsRest @Since("1.4.0") ( dataset.select($(labelCol), $(featuresCol)) } - // persist if underlying dataset is not persistent. - val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE - if (handlePersistence) { + if ($(handlePersistence)) { multiclassLabeled.persist(StorageLevel.MEMORY_AND_DISK) } @@ -374,7 +372,7 @@ final class OneVsRest @Since("1.4.0") ( }.toArray[ClassificationModel[_, _]] instr.logNumFeatures(models.head.numFeatures) - if (handlePersistence) { + if ($(handlePersistence)) { multiclassLabeled.unpersist() } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index e02b532ca8a93..3ea7f0b1594b9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -40,7 +40,7 @@ import org.apache.spark.util.VersionUtils.majorVersion * Common params for KMeans and KMeansModel */ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFeaturesCol - with HasSeed with HasPredictionCol with HasTol { + with HasSeed with HasPredictionCol with HasTol with HasHandlePersistence { /** * The number of clusters to create (k). Must be > 1. Note that it is possible for fewer than @@ -300,20 +300,23 @@ class KMeans @Since("1.5.0") ( @Since("1.5.0") def setSeed(value: Long): this.type = set(seed, value) + /** @group setParam */ + @Since("2.3.0") + def setHandlePersistence(value: Boolean): this.type = set(handlePersistence, value) + @Since("2.0.0") override def fit(dataset: Dataset[_]): KMeansModel = { transformSchema(dataset.schema, logging = true) - val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE val instances: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => OldVectors.fromML(point) } - if (handlePersistence) { + if ($(handlePersistence)) { instances.persist(StorageLevel.MEMORY_AND_DISK) } - val instr = Instrumentation.create(this, instances) + val instr = Instrumentation.create(this, dataset) instr.logParams(featuresCol, predictionCol, k, initMode, initSteps, maxIter, seed, tol) val algo = new MLlibKMeans() .setK($(k)) @@ -329,7 +332,7 @@ class KMeans @Since("1.5.0") ( model.setSummary(Some(summary)) instr.logSuccess(model) - if (handlePersistence) { + if ($(handlePersistence)) { instances.unpersist() } model diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 1860fe8361749..fe540a43d6e81 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -82,7 +82,8 @@ private[shared] object SharedParamsCodeGen { "all instance weights as 1.0"), ParamDesc[String]("solver", "the solver algorithm for optimization", finalFields = false), ParamDesc[Int]("aggregationDepth", "suggested depth for treeAggregate (>= 2)", Some("2"), - isValid = "ParamValidators.gtEq(2)", isExpertParam = true)) + isValid = "ParamValidators.gtEq(2)", isExpertParam = true), + ParamDesc[Boolean]("handlePersistence", "whether to handle data persistence", Some("true"))) val code = genSharedParams(params) val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala" diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 6061d9ca0a084..a7a46e8a46e5c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -402,4 +402,21 @@ private[ml] trait HasAggregationDepth extends Params { /** @group expertGetParam */ final def getAggregationDepth: Int = $(aggregationDepth) } + +/** + * Trait for shared param handlePersistence (default: true). + */ +private[ml] trait HasHandlePersistence extends Params { + + /** + * Param for whether to handle data persistence. + * @group param + */ + final val handlePersistence: BooleanParam = new BooleanParam(this, "handlePersistence", "whether to handle data persistence") + + setDefault(handlePersistence, true) + + /** @group getParam */ + final def getHandlePersistence: Boolean = $(handlePersistence) +} // scalastyle:on diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 16821f317760e..e7dfa8b7b748f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -46,7 +46,8 @@ import org.apache.spark.storage.StorageLevel */ private[regression] trait AFTSurvivalRegressionParams extends Params with HasFeaturesCol with HasLabelCol with HasPredictionCol with HasMaxIter - with HasTol with HasFitIntercept with HasAggregationDepth with Logging { + with HasTol with HasFitIntercept with HasAggregationDepth with HasHandlePersistence + with Logging { /** * Param for censor column name. @@ -197,6 +198,10 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value) setDefault(aggregationDepth -> 2) + /** @group setParam */ + @Since("2.3.0") + def setHandlePersistence(value: Boolean): this.type = set(handlePersistence, value) + /** * Extract [[featuresCol]], [[labelCol]] and [[censorCol]] from input dataset, * and put it in an RDD with strong types. @@ -213,8 +218,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S override def fit(dataset: Dataset[_]): AFTSurvivalRegressionModel = { transformSchema(dataset.schema, logging = true) val instances = extractAFTPoints(dataset) - val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE - if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) + if ($(handlePersistence)) instances.persist(StorageLevel.MEMORY_AND_DISK) val featuresSummarizer = { val seqOp = (c: MultivariateOnlineSummarizer, v: AFTPoint) => c.add(v.features) @@ -273,7 +277,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S } bcFeaturesStd.destroy(blocking = false) - if (handlePersistence) instances.unpersist() + if ($(handlePersistence)) instances.unpersist() val rawCoefficients = parameters.slice(2, parameters.length) var i = 0 diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index 529f66eadbcff..301aaaea5d2e0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -39,7 +39,8 @@ import org.apache.spark.storage.StorageLevel * Params for isotonic regression. */ private[regression] trait IsotonicRegressionBase extends Params with HasFeaturesCol - with HasLabelCol with HasPredictionCol with HasWeightCol with Logging { + with HasLabelCol with HasPredictionCol with HasWeightCol with HasHandlePersistence + with Logging { /** * Param for whether the output sequence should be isotonic/increasing (true) or @@ -157,6 +158,10 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri @Since("1.5.0") def setFeatureIndex(value: Int): this.type = set(featureIndex, value) + /** @group setParam */ + @Since("2.3.0") + def setHandlePersistence(value: Boolean): this.type = set(handlePersistence, value) + @Since("1.5.0") override def copy(extra: ParamMap): IsotonicRegression = defaultCopy(extra) @@ -165,8 +170,7 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri transformSchema(dataset.schema, logging = true) // Extract columns from data. If dataset is persisted, do not persist oldDataset. val instances = extractWeightedLabeledPoints(dataset) - val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE - if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) + if ($(handlePersistence)) instances.persist(StorageLevel.MEMORY_AND_DISK) val instr = Instrumentation.create(this, dataset) instr.logParams(labelCol, featuresCol, weightCol, predictionCol, featureIndex, isotonic) @@ -175,7 +179,7 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri val isotonicRegression = new MLlibIsotonicRegression().setIsotonic($(isotonic)) val oldModel = isotonicRegression.run(instances) - if (handlePersistence) instances.unpersist() + if ($(handlePersistence)) instances.unpersist() val model = copyValues(new IsotonicRegressionModel(uid, oldModel).setParent(this)) instr.logSuccess(model) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index ed431f550817e..dc090f85c9a34 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -53,7 +53,7 @@ import org.apache.spark.storage.StorageLevel private[regression] trait LinearRegressionParams extends PredictorParams with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol with HasFitIntercept with HasStandardization with HasWeightCol with HasSolver - with HasAggregationDepth { + with HasAggregationDepth with HasHandlePersistence { import LinearRegression._ @@ -208,6 +208,10 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value) setDefault(aggregationDepth -> 2) + /** @group setParam */ + @Since("2.3.0") + def setHandlePersistence(value: Boolean): this.type = set(handlePersistence, value) + override protected def train(dataset: Dataset[_]): LinearRegressionModel = { // Extract the number of features before deciding optimization solver. val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector](0).size @@ -251,8 +255,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String return lrModel } - val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE - if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) + if ($(handlePersistence)) instances.persist(StorageLevel.MEMORY_AND_DISK) val (featuresSummarizer, ySummarizer) = { val seqOp = (c: (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer), @@ -285,7 +288,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String s"zeros and the intercept will be the mean of the label; as a result, " + s"training is not needed.") } - if (handlePersistence) instances.unpersist() + if ($(handlePersistence)) instances.unpersist() val coefficients = Vectors.sparse(numFeatures, Seq.empty) val intercept = yMean @@ -422,7 +425,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String 0.0 } - if (handlePersistence) instances.unpersist() + if ($(handlePersistence)) instances.unpersist() val model = copyValues(new LinearRegressionModel(uid, coefficients, intercept)) // Handle possible missing or invalid prediction columns diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 4b650000736e2..82511679cdc02 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -444,13 +444,13 @@ class LogisticRegressionWithLBFGS lr.setFitIntercept(addIntercept) lr.setMaxIter(optimizer.getNumIterations()) lr.setTol(optimizer.getConvergenceTol()) + // Determine if we should cache the DF + lr.setHandlePersistence(input.getStorageLevel == StorageLevel.NONE) // Convert our input into a DataFrame val spark = SparkSession.builder().sparkContext(input.context).getOrCreate() val df = spark.createDataFrame(input.map(_.asML)) - // Determine if we should cache the DF - val handlePersistence = input.getStorageLevel == StorageLevel.NONE // Train our model - val mlLogisticRegressionModel = lr.train(df, handlePersistence) + val mlLogisticRegressionModel = lr.train(df) // convert the model val weights = Vectors.dense(mlLogisticRegressionModel.coefficients.toArray) createModel(weights, mlLogisticRegressionModel.intercept) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index dd299e074535e..a2706ae442cbd 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -68,7 +68,12 @@ object MimaExcludes { // [SPARK-14280] Support Scala 2.12 ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.FutureAction.transformWith"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.FutureAction.transform") + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.FutureAction.transform"), + + // [SPARK-18608] Add Param HasHandlePersistence + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasHandlePersistence.handlePersistence"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasHandlePersistence.getHandlePersistence"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasHandlePersistence.org$apache$spark$ml$param$shared$HasHandlePersistence$_setter_$handlePersistence_=") ) // Exclude rules for 2.2.x