Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
929f0e6
updating DT APIf
jkbradley Jul 18, 2014
29e29b8
Merging multiclass DT PR, plus others, into branch with updates to DT…
jkbradley Jul 18, 2014
20fc805
Mostly done with DecisionTree API re-config. Still need to update De…
jkbradley Jul 19, 2014
0ced13a
Major changes to DecisionTree API and internals. Unit tests work. S…
jkbradley Jul 23, 2014
4ba347f
Merge remote-tracking branch 'upstream/master' into decisiontree-api
jkbradley Jul 23, 2014
a853bfc
Last non-merge commit said it changed the maxDepth meaning, but it di…
jkbradley Jul 23, 2014
4506844
Changed all config/impurity classes/objects to be private[mllib].
jkbradley Jul 24, 2014
b6b0809
removed mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionT…
jkbradley Jul 24, 2014
a2a9311
removed mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionT…
jkbradley Jul 24, 2014
0cb9866
Merge branch 'decisiontree-api' of github.com:jkbradley/spark into de…
jkbradley Jul 24, 2014
3ff5027
Bug fix: Indexing was inconsistent for aggregate calculations for uno…
jkbradley Jul 24, 2014
3ba5b4c
Merge remote-tracking branch 'upstream/master' into decisiontree-api
jkbradley Jul 25, 2014
e1243a5
Fixed scala style issues reported by Jenkins
jkbradley Jul 25, 2014
62c2fbc
Merge remote-tracking branch 'upstream/master' into decisiontree-api
jkbradley Jul 25, 2014
3eea304
Added Algo exception to MimaExcludes.scala
jkbradley Jul 25, 2014
cda2a80
Added more exceptions to MimaExcludes.scala
jkbradley Jul 25, 2014
e73dc32
Added yet more exceptions to MimaExcludes.scala
jkbradley Jul 25, 2014
07e9c16
Modified Decision Tree params classes to use Scala BeansProperty for …
jkbradley Jul 28, 2014
c42d85e
Merge remote-tracking branch 'upstream/master' into decisiontree-api
jkbradley Jul 28, 2014
becec3f
Made DTParams class abstract. Moved supported* methods in DT*Params …
jkbradley Jul 28, 2014
c0a46be
added newline character for Scala style
jkbradley Jul 28, 2014
4bea4bd
Updated documentation for Decision Trees based on new API
jkbradley Jul 28, 2014
e67ea9c
Small updates based on @manishamde comments:
jkbradley Jul 29, 2014
40c81e3
Merge remote-tracking branch 'upstream/master' into decisiontree-api
jkbradley Jul 29, 2014
f543f94
Merge remote-tracking branch 'upstream/master' into decisiontree-api
jkbradley Jul 29, 2014
bdc2aa7
Changed DecisionTree*Model print() methods to be called toString(). …
jkbradley Jul 29, 2014
17dcc09
Added @Experimental tags to some Decision Tree objects.
jkbradley Jul 29, 2014
d2c1dad
Fixed bug in DecisionTreeRunner with old print function name. Added …
jkbradley Jul 29, 2014
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 23 additions & 13 deletions docs/mllib-decision-tree.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,23 +116,28 @@ maximum tree depth of 5. The training error is calculated to measure the algorit
<div class="codetabs">
<div data-lang="scala">
{% highlight scala %}
import org.apache.spark.SparkContext
import org.apache.spark.mllib.tree.DecisionTree
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.rdd.DatasetInfo
import org.apache.spark.mllib.tree.DecisionTreeClassifier
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.impurity.Gini

// Load and parse the data file
val data = sc.textFile("data/mllib/sample_tree_data.csv")
val parsedData = data.map { line =>
val parts = line.split(',').map(_.toDouble)
LabeledPoint(parts(0), Vectors.dense(parts.tail))
}
val numFeatures = parsedData.take(1)(0).features.size
val datasetInfo = new DatasetInfo(numClasses = 2, numFeatures = numFeatures)

