Skip to content

Commit 1b6ebb3

Browse files
author
Ram Sriharsha
committed
[SPARK-7404][ml] Add RegressionEvaluator to spark.ml
1 parent e4136ea commit 1b6ebb3

File tree

2 files changed

+168
-0
lines changed

2 files changed

+168
-0
lines changed
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.evaluation
19+
20+
import org.apache.spark.annotation.AlphaComponent
21+
import org.apache.spark.ml.Evaluator
22+
import org.apache.spark.ml.param.Param
23+
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
24+
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
25+
import org.apache.spark.mllib.evaluation.RegressionMetrics
26+
import org.apache.spark.sql.{DataFrame, Row}
27+
import org.apache.spark.sql.types.DoubleType
28+
29+
/**
30+
* :: AlphaComponent ::
31+
*
32+
* Evaluator for regression, which expects two input columns: score and label.
33+
*/
34+
@AlphaComponent
35+
class RegressionEvaluator(override val uid: String)
36+
extends Evaluator with HasPredictionCol with HasLabelCol {
37+
38+
def this() = this(Identifiable.randomUID("regEval"))
39+
40+
/**
41+
* param for metric name in evaluation
42+
* @group param
43+
*/
44+
val metricName: Param[String] = new Param(this, "metricName",
45+
"metric name in evaluation (rmse|r2|mae)")
46+
47+
/** @group getParam */
48+
def getMetricName: String = $(metricName)
49+
50+
/** @group setParam */
51+
def setMetricName(value: String): this.type = set(metricName, value)
52+
53+
/** @group setParam */
54+
def setScoreCol(value: String): this.type = set(predictionCol, value)
55+
56+
/** @group setParam */
57+
def setLabelCol(value: String): this.type = set(labelCol, value)
58+
59+
setDefault(metricName -> "rmse")
60+
61+
override def evaluate(dataset: DataFrame): Double = {
62+
val schema = dataset.schema
63+
SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType)
64+
SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
65+
66+
val scoreAndLabels = dataset.select($(predictionCol), $(labelCol))
67+
.map { case Row(prediction: Double, label: Double) =>
68+
(prediction, label)
69+
}
70+
val metrics = new RegressionMetrics(scoreAndLabels)
71+
val metric = $(metricName) match {
72+
case "rmse" =>
73+
metrics.rootMeanSquaredError
74+
case "mse" =>
75+
metrics.meanSquaredError
76+
case "r2" =>
77+
metrics.r2
78+
case "mae" =>
79+
metrics.meanAbsoluteError
80+
case other =>
81+
throw new IllegalArgumentException(s"Does not support metric $other.")
82+
}
83+
metric
84+
}
85+
}
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.evaluation
19+
20+
import org.scalatest.FunSuite
21+
22+
import org.apache.spark.ml.regression.LinearRegression
23+
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
24+
import org.apache.spark.mllib.util.TestingUtils._
25+
import org.apache.spark.sql.DataFrame
26+
27+
class RegressionEvaluatorSuite extends FunSuite with MLlibTestSparkContext {
28+
29+
@transient var dataset: DataFrame = _
30+
31+
override def beforeAll(): Unit = {
32+
super.beforeAll()
33+
/**
34+
* Here is the instruction describing how to export the test data into CSV format
35+
* so we can validate the metrics compared with scikit learns regression metrics package.
36+
*
37+
* import org.apache.spark.mllib.util.LinearDataGenerator
38+
* val data = sc.parallelize(LinearDataGenerator.generateLinearInput(6.3,
39+
* Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1))
40+
* data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1))
41+
* .saveAsTextFile("path")
42+
*/
43+
dataset = sqlContext.createDataFrame(
44+
sc.parallelize(LinearDataGenerator.generateLinearInput(
45+
6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2))
46+
}
47+
48+
test("Regression Evaluator: default params") {
49+
/**
50+
* Using the following python code to load the data and train the model using scikit learn.
51+
*
52+
* > from sklearn.linear_model import LinearRegression
53+
* > from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
54+
* > import pandas as pd
55+
* > from patsy import dmatrices
56+
* > df = pd.read_csv("path")
57+
* > y, X = dmatrices('label ~ x + y',df, return_type="dataframe")
58+
* > regr = LinearRegression()
59+
* > regr.fit(X, y)
60+
* > print('Mean Squared Error: %.2f' % mean_squared_error(y, regr.predict(X)))
61+
* > print('Mean Absolute Error: %.2f' % mean_absolute_error(y, regr.predict(X)))
62+
* > print('R2 score: %.2f' % r2_score(y, regr.predict(X)))
63+
* > Mean Squared Error: 0.01
64+
* > Mean Absolute Error: 0.08
65+
* > R2 score: 1.00
66+
*/
67+
val trainer = new LinearRegression
68+
val model = trainer.fit(dataset)
69+
val predictions = model.transform(dataset)
70+
71+
// default = rmse
72+
val evaluator = new RegressionEvaluator()
73+
assert(evaluator.evaluate(predictions) ~== 0.1 relTol 0.02)
74+
75+
// r2 score
76+
evaluator.setMetricName("r2")
77+
assert(evaluator.evaluate(predictions) ~== 0.01 relTol 0.002)
78+
79+
// mae
80+
evaluator.setMetricName("mae")
81+
assert(evaluator.evaluate(predictions) ~== 0.08 relTol 0.01)
82+
}
83+
}

0 commit comments

Comments
 (0)