Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
a95bc22
timing for DecisionTree internals
jkbradley Aug 5, 2014
511ec85
Merge remote-tracking branch 'upstream/master' into dt-timing
jkbradley Aug 6, 2014
bcf874a
Merge remote-tracking branch 'upstream/master' into dt-timing
jkbradley Aug 7, 2014
f61e9d2
Merge remote-tracking branch 'upstream/master' into dt-timing
jkbradley Aug 8, 2014
3211f02
Optimizing DecisionTree
jkbradley Aug 8, 2014
0f676e2
Optimizations + Bug fix for DecisionTree
jkbradley Aug 8, 2014
b2ed1f3
Merge remote-tracking branch 'upstream/master' into dt-opt
jkbradley Aug 8, 2014
b914f3b
DecisionTree optimization: eliminated filters + small changes
jkbradley Aug 9, 2014
c1565a5
Small DecisionTree updates:
jkbradley Aug 11, 2014
fd65372
Major changes:
jkbradley Aug 13, 2014
51ef781
Fixed bug introduced by last commit: Variance impurity calculation wa…
jkbradley Aug 13, 2014
e3c84cc
Added stuff fro mnist8m to D T Runner
jkbradley Aug 14, 2014
86e217f
added cache to DT input
jkbradley Aug 14, 2014
438a660
removed subsampling for mnist8m from DT
jkbradley Aug 14, 2014
dd4d3aa
Mid-process in bug fix: bug for binary classification with categorica…
jkbradley Aug 14, 2014
a87e08f
Merge remote-tracking branch 'upstream/master' into dt-opt1
jkbradley Aug 14, 2014
8464a6e
Moved TimeTracker to tree/impl/ in its own file, and cleaned it up. …
jkbradley Aug 14, 2014
e66f1b1
TreePoint
jkbradley Aug 14, 2014
d036089
Print timing info to logDebug.
jkbradley Aug 14, 2014
430d782
Added more debug info on binning error. Added some docs.
jkbradley Aug 14, 2014
356daba
Merge branch 'dt-opt1' into dt-opt2
jkbradley Aug 14, 2014
26d10dd
Removed tree/model/Filter.scala since no longer used. Removed debugg…
jkbradley Aug 15, 2014
5fce635
Merge branch 'dt-opt2' into dt-opt3
jkbradley Aug 15, 2014
45f7ea7
partial merge, not yet done
jkbradley Aug 17, 2014
9c83363
partial merge but not done yet
jkbradley Aug 18, 2014
b314659
Merge remote-tracking branch 'upstream/master' into dt-opt3
jkbradley Aug 18, 2014
3ba7166
Merge remote-tracking branch 'upstream/master' into dt-opt3
jkbradley Aug 18, 2014
61c4509
Fixed bugs from merge: missing DT timer call, and numBins setting. C…
jkbradley Aug 18, 2014
5f94342
Added treeAggregate since not yet merged from master. Moved node ind…
jkbradley Aug 18, 2014
95cad7c
Merge remote-tracking branch 'upstream/master' into dt-opt3
jkbradley Aug 18, 2014
a40f8f1
Changed nodes to be indexed from 1. Tests work.
jkbradley Aug 19, 2014
d7c53ee
Added more doc for ImpurityAggregator
jkbradley Aug 19, 2014
fd8df30
Moved some aggregation helpers outside of findBestSplitsPerGroup
jkbradley Aug 19, 2014
92f7118
Added partly written DTStatsAggregator
jkbradley Aug 20, 2014
f2166fd
still working on DTStatsAggregator
jkbradley Aug 20, 2014
807cd00
Finished DTStatsAggregator, a wrapper around the aggregate statistics…
jkbradley Aug 24, 2014
6d32ccd
In DecisionTree.binsToBestSplit, changed loops to iterators to shorte…
jkbradley Aug 24, 2014
062c31d
Merge remote-tracking branch 'upstream/master' into dt-opt3alt
jkbradley Aug 24, 2014
105f8ab
Removed commented-out getEmptyBinAggregates from DecisionTree
jkbradley Aug 25, 2014
37ca845
Fixed problem with how DecisionTree handles ordered categorical featu…
jkbradley Aug 25, 2014
e676da1
Updated documentation for DecisionTree
jkbradley Aug 26, 2014
92f934f
Merge remote-tracking branch 'upstream/master' into dt-opt3alt
jkbradley Aug 26, 2014
1485fcc
Made some DecisionTree methods private.
jkbradley Aug 27, 2014
1e3b1c7
Merge remote-tracking branch 'upstream/master' into dt-opt3alt
jkbradley Sep 2, 2014
4651154
Changed numBins semantics for unordered features.
jkbradley Sep 2, 2014
aa4e4df
Updated DTStatsAggregator with bug fix (nodeString should not be mult…
jkbradley Sep 3, 2014
a2acea5
Small optimizations based on profiling
jkbradley Sep 5, 2014
425716c
Merge remote-tracking branch 'upstream/master' into rfs
jkbradley Sep 5, 2014
00e4404
optimization for TreePoint construction (pre-computing featureArity a…
jkbradley Sep 5, 2014
d3cc46b
Merge remote-tracking branch 'upstream/master' into dt-opt3alt
jkbradley Sep 6, 2014
42c192a
Merge branch 'rfs' into dt-opt3alt
jkbradley Sep 6, 2014
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,341 changes: 450 additions & 891 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.tree.impl

import org.apache.spark.mllib.tree.impurity._

/**
* DecisionTree statistics aggregator.
* This holds a flat array of statistics for a set of (nodes, features, bins)
* and helps with indexing.
*/
private[tree] class DTStatsAggregator(
val metadata: DecisionTreeMetadata,
val numNodes: Int) extends Serializable {

/**
* [[ImpurityAggregator]] instance specifying the impurity type.
*/
val impurityAggregator: ImpurityAggregator = metadata.impurity match {
case Gini => new GiniAggregator(metadata.numClasses)
case Entropy => new EntropyAggregator(metadata.numClasses)
case Variance => new VarianceAggregator()
case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}")
}

/**
* Number of elements (Double values) used for the sufficient statistics of each bin.
*/
val statsSize: Int = impurityAggregator.statsSize

val numFeatures: Int = metadata.numFeatures

/**
* Number of bins for each feature. This is indexed by the feature index.
*/
val numBins: Array[Int] = metadata.numBins

/**
* Number of splits for the given feature.
*/
def numSplits(featureIndex: Int): Int = metadata.numSplits(featureIndex)

/**
* Indicator for each feature of whether that feature is an unordered feature.
* TODO: Is Array[Boolean] any faster?
*/
def isUnordered(featureIndex: Int): Boolean = metadata.isUnordered(featureIndex)

/**
* Offset for each feature for calculating indices into the [[allStats]] array.
*/
private val featureOffsets: Array[Int] = {
def featureOffsetsCalc(total: Int, featureIndex: Int): Int = {
if (isUnordered(featureIndex)) {
total + 2 * numBins(featureIndex)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would the factor of 2 for unordered categorical feature be more suitable in the numBins calculation in the DecisionTreeMetaData class?

} else {
total + numBins(featureIndex)
}
}
Range(0, numFeatures).scanLeft(0)(featureOffsetsCalc).map(statsSize * _).toArray
}

/**
* Number of elements for each node, corresponding to stride between nodes in [[allStats]].
*/
private val nodeStride: Int = featureOffsets.last

/**
* Total number of elements stored in this aggregator.
*/
val allStatsSize: Int = numNodes * nodeStride

/**
* Flat array of elements.
* Index for start of stats for a (node, feature, bin) is:
* index = nodeIndex * nodeStride + featureOffsets(featureIndex) + binIndex * statsSize
* Note: For unordered features, the left child stats have binIndex in [0, numBins(featureIndex))
* and the right child stats in [numBins(featureIndex), 2 * numBins(featureIndex))
*/
val allStats: Array[Double] = new Array[Double](allStatsSize)

/**
* Get an [[ImpurityCalculator]] for a given (node, feature, bin).
* @param nodeFeatureOffset For ordered features, this is a pre-computed (node, feature) offset
* from [[getNodeFeatureOffset]].
* For unordered features, this is a pre-computed
* (node, feature, left/right child) offset from
* [[getLeftRightNodeFeatureOffsets]].
*/
def getImpurityCalculator(nodeFeatureOffset: Int, binIndex: Int): ImpurityCalculator = {
impurityAggregator.getCalculator(allStats, nodeFeatureOffset + binIndex * statsSize)
}

/**
* Update the stats for a given (node, feature, bin) for ordered features, using the given label.
*/
def update(nodeIndex: Int, featureIndex: Int, binIndex: Int, label: Double): Unit = {
val i = nodeIndex * nodeStride + featureOffsets(featureIndex) + binIndex * statsSize
impurityAggregator.update(allStats, i, label)
}

/**
* Pre-compute node offset for use with [[nodeUpdate]].
*/
def getNodeOffset(nodeIndex: Int): Int = nodeIndex * nodeStride

/**
* Faster version of [[update]].
* Update the stats for a given (node, feature, bin) for ordered features, using the given label.
* @param nodeOffset Pre-computed node offset from [[getNodeOffset]].
*/
def nodeUpdate(nodeOffset: Int, featureIndex: Int, binIndex: Int, label: Double): Unit = {
val i = nodeOffset + featureOffsets(featureIndex) + binIndex * statsSize
impurityAggregator.update(allStats, i, label)
}

/**
* Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]].
* For ordered features only.
*/
def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int = {
require(!isUnordered(featureIndex),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we enable these requires only in debug mode?

s"DTStatsAggregator.getNodeFeatureOffset is for ordered features only, but was called" +
s" for unordered feature $featureIndex.")
nodeIndex * nodeStride + featureOffsets(featureIndex)
}

/**
* Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]].
* For unordered features only.
*/
def getLeftRightNodeFeatureOffsets(nodeIndex: Int, featureIndex: Int): (Int, Int) = {
require(isUnordered(featureIndex),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto.

s"DTStatsAggregator.getLeftRightNodeFeatureOffsets is for unordered features only," +
s" but was called for ordered feature $featureIndex.")
val baseOffset = nodeIndex * nodeStride + featureOffsets(featureIndex)
(baseOffset, baseOffset + numBins(featureIndex) * statsSize)
}

/**
* Faster version of [[update]].
* Update the stats for a given (node, feature, bin), using the given label.
* @param nodeFeatureOffset For ordered features, this is a pre-computed (node, feature) offset
* from [[getNodeFeatureOffset]].
* For unordered features, this is a pre-computed
* (node, feature, left/right child) offset from
* [[getLeftRightNodeFeatureOffsets]].
*/
def nodeFeatureUpdate(nodeFeatureOffset: Int, binIndex: Int, label: Double): Unit = {
impurityAggregator.update(allStats, nodeFeatureOffset + binIndex * statsSize, label)
}

/**
* For a given (node, feature), merge the stats for two bins.
* @param nodeFeatureOffset For ordered features, this is a pre-computed (node, feature) offset
* from [[getNodeFeatureOffset]].
* For unordered features, this is a pre-computed
* (node, feature, left/right child) offset from
* [[getLeftRightNodeFeatureOffsets]].
* @param binIndex The other bin is merged into this bin.
* @param otherBinIndex This bin is not modified.
*/
def mergeForNodeFeature(nodeFeatureOffset: Int, binIndex: Int, otherBinIndex: Int): Unit = {
impurityAggregator.merge(allStats, nodeFeatureOffset + binIndex * statsSize,
nodeFeatureOffset + otherBinIndex * statsSize)
}

/**
* Merge this aggregator with another, and returns this aggregator.
* This method modifies this aggregator in-place.
*/
def merge(other: DTStatsAggregator): DTStatsAggregator = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also possibly create a JIRA for handling precision loss while using Double for large aggregates during variance calculation. This was an observation from @mengxr during the first DT PR. I tried to incorporate it then but it was hard with that code structure. This abstraction might be a good place to incorporate @mengxr 's suggestion. Again, not urgent but good for future improvement.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea; I'll make a JIRA for that.

require(allStatsSize == other.allStatsSize,
s"DTStatsAggregator.merge requires that both aggregators have the same length stats vectors."
+ s" This aggregator is of length $allStatsSize, but the other is ${other.allStatsSize}.")
var i = 0
// TODO: Test BLAS.axpy
while (i < allStatsSize) {
allStats(i) += other.allStats(i)
i += 1
}
this
}

}

