Skip to content

Commit 048a759

Browse files
zhengruifengsrowen
authored andcommitted
[SPARK-14030][MLLIB] Add parameter check to MLLIB
## What changes were proposed in this pull request? add parameter verification to MLLIB, like numCorrections > 0 tolerance >= 0 iters > 0 regParam >= 0 ## How was this patch tested? manual tests Author: Ruifeng Zheng <[email protected]> Author: Zheng RuiFeng <mllabs@datanode1.(none)> Author: mllabs <mllabs@datanode1.(none)> Author: Zheng RuiFeng <[email protected]> Closes #11852 from zhengruifeng/lbfgs_check.
1 parent 1803bf6 commit 048a759

File tree

13 files changed

+83
-13
lines changed

13 files changed

+83
-13
lines changed

mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,8 @@ class NaiveBayes private (
326326
/** Set the smoothing parameter. Default: 1.0. */
327327
@Since("0.9.0")
328328
def setLambda(lambda: Double): NaiveBayes = {
329+
require(lambda >= 0,
330+
s"Smoothing parameter must be nonnegative but got ${lambda}")
329331
this.lambda = lambda
330332
this
331333
}

mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,9 @@ class GaussianMixture private (
7878
*/
7979
@Since("1.3.0")
8080
def setInitialModel(model: GaussianMixtureModel): this.type = {
81-
if (model.k == k) {
82-
initialModel = Some(model)
83-
} else {
84-
throw new IllegalArgumentException("mismatched cluster count (model.k != k)")
85-
}
81+
require(model.k == k,
82+
s"Mismatched cluster count (model.k ${model.k} != k ${k})")
83+
initialModel = Some(model)
8684
this
8785
}
8886

@@ -97,6 +95,8 @@ class GaussianMixture private (
9795
*/
9896
@Since("1.3.0")
9997
def setK(k: Int): this.type = {
98+
require(k > 0,
99+
s"Number of Gaussians must be positive but got ${k}")
100100
this.k = k
101101
this
102102
}
@@ -112,6 +112,8 @@ class GaussianMixture private (
112112
*/
113113
@Since("1.3.0")
114114
def setMaxIterations(maxIterations: Int): this.type = {
115+
require(maxIterations >= 0,
116+
s"Maximum of iterations must be nonnegative but got ${maxIterations}")
115117
this.maxIterations = maxIterations
116118
this
117119
}
@@ -128,6 +130,8 @@ class GaussianMixture private (
128130
*/
129131
@Since("1.3.0")
130132
def setConvergenceTol(convergenceTol: Double): this.type = {
133+
require(convergenceTol >= 0.0,
134+
s"Convergence tolerance must be nonnegative but got ${convergenceTol}")
131135
this.convergenceTol = convergenceTol
132136
this
133137
}

mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ class KMeans private (
6565
*/
6666
@Since("0.8.0")
6767
def setK(k: Int): this.type = {
68+
require(k > 0,
69+
s"Number of clusters must be positive but got ${k}")
6870
this.k = k
6971
this
7072
}
@@ -80,6 +82,8 @@ class KMeans private (
8082
*/
8183
@Since("0.8.0")
8284
def setMaxIterations(maxIterations: Int): this.type = {
85+
require(maxIterations >= 0,
86+
s"Maximum of iterations must be nonnegative but got ${maxIterations}")
8387
this.maxIterations = maxIterations
8488
this
8589
}
@@ -147,9 +151,8 @@ class KMeans private (
147151
*/
148152
@Since("0.8.0")
149153
def setInitializationSteps(initializationSteps: Int): this.type = {
150-
if (initializationSteps <= 0) {
151-
throw new IllegalArgumentException("Number of initialization steps must be positive")
152-
}
154+
require(initializationSteps > 0,
155+
s"Number of initialization steps must be positive but got ${initializationSteps}")
153156
this.initializationSteps = initializationSteps
154157
this
155158
}
@@ -166,6 +169,8 @@ class KMeans private (
166169
*/
167170
@Since("0.8.0")
168171
def setEpsilon(epsilon: Double): this.type = {
172+
require(epsilon >= 0,
173+
s"Distance threshold must be nonnegative but got ${epsilon}")
169174
this.epsilon = epsilon
170175
this
171176
}

mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,8 @@ class LDA private (
232232
*/
233233
@Since("1.3.0")
234234
def setMaxIterations(maxIterations: Int): this.type = {
235+
require(maxIterations >= 0,
236+
s"Maximum of iterations must be nonnegative but got ${maxIterations}")
235237
this.maxIterations = maxIterations
236238
this
237239
}

mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ class PowerIterationClustering private[clustering] (
137137
*/
138138
@Since("1.3.0")
139139
def setK(k: Int): this.type = {
140+
require(k > 0,
141+
s"Number of clusters must be positive but got ${k}")
140142
this.k = k
141143
this
142144
}
@@ -146,6 +148,8 @@ class PowerIterationClustering private[clustering] (
146148
*/
147149
@Since("1.3.0")
148150
def setMaxIterations(maxIterations: Int): this.type = {
151+
require(maxIterations >= 0,
152+
s"Maximum of iterations must be nonnegative but got ${maxIterations}")
149153
this.maxIterations = maxIterations
150154
this
151155
}

mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,8 @@ class StreamingKMeans @Since("1.2.0") (
178178
*/
179179
@Since("1.2.0")
180180
def setK(k: Int): this.type = {
181+
require(k > 0,
182+
s"Number of clusters must be positive but got ${k}")
181183
this.k = k
182184
this
183185
}
@@ -187,6 +189,8 @@ class StreamingKMeans @Since("1.2.0") (
187189
*/
188190
@Since("1.2.0")
189191
def setDecayFactor(a: Double): this.type = {
192+
require(a >= 0,
193+
s"Decay factor must be nonnegative but got ${a}")
190194
this.decayFactor = a
191195
this
192196
}
@@ -198,6 +202,8 @@ class StreamingKMeans @Since("1.2.0") (
198202
*/
199203
@Since("1.2.0")
200204
def setHalfLife(halfLife: Double, timeUnit: String): this.type = {
205+
require(halfLife > 0,
206+
s"Half life must be positive but got ${halfLife}")
201207
if (timeUnit != StreamingKMeans.BATCHES && timeUnit != StreamingKMeans.POINTS) {
202208
throw new IllegalArgumentException("Invalid time unit for decay: " + timeUnit)
203209
}

mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ import org.apache.spark.rdd.RDD
3030
*/
3131
@Since("1.4.0")
3232
class PCA @Since("1.4.0") (@Since("1.4.0") val k: Int) {
33-
require(k >= 1, s"PCA requires a number of principal components k >= 1 but was given $k")
33+
require(k > 0,
34+
s"Number of principal components must be positive but got ${k}")
3435

3536
/**
3637
* Computes a [[PCAModel]] that contains the principal components of the input vectors.

mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ class Word2Vec extends Serializable with Logging {
8484
*/
8585
@Since("2.0.0")
8686
def setMaxSentenceLength(maxSentenceLength: Int): this.type = {
87+
require(maxSentenceLength > 0,
88+
s"Maximum length of sentences must be positive but got ${maxSentenceLength}")
8789
this.maxSentenceLength = maxSentenceLength
8890
this
8991
}
@@ -93,6 +95,8 @@ class Word2Vec extends Serializable with Logging {
9395
*/
9496
@Since("1.1.0")
9597
def setVectorSize(vectorSize: Int): this.type = {
98+
require(vectorSize > 0,
99+
s"vector size must be positive but got ${vectorSize}")
96100
this.vectorSize = vectorSize
97101
this
98102
}
@@ -102,6 +106,8 @@ class Word2Vec extends Serializable with Logging {
102106
*/
103107
@Since("1.1.0")
104108
def setLearningRate(learningRate: Double): this.type = {
109+
require(learningRate > 0,
110+
s"Initial learning rate must be positive but got ${learningRate}")
105111
this.learningRate = learningRate
106112
this
107113
}
@@ -111,7 +117,8 @@ class Word2Vec extends Serializable with Logging {
111117
*/
112118
@Since("1.1.0")
113119
def setNumPartitions(numPartitions: Int): this.type = {
114-
require(numPartitions > 0, s"numPartitions must be greater than 0 but got $numPartitions")
120+
require(numPartitions > 0,
121+
s"Number of partitions must be positive but got ${numPartitions}")
115122
this.numPartitions = numPartitions
116123
this
117124
}
@@ -122,6 +129,8 @@ class Word2Vec extends Serializable with Logging {
122129
*/
123130
@Since("1.1.0")
124131
def setNumIterations(numIterations: Int): this.type = {
132+
require(numIterations >= 0,
133+
s"Number of iterations must be nonnegative but got ${numIterations}")
125134
this.numIterations = numIterations
126135
this
127136
}
@@ -140,6 +149,8 @@ class Word2Vec extends Serializable with Logging {
140149
*/
141150
@Since("1.6.0")
142151
def setWindowSize(window: Int): this.type = {
152+
require(window > 0,
153+
s"Window of words must be positive but got ${window}")
143154
this.window = window
144155
this
145156
}
@@ -150,6 +161,8 @@ class Word2Vec extends Serializable with Logging {
150161
*/
151162
@Since("1.3.0")
152163
def setMinCount(minCount: Int): this.type = {
164+
require(minCount >= 0,
165+
s"Minimum number of times must be nonnegative but got ${minCount}")
153166
this.minCount = minCount
154167
this
155168
}

mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ class AssociationRules private[fpm] (
5050
*/
5151
@Since("1.5.0")
5252
def setMinConfidence(minConfidence: Double): this.type = {
53-
require(minConfidence >= 0.0 && minConfidence <= 1.0)
53+
require(minConfidence >= 0.0 && minConfidence <= 1.0,
54+
s"Minimal confidence must be in range [0, 1] but got ${minConfidence}")
5455
this.minConfidence = minConfidence
5556
this
5657
}

mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,8 @@ class FPGrowth private (
180180
*/
181181
@Since("1.3.0")
182182
def setMinSupport(minSupport: Double): this.type = {
183+
require(minSupport >= 0.0 && minSupport <= 1.0,
184+
s"Minimal support level must be in range [0, 1] but got ${minSupport}")
183185
this.minSupport = minSupport
184186
this
185187
}
@@ -190,6 +192,8 @@ class FPGrowth private (
190192
*/
191193
@Since("1.3.0")
192194
def setNumPartitions(numPartitions: Int): this.type = {
195+
require(numPartitions > 0,
196+
s"Number of partitions must be positive but got ${numPartitions}")
193197
this.numPartitions = numPartitions
194198
this
195199
}

0 commit comments

Comments
 (0)