@@ -37,7 +37,7 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging {
3737 // Cache input RDD for speedup during multiple passes
3838 input.cache()
3939
40- val (splits, bins) = DecisionTree .find_splits_bins (input, strategy)
40+ val (splits, bins) = DecisionTree .findSplitsBins (input, strategy)
4141 logDebug(" numSplits = " + bins(0 ).length)
4242 strategy.numBins = bins(0 ).length
4343
@@ -54,8 +54,6 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging {
5454
5555 logDebug(" algo = " + strategy.algo)
5656
57-
58-
5957 breakable {
6058 for (level <- 0 until maxDepth){
6159
@@ -185,10 +183,21 @@ object DecisionTree extends Serializable with Logging {
185183 val featureIndex = filter.split.feature
186184 val threshold = filter.split.threshold
187185 val comparison = filter.comparison
188- comparison match {
189- case (- 1 ) => if (features(featureIndex) > threshold) return false
190- case (0 ) => if (features(featureIndex) != threshold) return false
191- case (1 ) => if (features(featureIndex) <= threshold) return false
186+ val categories = filter.split.categories
187+ val isFeatureContinuous = filter.split.featureType == Continuous
188+ val feature = features(featureIndex)
189+ if (isFeatureContinuous){
190+ comparison match {
191+ case (- 1 ) => if (feature > threshold) return false
192+ case (1 ) => if (feature <= threshold) return false
193+ }
194+ } else {
195+ val containsFeature = categories.contains(feature)
196+ comparison match {
197+ case (- 1 ) => if (! containsFeature) return false
198+ case (1 ) => if (containsFeature) return false
199+ }
200+
192201 }
193202 }
194203 true
@@ -197,18 +206,34 @@ object DecisionTree extends Serializable with Logging {
197206 /* Finds the right bin for the given feature*/
198207 def findBin (featureIndex : Int , labeledPoint : LabeledPoint ) : Int = {
199208 // logDebug("finding bin for labeled point " + labeledPoint.features(featureIndex))
200- // TODO: Do binary search
201- for (binIndex <- 0 until strategy.numBins) {
202- val bin = bins(featureIndex)(binIndex)
203- // TODO: Remove this requirement post basic functional
204- val lowThreshold = bin.lowSplit.threshold
205- val highThreshold = bin.highSplit.threshold
206- val features = labeledPoint.features
207- if ((lowThreshold <= features(featureIndex)) & (highThreshold > features(featureIndex))) {
208- return binIndex
209+
210+ val isFeatureContinous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
211+ if (isFeatureContinous){
212+ // TODO: Do binary search
213+ for (binIndex <- 0 until strategy.numBins) {
214+ val bin = bins(featureIndex)(binIndex)
215+ // TODO: Remove this requirement post basic functional
216+ val lowThreshold = bin.lowSplit.threshold
217+ val highThreshold = bin.highSplit.threshold
218+ val features = labeledPoint.features
219+ if ((lowThreshold <= features(featureIndex)) & (highThreshold > features(featureIndex))) {
220+ return binIndex
221+ }
222+ }
223+ throw new UnknownError (" no bin was found for continuous variable." )
224+ } else {
225+ for (binIndex <- 0 until strategy.numBins) {
226+ val bin = bins(featureIndex)(binIndex)
227+ // TODO: Remove this requirement post basic functional
228+ val category = bin.category
229+ val features = labeledPoint.features
230+ if (category == features(featureIndex)) {
231+ return binIndex
232+ }
209233 }
234+ throw new UnknownError (" no bin was found for categorical variable." )
235+
210236 }
211- throw new UnknownError (" no bin was found." )
212237
213238 }
214239
@@ -565,7 +590,7 @@ object DecisionTree extends Serializable with Logging {
565590 @return a tuple of (splits,bins) where Split is an Array[Array[Split]] of size (numFeatures,numSplits-1) and bins is an
566591 Array[Array[Bin]] of size (numFeatures,numSplits1)
567592 */
568- def find_splits_bins (input : RDD [LabeledPoint ], strategy : Strategy ) : (Array [Array [Split ]], Array [Array [Bin ]]) = {
593+ def findSplitsBins (input : RDD [LabeledPoint ], strategy : Strategy ) : (Array [Array [Split ]], Array [Array [Bin ]]) = {
569594
570595 val count = input.count()
571596
@@ -603,31 +628,71 @@ object DecisionTree extends Serializable with Logging {
603628 logDebug(" stride = " + stride)
604629 for (index <- 0 until numBins- 1 ) {
605630 val sampleIndex = (index+ 1 )* stride.toInt
606- val split = new Split (featureIndex,featureSamples(sampleIndex),Continuous )
631+ val split = new Split (featureIndex,featureSamples(sampleIndex),Continuous , List () )
607632 splits(featureIndex)(index) = split
608633 }
609634 } else {
610635 val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex)
611- for (index <- 0 until maxFeatureValue){
612- // TODO: Sort by centriod
613- val split = new Split (featureIndex,index,Categorical )
614- splits(featureIndex)(index) = split
636+
637+ require(maxFeatureValue < numBins, " number of categories should be less than number of bins" )
638+
639+ val centriodForCategories
640+ = sampledInput.map(lp => (lp.features(featureIndex),lp.label))
641+ .groupBy(_._1).mapValues(x => x.map(_._2).sum / x.map(_._1).length)
642+
643+ // Checking for missing categorical variables
644+ val fullCentriodForCategories = scala.collection.mutable.Map [Double ,Double ]()
645+ for (i <- 0 until maxFeatureValue){
646+ if (centriodForCategories.contains(i)){
647+ fullCentriodForCategories(i) = centriodForCategories(i)
648+ } else {
649+ fullCentriodForCategories(i) = Double .MaxValue
650+ }
651+ }
652+
653+ val categoriesSortedByCentriod
654+ = fullCentriodForCategories.toList sortBy {_._2}
655+
656+ logDebug(" centriod for categorical variable = " + categoriesSortedByCentriod)
657+
658+ var categoriesForSplit = List [Double ]()
659+ categoriesSortedByCentriod.iterator.zipWithIndex foreach {
660+ case ((key, value), index) => {
661+ categoriesForSplit = key :: categoriesForSplit
662+ splits(featureIndex)(index) = new Split (featureIndex,Double .MinValue ,Categorical ,categoriesForSplit)
663+ bins(featureIndex)(index) = {
664+ if (index == 0 ) {
665+ new Bin (new DummyCategoricalSplit (featureIndex,Categorical ),splits(featureIndex)(0 ),Categorical ,key)
666+ }
667+ else {
668+ new Bin (splits(featureIndex)(index- 1 ),splits(featureIndex)(index),Categorical ,key)
669+ }
670+ }
671+ }
615672 }
616673 }
617674 }
618675
619676 // Find all bins
620677 for (featureIndex <- 0 until numFeatures){
621- bins(featureIndex)(0 )
622- = new Bin (new DummyLowSplit (Continuous ),splits(featureIndex)(0 ),Continuous )
623- for (index <- 1 until numBins - 1 ){
624- val bin = new Bin (splits(featureIndex)(index- 1 ),splits(featureIndex)(index),Continuous )
625- bins(featureIndex)(index) = bin
678+ val isFeatureContinous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
679+ if (isFeatureContinous) { // bins for categorical variables are already assigned
680+ bins(featureIndex)(0 )
681+ = new Bin (new DummyLowSplit (featureIndex, Continuous ),splits(featureIndex)(0 ),Continuous ,Double .MinValue )
682+ for (index <- 1 until numBins - 1 ){
683+ val bin = new Bin (splits(featureIndex)(index- 1 ),splits(featureIndex)(index),Continuous ,Double .MinValue )
684+ bins(featureIndex)(index) = bin
685+ }
686+ bins(featureIndex)(numBins- 1 )
687+ = new Bin (splits(featureIndex)(numBins- 2 ),new DummyHighSplit (featureIndex, Continuous ),Continuous ,Double .MinValue )
688+ } else {
689+ val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex)
690+ for (i <- maxFeatureValue until numBins){
691+ bins(featureIndex)(i)
692+ = new Bin (new DummyCategoricalSplit (featureIndex,Categorical ),new DummyCategoricalSplit (featureIndex,Categorical ),Categorical ,Double .MaxValue )
693+ }
626694 }
627- bins(featureIndex)(numBins- 1 )
628- = new Bin (splits(featureIndex)(numBins- 2 ),new DummyHighSplit (Continuous ),Continuous )
629695 }
630-
631696 (splits,bins)
632697 }
633698 case MinMax => {
0 commit comments