@@ -152,17 +152,15 @@ object DecisionTree extends Serializable with Logging {
152152 // Find the number of features by looking at the first sample
153153 val numFeatures = input.take(1 )(0 ).features.length
154154 logDebug(" numFeatures = " + numFeatures)
155- val numSplits = strategy.numBins
156- logDebug(" numSplits = " + numSplits )
155+ val numBins = strategy.numBins
156+ logDebug(" numBins = " + numBins )
157157
158158 /* Find the filters used before reaching the current code*/
159159 def findParentFilters (nodeIndex : Int ): List [Filter ] = {
160160 if (level == 0 ) {
161161 List [Filter ]()
162162 } else {
163163 val nodeFilterIndex = scala.math.pow(2 , level).toInt - 1 + nodeIndex
164- // val parentFilterIndex = nodeFilterIndex / 2
165- // TODO: Check left or right filter
166164 filters(nodeFilterIndex)
167165 }
168166 }
@@ -204,9 +202,9 @@ object DecisionTree extends Serializable with Logging {
204202 }
205203
206204 /* Finds the right bin for the given feature*/
207- def findBin (featureIndex : Int , labeledPoint : LabeledPoint , isFeatureContinous : Boolean ) : Int = {
205+ def findBin (featureIndex : Int , labeledPoint : LabeledPoint , isFeatureContinuous : Boolean ) : Int = {
208206
209- if (isFeatureContinous ){
207+ if (isFeatureContinuous ){
210208 // TODO: Do binary search
211209 for (binIndex <- 0 until strategy.numBins) {
212210 val bin = bins(featureIndex)(binIndex)
@@ -245,11 +243,11 @@ object DecisionTree extends Serializable with Logging {
245243 // calculating bin index and label per feature per node
246244 val arr = new Array [Double ](1 + (numFeatures * numNodes))
247245 arr(0 ) = labeledPoint.label
248- for (index <- 0 until numNodes) {
249- val parentFilters = findParentFilters(index )
246+ for (nodeIndex <- 0 until numNodes) {
247+ val parentFilters = findParentFilters(nodeIndex )
250248 // Find out whether the sample qualifies for the particular node
251249 val sampleValid = isSampleValid(parentFilters, labeledPoint)
252- val shift = 1 + numFeatures * index
250+ val shift = 1 + numFeatures * nodeIndex
253251 if (! sampleValid) {
254252 // Add to invalid bin index -1
255253 for (featureIndex <- 0 until numFeatures) {
@@ -274,11 +272,11 @@ object DecisionTree extends Serializable with Logging {
274272 val isSampleValidForNode = if (arr(validSignalIndex) != - 1 ) true else false
275273 if (isSampleValidForNode) {
276274 val label = arr(0 )
277- for (feature <- 0 until numFeatures) {
275+ for (featureIndex <- 0 until numFeatures) {
278276 val arrShift = 1 + numFeatures * node
279- val aggShift = 2 * numSplits * numFeatures * node
280- val arrIndex = arrShift + feature
281- val aggIndex = aggShift + 2 * feature * numSplits + arr(arrIndex).toInt * 2
277+ val aggShift = 2 * numBins * numFeatures * node
278+ val arrIndex = arrShift + featureIndex
279+ val aggIndex = aggShift + 2 * featureIndex * numBins + arr(arrIndex).toInt * 2
282280 label match {
283281 case (0.0 ) => agg(aggIndex) = agg(aggIndex) + 1
284282 case (1.0 ) => agg(aggIndex + 1 ) = agg(aggIndex + 1 ) + 1
@@ -296,9 +294,9 @@ object DecisionTree extends Serializable with Logging {
296294 val label = arr(0 )
297295 for (feature <- 0 until numFeatures) {
298296 val arrShift = 1 + numFeatures * node
299- val aggShift = 3 * numSplits * numFeatures * node
297+ val aggShift = 3 * numBins * numFeatures * node
300298 val arrIndex = arrShift + feature
301- val aggIndex = aggShift + 3 * feature * numSplits + arr(arrIndex).toInt * 3
299+ val aggIndex = aggShift + 3 * feature * numBins + arr(arrIndex).toInt * 3
302300 // count, sum, sum^2
303301 agg(aggIndex) = agg(aggIndex) + 1
304302 agg(aggIndex + 1 ) = agg(aggIndex + 1 ) + label
@@ -318,7 +316,6 @@ object DecisionTree extends Serializable with Logging {
318316 @return Array[Double] storing aggregate calculation of size 2*numSplits*numFeatures*numNodes for classification
319317 */
320318 def binSeqOp (agg : Array [Double ], arr : Array [Double ]) : Array [Double ] = {
321- // TODO: Requires logic for regressions
322319 strategy.algo match {
323320 case Classification => classificationBinSeqOp(arr, agg)
324321 // TODO: Implement this
@@ -327,10 +324,9 @@ object DecisionTree extends Serializable with Logging {
327324 agg
328325 }
329326
330- // TODO: This length is different for regression
331327 val binAggregateLength = strategy.algo match {
332- case Classification => 2 * numSplits * numFeatures * numNodes
333- case Regression => 3 * numSplits * numFeatures * numNodes
328+ case Classification => 2 * numBins * numFeatures * numNodes
329+ case Regression => 3 * numBins * numFeatures * numNodes
334330 }
335331 logDebug(" binAggregateLength = " + binAggregateLength)
336332
@@ -453,52 +449,52 @@ object DecisionTree extends Serializable with Logging {
453449 strategy.algo match {
454450 case Classification => {
455451
456- val leftNodeAgg = Array .ofDim[Double ](numFeatures, 2 * (numSplits - 1 ))
457- val rightNodeAgg = Array .ofDim[Double ](numFeatures, 2 * (numSplits - 1 ))
452+ val leftNodeAgg = Array .ofDim[Double ](numFeatures, 2 * (numBins - 1 ))
453+ val rightNodeAgg = Array .ofDim[Double ](numFeatures, 2 * (numBins - 1 ))
458454 for (featureIndex <- 0 until numFeatures) {
459- val shift = 2 * featureIndex* numSplits
455+ val shift = 2 * featureIndex* numBins
460456 leftNodeAgg(featureIndex)(0 ) = binData(shift + 0 )
461457 leftNodeAgg(featureIndex)(1 ) = binData(shift + 1 )
462- rightNodeAgg(featureIndex)(2 * (numSplits - 2 )) = binData(shift + (2 * (numSplits - 1 )))
463- rightNodeAgg(featureIndex)(2 * (numSplits - 2 ) + 1 ) = binData(shift + (2 * (numSplits - 1 )) + 1 )
464- for (splitIndex <- 1 until numSplits - 1 ) {
458+ rightNodeAgg(featureIndex)(2 * (numBins - 2 )) = binData(shift + (2 * (numBins - 1 )))
459+ rightNodeAgg(featureIndex)(2 * (numBins - 2 ) + 1 ) = binData(shift + (2 * (numBins - 1 )) + 1 )
460+ for (splitIndex <- 1 until numBins - 1 ) {
465461 leftNodeAgg(featureIndex)(2 * splitIndex)
466462 = binData(shift + 2 * splitIndex) + leftNodeAgg(featureIndex)(2 * splitIndex - 2 )
467463 leftNodeAgg(featureIndex)(2 * splitIndex + 1 )
468464 = binData(shift + 2 * splitIndex + 1 ) + leftNodeAgg(featureIndex)(2 * splitIndex - 2 + 1 )
469- rightNodeAgg(featureIndex)(2 * (numSplits - 2 - splitIndex))
470- = binData(shift + (2 * (numSplits - 1 - splitIndex))) + rightNodeAgg(featureIndex)(2 * (numSplits - 1 - splitIndex))
471- rightNodeAgg(featureIndex)(2 * (numSplits - 2 - splitIndex) + 1 )
472- = binData(shift + (2 * (numSplits - 1 - splitIndex) + 1 )) + rightNodeAgg(featureIndex)(2 * (numSplits - 1 - splitIndex) + 1 )
465+ rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex))
466+ = binData(shift + (2 * (numBins - 1 - splitIndex))) + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex))
467+ rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex) + 1 )
468+ = binData(shift + (2 * (numBins - 1 - splitIndex) + 1 )) + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex) + 1 )
473469 }
474470 }
475471 (leftNodeAgg, rightNodeAgg)
476472 }
477473 case Regression => {
478474
479- val leftNodeAgg = Array .ofDim[Double ](numFeatures, 3 * (numSplits - 1 ))
480- val rightNodeAgg = Array .ofDim[Double ](numFeatures, 3 * (numSplits - 1 ))
475+ val leftNodeAgg = Array .ofDim[Double ](numFeatures, 3 * (numBins - 1 ))
476+ val rightNodeAgg = Array .ofDim[Double ](numFeatures, 3 * (numBins - 1 ))
481477 for (featureIndex <- 0 until numFeatures) {
482- val shift = 3 * featureIndex* numSplits
478+ val shift = 3 * featureIndex* numBins
483479 leftNodeAgg(featureIndex)(0 ) = binData(shift + 0 )
484480 leftNodeAgg(featureIndex)(1 ) = binData(shift + 1 )
485481 leftNodeAgg(featureIndex)(2 ) = binData(shift + 2 )
486- rightNodeAgg(featureIndex)(3 * (numSplits - 2 )) = binData(shift + (3 * (numSplits - 1 )))
487- rightNodeAgg(featureIndex)(3 * (numSplits - 2 ) + 1 ) = binData(shift + (3 * (numSplits - 1 )) + 1 )
488- rightNodeAgg(featureIndex)(3 * (numSplits - 2 ) + 2 ) = binData(shift + (3 * (numSplits - 1 )) + 2 )
489- for (splitIndex <- 1 until numSplits - 1 ) {
482+ rightNodeAgg(featureIndex)(3 * (numBins - 2 )) = binData(shift + (3 * (numBins - 1 )))
483+ rightNodeAgg(featureIndex)(3 * (numBins - 2 ) + 1 ) = binData(shift + (3 * (numBins - 1 )) + 1 )
484+ rightNodeAgg(featureIndex)(3 * (numBins - 2 ) + 2 ) = binData(shift + (3 * (numBins - 1 )) + 2 )
485+ for (splitIndex <- 1 until numBins - 1 ) {
490486 leftNodeAgg(featureIndex)(3 * splitIndex)
491487 = binData(shift + 3 * splitIndex) + leftNodeAgg(featureIndex)(3 * splitIndex - 3 )
492488 leftNodeAgg(featureIndex)(3 * splitIndex + 1 )
493489 = binData(shift + 3 * splitIndex + 1 ) + leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 1 )
494490 leftNodeAgg(featureIndex)(3 * splitIndex + 2 )
495491 = binData(shift + 3 * splitIndex + 2 ) + leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 2 )
496- rightNodeAgg(featureIndex)(3 * (numSplits - 2 - splitIndex))
497- = binData(shift + (3 * (numSplits - 1 - splitIndex))) + rightNodeAgg(featureIndex)(3 * (numSplits - 1 - splitIndex))
498- rightNodeAgg(featureIndex)(3 * (numSplits - 2 - splitIndex) + 1 )
499- = binData(shift + (3 * (numSplits - 1 - splitIndex) + 1 )) + rightNodeAgg(featureIndex)(3 * (numSplits - 1 - splitIndex) + 1 )
500- rightNodeAgg(featureIndex)(3 * (numSplits - 2 - splitIndex) + 2 )
501- = binData(shift + (3 * (numSplits - 1 - splitIndex) + 2 )) + rightNodeAgg(featureIndex)(3 * (numSplits - 1 - splitIndex) + 2 )
492+ rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex))
493+ = binData(shift + (3 * (numBins - 1 - splitIndex))) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex))
494+ rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 1 )
495+ = binData(shift + (3 * (numBins - 1 - splitIndex) + 1 )) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 1 )
496+ rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 2 )
497+ = binData(shift + (3 * (numBins - 1 - splitIndex) + 2 )) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2 )
502498 }
503499 }
504500 (leftNodeAgg, rightNodeAgg)
@@ -509,10 +505,10 @@ object DecisionTree extends Serializable with Logging {
509505 def calculateGainsForAllNodeSplits (leftNodeAgg : Array [Array [Double ]], rightNodeAgg : Array [Array [Double ]], nodeImpurity : Double )
510506 : Array [Array [InformationGainStats ]] = {
511507
512- val gains = Array .ofDim[InformationGainStats ](numFeatures, numSplits - 1 )
508+ val gains = Array .ofDim[InformationGainStats ](numFeatures, numBins - 1 )
513509
514510 for (featureIndex <- 0 until numFeatures) {
515- for (index <- 0 until numSplits - 1 ) {
511+ for (index <- 0 until numBins - 1 ) {
516512 // logDebug("splitIndex = " + index)
517513 gains(featureIndex)(index) = calculateGainForSplit(leftNodeAgg, featureIndex, index, rightNodeAgg, nodeImpurity)
518514 }
@@ -521,10 +517,10 @@ object DecisionTree extends Serializable with Logging {
521517 }
522518
523519 /*
524- Find the best split for a node given bin aggregate data
520+ Find the best split for a node given bin aggregate data
525521
526- @param binData Array[Double] of size 2*numSplits*numFeatures
527- */
522+ @param binData Array[Double] of size 2*numSplits*numFeatures
523+ */
528524 def binsToBestSplit (binData : Array [Double ], nodeImpurity : Double ) : (Split , InformationGainStats ) = {
529525 logDebug(" node impurity = " + nodeImpurity)
530526 val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData)
@@ -536,7 +532,7 @@ object DecisionTree extends Serializable with Logging {
536532 // Initialization with infeasible values
537533 var bestGainStats = new InformationGainStats (Double .MinValue ,- 1.0 ,- 1.0 ,- 1.0 ,- 1 )
538534 for (featureIndex <- 0 until numFeatures) {
539- for (splitIndex <- 0 until numSplits - 1 ){
535+ for (splitIndex <- 0 until numBins - 1 ){
540536 val gainStats = gains(featureIndex)(splitIndex)
541537 if (gainStats.gain > bestGainStats.gain) {
542538 bestGainStats = gainStats
@@ -556,13 +552,13 @@ object DecisionTree extends Serializable with Logging {
556552 def getBinDataForNode (node : Int ): Array [Double ] = {
557553 strategy.algo match {
558554 case Classification => {
559- val shift = 2 * node * numSplits * numFeatures
560- val binsForNode = binAggregates.slice(shift, shift + 2 * numSplits * numFeatures)
555+ val shift = 2 * node * numBins * numFeatures
556+ val binsForNode = binAggregates.slice(shift, shift + 2 * numBins * numFeatures)
561557 binsForNode
562558 }
563559 case Regression => {
564- val shift = 3 * node * numSplits * numFeatures
565- val binsForNode = binAggregates.slice(shift, shift + 3 * numSplits * numFeatures)
560+ val shift = 3 * node * numBins * numFeatures
561+ val binsForNode = binAggregates.slice(shift, shift + 3 * numBins * numFeatures)
566562 binsForNode
567563 }
568564 }
0 commit comments