Skip to content

Commit ab5cb21

Browse files
committed
multiclass logic
1 parent d8e4a11 commit ab5cb21

File tree

4 files changed

+94
-54
lines changed

4 files changed

+94
-54
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 71 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -545,17 +545,24 @@ object DecisionTree extends Serializable with Logging {
545545
-1
546546
}
547547

548+
/**
549+
* Sequential search helper method to find bin for categorical feature in multiclass
550+
* classification. Dummy value of 0 used since it is not used in future calculation
551+
*/
552+
def sequentialBinSearchForCategoricalFeatureInBinaryClassification(): Int = 0
553+
548554
/**
549555
* Sequential search helper method to find bin for categorical feature.
550556
*/
551-
def sequentialBinSearchForCategoricalFeature(): Int = {
552-
val numCategoricalBins = strategy.categoricalFeaturesInfo(featureIndex)
557+
def sequentialBinSearchForCategoricalFeatureInMultiClassClassification(): Int = {
558+
val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
559+
val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1
553560
var binIndex = 0
554561
while (binIndex < numCategoricalBins) {
555562
val bin = bins(featureIndex)(binIndex)
556-
val category = bin.category
563+
val categories = bin.highSplit.categories
557564
val features = labeledPoint.features
558-
if (category == features(featureIndex)) {
565+
if (categories.contains(features(featureIndex))) {
559566
return binIndex
560567
}
561568
binIndex += 1
@@ -572,7 +579,14 @@ object DecisionTree extends Serializable with Logging {
572579
binIndex
573580
} else {
574581
// Perform sequential search to find bin for categorical features.
575-
val binIndex = sequentialBinSearchForCategoricalFeature()
582+
val binIndex = {
583+
if (strategy.isMultiClassification) {
584+
sequentialBinSearchForCategoricalFeatureInBinaryClassification()
585+
}
586+
else {
587+
sequentialBinSearchForCategoricalFeatureInMultiClassClassification()
588+
}
589+
}
576590
if (binIndex == -1){
577591
throw new UnknownError("no bin was found for categorical variable.")
578592
}
@@ -584,7 +598,8 @@ object DecisionTree extends Serializable with Logging {
584598
* Finds bins for all nodes (and all features) at a given level.
585599
* For l nodes, k features the storage is as follows:
586600
* label, b_11, b_12, .. , b_1k, b_21, b_22, .. , b_2k, b_l1, b_l2, .. , b_lk,
587-
* where b_ij is an integer between 0 and numBins - 1.
601+
* where b_ij is an integer between 0 and numBins - 1 for regressions and binary
602+
* classification and an invalid value for categorical feature in multiclass classification.
588603
* Invalid sample is denoted by noting bin for feature 1 as -1.
589604
*/
590605
def findBinsForLevel(labeledPoint: WeightedLabeledPoint): Array[Double] = {
@@ -646,7 +661,22 @@ object DecisionTree extends Serializable with Logging {
646661
= aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses
647662
label.toInt match {
648663
case n: Int =>
649-
agg(aggIndex + n) = agg(aggIndex + n) + 1 * labelWeights.getOrElse(n, 1)
664+
val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
665+
if (isFeatureContinuous && strategy.isMultiClassification) {
666+
// Find all matching bins and increment their values
667+
val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
668+
val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1
669+
var binIndex = 0
670+
while (binIndex < numCategoricalBins) {
671+
if (bins(featureIndex)(binIndex).highSplit.categories.contains(n)){
672+
agg(aggIndex + binIndex)
673+
= agg(aggIndex + binIndex) + labelWeights.getOrElse(binIndex, 1)
674+
}
675+
binIndex += 1
676+
}
677+
} else {
678+
agg(aggIndex + n) = agg(aggIndex + n) + labelWeights.getOrElse(n, 1)
679+
}
650680
}
651681
featureIndex += 1
652682
}
@@ -705,6 +735,7 @@ object DecisionTree extends Serializable with Logging {
705735
agg
706736
}
707737

738+
// TODO: Double-check this
708739
// Calculate bin aggregate length for classification or regression.
709740
val binAggregateLength = strategy.algo match {
710741
case Classification => numClasses * numBins * numFeatures * numNodes
@@ -785,10 +816,10 @@ object DecisionTree extends Serializable with Logging {
785816
}
786817

787818
if (leftTotalCount == 0) {
788-
return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity,1)
819+
return new InformationGainStats(0, topImpurity, topImpurity, Double.MinValue, 1)
789820
}
790821
if (rightTotalCount == 0) {
791-
return new InformationGainStats(0, topImpurity, topImpurity, Double.MinValue,0)
822+
return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity, 1)
792823
}
793824

794825
val leftImpurity = strategy.impurity.calculate(leftCounts, leftTotalCount)
@@ -812,16 +843,16 @@ object DecisionTree extends Serializable with Logging {
812843
= leftCounts.zip(rightCounts)
813844
.map{case (leftCount, rightCount) => leftCount + rightCount}
814845

815-
def indexOfLargest(array: Seq[Double]): Int = {
846+
def indexOfLargestArrayElement(array: Array[Double]): Int = {
816847
val result = array.foldLeft(-1,Double.MinValue,0) {
817848
case ((maxIndex, maxValue, currentIndex), currentValue) =>
818849
if(currentValue > maxValue) (currentIndex,currentValue,currentIndex+1)
819850
else (maxIndex,maxValue,currentIndex+1)
820851
}
821-
if (result._1 < 0) result._1 else 0
852+
if (result._1 < 0) 0 else result._1
822853
}
823854

824-
val predict = indexOfLargest(leftRightCounts)
855+
val predict = indexOfLargestArrayElement(leftRightCounts)
825856
val prob = leftRightCounts(predict) / totalCount
826857

827858
new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob)
@@ -1051,8 +1082,20 @@ object DecisionTree extends Serializable with Logging {
10511082
while (featureIndex < numFeatures) {
10521083
// Iterate over all splits.
10531084
var splitIndex = 0
1054-
// TODO: Modify this for categorical variables to go over only valid splits
1055-
while (splitIndex < numBins - 1) {
1085+
val maxSplitIndex : Double = {
1086+
val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
1087+
if (isFeatureContinuous) {
1088+
numBins - 1
1089+
} else { // Categorical feature
1090+
val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
1091+
if (strategy.isMultiClassification) {
1092+
math.pow(2.0, featureCategories - 1).toInt - 1
1093+
} else { // Binary classification
1094+
featureCategories
1095+
}
1096+
}
1097+
}
1098+
while (splitIndex < maxSplitIndex) {
10561099
val gainStats = gains(featureIndex)(splitIndex)
10571100
if (gainStats.gain > bestGainStats.gain) {
10581101
bestGainStats = gainStats
@@ -1176,24 +1219,29 @@ object DecisionTree extends Serializable with Logging {
11761219
splits(featureIndex)(index) = split
11771220
}
11781221
} else { // Categorical feature
1179-
val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex)
1222+
val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
11801223

11811224
// Use different bin/split calculation strategy for multiclass classification
11821225
if (strategy.isMultiClassification) {
1183-
// Iterate from 0 to 2^maxFeatureValue - 1 leading to 2^(maxFeatureValue- 1) - 1
1184-
// combinations.
1226+
// 2^(maxFeatureValue- 1) - 1 combinations
11851227
var index = 0
1186-
while (index < math.pow(2.0, maxFeatureValue).toInt - 1) {
1228+
while (index < math.pow(2.0, featureCategories - 1).toInt - 1) {
11871229
val categories: List[Double]
1188-
= extractMultiClassCategories(index + 1, maxFeatureValue)
1230+
= extractMultiClassCategories(index + 1, featureCategories)
11891231
splits(featureIndex)(index)
11901232
= new Split(featureIndex, Double.MinValue, Categorical, categories)
11911233
bins(featureIndex)(index) = {
11921234
if (index == 0) {
1193-
new Bin(new DummyCategoricalSplit(featureIndex, Categorical),
1194-
splits(featureIndex)(0), Categorical, Double.MinValue)
1235+
new Bin(
1236+
new DummyCategoricalSplit(featureIndex, Categorical),
1237+
splits(featureIndex)(0),
1238+
Categorical,
1239+
Double.MinValue)
11951240
} else {
1196-
new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), Categorical,
1241+
new Bin(
1242+
splits(featureIndex)(index - 1),
1243+
splits(featureIndex)(index),
1244+
Categorical,
11971245
Double.MinValue)
11981246
}
11991247
}
@@ -1210,7 +1258,7 @@ object DecisionTree extends Serializable with Logging {
12101258

12111259
// Check for missing categorical variables and putting them last in the sorted list.
12121260
val fullCentroidForCategories = scala.collection.mutable.Map[Double,Double]()
1213-
for (i <- 0 until maxFeatureValue) {
1261+
for (i <- 0 until featureCategories) {
12141262
if (centroidForCategories.contains(i)) {
12151263
fullCentroidForCategories(i) = centroidForCategories(i)
12161264
} else {

mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,11 @@ object Entropy extends Impurity {
4242
var impurity = 0.0
4343
var classIndex = 0
4444
while (classIndex < numClasses) {
45-
val freq = counts(classIndex) / totalCount
46-
impurity -= freq * log2(freq)
45+
val classCount = counts(classIndex)
46+
if (classCount != 0) {
47+
val freq = classCount / totalCount
48+
impurity -= freq * log2(freq)
49+
}
4750
classIndex += 1
4851
}
4952
impurity

mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._
2828
* @param highSplit signifying the upper threshold for the continuous feature to be
2929
* accepted in the bin
3030
* @param featureType type of feature -- categorical or continuous
31-
* @param category categorical label value accepted in the bin
31+
* @param category categorical label value accepted in the bin for binary classification
3232
*/
3333
private[tree]
3434
case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double)

mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala

Lines changed: 17 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
3535
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
3636
assert(arr.length === 1000)
3737
val rdd = sc.parallelize(arr)
38-
val strategy = new Strategy(Classification, Gini, 3, 100)
38+
val strategy = new Strategy(Classification, Gini, 3, 2, 100)
3939
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
4040
assert(splits.length === 2)
4141
assert(bins.length === 2)
@@ -51,6 +51,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
5151
Classification,
5252
Gini,
5353
maxDepth = 3,
54+
numClassesForClassification = 2,
5455
maxBins = 100,
5556
categoricalFeaturesInfo = Map(0 -> 2, 1-> 2))
5657
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
@@ -130,6 +131,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
130131
Classification,
131132
Gini,
132133
maxDepth = 3,
134+
numClassesForClassification = 2,
133135
maxBins = 100,
134136
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
135137
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
@@ -237,20 +239,20 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
237239
assert(List(3.0, 2.0, 0.0).toSeq == l.toSeq)
238240
}
239241

240-
test("split and bin calculations for categorical variables wiht multiclass classification") {
242+
test("split and bin calculations for categorical variables with multiclass classification") {
241243
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
242244
assert(arr.length === 1000)
243245
val rdd = sc.parallelize(arr)
244246
val strategy = new Strategy(
245247
Classification,
246248
Gini,
247249
maxDepth = 3,
250+
numClassesForClassification = 100,
248251
maxBins = 100,
249-
categoricalFeaturesInfo = Map(0 -> 2, 1-> 2),
250-
numClassesForClassification = 3)
252+
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
251253
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
252254

253-
// Expecting 2^3 - 1 = 7 bins/splits
255+
// Expecting 2^2 - 1 = 3 bins/splits
254256
assert(splits(0)(0).feature === 0)
255257
assert(splits(0)(0).threshold === Double.MinValue)
256258
assert(splits(0)(0).featureType === Categorical)
@@ -287,6 +289,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
287289
assert(splits(1)(2).categories.contains(1.0))
288290

289291
assert(splits(0)(3) === null)
292+
assert(splits(1)(3) === null)
290293

291294

292295
// Check bins.
@@ -329,29 +332,14 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
329332

330333
}
331334

332-
test("split and bin calculations for categorical variables with no sample for one category " +
333-
"for multiclass classification") {
334-
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
335-
assert(arr.length === 1000)
336-
val rdd = sc.parallelize(arr)
337-
val strategy = new Strategy(
338-
Classification,
339-
Gini,
340-
maxDepth = 3,
341-
maxBins = 100,
342-
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3),
343-
numClassesForClassification = 3)
344-
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
345-
346-
}
347-
348335
test("classification stump with all categorical variables") {
349336
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
350337
assert(arr.length === 1000)
351338
val rdd = sc.parallelize(arr)
352339
val strategy = new Strategy(
353340
Classification,
354341
Gini,
342+
numClassesForClassification = 2,
355343
maxDepth = 3,
356344
maxBins = 100,
357345
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
@@ -367,8 +355,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
367355

368356
val stats = bestSplits(0)._2
369357
assert(stats.gain > 0)
370-
assert(stats.predict > 0.4)
371-
assert(stats.predict < 0.5)
358+
assert(stats.predict === 0)
359+
assert(stats.prob > 0.5)
360+
assert(stats.prob < 0.6)
372361
assert(stats.impurity > 0.2)
373362
}
374363

@@ -403,7 +392,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
403392
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
404393
assert(arr.length === 1000)
405394
val rdd = sc.parallelize(arr)
406-
val strategy = new Strategy(Classification, Gini, 3, 100)
395+
val strategy = new Strategy(Classification, Gini, 3, 2, 100)
407396
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
408397
assert(splits.length === 2)
409398
assert(splits(0).length === 99)
@@ -426,7 +415,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
426415
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
427416
assert(arr.length === 1000)
428417
val rdd = sc.parallelize(arr)
429-
val strategy = new Strategy(Classification, Gini, 3, 100)
418+
val strategy = new Strategy(Classification, Gini, 3, 2, 100)
430419
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
431420
assert(splits.length === 2)
432421
assert(splits(0).length === 99)
@@ -450,7 +439,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
450439
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
451440
assert(arr.length === 1000)
452441
val rdd = sc.parallelize(arr)
453-
val strategy = new Strategy(Classification, Entropy, 3, 100)
442+
val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
454443
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
455444
assert(splits.length === 2)
456445
assert(splits(0).length === 99)
@@ -474,7 +463,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
474463
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
475464
assert(arr.length === 1000)
476465
val rdd = sc.parallelize(arr)
477-
val strategy = new Strategy(Classification, Entropy, 3, 100)
466+
val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
478467
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
479468
assert(splits.length === 2)
480469
assert(splits(0).length === 99)
@@ -498,7 +487,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
498487
val arr = DecisionTreeSuite.generateOrderedLabeledPoints()
499488
assert(arr.length === 1000)
500489
val rdd = sc.parallelize(arr)
501-
val strategy = new Strategy(Classification, Entropy, 3, 100)
490+
val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
502491
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
503492
assert(splits.length === 2)
504493
assert(splits(0).length === 99)

0 commit comments

Comments
 (0)