@@ -62,7 +62,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
6262
6363 timer.start(" total" )
6464
65- // Cache input RDD for speedup during multiple passes.
6665 timer.start(" init" )
6766 val retaggedInput = input.retag(classOf [LabeledPoint ])
6867 logDebug(" algo = " + strategy.algo)
@@ -77,17 +76,15 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
7776 logDebug(" numBins = " + numBins)
7877
7978 timer.start(" init" )
79+ // Bin feature values (TreePoint representation).
80+ // Cache input RDD for speedup during multiple passes.
8081 val treeInput = TreePoint .convertToTreeRDD(retaggedInput, strategy, bins).cache()
8182 timer.stop(" init" )
8283
8384 // depth of the decision tree
8485 val maxDepth = strategy.maxDepth
8586 // the max number of nodes possible given the depth of the tree
8687 val maxNumNodes = math.pow(2 , maxDepth + 1 ).toInt - 1
87- // Initialize an array to hold filters applied to points for each node.
88- // val filters = new Array[List[Filter]](maxNumNodes)
89- // The filter at the top node is an empty list.
90- // filters(0) = List()
9188 // Initialize an array to hold parent impurity calculations for each node.
9289 val parentImpurities = new Array [Double ](maxNumNodes)
9390 // dummy value for top node (updated during first split calculation)
@@ -118,9 +115,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
118115 /*
119116 * The main idea here is to perform level-wise training of the decision tree nodes thus
120117 * reducing the passes over the data from l to log2(l) where l is the total number of nodes.
121- * Each data sample is checked for validity w.r.t to each node at a given level -- i.e.,
122- * the sample is only used for the split calculation at the node if the sampled would have
123- * still survived the filters of the parent nodes.
118+ * Each data sample is handled by a particular node at that level (or it reaches a leaf
119+ * beforehand and is not used in later levels.
124120 */
125121
126122 var level = 0
@@ -169,7 +165,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
169165 }
170166 require(math.pow(2 , level) == splitsStatsForLevel.length)
171167 // Check whether all the nodes at the current level at leaves.
172- println(s " LOOP over levels: level= $level, splitStats...gains: ${splitsStatsForLevel.map(_._2.gain).mkString(" ," )}" )
173168 val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0 )
174169 logDebug(" all leaf = " + allLeaf)
175170 if (allLeaf) {
@@ -237,8 +232,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
237232 logDebug(" nodeIndex = " + nodeIndex + " , impurity = " + impurity)
238233 // noting the parent impurities
239234 parentImpurities(nodeIndex) = impurity
240- // noting the parents filters for the child nodes
241- val childFilter = new Filter (nodeSplitStats._1, if (i == 0 ) - 1 else 1 )
242235 i += 1
243236 }
244237 }
@@ -461,7 +454,6 @@ object DecisionTree extends Serializable with Logging {
461454 * @param maxLevelForSingleGroup the deepest level for single-group level-wise computation.
462455 * @param unorderedFeatures Set of unordered (categorical) features.
463456 * @return array (over nodes) of splits with best split for each node at a given level.
464- * TODO: UPDATE DOC
465457 */
466458 protected [tree] def findBestSplits (
467459 input : RDD [TreePoint ],
@@ -512,7 +504,6 @@ object DecisionTree extends Serializable with Logging {
512504 * @param numGroups total number of node groups at the current level. Default value is set to 1.
513505 * @param groupIndex index of the node group being processed. Default value is set to 0.
514506 * @return array of splits with best splits for all nodes at a given level.
515- * TODO: UPDATE DOC
516507 */
517508 private def findBestSplitsPerGroup (
518509 input : RDD [TreePoint ],
@@ -539,7 +530,7 @@ object DecisionTree extends Serializable with Logging {
539530 * We use a bin-wise best split computation strategy instead of a straightforward best split
540531 * computation strategy. Instead of analyzing each sample for contribution to the left/right
541532 * child node impurity of every split, we first categorize each feature of a sample into a
542- * bin. Each bin is an interval between a low and high split. Since each splits , and thus bin,
533+ * bin. Each bin is an interval between a low and high split. Since each split , and thus bin,
543534 * is ordered (read ordering for categorical variables in the findSplitsBins method),
544535 * we exploit this structure to calculate aggregates for bins and then use these aggregates
545536 * to calculate information gain for each split.
@@ -660,7 +651,6 @@ object DecisionTree extends Serializable with Logging {
660651 * numClasses * numBins * numFeatures * numNodes.
661652 * Indexed by (node, feature, bin, label) where label is the least significant bit.
662653 * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group).
663- * TODO: UPDATE DOC
664654 */
665655 def updateBinForOrderedFeature (
666656 treePoint : TreePoint ,
@@ -681,21 +671,19 @@ object DecisionTree extends Serializable with Logging {
681671 * where [bins] ranges over all bins.
682672 * Updates left or right side of aggregate depending on split.
683673 *
674+ * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group).
684675 * @param treePoint Data point being aggregated.
685676 * @param agg Indexed by (left/right, node, feature, bin, label)
686677 * where label is the least significant bit.
687678 * The left/right specifier is a 0/1 index indicating left/right child info.
688679 * @param rightChildShift Offset for right side of agg.
689- * TODO: UPDATE DOC
690- * TODO: Make arg order same as for ordered feature.
691680 */
692681 def updateBinForUnorderedFeature (
693682 nodeIndex : Int ,
694683 featureIndex : Int ,
695684 treePoint : TreePoint ,
696685 agg : Array [Double ],
697686 rightChildShift : Int ): Unit = {
698- // println(s"-- updateBinForUnorderedFeature node:$nodeIndex, feature:$featureIndex, label:$label.")
699687 val featureValue = treePoint.features(featureIndex)
700688 // Update the left or right count for one bin.
701689 val aggShift =
@@ -780,7 +768,6 @@ object DecisionTree extends Serializable with Logging {
780768 * @return agg
781769 */
782770 def regressionBinSeqOp (agg : Array [Double ], treePoint : TreePoint , nodeIndex : Int ): Unit = {
783- // TODO: Move stuff outside loop.
784771 val label = treePoint.label
785772 // Iterate over all features.
786773 var featureIndex = 0
@@ -791,9 +778,6 @@ object DecisionTree extends Serializable with Logging {
791778 3 * numBins * numFeatures * nodeIndex +
792779 3 * numBins * featureIndex +
793780 3 * binIndex
794- if (aggIndex >= agg.size) {
795- println(s " aggIndex = $aggIndex, agg.size = ${agg.size}. binIndex = $binIndex, featureIndex = $featureIndex, nodeIndex = $nodeIndex, numBins = $numBins, numFeatures = $numFeatures" )
796- }
797781 agg(aggIndex) = agg(aggIndex) + 1
798782 agg(aggIndex + 1 ) = agg(aggIndex + 1 ) + label
799783 agg(aggIndex + 2 ) = agg(aggIndex + 2 ) + label * label
@@ -1025,7 +1009,6 @@ object DecisionTree extends Serializable with Logging {
10251009 * Element i (i = 1, ..., numSplits - 1) is set to be
10261010 * the cumulative sum (from right) over binData for bins
10271011 * numBins - 1, ..., numBins - 1 - i.
1028- * TODO: We could avoid doing one of these cumulative sums.
10291012 */
10301013 def findAggForOrderedFeatureClassification (
10311014 leftNodeAgg : Array [Array [Array [Double ]]],
@@ -1196,16 +1179,6 @@ object DecisionTree extends Serializable with Logging {
11961179 } else {
11971180 featureCategories
11981181 }
1199- /*
1200- val isSpaceSufficientForAllCategoricalSplits =
1201- numBins > math.pow(2, featureCategories.toInt - 1) - 1
1202- if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) {
1203- math.pow(2.0, featureCategories - 1).toInt - 1
1204- } else {
1205- // Ordered features
1206- featureCategories
1207- }
1208- */
12091182 }
12101183 }
12111184
0 commit comments