Skip to content

Commit 26d10dd

Browse files
committed
Removed tree/model/Filter.scala since no longer used. Removed debugging println calls in DecisionTree.scala.
1 parent 356daba commit 26d10dd

File tree

2 files changed

+6
-61
lines changed

2 files changed

+6
-61
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 6 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
6262

6363
timer.start("total")
6464

65-
// Cache input RDD for speedup during multiple passes.
6665
timer.start("init")
6766
val retaggedInput = input.retag(classOf[LabeledPoint])
6867
logDebug("algo = " + strategy.algo)
@@ -77,17 +76,15 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
7776
logDebug("numBins = " + numBins)
7877

7978
timer.start("init")
79+
// Bin feature values (TreePoint representation).
80+
// Cache input RDD for speedup during multiple passes.
8081
val treeInput = TreePoint.convertToTreeRDD(retaggedInput, strategy, bins).cache()
8182
timer.stop("init")
8283

8384
// depth of the decision tree
8485
val maxDepth = strategy.maxDepth
8586
// the max number of nodes possible given the depth of the tree
8687
val maxNumNodes = math.pow(2, maxDepth + 1).toInt - 1
87-
// Initialize an array to hold filters applied to points for each node.
88-
//val filters = new Array[List[Filter]](maxNumNodes)
89-
// The filter at the top node is an empty list.
90-
//filters(0) = List()
9188
// Initialize an array to hold parent impurity calculations for each node.
9289
val parentImpurities = new Array[Double](maxNumNodes)
9390
// dummy value for top node (updated during first split calculation)
@@ -118,9 +115,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
118115
/*
119116
* The main idea here is to perform level-wise training of the decision tree nodes thus
120117
* reducing the passes over the data from l to log2(l) where l is the total number of nodes.
121-
* Each data sample is checked for validity w.r.t to each node at a given level -- i.e.,
122-
* the sample is only used for the split calculation at the node if the sampled would have
123-
* still survived the filters of the parent nodes.
118+
* Each data sample is handled by a particular node at that level (or it reaches a leaf
119+
* beforehand and is not used in later levels.
124120
*/
125121

