Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,45 @@ import java.lang.{Integer => JavaInteger}

import org.jblas.DoubleMatrix

import org.apache.spark.SparkContext._
import org.apache.spark.Logging
import org.apache.spark.api.java.{JavaPairRDD, JavaRDD}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel

/**
* Model representing the result of matrix factorization.
*
* Note: If you create the model directly using constructor, please be aware that fast prediction
* requires cached user/product features and their associated partitioners.
*
* @param rank Rank for the features in this model.
* @param userFeatures RDD of tuples where each tuple represents the userId and
* the features computed for this user.
* @param productFeatures RDD of tuples where each tuple represents the productId
* and the features computed for this product.
*/
class MatrixFactorizationModel private[mllib] (
class MatrixFactorizationModel(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this now public, it might be good to add either (a) one check upon initialization doing a take(1) and comparing with rank, or (b) runtime checks in the various methods in MatrixFactorizationModel. IMO, it's OK if not since those would both add extra overhead, but perhaps there should be a warning for the constructor noting that the arguments are not checked.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

val rank: Int,
val userFeatures: RDD[(Int, Array[Double])],
val productFeatures: RDD[(Int, Array[Double])]) extends Serializable {
val productFeatures: RDD[(Int, Array[Double])]) extends Serializable with Logging {

require(rank > 0)
validateFeatures("User", userFeatures)
validateFeatures("Product", productFeatures)

/** Validates factors and warns users if there are performance concerns. */
private def validateFeatures(name: String, features: RDD[(Int, Array[Double])]): Unit = {
require(features.first()._2.size == rank,
s"$name feature dimension does not match the rank $rank.")
if (features.partitioner.isEmpty) {
logWarning(s"$name factor does not have a partitioner. "
+ "Prediction on individual records could be slow.")
}
if (features.getStorageLevel == StorageLevel.NONE) {
logWarning(s"$name factor is not cached. Prediction could be slow.")
}
}

/** Predict the rating of one user for one product. */
def predict(user: Int, product: Int): Double = {
val userVector = new DoubleMatrix(userFeatures.lookup(user).head)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.recommendation

import org.scalatest.FunSuite
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Organize imports


import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD

class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext {

val rank = 2
var userFeatures: RDD[(Int, Array[Double])] = _
var prodFeatures: RDD[(Int, Array[Double])] = _

override def beforeAll(): Unit = {
super.beforeAll()
userFeatures = sc.parallelize(Seq((0, Array(1.0, 2.0)), (1, Array(3.0, 4.0))))
prodFeatures = sc.parallelize(Seq((2, Array(5.0, 6.0))))
}

test("constructor") {
val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures)
assert(model.predict(0, 2) ~== 17.0 relTol 1e-14)

intercept[IllegalArgumentException] {
new MatrixFactorizationModel(1, userFeatures, prodFeatures)
}

val userFeatures1 = sc.parallelize(Seq((0, Array(1.0)), (1, Array(3.0))))
intercept[IllegalArgumentException] {
new MatrixFactorizationModel(rank, userFeatures1, prodFeatures)
}

val prodFeatures1 = sc.parallelize(Seq((2, Array(5.0))))
intercept[IllegalArgumentException] {
new MatrixFactorizationModel(rank, userFeatures, prodFeatures1)
}
}
}