@@ -172,7 +172,11 @@ class DecisionTree private(val strategy: Strategy) extends Serializable with Log
172172object DecisionTree extends Serializable with Logging {
173173
174174 /**
175- * Method to train a decision tree model over an RDD
175+ * Method to train a decision tree model where the instances are represented as an RDD of
176+ * (label, features) pairs. The method supports binary classification and regression. For the
177+ * binary classification, the label for each instance should either be 0 or 1 to denote the two
178+ * classes. The parameters for the algorithm are specified using the strategy parameter.
179+ *
176180 * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]] used as training data
177181 * for DecisionTree
178182 * @param strategy The configuration parameters for the tree algorithm which specify the type
@@ -185,7 +189,11 @@ object DecisionTree extends Serializable with Logging {
185189 }
186190
187191 /**
188- * Method to train a decision tree model over an RDD
192+ * Method to train a decision tree model where the instances are represented as an RDD of
193+ * (label, features) pairs. The method supports binary classification and regression. For the
194+ * binary classification, the label for each instance should either be 0 or 1 to denote the two
195+ * classes.
196+ *
189197 * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]] used as
190198 * training data
191199 * @param algo algorithm, classification or regression
@@ -204,8 +212,13 @@ object DecisionTree extends Serializable with Logging {
204212
205213
206214 /**
207- * Method to train a decision tree model over an RDD
208- * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]] used as
215+ * Method to train a decision tree model where the instances are represented as an RDD of
216+ * (label, features) pairs. The decision tree method supports binary classification and
217+ * regression. For the binary classification, the label for each instance should either be 0 or
218+ * 1 to denote the two classes. The method also supports categorical features inputs where the
219+ * number of categories can specified using the categoricalFeaturesInfo option.
220+ *
221+ * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]] used as
209222 * training data for DecisionTree
210223 * @param algo classification or regression
211224 * @param impurity criterion used for information gain calculation
@@ -236,6 +249,7 @@ object DecisionTree extends Serializable with Logging {
236249
237250 /**
238251 * Returns an array of optimal splits for all nodes at a given level
252+ *
239253 * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]] used as training data
240254 * for DecisionTree
241255 * @param parentImpurities Impurities for all parent nodes for the current level
@@ -247,7 +261,7 @@ object DecisionTree extends Serializable with Logging {
247261 * @param bins possible bins for all features
248262 * @return array of splits with best splits for all nodes at a given level.
249263 */
250- private def findBestSplits (
264+ protected [tree] def findBestSplits (
251265 input : RDD [LabeledPoint ],
252266 parentImpurities : Array [Double ],
253267 strategy : Strategy ,
@@ -885,7 +899,7 @@ object DecisionTree extends Serializable with Logging {
885899 * .model.Split] of size (numFeatures, numSplits-1) and bins is an Array of [org.apache
886900 * .spark.mllib.tree.model.Bin] of size (numFeatures, numSplits1)
887901 */
888- private def findSplitsBins (
902+ protected [tree] def findSplitsBins (
889903 input : RDD [LabeledPoint ],
890904 strategy : Strategy ): (Array [Array [Split ]], Array [Array [Bin ]]) = {
891905 val count = input.count()
0 commit comments