1818package org .apache .spark .mllib .linalg .distributed
1919
2020import breeze .linalg .{DenseMatrix => BDM }
21- import org .apache .spark .util .Utils
2221
2322import org .apache .spark .{Logging , Partitioner }
24- import org .apache .spark .mllib .linalg ._
25- import org .apache .spark .mllib .rdd .RDDFunctions ._
23+ import org .apache .spark .mllib .linalg .{DenseMatrix , Matrix }
2624import org .apache .spark .rdd .RDD
2725import org .apache .spark .storage .StorageLevel
2826
2927/**
30- * A grid partitioner, which stores every block in a separate partition.
28+ * A grid partitioner, which uses a regular grid to partition coordinates .
3129 *
32- * @param numRowBlocks Number of blocks that form the rows of the matrix.
33- * @param numColBlocks Number of blocks that form the columns of the matrix.
34- * @param suggestedNumPartitions Number of partitions to partition the rdd into. The final number
35- * of partitions will be set to `min(suggestedNumPartitions,
36- * numRowBlocks * numColBlocks)`, because setting the number of
37- * partitions greater than the number of sub matrices is not useful.
30+ * @param rows Number of rows.
31+ * @param cols Number of columns.
32+ * @param rowsPerPart Number of rows per partition, which may be less at the bottom edge.
33+ * @param colsPerPart Number of columns per partition, which may be less at the right edge.
3834 */
3935private [mllib] class GridPartitioner (
40- val numRowBlocks : Int ,
41- val numColBlocks : Int ,
42- suggestedNumPartitions : Int ) extends Partitioner {
43- private val totalBlocks = numRowBlocks.toLong * numColBlocks
44- // Having the number of partitions greater than the number of sub matrices does not help
45- override val numPartitions = math.min(suggestedNumPartitions, totalBlocks).toInt
46-
47- private val blockLengthsPerPartition = findOptimalBlockLengths
48- // Number of neighboring blocks to take in each row
49- private val numRowBlocksPerPartition = blockLengthsPerPartition._1
50- // Number of neighboring blocks to take in each column
51- private val numColBlocksPerPartition = blockLengthsPerPartition._2
52- // Number of rows of partitions
53- private val blocksPerRow = math.ceil(numRowBlocks * 1.0 / numRowBlocksPerPartition).toInt
36+ val rows : Int ,
37+ val cols : Int ,
38+ val rowsPerPart : Int ,
39+ val colsPerPart : Int ) extends Partitioner {
40+
41+ require(rows > 0 )
42+ require(cols > 0 )
43+ require(rowsPerPart > 0 )
44+ require(colsPerPart > 0 )
45+
46+ private val rowPartitions = math.ceil(rows / rowsPerPart).toInt
47+ private val colPartitions = math.ceil(cols / colsPerPart).toInt
48+
49+ override val numPartitions = rowPartitions * colPartitions
5450
5551 /**
56- * Returns the index of the partition the SubMatrix belongs to. Tries to achieve block wise
57- * partitioning.
52+ * Returns the index of the partition the input coordinate belongs to.
5853 *
59- * @param key The key for the SubMatrix. Can be its position in the grid (its column major index)
60- * or a tuple of three integers that are the final row index after the multiplication,
61- * the index of the block to multiply with, and the final column index after the
62- * multiplication.
63- * @return The index of the partition, which the SubMatrix belongs to.
54+ * @param key The coordinate (i, j) or a tuple (i, j, k), where k is the inner index used in
55+ * multiplication. k is ignored in computing partitions.
56+ * @return The index of the partition, which the coordinate belongs to.
6457 */
6558 override def getPartition (key : Any ): Int = {
6659 key match {
67- case (blockRowIndex : Int , blockColIndex : Int ) =>
68- getPartitionId(blockRowIndex, blockColIndex )
69- case (blockRowIndex : Int , innerIndex : Int , blockColIndex : Int ) =>
70- getPartitionId(blockRowIndex, blockColIndex )
60+ case (i : Int , j : Int ) =>
61+ getPartitionId(i, j )
62+ case (i : Int , j : Int , _ : Int ) =>
63+ getPartitionId(i, j )
7164 case _ =>
72- throw new IllegalArgumentException (s " Unrecognized key. key : $key" )
65+ throw new IllegalArgumentException (s " Unrecognized key: $key. " )
7366 }
7467 }
7568
7669 /** Partitions sub-matrices as blocks with neighboring sub-matrices. */
77- private def getPartitionId (blockRowIndex : Int , blockColIndex : Int ): Int = {
78- require(0 <= blockRowIndex && blockRowIndex < numRowBlocks, " The blockRowIndex in the key " +
79- s " must be in the range 0 <= blockRowIndex < numRowBlocks. blockRowIndex: $blockRowIndex, " +
80- s " numRowBlocks: $numRowBlocks" )
81- require(0 <= blockRowIndex && blockColIndex < numColBlocks, " The blockColIndex in the key " +
82- s " must be in the range 0 <= blockRowIndex < numColBlocks. blockColIndex: $blockColIndex, " +
83- s " numColBlocks: $numColBlocks" )
84- // Coordinates of the block
85- val i = blockRowIndex / numRowBlocksPerPartition
86- val j = blockColIndex / numColBlocksPerPartition
87- // The mod shouldn't be required but is added as a guarantee for possible corner cases
88- Utils .nonNegativeMod(j * blocksPerRow + i, numPartitions)
70+ private def getPartitionId (i : Int , j : Int ): Int = {
71+ require(0 <= i && i < rows, s " Row index $i out of range [0, $rows). " )
72+ require(0 <= j && j < cols, s " Column index $j out of range [0, $cols). " )
73+ i / rowsPerPart + j / colsPerPart * rowPartitions
8974 }
9075
91- /** Tries to calculate the optimal number of blocks that should be in each partition. */
92- private def findOptimalBlockLengths : (Int , Int ) = {
93- // Gives the optimal number of blocks that need to be in each partition
94- val targetNumBlocksPerPartition = math.ceil(totalBlocks * 1.0 / numPartitions).toInt
95- // Number of neighboring blocks to take in each row
96- var m = math.ceil(math.sqrt(targetNumBlocksPerPartition)).toInt
97- // Number of neighboring blocks to take in each column
98- var n = math.ceil(targetNumBlocksPerPartition * 1.0 / m).toInt
99- // Try to make m and n close to each other while making sure that we don't exceed the number
100- // of partitions
101- var numBlocksForRows = math.ceil(numRowBlocks * 1.0 / m)
102- var numBlocksForCols = math.ceil(numColBlocks * 1.0 / n)
103- while ((numBlocksForRows * numBlocksForCols > numPartitions) && (m * n != 0 )) {
104- if (numRowBlocks <= numColBlocks) {
105- m += 1
106- n = math.ceil(targetNumBlocksPerPartition * 1.0 / m).toInt
107- } else {
108- n += 1
109- m = math.ceil(targetNumBlocksPerPartition * 1.0 / n).toInt
110- }
111- numBlocksForRows = math.ceil(numRowBlocks * 1.0 / m)
112- numBlocksForCols = math.ceil(numColBlocks * 1.0 / n)
113- }
114- // If a good partitioning scheme couldn't be found, set the side with the smaller dimension to
115- // 1 and the other to the number of targetNumBlocksPerPartition
116- if (m * n == 0 ) {
117- if (numRowBlocks <= numColBlocks) {
118- m = 1
119- n = targetNumBlocksPerPartition
120- } else {
121- n = 1
122- m = targetNumBlocksPerPartition
123- }
124- }
125- (m, n)
126- }
127-
128- /** Checks whether the partitioners have the same characteristics */
12976 override def equals (obj : Any ): Boolean = {
13077 obj match {
13178 case r : GridPartitioner =>
132- (this .numRowBlocks == r.numRowBlocks ) && (this .numColBlocks == r.numColBlocks ) &&
133- (this .numPartitions == r.numPartitions )
79+ (this .rows == r.rows ) && (this .cols == r.cols ) &&
80+ (this .rowsPerPart == r.rowsPerPart) && ( this .colsPerPart == r.colsPerPart )
13481 case _ =>
13582 false
13683 }
13784 }
13885}
13986
87+ private [mllib] object GridPartitioner {
88+
89+ /** Creates a new [[GridPartitioner ]] instance. */
90+ def apply (rows : Int , cols : Int , rowsPerPart : Int , colsPerPart : Int ): GridPartitioner = {
91+ new GridPartitioner (rows, cols, rowsPerPart, colsPerPart)
92+ }
93+
94+ /** Creates a new [[GridPartitioner ]] instance with the input suggested number of partitions. */
95+ def apply (rows : Int , cols : Int , suggestedNumPartitions : Int ): GridPartitioner = {
96+ require(suggestedNumPartitions > 0 )
97+ val scale = 1.0 / math.sqrt(suggestedNumPartitions)
98+ val rowsPerPart = math.round(math.max(scale * rows, 1.0 )).toInt
99+ val colsPerPart = math.round(math.max(scale * cols, 1.0 )).toInt
100+ new GridPartitioner (rows, cols, rowsPerPart, colsPerPart)
101+ }
102+ }
103+
140104/**
141105 * Represents a distributed matrix in blocks of local matrices.
142106 *
143- * @param rdd The RDD of SubMatrices (local matrices) that form this matrix
144- * @param nRows Number of rows of this matrix. If the supplied value is less than or equal to zero,
145- * the number of rows will be calculated when `numRows` is invoked.
146- * @param nCols Number of columns of this matrix. If the supplied value is less than or equal to
147- * zero, the number of columns will be calculated when `numCols` is invoked.
107+ * @param blocks The RDD of sub-matrix blocks (blockRowIndex, blockColIndex, sub-matrix) that form
108+ * this distributed matrix.
148109 * @param rowsPerBlock Number of rows that make up each block. The blocks forming the final
149110 * rows are not required to have the given number of rows
150111 * @param colsPerBlock Number of columns that make up each block. The blocks forming the final
151112 * columns are not required to have the given number of columns
113+ * @param nRows Number of rows of this matrix. If the supplied value is less than or equal to zero,
114+ * the number of rows will be calculated when `numRows` is invoked.
115+ * @param nCols Number of columns of this matrix. If the supplied value is less than or equal to
116+ * zero, the number of columns will be calculated when `numCols` is invoked.
152117 */
153118class BlockMatrix (
154- val rdd : RDD [((Int , Int ), Matrix )],
155- private var nRows : Long ,
156- private var nCols : Long ,
119+ val blocks : RDD [((Int , Int ), Matrix )],
157120 val rowsPerBlock : Int ,
158- val colsPerBlock : Int ) extends DistributedMatrix with Logging {
121+ val colsPerBlock : Int ,
122+ private var nRows : Long ,
123+ private var nCols : Long ) extends DistributedMatrix with Logging {
159124
160- private type SubMatrix = ((Int , Int ), Matrix ) // ((blockRowIndex, blockColIndex), matrix)
125+ private type MatrixBlock = ((Int , Int ), Matrix ) // ((blockRowIndex, blockColIndex), sub- matrix)
161126
162127 /**
163128 * Alternate constructor for BlockMatrix without the input of the number of rows and columns.
@@ -172,45 +137,48 @@ class BlockMatrix(
172137 rdd : RDD [((Int , Int ), Matrix )],
173138 rowsPerBlock : Int ,
174139 colsPerBlock : Int ) = {
175- this (rdd, 0L , 0L , rowsPerBlock, colsPerBlock )
140+ this (rdd, rowsPerBlock, colsPerBlock, 0L , 0L )
176141 }
177142
178- private lazy val dims : (Long , Long ) = getDim
179-
180143 override def numRows (): Long = {
181- if (nRows <= 0L ) nRows = dims._1
144+ if (nRows <= 0L ) estimateDim()
182145 nRows
183146 }
184147
185148 override def numCols (): Long = {
186- if (nCols <= 0L ) nCols = dims._2
149+ if (nCols <= 0L ) estimateDim()
187150 nCols
188151 }
189152
190153 val numRowBlocks = math.ceil(numRows() * 1.0 / rowsPerBlock).toInt
191154 val numColBlocks = math.ceil(numCols() * 1.0 / colsPerBlock).toInt
192155
193156 private [mllib] var partitioner : GridPartitioner =
194- new GridPartitioner (numRowBlocks, numColBlocks, rdd.partitions.length)
195-
196- /** Returns the dimensions of the matrix. */
197- private def getDim : (Long , Long ) = {
198- val (rows, cols) = rdd.map { case ((blockRowIndex, blockColIndex), mat) =>
199- (blockRowIndex * rowsPerBlock + mat.numRows, blockColIndex * colsPerBlock + mat.numCols)
200- }.reduce((x0, x1) => (math.max(x0._1, x1._1), math.max(x0._2, x1._2)))
201-
202- (math.max(rows, nRows), math.max(cols, nCols))
157+ GridPartitioner (numRowBlocks, numColBlocks, suggestedNumPartitions = blocks.partitions.size)
158+
159+ /** Estimates the dimensions of the matrix. */
160+ private def estimateDim (): Unit = {
161+ val (rows, cols) = blocks.map { case ((blockRowIndex, blockColIndex), mat) =>
162+ (blockRowIndex.toLong * rowsPerBlock + mat.numRows,
163+ blockColIndex.toLong * colsPerBlock + mat.numCols)
164+ }.reduce { (x0, x1) =>
165+ (math.max(x0._1, x1._1), math.max(x0._2, x1._2))
166+ }
167+ if (nRows <= 0L ) nRows = rows
168+ assert(rows <= nRows, s " The number of rows $rows is more than claimed $nRows. " )
169+ if (nCols <= 0L ) nCols = cols
170+ assert(cols <= nCols, s " The number of columns $cols is more than claimed $nCols. " )
203171 }
204172
205- /** Cache the underlying RDD. */
206- def cache (): BlockMatrix = {
207- rdd .cache()
173+ /** Caches the underlying RDD. */
174+ def cache (): this . type = {
175+ blocks .cache()
208176 this
209177 }
210178
211- /** Set the storage level for the underlying RDD . */
212- def persist (storageLevel : StorageLevel ): BlockMatrix = {
213- rdd .persist(storageLevel)
179+ /** Persists the underlying RDD with the specified storage level . */
180+ def persist (storageLevel : StorageLevel ): this . type = {
181+ blocks .persist(storageLevel)
214182 this
215183 }
216184
@@ -222,22 +190,22 @@ class BlockMatrix(
222190 s " Int.MaxValue. Currently numCols: ${numCols()}" )
223191 require(numRows() * numCols() < Int .MaxValue , " The length of the values array must be " +
224192 s " less than Int.MaxValue. Currently numRows * numCols: ${numRows() * numCols()}" )
225- val nRows = numRows().toInt
226- val nCols = numCols().toInt
227- val mem = nRows * nCols / 125000
193+ val m = numRows().toInt
194+ val n = numCols().toInt
195+ val mem = m * n / 125000
228196 if (mem > 500 ) logWarning(s " Storing this matrix will require $mem MB of memory! " )
229197
230- val parts = rdd .collect()
231- val values = new Array [Double ](nRows * nCols )
232- parts .foreach { case ((blockRowIndex, blockColIndex), block ) =>
198+ val localBlocks = blocks .collect()
199+ val values = new Array [Double ](m * n )
200+ localBlocks .foreach { case ((blockRowIndex, blockColIndex), submat ) =>
233201 val rowOffset = blockRowIndex * rowsPerBlock
234202 val colOffset = blockColIndex * colsPerBlock
235- block .foreachActive { (i, j, v) =>
236- val indexOffset = (j + colOffset) * nRows + rowOffset + i
203+ submat .foreachActive { (i, j, v) =>
204+ val indexOffset = (j + colOffset) * m + rowOffset + i
237205 values(indexOffset) = v
238206 }
239207 }
240- new DenseMatrix (nRows, nCols , values)
208+ new DenseMatrix (m, n , values)
241209 }
242210
243211 /** Collects data and assembles a local dense breeze matrix (for test only). */
0 commit comments