diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 881dcefb79be3..59aaa1cd457a7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -82,6 +82,9 @@ class DecisionTreeClassifier @Since("1.4.0") ( @Since("1.6.0") override def setSeed(value: Long): this.type = super.setSeed(value) + @Since("2.0.0") + override def setClassWeights(value: Array[Double]): this.type = super.setClassWeights(value) + override protected def train(dataset: Dataset[_]): DecisionTreeClassificationModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) @@ -119,7 +122,7 @@ class DecisionTreeClassifier @Since("1.4.0") ( categoricalFeatures: Map[Int, Int], numClasses: Int): OldStrategy = { super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity, - subsamplingRate = 1.0) + subsamplingRate = 1.0, getClassWeights) } @Since("1.4.1") @@ -129,7 +132,7 @@ class DecisionTreeClassifier @Since("1.4.0") ( @Since("1.4.0") @Experimental object DecisionTreeClassifier extends DefaultParamsReadable[DecisionTreeClassifier] { - /** Accessor for supported impurities: entropy, gini */ + /** Accessor for supported impurities: entropy, gini, weightedgini */ @Since("1.4.0") final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities @@ -168,7 +171,7 @@ class DecisionTreeClassificationModel private[ml] ( } override protected def predictRaw(features: Vector): Vector = { - Vectors.dense(rootNode.predictImpl(features).impurityStats.stats.clone()) + Vectors.dense(rootNode.predictImpl(features).impurityStats.weightedStats.clone()) } override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index b3c074f839250..5e61b759c7c6e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -98,13 +98,17 @@ class RandomForestClassifier @Since("1.4.0") ( override def setFeatureSubsetStrategy(value: String): this.type = super.setFeatureSubsetStrategy(value) + @Since("2.0.0") + override def setClassWeights(value: Array[Double]): this.type = super.setClassWeights(value) + override protected def train(dataset: Dataset[_]): RandomForestClassificationModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val numClasses: Int = getNumClasses(dataset) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses) val strategy = - super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity) + super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, + getOldImpurity, getSubsamplingRate, getClassWeights) val instr = Instrumentation.create(this, oldDataset) instr.logParams(params: _*) @@ -195,7 +199,8 @@ class RandomForestClassificationModel private[ml] ( // Ignore the tree weights since all are 1.0 for now. val votes = Array.fill[Double](numClasses)(0.0) _trees.view.foreach { tree => - val classCounts: Array[Double] = tree.rootNode.predictImpl(features).impurityStats.stats + val classCounts: Array[Double] = + tree.rootNode.predictImpl(features).impurityStats.weightedStats val total = classCounts.sum if (total != 0) { var i = 0 diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index c4df9d11127f4..b2fe5ded61793 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -117,7 +117,7 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S /** (private[ml]) Create a Strategy instance to use with the old API. */ private[ml] def getOldStrategy(categoricalFeatures: Map[Int, Int]): OldStrategy = { super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity, - subsamplingRate = 1.0) + subsamplingRate = 1.0, classWeights = Array()) } @Since("1.4.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index a6dbf21d55e2b..9429a053804d1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -98,7 +98,8 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val strategy = - super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity) + super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, + getOldImpurity, getSubsamplingRate, classWeights = Array()) val instr = Instrumentation.create(this, oldDataset) instr.logParams(params: _*) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala index 61091bb803e49..3d175006b9abe 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala @@ -20,7 +20,6 @@ package org.apache.spark.ml.tree.impl import org.apache.spark.mllib.tree.impurity._ - /** * DecisionTree statistics aggregator for a node. * This holds a flat array of statistics for a set of (features, bins) @@ -38,6 +37,7 @@ private[spark] class DTStatsAggregator( case Gini => new GiniAggregator(metadata.numClasses) case Entropy => new EntropyAggregator(metadata.numClasses) case Variance => new VarianceAggregator() + case WeightedGini => new WeightedGiniAggregator(metadata.numClasses, metadata.classWeights) case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}") } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala index 442f52bf0231d..a8ad966adf1cb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala @@ -53,7 +53,8 @@ private[spark] class DecisionTreeMetadata( val minInstancesPerNode: Int, val minInfoGain: Double, val numTrees: Int, - val numFeaturesPerNode: Int) extends Serializable { + val numFeaturesPerNode: Int, + val classWeights: Array[Double]) extends Serializable { def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex) @@ -207,7 +208,8 @@ private[spark] object DecisionTreeMetadata extends Logging { new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max, strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins, strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth, - strategy.minInstancesPerNode, strategy.minInfoGain, numTrees, numFeaturesPerNode) + strategy.minInstancesPerNode, strategy.minInfoGain, numTrees, numFeaturesPerNode, + strategy.classWeights) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 71c8c42ce5eba..fe83d602764a4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -657,8 +657,15 @@ private[spark] object RandomForest extends Logging { val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0 val rightImpurity = rightImpurityCalculator.calculate() - val leftWeight = leftCount / totalCount.toDouble - val rightWeight = rightCount / totalCount.toDouble + // Weighted count is equivalent to normal count using Gini or Entropy impurity + // where the class weights are assumed to be uniform + val leftWeightedCount = leftImpurityCalculator.weightedCount + val rightWeightedCount = rightImpurityCalculator.weightedCount + + val totalWeightedCount = leftWeightedCount + rightWeightedCount + + val leftWeight = leftWeightedCount / totalWeightedCount.toDouble + val rightWeight = rightWeightedCount / totalWeightedCount.toDouble val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index 56c85c9b53e17..029ccfec2e2cb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -342,9 +342,17 @@ private[ml] object DecisionTreeModelReadWrite { Param.jsonDecode[String](compact(render(impurityJson))) } + // Get class weights to construct ImpurityCalculator. This value + // is ignored unless the impurity is WeightedGini + val classWeights: Array[Double] = { + val classWeightsJson: JValue = metadata.getParamValue("classWeights") + compact(render(classWeightsJson)).split("\\[|,|\\]") + .filter((s: String) => s.length() != 0).map((s: String) => s.toDouble) + } + val dataPath = new Path(path, "data").toString val data = sqlContext.read.parquet(dataPath).as[NodeData] - buildTreeFromNodes(data.collect(), impurityType) + buildTreeFromNodes(data.collect(), impurityType, classWeights) } /** @@ -353,7 +361,8 @@ private[ml] object DecisionTreeModelReadWrite { * @param impurityType Impurity type for this tree * @return Root node of reconstructed tree */ - def buildTreeFromNodes(data: Array[NodeData], impurityType: String): Node = { + def buildTreeFromNodes(data: Array[NodeData], impurityType: String, + classWeights: Array[Double]): Node = { // Load all nodes, sorted by ID. val nodes = data.sortBy(_.id) // Sanity checks; could remove @@ -365,7 +374,8 @@ private[ml] object DecisionTreeModelReadWrite { // traversal, this guarantees that child nodes will be built before parent nodes. val finalNodes = new Array[Node](nodes.length) nodes.reverseIterator.foreach { case n: NodeData => - val impurityStats = ImpurityCalculator.getCalculator(impurityType, n.impurityStats) + val impurityStats = ImpurityCalculator.getCalculator(impurityType, + n.impurityStats, classWeights) val node = if (n.leftChild != -1) { val leftChild = finalNodes(n.leftChild) val rightChild = finalNodes(n.rightChild) @@ -437,6 +447,15 @@ private[ml] object EnsembleModelReadWrite { Param.jsonDecode[String](compact(render(impurityJson))) } + // Get class weights to construct ImpurityCalculator. This value + // is ignored unless the impurity is WeightedGini + val classWeights: Array[Double] = { + val classWeightsJson: JValue = metadata.getParamValue("classWeights") + val classWeightsArray = compact(render(classWeightsJson)).split("\\[|,|\\]") + .filter((s: String) => s.length() != 0).map((s: String) => s.toDouble) + classWeightsArray + } + val treesMetadataPath = new Path(path, "treesMetadata").toString val treesMetadataRDD: RDD[(Int, (Metadata, Double))] = sql.read.parquet(treesMetadataPath) .select("treeID", "metadata", "weights").as[(Int, String, Double)].rdd.map { @@ -454,7 +473,8 @@ private[ml] object EnsembleModelReadWrite { val rootNodesRDD: RDD[(Int, Node)] = nodeData.rdd.map(d => (d.treeID, d.nodeData)).groupByKey().map { case (treeID: Int, nodeData: Iterable[NodeData]) => - treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes(nodeData.toArray, impurityType) + treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes(nodeData.toArray, + impurityType, classWeights) } val rootNodes: Array[Node] = rootNodesRDD.sortByKey().values.collect() (metadata, treesMetadata.zip(rootNodes), treesWeights) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index d7559f8950c3d..aba5ab1aec455 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy} -import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance} +import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance, WeightedGini} import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, LogLoss => OldLogLoss, Loss => OldLoss, SquaredError => OldSquaredError} import org.apache.spark.sql.types.{DataType, DoubleType, StructType} @@ -155,7 +155,31 @@ private[ml] trait DecisionTreeParams extends PredictorParams */ def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) - /** (private[ml]) Create a Strategy instance to use with the old API. */ + /** (private[ml]) Create a Strategy instance. */ + private[ml] def getOldStrategy( + categoricalFeatures: Map[Int, Int], + numClasses: Int, + oldAlgo: OldAlgo.Algo, + oldImpurity: OldImpurity, + subsamplingRate: Double, + classWeights: Array[Double]): OldStrategy = { + val strategy = OldStrategy.defaultStrategy(oldAlgo) + strategy.impurity = oldImpurity + strategy.checkpointInterval = getCheckpointInterval + strategy.maxBins = getMaxBins + strategy.maxDepth = getMaxDepth + strategy.maxMemoryInMB = getMaxMemoryInMB + strategy.minInfoGain = getMinInfoGain + strategy.minInstancesPerNode = getMinInstancesPerNode + strategy.useNodeIdCache = getCacheNodeIds + strategy.numClasses = numClasses + strategy.categoricalFeaturesInfo = categoricalFeatures + strategy.subsamplingRate = subsamplingRate + strategy.classWeights = classWeights + strategy + } + + /** (private[ml]) Create a Strategy whose interface is compatible with the old API. */ private[ml] def getOldStrategy( categoricalFeatures: Map[Int, Int], numClasses: Int, @@ -174,6 +198,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams strategy.numClasses = numClasses strategy.categoricalFeaturesInfo = categoricalFeatures strategy.subsamplingRate = subsamplingRate + strategy.classWeights = Array(1.0, 1.0) strategy } } @@ -185,7 +210,7 @@ private[ml] trait TreeClassifierParams extends Params { /** * Criterion used for information gain calculation (case-insensitive). - * Supported: "entropy" and "gini". + * Supported: "entropy", "gini" and "weightedgini". * (default = gini) * @group param */ @@ -194,7 +219,15 @@ private[ml] trait TreeClassifierParams extends Params { s" ${TreeClassifierParams.supportedImpurities.mkString(", ")}", (value: String) => TreeClassifierParams.supportedImpurities.contains(value.toLowerCase)) - setDefault(impurity -> "gini") + /** + * An array that stores the weights of class labels. All elements must be non-negative. + * (default = Array(1.0, 1.0)) + * @group Param + */ + final val classWeights: DoubleArrayParam = new DoubleArrayParam(this, "classWeights", "An array" + + " that stores the weights of class labels. All elements must be non-negative.") + + setDefault(impurity -> "gini", classWeights -> Array(1.0, 1.0)) /** @group setParam */ def setImpurity(value: String): this.type = set(impurity, value) @@ -202,11 +235,18 @@ private[ml] trait TreeClassifierParams extends Params { /** @group getParam */ final def getImpurity: String = $(impurity).toLowerCase + /** @group SetParam */ + def setClassWeights(value: Array[Double]): this.type = set(classWeights, value) + + /** @group GetParam */ + final def getClassWeights: Array[Double] = $(classWeights) + /** Convert new impurity to old impurity. */ private[ml] def getOldImpurity: OldImpurity = { getImpurity match { case "entropy" => OldEntropy case "gini" => OldGini + case "weightedgini" => WeightedGini case _ => // Should never happen because of check in setter method. throw new RuntimeException( @@ -217,7 +257,8 @@ private[ml] trait TreeClassifierParams extends Params { private[ml] object TreeClassifierParams { // These options should be lowercase. - final val supportedImpurities: Array[String] = Array("entropy", "gini").map(_.toLowerCase) + final val supportedImpurities: Array[String] = Array("entropy", "gini", "weightedgini") + .map(_.toLowerCase) } private[ml] trait DecisionTreeClassifierParams @@ -239,7 +280,16 @@ private[ml] trait TreeRegressorParams extends Params { s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}", (value: String) => TreeRegressorParams.supportedImpurities.contains(value.toLowerCase)) - setDefault(impurity -> "variance") + /** + * An array that stores the weights of class labels. This parameter will be ignored in + * regression trees. + * (default = Array()) + * @group Param + */ + final val classWeights: DoubleArrayParam = new DoubleArrayParam(this, "classWeights", "An array" + + " that stores the weights of class labels. All elements must be non-negative.") + + setDefault(impurity -> "variance", classWeights -> Array()) /** @group setParam */ def setImpurity(value: String): this.type = set(impurity, value) @@ -247,6 +297,12 @@ private[ml] trait TreeRegressorParams extends Params { /** @group getParam */ final def getImpurity: String = $(impurity).toLowerCase + /** @group SetParam */ + def setClassWeights(value: Array[Double]): this.type = set(classWeights, value) + + /** @group GetParam */ + final def getClassWeights: Array[Double] = $(classWeights) + /** Convert new impurity to old impurity. */ private[ml] def getOldImpurity: OldImpurity = { getImpurity match { @@ -312,8 +368,19 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams { categoricalFeatures: Map[Int, Int], numClasses: Int, oldAlgo: OldAlgo.Algo, - oldImpurity: OldImpurity): OldStrategy = { - super.getOldStrategy(categoricalFeatures, numClasses, oldAlgo, oldImpurity, getSubsamplingRate) + oldImpurity: OldImpurity, + classWeights: Array[Double]): OldStrategy = { + super.getOldStrategy(categoricalFeatures, numClasses, oldAlgo, + oldImpurity, getSubsamplingRate, classWeights) + } + + private[ml] def getOldStrategy( + categoricalFeatures: Map[Int, Int], + numClasses: Int, + oldAlgo: OldAlgo.Algo, + oldImpurity: OldImpurity): OldStrategy = { + super.getOldStrategy(categoricalFeatures, numClasses, oldAlgo, + oldImpurity, getSubsamplingRate, Array(1.0, 1.0)) } } @@ -455,7 +522,9 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS private[ml] def getOldBoostingStrategy( categoricalFeatures: Map[Int, Int], oldAlgo: OldAlgo.Algo): OldBoostingStrategy = { - val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 2, oldAlgo, OldVariance) + val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 2, + oldAlgo, OldVariance, Array(1.0, 1.0)) + // NOTE: The old API does not support "seed" so we ignore it. new OldBoostingStrategy(strategy, getOldLossType, getMaxIter, getStepSize) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index b34e1b1b56c43..e96350db6bb1d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -23,7 +23,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Since import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ -import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} +import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance, WeightedGini} /** * Stores all the configuration options for tree construction @@ -32,6 +32,7 @@ import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] * @param impurity Criterion used for information gain calculation. * Supported for Classification: [[org.apache.spark.mllib.tree.impurity.Gini]], + * [[org.apache.spark.mllib.tree.impurity.WeightedGini]], * [[org.apache.spark.mllib.tree.impurity.Entropy]]. * Supported for Regression: [[org.apache.spark.mllib.tree.impurity.Variance]]. * @param maxDepth Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means @@ -65,6 +66,8 @@ import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} * E.g. 10 means that the cache will get checkpointed every 10 updates. If * the checkpoint directory is not set in * [[org.apache.spark.SparkContext]], this setting is ignored. + * @param classWeights Weights of classes used in classification problems. It will be ignored in + * regression problems. */ @Since("1.0.0") class Strategy @Since("1.3.0") ( @@ -80,7 +83,9 @@ class Strategy @Since("1.3.0") ( @Since("1.0.0") @BeanProperty var maxMemoryInMB: Int = 256, @Since("1.2.0") @BeanProperty var subsamplingRate: Double = 1, @Since("1.2.0") @BeanProperty var useNodeIdCache: Boolean = false, - @Since("1.2.0") @BeanProperty var checkpointInterval: Int = 10) extends Serializable { + @Since("1.2.0") @BeanProperty var checkpointInterval: Int = 10, + @Since("2.0.0") @BeanProperty var classWeights: Array[Double] = Array(1.0, 1.0)) + extends Serializable { /** */ @@ -96,6 +101,29 @@ class Strategy @Since("1.3.0") ( isMulticlassClassification && (categoricalFeaturesInfo.size > 0) } + /** + * Make the Strategy class compatible with old API + */ + @Since("2.0.0") + def this( + algo: Algo, + impurity: Impurity, + maxDepth: Int, + numClasses: Int, + maxBins: Int, + quantileCalculationStrategy: QuantileStrategy, + categoricalFeaturesInfo: Map[Int, Int], + minInstancesPerNode: Int, + minInfoGain: Double, + maxMemoryInMB: Int, + subsamplingRate: Double, + useNodeIdCache: Boolean, + checkpointInterval: Int) { + this(algo, impurity, maxDepth, numClasses, maxBins, quantileCalculationStrategy, + categoricalFeaturesInfo, minInstancesPerNode, minInfoGain, maxMemoryInMB, + subsamplingRate, useNodeIdCache, checkpointInterval, Array()) + } + /** * Java-friendly constructor for [[org.apache.spark.mllib.tree.configuration.Strategy]] */ @@ -140,9 +168,9 @@ class Strategy @Since("1.3.0") ( require(numClasses >= 2, s"DecisionTree Strategy for Classification must have numClasses >= 2," + s" but numClasses = $numClasses.") - require(Set(Gini, Entropy).contains(impurity), + require(Set(Gini, Entropy, WeightedGini).contains(impurity), s"DecisionTree Strategy given invalid impurity for Classification: $impurity." + - s" Valid settings: Gini, Entropy") + s" Valid settings: Gini, Entropy, WeightedGini") case Regression => require(impurity == Variance, s"DecisionTree Strategy given invalid impurity for Regression: $impurity." + @@ -163,6 +191,14 @@ class Strategy @Since("1.3.0") ( require(subsamplingRate > 0 && subsamplingRate <= 1, s"DecisionTree Strategy requires subsamplingRate <=1 and >0, but was given " + s"$subsamplingRate") + if (impurity == WeightedGini) { + require(numClasses == classWeights.length, + s"DecisionTree Strategy requires the number of class weights be the same as the " + + s"number of classes, but there are $numClasses classes and ${classWeights.length} weights") + require(classWeights.forall((x: Double) => x >= 0), + s"DecisionTree Strategy requires the all the class weights be non-negative" + + s", but at least one of them is negative") + } } /** @@ -172,7 +208,7 @@ class Strategy @Since("1.3.0") ( def copy: Strategy = { new Strategy(algo, impurity, maxDepth, numClasses, maxBins, quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode, minInfoGain, - maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointInterval) + maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointInterval, classWeights) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index ff7700d2d1b7f..de24ba8444512 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -138,6 +138,11 @@ private[spark] class EntropyCalculator(stats: Array[Double]) extends ImpurityCal */ def count: Long = stats.sum.toLong + /** + * Weighted summary statistics of data points, which in this case assume uniform class weights + */ + def weightedCount: Double = stats.sum + /** * Prediction which should be made based on the sufficient statistics. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 58dc79b7398e2..ded6488ddc79c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -134,6 +134,11 @@ private[spark] class GiniCalculator(stats: Array[Double]) extends ImpurityCalcul */ def count: Long = stats.sum.toLong + /** + * Weighted summary statistics of data points, which in this case assume uniform class weights + */ + def weightedCount: Double = stats.sum + /** * Prediction which should be made based on the sufficient statistics. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 65f0163ec6059..b91752f0ff2c4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -99,6 +99,7 @@ private[spark] abstract class ImpurityAggregator(val statsSize: Int) extends Ser */ private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) extends Serializable { + val weightedStats: Array[Double] = stats /** * Make a deep copy of this [[ImpurityCalculator]]. */ @@ -147,6 +148,11 @@ private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) exten */ def count: Long + /** + * Weighted summary statistics of data points + */ + def weightedCount: Double + /** * Prediction which should be made based on the sufficient statistics. */ @@ -185,11 +191,13 @@ private[spark] object ImpurityCalculator { * Create an [[ImpurityCalculator]] instance of the given impurity type and with * the given stats. */ - def getCalculator(impurity: String, stats: Array[Double]): ImpurityCalculator = { + def getCalculator(impurity: String, stats: Array[Double], + classWeights: Array[Double]): ImpurityCalculator = { impurity match { case "gini" => new GiniCalculator(stats) case "entropy" => new EntropyCalculator(stats) case "variance" => new VarianceCalculator(stats) + case "weightedgini" => new WeightedGiniCalculator(stats, classWeights) case _ => throw new IllegalArgumentException( s"ImpurityCalculator builder did not recognize impurity type: $impurity") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 2423516123b82..1087139fb4bd2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -122,6 +122,11 @@ private[spark] class VarianceCalculator(stats: Array[Double]) extends ImpurityCa */ def count: Long = stats(0).toLong + /** + * Weighted summary statistics of data points, which in this case assume uniform class weights + */ + def weightedCount: Double = stats(0) + /** * Prediction which should be made based on the sufficient statistics. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/WeightedGini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/WeightedGini.scala new file mode 100644 index 0000000000000..90232d07a6916 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/WeightedGini.scala @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impurity + +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} + +/** + * :: Experimental :: + * Class for calculating the Gini impurity with class weights using + * altered prior method during classification. + */ +@Since("2.0.0") +@Experimental +object WeightedGini extends Impurity { + + /** + * :: DeveloperApi :: + * information calculation for multiclass classification + * @param weightedCounts Array[Double] with counts for each label + * @param weightedTotalCount sum of counts for all labels + * @return information value, or 0 if totalCount = 0 + */ + @Since("2.0.0") + @DeveloperApi + override def calculate(weightedCounts: Array[Double], weightedTotalCount: Double): Double = { + if (weightedTotalCount == 0) { + return 0 + } + val numClasses = weightedCounts.length + var impurity = 1.0 + var classIndex = 0 + while (classIndex < numClasses) { + val freq = weightedCounts(classIndex) / weightedTotalCount + impurity -= freq * freq + classIndex += 1 + } + impurity + } + + /** + * :: DeveloperApi :: + * variance calculation + * @param count number of instances + * @param sum sum of labels + * @param sumSquares summation of squares of the labels + * @return information value, or 0 if count = 0 + */ + @Since("2.0.0") + @DeveloperApi + override def calculate(count: Double, sum: Double, sumSquares: Double): Double = + throw new UnsupportedOperationException("WeightedGini.calculate") + + /** + * Get this impurity instance. + * This is useful for passing impurity parameters to a Strategy in Java. + */ + @Since("2.0.0") + def instance: this.type = this + +} + +/** + * Class for updating views of a vector of sufficient statistics, + * in order to compute impurity from a sample. + * Note: Instances of this class do not hold the data; they operate on views of the data. + * @param numClasses Number of classes for label. + * @param classWeights Weights of classes + */ +private[spark] class WeightedGiniAggregator(numClasses: Int, classWeights: Array[Double]) + extends ImpurityAggregator(numClasses) with Serializable { + + /** + * Update stats for one (node, feature, bin) with the given label. + * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. + * @param offset Start index of stats for this (node, feature, bin). + */ + def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = { + if (label >= statsSize) { + throw new IllegalArgumentException(s"WeightedGiniAggregator given label $label" + + s" but requires label < numClasses (= $statsSize).") + } + if (label < 0) { + throw new IllegalArgumentException(s"WeightedGiniAggregator given label $label" + + s"but requires label is non-negative.") + } + allStats(offset + label.toInt) += instanceWeight + } + + /** + * Get an [[ImpurityCalculator]] for a (node, feature, bin). + * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. + * @param offset Start index of stats for this (node, feature, bin). + */ + def getCalculator(allStats: Array[Double], offset: Int): WeightedGiniCalculator = { + new WeightedGiniCalculator(allStats.view(offset, offset + statsSize).toArray, classWeights) + } +} + +/** + * Stores statistics for one (node, feature, bin) for calculating impurity. + * Unlike [[WeightedGiniAggregator]], this class stores its own data and is for a specific + * (node, feature, bin). + * @param stats Array of sufficient statistics for a (node, feature, bin). + * @param classWeights Weights of classes + */ +private[spark] class WeightedGiniCalculator(stats: Array[Double], classWeights: Array[Double]) + extends ImpurityCalculator(stats) { + + override val weightedStats = stats.zip(classWeights).map(x => x._1 * x._2) + /** + * Make a deep copy of this [[ImpurityCalculator]]. + */ + def copy: WeightedGiniCalculator = new WeightedGiniCalculator(stats.clone(), classWeights.clone()) + + /** + * Calculate the impurity from the stored sufficient statistics. + */ + def calculate(): Double = WeightedGini.calculate(weightedStats, weightedStats.sum) + + /** + * Number of data points accounted for in the sufficient statistics. + */ + def count: Long = stats.sum.toLong + + /** + * Weighted summary statistics of data points + */ + def weightedCount: Double = weightedStats.sum + + /** + * Prediction which should be made based on the sufficient statistics. + */ + def predict: Double = if (count == 0) { + 0 + } else { + indexOfLargestArrayElement(weightedStats) + } + + /** + * Probability of the label given by [[predict]]. + */ + override def prob(label: Double): Double = { + val lbl = label.toInt + require(lbl < stats.length, + s"WeightedGiniCalculator.prob given invalid label: $lbl (should be < ${stats.length}") + require(lbl >= 0, "WeightedGiniImpurity does not support negative labels") + val cnt = weightedCount + if (cnt == 0) { + 0 + } else { + weightedStats(lbl) / cnt + } + } + + override def toString: String = s"WeightedGiniCalculator(stats = [${stats.mkString(", ")}])" + + /** + * Add the stats from another calculator into this one, modifying and returning this calculator. + * Update the weightedStats at the same time + */ + override def add(other: ImpurityCalculator): ImpurityCalculator = { + require(stats.length == other.stats.length, + s"Two ImpurityCalculator instances cannot be added with different counts sizes." + + s" Sizes are ${stats.length} and ${other.stats.length}.") + val otherCalculator = other.asInstanceOf[WeightedGiniCalculator] + val len = otherCalculator.stats.length + var i = 0 + while (i < len) { + stats(i) += otherCalculator.stats(i) + weightedStats(i) += otherCalculator.weightedStats(i) + i += 1 + } + this + } + + /** + * Subtract the stats from another calculator from this one, modifying and returning this + * calculator. Update the weightedStats at the same time + */ + override def subtract(other: ImpurityCalculator): ImpurityCalculator = { + require(stats.length == other.stats.length, + s"Two ImpurityCalculator instances cannot be subtracted with different counts sizes." + + s" Sizes are ${stats.length} and ${other.stats.length}.") + val otherCalculator = other.asInstanceOf[WeightedGiniCalculator] + val len = otherCalculator.stats.length + var i = 0 + while (i < len) { + stats(i) -= otherCalculator.stats(i) + weightedStats(i) -= otherCalculator.weightedStats(i) + i += 1 + } + this + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 089d30abb5ef9..096ab2467ab83 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -69,6 +69,18 @@ class DecisionTreeClassifierSuite // Tests calling train() ///////////////////////////////////////////////////////////////////////////// + test("Binary classification with explicitly setting uniform class weights") { + val dt = new DecisionTreeClassifier() + .setImpurity("WeightedGini") + .setMaxDepth(2) + .setMaxBins(100) + .setSeed(1) + .setClassWeights(Array(1, 1)) + val categoricalFeatures = Map(0 -> 3, 1 -> 3) + val numClasses = 2 + compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures, numClasses) + } + test("Binary classification stump with ordered categorical features") { val dt = new DecisionTreeClassifier() .setImpurity("gini") diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index 2e99ee157ae95..5ea110ec0d020 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -234,7 +234,8 @@ private object RandomForestClassifierSuite extends SparkFunSuite { numClasses: Int): Unit = { val numFeatures = data.first().features.size val oldStrategy = - rf.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, rf.getOldImpurity) + rf.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, + rf.getOldImpurity, rf.getSubsamplingRate, rf.getClassWeights) val oldModel = OldRandomForest.trainClassifier( data.map(OldLabeledPoint.fromML), oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index c08335f9f84af..169dcdd3f567d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -140,7 +140,9 @@ private object RandomForestRegressorSuite extends SparkFunSuite { categoricalFeatures: Map[Int, Int]): Unit = { val numFeatures = data.first().features.size val oldStrategy = - rf.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, rf.getOldImpurity) + rf.getOldStrategy(categoricalFeatures, numClasses = 0, + OldAlgo.Regression, rf.getOldImpurity, rf.getSubsamplingRate, + classWeights = Array()) val oldModel = OldRandomForest.trainRegressor(data.map(OldLabeledPoint.fromML), oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index dcc2f305df75a..dce4e698b82cd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -93,7 +93,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, Map(), Set(), Array(6), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0, 0 + 0, 0, 0.0, 0, 0, Array[Double]() ) val featureSamples = Array.fill(200000)(math.random) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) @@ -110,7 +110,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, Map(), Set(), Array(5), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0, 0 + 0, 0, 0.0, 0, 0, Array[Double]() ) val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) @@ -124,7 +124,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, Map(), Set(), Array(3), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0, 0 + 0, 0, 0.0, 0, 0, Array[Double]() ) val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) @@ -138,7 +138,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, Map(), Set(), Array(3), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0, 0 + 0, 0, 0.0, 0, 0, Array[Double]() ) val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 270104f85b838..57c275baed215 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -94,7 +94,7 @@ This file is divided into 3 sections: - +