Skip to content

Commit daff601

Browse files
committed
fix index error of sparse vector
1 parent 6bd0a10 commit daff601

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialMapper.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,13 +99,14 @@ object PolynomialMapper {
9999
val len = numMonomials(currDegree, nDim)
100100
var numToRemoveCum = 0
101101
val allExpansions = lVal.zip(lIdx).flatMap { case (lv, li) =>
102+
val numToRemove = numMonomials(currDegree - 1, nDim - li)
102103
val currExpansions = rVal.zip(rIdx).map { case (rv, ri) =>
103-
val realIdx = li * nDim + ri
104-
(if(realIdx > numToRemoveCum) lv * rv else 0.0, realIdx - numToRemoveCum)
104+
val realIdx = ri - (rLen - numToRemove)
105+
(if (realIdx >= 0) lv * rv else 0.0, numToRemoveCum + realIdx)
105106
}
106-
numToRemoveCum += numMonomials(currDegree - 1, nDim - li)
107+
numToRemoveCum += numToRemove
107108
currExpansions
108-
}
109+
}.filter(_._1 != 0.0)
109110
Vectors.sparse(len, allExpansions.map(_._2), allExpansions.map(_._1))
110111

111112
case _ => throw new Exception("vector types are not match.")

mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialMapperSuite.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,9 @@ class PolynomialMapperSuite extends FunSuite with MLlibTestSparkContext {
7979
}
8080

8181
test("fake") {
82-
val result = collectResult(polynomialMapper.setDegree(2).transform(dataFrame))
82+
polynomialMapper.setDegree(3)
83+
println(polynomialMapper.getDegree)
84+
val result = collectResult(polynomialMapper.transform(dataFrame))
8385
for(r <- result) {
8486
println(r)
8587
}

0 commit comments

Comments
 (0)