// Run training algorithm to build the model
val maxDepth = 5
val model = DecisionTree.train(parsedData, Classification, Gini, maxDepth)
val dtParams = DecisionTreeClassifier.defaultParams()
dtParams.impurity = "gini"
dtParams.maxDepth = 4
val model = DecisionTreeClassifier.train(parsedData, datasetInfo, dtParams)

// Print model in human-readable format.
model.print()

// Evaluate model on training examples and compute training error
val labelAndPreds = parsedData.map { point =>
Expand All @@ -155,31 +160,36 @@ depth of 5. The Mean Squared Error (MSE) is computed at the end to evaluate
<div class="codetabs">
<div data-lang="scala">
{% highlight scala %}
import org.apache.spark.SparkContext
import org.apache.spark.mllib.tree.DecisionTree
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.rdd.DatasetInfo
import org.apache.spark.mllib.tree.DecisionTreeRegressor
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.impurity.Variance

// Load and parse the data file
val data = sc.textFile("data/mllib/sample_tree_data.csv")
val parsedData = data.map { line =>
val parts = line.split(',').map(_.toDouble)
LabeledPoint(parts(0), Vectors.dense(parts.tail))
}
val numFeatures = parsedData.take(1)(0).features.size
val datasetInfo = new DatasetInfo(numClasses = 0, numFeatures = numFeatures)

// Run training algorithm to build the model
val maxDepth = 5
val model = DecisionTree.train(parsedData, Regression, Variance, maxDepth)
val dtParams = DecisionTreeRegressor.defaultParams()
dtParams.impurity = "variance"
dtParams.maxDepth = 4
val model = DecisionTreeRegressor.train(parsedData, datasetInfo, dtParams)

// Print model in human-readable format.
model.print()

// Evaluate model on training examples and compute training error
val valuesAndPreds = parsedData.map { point =>
val prediction = model.predict(point.features)
(point.label, prediction)
}
val MSE = valuesAndPreds.map{ case(v, p) => math.pow((v - p), 2)}.mean()
println("training Mean Squared Error = " + MSE)
println("Training Mean Squared Error = " + MSE)
{% endhighlight %}
</div>
</div>
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@ import scopt.OptionParser

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.rdd.DatasetInfo
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{DecisionTree, impurity}
import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.{DecisionTreeClassifier, DecisionTreeRegressor}
import org.apache.spark.mllib.tree.configuration.{DTClassifierParams, DTRegressorParams}
import org.apache.spark.mllib.tree.model.DecisionTreeModel
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
Expand All @@ -36,59 +35,66 @@ import org.apache.spark.rdd.RDD
* ./bin/spark-example org.apache.spark.examples.mllib.DecisionTreeRunner [options]
* }}}
* If you use it as a template to create your own app, please use `spark-submit` to submit your app.
*
* Note: This script treats all features as real-valued (not categorical).
* To include categorical features, modify
* [[org.apache.spark.mllib.rdd.DatasetInfo.categoricalFeaturesInfo]].
*/
object DecisionTreeRunner {

object ImpurityType extends Enumeration {
type ImpurityType = Value
val Gini, Entropy, Variance = Value
}

import ImpurityType._

case class Params(
input: String = null,
algo: Algo = Classification,
numClassesForClassification: Int = 2,
maxDepth: Int = 5,
impurity: ImpurityType = Gini,
maxBins: Int = 100)
dataFormat: String = null,
algo: String = "classification",
impurity: Option[String] = None,
maxDepth: Int = 4,
maxBins: Int = 100,
fracTest: Double = 0.2)

private val defaultCImpurity = new DTClassifierParams().impurity
private val defaultRImpurity = new DTRegressorParams().impurity

