Skip to content

Commit 96fcec4

Browse files
committed
Reorg some code.
1 parent 43192a4 commit 96fcec4

File tree

1 file changed

+17
-25
lines changed

1 file changed

+17
-25
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
188188
*/
189189
@Since("2.2.0")
190190
val lowerBoundsOnCoefficients: Param[Matrix] = new Param(this, "lowerBoundsOnCoefficients",
191-
"The lower bound of coefficients if fitting under bound constrained optimization.")
191+
"The lower bounds on coefficients if fitting under bound constrained optimization.")
192192

193193
/** @group getParam */
194194
@Since("2.2.0")
@@ -204,7 +204,7 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
204204
*/
205205
@Since("2.2.0")
206206
val upperBoundsOnCoefficients: Param[Matrix] = new Param(this, "upperBoundsOnCoefficients",
207-
"The upper bound of coefficients if fitting under bound constrained optimization.")
207+
"The upper bounds on coefficients if fitting under bound constrained optimization.")
208208

209209
/** @group getParam */
210210
@Since("2.2.0")
@@ -219,7 +219,7 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
219219
*/
220220
@Since("2.2.0")
221221
val lowerBoundsOnIntercepts: Param[Vector] = new Param(this, "lowerBoundsOnIntercepts",
222-
"The lower bound of intercept if fitting under bound constrained optimization.")
222+
"The lower bounds on intercepts if fitting under bound constrained optimization.")
223223

224224
/** @group getParam */
225225
@Since("2.2.0")
@@ -234,17 +234,30 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
234234
*/
235235
@Since("2.2.0")
236236
val upperBoundsOnIntercepts: Param[Vector] = new Param(this, "upperBoundsOnIntercepts",
237-
"The upper bound of coefficients if fitting under bound constrained optimization.")
237+
"The upper bounds on intercepts if fitting under bound constrained optimization.")
238238

239239
/** @group getParam */
240240
@Since("2.2.0")
241241
def getUpperBoundsOnIntercepts: Vector = $(upperBoundsOnIntercepts)
242242

243+
protected def usingBoundConstrainedOptimization: Boolean = {
244+
isSet(lowerBoundsOnCoefficients) || isSet(upperBoundsOnCoefficients) ||
245+
isSet(lowerBoundsOnIntercepts) || isSet(upperBoundsOnIntercepts)
246+
}
247+
243248
override protected def validateAndTransformSchema(
244249
schema: StructType,
245250
fitting: Boolean,
246251
featuresDataType: DataType): StructType = {
247252
checkThresholdConsistency()
253+
if (usingBoundConstrainedOptimization) {
254+
require($(elasticNetParam) == 0.0, "Fitting under bound constrained optimization only " +
255+
s"supports L2 regularization, but got elasticNetParam = $getElasticNetParam.")
256+
}
257+
if (!$(fitIntercept)) {
258+
require(!isSet(lowerBoundsOnIntercepts) && !isSet(upperBoundsOnIntercepts),
259+
"Pls don't set bounds on intercepts if fitting without intercept.")
260+
}
248261
super.validateAndTransformSchema(schema, fitting, featuresDataType)
249262
}
250263
}
@@ -409,11 +422,6 @@ class LogisticRegression @Since("1.2.0") (
409422
@Since("2.2.0")
410423
def setUpperBoundsOnIntercepts(value: Vector): this.type = set(upperBoundsOnIntercepts, value)
411424

412-
private def usingBoundConstrainedOptimization: Boolean = {
413-
isSet(lowerBoundsOnCoefficients) || isSet(upperBoundsOnCoefficients) ||
414-
isSet(lowerBoundsOnIntercepts) || isSet(upperBoundsOnIntercepts)
415-
}
416-
417425
private def assertBoundConstrainedOptimizationParamsValid(
418426
numCoefficientSets: Int,
419427
numFeatures: Int): Unit = {
@@ -459,22 +467,6 @@ class LogisticRegression @Since("1.2.0") (
459467
train(dataset, handlePersistence)
460468
}
461469

462-
@Since("2.2.0")
463-
override def validateAndTransformSchema(
464-
schema: StructType,
465-
fitting: Boolean,
466-
featuresDataType: DataType): StructType = {
467-
if (usingBoundConstrainedOptimization) {
468-
require($(elasticNetParam) == 0.0, "Fitting under bound constrained optimization only " +
469-
s"supports L2 regularization, but got elasticNetParam = $getElasticNetParam.")
470-
}
471-
if (!$(fitIntercept)) {
472-
require(!isSet(lowerBoundsOnIntercepts) && !isSet(upperBoundsOnIntercepts),
473-
"Pls don't set bounds on intercepts if fitting without intercept.")
474-
}
475-
super.validateAndTransformSchema(schema, fitting, featuresDataType)
476-
}
477-
478470
protected[spark] def train(
479471
dataset: Dataset[_],
480472
handlePersistence: Boolean): LogisticRegressionModel = {

0 commit comments

Comments
 (0)