Skip to content
66 changes: 59 additions & 7 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ private[spark] object BLAS extends Serializable with Logging {
def gemv(
alpha: Double,
A: Matrix,
x: DenseVector,
x: Vector,
beta: Double,
y: DenseVector): Unit = {
require(A.numCols == x.size,
Expand All @@ -473,13 +473,16 @@ private[spark] object BLAS extends Serializable with Logging {
if (alpha == 0.0) {
logDebug("gemv: alpha is equal to 0. Returning y.")
} else {
A match {
case sparse: SparseMatrix =>
gemv(alpha, sparse, x, beta, y)
case dense: DenseMatrix =>
gemv(alpha, dense, x, beta, y)
(A, x) match {
case (sparse: SparseMatrix, dx: DenseVector) =>
gemv(alpha, sparse, dx, beta, y)
case (dense: DenseMatrix, dx: DenseVector) =>
gemv(alpha, dense, dx, beta, y)
case (dense: DenseMatrix, sx: SparseVector) =>
gemv(alpha, dense, sx, beta, y)
case _ =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about SparseMatrix and SparseVector? To make the consistent naming, we can use dmA, smA, dvx, and svx.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you don't really want to add SparseMatrix and SparseVector, the type safety will be broken when you call with this configuration. Previously, this function is totally type safe in compile time, and no way to get into "case _".

throw new IllegalArgumentException(s"gemv doesn't support matrix type ${A.getClass}.")
throw new IllegalArgumentException(s"gemv doesn't support running on matrix type " +
s"${A.getClass} and vector type ${x.getClass}.")
}
}
}
Expand All @@ -500,6 +503,55 @@ private[spark] object BLAS extends Serializable with Logging {
nativeBLAS.dgemv(tStrA, mA, nA, alpha, A.values, mA, x.values, 1, beta,
y.values, 1)
}

/**
* y := alpha * A * x + beta * y
* For `DenseMatrix` A and SparseVector x.
*/
private def gemv(
alpha: Double,
A: DenseMatrix,
x: SparseVector,
beta: Double,
y: DenseVector): Unit = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extra space. please also fix the one in

  private def gemv(
      alpha: Double,
      A: DenseMatrix,
      x: DenseVector,
      beta: Double,
      y: DenseVector): Unit =  {

val mA: Int = A.numRows
val nA: Int = A.numCols

val Avals = A.values
var colCounterForA = 0

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you using this?

var xIndices = x.indices
var xNnz = xIndices.size
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xIndices.length when you try to get the size of native scala Array. size will call length which is another jvm call.

var xValues = x.values

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all should be val

scal(beta, y)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we should check it?


if (!A.isTransposed) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (A.isTransposed) and change the order of code.

var rowCounterForA = 0
while (rowCounterForA < mA) {
var sum = 0.0
var k = 0
while (k < xNnz) {
sum += xValues(k) * Avals(xIndices(k) * mA + rowCounterForA)
k += 1
}
y.values(rowCounterForA) += sum * alpha
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

y.values is slow. Do val yValues = y.values

rowCounterForA += 1
}
} else {
var rowCounterForA = 0
while (rowCounterForA < mA) {
var sum = 0.0
var k = 0
while (k < xNnz) {
sum += xValues(k) * Avals(xIndices(k) + rowCounterForA * nA)
k += 1
}
y.values(rowCounterForA) += sum * alpha
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

rowCounterForA += 1
}
}
}

/**
* y := alpha * A * x + beta * y
Expand Down
30 changes: 20 additions & 10 deletions mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -257,11 +257,12 @@ class BLASSuite extends FunSuite {
new DenseMatrix(4, 3, Array(0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 3.0))
val sA = new SparseMatrix(4, 3, Array(0, 1, 3, 4), Array(1, 0, 2, 3), Array(1.0, 2.0, 1.0, 3.0))

val x = new DenseVector(Array(1.0, 2.0, 3.0))
val dx = new DenseVector(Array(1.0, 2.0, 3.0))
val sx = dx.toSparse
val expected = new DenseVector(Array(4.0, 1.0, 2.0, 9.0))

assert(dA.multiply(x) ~== expected absTol 1e-15)
assert(sA.multiply(x) ~== expected absTol 1e-15)
assert(dA.multiply(dx) ~== expected absTol 1e-15)
assert(sA.multiply(dx) ~== expected absTol 1e-15)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add

 assert(dA.multiply(sx) ~== expected absTol 1e-15)
 assert(sA.multiply(sx) ~== expected absTol 1e-15)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Matrix.multiply has a problem. Now its signature is multiply(y: DenseVector). Turning it to multiply(y: Vector) or adding new multiply(y: Vector) can't pass binary compatibility check as I did in #6189. I am asking @mengxr whether we can add it into MimaExcludes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@viirya Adding MimaExcludes should be fine since you are making it more generalized.

val y1 = new DenseVector(Array(1.0, 3.0, 1.0, 0.0))
val y2 = y1.copy
Expand All @@ -270,17 +271,26 @@ class BLASSuite extends FunSuite {
val expected2 = new DenseVector(Array(6.0, 7.0, 4.0, 9.0))
val expected3 = new DenseVector(Array(10.0, 8.0, 6.0, 18.0))

gemv(1.0, dA, x, 2.0, y1)
gemv(1.0, sA, x, 2.0, y2)
gemv(2.0, dA, x, 2.0, y3)
gemv(2.0, sA, x, 2.0, y4)
gemv(1.0, dA, dx, 2.0, y1)
gemv(1.0, sA, dx, 2.0, y2)
gemv(2.0, dA, dx, 2.0, y3)
gemv(2.0, sA, dx, 2.0, y4)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To the following instead.

    gemv(1.0, dA, dx, 2.0, y1)
    gemv(1.0, sA, dx, 2.0, y2)
    gemv(1.0, dA, sx, 2.0, y3)
    gemv(1.0, sA, sx, 2.0, y4)

    gemv(2.0, dA, dx, 2.0, y5)
    gemv(2.0, sA, dx, 2.0, y6)
    gemv(2.0, dA, sx, 2.0, y7)
    gemv(2.0, sA, sx, 2.0, y8)

    assert(y1 ~== expected2 absTol 1e-15)
    assert(y2 ~== expected2 absTol 1e-15)
    assert(y3 ~== expected3 absTol 1e-15)
    assert(y4 ~== expected3 absTol 1e-15)
    assert(y5 ~== expected2 absTol 1e-15)
    assert(y6 ~== expected2 absTol 1e-15)
    assert(y7 ~== expected3 absTol 1e-15)
    assert(y8 ~== expected3 absTol 1e-15)

assert(y1 ~== expected2 absTol 1e-15)
assert(y2 ~== expected2 absTol 1e-15)
assert(y3 ~== expected3 absTol 1e-15)
assert(y4 ~== expected3 absTol 1e-15)

val y1_copy = new DenseVector(Array(1.0, 3.0, 1.0, 0.0))
val y3_copy = y1_copy.copy

gemv(1.0, dA, sx, 2.0, y1_copy)
gemv(2.0, dA, sx, 2.0, y3_copy)
assert(y1_copy ~== expected2 absTol 1e-15)
assert(y3_copy ~== expected3 absTol 1e-15)

withClue("columns of A don't match the rows of B") {
intercept[Exception] {
gemv(1.0, dA.transpose, x, 2.0, y1)
gemv(1.0, dA.transpose, dx, 2.0, y1)
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check exception for

gemv(1.0, dA.transpose, dx, 2.0, y1)
gemv(1.0, sA.transpose, dx, 2.0, y1)
gemv(1.0, dA.transpose, sx, 2.0, y1)
gemv(1.0, sA.transpose, sx, 2.0, y1)

val dAT =
Expand All @@ -291,7 +301,7 @@ class BLASSuite extends FunSuite {
val dATT = dAT.transpose
val sATT = sAT.transpose

assert(dATT.multiply(x) ~== expected absTol 1e-15)
assert(sATT.multiply(x) ~== expected absTol 1e-15)
assert(dATT.multiply(dx) ~== expected absTol 1e-15)
assert(sATT.multiply(dx) ~== expected absTol 1e-15)
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add

assert(dATT.multiply(sx) ~== expected absTol 1e-15)
assert(sATT.multiply(sx) ~== expected absTol 1e-15)

}