-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-7681][MLlib] Add SparseVector support for gemv #6209
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
c069507
5d6d07a
4616696
410381a
054f05d
458d1ae
57a8c1e
b890e63
ce0bb8b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -463,7 +463,7 @@ private[spark] object BLAS extends Serializable with Logging { | |
| def gemv( | ||
| alpha: Double, | ||
| A: Matrix, | ||
| x: DenseVector, | ||
| x: Vector, | ||
| beta: Double, | ||
| y: DenseVector): Unit = { | ||
| require(A.numCols == x.size, | ||
|
|
@@ -473,13 +473,16 @@ private[spark] object BLAS extends Serializable with Logging { | |
| if (alpha == 0.0) { | ||
| logDebug("gemv: alpha is equal to 0. Returning y.") | ||
| } else { | ||
| A match { | ||
| case sparse: SparseMatrix => | ||
| gemv(alpha, sparse, x, beta, y) | ||
| case dense: DenseMatrix => | ||
| gemv(alpha, dense, x, beta, y) | ||
| (A, x) match { | ||
| case (sparse: SparseMatrix, dx: DenseVector) => | ||
| gemv(alpha, sparse, dx, beta, y) | ||
| case (dense: DenseMatrix, dx: DenseVector) => | ||
| gemv(alpha, dense, dx, beta, y) | ||
| case (dense: DenseMatrix, sx: SparseVector) => | ||
| gemv(alpha, dense, sx, beta, y) | ||
| case _ => | ||
| throw new IllegalArgumentException(s"gemv doesn't support matrix type ${A.getClass}.") | ||
| throw new IllegalArgumentException(s"gemv doesn't support running on matrix type " + | ||
| s"${A.getClass} and vector type ${x.getClass}.") | ||
| } | ||
| } | ||
| } | ||
|
|
@@ -500,6 +503,55 @@ private[spark] object BLAS extends Serializable with Logging { | |
| nativeBLAS.dgemv(tStrA, mA, nA, alpha, A.values, mA, x.values, 1, beta, | ||
| y.values, 1) | ||
| } | ||
|
|
||
| /** | ||
| * y := alpha * A * x + beta * y | ||
| * For `DenseMatrix` A and SparseVector x. | ||
| */ | ||
| private def gemv( | ||
| alpha: Double, | ||
| A: DenseMatrix, | ||
| x: SparseVector, | ||
| beta: Double, | ||
| y: DenseVector): Unit = { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. extra space. please also fix the one in |
||
| val mA: Int = A.numRows | ||
| val nA: Int = A.numCols | ||
|
|
||
| val Avals = A.values | ||
| var colCounterForA = 0 | ||
|
|
||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. are you using this? |
||
| var xIndices = x.indices | ||
| var xNnz = xIndices.size | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| var xValues = x.values | ||
|
|
||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. all should be |
||
| scal(beta, y) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why we should check it? |
||
|
|
||
| if (!A.isTransposed) { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| var rowCounterForA = 0 | ||
| while (rowCounterForA < mA) { | ||
| var sum = 0.0 | ||
| var k = 0 | ||
| while (k < xNnz) { | ||
| sum += xValues(k) * Avals(xIndices(k) * mA + rowCounterForA) | ||
| k += 1 | ||
| } | ||
| y.values(rowCounterForA) += sum * alpha | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| rowCounterForA += 1 | ||
| } | ||
| } else { | ||
| var rowCounterForA = 0 | ||
| while (rowCounterForA < mA) { | ||
| var sum = 0.0 | ||
| var k = 0 | ||
| while (k < xNnz) { | ||
| sum += xValues(k) * Avals(xIndices(k) + rowCounterForA * nA) | ||
| k += 1 | ||
| } | ||
| y.values(rowCounterForA) += sum * alpha | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||
| rowCounterForA += 1 | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * y := alpha * A * x + beta * y | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -257,11 +257,12 @@ class BLASSuite extends FunSuite { | |
| new DenseMatrix(4, 3, Array(0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 3.0)) | ||
| val sA = new SparseMatrix(4, 3, Array(0, 1, 3, 4), Array(1, 0, 2, 3), Array(1.0, 2.0, 1.0, 3.0)) | ||
|
|
||
| val x = new DenseVector(Array(1.0, 2.0, 3.0)) | ||
| val dx = new DenseVector(Array(1.0, 2.0, 3.0)) | ||
| val sx = dx.toSparse | ||
| val expected = new DenseVector(Array(4.0, 1.0, 2.0, 9.0)) | ||
|
|
||
| assert(dA.multiply(x) ~== expected absTol 1e-15) | ||
| assert(sA.multiply(x) ~== expected absTol 1e-15) | ||
| assert(dA.multiply(dx) ~== expected absTol 1e-15) | ||
| assert(sA.multiply(dx) ~== expected absTol 1e-15) | ||
|
|
||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @viirya Adding |
||
| val y1 = new DenseVector(Array(1.0, 3.0, 1.0, 0.0)) | ||
| val y2 = y1.copy | ||
|
|
@@ -270,17 +271,26 @@ class BLASSuite extends FunSuite { | |
| val expected2 = new DenseVector(Array(6.0, 7.0, 4.0, 9.0)) | ||
| val expected3 = new DenseVector(Array(10.0, 8.0, 6.0, 18.0)) | ||
|
|
||
| gemv(1.0, dA, x, 2.0, y1) | ||
| gemv(1.0, sA, x, 2.0, y2) | ||
| gemv(2.0, dA, x, 2.0, y3) | ||
| gemv(2.0, sA, x, 2.0, y4) | ||
| gemv(1.0, dA, dx, 2.0, y1) | ||
| gemv(1.0, sA, dx, 2.0, y2) | ||
| gemv(2.0, dA, dx, 2.0, y3) | ||
| gemv(2.0, sA, dx, 2.0, y4) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To the following instead. |
||
| assert(y1 ~== expected2 absTol 1e-15) | ||
| assert(y2 ~== expected2 absTol 1e-15) | ||
| assert(y3 ~== expected3 absTol 1e-15) | ||
| assert(y4 ~== expected3 absTol 1e-15) | ||
|
|
||
| val y1_copy = new DenseVector(Array(1.0, 3.0, 1.0, 0.0)) | ||
| val y3_copy = y1_copy.copy | ||
|
|
||
| gemv(1.0, dA, sx, 2.0, y1_copy) | ||
| gemv(2.0, dA, sx, 2.0, y3_copy) | ||
| assert(y1_copy ~== expected2 absTol 1e-15) | ||
| assert(y3_copy ~== expected3 absTol 1e-15) | ||
|
|
||
| withClue("columns of A don't match the rows of B") { | ||
| intercept[Exception] { | ||
| gemv(1.0, dA.transpose, x, 2.0, y1) | ||
| gemv(1.0, dA.transpose, dx, 2.0, y1) | ||
| } | ||
| } | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. check exception for |
||
| val dAT = | ||
|
|
@@ -291,7 +301,7 @@ class BLASSuite extends FunSuite { | |
| val dATT = dAT.transpose | ||
| val sATT = sAT.transpose | ||
|
|
||
| assert(dATT.multiply(x) ~== expected absTol 1e-15) | ||
| assert(sATT.multiply(x) ~== expected absTol 1e-15) | ||
| assert(dATT.multiply(dx) ~== expected absTol 1e-15) | ||
| assert(sATT.multiply(dx) ~== expected absTol 1e-15) | ||
| } | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add |
||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about
SparseMatrixandSparseVector? To make the consistent naming, we can usedmA,smA,dvx, andsvx.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you don't really want to add
SparseMatrixandSparseVector, the type safety will be broken when you call with this configuration. Previously, this function is totally type safe in compile time, and no way to get into "case _".