def main(args: Array[String]) {
val defaultParams = Params()

val parser = new OptionParser[Params]("DecisionTreeRunner") {
head("DecisionTreeRunner: an example decision tree app.")
opt[String]("algo")
.text(s"algorithm (${Algo.values.mkString(",")}), default: ${defaultParams.algo}")
.action((x, c) => c.copy(algo = Algo.withName(x)))
.text(s"algorithm (classification, regression), default: ${defaultParams.algo}")
.action((x, c) => c.copy(algo = x))
opt[String]("impurity")
.text(s"impurity type (${ImpurityType.values.mkString(",")}), " +
s"default: ${defaultParams.impurity}")
.action((x, c) => c.copy(impurity = ImpurityType.withName(x)))
.text(
s"impurity type\n" +
s"\tFor classification: ${DTClassifierParams.supportedImpurities.mkString(",")}\n" +
s"\t default: $defaultCImpurity" +
s"\tFor regression: ${DTRegressorParams.supportedImpurities.mkString(",")}\n" +
s"\t default: $defaultRImpurity")
.action((x, c) => c.copy(impurity = Some(x)))
opt[Int]("maxDepth")
.text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
.action((x, c) => c.copy(maxDepth = x))
opt[Int]("numClassesForClassification")
.text(s"number of classes for classification, "
+ s"default: ${defaultParams.numClassesForClassification}")
.action((x, c) => c.copy(numClassesForClassification = x))
opt[Int]("maxBins")
.text(s"max number of bins, default: ${defaultParams.maxBins}")
.action((x, c) => c.copy(maxBins = x))
opt[Double]("fracTest")
.text(s"fraction of data to hold out for testing, default: ${defaultParams.fracTest}")
.action((x, c) => c.copy(fracTest = x))
arg[String]("<input>")
.text("input paths to labeled examples in dense format (label,f0 f1 f2 ...)")
.text("input paths to labeled examples")
.required()
.action((x, c) => c.copy(input = x))
arg[String]("<dataFormat>")
.text("data format: dense/libsvm")
.required()
.action((x, c) => c.copy(dataFormat = x))
checkConfig { params =>
if (params.algo == Classification &&
(params.impurity == Gini || params.impurity == Entropy)) {
success
} else if (params.algo == Regression && params.impurity == Variance) {
success
} else {
failure(s"Algo ${params.algo} is not compatible with impurity ${params.impurity}.")
if (!List("classification", "regression").contains(params.algo)) {
failure(s"Did not recognize Algo: ${params.algo}")
}
if (params.fracTest < 0 || params.fracTest > 1) {
failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].")
}
success
}
}

Expand All @@ -104,42 +110,92 @@ object DecisionTreeRunner {
val sc = new SparkContext(conf)

// Load training data and cache it.
val examples = MLUtils.loadLabeledPoints(sc, params.input).cache()
val origExamples = params.dataFormat match {
case "dense" => MLUtils.loadLabeledPoints(sc, params.input).cache()
case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input, multiclass = true).cache()
}
// For classification, re-index classes if needed.
val (examples, numClasses) = params.algo match {
case "classification" => {
// classCounts: class --> # examples in class
val classCounts = origExamples.map(_.label).countByValue
val numClasses = classCounts.size
// classIndex: class --> index in 0,...,numClasses-1
val classIndex = {
if (classCounts.keySet != Set[Double](0.0, 1.0)) {
classCounts.keys.toList.sorted.zipWithIndex.toMap
} else {
Map[Double, Int]()
}
}
val examples = {
if (classIndex.isEmpty) {
origExamples
} else {
origExamples.map(lp => LabeledPoint(classIndex(lp.label), lp.features))
}
}
println(s"numClasses = $numClasses.")
println(s"Per-class example fractions, counts:")
println(s"Class\tFrac\tCount")
classCounts.keys.toList.sorted.foreach(c => {
val frac = classCounts(c) / (0.0 + examples.count())
println(s"$c\t$frac\t${classCounts(c)}")
})
(examples, numClasses)
}
case "regression" => {
(origExamples, 0)
}
case _ => {
throw new IllegalArgumentException("Algo ${params.algo} not supported.")
}
}

// Split into training, test.
val splits = examples.randomSplit(Array(0.8, 0.2))
val training = splits(0).cache()
val test = splits(1).cache()

val numTraining = training.count()
val numTest = test.count()

println(s"numTraining = $numTraining, numTest = $numTest.")
println(s"numTraining = $numTraining, numTest = $numTest")

examples.unpersist(blocking = false)

val impurityCalculator = params.impurity match {
case Gini => impurity.Gini
case Entropy => impurity.Entropy
case Variance => impurity.Variance
}

val strategy
= new Strategy(
algo = params.algo,
impurity = impurityCalculator,
maxDepth = params.maxDepth,
maxBins = params.maxBins,
numClassesForClassification = params.numClassesForClassification)
val model = DecisionTree.train(training, strategy)

if (params.algo == Classification) {
val accuracy = accuracyScore(model, test)
println(s"Test accuracy = $accuracy.")
}
val numFeatures = examples.take(1)(0).features.size
val datasetInfo = new DatasetInfo(numClasses, numFeatures)

if (params.algo == Regression) {
val mse = meanSquaredError(model, test)
println(s"Test mean squared error = $mse.")
params.algo match {
case "classification" => {
val dtParams = DecisionTreeClassifier.defaultParams()
dtParams.maxDepth = params.maxDepth
dtParams.maxBins = params.maxBins
if (params.impurity == None) {
dtParams.impurity = defaultCImpurity
}
val dtLearner = new DecisionTreeClassifier(dtParams)
val model = dtLearner.run(training, datasetInfo)
println(model.toString)
val accuracy = accuracyScore(model, test)
println(s"Test accuracy = $accuracy")
}
case "regression" => {
val dtParams = DecisionTreeRegressor.defaultParams()
dtParams.maxDepth = params.maxDepth
dtParams.maxBins = params.maxBins
if (params.impurity == None) {
dtParams.impurity = defaultRImpurity
}
val dtLearner = new DecisionTreeRegressor(dtParams)
val model = dtLearner.run(training, datasetInfo)
println(model.toString)
val mse = meanSquaredError(model, test)
println(s"Test mean squared error = $mse")
}
case _ => {
throw new IllegalArgumentException("Algo ${params.algo} not supported.")
}
}

sc.stop()
Expand All @@ -159,9 +215,11 @@ object DecisionTreeRunner {
/**
* Calculates the mean squared error for regression.
*/
private def meanSquaredError(tree: DecisionTreeModel, data: RDD[LabeledPoint]): Double = {
private def meanSquaredError(
model: DecisionTreeModel,
data: RDD[LabeledPoint]): Double = {
data.map { y =>
val err = tree.predict(y.features) - y.label
val err = model.predict(y.features) - y.label
err * err
}.mean()
}
Expand Down
66 changes: 66 additions & 0 deletions mllib/src/main/scala/org/apache/spark/mllib/rdd/DatasetInfo.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.rdd

/**
* :: Experimental ::
* A class for holding dataset metadata.
* @param numClasses Number of classes for classification. Values of 0 or 1 indicate regression.
* @param numFeatures Number of features.
* @param categoricalFeaturesInfo A map storing information about the categorical variables and the
* number of discrete values they take. For example, an entry (n ->
* k) implies the feature n is categorical with k categories 0,
* 1, 2, ... , k-1. It's important to note that features are
* zero-indexed.
*/
class DatasetInfo (
val numClasses: Int,
val numFeatures: Int,
val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int]())
extends Serializable {

/**
* Indicates if this dataset's label is real-valued (numClasses < 2).
*/
def isRegression: Boolean = {
numClasses < 2
}

/**
* Indicates if this dataset's label is categorical (numClasses >= 2).
*/
def isClassification: Boolean = {
numClasses >= 2
}

/**
* Indicates if this dataset's label is categorical with >2 categories.
*/
def isMulticlass: Boolean = {
numClasses > 2
}

/**
* Indicates if this dataset's label is categorical with >2 categories,
* and there is at least one categorical feature.
*/
def isMulticlassWithCategoricalFeatures: Boolean = {
isMulticlass && categoricalFeaturesInfo.nonEmpty
}

}
Loading