@@ -545,17 +545,24 @@ object DecisionTree extends Serializable with Logging {
545545        - 1 
546546      }
547547
548+       /**  
549+        * Sequential search helper method to find bin for categorical feature in multiclass 
550+        * classification. Dummy value of 0 used since it is not used in future calculation 
551+        */  
552+       def  sequentialBinSearchForCategoricalFeatureInBinaryClassification ():  Int  =  0 
553+ 
548554      /**  
549555       * Sequential search helper method to find bin for categorical feature. 
550556       */  
551-       def  sequentialBinSearchForCategoricalFeature ():  Int  =  {
552-         val  numCategoricalBins  =  strategy.categoricalFeaturesInfo(featureIndex)
557+       def  sequentialBinSearchForCategoricalFeatureInMultiClassClassification ():  Int  =  {
558+         val  featureCategories  =  strategy.categoricalFeaturesInfo(featureIndex)
559+         val  numCategoricalBins  =  math.pow(2.0 , featureCategories -  1 ).toInt -  1 
553560        var  binIndex  =  0 
554561        while  (binIndex <  numCategoricalBins) {
555562          val  bin  =  bins(featureIndex)(binIndex)
556-           val  category  =  bin.category 
563+           val  categories  =  bin.highSplit.categories 
557564          val  features  =  labeledPoint.features
558-           if  (category  ==   features(featureIndex)) {
565+           if  (categories.contains( features(featureIndex) )) {
559566            return  binIndex
560567          }
561568          binIndex +=  1 
@@ -572,7 +579,14 @@ object DecisionTree extends Serializable with Logging {
572579        binIndex
573580      } else  {
574581        //  Perform sequential search to find bin for categorical features.
575-         val  binIndex  =  sequentialBinSearchForCategoricalFeature()
582+         val  binIndex  =  {
583+           if  (strategy.isMultiClassification) {
584+             sequentialBinSearchForCategoricalFeatureInBinaryClassification()
585+           }
586+           else  {
587+             sequentialBinSearchForCategoricalFeatureInMultiClassClassification()
588+           }
589+         }
576590        if  (binIndex ==  - 1 ){
577591          throw  new  UnknownError (" no bin was found for categorical variable."  )
578592        }
@@ -584,7 +598,8 @@ object DecisionTree extends Serializable with Logging {
584598     * Finds bins for all nodes (and all features) at a given level. 
585599     * For l nodes, k features the storage is as follows: 
586600     * label, b_11, b_12, .. , b_1k, b_21, b_22, .. , b_2k, b_l1, b_l2, .. , b_lk, 
587-      * where b_ij is an integer between 0 and numBins - 1. 
601+      * where b_ij is an integer between 0 and numBins - 1 for regressions and binary 
602+      * classification and an invalid value for categorical feature in  multiclass classification. 
588603     * Invalid sample is denoted by noting bin for feature 1 as -1. 
589604     */  
590605    def  findBinsForLevel (labeledPoint : WeightedLabeledPoint ):  Array [Double ] =  {
@@ -646,7 +661,22 @@ object DecisionTree extends Serializable with Logging {
646661              =  aggShift +  numClasses *  featureIndex *  numBins +  arr(arrIndex).toInt *  numClasses
647662            label.toInt match  {
648663              case  n : Int  => 
649-                 agg(aggIndex +  n) =  agg(aggIndex +  n) +  1  *  labelWeights.getOrElse(n, 1 )
664+                 val  isFeatureContinuous  =  strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
665+                 if  (isFeatureContinuous &&  strategy.isMultiClassification) {
666+                   //  Find all matching bins and increment their values
667+                   val  featureCategories  =  strategy.categoricalFeaturesInfo(featureIndex)
668+                   val  numCategoricalBins  =  math.pow(2.0 , featureCategories -  1 ).toInt -  1 
669+                   var  binIndex  =  0 
670+                   while  (binIndex <  numCategoricalBins) {
671+                     if  (bins(featureIndex)(binIndex).highSplit.categories.contains(n)){
672+                       agg(aggIndex +  binIndex)
673+                         =  agg(aggIndex +  binIndex) +  labelWeights.getOrElse(binIndex, 1 )
674+                     }
675+                     binIndex +=  1 
676+                   }
677+                 } else  {
678+                   agg(aggIndex +  n) =  agg(aggIndex +  n) +  labelWeights.getOrElse(n, 1 )
679+                 }
650680            }
651681            featureIndex +=  1 
652682          }
@@ -705,6 +735,7 @@ object DecisionTree extends Serializable with Logging {
705735      agg
706736    }
707737
738+     //  TODO: Double-check this
708739    //  Calculate bin aggregate length for classification or regression.
709740    val  binAggregateLength  =  strategy.algo match  {
710741      case  Classification  =>  numClasses *  numBins *  numFeatures *  numNodes
@@ -785,10 +816,10 @@ object DecisionTree extends Serializable with Logging {
785816          }
786817
787818          if  (leftTotalCount ==  0 ) {
788-             return  new  InformationGainStats (0 , topImpurity, Double .MinValue , topImpurity, 1 )
819+             return  new  InformationGainStats (0 , topImpurity, topImpurity,  Double .MinValue , 1 )
789820          }
790821          if  (rightTotalCount ==  0 ) {
791-             return  new  InformationGainStats (0 , topImpurity, topImpurity,  Double .MinValue ,0 )
822+             return  new  InformationGainStats (0 , topImpurity, Double .MinValue , topImpurity,  1 )
792823          }
793824
794825          val  leftImpurity  =  strategy.impurity.calculate(leftCounts, leftTotalCount)
@@ -812,16 +843,16 @@ object DecisionTree extends Serializable with Logging {
812843            =  leftCounts.zip(rightCounts)
813844              .map{case  (leftCount, rightCount) =>  leftCount +  rightCount}
814845
815-           def  indexOfLargest (array : Seq [Double ]):  Int  =  {
846+           def  indexOfLargestArrayElement (array : Array [Double ]):  Int  =  {
816847            val  result  =  array.foldLeft(- 1 ,Double .MinValue ,0 ) {
817848              case  ((maxIndex, maxValue, currentIndex), currentValue) => 
818849                if (currentValue >  maxValue) (currentIndex,currentValue,currentIndex+ 1 )
819850                else  (maxIndex,maxValue,currentIndex+ 1 )
820851            }
821-             if  (result._1 <  0 ) result._1  else  0 
852+             if  (result._1 <  0 ) 0  else  result._1 
822853          }
823854
824-           val  predict  =  indexOfLargest (leftRightCounts)
855+           val  predict  =  indexOfLargestArrayElement (leftRightCounts)
825856          val  prob  =  leftRightCounts(predict) /  totalCount
826857
827858          new  InformationGainStats (gain, impurity, leftImpurity, rightImpurity, predict, prob)
@@ -1051,8 +1082,20 @@ object DecisionTree extends Serializable with Logging {
10511082        while  (featureIndex <  numFeatures) {
10521083          //  Iterate over all splits.
10531084          var  splitIndex  =  0 
1054-           //  TODO: Modify this for categorical variables to go over only valid splits
1055-           while  (splitIndex <  numBins -  1 ) {
1085+           val  maxSplitIndex  :  Double  =  {
1086+             val  isFeatureContinuous  =  strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
1087+             if  (isFeatureContinuous) {
1088+               numBins -  1 
1089+             } else  { //  Categorical feature
1090+               val  featureCategories  =  strategy.categoricalFeaturesInfo(featureIndex)
1091+               if  (strategy.isMultiClassification) {
1092+                 math.pow(2.0 , featureCategories -  1 ).toInt -  1 
1093+               } else  { //  Binary classification
1094+                 featureCategories
1095+               }
1096+             }
1097+           }
1098+           while  (splitIndex <  maxSplitIndex) {
10561099            val  gainStats  =  gains(featureIndex)(splitIndex)
10571100            if  (gainStats.gain >  bestGainStats.gain) {
10581101              bestGainStats =  gainStats
@@ -1176,24 +1219,29 @@ object DecisionTree extends Serializable with Logging {
11761219              splits(featureIndex)(index) =  split
11771220            }
11781221          } else  { //  Categorical feature
1179-             val  maxFeatureValue  =  strategy.categoricalFeaturesInfo(featureIndex)
1222+             val  featureCategories  =  strategy.categoricalFeaturesInfo(featureIndex)
11801223
11811224            //  Use different bin/split calculation strategy for multiclass classification
11821225            if  (strategy.isMultiClassification) {
1183-               //  Iterate from 0 to 2^maxFeatureValue - 1 leading to 2^(maxFeatureValue- 1) - 1
1184-               //  combinations.
1226+               //  2^(maxFeatureValue- 1) - 1 combinations
11851227              var  index  =  0 
1186-               while  (index <  math.pow(2.0 , maxFeatureValue ).toInt -  1 ) {
1228+               while  (index <  math.pow(2.0 , featureCategories  -   1 ).toInt -  1 ) {
11871229                val  categories :  List [Double ]
1188-                   =  extractMultiClassCategories(index +  1 , maxFeatureValue )
1230+                   =  extractMultiClassCategories(index +  1 , featureCategories )
11891231                splits(featureIndex)(index)
11901232                  =  new  Split (featureIndex, Double .MinValue , Categorical , categories)
11911233                bins(featureIndex)(index) =  {
11921234                  if  (index ==  0 ) {
1193-                     new  Bin (new  DummyCategoricalSplit (featureIndex, Categorical ),
1194-                       splits(featureIndex)(0 ), Categorical , Double .MinValue )
1235+                     new  Bin (
1236+                       new  DummyCategoricalSplit (featureIndex, Categorical ),
1237+                       splits(featureIndex)(0 ),
1238+                       Categorical ,
1239+                       Double .MinValue )
11951240                  } else  {
1196-                     new  Bin (splits(featureIndex)(index- 1 ), splits(featureIndex)(index), Categorical ,
1241+                     new  Bin (
1242+                       splits(featureIndex)(index -  1 ),
1243+                       splits(featureIndex)(index),
1244+                       Categorical ,
11971245                      Double .MinValue )
11981246                  }
11991247                }
@@ -1210,7 +1258,7 @@ object DecisionTree extends Serializable with Logging {
12101258
12111259              //  Check for missing categorical variables and putting them last in the sorted list.
12121260              val  fullCentroidForCategories  =  scala.collection.mutable.Map [Double ,Double ]()
1213-               for  (i <-  0  until maxFeatureValue ) {
1261+               for  (i <-  0  until featureCategories ) {
12141262                if  (centroidForCategories.contains(i)) {
12151263                  fullCentroidForCategories(i) =  centroidForCategories(i)
12161264                } else  {
0 commit comments