@@ -32,10 +32,10 @@ import org.apache.spark.util.Utils
3232 * @param blockIdCol The column index of this block
3333 * @param mat The underlying local matrix
3434 */
35- case class BlockPartition (blockIdRow : Int , blockIdCol : Int , mat : DenseMatrix ) extends Serializable
35+ case class SubMatrix (blockIdRow : Int , blockIdCol : Int , mat : DenseMatrix ) extends Serializable
3636
3737/**
38- * Information about the BlockMatrix maintained on the driver
38+ * Information of the submatrices of the BlockMatrix maintained on the driver
3939 *
4040 * @param partitionId The id of the partition the block is found in
4141 * @param blockIdRow The row index of this block
@@ -45,7 +45,7 @@ case class BlockPartition(blockIdRow: Int, blockIdCol: Int, mat: DenseMatrix) ex
4545 * @param startCol The starting column index with respect to the distributed BlockMatrix
4646 * @param numCols The number of columns in this block
4747 */
48- case class BlockPartitionInfo (
48+ case class SubMatrixInfo (
4949 partitionId : Int ,
5050 blockIdRow : Int ,
5151 blockIdCol : Int ,
@@ -67,6 +67,13 @@ abstract class BlockMatrixPartitioner(
6767 val colPerBlock : Int ) extends Partitioner {
6868 val name : String
6969
70+ /**
71+ * Returns the index of the partition the SubMatrix belongs to.
72+ *
73+ * @param key The key for the SubMatrix. Can be its row index, column index or position in the
74+ * grid.
75+ * @return The index of the partition, which the SubMatrix belongs to.
76+ */
7077 override def getPartition (key : Any ): Int = {
7178 Utils .nonNegativeMod(key.asInstanceOf [Int ], numPartitions)
7279 }
@@ -91,6 +98,7 @@ class GridPartitioner(
9198
9299 override val numPartitions = numRowBlocks * numColBlocks
93100
101+ /** Checks whether the partitioners have the same characteristics */
94102 override def equals (obj : Any ): Boolean = {
95103 obj match {
96104 case r : GridPartitioner =>
@@ -118,6 +126,7 @@ class RowBasedPartitioner(
118126
119127 override val name = " row"
120128
129+ /** Checks whether the partitioners have the same characteristics */
121130 override def equals (obj : Any ): Boolean = {
122131 obj match {
123132 case r : RowBasedPartitioner =>
@@ -145,6 +154,7 @@ class ColumnBasedPartitioner(
145154
146155 override val name = " column"
147156
157+ /** Checks whether the partitioners have the same characteristics */
148158 override def equals (obj : Any ): Boolean = {
149159 obj match {
150160 case p : ColumnBasedPartitioner =>
@@ -163,19 +173,19 @@ class ColumnBasedPartitioner(
163173 *
164174 * @param numRowBlocks Number of blocks that form the rows of this matrix
165175 * @param numColBlocks Number of blocks that form the columns of this matrix
166- * @param rdd The RDD of BlockPartitions (local matrices) that form this matrix
167- * @param partitioner A partitioner that specifies how BlockPartitions are stored in the cluster
176+ * @param rdd The RDD of SubMatrixs (local matrices) that form this matrix
177+ * @param partitioner A partitioner that specifies how SubMatrixs are stored in the cluster
168178 */
169179class BlockMatrix (
170180 val numRowBlocks : Int ,
171181 val numColBlocks : Int ,
172- val rdd : RDD [BlockPartition ],
182+ val rdd : RDD [SubMatrix ],
173183 val partitioner : BlockMatrixPartitioner ) extends DistributedMatrix with Logging {
174184
175185 // A key-value pair RDD is required to partition properly
176- private var matrixRDD : RDD [(Int , BlockPartition )] = keyBy()
186+ private var matrixRDD : RDD [(Int , SubMatrix )] = keyBy()
177187
178- @ transient var blockInfo_ : Map [(Int , Int ), BlockPartitionInfo ] = null
188+ @ transient var blockInfo_ : Map [(Int , Int ), SubMatrixInfo ] = null
179189
180190 private lazy val dims : (Long , Long ) = getDim
181191
@@ -184,40 +194,36 @@ class BlockMatrix(
184194
185195 if (partitioner.name.equals(" column" )) {
186196 require(numColBlocks == partitioner.numPartitions, " The number of column blocks should match" +
187- " the number of partitions of the column partitioner." )
197+ s " the number of partitions of the column partitioner. numColBlocks: $numColBlocks, " +
198+ s " partitioner.numPartitions: ${partitioner.numPartitions}" )
188199 } else if (partitioner.name.equals(" row" )) {
189200 require(numRowBlocks == partitioner.numPartitions, " The number of row blocks should match" +
190- " the number of partitions of the row partitioner." )
201+ s " the number of partitions of the row partitioner. numRowBlocks: $numRowBlocks, " +
202+ s " partitioner.numPartitions: ${partitioner.numPartitions}" )
191203 } else if (partitioner.name.equals(" grid" )) {
192204 require(numRowBlocks * numColBlocks == partitioner.numPartitions, " The number of blocks " +
193- " should match the number of partitions of the grid partitioner." )
205+ s " should match the number of partitions of the grid partitioner. numRowBlocks * " +
206+ s " numColBlocks: ${numRowBlocks * numColBlocks}, " +
207+ s " partitioner.numPartitions: ${partitioner.numPartitions}" )
194208 } else {
195209 throw new IllegalArgumentException (" Unrecognized partitioner." )
196210 }
197211
198- /* Returns the dimensions of the matrix. */
212+ /** Returns the dimensions of the matrix. */
199213 def getDim : (Long , Long ) = {
200214 val bi = getBlockInfo
201215 val xDim = bi.map { x =>
202216 (x._1._1, x._2.numRows.toLong)
203- }.groupBy(x => x._1).values.map { x =>
204- x.head._2.toLong
205- }.reduceLeft {
206- _ + _
207- }
217+ }.groupBy(x => x._1).values.map(_.head._2.toLong).reduceLeft(_ + _)
208218
209219 val yDim = bi.map { x =>
210220 (x._1._2, x._2.numCols.toLong)
211- }.groupBy(x => x._1).values.map { x =>
212- x.head._2.toLong
213- }.reduceLeft {
214- _ + _
215- }
221+ }.groupBy(x => x._1).values.map(_.head._2.toLong).reduceLeft(_ + _)
216222
217223 (xDim, yDim)
218224 }
219225
220- /* Calculates the information for each block and collects it on the driver */
226+ /** Calculates the information for each block and collects it on the driver */
221227 private def calculateBlockInfo (): Unit = {
222228 // collect may cause akka frameSize errors
223229 val blockStartRowColsParts = matrixRDD.mapPartitionsWithIndex { case (partId, iter) =>
@@ -243,38 +249,38 @@ class BlockMatrix(
243249 }.toMap
244250
245251 blockInfo_ = blockStartRowCols.map{ case ((rowId, colId), (partId, numRow, numCol)) =>
246- ((rowId, colId), new BlockPartitionInfo (partId, rowId, colId, cumulativeRowSum(rowId),
252+ ((rowId, colId), new SubMatrixInfo (partId, rowId, colId, cumulativeRowSum(rowId),
247253 numRow, cumulativeColSum(colId), numCol))
248254 }.toMap
249255 }
250256
251- /* Returns a map of the information of the blocks that form the distributed matrix. */
252- def getBlockInfo : Map [(Int , Int ), BlockPartitionInfo ] = {
257+ /** Returns a map of the information of the blocks that form the distributed matrix. */
258+ def getBlockInfo : Map [(Int , Int ), SubMatrixInfo ] = {
253259 if (blockInfo_ == null ) {
254260 calculateBlockInfo()
255261 }
256262 blockInfo_
257263 }
258264
259- /* Returns the Frobenius Norm of the matrix */
265+ /** Returns the Frobenius Norm of the matrix */
260266 def normFro (): Double = {
261267 math.sqrt(rdd.map(lm => lm.mat.values.map(x => math.pow(x, 2 )).sum).reduce(_ + _))
262268 }
263269
264- /* Cache the underlying RDD. */
270+ /** Cache the underlying RDD. */
265271 def cache (): DistributedMatrix = {
266272 matrixRDD.cache()
267273 this
268274 }
269275
270- /* Set the storage level for the underlying RDD. */
276+ /** Set the storage level for the underlying RDD. */
271277 def persist (storageLevel : StorageLevel ): DistributedMatrix = {
272278 matrixRDD.persist(storageLevel)
273279 this
274280 }
275281
276- /* Add a key to the underlying rdd for partitioning and joins. */
277- private def keyBy (part : BlockMatrixPartitioner = partitioner): RDD [(Int , BlockPartition )] = {
282+ /** Add a key to the underlying rdd for partitioning and joins. */
283+ private def keyBy (part : BlockMatrixPartitioner = partitioner): RDD [(Int , SubMatrix )] = {
278284 rdd.map { block =>
279285 part match {
280286 case r : RowBasedPartitioner => (block.blockIdRow, block)
@@ -296,7 +302,7 @@ class BlockMatrix(
296302 this
297303 }
298304
299- /* Collect the distributed matrix on the driver. */
305+ /** Collect the distributed matrix on the driver. */
300306 def collect (): DenseMatrix = {
301307 val parts = rdd.map(x => ((x.blockIdRow, x.blockIdCol), x.mat)).
302308 collect().sortBy(x => (x._1._2, x._1._1))
@@ -324,6 +330,7 @@ class BlockMatrix(
324330 new DenseMatrix (nRows, nCols, values)
325331 }
326332
333+ /** Collects data and assembles a local dense breeze matrix (for test only). */
327334 private [mllib] def toBreeze (): BDM [Double ] = {
328335 val localMat = collect()
329336 new BDM [Double ](localMat.numRows, localMat.numCols, localMat.values)
0 commit comments