1717
1818package  org .apache .spark .mllib .recommendation 
1919
20+ import  org .apache .spark .mllib .linalg .{DenseMatrix , Vectors }
2021import  org .scalatest .FunSuite 
2122
2223import  org .apache .spark .mllib .util .MLlibTestSparkContext 
@@ -92,4 +93,51 @@ class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext
9293    assert(recommendations(2 )(1 ).user ==  0 )
9394    assert(recommendations(2 )(1 ).rating ~==  17.0  relTol 1e-14 )
9495  }
95- }
96+ 
97+   test(" batch similar users/products" 
98+     val  n  =  3 
99+ 
100+     val  userFeatures  =  sc.parallelize(Seq (
101+       (0 , Array (0.0 , 3.0 , 6.0 , 9.0 )),
102+       (1 , Array (1.0 , 4.0 , 7.0 , 0.0 )),
103+       (2 , Array (2.0 , 5.0 , 8.0 , 1.0 ))
104+     ), 2 )
105+ 
106+     val  model  =  new  MatrixFactorizationModel (4 , userFeatures, userFeatures)
107+ 
108+     val  topk  =  2 
109+ 
110+     val  similarUsers  =  model.similarUsers(topk)
111+ 
112+     val  similarProducts  =  model.similarProducts(topk)
113+ 
114+     assert(similarUsers.numRows() ==  n)
115+     assert(similarUsers.entries.count() ==  n *  topk)
116+ 
117+     assert(similarProducts.numRows() ==  n)
118+     assert(similarProducts.entries.count() ==  n *  topk)
119+ 
120+     val  similarEntriesUsers  =  similarUsers.entries.collect()
121+     val  similarEntriesProducts  =  similarProducts.entries.collect()
122+ 
123+     val  colMags  =  Vectors .dense(Math .sqrt(126 ), Math .sqrt(66 ), Math .sqrt(94 ))
124+ 
125+     val  expected  = 
126+       new  DenseMatrix (3 , 3 ,
127+         Array (126.0 , 54.0 , 72.0 , 54.0 , 66.0 , 78.0 , 72.0 , 78.0 , 94.0 ))
128+ 
129+     for  (i <-  0  until n; j <-  0  until n) expected(i, j) /=  (colMags(i) *  colMags(j))
130+ 
131+     similarEntriesUsers.foreach { entry => 
132+       val  row  =  entry.i.toInt
133+       val  col  =  entry.j.toInt
134+       assert(entry.value ~==  expected(row, col) relTol 1e-6 )
135+     }
136+ 
137+     similarEntriesProducts.foreach { entry => 
138+       val  row  =  entry.i.toInt
139+       val  col  =  entry.j.toInt
140+       assert(entry.value ~==  expected(row, col) relTol 1e-6 )
141+     }
142+   }
143+ }
0 commit comments