@@ -343,6 +343,123 @@ def ndcgAt(self, k):
343343        return  self .call ("ndcgAt" , int (k ))
344344
345345
346+ class  MultilabelMetrics (JavaModelWrapper ):
347+     """ 
348+     Evaluator for multilabel classification. 
349+ 
350+     >>> predictionAndLabels = sc.parallelize([([0.0, 1.0], [0.0, 2.0]), ([0.0, 2.0], [0.0, 1.0]), 
351+     ...     ([], [0.0]), ([2.0], [2.0]), ([2.0, 0.0], [2.0, 0.0]), 
352+     ...     ([0.0, 1.0, 2.0], [0.0, 1.0]), ([1.0], [1.0, 2.0])]) 
353+     >>> metrics = MultilabelMetrics(predictionAndLabels) 
354+     >>> metrics.precision(0.0) 
355+     1.0 
356+     >>> metrics.recall(1.0) 
357+     0.66... 
358+     >>> metrics.f1Measure(2.0) 
359+     0.5 
360+     >>> metrics.precision() 
361+     0.66... 
362+     >>> metrics.recall() 
363+     0.64... 
364+     >>> metrics.f1Measure() 
365+     0.63... 
366+     >>> metrics.microPrecision 
367+     0.72... 
368+     >>> metrics.microRecall 
369+     0.66... 
370+     >>> metrics.microF1Measure 
371+     0.69... 
372+     >>> metrics.hammingLoss 
373+     0.33... 
374+     >>> metrics.subsetAccuracy 
375+     0.28... 
376+     >>> metrics.accuracy 
377+     0.54... 
378+     """ 
379+ 
380+     def  __init__ (self , predictionAndLabels ):
381+         sc  =  predictionAndLabels .ctx 
382+         sql_ctx  =  SQLContext (sc )
383+         df  =  sql_ctx .createDataFrame (predictionAndLabels ,
384+                                      schema = sql_ctx ._inferSchema (predictionAndLabels ))
385+         java_class  =  sc ._jvm .org .apache .spark .mllib .evaluation .MultilabelMetrics 
386+         java_model  =  java_class (df ._jdf )
387+         super (MultilabelMetrics , self ).__init__ (java_model )
388+ 
389+     def  precision (self , label = None ):
390+         """ 
391+         Returns precision or precision for a given label (category) if specified. 
392+         """ 
393+         if  label  is  None :
394+             return  self .call ("precision" )
395+         else :
396+             return  self .call ("precision" , float (label ))
397+ 
398+     def  recall (self , label = None ):
399+         """ 
400+         Returns recall or recall for a given label (category) if specified. 
401+         """ 
402+         if  label  is  None :
403+             return  self .call ("recall" )
404+         else :
405+             return  self .call ("recall" , float (label ))
406+ 
407+     def  f1Measure (self , label = None ):
408+         """ 
409+         Returns f1Measure or f1Measure for a given label (category) if specified. 
410+         """ 
411+         if  label  is  None :
412+             return  self .call ("f1Measure" )
413+         else :
414+             return  self .call ("f1Measure" , float (label ))
415+ 
416+     @property  
417+     def  microPrecision (self ):
418+         """ 
419+         Returns micro-averaged label-based precision. 
420+         (equals to micro-averaged document-based precision) 
421+         """ 
422+         return  self .call ("microPrecision" )
423+ 
424+     @property  
425+     def  microRecall (self ):
426+         """ 
427+         Returns micro-averaged label-based recall. 
428+         (equals to micro-averaged document-based recall) 
429+         """ 
430+         return  self .call ("microRecall" )
431+ 
432+     @property  
433+     def  microF1Measure (self ):
434+         """ 
435+         Returns micro-averaged label-based f1-measure. 
436+         (equals to micro-averaged document-based f1-measure) 
437+         """ 
438+         return  self .call ("microF1Measure" )
439+ 
440+     @property  
441+     def  hammingLoss (self ):
442+         """ 
443+         Returns Hamming-loss. 
444+         """ 
445+         return  self .call ("hammingLoss" )
446+ 
447+     @property  
448+     def  subsetAccuracy (self ):
449+         """ 
450+         Returns subset accuracy. 
451+         (for equal sets of labels) 
452+         """ 
453+         return  self .call ("subsetAccuracy" )
454+ 
455+     @property  
456+     def  accuracy (self ):
457+         """ 
458+         Returns accuracy. 
459+         """ 
460+         return  self .call ("accuracy" )
461+ 
462+ 
346463def  _test ():
347464    import  doctest 
348465    from  pyspark  import  SparkContext 
0 commit comments