@@ -121,7 +121,7 @@ object DecisionTree extends Serializable {
121121
122122 /* Finds the right bin for the given feature*/
123123 def findBin (featureIndex : Int , labeledPoint : LabeledPoint ) : Int = {
124- println(" finding bin for labeled point " + labeledPoint.features(featureIndex))
124+ // println("finding bin for labeled point " + labeledPoint.features(featureIndex))
125125 // TODO: Do binary search
126126 for (binIndex <- 0 until strategy.numSplits) {
127127 val bin = bins(featureIndex)(binIndex)
@@ -227,21 +227,27 @@ object DecisionTree extends Serializable {
227227
228228 val binAggregates = binMappedRDD.aggregate(Array .fill[Double ](2 * numSplits* numFeatures* numNodes)(0 ))(binSeqOp,binCombOp)
229229 println(" binAggregates.length = " + binAggregates.length)
230- binAggregates.foreach(x => println(x))
230+ // binAggregates.foreach(x => println(x))
231231
232232
233233 def calculateGainForSplit (leftNodeAgg : Array [Array [Double ]], featureIndex : Int , index : Int , rightNodeAgg : Array [Array [Double ]], topImpurity : Double ): Double = {
234234
235235 val left0Count = leftNodeAgg(featureIndex)(2 * index)
236236 val left1Count = leftNodeAgg(featureIndex)(2 * index + 1 )
237237 val leftCount = left0Count + left1Count
238- println(" left0count = " + left0Count + " , left1count = " + left1Count + " , leftCount = " + leftCount)
238+
239+ if (leftCount == 0 ) return 0
240+
241+ // println("left0count = " + left0Count + ", left1count = " + left1Count + ", leftCount = " + leftCount)
239242 val leftImpurity = strategy.impurity.calculate(left0Count, left1Count)
240243
241244 val right0Count = rightNodeAgg(featureIndex)(2 * index)
242245 val right1Count = rightNodeAgg(featureIndex)(2 * index + 1 )
243246 val rightCount = right0Count + right1Count
244- println(" right0count = " + right0Count + " , right1count = " + right1Count + " , rightCount = " + rightCount)
247+
248+ if (rightCount == 0 ) return 0
249+
250+ // println("right0count = " + right0Count + ", right1count = " + right1Count + ", rightCount = " + rightCount)
245251 val rightImpurity = strategy.impurity.calculate(right0Count, right1Count)
246252
247253 val leftWeight = leftCount.toDouble / (leftCount + rightCount)
@@ -261,21 +267,21 @@ object DecisionTree extends Serializable {
261267 def extractLeftRightNodeAggregates (binData : Array [Double ]): (Array [Array [Double ]], Array [Array [Double ]]) = {
262268 val leftNodeAgg = Array .ofDim[Double ](numFeatures, 2 * (numSplits - 1 ))
263269 val rightNodeAgg = Array .ofDim[Double ](numFeatures, 2 * (numSplits - 1 ))
264- println(" binData.length = " + binData.length)
265- println(" binData.sum = " + binData.sum)
270+ // println("binData.length = " + binData.length)
271+ // println("binData.sum = " + binData.sum)
266272 for (featureIndex <- 0 until numFeatures) {
267- println(" featureIndex = " + featureIndex)
273+ // println("featureIndex = " + featureIndex)
268274 val shift = 2 * featureIndex* numSplits
269275 leftNodeAgg(featureIndex)(0 ) = binData(shift + 0 )
270- println(" binData(shift + 0) = " + binData(shift + 0 ))
276+ // println("binData(shift + 0) = " + binData(shift + 0))
271277 leftNodeAgg(featureIndex)(1 ) = binData(shift + 1 )
272- println(" binData(shift + 1) = " + binData(shift + 1 ))
278+ // println("binData(shift + 1) = " + binData(shift + 1))
273279 rightNodeAgg(featureIndex)(2 * (numSplits - 2 )) = binData(shift + (2 * (numSplits - 1 )))
274- println(binData(shift + (2 * (numSplits - 1 ))))
280+ // println(binData(shift + (2 * (numSplits - 1))))
275281 rightNodeAgg(featureIndex)(2 * (numSplits - 2 ) + 1 ) = binData(shift + (2 * (numSplits - 1 )) + 1 )
276- println(binData(shift + (2 * (numSplits - 1 )) + 1 ))
282+ // println(binData(shift + (2 * (numSplits - 1)) + 1))
277283 for (splitIndex <- 1 until numSplits - 1 ) {
278- println(" splitIndex = " + splitIndex)
284+ // println("splitIndex = " + splitIndex)
279285 leftNodeAgg(featureIndex)(2 * splitIndex)
280286 = binData(shift + 2 * splitIndex) + leftNodeAgg(featureIndex)(2 * splitIndex - 2 )
281287 leftNodeAgg(featureIndex)(2 * splitIndex + 1 )
@@ -295,7 +301,7 @@ object DecisionTree extends Serializable {
295301
296302 for (featureIndex <- 0 until numFeatures) {
297303 for (index <- 0 until numSplits - 1 ) {
298- println(" splitIndex = " + index)
304+ // println("splitIndex = " + index)
299305 gains(featureIndex)(index) = calculateGainForSplit(leftNodeAgg, featureIndex, index, rightNodeAgg, nodeImpurity)
300306 }
301307 }
@@ -312,8 +318,8 @@ object DecisionTree extends Serializable {
312318 val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData)
313319 val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity)
314320
315- println(" gains.size = " + gains.size)
316- println(" gains(0).size = " + gains(0 ).size)
321+ // println("gains.size = " + gains.size)
322+ // println("gains(0).size = " + gains(0).size)
317323
318324 val (bestFeatureIndex,bestSplitIndex) = {
319325 var bestFeatureIndex = 0
@@ -322,7 +328,7 @@ object DecisionTree extends Serializable {
322328 for (featureIndex <- 0 until numFeatures) {
323329 for (splitIndex <- 0 until numSplits - 1 ){
324330 val gain = gains(featureIndex)(splitIndex)
325- println(" featureIndex = " + featureIndex + " , splitIndex = " + splitIndex + " , gain = " + gain)
331+ // println("featureIndex = " + featureIndex + ", splitIndex = " + splitIndex + ", gain = " + gain)
326332 if (gain > maxGain) {
327333 maxGain = gain
328334 bestFeatureIndex = featureIndex
@@ -335,6 +341,8 @@ object DecisionTree extends Serializable {
335341 }
336342
337343 splits(bestFeatureIndex)(bestSplitIndex)
344+
345+ // TODo: Return array of node stats with split and impurity information
338346 }
339347
340348 // Calculate best splits for all nodes at a given level
@@ -388,6 +396,9 @@ object DecisionTree extends Serializable {
388396 for (featureIndex <- 0 until numFeatures){
389397 val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted
390398 val stride : Double = numSamples.toDouble/ numSplits
399+
400+ println(" stride = " + stride)
401+
391402 for (index <- 0 until numSplits- 1 ) {
392403 val sampleIndex = (index+ 1 )* stride.toInt
393404 val split = new Split (featureIndex,featureSamples(sampleIndex)," continuous" )
0 commit comments