@@ -40,7 +40,8 @@ import org.apache.spark.ml.util._
4040import org .apache .spark .mllib .linalg .CholeskyDecomposition
4141import org .apache .spark .mllib .optimization .NNLS
4242import org .apache .spark .rdd .RDD
43- import org .apache .spark .sql .{DataFrame , Dataset }
43+ import org .apache .spark .sql .{DataFrame , Dataset , Row }
44+ import org .apache .spark .sql .catalyst .encoders .RowEncoder
4445import org .apache .spark .sql .functions ._
4546import org .apache .spark .sql .types ._
4647import org .apache .spark .storage .StorageLevel
@@ -250,7 +251,8 @@ class ALSModel private[ml] (
250251
251252 private val predict = udf { (userFeatures : Seq [Float ], itemFeatures : Seq [Float ]) =>
252253 if (userFeatures != null && itemFeatures != null ) {
253- // TODO: try dot-producting on Seqs or another non-converted type for potential optimization
254+ // TODO(SPARK-19759): try dot-producting on Seqs or another non-converted type for
255+ // potential optimization.
254256 blas.sdot(rank, userFeatures.toArray, 1 , itemFeatures.toArray, 1 )
255257 } else {
256258 Float .NaN
@@ -288,48 +290,50 @@ class ALSModel private[ml] (
288290 override def write : MLWriter = new ALSModel .ALSModelWriter (this )
289291
290292 /**
291- * Returns top `num ` items recommended for each user, for all users.
292- * @param num number of recommendations for each user
293+ * Returns top `numItems ` items recommended for each user, for all users.
294+ * @param numItems max number of recommendations for each user
293295 * @return a DataFrame of (userCol: Int, recommendations), where recommendations are
294296 * stored as an array of (itemId: Int, rating: Double) tuples.
295297 */
296298 @ Since (" 2.2.0" )
297- def recommendForAllUsers (num : Int ): DataFrame = {
298- recommendForAll(userFactors, itemFactors, $(userCol), num )
299+ def recommendForAllUsers (numItems : Int ): DataFrame = {
300+ recommendForAll(userFactors, itemFactors, $(userCol), $(itemCol), numItems )
299301 }
300302
301303 /**
302- * Returns top `num ` users recommended for each item, for all items.
303- * @param num number of recommendations for each item
304+ * Returns top `numUsers ` users recommended for each item, for all items.
305+ * @param numUsers max number of recommendations for each item
304306 * @return a DataFrame of (itemCol: Int, recommendations), where recommendations are
305307 * stored as an array of (userId: Int, rating: Double) tuples.
306308 */
307309 @ Since (" 2.2.0" )
308- def recommendForAllItems (num : Int ): DataFrame = {
309- recommendForAll(itemFactors, userFactors, $(itemCol), num )
310+ def recommendForAllItems (numUsers : Int ): DataFrame = {
311+ recommendForAll(itemFactors, userFactors, $(itemCol), $(userCol), numUsers )
310312 }
311313
312314 /**
313315 * Makes recommendations for all users (or items).
314316 * @param srcFactors src factors for which to generate recommendations
315317 * @param dstFactors dst factors used to make recommendations
316318 * @param srcOutputColumn name of the column for the source in the output DataFrame
317- * @param num number of recommendations for each record
319+ * @param dstOutputColumn name of the column for the destination in the output DataFrame
320+ * @param num max number of recommendations for each record
318321 * @return a DataFrame of (srcOutputColumn: Int, recommendations), where recommendations are
319322 * stored as an array of (dstId: Int, rating: Double) tuples.
320323 */
321324 private def recommendForAll (
322325 srcFactors : DataFrame ,
323326 dstFactors : DataFrame ,
324327 srcOutputColumn : String ,
328+ dstOutputColumn : String ,
325329 num : Int ): DataFrame = {
326330 import srcFactors .sparkSession .implicits ._
327331
328332 val ratings = srcFactors.crossJoin(dstFactors)
329333 .select(
330- srcFactors(" id" ).as( " srcId " ) ,
331- dstFactors(" id" ).as( " dstId " ) ,
332- predict(srcFactors(" features" ), dstFactors(" features" )).as($(predictionCol)) )
334+ srcFactors(" id" ),
335+ dstFactors(" id" ),
336+ predict(srcFactors(" features" ), dstFactors(" features" )))
333337 // We'll force the IDs to be Int. Unfortunately this converts IDs to Int in the output.
334338 val topKAggregator = new TopByKeyAggregator [Int , Int , Float ](num, Ordering .by(_._2))
335339 ratings.as[(Int , Int , Float )].groupByKey(_._1).agg(topKAggregator.toColumn)
0 commit comments