Skip to content

Commit 6d06cbb

Browse files
hhbyyhjeanlyn
authored andcommitted
[SPARK-7983] [MLLIB] Add require for one-based indices in loadLibSVMFile
jira: https://issues.apache.org/jira/browse/SPARK-7983 Customers frequently use zero-based indices in their LIBSVM files. No warnings or errors from Spark will be reported during their computation afterwards, and usually it will lead to wired result for many algorithms (like GBDT). add a quick check. Author: Yuhao Yang <[email protected]> Closes apache#6538 from hhbyyh/loadSVM and squashes the following commits: 79d9c11 [Yuhao Yang] optimization as respond to comments 4310710 [Yuhao Yang] merge conflict 96460f1 [Yuhao Yang] merge conflict 20a2811 [Yuhao Yang] use require 6e4f8ca [Yuhao Yang] add check for ascending order 9956365 [Yuhao Yang] add ut for 0-based loadlibsvm exception 5bd1f9a [Yuhao Yang] add require for one-based in loadLIBSVM
1 parent 9ff596c commit 6d06cbb

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,18 @@ object MLUtils {
8282
val value = indexAndValue(1).toDouble
8383
(index, value)
8484
}.unzip
85+
86+
// check if indices are one-based and in ascending order
87+
var previous = -1
88+
var i = 0
89+
val indicesLength = indices.length
90+
while (i < indicesLength) {
91+
val current = indices(i)
92+
require(current > previous, "indices should be one-based and in ascending order" )
93+
previous = current
94+
i += 1
95+
}
96+
8597
(label, indices.toArray, values.toArray)
8698
}
8799

mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import breeze.linalg.{squaredDistance => breezeSquaredDistance}
2525
import com.google.common.base.Charsets
2626
import com.google.common.io.Files
2727

28+
import org.apache.spark.SparkException
2829
import org.apache.spark.SparkFunSuite
2930
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors}
3031
import org.apache.spark.mllib.regression.LabeledPoint
@@ -108,6 +109,40 @@ class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext {
108109
Utils.deleteRecursively(tempDir)
109110
}
110111

112+
test("loadLibSVMFile throws IllegalArgumentException when indices is zero-based") {
113+
val lines =
114+
"""
115+
|0
116+
|0 0:4.0 4:5.0 6:6.0
117+
""".stripMargin
118+
val tempDir = Utils.createTempDir()
119+
val file = new File(tempDir.getPath, "part-00000")
120+
Files.write(lines, file, Charsets.US_ASCII)
121+
val path = tempDir.toURI.toString
122+
123+
intercept[SparkException] {
124+
loadLibSVMFile(sc, path).collect()
125+
}
126+
Utils.deleteRecursively(tempDir)
127+
}
128+
129+
test("loadLibSVMFile throws IllegalArgumentException when indices is not in ascending order") {
130+
val lines =
131+
"""
132+
|0
133+
|0 3:4.0 2:5.0 6:6.0
134+
""".stripMargin
135+
val tempDir = Utils.createTempDir()
136+
val file = new File(tempDir.getPath, "part-00000")
137+
Files.write(lines, file, Charsets.US_ASCII)
138+
val path = tempDir.toURI.toString
139+
140+
intercept[SparkException] {
141+
loadLibSVMFile(sc, path).collect()
142+
}
143+
Utils.deleteRecursively(tempDir)
144+
}
145+
111146
test("saveAsLibSVMFile") {
112147
val examples = sc.parallelize(Seq(
113148
LabeledPoint(1.1, Vectors.sparse(3, Seq((0, 1.23), (2, 4.56)))),

0 commit comments

Comments
 (0)