@@ -20,9 +20,10 @@ package org.apache.spark.mllib.tree
2020import org .apache .spark .SparkContext ._
2121import org .apache .spark .rdd .RDD
2222import org .apache .spark .mllib .tree .model ._
23- import org .apache .spark .Logging
23+ import org .apache .spark .{ SparkContext , Logging }
2424import org .apache .spark .mllib .regression .LabeledPoint
2525import org .apache .spark .mllib .tree .model .Split
26+ import org .apache .spark .mllib .tree .impurity .Gini
2627
2728
2829class DecisionTree (val strategy : Strategy ) {
@@ -46,8 +47,13 @@ class DecisionTree(val strategy : Strategy) {
4647 // Find best split for all nodes at a level
4748 val numNodes = scala.math.pow(2 ,level).toInt
4849 // TODO: Change the input parent impurities values
49- val bestSplits = DecisionTree .findBestSplits(input, Array (0.0 ), strategy, level, filters,splits,bins)
50+ val splits_stats_for_level = DecisionTree .findBestSplits(input, Array (2.0 ), strategy, level, filters,splits,bins)
51+ for (tmp <- splits_stats_for_level){
52+ println(" final best split = " + tmp._1)
53+ }
5054 // TODO: update filters and decision tree model
55+ require(scala.math.pow(2 ,level)== splits_stats_for_level.length)
56+
5157 }
5258
5359 return new DecisionTreeModel ()
@@ -77,7 +83,7 @@ object DecisionTree extends Serializable {
7783 level : Int ,
7884 filters : Array [List [Filter ]],
7985 splits : Array [Array [Split ]],
80- bins : Array [Array [Bin ]]) : Array [Split ] = {
86+ bins : Array [Array [Bin ]]) : Array [( Split , Double , Long , Long ) ] = {
8187
8288 // Common calculations for multiple nested methods
8389 val numNodes = scala.math.pow(2 , level).toInt
@@ -94,8 +100,9 @@ object DecisionTree extends Serializable {
94100 List [Filter ]()
95101 } else {
96102 val nodeFilterIndex = scala.math.pow(2 , level).toInt + nodeIndex
97- val parentFilterIndex = nodeFilterIndex / 2
98- filters(parentFilterIndex)
103+ // val parentFilterIndex = nodeFilterIndex / 2
104+ // TODO: Check left or right filter
105+ filters(nodeFilterIndex)
99106 }
100107 }
101108
@@ -230,30 +237,34 @@ object DecisionTree extends Serializable {
230237 // binAggregates.foreach(x => println(x))
231238
232239
233- def calculateGainForSplit (leftNodeAgg : Array [Array [Double ]], featureIndex : Int , index : Int , rightNodeAgg : Array [Array [Double ]], topImpurity : Double ): Double = {
240+ def calculateGainForSplit (leftNodeAgg : Array [Array [Double ]],
241+ featureIndex : Int ,
242+ index : Int ,
243+ rightNodeAgg : Array [Array [Double ]],
244+ topImpurity : Double ) : (Double , Long , Long ) = {
234245
235246 val left0Count = leftNodeAgg(featureIndex)(2 * index)
236247 val left1Count = leftNodeAgg(featureIndex)(2 * index + 1 )
237248 val leftCount = left0Count + left1Count
238249
239- if (leftCount == 0 ) return 0
240-
241- // println("left0count = " + left0Count + ", left1count = " + left1Count + ", leftCount = " + leftCount)
242- val leftImpurity = strategy.impurity.calculate(left0Count, left1Count)
243-
244250 val right0Count = rightNodeAgg(featureIndex)(2 * index)
245251 val right1Count = rightNodeAgg(featureIndex)(2 * index + 1 )
246252 val rightCount = right0Count + right1Count
247253
248- if (rightCount == 0 ) return 0
254+ if (leftCount == 0 ) return (0 , leftCount.toLong, rightCount.toLong)
255+
256+ // println("left0count = " + left0Count + ", left1count = " + left1Count + ", leftCount = " + leftCount)
257+ val leftImpurity = strategy.impurity.calculate(left0Count, left1Count)
258+
259+ if (rightCount == 0 ) return (0 , leftCount.toLong, rightCount.toLong)
249260
250261 // println("right0count = " + right0Count + ", right1count = " + right1Count + ", rightCount = " + rightCount)
251262 val rightImpurity = strategy.impurity.calculate(right0Count, right1Count)
252263
253264 val leftWeight = leftCount.toDouble / (leftCount + rightCount)
254265 val rightWeight = rightCount.toDouble / (leftCount + rightCount)
255266
256- topImpurity - leftWeight * leftImpurity - rightWeight * rightImpurity
267+ ( topImpurity - leftWeight * leftImpurity - rightWeight * rightImpurity, leftCount.toLong, rightCount.toLong)
257268
258269 }
259270
@@ -295,9 +306,10 @@ object DecisionTree extends Serializable {
295306 (leftNodeAgg, rightNodeAgg)
296307 }
297308
298- def calculateGainsForAllNodeSplits (leftNodeAgg : Array [Array [Double ]], rightNodeAgg : Array [Array [Double ]], nodeImpurity : Double ): Array [Array [Double ]] = {
309+ def calculateGainsForAllNodeSplits (leftNodeAgg : Array [Array [Double ]], rightNodeAgg : Array [Array [Double ]], nodeImpurity : Double )
310+ : Array [Array [(Double ,Long ,Long )]] = {
299311
300- val gains = Array .ofDim[Double ](numFeatures, numSplits - 1 )
312+ val gains = Array .ofDim[( Double , Long , Long ) ](numFeatures, numSplits - 1 )
301313
302314 for (featureIndex <- 0 until numFeatures) {
303315 for (index <- 0 until numSplits - 1 ) {
@@ -313,40 +325,44 @@ object DecisionTree extends Serializable {
313325
314326 @param binData Array[Double] of size 2*numSplits*numFeatures
315327 */
316- def binsToBestSplit (binData : Array [Double ], nodeImpurity : Double ) : Split = {
328+ def binsToBestSplit (binData : Array [Double ], nodeImpurity : Double ) : ( Split , Double , Long , Long ) = {
317329 println(" node impurity = " + nodeImpurity)
318330 val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData)
319331 val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity)
320332
321333 // println("gains.size = " + gains.size)
322334 // println("gains(0).size = " + gains(0).size)
323335
324- val (bestFeatureIndex,bestSplitIndex) = {
336+ val (bestFeatureIndex,bestSplitIndex, gain, leftCount, rightCount ) = {
325337 var bestFeatureIndex = 0
326338 var bestSplitIndex = 0
327339 var maxGain = Double .MinValue
340+ var leftSamples = Long .MinValue
341+ var rightSamples = Long .MinValue
328342 for (featureIndex <- 0 until numFeatures) {
329343 for (splitIndex <- 0 until numSplits - 1 ){
330344 val gain = gains(featureIndex)(splitIndex)
331345 // println("featureIndex = " + featureIndex + ", splitIndex = " + splitIndex + ", gain = " + gain)
332- if (gain > maxGain) {
333- maxGain = gain
346+ if (gain._1 > maxGain) {
347+ maxGain = gain._1
348+ leftSamples = gain._2
349+ rightSamples = gain._3
334350 bestFeatureIndex = featureIndex
335351 bestSplitIndex = splitIndex
336- println(" bestFeatureIndex = " + bestFeatureIndex + " , bestSplitIndex = " + bestSplitIndex + " , maxGain = " + maxGain)
352+ println(" bestFeatureIndex = " + bestFeatureIndex + " , bestSplitIndex = " + bestSplitIndex
353+ + " , maxGain = " + maxGain + " , leftSamples = " + leftSamples + " ,rightSamples = " + rightSamples)
337354 }
338355 }
339356 }
340- (bestFeatureIndex,bestSplitIndex)
357+ (bestFeatureIndex,bestSplitIndex,maxGain,leftSamples,rightSamples )
341358 }
342359
343- splits(bestFeatureIndex)(bestSplitIndex)
344-
345- // TODo: Return array of node stats with split and impurity information
360+ (splits(bestFeatureIndex)(bestSplitIndex),gain,leftCount,rightCount)
361+ // TODO: Return array of node stats with split and impurity information
346362 }
347363
348364 // Calculate best splits for all nodes at a given level
349- val bestSplits = new Array [Split ](numNodes)
365+ val bestSplits = new Array [( Split , Double , Long , Long ) ](numNodes)
350366 for (node <- 0 until numNodes){
351367 val shift = 2 * node* numSplits* numFeatures
352368 val binsForNode = binAggregates.slice(shift,shift+ 2 * numSplits* numFeatures)
@@ -381,9 +397,6 @@ object DecisionTree extends Serializable {
381397 val sampledInput = input.sample(false , fraction, 42 ).collect()
382398 val numSamples = sampledInput.length
383399
384- // TODO: Remove this requirement
385- require(numSamples > numSplits, " length of input samples should be greater than numSplits" )
386-
387400 // Find the number of features by looking at the first sample
388401 val numFeatures = input.take(1 )(0 ).features.length
389402
@@ -395,14 +408,22 @@ object DecisionTree extends Serializable {
395408 // Find all splits
396409 for (featureIndex <- 0 until numFeatures){
397410 val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted
398- val stride : Double = numSamples.toDouble/ numSplits
399-
400- println(" stride = " + stride)
401411
402- for (index <- 0 until numSplits- 1 ) {
403- val sampleIndex = (index+ 1 )* stride.toInt
404- val split = new Split (featureIndex,featureSamples(sampleIndex)," continuous" )
405- splits(featureIndex)(index) = split
412+ if (numSamples < numSplits) {
413+ // TODO: Test this
414+ println(" numSamples = " + numSamples + " , less than numSplits = " + numSplits)
415+ for (index <- 0 until numSplits- 1 ) {
416+ val split = new Split (featureIndex,featureSamples(index)," continuous" )
417+ splits(featureIndex)(index) = split
418+ }
419+ } else {
420+ val stride : Double = numSamples.toDouble/ numSplits
421+ println(" stride = " + stride)
422+ for (index <- 0 until numSplits- 1 ) {
423+ val sampleIndex = (index+ 1 )* stride.toInt
424+ val split = new Split (featureIndex,featureSamples(sampleIndex)," continuous" )
425+ splits(featureIndex)(index) = split
426+ }
406427 }
407428 }
408429
@@ -430,4 +451,36 @@ object DecisionTree extends Serializable {
430451 }
431452 }
432453
454+ def main (args : Array [String ]) {
455+
456+ val sc = new SparkContext (args(0 ), " DecisionTree" )
457+ val data = loadLabeledData(sc, args(1 ))
458+
459+ val strategy = new Strategy (kind = " classification" , impurity = Gini , maxDepth = 2 , numSplits = 569 )
460+ val model = new DecisionTree (strategy).train(data)
461+
462+ sc.stop()
463+ }
464+
465+ /**
466+ * Load labeled data from a file. The data format used here is
467+ * <L>, <f1> <f2> ...
468+ * where <f1>, <f2> are feature values in Double and <L> is the corresponding label as Double.
469+ *
470+ * @param sc SparkContext
471+ * @param dir Directory to the input data files.
472+ * @return An RDD of LabeledPoint. Each labeled point has two elements: the first element is
473+ * the label, and the second element represents the feature values (an array of Double).
474+ */
475+ def loadLabeledData (sc : SparkContext , dir : String ): RDD [LabeledPoint ] = {
476+ sc.textFile(dir).map { line =>
477+ val parts = line.trim().split(" ," )
478+ val label = parts(0 ).toDouble
479+ val features = parts.slice(1 ,parts.length).map(_.toDouble)
480+ LabeledPoint (label, features)
481+ }
482+ }
483+
484+
485+
433486}
0 commit comments