@@ -26,7 +26,7 @@ import org.apache.spark.mllib.tree.model.Split
2626import org .apache .spark .mllib .tree .impurity .Gini
2727
2828
29- class DecisionTree (val strategy : Strategy ) {
29+ class DecisionTree (val strategy : Strategy ) extends Logging {
3030
3131 def train (input : RDD [LabeledPoint ]) : DecisionTreeModel = {
3232
@@ -42,20 +42,43 @@ class DecisionTree(val strategy : Strategy) {
4242
4343 val maxNumNodes = scala.math.pow(2 ,maxDepth).toInt - 1
4444 val filters = new Array [List [Filter ]](maxNumNodes)
45+ filters(0 ) = List ()
46+ val parentImpurities = new Array [Double ](maxNumNodes)
47+ // Dummy value for top node (calculate from scratch during first split calculation)
48+ parentImpurities(0 ) = Double .MinValue
4549
4650 for (level <- 0 until maxDepth){
51+
52+ println(" #####################################" )
53+ println(" level = " + level)
54+ println(" #####################################" )
55+
4756 // Find best split for all nodes at a level
4857 val numNodes = scala.math.pow(2 ,level).toInt
49- // TODO: Change the input parent impurities values
50- val splits_stats_for_level = DecisionTree .findBestSplits(input, Array (2.0 ), strategy, level, filters,splits,bins)
51- for (tmp <- splits_stats_for_level){
52- println(" final best split = " + tmp._1)
58+ val splitsStatsForLevel = DecisionTree .findBestSplits(input, parentImpurities, strategy, level, filters,splits,bins)
59+ for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex){
60+ for (i <- 0 to 1 ){
61+ val nodeIndex = (scala.math.pow(2 ,level+ 1 )).toInt - 1 + 2 * index + i
62+ if (level < maxDepth - 1 ){
63+ val impurity = if (i == 0 ) nodeSplitStats._2.leftImpurity else nodeSplitStats._2.rightImpurity
64+ println(" nodeIndex = " + nodeIndex + " , impurity = " + impurity)
65+ parentImpurities(nodeIndex) = impurity
66+ println(" updating nodeIndex = " + nodeIndex)
67+ filters(nodeIndex) = new Filter (nodeSplitStats._1, if (i == 0 ) - 1 else 1 ) :: filters((nodeIndex- 1 )/ 2 )
68+ for (filter <- filters(nodeIndex)){
69+ println(filter)
70+ }
71+ }
72+ }
73+ println(" final best split = " + nodeSplitStats._1)
5374 }
54- // TODO: update filters and decision tree model
55- require(scala.math.pow( 2 ,level) == splits_stats_for_level.length)
75+ require(scala.math.pow( 2 ,level) == splitsStatsForLevel.length)
76+
5677
5778 }
5879
80+ // TODO: Extract decision tree model
81+
5982 return new DecisionTreeModel ()
6083 }
6184
@@ -99,7 +122,7 @@ object DecisionTree extends Serializable {
99122 if (level == 0 ) {
100123 List [Filter ]()
101124 } else {
102- val nodeFilterIndex = scala.math.pow(2 , level).toInt + nodeIndex
125+ val nodeFilterIndex = scala.math.pow(2 , level).toInt - 1 + nodeIndex
103126 // val parentFilterIndex = nodeFilterIndex / 2
104127 // TODO: Check left or right filter
105128 filters(nodeFilterIndex)
@@ -155,11 +178,11 @@ object DecisionTree extends Serializable {
155178 // calculating bin index and label per feature per node
156179 val arr = new Array [Double ](1 + (numFeatures * numNodes))
157180 arr(0 ) = labeledPoint.label
158- for (nodeIndex <- 0 until numNodes) {
159- val parentFilters = findParentFilters(nodeIndex )
181+ for (index <- 0 until numNodes) {
182+ val parentFilters = findParentFilters(index )
160183 // Find out whether the sample qualifies for the particular node
161184 val sampleValid = isSampleValid(parentFilters, labeledPoint)
162- val shift = 1 + numFeatures * nodeIndex
185+ val shift = 1 + numFeatures * index
163186 if (! sampleValid) {
164187 // Add to invalid bin index -1
165188 for (featureIndex <- 0 until numFeatures) {
@@ -251,22 +274,26 @@ object DecisionTree extends Serializable {
251274 val right1Count = rightNodeAgg(featureIndex)(2 * index + 1 )
252275 val rightCount = right0Count + right1Count
253276
277+ val impurity = if (level > 0 ) topImpurity else strategy.impurity.calculate(left0Count + right0Count, left1Count + right1Count)
278+
254279 if (leftCount == 0 ) return new InformationGainStats (0 ,topImpurity,Double .MinValue ,0 ,topImpurity,rightCount.toLong)
255280 if (rightCount == 0 ) return new InformationGainStats (0 ,topImpurity,topImpurity,leftCount.toLong,Double .MinValue ,0 )
256281
257- // println("left0count = " + left0Count + ", left1count = " + left1Count + ", leftCount = " + leftCount)
258282 val leftImpurity = strategy.impurity.calculate(left0Count, left1Count)
259-
260-
261- // println("right0count = " + right0Count + ", right1count = " + right1Count + ", rightCount = " + rightCount)
262283 val rightImpurity = strategy.impurity.calculate(right0Count, right1Count)
263284
264285 val leftWeight = leftCount.toDouble / (leftCount + rightCount)
265286 val rightWeight = rightCount.toDouble / (leftCount + rightCount)
266287
267- val gain = topImpurity - leftWeight * leftImpurity - rightWeight * rightImpurity
288+ val gain = {
289+ if (level > 0 ) {
290+ impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
291+ } else {
292+ impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
293+ }
294+ }
268295
269- new InformationGainStats (gain,topImpurity ,leftImpurity,leftCount.toLong,rightImpurity,rightCount.toLong)
296+ new InformationGainStats (gain,impurity ,leftImpurity,leftCount.toLong,rightImpurity,rightCount.toLong)
270297
271298 }
272299
@@ -339,7 +366,7 @@ object DecisionTree extends Serializable {
339366 var bestFeatureIndex = 0
340367 var bestSplitIndex = 0
341368 // Initialization with infeasible values
342- var bestGainStats = new InformationGainStats (- 1.0 ,- 1.0 ,- 1.0 ,0 ,- 1.0 ,0 )
369+ var bestGainStats = new InformationGainStats (Double . MinValue ,- 1.0 ,- 1.0 ,0 ,- 1.0 ,0 )
343370// var maxGain = Double.MinValue
344371// var leftSamples = Long.MinValue
345372// var rightSamples = Long.MinValue
@@ -351,8 +378,8 @@ object DecisionTree extends Serializable {
351378 bestGainStats = gainStats
352379 bestFeatureIndex = featureIndex
353380 bestSplitIndex = splitIndex
354- println(" bestFeatureIndex = " + bestFeatureIndex + " , bestSplitIndex = " + bestSplitIndex
355- + " , gain stats = " + bestGainStats)
381+ // println("bestFeatureIndex = " + bestFeatureIndex + ", bestSplitIndex = " + bestSplitIndex)
382+ // println( " gain stats = " + bestGainStats)
356383 }
357384 }
358385 }
@@ -365,9 +392,12 @@ object DecisionTree extends Serializable {
365392 // Calculate best splits for all nodes at a given level
366393 val bestSplits = new Array [(Split , InformationGainStats )](numNodes)
367394 for (node <- 0 until numNodes){
395+ val nodeImpurityIndex = scala.math.pow(2 , level).toInt - 1 + node
368396 val shift = 2 * node* numSplits* numFeatures
369397 val binsForNode = binAggregates.slice(shift,shift+ 2 * numSplits* numFeatures)
370- val parentNodeImpurity = parentImpurities(node/ 2 )
398+ println(" nodeImpurityIndex = " + nodeImpurityIndex)
399+ val parentNodeImpurity = parentImpurities(nodeImpurityIndex)
400+ println(" node impurity = " + parentNodeImpurity)
371401 bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity)
372402 }
373403
@@ -456,8 +486,9 @@ object DecisionTree extends Serializable {
456486
457487 val sc = new SparkContext (args(0 ), " DecisionTree" )
458488 val data = loadLabeledData(sc, args(1 ))
489+ val maxDepth = args(2 ).toInt
459490
460- val strategy = new Strategy (kind = " classification" , impurity = Gini , maxDepth = 2 , numSplits = 569 )
491+ val strategy = new Strategy (kind = " classification" , impurity = Gini , maxDepth = maxDepth , numSplits = 569 )
461492 val model = new DecisionTree (strategy).train(data)
462493
463494 sc.stop()
0 commit comments