Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.mllib.evaluation

import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._
import org.apache.spark.sql.DataFrame

/**
* Evaluator for multilabel classification.
Expand All @@ -27,6 +28,13 @@ import org.apache.spark.SparkContext._
*/
class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]) {

/**
* An auxiliary constructor taking a DataFrame.
* @param predictionAndLabels a DataFrame with two double array columns: prediction and label
*/
private[mllib] def this(predictionAndLabels: DataFrame) =
this(predictionAndLabels.map(r => (r.getSeq[Double](0).toArray, r.getSeq[Double](1).toArray)))

private lazy val numDocs: Long = predictionAndLabels.count()

private lazy val numLabels: Long = predictionAndLabels.flatMap { case (_, labels) =>
Expand Down
117 changes: 117 additions & 0 deletions python/pyspark/mllib/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,123 @@ def ndcgAt(self, k):
return self.call("ndcgAt", int(k))


class MultilabelMetrics(JavaModelWrapper):
"""
Evaluator for multilabel classification.

>>> predictionAndLabels = sc.parallelize([([0.0, 1.0], [0.0, 2.0]), ([0.0, 2.0], [0.0, 1.0]),
... ([], [0.0]), ([2.0], [2.0]), ([2.0, 0.0], [2.0, 0.0]),
... ([0.0, 1.0, 2.0], [0.0, 1.0]), ([1.0], [1.0, 2.0])])
>>> metrics = MultilabelMetrics(predictionAndLabels)
>>> metrics.precision(0.0)
1.0
>>> metrics.recall(1.0)
0.66...
>>> metrics.f1Measure(2.0)
0.5
>>> metrics.precision()
0.66...
>>> metrics.recall()
0.64...
>>> metrics.f1Measure()
0.63...
>>> metrics.microPrecision
0.72...
>>> metrics.microRecall
0.66...
>>> metrics.microF1Measure
0.69...
>>> metrics.hammingLoss
0.33...
>>> metrics.subsetAccuracy
0.28...
>>> metrics.accuracy
0.54...
"""

def __init__(self, predictionAndLabels):
sc = predictionAndLabels.ctx
sql_ctx = SQLContext(sc)
df = sql_ctx.createDataFrame(predictionAndLabels,
schema=sql_ctx._inferSchema(predictionAndLabels))
java_class = sc._jvm.org.apache.spark.mllib.evaluation.MultilabelMetrics
java_model = java_class(df._jdf)
super(MultilabelMetrics, self).__init__(java_model)

def precision(self, label=None):
"""
Returns precision or precision for a given label (category) if specified.
"""
if label is None:
return self.call("precision")
else:
return self.call("precision", float(label))

def recall(self, label=None):
"""
Returns recall or recall for a given label (category) if specified.
"""
if label is None:
return self.call("recall")
else:
return self.call("recall", float(label))

def f1Measure(self, label=None):
"""
Returns f1Measure or f1Measure for a given label (category) if specified.
"""
if label is None:
return self.call("f1Measure")
else:
return self.call("f1Measure", float(label))

@property
def microPrecision(self):
"""
Returns micro-averaged label-based precision.
(equals to micro-averaged document-based precision)
"""
return self.call("microPrecision")

@property
def microRecall(self):
"""
Returns micro-averaged label-based recall.
(equals to micro-averaged document-based recall)
"""
return self.call("microRecall")

@property
def microF1Measure(self):
"""
Returns micro-averaged label-based f1-measure.
(equals to micro-averaged document-based f1-measure)
"""
return self.call("microF1Measure")

@property
def hammingLoss(self):
"""
Returns Hamming-loss.
"""
return self.call("hammingLoss")

@property
def subsetAccuracy(self):
"""
Returns subset accuracy.
(for equal sets of labels)
"""
return self.call("subsetAccuracy")

@property
def accuracy(self):
"""
Returns accuracy.
"""
return self.call("accuracy")


def _test():
import doctest
from pyspark import SparkContext
Expand Down