1+ /*
2+ * Licensed to the Apache Software Foundation (ASF) under one or more
3+ * contributor license agreements. See the NOTICE file distributed with
4+ * this work for additional information regarding copyright ownership.
5+ * The ASF licenses this file to You under the Apache License, Version 2.0
6+ * (the "License"); you may not use this file except in compliance with
7+ * the License. You may obtain a copy of the License at
8+ *
9+ * http://www.apache.org/licenses/LICENSE-2.0
10+ *
11+ * Unless required by applicable law or agreed to in writing, software
12+ * distributed under the License is distributed on an "AS IS" BASIS,
13+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+ * See the License for the specific language governing permissions and
15+ * limitations under the License.
16+ */
17+
18+ package org .apache .spark .mllib .evaluation
19+
20+ import org .apache .spark .rdd .RDD
21+ import org .apache .spark .Logging
22+ import org .apache .spark .SparkContext ._
23+
24+ /**
25+ * Evaluator for multiclass classification.
26+ *
27+ * @param scoreAndLabels an RDD of (score, label) pairs.
28+ */
29+ class MulticlassEvaluator (scoreAndLabels : RDD [(Double , Double )]) extends Logging {
30+
31+ /* class = category; label = instance of class; prediction = instance of class */
32+
33+ private lazy val labelCountByClass = scoreAndLabels.values.countByValue()
34+ private lazy val labelCount = labelCountByClass.foldLeft(0L ){case (sum, (_, count)) => sum + count}
35+ private lazy val tpByClass = scoreAndLabels.map{ case (prediction, label) => (label, if (label == prediction) 1 else 0 ) }.reduceByKey{_ + _}.collectAsMap
36+ private lazy val fpByClass = scoreAndLabels.map{ case (prediction, label) => (prediction, if (prediction != label) 1 else 0 ) }.reduceByKey{_ + _}.collectAsMap
37+
38+ /**
39+ * Returns Precision for a given label (category)
40+ * @param label the label.
41+ * @return Precision.
42+ */
43+ def precision (label : Double ): Double = if (tpByClass(label) + fpByClass.getOrElse(label, 0 ) == 0 ) 0
44+ else tpByClass(label).toDouble / (tpByClass(label) + fpByClass.getOrElse(label, 0 )).toDouble
45+
46+ /**
47+ * Returns Recall for a given label (category)
48+ * @param label the label.
49+ * @return Recall.
50+ */
51+ def recall (label : Double ): Double = tpByClass(label).toDouble / labelCountByClass(label).toDouble
52+
53+ /**
54+ * Returns F1-measure for a given label (category)
55+ * @param label the label.
56+ * @return F1-measure.*/
57+ def f1Measure (label : Double ): Double = 2 * precision(label) * recall(label) / (precision(label) + recall(label))
58+
59+ /**
60+ * Returns micro-averaged Recall (equals to microPrecision and microF1measure for multiclass classifier)
61+ * @return microRecall.
62+ */
63+ def microRecall : Double = tpByClass.foldLeft(0L ){case (sum,(_, tp)) => sum + tp}.toDouble / labelCount.toDouble
64+
65+ /**
66+ * Returns micro-averaged Precision (equals to microPrecision and microF1measure for multiclass classifier)
67+ * @return microPrecision.
68+ */
69+ def microPrecision : Double = microRecall
70+
71+ /**
72+ * Returns micro-averaged F1-measure (equals to microPrecision and microRecall for multiclass classifier)
73+ * @return microF1measure.
74+ */
75+ def microF1Measure : Double = microRecall
76+
77+ /**
78+ * Returns weighted averaged Recall
79+ * @return weightedRecall.
80+ */
81+ def weightedRecall : Double = labelCountByClass.foldLeft(0.0 ){case (wRecall, (category, count)) => wRecall + recall(category) * count.toDouble / labelCount.toDouble}
82+
83+ /**
84+ * Returns weighted averaged Precision
85+ * @return weightedPrecision.
86+ */
87+ def weightedPrecision : Double = labelCountByClass.foldLeft(0.0 ){case (wPrecision, (category, count)) => wPrecision + precision(category) * count.toDouble / labelCount.toDouble}
88+
89+ /**
90+ * Returns weighted averaged F1-measure
91+ * @return weightedF1Measure.
92+ */
93+ def weightedF1Measure : Double = 2 * weightedPrecision * weightedRecall / (weightedPrecision + weightedRecall)
94+
95+ /**
96+ * Returns map with Precisions for individual classes
97+ * @return precisionPerClass.
98+ */
99+ def precisionPerClass = labelCountByClass.map{case (category, _) => (category, precision(category))}.toMap
100+
101+ /**
102+ * Returns map with Recalls for individual classes
103+ * @return recallPerClass.
104+ */
105+ def recallPerClass = labelCountByClass.map{case (category, _) => (category, recall(category))}.toMap
106+
107+ /**
108+ * Returns map with F1-measures for individual classes
109+ * @return f1MeasurePerClass.
110+ */
111+ def f1MeasurePerClass = labelCountByClass.map{case (category, _) => (category, f1Measure(category))}.toMap
112+ }
0 commit comments