private[tree] object DTStatsAggregator extends Serializable {

/**
* Combines two aggregates (modifying the first) and returns the combination.
*/
def binCombOp(
agg1: DTStatsAggregator,
agg2: DTStatsAggregator): DTStatsAggregator = {
agg1.merge(agg2)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@ import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.impurity.Impurity
import org.apache.spark.rdd.RDD


/**
* Learning and dataset metadata for DecisionTree.
*
* @param numClasses For classification: labels can take values {0, ..., numClasses - 1}.
* For regression: fixed at 0 (no meaning).
* @param maxBins Maximum number of bins, for all features.
* @param featureArity Map: categorical feature index --> arity.
* I.e., the feature takes values in {0, ..., arity - 1}.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Javadoc does not match the method arguments.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This actually does match the parameter meaning. The parameter is set when DecisionTreeMetadata is constructed by buildMetadata, and it is set to be the actual max, not the max possible/allowed number of bins.

* @param numBins Number of bins for each feature.
*/
private[tree] class DecisionTreeMetadata(
val numFeatures: Int,
Expand All @@ -42,6 +43,7 @@ private[tree] class DecisionTreeMetadata(
val maxBins: Int,
val featureArity: Map[Int, Int],
val unorderedFeatures: Set[Int],
val numBins: Array[Int],
val impurity: Impurity,
val quantileStrategy: QuantileStrategy) extends Serializable {

Expand All @@ -57,10 +59,26 @@ private[tree] class DecisionTreeMetadata(

def isContinuous(featureIndex: Int): Boolean = !featureArity.contains(featureIndex)

/**
* Number of splits for the given feature.
* For unordered features, there are 2 bins per split.
* For ordered features, there is 1 more bin than split.
*/
def numSplits(featureIndex: Int): Int = if (isUnordered(featureIndex)) {
numBins(featureIndex) >> 1
} else {
numBins(featureIndex) - 1
}

}

private[tree] object DecisionTreeMetadata {

/**
* Construct a [[DecisionTreeMetadata]] instance for this dataset and parameters.
* This computes which categorical features will be ordered vs. unordered,
* as well as the number of splits and bins for each feature.
*/
def buildMetadata(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeMetadata = {

val numFeatures = input.take(1)(0).features.size
Expand All @@ -70,32 +88,55 @@ private[tree] object DecisionTreeMetadata {
case Regression => 0
}

val maxBins = math.min(strategy.maxBins, numExamples).toInt
val log2MaxBinsp1 = math.log(maxBins + 1) / math.log(2.0)
val maxPossibleBins = math.min(strategy.maxBins, numExamples).toInt

// We check the number of bins here against maxPossibleBins.
// This needs to be checked here instead of in Strategy since maxPossibleBins can be modified
// based on the number of training examples.
if (strategy.categoricalFeaturesInfo.nonEmpty) {
val maxCategoriesPerFeature = strategy.categoricalFeaturesInfo.values.max
require(maxCategoriesPerFeature <= maxPossibleBins,
s"DecisionTree requires maxBins (= $maxPossibleBins) >= max categories " +
s"in categorical features (= $maxCategoriesPerFeature)")
}

val unorderedFeatures = new mutable.HashSet[Int]()
val numBins = Array.fill[Int](numFeatures)(maxPossibleBins)
if (numClasses > 2) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

val isMulticlassClassification = numClasses > 2 might be more readable

strategy.categoricalFeaturesInfo.foreach { case (f, k) =>
if (k - 1 < log2MaxBinsp1) {
// Note: The above check is equivalent to checking:
// numUnorderedBins = (1 << k - 1) - 1 < maxBins
unorderedFeatures.add(f)
// Multiclass classification
val maxCategoriesForUnorderedFeature =
((math.log(maxPossibleBins / 2 + 1) / math.log(2.0)) + 1).floor.toInt
strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) =>
// Decide if some categorical features should be treated as unordered features,
// which require 2 * ((1 << numCategories - 1) - 1) bins.
// We do this check with log values to prevent overflows in case numCategories is large.
// The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins
if (numCategories <= maxCategoriesForUnorderedFeature) {
unorderedFeatures.add(featureIndex)
numBins(featureIndex) = numUnorderedBins(numCategories)
} else {
// TODO: Allow this case, where we simply will know nothing about some categories?
require(k < maxBins, s"maxBins (= $maxBins) should be greater than max categories " +
s"in categorical features (>= $k)")
numBins(featureIndex) = numCategories
}
}
} else {
strategy.categoricalFeaturesInfo.foreach { case (f, k) =>
require(k < maxBins, s"maxBins (= $maxBins) should be greater than max categories " +
s"in categorical features (>= $k)")
// Binary classification or regression
strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) =>
numBins(featureIndex) = numCategories
}
}

new DecisionTreeMetadata(numFeatures, numExamples, numClasses, maxBins,
strategy.categoricalFeaturesInfo, unorderedFeatures.toSet,
new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
strategy.impurity, strategy.quantileCalculationStrategy)
}

/**
* Given the arity of a categorical feature (arity = number of categories),
* return the number of bins for the feature if it is to be treated as an unordered feature.
* There is 1 split for every partitioning of categories into 2 disjoint, non-empty sets;
* there are math.pow(2, arity - 1) - 1 such splits.
* Each split has 2 corresponding bins.
*/
def numUnorderedBins(arity: Int): Int = 2 * ((1 << arity - 1) - 1)

}
Loading