@@ -188,7 +188,7 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
188188 */
189189 @ Since (" 2.2.0" )
190190 val lowerBoundsOnCoefficients : Param [Matrix ] = new Param (this , " lowerBoundsOnCoefficients" ,
191- " The lower bound of coefficients if fitting under bound constrained optimization." )
191+ " The lower bounds on coefficients if fitting under bound constrained optimization." )
192192
193193 /** @group getParam */
194194 @ Since (" 2.2.0" )
@@ -204,7 +204,7 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
204204 */
205205 @ Since (" 2.2.0" )
206206 val upperBoundsOnCoefficients : Param [Matrix ] = new Param (this , " upperBoundsOnCoefficients" ,
207- " The upper bound of coefficients if fitting under bound constrained optimization." )
207+ " The upper bounds on coefficients if fitting under bound constrained optimization." )
208208
209209 /** @group getParam */
210210 @ Since (" 2.2.0" )
@@ -219,7 +219,7 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
219219 */
220220 @ Since (" 2.2.0" )
221221 val lowerBoundsOnIntercepts : Param [Vector ] = new Param (this , " lowerBoundsOnIntercepts" ,
222- " The lower bound of intercept if fitting under bound constrained optimization." )
222+ " The lower bounds on intercepts if fitting under bound constrained optimization." )
223223
224224 /** @group getParam */
225225 @ Since (" 2.2.0" )
@@ -234,17 +234,30 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
234234 */
235235 @ Since (" 2.2.0" )
236236 val upperBoundsOnIntercepts : Param [Vector ] = new Param (this , " upperBoundsOnIntercepts" ,
237- " The upper bound of coefficients if fitting under bound constrained optimization." )
237+ " The upper bounds on intercepts if fitting under bound constrained optimization." )
238238
239239 /** @group getParam */
240240 @ Since (" 2.2.0" )
241241 def getUpperBoundsOnIntercepts : Vector = $(upperBoundsOnIntercepts)
242242
243+ protected def usingBoundConstrainedOptimization : Boolean = {
244+ isSet(lowerBoundsOnCoefficients) || isSet(upperBoundsOnCoefficients) ||
245+ isSet(lowerBoundsOnIntercepts) || isSet(upperBoundsOnIntercepts)
246+ }
247+
243248 override protected def validateAndTransformSchema (
244249 schema : StructType ,
245250 fitting : Boolean ,
246251 featuresDataType : DataType ): StructType = {
247252 checkThresholdConsistency()
253+ if (usingBoundConstrainedOptimization) {
254+ require($(elasticNetParam) == 0.0 , " Fitting under bound constrained optimization only " +
255+ s " supports L2 regularization, but got elasticNetParam = $getElasticNetParam. " )
256+ }
257+ if (! $(fitIntercept)) {
258+ require(! isSet(lowerBoundsOnIntercepts) && ! isSet(upperBoundsOnIntercepts),
259+ " Pls don't set bounds on intercepts if fitting without intercept." )
260+ }
248261 super .validateAndTransformSchema(schema, fitting, featuresDataType)
249262 }
250263}
@@ -409,11 +422,6 @@ class LogisticRegression @Since("1.2.0") (
409422 @ Since (" 2.2.0" )
410423 def setUpperBoundsOnIntercepts (value : Vector ): this .type = set(upperBoundsOnIntercepts, value)
411424
412- private def usingBoundConstrainedOptimization : Boolean = {
413- isSet(lowerBoundsOnCoefficients) || isSet(upperBoundsOnCoefficients) ||
414- isSet(lowerBoundsOnIntercepts) || isSet(upperBoundsOnIntercepts)
415- }
416-
417425 private def assertBoundConstrainedOptimizationParamsValid (
418426 numCoefficientSets : Int ,
419427 numFeatures : Int ): Unit = {
@@ -459,22 +467,6 @@ class LogisticRegression @Since("1.2.0") (
459467 train(dataset, handlePersistence)
460468 }
461469
462- @ Since (" 2.2.0" )
463- override def validateAndTransformSchema (
464- schema : StructType ,
465- fitting : Boolean ,
466- featuresDataType : DataType ): StructType = {
467- if (usingBoundConstrainedOptimization) {
468- require($(elasticNetParam) == 0.0 , " Fitting under bound constrained optimization only " +
469- s " supports L2 regularization, but got elasticNetParam = $getElasticNetParam. " )
470- }
471- if (! $(fitIntercept)) {
472- require(! isSet(lowerBoundsOnIntercepts) && ! isSet(upperBoundsOnIntercepts),
473- " Pls don't set bounds on intercepts if fitting without intercept." )
474- }
475- super .validateAndTransformSchema(schema, fitting, featuresDataType)
476- }
477-
478470 protected [spark] def train (
479471 dataset : Dataset [_],
480472 handlePersistence : Boolean ): LogisticRegressionModel = {
0 commit comments