Skip to content

Commit 80e8c66

Browse files
committed
working version of multi-level split calculation
Signed-off-by: Manish Amde <[email protected]>
1 parent 4798aae commit 80e8c66

File tree

2 files changed

+63
-28
lines changed

2 files changed

+63
-28
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 53 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import org.apache.spark.mllib.tree.model.Split
2626
import org.apache.spark.mllib.tree.impurity.Gini
2727

2828

29-
class DecisionTree(val strategy : Strategy) {
29+
class DecisionTree(val strategy : Strategy) extends Logging {
3030

3131
def train(input : RDD[LabeledPoint]) : DecisionTreeModel = {
3232

@@ -42,20 +42,43 @@ class DecisionTree(val strategy : Strategy) {
4242

4343
val maxNumNodes = scala.math.pow(2,maxDepth).toInt - 1
4444
val filters = new Array[List[Filter]](maxNumNodes)
45+
filters(0) = List()
46+
val parentImpurities = new Array[Double](maxNumNodes)
47+
//Dummy value for top node (calculate from scratch during first split calculation)
48+
parentImpurities(0) = Double.MinValue
4549

4650
for (level <- 0 until maxDepth){
51+
52+
println("#####################################")
53+
println("level = " + level)
54+
println("#####################################")
55+
4756
//Find best split for all nodes at a level
4857
val numNodes= scala.math.pow(2,level).toInt
49-
//TODO: Change the input parent impurities values
50-
val splits_stats_for_level = DecisionTree.findBestSplits(input, Array(2.0), strategy, level, filters,splits,bins)
51-
for (tmp <- splits_stats_for_level){
52-
println("final best split = " + tmp._1)
58+
val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy, level, filters,splits,bins)
59+
for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex){
60+
for (i <- 0 to 1){
61+
val nodeIndex = (scala.math.pow(2,level+1)).toInt - 1 + 2*index + i
62+
if(level < maxDepth - 1){
63+
val impurity = if (i == 0) nodeSplitStats._2.leftImpurity else nodeSplitStats._2.rightImpurity
64+
println("nodeIndex = " + nodeIndex + ", impurity = " + impurity)
65+
parentImpurities(nodeIndex) = impurity
66+
println("updating nodeIndex = " + nodeIndex)
67+
filters(nodeIndex) = new Filter(nodeSplitStats._1, if(i == 0) - 1 else 1) :: filters((nodeIndex-1)/2)
68+
for (filter <- filters(nodeIndex)){
69+
println(filter)
70+
}
71+
}
72+
}
73+
println("final best split = " + nodeSplitStats._1)
5374
}
54-
//TODO: update filters and decision tree model
55-
require(scala.math.pow(2,level)==splits_stats_for_level.length)
75+
require(scala.math.pow(2,level)==splitsStatsForLevel.length)
76+
5677

5778
}
5879

80+
//TODO: Extract decision tree model
81+
5982
return new DecisionTreeModel()
6083
}
6184

@@ -99,7 +122,7 @@ object DecisionTree extends Serializable {
99122
if (level == 0) {
100123
List[Filter]()
101124
} else {
102-
val nodeFilterIndex = scala.math.pow(2, level).toInt + nodeIndex
125+
val nodeFilterIndex = scala.math.pow(2, level).toInt - 1 + nodeIndex
103126
//val parentFilterIndex = nodeFilterIndex / 2
104127
//TODO: Check left or right filter
105128
filters(nodeFilterIndex)
@@ -155,11 +178,11 @@ object DecisionTree extends Serializable {
155178
// calculating bin index and label per feature per node
156179
val arr = new Array[Double](1+(numFeatures * numNodes))
157180
arr(0) = labeledPoint.label
158-
for (nodeIndex <- 0 until numNodes) {
159-
val parentFilters = findParentFilters(nodeIndex)
181+
for (index <- 0 until numNodes) {
182+
val parentFilters = findParentFilters(index)
160183
//Find out whether the sample qualifies for the particular node
161184
val sampleValid = isSampleValid(parentFilters, labeledPoint)
162-
val shift = 1 + numFeatures * nodeIndex
185+
val shift = 1 + numFeatures * index
163186
if (!sampleValid) {
164187
//Add to invalid bin index -1
165188
for (featureIndex <- 0 until numFeatures) {
@@ -251,22 +274,26 @@ object DecisionTree extends Serializable {
251274
val right1Count = rightNodeAgg(featureIndex)(2 * index + 1)
252275
val rightCount = right0Count + right1Count
253276

277+
val impurity = if (level > 0) topImpurity else strategy.impurity.calculate(left0Count + right0Count, left1Count + right1Count)
278+
254279
if (leftCount == 0) return new InformationGainStats(0,topImpurity,Double.MinValue,0,topImpurity,rightCount.toLong)
255280
if (rightCount == 0) return new InformationGainStats(0,topImpurity,topImpurity,leftCount.toLong,Double.MinValue,0)
256281

257-
//println("left0count = " + left0Count + ", left1count = " + left1Count + ", leftCount = " + leftCount)
258282
val leftImpurity = strategy.impurity.calculate(left0Count, left1Count)
259-
260-
261-
//println("right0count = " + right0Count + ", right1count = " + right1Count + ", rightCount = " + rightCount)
262283
val rightImpurity = strategy.impurity.calculate(right0Count, right1Count)
263284

264285
val leftWeight = leftCount.toDouble / (leftCount + rightCount)
265286
val rightWeight = rightCount.toDouble / (leftCount + rightCount)
266287

267-
val gain = topImpurity - leftWeight * leftImpurity - rightWeight * rightImpurity
288+
val gain = {
289+
if (level > 0) {
290+
impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
291+
} else {
292+
impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
293+
}
294+
}
268295

269-
new InformationGainStats(gain,topImpurity,leftImpurity,leftCount.toLong,rightImpurity,rightCount.toLong)
296+
new InformationGainStats(gain,impurity,leftImpurity,leftCount.toLong,rightImpurity,rightCount.toLong)
270297

271298
}
272299

@@ -339,7 +366,7 @@ object DecisionTree extends Serializable {
339366
var bestFeatureIndex = 0
340367
var bestSplitIndex = 0
341368
//Initialization with infeasible values
342-
var bestGainStats = new InformationGainStats(-1.0,-1.0,-1.0,0,-1.0,0)
369+
var bestGainStats = new InformationGainStats(Double.MinValue,-1.0,-1.0,0,-1.0,0)
343370
// var maxGain = Double.MinValue
344371
// var leftSamples = Long.MinValue
345372
// var rightSamples = Long.MinValue
@@ -351,8 +378,8 @@ object DecisionTree extends Serializable {
351378
bestGainStats = gainStats
352379
bestFeatureIndex = featureIndex
353380
bestSplitIndex = splitIndex
354-
println("bestFeatureIndex = " + bestFeatureIndex + ", bestSplitIndex = " + bestSplitIndex
355-
+ ", gain stats = " + bestGainStats)
381+
//println("bestFeatureIndex = " + bestFeatureIndex + ", bestSplitIndex = " + bestSplitIndex)
382+
//println( "gain stats = " + bestGainStats)
356383
}
357384
}
358385
}
@@ -365,9 +392,12 @@ object DecisionTree extends Serializable {
365392
//Calculate best splits for all nodes at a given level
366393
val bestSplits = new Array[(Split, InformationGainStats)](numNodes)
367394
for (node <- 0 until numNodes){
395+
val nodeImpurityIndex = scala.math.pow(2, level).toInt - 1 + node
368396
val shift = 2*node*numSplits*numFeatures
369397
val binsForNode = binAggregates.slice(shift,shift+2*numSplits*numFeatures)
370-
val parentNodeImpurity = parentImpurities(node/2)
398+
println("nodeImpurityIndex = " + nodeImpurityIndex)
399+
val parentNodeImpurity = parentImpurities(nodeImpurityIndex)
400+
println("node impurity = " + parentNodeImpurity)
371401
bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity)
372402
}
373403

@@ -456,8 +486,9 @@ object DecisionTree extends Serializable {
456486

457487
val sc = new SparkContext(args(0), "DecisionTree")
458488
val data = loadLabeledData(sc, args(1))
489+
val maxDepth = args(2).toInt
459490

460-
val strategy = new Strategy(kind = "classification", impurity = Gini, maxDepth = 2, numSplits = 569)
491+
val strategy = new Strategy(kind = "classification", impurity = Gini, maxDepth = maxDepth, numSplits = 569)
461492
val model = new DecisionTree(strategy).train(data)
462493

463494
sc.stop()

mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,15 @@ package org.apache.spark.mllib.tree.impurity
1818

1919
object Gini extends Impurity {
2020

21-
def calculate(c0 : Double, c1 : Double): Double = {
22-
val total = c0 + c1
23-
val f0 = c0 / total
24-
val f1 = c1 / total
25-
1 - f0*f0 - f1*f1
26-
}
21+
def calculate(c0 : Double, c1 : Double): Double = {
22+
if (c0 == 0 || c1 == 0) {
23+
0
24+
} else {
25+
val total = c0 + c1
26+
val f0 = c0 / total
27+
val f1 = c1 / total
28+
1 - f0*f0 - f1*f1
29+
}
30+
}
2731

2832
}

0 commit comments

Comments
 (0)