Skip to content

Commit ed08214

Browse files
committed
check preconditions and unit tests
1 parent a624c12 commit ed08214

File tree

2 files changed

+79
-4
lines changed

2 files changed

+79
-4
lines changed

mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,16 @@ import java.lang.{Integer => JavaInteger}
2121

2222
import org.jblas.DoubleMatrix
2323

24-
import org.apache.spark.SparkContext._
24+
import org.apache.spark.Logging
2525
import org.apache.spark.api.java.{JavaPairRDD, JavaRDD}
2626
import org.apache.spark.rdd.RDD
27+
import org.apache.spark.storage.StorageLevel
2728

2829
/**
2930
* Model representing the result of matrix factorization.
3031
*
31-
* NB: If you create the model directly using constructor, please be aware that fast prediction
32-
* requires cached user/product features and the availability of their partitioning information.
32+
* Note: If you create the model directly using constructor, please be aware that fast prediction
33+
* requires cached user/product features and their associated partitioners.
3334
*
3435
* @param rank Rank for the features in this model.
3536
* @param userFeatures RDD of tuples where each tuple represents the userId and
@@ -40,7 +41,25 @@ import org.apache.spark.rdd.RDD
4041
class MatrixFactorizationModel(
4142
val rank: Int,
4243
val userFeatures: RDD[(Int, Array[Double])],
43-
val productFeatures: RDD[(Int, Array[Double])]) extends Serializable {
44+
val productFeatures: RDD[(Int, Array[Double])]) extends Serializable with Logging {
45+
46+
require(rank > 0)
47+
validateFeatures("User", userFeatures)
48+
validateFeatures("Product", productFeatures)
49+
50+
/** Validates factors and warns users if there are performance concerns. */
51+
private def validateFeatures(name: String, features: RDD[(Int, Array[Double])]): Unit = {
52+
require(features.first()._2.size == rank,
53+
s"$name feature dimension does not match the rank $rank.")
54+
if (features.partitioner.isEmpty) {
55+
logWarning(s"$name factor does not have a partitioner. "
56+
+ "Prediction on individual records could be slow.")
57+
}
58+
if (features.getStorageLevel == StorageLevel.NONE) {
59+
logWarning(s"$name factor is not cached. Prediction could be slow.")
60+
}
61+
}
62+
4463
/** Predict the rating of one user for one product. */
4564
def predict(user: Int, product: Int): Double = {
4665
val userVector = new DoubleMatrix(userFeatures.lookup(user).head)
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.recommendation
19+
20+
import org.apache.spark.mllib.util.MLlibTestSparkContext
21+
import org.apache.spark.rdd.RDD
22+
import org.scalatest.FunSuite
23+
import org.apache.spark.mllib.util.TestingUtils._
24+
25+
26+
class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext {
27+
28+
val rank = 2
29+
var userFeatures: RDD[(Int, Array[Double])] = _
30+
var prodFeatures: RDD[(Int, Array[Double])] = _
31+
32+
override def beforeAll(): Unit = {
33+
super.beforeAll()
34+
userFeatures = sc.parallelize(Seq((0, Array(1.0, 2.0)), (1, Array(3.0, 4.0))))
35+
prodFeatures = sc.parallelize(Seq((2, Array(5.0, 6.0))))
36+
}
37+
38+
test("constructor") {
39+
val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures)
40+
assert(model.predict(0, 2) ~== 17.0 relTol 1e-14)
41+
42+
intercept[IllegalArgumentException] {
43+
new MatrixFactorizationModel(1, userFeatures, prodFeatures)
44+
}
45+
46+
val userFeatures1 = sc.parallelize(Seq((0, Array(1.0)), (1, Array(3.0))))
47+
intercept[IllegalArgumentException] {
48+
new MatrixFactorizationModel(rank, userFeatures1, prodFeatures)
49+
}
50+
51+
val prodFeatures1 = sc.parallelize(Seq((2, Array(5.0))))
52+
intercept[IllegalArgumentException] {
53+
new MatrixFactorizationModel(rank, userFeatures, prodFeatures1)
54+
}
55+
}
56+
}

0 commit comments

Comments
 (0)