126122
var level = 0
@@ -169,7 +165,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
169165
}
170166
require(math.pow(2, level) == splitsStatsForLevel.length)
171167
// Check whether all the nodes at the current level at leaves.
172-
println(s"LOOP over levels: level=$level, splitStats...gains: ${splitsStatsForLevel.map(_._2.gain).mkString(",")}")
173168
val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0)
174169
logDebug("all leaf = " + allLeaf)
175170
if (allLeaf) {
@@ -237,8 +232,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
237232
logDebug("nodeIndex = " + nodeIndex + ", impurity = " + impurity)
238233
// noting the parent impurities
239234
parentImpurities(nodeIndex) = impurity
240-
// noting the parents filters for the child nodes
241-
val childFilter = new Filter(nodeSplitStats._1, if (i == 0) -1 else 1)
242235
i += 1
243236
}
244237
}
@@ -461,7 +454,6 @@ object DecisionTree extends Serializable with Logging {
461454
* @param maxLevelForSingleGroup the deepest level for single-group level-wise computation.
462455
* @param unorderedFeatures Set of unordered (categorical) features.
463456
* @return array (over nodes) of splits with best split for each node at a given level.
464-
* TODO: UPDATE DOC
465457
*/
466458
protected[tree] def findBestSplits(
467459
input: RDD[TreePoint],
@@ -512,7 +504,6 @@ object DecisionTree extends Serializable with Logging {
512504
* @param numGroups total number of node groups at the current level. Default value is set to 1.
513505
* @param groupIndex index of the node group being processed. Default value is set to 0.
514506
* @return array of splits with best splits for all nodes at a given level.
515-
* TODO: UPDATE DOC
516507
*/
517508
private def findBestSplitsPerGroup(
518509
input: RDD[TreePoint],
@@ -539,7 +530,7 @@ object DecisionTree extends Serializable with Logging {
539530
* We use a bin-wise best split computation strategy instead of a straightforward best split
540531
* computation strategy. Instead of analyzing each sample for contribution to the left/right
541532
* child node impurity of every split, we first categorize each feature of a sample into a
542-
* bin. Each bin is an interval between a low and high split. Since each splits, and thus bin,
533+
* bin. Each bin is an interval between a low and high split. Since each split, and thus bin,
543534
* is ordered (read ordering for categorical variables in the findSplitsBins method),
544535
* we exploit this structure to calculate aggregates for bins and then use these aggregates
545536
* to calculate information gain for each split.
@@ -660,7 +651,6 @@ object DecisionTree extends Serializable with Logging {
660651
* numClasses * numBins * numFeatures * numNodes.
661652
* Indexed by (node, feature, bin, label) where label is the least significant bit.
662653
* @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group).
663-
* TODO: UPDATE DOC
664654
*/
665655
def updateBinForOrderedFeature(
666656
treePoint: TreePoint,
@@ -681,21 +671,19 @@ object DecisionTree extends Serializable with Logging {
681671
* where [bins] ranges over all bins.
682672
* Updates left or right side of aggregate depending on split.
683673
*
674+
* @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group).
684675
* @param treePoint Data point being aggregated.
685676
* @param agg Indexed by (left/right, node, feature, bin, label)
686677
* where label is the least significant bit.
687678
* The left/right specifier is a 0/1 index indicating left/right child info.
688679
* @param rightChildShift Offset for right side of agg.
689-
* TODO: UPDATE DOC
690-
* TODO: Make arg order same as for ordered feature.
691680
*/
692681
def updateBinForUnorderedFeature(
693682
nodeIndex: Int,
694683
featureIndex: Int,
695684
treePoint: TreePoint,
696685
agg: Array[Double],
697686
rightChildShift: Int): Unit = {
698-
//println(s"-- updateBinForUnorderedFeature node:$nodeIndex, feature:$featureIndex, label:$label.")
699687
val featureValue = treePoint.features(featureIndex)
700688
// Update the left or right count for one bin.
701689
val aggShift =
@@ -780,7 +768,6 @@ object DecisionTree extends Serializable with Logging {
780768
* @return agg
781769
*/
782770
def regressionBinSeqOp(agg: Array[Double], treePoint: TreePoint, nodeIndex: Int): Unit = {
783-
// TODO: Move stuff outside loop.
784771
val label = treePoint.label
785772
// Iterate over all features.
786773
var featureIndex = 0
@@ -791,9 +778,6 @@ object DecisionTree extends Serializable with Logging {
791778
3 * numBins * numFeatures * nodeIndex +
792779
3 * numBins * featureIndex +
793780
3 * binIndex
794-
if (aggIndex >= agg.size) {
795-
println(s"aggIndex = $aggIndex, agg.size = ${agg.size}. binIndex = $binIndex, featureIndex = $featureIndex, nodeIndex = $nodeIndex, numBins = $numBins, numFeatures = $numFeatures")
796-
}
797781
agg(aggIndex) = agg(aggIndex) + 1
798782
agg(aggIndex + 1) = agg(aggIndex + 1) + label
799783
agg(aggIndex + 2) = agg(aggIndex + 2) + label * label
@@ -1025,7 +1009,6 @@ object DecisionTree extends Serializable with Logging {
10251009
* Element i (i = 1, ..., numSplits - 1) is set to be
10261010
* the cumulative sum (from right) over binData for bins
10271011
* numBins - 1, ..., numBins - 1 - i.
1028-
* TODO: We could avoid doing one of these cumulative sums.
10291012
*/
10301013
def findAggForOrderedFeatureClassification(
10311014
leftNodeAgg: Array[Array[Array[Double]]],
@@ -1196,16 +1179,6 @@ object DecisionTree extends Serializable with Logging {
11961179
} else {
11971180
featureCategories
11981181
}
1199-
/*
1200-
val isSpaceSufficientForAllCategoricalSplits =
1201-
numBins > math.pow(2, featureCategories.toInt - 1) - 1
1202-
if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) {
1203-
math.pow(2.0, featureCategories - 1).toInt - 1
1204-
} else {
1205-
// Ordered features
1206-
featureCategories
1207-
}
1208-
*/
12091182
}
12101183
}
12111184

mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala

Lines changed: 0 additions & 28 deletions
This file was deleted.

0 commit comments

Comments
 (0)