Skip to content

Commit 00cabd0

Browse files
author
Debasish Das
committed
testcases for similarUsers, similarProducts
1 parent 2541cd7 commit 00cabd0

File tree

1 file changed

+49
-1
lines changed

1 file changed

+49
-1
lines changed

mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.mllib.recommendation
1919

20+
import org.apache.spark.mllib.linalg.{DenseMatrix, Vectors}
2021
import org.scalatest.FunSuite
2122

2223
import org.apache.spark.mllib.util.MLlibTestSparkContext
@@ -92,4 +93,51 @@ class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext
9293
assert(recommendations(2)(1).user == 0)
9394
assert(recommendations(2)(1).rating ~== 17.0 relTol 1e-14)
9495
}
95-
}
96+
97+
test("batch similar users/products") {
98+
val n = 3
99+
100+
val userFeatures = sc.parallelize(Seq(
101+
(0, Array(0.0, 3.0, 6.0, 9.0)),
102+
(1, Array(1.0, 4.0, 7.0, 0.0)),
103+
(2, Array(2.0, 5.0, 8.0, 1.0))
104+
), 2)
105+
106+
val model = new MatrixFactorizationModel(4, userFeatures, userFeatures)
107+
108+
val topk = 2
109+
110+
val similarUsers = model.similarUsers(topk)
111+
112+
val similarProducts = model.similarProducts(topk)
113+
114+
assert(similarUsers.numRows() == n)
115+
assert(similarUsers.entries.count() == n * topk)
116+
117+
assert(similarProducts.numRows() == n)
118+
assert(similarProducts.entries.count() == n * topk)
119+
120+
val similarEntriesUsers = similarUsers.entries.collect()
121+
val similarEntriesProducts = similarProducts.entries.collect()
122+
123+
val colMags = Vectors.dense(Math.sqrt(126), Math.sqrt(66), Math.sqrt(94))
124+
125+
val expected =
126+
new DenseMatrix(3, 3,
127+
Array(126.0, 54.0, 72.0, 54.0, 66.0, 78.0, 72.0, 78.0, 94.0))
128+
129+
for (i <- 0 until n; j <- 0 until n) expected(i, j) /= (colMags(i) * colMags(j))
130+
131+
similarEntriesUsers.foreach { entry =>
132+
val row = entry.i.toInt
133+
val col = entry.j.toInt
134+
assert(entry.value ~== expected(row, col) relTol 1e-6)
135+
}
136+
137+
similarEntriesProducts.foreach { entry =>
138+
val row = entry.i.toInt
139+
val col = entry.j.toInt
140+
assert(entry.value ~== expected(row, col) relTol 1e-6)
141+
}
142+
}
143+
}

0 commit comments

Comments
 (0)