Skip to content

Commit 01f3be7

Browse files
author
Nick Pentreath
committed
[SPARK-20300][ML][PYSPARK] Python API for ALSModel.recommendForAllUsers,Items
Add Python API for `ALSModel` methods `recommendForAllUsers`, `recommendForAllItems` ## How was this patch tested? New doc tests. Author: Nick Pentreath <[email protected]> Closes #17622 from MLnick/SPARK-20300-pyspark-recall. (cherry picked from commit e300a5a) Signed-off-by: Nick Pentreath <[email protected]>
1 parent ef5e2a0 commit 01f3be7

File tree

1 file changed

+30
-0
lines changed

1 file changed

+30
-0
lines changed

python/pyspark/ml/recommendation.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,14 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
8282
Row(user=1, item=0, prediction=2.6258413791656494)
8383
>>> predictions[2]
8484
Row(user=2, item=0, prediction=-1.5018409490585327)
85+
>>> user_recs = model.recommendForAllUsers(3)
86+
>>> user_recs.where(user_recs.user == 0)\
87+
.select("recommendations.item", "recommendations.rating").collect()
88+
[Row(item=[0, 1, 2], rating=[3.910..., 1.992..., -0.138...])]
89+
>>> item_recs = model.recommendForAllItems(3)
90+
>>> item_recs.where(item_recs.item == 2)\
91+
.select("recommendations.user", "recommendations.rating").collect()
92+
[Row(user=[2, 1, 0], rating=[4.901..., 3.981..., -0.138...])]
8593
>>> als_path = temp_path + "/als"
8694
>>> als.save(als_path)
8795
>>> als2 = ALS.load(als_path)
@@ -384,6 +392,28 @@ def itemFactors(self):
384392
"""
385393
return self._call_java("itemFactors")
386394

395+
@since("2.2.0")
396+
def recommendForAllUsers(self, numItems):
397+
"""
398+
Returns top `numItems` items recommended for each user, for all users.
399+
400+
:param numItems: max number of recommendations for each user
401+
:return: a DataFrame of (userCol, recommendations), where recommendations are
402+
stored as an array of (itemCol, rating) Rows.
403+
"""
404+
return self._call_java("recommendForAllUsers", numItems)
405+
406+
@since("2.2.0")
407+
def recommendForAllItems(self, numUsers):
408+
"""
409+
Returns top `numUsers` users recommended for each item, for all items.
410+
411+
:param numUsers: max number of recommendations for each item
412+
:return: a DataFrame of (itemCol, recommendations), where recommendations are
413+
stored as an array of (userCol, rating) Rows.
414+
"""
415+
return self._call_java("recommendForAllItems", numUsers)
416+
387417

388418
if __name__ == "__main__":
389419
import doctest

0 commit comments

Comments
 (0)