diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 881dcefb79be3..59aaa1cd457a7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -82,6 +82,9 @@ class DecisionTreeClassifier @Since("1.4.0") (
   @Since("1.6.0")
   override def setSeed(value: Long): this.type = super.setSeed(value)
 
+  @Since("2.0.0")
+  override def setClassWeights(value: Array[Double]): this.type = super.setClassWeights(value)
+
   override protected def train(dataset: Dataset[_]): DecisionTreeClassificationModel = {
     val categoricalFeatures: Map[Int, Int] =
       MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
@@ -119,7 +122,7 @@ class DecisionTreeClassifier @Since("1.4.0") (
       categoricalFeatures: Map[Int, Int],
       numClasses: Int): OldStrategy = {
     super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity,
-      subsamplingRate = 1.0)
+      subsamplingRate = 1.0, getClassWeights)
   }
 
   @Since("1.4.1")
@@ -129,7 +132,7 @@ class DecisionTreeClassifier @Since("1.4.0") (
 @Since("1.4.0")
 @Experimental
 object DecisionTreeClassifier extends DefaultParamsReadable[DecisionTreeClassifier] {
-  /** Accessor for supported impurities: entropy, gini */
+  /** Accessor for supported impurities: entropy, gini, weightedgini */
   @Since("1.4.0")
   final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities
 
@@ -168,7 +171,7 @@ class DecisionTreeClassificationModel private[ml] (
   }
 
   override protected def predictRaw(features: Vector): Vector = {
-    Vectors.dense(rootNode.predictImpl(features).impurityStats.stats.clone())
+    Vectors.dense(rootNode.predictImpl(features).impurityStats.weightedStats.clone())
   }
 
   override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index b3c074f839250..5e61b759c7c6e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -98,13 +98,17 @@ class RandomForestClassifier @Since("1.4.0") (
   override def setFeatureSubsetStrategy(value: String): this.type =
     super.setFeatureSubsetStrategy(value)
 
+  @Since("2.0.0")
+  override def setClassWeights(value: Array[Double]): this.type = super.setClassWeights(value)
+
   override protected def train(dataset: Dataset[_]): RandomForestClassificationModel = {
     val categoricalFeatures: Map[Int, Int] =
       MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
     val numClasses: Int = getNumClasses(dataset)
     val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
     val strategy =
-      super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)
+      super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification,
+        getOldImpurity, getSubsamplingRate, getClassWeights)
 
     val instr = Instrumentation.create(this, oldDataset)
     instr.logParams(params: _*)
@@ -195,7 +199,8 @@ class RandomForestClassificationModel private[ml] (
     // Ignore the tree weights since all are 1.0 for now.
     val votes = Array.fill[Double](numClasses)(0.0)
     _trees.view.foreach { tree =>
-      val classCounts: Array[Double] = tree.rootNode.predictImpl(features).impurityStats.stats
+      val classCounts: Array[Double] =
+        tree.rootNode.predictImpl(features).impurityStats.weightedStats
       val total = classCounts.sum
       if (total != 0) {
         var i = 0
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index c4df9d11127f4..b2fe5ded61793 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -117,7 +117,7 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
   /** (private[ml]) Create a Strategy instance to use with the old API. */
   private[ml] def getOldStrategy(categoricalFeatures: Map[Int, Int]): OldStrategy = {
     super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity,
-      subsamplingRate = 1.0)
+      subsamplingRate = 1.0, classWeights = Array())
   }
 
   @Since("1.4.0")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index a6dbf21d55e2b..9429a053804d1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -98,7 +98,8 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
       MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
     val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
     val strategy =
-      super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity)
+      super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression,
+        getOldImpurity, getSubsamplingRate, classWeights = Array())
 
     val instr = Instrumentation.create(this, oldDataset)
     instr.logParams(params: _*)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala
index 61091bb803e49..3d175006b9abe 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala
@@ -20,7 +20,6 @@ package org.apache.spark.ml.tree.impl
 import org.apache.spark.mllib.tree.impurity._
 
 
-
 /**
  * DecisionTree statistics aggregator for a node.
  * This holds a flat array of statistics for a set of (features, bins)
@@ -38,6 +37,7 @@ private[spark] class DTStatsAggregator(
     case Gini => new GiniAggregator(metadata.numClasses)
     case Entropy => new EntropyAggregator(metadata.numClasses)
     case Variance => new VarianceAggregator()
+    case WeightedGini => new WeightedGiniAggregator(metadata.numClasses, metadata.classWeights)
     case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}")
   }
 
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
index 442f52bf0231d..a8ad966adf1cb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
@@ -53,7 +53,8 @@ private[spark] class DecisionTreeMetadata(
     val minInstancesPerNode: Int,
     val minInfoGain: Double,
     val numTrees: Int,
-    val numFeaturesPerNode: Int) extends Serializable {
+    val numFeaturesPerNode: Int,
+    val classWeights: Array[Double]) extends Serializable {
 
   def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex)
 
@@ -207,7 +208,8 @@ private[spark] object DecisionTreeMetadata extends Logging {
     new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
       strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
       strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth,
-      strategy.minInstancesPerNode, strategy.minInfoGain, numTrees, numFeaturesPerNode)
+      strategy.minInstancesPerNode, strategy.minInfoGain, numTrees, numFeaturesPerNode,
+      strategy.classWeights)
   }
 
   /**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index 71c8c42ce5eba..fe83d602764a4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -657,8 +657,15 @@ private[spark] object RandomForest extends Logging {
     val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
     val rightImpurity = rightImpurityCalculator.calculate()
 
-    val leftWeight = leftCount / totalCount.toDouble
-    val rightWeight = rightCount / totalCount.toDouble
+    // Weighted count is equivalent to normal count using Gini or Entropy impurity
+    // where the class weights are assumed to be uniform
+    val leftWeightedCount = leftImpurityCalculator.weightedCount
+    val rightWeightedCount = rightImpurityCalculator.weightedCount
+
+    val totalWeightedCount = leftWeightedCount + rightWeightedCount
+
+    val leftWeight = leftWeightedCount / totalWeightedCount.toDouble
+    val rightWeight = rightWeightedCount / totalWeightedCount.toDouble
 
     val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
 
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
index 56c85c9b53e17..029ccfec2e2cb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -342,9 +342,17 @@ private[ml] object DecisionTreeModelReadWrite {
       Param.jsonDecode[String](compact(render(impurityJson)))
     }
 
+    // Get class weights to construct ImpurityCalculator. This value
+    // is ignored unless the impurity is WeightedGini
+    val classWeights: Array[Double] = {
+      val classWeightsJson: JValue = metadata.getParamValue("classWeights")
+      compact(render(classWeightsJson)).split("\\[|,|\\]")
+        .filter((s: String) => s.length() != 0).map((s: String) => s.toDouble)
+    }
+
     val dataPath = new Path(path, "data").toString
     val data = sqlContext.read.parquet(dataPath).as[NodeData]
-    buildTreeFromNodes(data.collect(), impurityType)
+    buildTreeFromNodes(data.collect(), impurityType, classWeights)
   }
 
   /**
@@ -353,7 +361,8 @@ private[ml] object DecisionTreeModelReadWrite {
    * @param impurityType  Impurity type for this tree
    * @return Root node of reconstructed tree
    */
-  def buildTreeFromNodes(data: Array[NodeData], impurityType: String): Node = {
+  def buildTreeFromNodes(data: Array[NodeData], impurityType: String,
+                         classWeights: Array[Double]): Node = {
     // Load all nodes, sorted by ID.
     val nodes = data.sortBy(_.id)
     // Sanity checks; could remove
@@ -365,7 +374,8 @@ private[ml] object DecisionTreeModelReadWrite {
     // traversal, this guarantees that child nodes will be built before parent nodes.
     val finalNodes = new Array[Node](nodes.length)
     nodes.reverseIterator.foreach { case n: NodeData =>
-      val impurityStats = ImpurityCalculator.getCalculator(impurityType, n.impurityStats)
+      val impurityStats = ImpurityCalculator.getCalculator(impurityType,
+        n.impurityStats, classWeights)
       val node = if (n.leftChild != -1) {
         val leftChild = finalNodes(n.leftChild)
         val rightChild = finalNodes(n.rightChild)
@@ -437,6 +447,15 @@ private[ml] object EnsembleModelReadWrite {
       Param.jsonDecode[String](compact(render(impurityJson)))
     }
 
+    // Get class weights to construct ImpurityCalculator. This value
+    // is ignored unless the impurity is WeightedGini
+    val classWeights: Array[Double] = {
+      val classWeightsJson: JValue = metadata.getParamValue("classWeights")
+      val classWeightsArray = compact(render(classWeightsJson)).split("\\[|,|\\]")
+        .filter((s: String) => s.length() != 0).map((s: String) => s.toDouble)
+      classWeightsArray
+    }
+
     val treesMetadataPath = new Path(path, "treesMetadata").toString
     val treesMetadataRDD: RDD[(Int, (Metadata, Double))] = sql.read.parquet(treesMetadataPath)
       .select("treeID", "metadata", "weights").as[(Int, String, Double)].rdd.map {
@@ -454,7 +473,8 @@ private[ml] object EnsembleModelReadWrite {
     val rootNodesRDD: RDD[(Int, Node)] =
       nodeData.rdd.map(d => (d.treeID, d.nodeData)).groupByKey().map {
         case (treeID: Int, nodeData: Iterable[NodeData]) =>
-          treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes(nodeData.toArray, impurityType)
+          treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes(nodeData.toArray,
+            impurityType, classWeights)
       }
     val rootNodes: Array[Node] = rootNodesRDD.sortByKey().values.collect()
     (metadata, treesMetadata.zip(rootNodes), treesWeights)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index d7559f8950c3d..aba5ab1aec455 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -24,7 +24,7 @@ import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util.SchemaUtils
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy}
-import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance}
+import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance, WeightedGini}
 import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, LogLoss => OldLogLoss, Loss => OldLoss, SquaredError => OldSquaredError}
 import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
 
@@ -155,7 +155,31 @@ private[ml] trait DecisionTreeParams extends PredictorParams
    */
   def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
 
-  /** (private[ml]) Create a Strategy instance to use with the old API. */
+  /** (private[ml]) Create a Strategy instance. */
+  private[ml] def getOldStrategy(
+      categoricalFeatures: Map[Int, Int],
+      numClasses: Int,
+      oldAlgo: OldAlgo.Algo,
+      oldImpurity: OldImpurity,
+      subsamplingRate: Double,
+      classWeights: Array[Double]): OldStrategy = {
+    val strategy = OldStrategy.defaultStrategy(oldAlgo)
+    strategy.impurity = oldImpurity
+    strategy.checkpointInterval = getCheckpointInterval
+    strategy.maxBins = getMaxBins
+    strategy.maxDepth = getMaxDepth
+    strategy.maxMemoryInMB = getMaxMemoryInMB
+    strategy.minInfoGain = getMinInfoGain
+    strategy.minInstancesPerNode = getMinInstancesPerNode
+    strategy.useNodeIdCache = getCacheNodeIds
+    strategy.numClasses = numClasses
+    strategy.categoricalFeaturesInfo = categoricalFeatures
+    strategy.subsamplingRate = subsamplingRate
+    strategy.classWeights = classWeights
+    strategy
+  }
+
+  /** (private[ml]) Create a Strategy whose interface is compatible with the old API. */
   private[ml] def getOldStrategy(
       categoricalFeatures: Map[Int, Int],
       numClasses: Int,
@@ -174,6 +198,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams
     strategy.numClasses = numClasses
     strategy.categoricalFeaturesInfo = categoricalFeatures
     strategy.subsamplingRate = subsamplingRate
+    strategy.classWeights = Array(1.0, 1.0)
     strategy
   }
 }
@@ -185,7 +210,7 @@ private[ml] trait TreeClassifierParams extends Params {
 
   /**
    * Criterion used for information gain calculation (case-insensitive).
-   * Supported: "entropy" and "gini".
+   * Supported: "entropy", "gini" and "weightedgini".
    * (default = gini)
    * @group param
    */
@@ -194,7 +219,15 @@ private[ml] trait TreeClassifierParams extends Params {
     s" ${TreeClassifierParams.supportedImpurities.mkString(", ")}",
     (value: String) => TreeClassifierParams.supportedImpurities.contains(value.toLowerCase))
 
-  setDefault(impurity -> "gini")
+  /**
+   * An array that stores the weights of class labels. All elements must be non-negative.
+   * (default = Array(1.0, 1.0))
+   * @group Param
+   */
+  final val classWeights: DoubleArrayParam = new DoubleArrayParam(this, "classWeights", "An array" +
+    " that stores the weights of class labels. All elements must be non-negative.")
+
+  setDefault(impurity -> "gini", classWeights -> Array(1.0, 1.0))
 
   /** @group setParam */
   def setImpurity(value: String): this.type = set(impurity, value)
@@ -202,11 +235,18 @@ private[ml] trait TreeClassifierParams extends Params {
   /** @group getParam */
   final def getImpurity: String = $(impurity).toLowerCase
 
+  /** @group SetParam */
+  def setClassWeights(value: Array[Double]): this.type = set(classWeights, value)
+
+  /** @group GetParam */
+  final def getClassWeights: Array[Double] = $(classWeights)
+
   /** Convert new impurity to old impurity. */
   private[ml] def getOldImpurity: OldImpurity = {
     getImpurity match {
       case "entropy" => OldEntropy
       case "gini" => OldGini
+      case "weightedgini" => WeightedGini
       case _ =>
         // Should never happen because of check in setter method.
         throw new RuntimeException(
@@ -217,7 +257,8 @@ private[ml] trait TreeClassifierParams extends Params {
 
 private[ml] object TreeClassifierParams {
   // These options should be lowercase.
-  final val supportedImpurities: Array[String] = Array("entropy", "gini").map(_.toLowerCase)
+  final val supportedImpurities: Array[String] = Array("entropy", "gini", "weightedgini")
+    .map(_.toLowerCase)
 }
 
 private[ml] trait DecisionTreeClassifierParams
@@ -239,7 +280,16 @@ private[ml] trait TreeRegressorParams extends Params {
     s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}",
     (value: String) => TreeRegressorParams.supportedImpurities.contains(value.toLowerCase))
 
-  setDefault(impurity -> "variance")
+  /**
+   * An array that stores the weights of class labels. This parameter will be ignored in
+   * regression trees.
+   * (default = Array())
+   * @group Param
+   */
+  final val classWeights: DoubleArrayParam = new DoubleArrayParam(this, "classWeights", "An array" +
+    " that stores the weights of class labels. All elements must be non-negative.")
+
+  setDefault(impurity -> "variance", classWeights -> Array())
 
   /** @group setParam */
   def setImpurity(value: String): this.type = set(impurity, value)
@@ -247,6 +297,12 @@ private[ml] trait TreeRegressorParams extends Params {
   /** @group getParam */
   final def getImpurity: String = $(impurity).toLowerCase
 
+  /** @group SetParam */
+  def setClassWeights(value: Array[Double]): this.type = set(classWeights, value)
+
+  /** @group GetParam */
+  final def getClassWeights: Array[Double] = $(classWeights)
+
   /** Convert new impurity to old impurity. */
   private[ml] def getOldImpurity: OldImpurity = {
     getImpurity match {
@@ -312,8 +368,19 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams {
       categoricalFeatures: Map[Int, Int],
       numClasses: Int,
       oldAlgo: OldAlgo.Algo,
-      oldImpurity: OldImpurity): OldStrategy = {
-    super.getOldStrategy(categoricalFeatures, numClasses, oldAlgo, oldImpurity, getSubsamplingRate)
+      oldImpurity: OldImpurity,
+      classWeights: Array[Double]): OldStrategy = {
+    super.getOldStrategy(categoricalFeatures, numClasses, oldAlgo,
+      oldImpurity, getSubsamplingRate, classWeights)
+  }
+
+  private[ml] def getOldStrategy(
+    categoricalFeatures: Map[Int, Int],
+    numClasses: Int,
+    oldAlgo: OldAlgo.Algo,
+    oldImpurity: OldImpurity): OldStrategy = {
+  super.getOldStrategy(categoricalFeatures, numClasses, oldAlgo,
+      oldImpurity, getSubsamplingRate, Array(1.0, 1.0))
   }
 }
 
@@ -455,7 +522,9 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS
   private[ml] def getOldBoostingStrategy(
       categoricalFeatures: Map[Int, Int],
       oldAlgo: OldAlgo.Algo): OldBoostingStrategy = {
-    val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 2, oldAlgo, OldVariance)
+    val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 2,
+      oldAlgo, OldVariance, Array(1.0, 1.0))
+
     // NOTE: The old API does not support "seed" so we ignore it.
     new OldBoostingStrategy(strategy, getOldLossType, getMaxIter, getStepSize)
   }
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index b34e1b1b56c43..e96350db6bb1d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -23,7 +23,7 @@ import scala.collection.JavaConverters._
 import org.apache.spark.annotation.Since
 import org.apache.spark.mllib.tree.configuration.Algo._
 import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
-import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance}
+import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance, WeightedGini}
 
 /**
  * Stores all the configuration options for tree construction
@@ -32,6 +32,7 @@ import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance}
  *              [[org.apache.spark.mllib.tree.configuration.Algo.Regression]]
  * @param impurity Criterion used for information gain calculation.
  *                 Supported for Classification: [[org.apache.spark.mllib.tree.impurity.Gini]],
+  *                 [[org.apache.spark.mllib.tree.impurity.WeightedGini]],
  *                  [[org.apache.spark.mllib.tree.impurity.Entropy]].
  *                 Supported for Regression: [[org.apache.spark.mllib.tree.impurity.Variance]].
  * @param maxDepth Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means
@@ -65,6 +66,8 @@ import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance}
  *                           E.g. 10 means that the cache will get checkpointed every 10 updates. If
  *                           the checkpoint directory is not set in
  *                           [[org.apache.spark.SparkContext]], this setting is ignored.
+ * @param classWeights Weights of classes used in classification problems. It will be ignored in
+ *                     regression problems.
  */
 @Since("1.0.0")
 class Strategy @Since("1.3.0") (
@@ -80,7 +83,9 @@ class Strategy @Since("1.3.0") (
     @Since("1.0.0") @BeanProperty var maxMemoryInMB: Int = 256,
     @Since("1.2.0") @BeanProperty var subsamplingRate: Double = 1,
     @Since("1.2.0") @BeanProperty var useNodeIdCache: Boolean = false,
-    @Since("1.2.0") @BeanProperty var checkpointInterval: Int = 10) extends Serializable {
+    @Since("1.2.0") @BeanProperty var checkpointInterval: Int = 10,
+    @Since("2.0.0") @BeanProperty var classWeights: Array[Double] = Array(1.0, 1.0))
+  extends Serializable {
 
   /**
    */
@@ -96,6 +101,29 @@ class Strategy @Since("1.3.0") (
     isMulticlassClassification && (categoricalFeaturesInfo.size > 0)
   }
 
+  /**
+   *  Make the Strategy class compatible with old API
+   */
+  @Since("2.0.0")
+  def this(
+    algo: Algo,
+    impurity: Impurity,
+    maxDepth: Int,
+    numClasses: Int,
+    maxBins: Int,
+    quantileCalculationStrategy: QuantileStrategy,
+    categoricalFeaturesInfo: Map[Int, Int],
+    minInstancesPerNode: Int,
+    minInfoGain: Double,
+    maxMemoryInMB: Int,
+    subsamplingRate: Double,
+    useNodeIdCache: Boolean,
+    checkpointInterval: Int) {
+    this(algo, impurity, maxDepth, numClasses, maxBins, quantileCalculationStrategy,
+      categoricalFeaturesInfo, minInstancesPerNode, minInfoGain, maxMemoryInMB,
+      subsamplingRate, useNodeIdCache, checkpointInterval, Array())
+  }
+
   /**
    * Java-friendly constructor for [[org.apache.spark.mllib.tree.configuration.Strategy]]
    */
@@ -140,9 +168,9 @@ class Strategy @Since("1.3.0") (
         require(numClasses >= 2,
           s"DecisionTree Strategy for Classification must have numClasses >= 2," +
           s" but numClasses = $numClasses.")
-        require(Set(Gini, Entropy).contains(impurity),
+        require(Set(Gini, Entropy, WeightedGini).contains(impurity),
           s"DecisionTree Strategy given invalid impurity for Classification: $impurity." +
-          s"  Valid settings: Gini, Entropy")
+          s"  Valid settings: Gini, Entropy, WeightedGini")
       case Regression =>
         require(impurity == Variance,
           s"DecisionTree Strategy given invalid impurity for Regression: $impurity." +
@@ -163,6 +191,14 @@ class Strategy @Since("1.3.0") (
     require(subsamplingRate > 0 && subsamplingRate <= 1,
       s"DecisionTree Strategy requires subsamplingRate <=1 and >0, but was given " +
       s"$subsamplingRate")
+    if (impurity == WeightedGini) {
+      require(numClasses == classWeights.length,
+        s"DecisionTree Strategy requires the number of class weights be the same as the " +
+        s"number of classes, but there are $numClasses classes and ${classWeights.length} weights")
+      require(classWeights.forall((x: Double) => x >= 0),
+        s"DecisionTree Strategy requires the all the class weights be non-negative" +
+        s", but at least one of them is negative")
+    }
   }
 
   /**
@@ -172,7 +208,7 @@ class Strategy @Since("1.3.0") (
   def copy: Strategy = {
     new Strategy(algo, impurity, maxDepth, numClasses, maxBins,
       quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode, minInfoGain,
-      maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointInterval)
+      maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointInterval, classWeights)
   }
 }
 
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
index ff7700d2d1b7f..de24ba8444512 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
@@ -138,6 +138,11 @@ private[spark] class EntropyCalculator(stats: Array[Double]) extends ImpurityCal
    */
   def count: Long = stats.sum.toLong
 
+  /**
+   * Weighted summary statistics of data points, which in this case assume uniform class weights
+   */
+  def weightedCount: Double = stats.sum
+
   /**
    * Prediction which should be made based on the sufficient statistics.
    */
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
index 58dc79b7398e2..ded6488ddc79c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
@@ -134,6 +134,11 @@ private[spark] class GiniCalculator(stats: Array[Double]) extends ImpurityCalcul
    */
   def count: Long = stats.sum.toLong
 
+  /**
+   * Weighted summary statistics of data points, which in this case assume uniform class weights
+   */
+  def weightedCount: Double = stats.sum
+
   /**
    * Prediction which should be made based on the sufficient statistics.
    */
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
index 65f0163ec6059..b91752f0ff2c4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
@@ -99,6 +99,7 @@ private[spark] abstract class ImpurityAggregator(val statsSize: Int) extends Ser
  */
 private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) extends Serializable {
 
+  val weightedStats: Array[Double] = stats
   /**
    * Make a deep copy of this [[ImpurityCalculator]].
    */
@@ -147,6 +148,11 @@ private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) exten
    */
   def count: Long
 
+  /**
+   * Weighted summary statistics of data points
+   */
+  def weightedCount: Double
+
   /**
    * Prediction which should be made based on the sufficient statistics.
    */
@@ -185,11 +191,13 @@ private[spark] object ImpurityCalculator {
    * Create an [[ImpurityCalculator]] instance of the given impurity type and with
    * the given stats.
    */
-  def getCalculator(impurity: String, stats: Array[Double]): ImpurityCalculator = {
+  def getCalculator(impurity: String, stats: Array[Double],
+                    classWeights: Array[Double]): ImpurityCalculator = {
     impurity match {
       case "gini" => new GiniCalculator(stats)
       case "entropy" => new EntropyCalculator(stats)
       case "variance" => new VarianceCalculator(stats)
+      case "weightedgini" => new WeightedGiniCalculator(stats, classWeights)
       case _ =>
         throw new IllegalArgumentException(
           s"ImpurityCalculator builder did not recognize impurity type: $impurity")
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
index 2423516123b82..1087139fb4bd2 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
@@ -122,6 +122,11 @@ private[spark] class VarianceCalculator(stats: Array[Double]) extends ImpurityCa
    */
   def count: Long = stats(0).toLong
 
+  /**
+   * Weighted summary statistics of data points, which in this case assume uniform class weights
+   */
+  def weightedCount: Double = stats(0)
+
   /**
    * Prediction which should be made based on the sufficient statistics.
    */
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/WeightedGini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/WeightedGini.scala
new file mode 100644
index 0000000000000..90232d07a6916
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/WeightedGini.scala
@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.impurity
+
+import org.apache.spark.annotation.{DeveloperApi, Experimental, Since}
+
+/**
+ * :: Experimental ::
+ * Class for calculating the Gini impurity with class weights using
+ * altered prior method during classification.
+ */
+@Since("2.0.0")
+@Experimental
+object WeightedGini extends Impurity {
+
+  /**
+   * :: DeveloperApi ::
+   * information calculation for multiclass classification
+   * @param weightedCounts Array[Double] with counts for each label
+   * @param weightedTotalCount sum of counts for all labels
+   * @return information value, or 0 if totalCount = 0
+   */
+  @Since("2.0.0")
+  @DeveloperApi
+  override def calculate(weightedCounts: Array[Double], weightedTotalCount: Double): Double = {
+    if (weightedTotalCount == 0) {
+      return 0
+    }
+    val numClasses = weightedCounts.length
+    var impurity = 1.0
+    var classIndex = 0
+    while (classIndex < numClasses) {
+      val freq = weightedCounts(classIndex) / weightedTotalCount
+      impurity -= freq * freq
+      classIndex += 1
+    }
+    impurity
+  }
+
+  /**
+   * :: DeveloperApi ::
+   * variance calculation
+   * @param count number of instances
+   * @param sum sum of labels
+   * @param sumSquares summation of squares of the labels
+   * @return information value, or 0 if count = 0
+   */
+  @Since("2.0.0")
+  @DeveloperApi
+  override def calculate(count: Double, sum: Double, sumSquares: Double): Double =
+    throw new UnsupportedOperationException("WeightedGini.calculate")
+
+  /**
+   * Get this impurity instance.
+   * This is useful for passing impurity parameters to a Strategy in Java.
+   */
+  @Since("2.0.0")
+  def instance: this.type = this
+
+}
+
+/**
+ * Class for updating views of a vector of sufficient statistics,
+ * in order to compute impurity from a sample.
+ * Note: Instances of this class do not hold the data; they operate on views of the data.
+ * @param numClasses  Number of classes for label.
+ * @param classWeights Weights of classes
+ */
+private[spark] class WeightedGiniAggregator(numClasses: Int, classWeights: Array[Double])
+  extends ImpurityAggregator(numClasses) with Serializable {
+
+  /**
+   * Update stats for one (node, feature, bin) with the given label.
+   * @param allStats  Flat stats array, with stats for this (node, feature, bin) contiguous.
+   * @param offset    Start index of stats for this (node, feature, bin).
+   */
+  def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = {
+    if (label >= statsSize) {
+      throw new IllegalArgumentException(s"WeightedGiniAggregator given label $label" +
+        s" but requires label < numClasses (= $statsSize).")
+    }
+    if (label < 0) {
+      throw new IllegalArgumentException(s"WeightedGiniAggregator given label $label" +
+        s"but requires label is non-negative.")
+    }
+    allStats(offset + label.toInt) += instanceWeight
+  }
+
+  /**
+   * Get an [[ImpurityCalculator]] for a (node, feature, bin).
+   * @param allStats  Flat stats array, with stats for this (node, feature, bin) contiguous.
+   * @param offset    Start index of stats for this (node, feature, bin).
+   */
+  def getCalculator(allStats: Array[Double], offset: Int): WeightedGiniCalculator = {
+    new WeightedGiniCalculator(allStats.view(offset, offset + statsSize).toArray, classWeights)
+  }
+}
+
+/**
+ * Stores statistics for one (node, feature, bin) for calculating impurity.
+ * Unlike [[WeightedGiniAggregator]], this class stores its own data and is for a specific
+ * (node, feature, bin).
+ * @param stats  Array of sufficient statistics for a (node, feature, bin).
+ * @param classWeights Weights of classes
+ */
+private[spark] class WeightedGiniCalculator(stats: Array[Double], classWeights: Array[Double])
+  extends ImpurityCalculator(stats) {
+
+  override val weightedStats = stats.zip(classWeights).map(x => x._1 * x._2)
+  /**
+   * Make a deep copy of this [[ImpurityCalculator]].
+   */
+  def copy: WeightedGiniCalculator = new WeightedGiniCalculator(stats.clone(), classWeights.clone())
+
+  /**
+   * Calculate the impurity from the stored sufficient statistics.
+   */
+  def calculate(): Double = WeightedGini.calculate(weightedStats, weightedStats.sum)
+
+  /**
+   * Number of data points accounted for in the sufficient statistics.
+   */
+  def count: Long = stats.sum.toLong
+
+  /**
+   * Weighted summary statistics of data points
+   */
+  def weightedCount: Double = weightedStats.sum
+
+  /**
+   * Prediction which should be made based on the sufficient statistics.
+   */
+  def predict: Double = if (count == 0) {
+    0
+  } else {
+    indexOfLargestArrayElement(weightedStats)
+  }
+
+  /**
+   * Probability of the label given by [[predict]].
+   */
+  override def prob(label: Double): Double = {
+    val lbl = label.toInt
+    require(lbl < stats.length,
+      s"WeightedGiniCalculator.prob given invalid label: $lbl (should be < ${stats.length}")
+    require(lbl >= 0, "WeightedGiniImpurity does not support negative labels")
+    val cnt = weightedCount
+    if (cnt == 0) {
+      0
+    } else {
+      weightedStats(lbl) / cnt
+    }
+  }
+
+  override def toString: String = s"WeightedGiniCalculator(stats = [${stats.mkString(", ")}])"
+
+  /**
+   * Add the stats from another calculator into this one, modifying and returning this calculator.
+   * Update the weightedStats at the same time
+   */
+  override def add(other: ImpurityCalculator): ImpurityCalculator = {
+    require(stats.length == other.stats.length,
+      s"Two ImpurityCalculator instances cannot be added with different counts sizes." +
+        s"  Sizes are ${stats.length} and ${other.stats.length}.")
+    val otherCalculator = other.asInstanceOf[WeightedGiniCalculator]
+    val len = otherCalculator.stats.length
+    var i = 0
+    while (i < len) {
+      stats(i) += otherCalculator.stats(i)
+      weightedStats(i) += otherCalculator.weightedStats(i)
+      i += 1
+    }
+    this
+  }
+
+  /**
+   * Subtract the stats from another calculator from this one, modifying and returning this
+   * calculator. Update the weightedStats at the same time
+   */
+  override def subtract(other: ImpurityCalculator): ImpurityCalculator = {
+    require(stats.length == other.stats.length,
+      s"Two ImpurityCalculator instances cannot be subtracted with different counts sizes." +
+        s"  Sizes are ${stats.length} and ${other.stats.length}.")
+    val otherCalculator = other.asInstanceOf[WeightedGiniCalculator]
+    val len = otherCalculator.stats.length
+    var i = 0
+    while (i < len) {
+      stats(i) -= otherCalculator.stats(i)
+      weightedStats(i) -= otherCalculator.weightedStats(i)
+      i += 1
+    }
+    this
+  }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index 089d30abb5ef9..096ab2467ab83 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -69,6 +69,18 @@ class DecisionTreeClassifierSuite
   // Tests calling train()
   /////////////////////////////////////////////////////////////////////////////
 
+  test("Binary classification with explicitly setting uniform class weights") {
+    val dt = new DecisionTreeClassifier()
+      .setImpurity("WeightedGini")
+      .setMaxDepth(2)
+      .setMaxBins(100)
+      .setSeed(1)
+      .setClassWeights(Array(1, 1))
+    val categoricalFeatures = Map(0 -> 3, 1 -> 3)
+    val numClasses = 2
+    compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures, numClasses)
+  }
+
   test("Binary classification stump with ordered categorical features") {
     val dt = new DecisionTreeClassifier()
       .setImpurity("gini")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index 2e99ee157ae95..5ea110ec0d020 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -234,7 +234,8 @@ private object RandomForestClassifierSuite extends SparkFunSuite {
       numClasses: Int): Unit = {
     val numFeatures = data.first().features.size
     val oldStrategy =
-      rf.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, rf.getOldImpurity)
+      rf.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification,
+        rf.getOldImpurity, rf.getSubsamplingRate, rf.getClassWeights)
     val oldModel = OldRandomForest.trainClassifier(
       data.map(OldLabeledPoint.fromML), oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy,
       rf.getSeed.toInt)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
index c08335f9f84af..169dcdd3f567d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
@@ -140,7 +140,9 @@ private object RandomForestRegressorSuite extends SparkFunSuite {
       categoricalFeatures: Map[Int, Int]): Unit = {
     val numFeatures = data.first().features.size
     val oldStrategy =
-      rf.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, rf.getOldImpurity)
+      rf.getOldStrategy(categoricalFeatures, numClasses = 0,
+        OldAlgo.Regression, rf.getOldImpurity, rf.getSubsamplingRate,
+        classWeights = Array())
     val oldModel = OldRandomForest.trainRegressor(data.map(OldLabeledPoint.fromML), oldStrategy,
       rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt)
     val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
index dcc2f305df75a..dce4e698b82cd 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
@@ -93,7 +93,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
       val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
         Map(), Set(),
         Array(6), Gini, QuantileStrategy.Sort,
-        0, 0, 0.0, 0, 0
+        0, 0, 0.0, 0, 0, Array[Double]()
       )
       val featureSamples = Array.fill(200000)(math.random)
       val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
@@ -110,7 +110,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
       val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
         Map(), Set(),
         Array(5), Gini, QuantileStrategy.Sort,
-        0, 0, 0.0, 0, 0
+        0, 0, 0.0, 0, 0, Array[Double]()
       )
       val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble)
       val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
@@ -124,7 +124,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
       val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
         Map(), Set(),
         Array(3), Gini, QuantileStrategy.Sort,
-        0, 0, 0.0, 0, 0
+        0, 0, 0.0, 0, 0, Array[Double]()
       )
       val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble)
       val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
@@ -138,7 +138,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
       val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
         Map(), Set(),
         Array(3), Gini, QuantileStrategy.Sort,
-        0, 0, 0.0, 0, 0
+        0, 0, 0.0, 0, 0, Array[Double]()
       )
       val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble)
       val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
diff --git a/scalastyle-config.xml b/scalastyle-config.xml
index 270104f85b838..57c275baed215 100644
--- a/scalastyle-config.xml
+++ b/scalastyle-config.xml
@@ -94,7 +94,7 @@ This file is divided into 3 sections:
   
 
   
-    
+