-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-3086] [SPARK-3043] [SPARK-3156] [mllib] DecisionTree aggregation improvements #2125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a95bc22
511ec85
bcf874a
f61e9d2
3211f02
0f676e2
b2ed1f3
b914f3b
c1565a5
fd65372
51ef781
e3c84cc
86e217f
438a660
dd4d3aa
a87e08f
8464a6e
e66f1b1
d036089
430d782
356daba
26d10dd
5fce635
45f7ea7
9c83363
b314659
3ba7166
61c4509
5f94342
95cad7c
a40f8f1
d7c53ee
fd8df30
92f7118
f2166fd
807cd00
6d32ccd
062c31d
105f8ab
37ca845
e676da1
92f934f
1485fcc
1e3b1c7
4651154
aa4e4df
a2acea5
425716c
00e4404
d3cc46b
42c192a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,213 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.mllib.tree.impl | ||
|
||
import org.apache.spark.mllib.tree.impurity._ | ||
|
||
/** | ||
* DecisionTree statistics aggregator. | ||
* This holds a flat array of statistics for a set of (nodes, features, bins) | ||
* and helps with indexing. | ||
*/ | ||
private[tree] class DTStatsAggregator( | ||
val metadata: DecisionTreeMetadata, | ||
val numNodes: Int) extends Serializable { | ||
|
||
/** | ||
* [[ImpurityAggregator]] instance specifying the impurity type. | ||
*/ | ||
val impurityAggregator: ImpurityAggregator = metadata.impurity match { | ||
case Gini => new GiniAggregator(metadata.numClasses) | ||
case Entropy => new EntropyAggregator(metadata.numClasses) | ||
case Variance => new VarianceAggregator() | ||
case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}") | ||
} | ||
|
||
/** | ||
* Number of elements (Double values) used for the sufficient statistics of each bin. | ||
*/ | ||
val statsSize: Int = impurityAggregator.statsSize | ||
|
||
val numFeatures: Int = metadata.numFeatures | ||
|
||
/** | ||
* Number of bins for each feature. This is indexed by the feature index. | ||
*/ | ||
val numBins: Array[Int] = metadata.numBins | ||
|
||
/** | ||
* Number of splits for the given feature. | ||
*/ | ||
def numSplits(featureIndex: Int): Int = metadata.numSplits(featureIndex) | ||
|
||
/** | ||
* Indicator for each feature of whether that feature is an unordered feature. | ||
* TODO: Is Array[Boolean] any faster? | ||
*/ | ||
def isUnordered(featureIndex: Int): Boolean = metadata.isUnordered(featureIndex) | ||
|
||
/** | ||
* Offset for each feature for calculating indices into the [[allStats]] array. | ||
*/ | ||
private val featureOffsets: Array[Int] = { | ||
def featureOffsetsCalc(total: Int, featureIndex: Int): Int = { | ||
if (isUnordered(featureIndex)) { | ||
total + 2 * numBins(featureIndex) | ||
} else { | ||
total + numBins(featureIndex) | ||
} | ||
} | ||
Range(0, numFeatures).scanLeft(0)(featureOffsetsCalc).map(statsSize * _).toArray | ||
} | ||
|
||
/** | ||
* Number of elements for each node, corresponding to stride between nodes in [[allStats]]. | ||
*/ | ||
private val nodeStride: Int = featureOffsets.last | ||
|
||
/** | ||
* Total number of elements stored in this aggregator. | ||
*/ | ||
val allStatsSize: Int = numNodes * nodeStride | ||
|
||
/** | ||
* Flat array of elements. | ||
* Index for start of stats for a (node, feature, bin) is: | ||
* index = nodeIndex * nodeStride + featureOffsets(featureIndex) + binIndex * statsSize | ||
* Note: For unordered features, the left child stats have binIndex in [0, numBins(featureIndex)) | ||
* and the right child stats in [numBins(featureIndex), 2 * numBins(featureIndex)) | ||
*/ | ||
val allStats: Array[Double] = new Array[Double](allStatsSize) | ||
|
||
/** | ||
* Get an [[ImpurityCalculator]] for a given (node, feature, bin). | ||
* @param nodeFeatureOffset For ordered features, this is a pre-computed (node, feature) offset | ||
* from [[getNodeFeatureOffset]]. | ||
* For unordered features, this is a pre-computed | ||
* (node, feature, left/right child) offset from | ||
* [[getLeftRightNodeFeatureOffsets]]. | ||
*/ | ||
def getImpurityCalculator(nodeFeatureOffset: Int, binIndex: Int): ImpurityCalculator = { | ||
impurityAggregator.getCalculator(allStats, nodeFeatureOffset + binIndex * statsSize) | ||
} | ||
|
||
/** | ||
* Update the stats for a given (node, feature, bin) for ordered features, using the given label. | ||
*/ | ||
def update(nodeIndex: Int, featureIndex: Int, binIndex: Int, label: Double): Unit = { | ||
val i = nodeIndex * nodeStride + featureOffsets(featureIndex) + binIndex * statsSize | ||
impurityAggregator.update(allStats, i, label) | ||
} | ||
|
||
/** | ||
* Pre-compute node offset for use with [[nodeUpdate]]. | ||
*/ | ||
def getNodeOffset(nodeIndex: Int): Int = nodeIndex * nodeStride | ||
|
||
/** | ||
* Faster version of [[update]]. | ||
* Update the stats for a given (node, feature, bin) for ordered features, using the given label. | ||
* @param nodeOffset Pre-computed node offset from [[getNodeOffset]]. | ||
*/ | ||
def nodeUpdate(nodeOffset: Int, featureIndex: Int, binIndex: Int, label: Double): Unit = { | ||
val i = nodeOffset + featureOffsets(featureIndex) + binIndex * statsSize | ||
impurityAggregator.update(allStats, i, label) | ||
} | ||
|
||
/** | ||
* Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]]. | ||
* For ordered features only. | ||
*/ | ||
def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int = { | ||
require(!isUnordered(featureIndex), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we enable these requires only in debug mode? |
||
s"DTStatsAggregator.getNodeFeatureOffset is for ordered features only, but was called" + | ||
s" for unordered feature $featureIndex.") | ||
nodeIndex * nodeStride + featureOffsets(featureIndex) | ||
} | ||
|
||
/** | ||
* Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]]. | ||
* For unordered features only. | ||
*/ | ||
def getLeftRightNodeFeatureOffsets(nodeIndex: Int, featureIndex: Int): (Int, Int) = { | ||
require(isUnordered(featureIndex), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ditto. |
||
s"DTStatsAggregator.getLeftRightNodeFeatureOffsets is for unordered features only," + | ||
s" but was called for ordered feature $featureIndex.") | ||
val baseOffset = nodeIndex * nodeStride + featureOffsets(featureIndex) | ||
(baseOffset, baseOffset + numBins(featureIndex) * statsSize) | ||
} | ||
|
||
/** | ||
* Faster version of [[update]]. | ||
* Update the stats for a given (node, feature, bin), using the given label. | ||
* @param nodeFeatureOffset For ordered features, this is a pre-computed (node, feature) offset | ||
* from [[getNodeFeatureOffset]]. | ||
* For unordered features, this is a pre-computed | ||
* (node, feature, left/right child) offset from | ||
* [[getLeftRightNodeFeatureOffsets]]. | ||
*/ | ||
def nodeFeatureUpdate(nodeFeatureOffset: Int, binIndex: Int, label: Double): Unit = { | ||
impurityAggregator.update(allStats, nodeFeatureOffset + binIndex * statsSize, label) | ||
} | ||
|
||
/** | ||
* For a given (node, feature), merge the stats for two bins. | ||
* @param nodeFeatureOffset For ordered features, this is a pre-computed (node, feature) offset | ||
* from [[getNodeFeatureOffset]]. | ||
* For unordered features, this is a pre-computed | ||
* (node, feature, left/right child) offset from | ||
* [[getLeftRightNodeFeatureOffsets]]. | ||
* @param binIndex The other bin is merged into this bin. | ||
* @param otherBinIndex This bin is not modified. | ||
*/ | ||
def mergeForNodeFeature(nodeFeatureOffset: Int, binIndex: Int, otherBinIndex: Int): Unit = { | ||
impurityAggregator.merge(allStats, nodeFeatureOffset + binIndex * statsSize, | ||
nodeFeatureOffset + otherBinIndex * statsSize) | ||
} | ||
|
||
/** | ||
* Merge this aggregator with another, and returns this aggregator. | ||
* This method modifies this aggregator in-place. | ||
*/ | ||
def merge(other: DTStatsAggregator): DTStatsAggregator = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should also possibly create a JIRA for handling precision loss while using Double for large aggregates during variance calculation. This was an observation from @mengxr during the first DT PR. I tried to incorporate it then but it was hard with that code structure. This abstraction might be a good place to incorporate @mengxr 's suggestion. Again, not urgent but good for future improvement. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good idea; I'll make a JIRA for that. |
||
require(allStatsSize == other.allStatsSize, | ||
s"DTStatsAggregator.merge requires that both aggregators have the same length stats vectors." | ||
+ s" This aggregator is of length $allStatsSize, but the other is ${other.allStatsSize}.") | ||
var i = 0 | ||
// TODO: Test BLAS.axpy | ||
while (i < allStatsSize) { | ||
allStats(i) += other.allStats(i) | ||
i += 1 | ||
} | ||
this | ||
} | ||
|
||
} | ||
|
||
private[tree] object DTStatsAggregator extends Serializable { | ||
|
||
/** | ||
* Combines two aggregates (modifying the first) and returns the combination. | ||
*/ | ||
def binCombOp( | ||
agg1: DTStatsAggregator, | ||
agg2: DTStatsAggregator): DTStatsAggregator = { | ||
agg1.merge(agg2) | ||
} | ||
|
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,14 +26,15 @@ import org.apache.spark.mllib.tree.configuration.Strategy | |
import org.apache.spark.mllib.tree.impurity.Impurity | ||
import org.apache.spark.rdd.RDD | ||
|
||
|
||
/** | ||
* Learning and dataset metadata for DecisionTree. | ||
* | ||
* @param numClasses For classification: labels can take values {0, ..., numClasses - 1}. | ||
* For regression: fixed at 0 (no meaning). | ||
* @param maxBins Maximum number of bins, for all features. | ||
* @param featureArity Map: categorical feature index --> arity. | ||
* I.e., the feature takes values in {0, ..., arity - 1}. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Javadoc does not match the method arguments. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This actually does match the parameter meaning. The parameter is set when DecisionTreeMetadata is constructed by buildMetadata, and it is set to be the actual max, not the max possible/allowed number of bins. |
||
* @param numBins Number of bins for each feature. | ||
*/ | ||
private[tree] class DecisionTreeMetadata( | ||
val numFeatures: Int, | ||
|
@@ -42,6 +43,7 @@ private[tree] class DecisionTreeMetadata( | |
val maxBins: Int, | ||
val featureArity: Map[Int, Int], | ||
val unorderedFeatures: Set[Int], | ||
val numBins: Array[Int], | ||
val impurity: Impurity, | ||
val quantileStrategy: QuantileStrategy) extends Serializable { | ||
|
||
|
@@ -57,10 +59,26 @@ private[tree] class DecisionTreeMetadata( | |
|
||
def isContinuous(featureIndex: Int): Boolean = !featureArity.contains(featureIndex) | ||
|
||
/** | ||
* Number of splits for the given feature. | ||
* For unordered features, there are 2 bins per split. | ||
* For ordered features, there is 1 more bin than split. | ||
*/ | ||
def numSplits(featureIndex: Int): Int = if (isUnordered(featureIndex)) { | ||
numBins(featureIndex) >> 1 | ||
} else { | ||
numBins(featureIndex) - 1 | ||
} | ||
|
||
} | ||
|
||
private[tree] object DecisionTreeMetadata { | ||
|
||
/** | ||
* Construct a [[DecisionTreeMetadata]] instance for this dataset and parameters. | ||
* This computes which categorical features will be ordered vs. unordered, | ||
* as well as the number of splits and bins for each feature. | ||
*/ | ||
def buildMetadata(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeMetadata = { | ||
|
||
val numFeatures = input.take(1)(0).features.size | ||
|
@@ -70,32 +88,55 @@ private[tree] object DecisionTreeMetadata { | |
case Regression => 0 | ||
} | ||
|
||
val maxBins = math.min(strategy.maxBins, numExamples).toInt | ||
val log2MaxBinsp1 = math.log(maxBins + 1) / math.log(2.0) | ||
val maxPossibleBins = math.min(strategy.maxBins, numExamples).toInt | ||
|
||
// We check the number of bins here against maxPossibleBins. | ||
// This needs to be checked here instead of in Strategy since maxPossibleBins can be modified | ||
// based on the number of training examples. | ||
if (strategy.categoricalFeaturesInfo.nonEmpty) { | ||
val maxCategoriesPerFeature = strategy.categoricalFeaturesInfo.values.max | ||
require(maxCategoriesPerFeature <= maxPossibleBins, | ||
s"DecisionTree requires maxBins (= $maxPossibleBins) >= max categories " + | ||
s"in categorical features (= $maxCategoriesPerFeature)") | ||
} | ||
|
||
val unorderedFeatures = new mutable.HashSet[Int]() | ||
val numBins = Array.fill[Int](numFeatures)(maxPossibleBins) | ||
if (numClasses > 2) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. val isMulticlassClassification = numClasses > 2 might be more readable |
||
strategy.categoricalFeaturesInfo.foreach { case (f, k) => | ||
if (k - 1 < log2MaxBinsp1) { | ||
// Note: The above check is equivalent to checking: | ||
// numUnorderedBins = (1 << k - 1) - 1 < maxBins | ||
unorderedFeatures.add(f) | ||
// Multiclass classification | ||
val maxCategoriesForUnorderedFeature = | ||
((math.log(maxPossibleBins / 2 + 1) / math.log(2.0)) + 1).floor.toInt | ||
strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) => | ||
// Decide if some categorical features should be treated as unordered features, | ||
// which require 2 * ((1 << numCategories - 1) - 1) bins. | ||
// We do this check with log values to prevent overflows in case numCategories is large. | ||
// The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins | ||
if (numCategories <= maxCategoriesForUnorderedFeature) { | ||
unorderedFeatures.add(featureIndex) | ||
numBins(featureIndex) = numUnorderedBins(numCategories) | ||
} else { | ||
// TODO: Allow this case, where we simply will know nothing about some categories? | ||
require(k < maxBins, s"maxBins (= $maxBins) should be greater than max categories " + | ||
s"in categorical features (>= $k)") | ||
numBins(featureIndex) = numCategories | ||
} | ||
} | ||
} else { | ||
strategy.categoricalFeaturesInfo.foreach { case (f, k) => | ||
require(k < maxBins, s"maxBins (= $maxBins) should be greater than max categories " + | ||
s"in categorical features (>= $k)") | ||
// Binary classification or regression | ||
strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) => | ||
numBins(featureIndex) = numCategories | ||
} | ||
} | ||
|
||
new DecisionTreeMetadata(numFeatures, numExamples, numClasses, maxBins, | ||
strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, | ||
new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max, | ||
strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins, | ||
strategy.impurity, strategy.quantileCalculationStrategy) | ||
} | ||
|
||
/** | ||
* Given the arity of a categorical feature (arity = number of categories), | ||
* return the number of bins for the feature if it is to be treated as an unordered feature. | ||
* There is 1 split for every partitioning of categories into 2 disjoint, non-empty sets; | ||
* there are math.pow(2, arity - 1) - 1 such splits. | ||
* Each split has 2 corresponding bins. | ||
*/ | ||
def numUnorderedBins(arity: Int): Int = 2 * ((1 << arity - 1) - 1) | ||
|
||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would the factor of 2 for unordered categorical feature be more suitable in the numBins calculation in the DecisionTreeMetaData class?