Skip to content

Commit 376545e

Browse files
bwahlgreenjkbradley
authored andcommitted
[SPARK-17721][MLLIB][BACKPORT] Fix for multiplying transposed SparseMatrix with SparseVector
Backport PR of changes relevant to mllib only, but otherwise identical to #15296 jkbradley Author: Bjarne Fruergaard <[email protected]> Closes #15311 from bwahlgreen/bugfix-spark-17721-1.6.
1 parent b999fa4 commit 376545e

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -638,12 +638,16 @@ private[spark] object BLAS extends Serializable with Logging {
638638
val indEnd = Arows(rowCounter + 1)
639639
var sum = 0.0
640640
var k = 0
641-
while (k < xNnz && i < indEnd) {
641+
while (i < indEnd && k < xNnz) {
642642
if (xIndices(k) == Acols(i)) {
643643
sum += Avals(i) * xValues(k)
644+
k += 1
645+
i += 1
646+
} else if (xIndices(k) < Acols(i)) {
647+
k += 1
648+
} else {
644649
i += 1
645650
}
646-
k += 1
647651
}
648652
yValues(rowCounter) = sum * alpha + beta * yValues(rowCounter)
649653
rowCounter += 1

mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,23 @@ class BLASSuite extends SparkFunSuite {
392392
}
393393
}
394394

395+
val y17 = new DenseVector(Array(0.0, 0.0))
396+
val y18 = y17.copy
397+
398+
val sA3 = new SparseMatrix(3, 2, Array(0, 2, 4), Array(1, 2, 0, 1), Array(2.0, 1.0, 1.0, 2.0))
399+
.transpose
400+
val sA4 =
401+
new SparseMatrix(2, 3, Array(0, 1, 3, 4), Array(1, 0, 1, 0), Array(1.0, 2.0, 2.0, 1.0))
402+
val sx3 = new SparseVector(3, Array(1, 2), Array(2.0, 1.0))
403+
404+
val expected4 = new DenseVector(Array(5.0, 4.0))
405+
406+
gemv(1.0, sA3, sx3, 0.0, y17)
407+
gemv(1.0, sA4, sx3, 0.0, y18)
408+
409+
assert(y17 ~== expected4 absTol 1e-15)
410+
assert(y18 ~== expected4 absTol 1e-15)
411+
395412
val dAT =
396413
new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0))
397414
val sAT =

0 commit comments

Comments
 (0)