Skip to content

Commit ebd2604

Browse files
author
Your Name
committed
clean-ups, comments.
1 parent 832b066 commit ebd2604

File tree

2 files changed

+19
-15
lines changed

2 files changed

+19
-15
lines changed

mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ import org.apache.spark.ml.util._
4040
import org.apache.spark.mllib.linalg.CholeskyDecomposition
4141
import org.apache.spark.mllib.optimization.NNLS
4242
import org.apache.spark.rdd.RDD
43-
import org.apache.spark.sql.{DataFrame, Dataset}
43+
import org.apache.spark.sql.{DataFrame, Dataset, Row}
44+
import org.apache.spark.sql.catalyst.encoders.RowEncoder
4445
import org.apache.spark.sql.functions._
4546
import org.apache.spark.sql.types._
4647
import org.apache.spark.storage.StorageLevel
@@ -250,7 +251,8 @@ class ALSModel private[ml] (
250251

251252
private val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) =>
252253
if (userFeatures != null && itemFeatures != null) {
253-
// TODO: try dot-producting on Seqs or another non-converted type for potential optimization
254+
// TODO(SPARK-19759): try dot-producting on Seqs or another non-converted type for
255+
// potential optimization.
254256
blas.sdot(rank, userFeatures.toArray, 1, itemFeatures.toArray, 1)
255257
} else {
256258
Float.NaN
@@ -288,48 +290,50 @@ class ALSModel private[ml] (
288290
override def write: MLWriter = new ALSModel.ALSModelWriter(this)
289291

290292
/**
291-
* Returns top `num` items recommended for each user, for all users.
292-
* @param num number of recommendations for each user
293+
* Returns top `numItems` items recommended for each user, for all users.
294+
* @param numItems max number of recommendations for each user
293295
* @return a DataFrame of (userCol: Int, recommendations), where recommendations are
294296
* stored as an array of (itemId: Int, rating: Double) tuples.
295297
*/
296298
@Since("2.2.0")
297-
def recommendForAllUsers(num: Int): DataFrame = {
298-
recommendForAll(userFactors, itemFactors, $(userCol), num)
299+
def recommendForAllUsers(numItems: Int): DataFrame = {
300+
recommendForAll(userFactors, itemFactors, $(userCol), $(itemCol), numItems)
299301
}
300302

301303
/**
302-
* Returns top `num` users recommended for each item, for all items.
303-
* @param num number of recommendations for each item
304+
* Returns top `numUsers` users recommended for each item, for all items.
305+
* @param numUsers max number of recommendations for each item
304306
* @return a DataFrame of (itemCol: Int, recommendations), where recommendations are
305307
* stored as an array of (userId: Int, rating: Double) tuples.
306308
*/
307309
@Since("2.2.0")
308-
def recommendForAllItems(num: Int): DataFrame = {
309-
recommendForAll(itemFactors, userFactors, $(itemCol), num)
310+
def recommendForAllItems(numUsers: Int): DataFrame = {
311+
recommendForAll(itemFactors, userFactors, $(itemCol), $(userCol), numUsers)
310312
}
311313

312314
/**
313315
* Makes recommendations for all users (or items).
314316
* @param srcFactors src factors for which to generate recommendations
315317
* @param dstFactors dst factors used to make recommendations
316318
* @param srcOutputColumn name of the column for the source in the output DataFrame
317-
* @param num number of recommendations for each record
319+
* @param dstOutputColumn name of the column for the destination in the output DataFrame
320+
* @param num max number of recommendations for each record
318321
* @return a DataFrame of (srcOutputColumn: Int, recommendations), where recommendations are
319322
* stored as an array of (dstId: Int, rating: Double) tuples.
320323
*/
321324
private def recommendForAll(
322325
srcFactors: DataFrame,
323326
dstFactors: DataFrame,
324327
srcOutputColumn: String,
328+
dstOutputColumn: String,
325329
num: Int): DataFrame = {
326330
import srcFactors.sparkSession.implicits._
327331

328332
val ratings = srcFactors.crossJoin(dstFactors)
329333
.select(
330-
srcFactors("id").as("srcId"),
331-
dstFactors("id").as("dstId"),
332-
predict(srcFactors("features"), dstFactors("features")).as($(predictionCol)))
334+
srcFactors("id"),
335+
dstFactors("id"),
336+
predict(srcFactors("features"), dstFactors("features")))
333337
// We'll force the IDs to be Int. Unfortunately this converts IDs to Int in the output.
334338
val topKAggregator = new TopByKeyAggregator[Int, Int, Float](num, Ordering.by(_._2))
335339
ratings.as[(Int, Int, Float)].groupByKey(_._1).agg(topKAggregator.toColumn)

mllib/src/main/scala/org/apache/spark/ml/recommendation/TopByKeyAggregator.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,5 +56,5 @@ private[recommendation] class TopByKeyAggregator[K1: TypeTag, K2: TypeTag, V: Ty
5656
Encoders.kryo[BoundedPriorityQueue[(K2, V)]]
5757
}
5858

59-
override def outputEncoder: Encoder[Array[(K2, V)]] = ExpressionEncoder[Array[(K2, V)]]
59+
override def outputEncoder: Encoder[Array[(K2, V)]] = ExpressionEncoder[Array[(K2, V)]]()
6060
}

0 commit comments

Comments
 (0)