Skip to content

Commit 12dae73

Browse files
committed
[SPARK-4406] FIX: Validate k in SVD
1 parent c082385 commit 12dae73

File tree

3 files changed

+24
-0
lines changed

3 files changed

+24
-0
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ class IndexedRowMatrix(
102102
k: Int,
103103
computeU: Boolean = false,
104104
rCond: Double = 1e-9): SingularValueDecomposition[IndexedRowMatrix, Matrix] = {
105+
106+
require(k >= 1, "k should be at least one.")
105107
val indices = rows.map(_.index)
106108
val svd = toRowMatrix().computeSVD(k, computeU, rCond)
107109
val U = if (computeU) {

mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,16 @@ class IndexedRowMatrixSuite extends FunSuite with MLlibTestSparkContext {
113113
assert(closeToZero(U * brzDiag(s) * V.t - localA))
114114
}
115115

116+
test("validate k in svd") {
117+
val A = new IndexedRowMatrix(indexedRows)
118+
try {
119+
A.computeSVD(-1)
120+
} catch {
121+
case ie: IllegalArgumentException =>
122+
}
123+
}
124+
125+
116126
def closeToZero(G: BDM[Double]): Boolean = {
117127
G.valuesIterator.map(math.abs).sum < 1e-6
118128
}

mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,18 @@ class RowMatrixSuite extends FunSuite with MLlibTestSparkContext {
171171
}
172172
}
173173

174+
test("validate k in svd") {
175+
for (mat <- Seq(denseMat, sparseMat)) {
176+
for (mode <- Seq("auto", "local-svd", "local-eigs", "dist-eigs")) {
177+
try {
178+
mat.computeSVD(-1, computeU = true, 1e-6, 300, 1e-10, mode)
179+
} catch {
180+
case ie: IllegalArgumentException =>
181+
}
182+
}
183+
}
184+
}
185+
174186
def closeToZero(G: BDM[Double]): Boolean = {
175187
G.valuesIterator.map(math.abs).sum < 1e-6
176188
}

0 commit comments

Comments
 (0)