Skip to content

Commit 10a63a6

Browse files
committed
[SPARK-4409] Fourth pass of code review
1 parent f62d6c7 commit 10a63a6

File tree

2 files changed

+92
-67
lines changed

2 files changed

+92
-67
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala

Lines changed: 83 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.mllib.linalg
1919

2020
import java.util.{Arrays, Random}
2121

22-
import scala.collection.mutable.{ArrayBuffer, ArrayBuilder, Map => MutableMap}
22+
import scala.collection.mutable.{ArrayBuilder => MArrayBuilder, HashSet => MHashSet, ArrayBuffer}
2323

2424
import breeze.linalg.{CSCMatrix => BSM, DenseMatrix => BDM, Matrix => BM}
2525

@@ -150,11 +150,10 @@ class DenseMatrix(val numRows: Int, val numCols: Int, val values: Array[Double])
150150

151151
/** Generate a `SparseMatrix` from the given `DenseMatrix`. */
152152
def toSparse(): SparseMatrix = {
153-
val spVals: ArrayBuilder[Double] = new ArrayBuilder.ofDouble
153+
val spVals: MArrayBuilder[Double] = new MArrayBuilder.ofDouble
154154
val colPtrs: Array[Int] = new Array[Int](numCols + 1)
155-
val rowIndices: ArrayBuilder[Int] = new ArrayBuilder.ofInt
155+
val rowIndices: MArrayBuilder[Int] = new MArrayBuilder.ofInt
156156
var nnz = 0
157-
var lastCol = -1
158157
var j = 0
159158
while (j < numCols) {
160159
var i = 0
@@ -164,19 +163,12 @@ class DenseMatrix(val numRows: Int, val numCols: Int, val values: Array[Double])
164163
if (v != 0.0) {
165164
rowIndices += i
166165
spVals += v
167-
while (j != lastCol) {
168-
colPtrs(lastCol + 1) = nnz
169-
lastCol += 1
170-
}
171166
nnz += 1
172167
}
173168
i += 1
174169
}
175170
j += 1
176-
}
177-
while (numCols > lastCol) {
178-
colPtrs(lastCol + 1) = nnz
179-
lastCol += 1
171+
colPtrs(j) = nnz
180172
}
181173
new SparseMatrix(numRows, numCols, colPtrs, rowIndices.result(), spVals.result())
182174
}
@@ -362,30 +354,54 @@ object SparseMatrix {
362354

363355
/**
364356
* Generate a `SparseMatrix` from Coordinate List (COO) format. Input must be an array of
365-
* (row, column, value) tuples.
357+
* (i, j, value) tuples. Entries that have duplicate values of i and j are
358+
* added together. Tuples where value is equal to zero will be omitted.
366359
* @param numRows number of rows of the matrix
367360
* @param numCols number of columns of the matrix
368-
* @param entries Array of (row, column, value) tuples
361+
* @param entries Array of (i, j, value) tuples
369362
* @return The corresponding `SparseMatrix`
370363
*/
371364
def fromCOO(numRows: Int, numCols: Int, entries: Array[(Int, Int, Double)]): SparseMatrix = {
372365
val sortedEntries = entries.sortBy(v => (v._2, v._1))
373366
val colPtrs = new Array[Int](numCols + 1)
374367
var nnz = 0
375368
var lastCol = -1
376-
val values = sortedEntries.map { case (i, j, v) =>
377-
while (j != lastCol) {
378-
colPtrs(lastCol + 1) = nnz
379-
lastCol += 1
369+
var lastIndex = -1
370+
sortedEntries.foreach { case (i, j, v) =>
371+
require(i >= 0 && j >= 0, "Negative indices given. Please make sure all indices are " +
372+
s"greater than or equal to zero. i: $i, j: $j, value: $v")
373+
if (v != 0.0) {
374+
while (j != lastCol) {
375+
colPtrs(lastCol + 1) = nnz
376+
lastCol += 1
377+
}
378+
val index = j * numRows + i
379+
if (lastIndex != index) {
380+
nnz += 1
381+
lastIndex = index
382+
}
380383
}
381-
nnz += 1
382-
v
383384
}
384385
while (numCols > lastCol) {
385386
colPtrs(lastCol + 1) = nnz
386387
lastCol += 1
387388
}
388-
new SparseMatrix(numRows, numCols, colPtrs.toArray, sortedEntries.map(_._1), values)
389+
val values = new Array[Double](nnz)
390+
val rowIndices = new Array[Int](nnz)
391+
lastIndex = -1
392+
var cnt = -1
393+
sortedEntries.foreach { case (i, j, v) =>
394+
if (v != 0.0) {
395+
val index = j * numRows + i
396+
if (lastIndex != index) {
397+
cnt += 1
398+
lastIndex = index
399+
}
400+
values(cnt) += v
401+
rowIndices(cnt) = i
402+
}
403+
}
404+
new SparseMatrix(numRows, numCols, colPtrs.toArray, rowIndices, values)
389405
}
390406

391407
/**
@@ -397,57 +413,54 @@ object SparseMatrix {
397413
new SparseMatrix(n, n, (0 to n).toArray, (0 until n).toArray, Array.fill(n)(1.0))
398414
}
399415

400-
/** Generates a `SparseMatrix` with a given random number generator and `method`, which
401-
* specifies the distribution. */
416+
/** Generates the skeleton of a random `SparseMatrix` with a given random number generator. */
402417
private def genRandMatrix(
403418
numRows: Int,
404419
numCols: Int,
405420
density: Double,
406-
rng: Random,
407-
method: Random => Double): SparseMatrix = {
421+
rng: Random): SparseMatrix = {
408422
require(density >= 0.0 && density <= 1.0, "density must be a double in the range " +
409423
s"0.0 <= d <= 1.0. Currently, density: $density")
410424
val length = math.ceil(numRows * numCols * density).toInt
411-
val entries = MutableMap[(Int, Int), Double]()
412425
var i = 0
413426
if (density == 0.0) {
414427
return new SparseMatrix(numRows, numCols, new Array[Int](numCols + 1),
415428
Array[Int](), Array[Double]())
416429
} else if (density == 1.0) {
430+
val rowIndices = Array.tabulate(numCols, numRows)((j, i) => i).flatten
417431
return new SparseMatrix(numRows, numCols, (0 to numRows * numCols by numRows).toArray,
418-
(0 until numRows * numCols).toArray, Array.fill(numRows * numCols)(method(rng)))
432+
rowIndices, new Array[Double](numRows * numCols))
419433
}
420-
// Expected number of iterations is less than 1.5 * length
421-
if (density < 0.34) {
422-
while (i < length) {
423-
var rowIndex = rng.nextInt(numRows)
424-
var colIndex = rng.nextInt(numCols)
425-
while (entries.contains((rowIndex, colIndex))) {
426-
rowIndex = rng.nextInt(numRows)
427-
colIndex = rng.nextInt(numCols)
428-
}
429-
entries += (rowIndex, colIndex) -> method(rng)
430-
i += 1
434+
if (density < 0.34) { // Expected number of iterations is less than 1.5 * length
435+
val entries = MHashSet[(Int, Int)]()
436+
while (entries.size < length) {
437+
entries += ((rng.nextInt(numRows), rng.nextInt(numCols)))
431438
}
439+
val entryList = entries.toArray.map(v => (v._1, v._2, 1.0))
440+
SparseMatrix.fromCOO(numRows, numCols, entryList)
432441
} else { // selection - rejection method
433442
var j = 0
434443
val pool = numRows * numCols
435-
// loop over columns so that the sort in fromCOO requires less sorting
444+
val rowIndexBuilder = new MArrayBuilder.ofInt
445+
val colPtrs = new Array[Int](numCols + 1)
436446
while (i < length && j < numCols) {
437447
var passedInPool = j * numRows
438448
var r = 0
439449
while (i < length && r < numRows) {
440450
if (rng.nextDouble() < 1.0 * (length - i) / (pool - passedInPool)) {
441-
entries += (r, j) -> method(rng)
451+
rowIndexBuilder += r
442452
i += 1
443453
}
444454
r += 1
445455
passedInPool += 1
446456
}
447457
j += 1
458+
colPtrs(j) = i
448459
}
460+
val rowIndices = rowIndexBuilder.result()
461+
new SparseMatrix(numRows, numCols, colPtrs, rowIndices, new Array[Double](rowIndices.size))
449462
}
450-
SparseMatrix.fromCOO(numRows, numCols, entries.toArray.map(v => (v._1._1, v._1._2, v._2)))
463+
451464
}
452465

453466
/**
@@ -461,8 +474,8 @@ object SparseMatrix {
461474
* @return `SparseMatrix` with size `numRows` x `numCols` and values in U(0, 1)
462475
*/
463476
def sprand(numRows: Int, numCols: Int, density: Double, rng: Random): SparseMatrix = {
464-
def method(rand: Random): Double = rand.nextDouble()
465-
genRandMatrix(numRows, numCols, density, rng, method)
477+
val mat = genRandMatrix(numRows, numCols, density, rng)
478+
mat.update(i => rng.nextDouble())
466479
}
467480

468481
/**
@@ -474,8 +487,8 @@ object SparseMatrix {
474487
* @return `SparseMatrix` with size `numRows` x `numCols` and values in N(0, 1)
475488
*/
476489
def sprandn(numRows: Int, numCols: Int, density: Double, rng: Random): SparseMatrix = {
477-
def method(rand: Random): Double = rand.nextGaussian()
478-
genRandMatrix(numRows, numCols, density, rng, method)
490+
val mat = genRandMatrix(numRows, numCols, density, rng)
491+
mat.update(i => rng.nextGaussian())
479492
}
480493

481494
/**
@@ -488,11 +501,10 @@ object SparseMatrix {
488501
val n = vector.size
489502
vector match {
490503
case sVec: SparseVector =>
491-
val indices = sVec.indices
492-
SparseMatrix.fromCOO(n, n, indices.zip(sVec.values).map(v => (v._1, v._1, v._2)))
504+
SparseMatrix.fromCOO(n, n, sVec.indices.zip(sVec.values).map(v => (v._1, v._1, v._2)))
493505
case dVec: DenseVector =>
494-
val values = dVec.values.zipWithIndex
495-
val nnzVals = values.filter(v => v._1 != 0.0)
506+
val entries = dVec.values.zipWithIndex
507+
val nnzVals = entries.filter(v => v._1 != 0.0)
496508
SparseMatrix.fromCOO(n, n, nnzVals.map(v => (v._2, v._2, v._1)))
497509
}
498510
}
@@ -645,11 +657,11 @@ object Matrices {
645657
return matrices(0)
646658
}
647659
val numRows = matrices(0).numRows
648-
var rowsMatch = true
649660
var hasSparse = false
650661
var numCols = 0
651662
matrices.foreach { mat =>
652-
if (numRows != mat.numRows) rowsMatch = false
663+
require(numRows == mat.numRows, "The number of rows of the matrices in this sequence, " +
664+
"don't match!")
653665
mat match {
654666
case sparse: SparseMatrix => hasSparse = true
655667
case dense: DenseMatrix => // empty on purpose
@@ -658,8 +670,6 @@ object Matrices {
658670
}
659671
numCols += mat.numCols
660672
}
661-
require(rowsMatch, "The number of rows of the matrices in this sequence, don't match!")
662-
663673
if (!hasSparse) {
664674
new DenseMatrix(numRows, numCols, matrices.flatMap(_.toArray))
665675
} else {
@@ -723,12 +733,12 @@ object Matrices {
723733
return matrices(0)
724734
}
725735
val numCols = matrices(0).numCols
726-
var colsMatch = true
727736
var hasSparse = false
728737
var numRows = 0
729738
var valsLength = 0
730739
matrices.foreach { mat =>
731-
if (numCols != mat.numCols) colsMatch = false
740+
require(numCols == mat.numCols, "The number of rows of the matrices in this sequence, " +
741+
"don't match!")
732742
mat match {
733743
case sparse: SparseMatrix =>
734744
hasSparse = true
@@ -741,15 +751,26 @@ object Matrices {
741751
numRows += mat.numRows
742752

743753
}
744-
require(colsMatch, "The number of rows of the matrices in this sequence, don't match!")
745-
746754
if (!hasSparse) {
747-
val matData = matrices.zipWithIndex.flatMap { case (mat, ind) =>
755+
val allValues = new Array[Double](numRows * numCols)
756+
var startRow = 0
757+
matrices.foreach { mat =>
758+
var j = 0
759+
val nRows = mat.numRows
748760
val values = mat.toArray
749-
for (j <- 0 until numCols) yield (j, ind,
750-
values.slice(j * mat.numRows, (j + 1) * mat.numRows))
751-
}.sortBy(x => (x._1, x._2))
752-
new DenseMatrix(numRows, numCols, matData.flatMap(_._3))
761+
while (j < numCols) {
762+
var i = 0
763+
val indStart = j * numRows + startRow
764+
val subMatStart = j * nRows
765+
while (i < nRows) {
766+
allValues(indStart + i) = values(subMatStart + i)
767+
i += 1
768+
}
769+
j += 1
770+
}
771+
startRow += nRows
772+
}
773+
new DenseMatrix(numRows, numCols, allValues)
753774
} else {
754775
var startRow = 0
755776
val entries: Array[(Int, Int, Double)] = matrices.flatMap {

mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,12 @@ class MatricesSuite extends FunSuite {
5454
assert(mat.colPtrs.eq(colPtrs), "should not copy data")
5555
assert(mat.rowIndices.eq(rowIndices), "should not copy data")
5656

57-
val entries: Array[(Int, Int, Double)] = Array((1, 0, 1.0), (2, 0, 2.0),
58-
(1, 2, 4.0), (2, 2, 5.0))
57+
val entries: Array[(Int, Int, Double)] = Array((2, 2, 3.0), (1, 0, 1.0), (2, 0, 2.0),
58+
(1, 2, 2.0), (2, 2, 2.0), (1, 2, 2.0), (0, 0, 0.0))
5959

6060
val mat2 = SparseMatrix.fromCOO(m, n, entries)
6161
assert(mat.toBreeze === mat2.toBreeze)
62+
assert(mat2.values.length == 4)
6263
}
6364

6465
test("sparse matrix construction with wrong number of elements") {
@@ -308,12 +309,15 @@ class MatricesSuite extends FunSuite {
308309
test("sprand") {
309310
val rng = mock[Random]
310311
when(rng.nextInt(4)).thenReturn(0, 1, 1, 3, 2, 2, 0, 1, 3, 0)
311-
when(rng.nextDouble()).thenReturn(1.0, 2.0, 3.0, 4.0)
312+
when(rng.nextDouble()).thenReturn(1.0, 2.0, 3.0, 4.0, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0)
312313
val mat = SparseMatrix.sprand(4, 4, 0.25, rng)
313314
assert(mat.numRows === 4)
314315
assert(mat.numCols === 4)
315316
assert(mat.rowIndices.toSeq === Seq(3, 0, 2, 1))
316-
assert(mat.values.toSeq === Seq(4.0, 1.0, 3.0, 2.0))
317+
assert(mat.values.toSeq === Seq(1.0, 2.0, 3.0, 4.0))
318+
val mat2 = SparseMatrix.sprand(2, 3, 1.0, rng)
319+
assert(mat2.rowIndices.toSeq === Seq(0, 1, 0, 1, 0, 1))
320+
assert(mat2.colPtrs.toSeq === Seq(0, 2, 4, 6))
317321
}
318322

319323
test("sprandn") {
@@ -324,6 +328,6 @@ class MatricesSuite extends FunSuite {
324328
assert(mat.numRows === 4)
325329
assert(mat.numCols === 4)
326330
assert(mat.rowIndices.toSeq === Seq(3, 0, 2, 1))
327-
assert(mat.values.toSeq === Seq(4.0, 1.0, 3.0, 2.0))
331+
assert(mat.values.toSeq === Seq(1.0, 2.0, 3.0, 4.0))
328332
}
329333
}

0 commit comments

Comments
 (0)