Skip to content

Commit a8eace2

Browse files
committed
Merge pull request #2 from mengxr/brkyvz-SPARK-3974
Simplify GridPartitioner partitioning
2 parents 5eecd48 + feb32a7 commit a8eace2

File tree

2 files changed

+161
-170
lines changed

2 files changed

+161
-170
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala

Lines changed: 95 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -18,146 +18,111 @@
1818
package org.apache.spark.mllib.linalg.distributed
1919

2020
import breeze.linalg.{DenseMatrix => BDM}
21-
import org.apache.spark.util.Utils
2221

2322
import org.apache.spark.{Logging, Partitioner}
24-
import org.apache.spark.mllib.linalg._
25-
import org.apache.spark.mllib.rdd.RDDFunctions._
23+
import org.apache.spark.mllib.linalg.{DenseMatrix, Matrix}
2624
import org.apache.spark.rdd.RDD
2725
import org.apache.spark.storage.StorageLevel
2826

2927
/**
30-
* A grid partitioner, which stores every block in a separate partition.
28+
* A grid partitioner, which uses a regular grid to partition coordinates.
3129
*
32-
* @param numRowBlocks Number of blocks that form the rows of the matrix.
33-
* @param numColBlocks Number of blocks that form the columns of the matrix.
34-
* @param suggestedNumPartitions Number of partitions to partition the rdd into. The final number
35-
* of partitions will be set to `min(suggestedNumPartitions,
36-
* numRowBlocks * numColBlocks)`, because setting the number of
37-
* partitions greater than the number of sub matrices is not useful.
30+
* @param rows Number of rows.
31+
* @param cols Number of columns.
32+
* @param rowsPerPart Number of rows per partition, which may be less at the bottom edge.
33+
* @param colsPerPart Number of columns per partition, which may be less at the right edge.
3834
*/
3935
private[mllib] class GridPartitioner(
40-
val numRowBlocks: Int,
41-
val numColBlocks: Int,
42-
suggestedNumPartitions: Int) extends Partitioner {
43-
private val totalBlocks = numRowBlocks.toLong * numColBlocks
44-
// Having the number of partitions greater than the number of sub matrices does not help
45-
override val numPartitions = math.min(suggestedNumPartitions, totalBlocks).toInt
46-
47-
private val blockLengthsPerPartition = findOptimalBlockLengths
48-
// Number of neighboring blocks to take in each row
49-
private val numRowBlocksPerPartition = blockLengthsPerPartition._1
50-
// Number of neighboring blocks to take in each column
51-
private val numColBlocksPerPartition = blockLengthsPerPartition._2
52-
// Number of rows of partitions
53-
private val blocksPerRow = math.ceil(numRowBlocks * 1.0 / numRowBlocksPerPartition).toInt
36+
val rows: Int,
37+
val cols: Int,
38+
val rowsPerPart: Int,
39+
val colsPerPart: Int) extends Partitioner {
40+
41+
require(rows > 0)
42+
require(cols > 0)
43+
require(rowsPerPart > 0)
44+
require(colsPerPart > 0)
45+
46+
private val rowPartitions = math.ceil(rows / rowsPerPart).toInt
47+
private val colPartitions = math.ceil(cols / colsPerPart).toInt
48+
49+
override val numPartitions = rowPartitions * colPartitions
5450

5551
/**
56-
* Returns the index of the partition the SubMatrix belongs to. Tries to achieve block wise
57-
* partitioning.
52+
* Returns the index of the partition the input coordinate belongs to.
5853
*
59-
* @param key The key for the SubMatrix. Can be its position in the grid (its column major index)
60-
* or a tuple of three integers that are the final row index after the multiplication,
61-
* the index of the block to multiply with, and the final column index after the
62-
* multiplication.
63-
* @return The index of the partition, which the SubMatrix belongs to.
54+
* @param key The coordinate (i, j) or a tuple (i, j, k), where k is the inner index used in
55+
* multiplication. k is ignored in computing partitions.
56+
* @return The index of the partition, which the coordinate belongs to.
6457
*/
6558
override def getPartition(key: Any): Int = {
6659
key match {
67-
case (blockRowIndex: Int, blockColIndex: Int) =>
68-
getPartitionId(blockRowIndex, blockColIndex)
69-
case (blockRowIndex: Int, innerIndex: Int, blockColIndex: Int) =>
70-
getPartitionId(blockRowIndex, blockColIndex)
60+
case (i: Int, j: Int) =>
61+
getPartitionId(i, j)
62+
case (i: Int, j: Int, _: Int) =>
63+
getPartitionId(i, j)
7164
case _ =>
72-
throw new IllegalArgumentException(s"Unrecognized key. key: $key")
65+
throw new IllegalArgumentException(s"Unrecognized key: $key.")
7366
}
7467
}
7568

7669
/** Partitions sub-matrices as blocks with neighboring sub-matrices. */
77-
private def getPartitionId(blockRowIndex: Int, blockColIndex: Int): Int = {
78-
require(0 <= blockRowIndex && blockRowIndex < numRowBlocks, "The blockRowIndex in the key " +
79-
s"must be in the range 0 <= blockRowIndex < numRowBlocks. blockRowIndex: $blockRowIndex," +
80-
s"numRowBlocks: $numRowBlocks")
81-
require(0 <= blockRowIndex && blockColIndex < numColBlocks, "The blockColIndex in the key " +
82-
s"must be in the range 0 <= blockRowIndex < numColBlocks. blockColIndex: $blockColIndex, " +
83-
s"numColBlocks: $numColBlocks")
84-
// Coordinates of the block
85-
val i = blockRowIndex / numRowBlocksPerPartition
86-
val j = blockColIndex / numColBlocksPerPartition
87-
// The mod shouldn't be required but is added as a guarantee for possible corner cases
88-
Utils.nonNegativeMod(j * blocksPerRow + i, numPartitions)
70+
private def getPartitionId(i: Int, j: Int): Int = {
71+
require(0 <= i && i < rows, s"Row index $i out of range [0, $rows).")
72+
require(0 <= j && j < cols, s"Column index $j out of range [0, $cols).")
73+
i / rowsPerPart + j / colsPerPart * rowPartitions
8974
}
9075

91-
/** Tries to calculate the optimal number of blocks that should be in each partition. */
92-
private def findOptimalBlockLengths: (Int, Int) = {
93-
// Gives the optimal number of blocks that need to be in each partition
94-
val targetNumBlocksPerPartition = math.ceil(totalBlocks * 1.0 / numPartitions).toInt
95-
// Number of neighboring blocks to take in each row
96-
var m = math.ceil(math.sqrt(targetNumBlocksPerPartition)).toInt
97-
// Number of neighboring blocks to take in each column
98-
var n = math.ceil(targetNumBlocksPerPartition * 1.0 / m).toInt
99-
// Try to make m and n close to each other while making sure that we don't exceed the number
100-
// of partitions
101-
var numBlocksForRows = math.ceil(numRowBlocks * 1.0 / m)
102-
var numBlocksForCols = math.ceil(numColBlocks * 1.0 / n)
103-
while ((numBlocksForRows * numBlocksForCols > numPartitions) && (m * n != 0)) {
104-
if (numRowBlocks <= numColBlocks) {
105-
m += 1
106-
n = math.ceil(targetNumBlocksPerPartition * 1.0 / m).toInt
107-
} else {
108-
n += 1
109-
m = math.ceil(targetNumBlocksPerPartition * 1.0 / n).toInt
110-
}
111-
numBlocksForRows = math.ceil(numRowBlocks * 1.0 / m)
112-
numBlocksForCols = math.ceil(numColBlocks * 1.0 / n)
113-
}
114-
// If a good partitioning scheme couldn't be found, set the side with the smaller dimension to
115-
// 1 and the other to the number of targetNumBlocksPerPartition
116-
if (m * n == 0) {
117-
if (numRowBlocks <= numColBlocks) {
118-
m = 1
119-
n = targetNumBlocksPerPartition
120-
} else {
121-
n = 1
122-
m = targetNumBlocksPerPartition
123-
}
124-
}
125-
(m, n)
126-
}
127-
128-
/** Checks whether the partitioners have the same characteristics */
12976
override def equals(obj: Any): Boolean = {
13077
obj match {
13178
case r: GridPartitioner =>
132-
(this.numRowBlocks == r.numRowBlocks) && (this.numColBlocks == r.numColBlocks) &&
133-
(this.numPartitions == r.numPartitions)
79+
(this.rows == r.rows) && (this.cols == r.cols) &&
80+
(this.rowsPerPart == r.rowsPerPart) && (this.colsPerPart == r.colsPerPart)
13481
case _ =>
13582
false
13683
}
13784
}
13885
}
13986

87+
private[mllib] object GridPartitioner {
88+
89+
/** Creates a new [[GridPartitioner]] instance. */
90+
def apply(rows: Int, cols: Int, rowsPerPart: Int, colsPerPart: Int): GridPartitioner = {
91+
new GridPartitioner(rows, cols, rowsPerPart, colsPerPart)
92+
}
93+
94+
/** Creates a new [[GridPartitioner]] instance with the input suggested number of partitions. */
95+
def apply(rows: Int, cols: Int, suggestedNumPartitions: Int): GridPartitioner = {
96+
require(suggestedNumPartitions > 0)
97+
val scale = 1.0 / math.sqrt(suggestedNumPartitions)
98+
val rowsPerPart = math.round(math.max(scale * rows, 1.0)).toInt
99+
val colsPerPart = math.round(math.max(scale * cols, 1.0)).toInt
100+
new GridPartitioner(rows, cols, rowsPerPart, colsPerPart)
101+
}
102+
}
103+
140104
/**
141105
* Represents a distributed matrix in blocks of local matrices.
142106
*
143-
* @param rdd The RDD of SubMatrices (local matrices) that form this matrix
144-
* @param nRows Number of rows of this matrix. If the supplied value is less than or equal to zero,
145-
* the number of rows will be calculated when `numRows` is invoked.
146-
* @param nCols Number of columns of this matrix. If the supplied value is less than or equal to
147-
* zero, the number of columns will be calculated when `numCols` is invoked.
107+
* @param blocks The RDD of sub-matrix blocks (blockRowIndex, blockColIndex, sub-matrix) that form
108+
* this distributed matrix.
148109
* @param rowsPerBlock Number of rows that make up each block. The blocks forming the final
149110
* rows are not required to have the given number of rows
150111
* @param colsPerBlock Number of columns that make up each block. The blocks forming the final
151112
* columns are not required to have the given number of columns
113+
* @param nRows Number of rows of this matrix. If the supplied value is less than or equal to zero,
114+
* the number of rows will be calculated when `numRows` is invoked.
115+
* @param nCols Number of columns of this matrix. If the supplied value is less than or equal to
116+
* zero, the number of columns will be calculated when `numCols` is invoked.
152117
*/
153118
class BlockMatrix(
154-
val rdd: RDD[((Int, Int), Matrix)],
155-
private var nRows: Long,
156-
private var nCols: Long,
119+
val blocks: RDD[((Int, Int), Matrix)],
157120
val rowsPerBlock: Int,
158-
val colsPerBlock: Int) extends DistributedMatrix with Logging {
121+
val colsPerBlock: Int,
122+
private var nRows: Long,
123+
private var nCols: Long) extends DistributedMatrix with Logging {
159124

160-
private type SubMatrix = ((Int, Int), Matrix) // ((blockRowIndex, blockColIndex), matrix)
125+
private type MatrixBlock = ((Int, Int), Matrix) // ((blockRowIndex, blockColIndex), sub-matrix)
161126

162127
/**
163128
* Alternate constructor for BlockMatrix without the input of the number of rows and columns.
@@ -172,45 +137,48 @@ class BlockMatrix(
172137
rdd: RDD[((Int, Int), Matrix)],
173138
rowsPerBlock: Int,
174139
colsPerBlock: Int) = {
175-
this(rdd, 0L, 0L, rowsPerBlock, colsPerBlock)
140+
this(rdd, rowsPerBlock, colsPerBlock, 0L, 0L)
176141
}
177142

178-
private lazy val dims: (Long, Long) = getDim
179-
180143
override def numRows(): Long = {
181-
if (nRows <= 0L) nRows = dims._1
144+
if (nRows <= 0L) estimateDim()
182145
nRows
183146
}
184147

185148
override def numCols(): Long = {
186-
if (nCols <= 0L) nCols = dims._2
149+
if (nCols <= 0L) estimateDim()
187150
nCols
188151
}
189152

190153
val numRowBlocks = math.ceil(numRows() * 1.0 / rowsPerBlock).toInt
191154
val numColBlocks = math.ceil(numCols() * 1.0 / colsPerBlock).toInt
192155

193156
private[mllib] var partitioner: GridPartitioner =
194-
new GridPartitioner(numRowBlocks, numColBlocks, rdd.partitions.length)
195-
196-
/** Returns the dimensions of the matrix. */
197-
private def getDim: (Long, Long) = {
198-
val (rows, cols) = rdd.map { case ((blockRowIndex, blockColIndex), mat) =>
199-
(blockRowIndex * rowsPerBlock + mat.numRows, blockColIndex * colsPerBlock + mat.numCols)
200-
}.reduce((x0, x1) => (math.max(x0._1, x1._1), math.max(x0._2, x1._2)))
201-
202-
(math.max(rows, nRows), math.max(cols, nCols))
157+
GridPartitioner(numRowBlocks, numColBlocks, suggestedNumPartitions = blocks.partitions.size)
158+
159+
/** Estimates the dimensions of the matrix. */
160+
private def estimateDim(): Unit = {
161+
val (rows, cols) = blocks.map { case ((blockRowIndex, blockColIndex), mat) =>
162+
(blockRowIndex.toLong * rowsPerBlock + mat.numRows,
163+
blockColIndex.toLong * colsPerBlock + mat.numCols)
164+
}.reduce { (x0, x1) =>
165+
(math.max(x0._1, x1._1), math.max(x0._2, x1._2))
166+
}
167+
if (nRows <= 0L) nRows = rows
168+
assert(rows <= nRows, s"The number of rows $rows is more than claimed $nRows.")
169+
if (nCols <= 0L) nCols = cols
170+
assert(cols <= nCols, s"The number of columns $cols is more than claimed $nCols.")
203171
}
204172

205-
/** Cache the underlying RDD. */
206-
def cache(): BlockMatrix = {
207-
rdd.cache()
173+
/** Caches the underlying RDD. */
174+
def cache(): this.type = {
175+
blocks.cache()
208176
this
209177
}
210178

211-
/** Set the storage level for the underlying RDD. */
212-
def persist(storageLevel: StorageLevel): BlockMatrix = {
213-
rdd.persist(storageLevel)
179+
/** Persists the underlying RDD with the specified storage level. */
180+
def persist(storageLevel: StorageLevel): this.type = {
181+
blocks.persist(storageLevel)
214182
this
215183
}
216184

@@ -222,22 +190,22 @@ class BlockMatrix(
222190
s"Int.MaxValue. Currently numCols: ${numCols()}")
223191
require(numRows() * numCols() < Int.MaxValue, "The length of the values array must be " +
224192
s"less than Int.MaxValue. Currently numRows * numCols: ${numRows() * numCols()}")
225-
val nRows = numRows().toInt
226-
val nCols = numCols().toInt
227-
val mem = nRows * nCols / 125000
193+
val m = numRows().toInt
194+
val n = numCols().toInt
195+
val mem = m * n / 125000
228196
if (mem > 500) logWarning(s"Storing this matrix will require $mem MB of memory!")
229197

230-
val parts = rdd.collect()
231-
val values = new Array[Double](nRows * nCols)
232-
parts.foreach { case ((blockRowIndex, blockColIndex), block) =>
198+
val localBlocks = blocks.collect()
199+
val values = new Array[Double](m * n)
200+
localBlocks.foreach { case ((blockRowIndex, blockColIndex), submat) =>
233201
val rowOffset = blockRowIndex * rowsPerBlock
234202
val colOffset = blockColIndex * colsPerBlock
235-
block.foreachActive { (i, j, v) =>
236-
val indexOffset = (j + colOffset) * nRows + rowOffset + i
203+
submat.foreachActive { (i, j, v) =>
204+
val indexOffset = (j + colOffset) * m + rowOffset + i
237205
values(indexOffset) = v
238206
}
239207
}
240-
new DenseMatrix(nRows, nCols, values)
208+
new DenseMatrix(m, n, values)
241209
}
242210

243211
/** Collects data and assembles a local dense breeze matrix (for test only). */

0 commit comments

Comments
 (0)