2222from pyspark .rdd import RDD
2323from pyspark .mllib .common import JavaModelWrapper , callMLlibFunc , inherit_doc
2424from pyspark .mllib .util import JavaLoader , JavaSaveable
25+ from pyspark .sql import DataFrame
2526
2627__all__ = ['MatrixFactorizationModel' , 'ALS' , 'Rating' ]
2728
@@ -78,18 +79,23 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader):
7879 True
7980
8081 >>> model = ALS.train(ratings, 1, nonnegative=True, seed=10)
81- >>> model.predict(2,2)
82+ >>> model.predict(2, 2)
83+ 3.8...
84+
85+ >>> df = sqlContext.createDataFrame([Rating(1, 1, 1.0), Rating(1, 2, 2.0), Rating(2, 1, 2.0)])
86+ >>> model = ALS.train(df, 1, nonnegative=True, seed=10)
87+ >>> model.predict(2, 2)
8288 3.8...
8389
8490 >>> model = ALS.trainImplicit(ratings, 1, nonnegative=True, seed=10)
85- >>> model.predict(2,2)
91+ >>> model.predict(2, 2)
8692 0.4...
8793
8894 >>> import os, tempfile
8995 >>> path = tempfile.mkdtemp()
9096 >>> model.save(sc, path)
9197 >>> sameModel = MatrixFactorizationModel.load(sc, path)
92- >>> sameModel.predict(2,2)
98+ >>> sameModel.predict(2, 2)
9399 0.4...
94100 >>> sameModel.predictAll(testset).collect()
95101 [Rating(...
@@ -125,13 +131,20 @@ class ALS(object):
125131
126132 @classmethod
127133 def _prepare (cls , ratings ):
128- assert isinstance (ratings , RDD ), "ratings should be RDD"
134+ if isinstance (ratings , RDD ):
135+ pass
136+ elif isinstance (ratings , DataFrame ):
137+ ratings = ratings .rdd
138+ else :
139+ raise TypeError ("Ratings should be represented by either an RDD or a DataFrame, "
140+ "but got %s." % type (ratings ))
129141 first = ratings .first ()
130- if not isinstance (first , Rating ):
131- if isinstance (first , (tuple , list )):
132- ratings = ratings .map (lambda x : Rating (* x ))
133- else :
134- raise ValueError ("rating should be RDD of Rating or tuple/list" )
142+ if isinstance (first , Rating ):
143+ pass
144+ elif isinstance (first , (tuple , list )):
145+ ratings = ratings .map (lambda x : Rating (* x ))
146+ else :
147+ raise TypeError ("Expect a Rating or a tuple/list, but got %s." % type (first ))
135148 return ratings
136149
137150 @classmethod
@@ -152,8 +165,11 @@ def trainImplicit(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alp
152165def _test ():
153166 import doctest
154167 import pyspark .mllib .recommendation
168+ from pyspark .sql import SQLContext
155169 globs = pyspark .mllib .recommendation .__dict__ .copy ()
156- globs ['sc' ] = SparkContext ('local[4]' , 'PythonTest' )
170+ sc = SparkContext ('local[4]' , 'PythonTest' )
171+ globs ['sc' ] = sc
172+ globs ['sqlContext' ] = SQLContext (sc )
157173 (failure_count , test_count ) = doctest .testmod (globs = globs , optionflags = doctest .ELLIPSIS )
158174 globs ['sc' ].stop ()
159175 if failure_count :
0 commit comments