@@ -42,7 +42,8 @@ private[mllib] class GridPartitioner(
4242    override  val  numPartitions :  Int ) extends  Partitioner  {
4343
4444  /**  
45-    * Returns the index of the partition the SubMatrix belongs to. 
45+    * Returns the index of the partition the SubMatrix belongs to. Tries to achieve block wise 
46+    * partitioning. 
4647   * 
4748   * @param  key  The key for the SubMatrix. Can be its position in the grid (its column major index) 
4849   *            or a tuple of three integers that are the final row index after the multiplication, 
@@ -51,22 +52,25 @@ private[mllib] class GridPartitioner(
5152   * @return  The index of the partition, which the SubMatrix belongs to. 
5253   */  
5354  override  def  getPartition (key : Any ):  Int  =  {
55+     val  sqrtPartition  =  math.round(math.sqrt(numPartitions)).toInt
56+     //  numPartitions may not be the square of a number, it can even be a prime number
57+     
5458    key match  {
55-       case  (rowIndex : Int , colIndex : Int ) => 
56-         Utils .nonNegativeMod(rowIndex  +  colIndex  *  numRowBlocks, numPartitions)
57-       case  (rowIndex : Int , innerIndex : Int , colIndex : Int ) => 
58-         Utils .nonNegativeMod(rowIndex  +  colIndex  *  numRowBlocks, numPartitions)
59+       case  (blockRowIndex : Int , blockColIndex : Int ) => 
60+         Utils .nonNegativeMod(blockRowIndex  +  blockColIndex  *  numRowBlocks, numPartitions)
61+       case  (blockRowIndex : Int , innerIndex : Int , blockColIndex : Int ) => 
62+         Utils .nonNegativeMod(blockRowIndex  +  blockColIndex  *  numRowBlocks, numPartitions)
5963      case  _ => 
60-         throw  new  IllegalArgumentException (" Unrecognized key" 
64+         throw  new  IllegalArgumentException (s " Unrecognized key. key:   $ key" )
6165    }
6266  }
6367
6468  /**  Checks whether the partitioners have the same characteristics */  
6569  override  def  equals (obj : Any ):  Boolean  =  {
6670    obj match  {
6771      case  r : GridPartitioner  => 
68-         (this .numPartitions  ==  r.numPartitions ) &&  (this .rowsPerBlock  ==  r.rowsPerBlock)  && 
69-           (this .colsPerBlock ==  r.colsPerBlock)
72+         (this .numRowBlocks  ==  r.numRowBlocks ) &&  (this .numColBlocks  ==  r.numColBlocks) 
73+           (this .rowsPerBlock  ==  r.rowsPerBlock)  &&  ( this . colsPerBlock ==  r.colsPerBlock)
7074      case  _ => 
7175        false 
7276    }
@@ -85,7 +89,7 @@ class BlockMatrix(
8589    val  numColBlocks :  Int ,
8690    val  rdd :  RDD [((Int , Int ), Matrix )]) extends  DistributedMatrix  with  Logging  {
8791
88-   type  SubMatrix  =  ((Int , Int ), Matrix ) //  ((blockRowIndex, blockColIndex), matrix)
92+   private   type  SubMatrix  =  ((Int , Int ), Matrix ) //  ((blockRowIndex, blockColIndex), matrix)
8993
9094  /**  
9195   * Alternate constructor for BlockMatrix without the input of a partitioner. Will use a Grid 
0 commit comments