@@ -545,17 +545,24 @@ object DecisionTree extends Serializable with Logging {
545545 - 1
546546 }
547547
548+ /**
549+ * Sequential search helper method to find bin for categorical feature in multiclass
550+ * classification. Dummy value of 0 used since it is not used in future calculation
551+ */
552+ def sequentialBinSearchForCategoricalFeatureInBinaryClassification (): Int = 0
553+
548554 /**
549555 * Sequential search helper method to find bin for categorical feature.
550556 */
551- def sequentialBinSearchForCategoricalFeature (): Int = {
552- val numCategoricalBins = strategy.categoricalFeaturesInfo(featureIndex)
557+ def sequentialBinSearchForCategoricalFeatureInMultiClassClassification (): Int = {
558+ val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
559+ val numCategoricalBins = math.pow(2.0 , featureCategories - 1 ).toInt - 1
553560 var binIndex = 0
554561 while (binIndex < numCategoricalBins) {
555562 val bin = bins(featureIndex)(binIndex)
556- val category = bin.category
563+ val categories = bin.highSplit.categories
557564 val features = labeledPoint.features
558- if (category == features(featureIndex)) {
565+ if (categories.contains( features(featureIndex) )) {
559566 return binIndex
560567 }
561568 binIndex += 1
@@ -572,7 +579,14 @@ object DecisionTree extends Serializable with Logging {
572579 binIndex
573580 } else {
574581 // Perform sequential search to find bin for categorical features.
575- val binIndex = sequentialBinSearchForCategoricalFeature()
582+ val binIndex = {
583+ if (strategy.isMultiClassification) {
584+ sequentialBinSearchForCategoricalFeatureInBinaryClassification()
585+ }
586+ else {
587+ sequentialBinSearchForCategoricalFeatureInMultiClassClassification()
588+ }
589+ }
576590 if (binIndex == - 1 ){
577591 throw new UnknownError (" no bin was found for categorical variable." )
578592 }
@@ -584,7 +598,8 @@ object DecisionTree extends Serializable with Logging {
584598 * Finds bins for all nodes (and all features) at a given level.
585599 * For l nodes, k features the storage is as follows:
586600 * label, b_11, b_12, .. , b_1k, b_21, b_22, .. , b_2k, b_l1, b_l2, .. , b_lk,
587- * where b_ij is an integer between 0 and numBins - 1.
601+ * where b_ij is an integer between 0 and numBins - 1 for regressions and binary
602+ * classification and an invalid value for categorical feature in multiclass classification.
588603 * Invalid sample is denoted by noting bin for feature 1 as -1.
589604 */
590605 def findBinsForLevel (labeledPoint : WeightedLabeledPoint ): Array [Double ] = {
@@ -646,7 +661,22 @@ object DecisionTree extends Serializable with Logging {
646661 = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses
647662 label.toInt match {
648663 case n : Int =>
649- agg(aggIndex + n) = agg(aggIndex + n) + 1 * labelWeights.getOrElse(n, 1 )
664+ val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
665+ if (isFeatureContinuous && strategy.isMultiClassification) {
666+ // Find all matching bins and increment their values
667+ val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
668+ val numCategoricalBins = math.pow(2.0 , featureCategories - 1 ).toInt - 1
669+ var binIndex = 0
670+ while (binIndex < numCategoricalBins) {
671+ if (bins(featureIndex)(binIndex).highSplit.categories.contains(n)){
672+ agg(aggIndex + binIndex)
673+ = agg(aggIndex + binIndex) + labelWeights.getOrElse(binIndex, 1 )
674+ }
675+ binIndex += 1
676+ }
677+ } else {
678+ agg(aggIndex + n) = agg(aggIndex + n) + labelWeights.getOrElse(n, 1 )
679+ }
650680 }
651681 featureIndex += 1
652682 }
@@ -705,6 +735,7 @@ object DecisionTree extends Serializable with Logging {
705735 agg
706736 }
707737
738+ // TODO: Double-check this
708739 // Calculate bin aggregate length for classification or regression.
709740 val binAggregateLength = strategy.algo match {
710741 case Classification => numClasses * numBins * numFeatures * numNodes
@@ -785,10 +816,10 @@ object DecisionTree extends Serializable with Logging {
785816 }
786817
787818 if (leftTotalCount == 0 ) {
788- return new InformationGainStats (0 , topImpurity, Double .MinValue , topImpurity, 1 )
819+ return new InformationGainStats (0 , topImpurity, topImpurity, Double .MinValue , 1 )
789820 }
790821 if (rightTotalCount == 0 ) {
791- return new InformationGainStats (0 , topImpurity, topImpurity, Double .MinValue ,0 )
822+ return new InformationGainStats (0 , topImpurity, Double .MinValue , topImpurity, 1 )
792823 }
793824
794825 val leftImpurity = strategy.impurity.calculate(leftCounts, leftTotalCount)
@@ -812,16 +843,16 @@ object DecisionTree extends Serializable with Logging {
812843 = leftCounts.zip(rightCounts)
813844 .map{case (leftCount, rightCount) => leftCount + rightCount}
814845
815- def indexOfLargest (array : Seq [Double ]): Int = {
846+ def indexOfLargestArrayElement (array : Array [Double ]): Int = {
816847 val result = array.foldLeft(- 1 ,Double .MinValue ,0 ) {
817848 case ((maxIndex, maxValue, currentIndex), currentValue) =>
818849 if (currentValue > maxValue) (currentIndex,currentValue,currentIndex+ 1 )
819850 else (maxIndex,maxValue,currentIndex+ 1 )
820851 }
821- if (result._1 < 0 ) result._1 else 0
852+ if (result._1 < 0 ) 0 else result._1
822853 }
823854
824- val predict = indexOfLargest (leftRightCounts)
855+ val predict = indexOfLargestArrayElement (leftRightCounts)
825856 val prob = leftRightCounts(predict) / totalCount
826857
827858 new InformationGainStats (gain, impurity, leftImpurity, rightImpurity, predict, prob)
@@ -1051,8 +1082,20 @@ object DecisionTree extends Serializable with Logging {
10511082 while (featureIndex < numFeatures) {
10521083 // Iterate over all splits.
10531084 var splitIndex = 0
1054- // TODO: Modify this for categorical variables to go over only valid splits
1055- while (splitIndex < numBins - 1 ) {
1085+ val maxSplitIndex : Double = {
1086+ val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
1087+ if (isFeatureContinuous) {
1088+ numBins - 1
1089+ } else { // Categorical feature
1090+ val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
1091+ if (strategy.isMultiClassification) {
1092+ math.pow(2.0 , featureCategories - 1 ).toInt - 1
1093+ } else { // Binary classification
1094+ featureCategories
1095+ }
1096+ }
1097+ }
1098+ while (splitIndex < maxSplitIndex) {
10561099 val gainStats = gains(featureIndex)(splitIndex)
10571100 if (gainStats.gain > bestGainStats.gain) {
10581101 bestGainStats = gainStats
@@ -1176,24 +1219,29 @@ object DecisionTree extends Serializable with Logging {
11761219 splits(featureIndex)(index) = split
11771220 }
11781221 } else { // Categorical feature
1179- val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex)
1222+ val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
11801223
11811224 // Use different bin/split calculation strategy for multiclass classification
11821225 if (strategy.isMultiClassification) {
1183- // Iterate from 0 to 2^maxFeatureValue - 1 leading to 2^(maxFeatureValue- 1) - 1
1184- // combinations.
1226+ // 2^(maxFeatureValue- 1) - 1 combinations
11851227 var index = 0
1186- while (index < math.pow(2.0 , maxFeatureValue ).toInt - 1 ) {
1228+ while (index < math.pow(2.0 , featureCategories - 1 ).toInt - 1 ) {
11871229 val categories : List [Double ]
1188- = extractMultiClassCategories(index + 1 , maxFeatureValue )
1230+ = extractMultiClassCategories(index + 1 , featureCategories )
11891231 splits(featureIndex)(index)
11901232 = new Split (featureIndex, Double .MinValue , Categorical , categories)
11911233 bins(featureIndex)(index) = {
11921234 if (index == 0 ) {
1193- new Bin (new DummyCategoricalSplit (featureIndex, Categorical ),
1194- splits(featureIndex)(0 ), Categorical , Double .MinValue )
1235+ new Bin (
1236+ new DummyCategoricalSplit (featureIndex, Categorical ),
1237+ splits(featureIndex)(0 ),
1238+ Categorical ,
1239+ Double .MinValue )
11951240 } else {
1196- new Bin (splits(featureIndex)(index- 1 ), splits(featureIndex)(index), Categorical ,
1241+ new Bin (
1242+ splits(featureIndex)(index - 1 ),
1243+ splits(featureIndex)(index),
1244+ Categorical ,
11971245 Double .MinValue )
11981246 }
11991247 }
@@ -1210,7 +1258,7 @@ object DecisionTree extends Serializable with Logging {
12101258
12111259 // Check for missing categorical variables and putting them last in the sorted list.
12121260 val fullCentroidForCategories = scala.collection.mutable.Map [Double ,Double ]()
1213- for (i <- 0 until maxFeatureValue ) {
1261+ for (i <- 0 until featureCategories ) {
12141262 if (centroidForCategories.contains(i)) {
12151263 fullCentroidForCategories(i) = centroidForCategories(i)
12161264 } else {
0 commit comments