@@ -204,15 +204,12 @@ object DecisionTree extends Serializable with Logging {
204204 }
205205
206206 /* Finds the right bin for the given feature*/
207- def findBin (featureIndex : Int , labeledPoint : LabeledPoint ) : Int = {
208- // logDebug("finding bin for labeled point " + labeledPoint.features(featureIndex))
207+ def findBin (featureIndex : Int , labeledPoint : LabeledPoint , isFeatureContinous : Boolean ) : Int = {
209208
210- val isFeatureContinous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
211209 if (isFeatureContinous){
212210 // TODO: Do binary search
213211 for (binIndex <- 0 until strategy.numBins) {
214212 val bin = bins(featureIndex)(binIndex)
215- // TODO: Remove this requirement post basic functional
216213 val lowThreshold = bin.lowSplit.threshold
217214 val highThreshold = bin.highSplit.threshold
218215 val features = labeledPoint.features
@@ -222,9 +219,9 @@ object DecisionTree extends Serializable with Logging {
222219 }
223220 throw new UnknownError (" no bin was found for continuous variable." )
224221 } else {
222+
225223 for (binIndex <- 0 until strategy.numBins) {
226224 val bin = bins(featureIndex)(binIndex)
227- // TODO: Remove this requirement post basic functional
228225 val category = bin.category
229226 val features = labeledPoint.features
230227 if (category == features(featureIndex)) {
@@ -262,7 +259,8 @@ object DecisionTree extends Serializable with Logging {
262259 } else {
263260 for (featureIndex <- 0 until numFeatures) {
264261 // logDebug("shift+featureIndex =" + (shift+featureIndex))
265- arr(shift + featureIndex) = findBin(featureIndex, labeledPoint)
262+ val isFeatureContinous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
263+ arr(shift + featureIndex) = findBin(featureIndex, labeledPoint,isFeatureContinous)
266264 }
267265 }
268266
0 commit comments