Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ import org.apache.spark.util.VersionUtils
*/
private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams
with HasRegParam with HasElasticNetParam with HasMaxIter with HasFitIntercept with HasTol
with HasStandardization with HasWeightCol with HasThreshold with HasAggregationDepth {
with HasStandardization with HasWeightCol with HasThreshold with HasAggregationDepth
with HasHandlePersistence {

import org.apache.spark.ml.classification.LogisticRegression.supportedFamilyNames

Expand Down Expand Up @@ -431,6 +432,10 @@ class LogisticRegression @Since("1.2.0") (
@Since("2.2.0")
def setUpperBoundsOnIntercepts(value: Vector): this.type = set(upperBoundsOnIntercepts, value)

/** @group setParam */
@Since("2.3.0")
def setHandlePersistence(value: Boolean): this.type = set(handlePersistence, value)

private def assertBoundConstrainedOptimizationParamsValid(
numCoefficientSets: Int,
numFeatures: Int): Unit = {
Expand Down Expand Up @@ -483,24 +488,17 @@ class LogisticRegression @Since("1.2.0") (
this
}

override protected[spark] def train(dataset: Dataset[_]): LogisticRegressionModel = {
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
train(dataset, handlePersistence)
}

protected[spark] def train(
dataset: Dataset[_],
handlePersistence: Boolean): LogisticRegressionModel = {
protected[spark] def train(dataset: Dataset[_]): LogisticRegressionModel = {
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val instances: RDD[Instance] =
dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}

if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
if ($(handlePersistence)) instances.persist(StorageLevel.MEMORY_AND_DISK)

val instr = Instrumentation.create(this, instances)
val instr = Instrumentation.create(this, dataset)
instr.logParams(regParam, elasticNetParam, standardization, threshold,
maxIter, tol, fitIntercept)

Expand Down Expand Up @@ -878,7 +876,7 @@ class LogisticRegression @Since("1.2.0") (
}
}

if (handlePersistence) instances.unpersist()
if ($(handlePersistence)) instances.unpersist()

val model = copyValues(new LogisticRegressionModel(uid, coefficientMatrix, interceptVector,
numClasses, isMultinomial))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@

package org.apache.spark.ml.classification

import java.util.{List => JList}
import java.util.UUID

import scala.collection.JavaConverters._
import scala.language.existentials

import org.apache.hadoop.fs.Path
Expand All @@ -34,7 +32,7 @@ import org.apache.spark.ml._
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
import org.apache.spark.ml.param.shared.HasWeightCol
import org.apache.spark.ml.param.shared.{HasHandlePersistence, HasWeightCol}
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
Expand All @@ -55,7 +53,7 @@ private[ml] trait ClassifierTypeTrait {
* Params for [[OneVsRest]].
*/
private[ml] trait OneVsRestParams extends PredictorParams
with ClassifierTypeTrait with HasWeightCol {
with ClassifierTypeTrait with HasWeightCol with HasHandlePersistence {

/**
* param for the base binary classifier that we reduce multiclass classification into.
Expand All @@ -67,6 +65,10 @@ private[ml] trait OneVsRestParams extends PredictorParams

/** @group getParam */
def getClassifier: ClassifierType = $(classifier)

/** @group setParam */
@Since("2.3.0")
def setHandlePersistence(value: Boolean): this.type = set(handlePersistence, value)
}

private[ml] object OneVsRestParams extends ClassifierTypeTrait {
Expand Down Expand Up @@ -163,9 +165,7 @@ final class OneVsRestModel private[ml] (
val initUDF = udf { () => Map[Int, Double]() }
val newDataset = dataset.withColumn(accColName, initUDF())

// persist if underlying dataset is not persistent.
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
if (handlePersistence) {
if ($(handlePersistence)) {
newDataset.persist(StorageLevel.MEMORY_AND_DISK)
}

Expand All @@ -190,7 +190,7 @@ final class OneVsRestModel private[ml] (
updatedDataset.select(newColumns: _*).withColumnRenamed(tmpColName, accColName)
}

if (handlePersistence) {
if ($(handlePersistence)) {
newDataset.unpersist()
}

Expand Down Expand Up @@ -346,9 +346,7 @@ final class OneVsRest @Since("1.4.0") (
dataset.select($(labelCol), $(featuresCol))
}

// persist if underlying dataset is not persistent.
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
if (handlePersistence) {
if ($(handlePersistence)) {
multiclassLabeled.persist(StorageLevel.MEMORY_AND_DISK)
}

Expand All @@ -374,7 +372,7 @@ final class OneVsRest @Since("1.4.0") (
}.toArray[ClassificationModel[_, _]]
instr.logNumFeatures(models.head.numFeatures)

if (handlePersistence) {
if ($(handlePersistence)) {
multiclassLabeled.unpersist()
}

Expand Down
13 changes: 8 additions & 5 deletions mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import org.apache.spark.util.VersionUtils.majorVersion
* Common params for KMeans and KMeansModel
*/
private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFeaturesCol
with HasSeed with HasPredictionCol with HasTol {
with HasSeed with HasPredictionCol with HasTol with HasHandlePersistence {

/**
* The number of clusters to create (k). Must be > 1. Note that it is possible for fewer than
Expand Down Expand Up @@ -300,20 +300,23 @@ class KMeans @Since("1.5.0") (
@Since("1.5.0")
def setSeed(value: Long): this.type = set(seed, value)

/** @group setParam */
@Since("2.3.0")
def setHandlePersistence(value: Boolean): this.type = set(handlePersistence, value)

@Since("2.0.0")
override def fit(dataset: Dataset[_]): KMeansModel = {
transformSchema(dataset.schema, logging = true)

val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
val instances: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map {
case Row(point: Vector) => OldVectors.fromML(point)
}

if (handlePersistence) {
if ($(handlePersistence)) {
instances.persist(StorageLevel.MEMORY_AND_DISK)
}

val instr = Instrumentation.create(this, instances)
val instr = Instrumentation.create(this, dataset)
instr.logParams(featuresCol, predictionCol, k, initMode, initSteps, maxIter, seed, tol)
val algo = new MLlibKMeans()
.setK($(k))
Expand All @@ -329,7 +332,7 @@ class KMeans @Since("1.5.0") (

model.setSummary(Some(summary))
instr.logSuccess(model)
if (handlePersistence) {
if ($(handlePersistence)) {
instances.unpersist()
}
model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ private[shared] object SharedParamsCodeGen {
"all instance weights as 1.0"),
ParamDesc[String]("solver", "the solver algorithm for optimization", finalFields = false),
ParamDesc[Int]("aggregationDepth", "suggested depth for treeAggregate (>= 2)", Some("2"),
isValid = "ParamValidators.gtEq(2)", isExpertParam = true))
isValid = "ParamValidators.gtEq(2)", isExpertParam = true),
ParamDesc[Boolean]("handlePersistence", "whether to handle data persistence", Some("true")))

val code = genSharedParams(params)
val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -402,4 +402,21 @@ private[ml] trait HasAggregationDepth extends Params {
/** @group expertGetParam */
final def getAggregationDepth: Int = $(aggregationDepth)
}

/**
* Trait for shared param handlePersistence (default: true).
*/
private[ml] trait HasHandlePersistence extends Params {

/**
* Param for whether to handle data persistence.
* @group param
*/
final val handlePersistence: BooleanParam = new BooleanParam(this, "handlePersistence", "whether to handle data persistence")

setDefault(handlePersistence, true)

/** @group getParam */
final def getHandlePersistence: Boolean = $(handlePersistence)
}
// scalastyle:on
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ import org.apache.spark.storage.StorageLevel
*/
private[regression] trait AFTSurvivalRegressionParams extends Params
with HasFeaturesCol with HasLabelCol with HasPredictionCol with HasMaxIter
with HasTol with HasFitIntercept with HasAggregationDepth with Logging {
with HasTol with HasFitIntercept with HasAggregationDepth with HasHandlePersistence
with Logging {

/**
* Param for censor column name.
Expand Down Expand Up @@ -197,6 +198,10 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value)
setDefault(aggregationDepth -> 2)

/** @group setParam */
@Since("2.3.0")
def setHandlePersistence(value: Boolean): this.type = set(handlePersistence, value)

/**
* Extract [[featuresCol]], [[labelCol]] and [[censorCol]] from input dataset,
* and put it in an RDD with strong types.
Expand All @@ -213,8 +218,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
override def fit(dataset: Dataset[_]): AFTSurvivalRegressionModel = {
transformSchema(dataset.schema, logging = true)
val instances = extractAFTPoints(dataset)
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
if ($(handlePersistence)) instances.persist(StorageLevel.MEMORY_AND_DISK)

val featuresSummarizer = {
val seqOp = (c: MultivariateOnlineSummarizer, v: AFTPoint) => c.add(v.features)
Expand Down Expand Up @@ -273,7 +277,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
}

bcFeaturesStd.destroy(blocking = false)
if (handlePersistence) instances.unpersist()
if ($(handlePersistence)) instances.unpersist()

val rawCoefficients = parameters.slice(2, parameters.length)
var i = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ import org.apache.spark.storage.StorageLevel
* Params for isotonic regression.
*/
private[regression] trait IsotonicRegressionBase extends Params with HasFeaturesCol
with HasLabelCol with HasPredictionCol with HasWeightCol with Logging {
with HasLabelCol with HasPredictionCol with HasWeightCol with HasHandlePersistence
with Logging {

/**
* Param for whether the output sequence should be isotonic/increasing (true) or
Expand Down Expand Up @@ -157,6 +158,10 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri
@Since("1.5.0")
def setFeatureIndex(value: Int): this.type = set(featureIndex, value)

/** @group setParam */
@Since("2.3.0")
def setHandlePersistence(value: Boolean): this.type = set(handlePersistence, value)

@Since("1.5.0")
override def copy(extra: ParamMap): IsotonicRegression = defaultCopy(extra)

Expand All @@ -165,8 +170,7 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri
transformSchema(dataset.schema, logging = true)
// Extract columns from data. If dataset is persisted, do not persist oldDataset.
val instances = extractWeightedLabeledPoints(dataset)
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
if ($(handlePersistence)) instances.persist(StorageLevel.MEMORY_AND_DISK)

val instr = Instrumentation.create(this, dataset)
instr.logParams(labelCol, featuresCol, weightCol, predictionCol, featureIndex, isotonic)
Expand All @@ -175,7 +179,7 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri
val isotonicRegression = new MLlibIsotonicRegression().setIsotonic($(isotonic))
val oldModel = isotonicRegression.run(instances)

if (handlePersistence) instances.unpersist()
if ($(handlePersistence)) instances.unpersist()

val model = copyValues(new IsotonicRegressionModel(uid, oldModel).setParent(this))
instr.logSuccess(model)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ import org.apache.spark.storage.StorageLevel
private[regression] trait LinearRegressionParams extends PredictorParams
with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol
with HasFitIntercept with HasStandardization with HasWeightCol with HasSolver
with HasAggregationDepth {
with HasAggregationDepth with HasHandlePersistence {

import LinearRegression._

Expand Down Expand Up @@ -208,6 +208,10 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value)
setDefault(aggregationDepth -> 2)

/** @group setParam */
@Since("2.3.0")
def setHandlePersistence(value: Boolean): this.type = set(handlePersistence, value)

override protected def train(dataset: Dataset[_]): LinearRegressionModel = {
// Extract the number of features before deciding optimization solver.
val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector](0).size
Expand Down Expand Up @@ -251,8 +255,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
return lrModel
}

val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
if ($(handlePersistence)) instances.persist(StorageLevel.MEMORY_AND_DISK)

val (featuresSummarizer, ySummarizer) = {
val seqOp = (c: (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer),
Expand Down Expand Up @@ -285,7 +288,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
s"zeros and the intercept will be the mean of the label; as a result, " +
s"training is not needed.")
}
if (handlePersistence) instances.unpersist()
if ($(handlePersistence)) instances.unpersist()
val coefficients = Vectors.sparse(numFeatures, Seq.empty)
val intercept = yMean

Expand Down Expand Up @@ -422,7 +425,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
0.0
}

if (handlePersistence) instances.unpersist()
if ($(handlePersistence)) instances.unpersist()

val model = copyValues(new LinearRegressionModel(uid, coefficients, intercept))
// Handle possible missing or invalid prediction columns
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -444,13 +444,13 @@ class LogisticRegressionWithLBFGS
lr.setFitIntercept(addIntercept)
lr.setMaxIter(optimizer.getNumIterations())
lr.setTol(optimizer.getConvergenceTol())
// Determine if we should cache the DF
lr.setHandlePersistence(input.getStorageLevel == StorageLevel.NONE)
// Convert our input into a DataFrame
val spark = SparkSession.builder().sparkContext(input.context).getOrCreate()
val df = spark.createDataFrame(input.map(_.asML))
// Determine if we should cache the DF
val handlePersistence = input.getStorageLevel == StorageLevel.NONE
// Train our model
val mlLogisticRegressionModel = lr.train(df, handlePersistence)
val mlLogisticRegressionModel = lr.train(df)
// convert the model
val weights = Vectors.dense(mlLogisticRegressionModel.coefficients.toArray)
createModel(weights, mlLogisticRegressionModel.intercept)
Expand Down
7 changes: 6 additions & 1 deletion project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,12 @@ object MimaExcludes {

// [SPARK-14280] Support Scala 2.12
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.FutureAction.transformWith"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.FutureAction.transform")
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.FutureAction.transform"),

// [SPARK-18608] Add Param HasHandlePersistence
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasHandlePersistence.handlePersistence"),
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasHandlePersistence.getHandlePersistence"),
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasHandlePersistence.org$apache$spark$ml$param$shared$HasHandlePersistence$_setter_$handlePersistence_=")
)

// Exclude rules for 2.2.x
Expand Down