@@ -81,7 +81,7 @@ The test error is calculated to measure the algorithm accuracy.
8181
8282<div data-lang =" scala " >
8383{% highlight scala %}
84- import org.apache.spark.mllib.tree.DecisionTree
84+ import org.apache.spark.mllib.tree.RandomForest
8585import org.apache.spark.mllib.util.MLUtils
8686
8787// Load and parse the data file.
@@ -90,16 +90,18 @@ val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
9090val splits = data.randomSplit(Array(0.7, 0.3))
9191val (trainingData, testData) = (splits(0), splits(1))
9292
93- // Train a DecisionTree model.
93+ // Train a RandomForest model.
9494// Empty categoricalFeaturesInfo indicates all features are continuous.
9595val numClasses = 2
9696val categoricalFeaturesInfo = Map[ Int, Int] ( )
97+ val numTrees = 3 // Use more in practice.
98+ val featureSubsetStrategy = "auto" // Let the algorithm choose.
9799val impurity = "gini"
98- val maxDepth = 5
100+ val maxDepth = 4
99101val maxBins = 32
100102
101- val model = DecisionTree .trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,
102- impurity, maxDepth, maxBins)
103+ val model = RandomForest .trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,
104+ numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)
103105
104106// Evaluate model on test instances and compute test error
105107val labelAndPreds = testData.map { point =>
@@ -108,26 +110,26 @@ val labelAndPreds = testData.map { point =>
108110}
109111val testErr = labelAndPreds.filter(r => r._ 1 != r._ 2).count.toDouble / testData.count()
110112println("Test Error = " + testErr)
111- println("Learned classification tree model:\n" + model.toDebugString)
113+ println("Learned classification forest model:\n" + model.toDebugString)
112114{% endhighlight %}
113115</div >
114116
115117<div data-lang =" java " >
116118{% highlight java %}
117- import java.util.HashMap;
118119import scala.Tuple2;
120+ import java.util.HashMap;
121+ import org.apache.spark.SparkConf;
119122import org.apache.spark.api.java.JavaPairRDD;
120123import org.apache.spark.api.java.JavaRDD;
121124import org.apache.spark.api.java.JavaSparkContext;
122125import org.apache.spark.api.java.function.Function;
123126import org.apache.spark.api.java.function.PairFunction;
124127import org.apache.spark.mllib.regression.LabeledPoint;
125- import org.apache.spark.mllib.tree.DecisionTree ;
126- import org.apache.spark.mllib.tree.model.DecisionTreeModel ;
128+ import org.apache.spark.mllib.tree.RandomForest ;
129+ import org.apache.spark.mllib.tree.model.RandomForestModel ;
127130import org.apache.spark.mllib.util.MLUtils;
128- import org.apache.spark.SparkConf;
129131
130- SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree ");
132+ SparkConf sparkConf = new SparkConf().setAppName("JavaRandomForestClassification ");
131133JavaSparkContext sc = new JavaSparkContext(sparkConf);
132134
133135// Load and parse the data file.
@@ -138,17 +140,20 @@ JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3});
138140JavaRDD<LabeledPoint > trainingData = splits[ 0] ;
139141JavaRDD<LabeledPoint > testData = splits[ 1] ;
140142
141- // Set parameters .
143+ // Train a RandomForest model .
142144// Empty categoricalFeaturesInfo indicates all features are continuous.
143145Integer numClasses = 2;
144- Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
146+ HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
147+ Integer numTrees = 3; // Use more in practice.
148+ String featureSubsetStrategy = "auto"; // Let the algorithm choose.
145149String impurity = "gini";
146150Integer maxDepth = 5;
147151Integer maxBins = 32;
152+ Integer seed = 12345;
148153
149- // Train a DecisionTree model for classification.
150- final DecisionTreeModel model = DecisionTree.trainClassifier(trainingData, numClasses ,
151- categoricalFeaturesInfo, impurity, maxDepth, maxBins );
154+ final RandomForestModel model = RandomForest.trainClassifier(trainingData, numClasses,
155+ categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins ,
156+ seed );
152157
153158// Evaluate model on test instances and compute test error
154159JavaPairRDD<Double, Double> predictionAndLabel =
@@ -166,38 +171,36 @@ Double testErr =
166171 }
167172 }).count() / testData.count();
168173System.out.println("Test Error: " + testErr);
169- System.out.println("Learned classification tree model:\n" + model.toDebugString());
174+ System.out.println("Learned classification forest model:\n" + model.toDebugString());
170175{% endhighlight %}
171176</div >
172177
173178<div data-lang =" python " >
174179{% highlight python %}
175- from pyspark.mllib.regression import LabeledPoint
176- from pyspark.mllib.tree import DecisionTree
180+ from pyspark.mllib.tree import RandomForest
177181from pyspark.mllib.util import MLUtils
178182
179183# Load and parse the data file into an RDD of LabeledPoint.
180184data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt')
181185# Split the data into training and test sets (30% held out for testing)
182186(trainingData, testData) = data.randomSplit([ 0.7, 0.3] )
183187
184- # Train a DecisionTree model.
188+ # Train a RandomForest model.
185189# Empty categoricalFeaturesInfo indicates all features are continuous.
186- model = DecisionTree.trainClassifier(trainingData, numClasses=2, categoricalFeaturesInfo={},
187- impurity='gini', maxDepth=5, maxBins=32)
190+ # Note: Use larger numTrees in practice.
191+ # Setting featureSubsetStrategy="auto" lets the algorithm choose.
192+ model = RandomForest.trainClassifier(trainingData, numClasses=2, categoricalFeaturesInfo={},
193+ numTrees=3, featureSubsetStrategy="auto",
194+ impurity='gini', maxDepth=4, maxBins=32)
188195
189196# Evaluate model on test instances and compute test error
190197predictions = model.predict(testData.map(lambda x: x.features))
191198labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
192199testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(testData.count())
193200print('Test Error = ' + str(testErr))
194- print('Learned classification tree model:')
201+ print('Learned classification forest model:')
195202print(model.toDebugString())
196203{% endhighlight %}
197-
198- Note: When making predictions for a dataset, it is more efficient to do batch prediction rather
199- than separately calling ` predict ` on each data point. This is because the Python code makes calls
200- to an underlying ` DecisionTree ` model in Scala.
201204</div >
202205
203206</div >
@@ -215,7 +218,7 @@ The Mean Squared Error (MSE) is computed at the end to evaluate
215218
216219<div data-lang =" scala " >
217220{% highlight scala %}
218- import org.apache.spark.mllib.tree.DecisionTree
221+ import org.apache.spark.mllib.tree.RandomForest
219222import org.apache.spark.mllib.util.MLUtils
220223
221224// Load and parse the data file.
@@ -224,15 +227,18 @@ val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
224227val splits = data.randomSplit(Array(0.7, 0.3))
225228val (trainingData, testData) = (splits(0), splits(1))
226229
227- // Train a DecisionTree model.
230+ // Train a RandomForest model.
228231// Empty categoricalFeaturesInfo indicates all features are continuous.
232+ val numClasses = 2
229233val categoricalFeaturesInfo = Map[ Int, Int] ( )
234+ val numTrees = 3 // Use more in practice.
235+ val featureSubsetStrategy = "auto" // Let the algorithm choose.
230236val impurity = "variance"
231- val maxDepth = 5
237+ val maxDepth = 4
232238val maxBins = 32
233239
234- val model = DecisionTree .trainRegressor(trainingData, categoricalFeaturesInfo, impurity ,
235- maxDepth, maxBins)
240+ val model = RandomForest .trainRegressor(trainingData, categoricalFeaturesInfo,
241+ numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)
236242
237243// Evaluate model on test instances and compute test error
238244val labelsAndPredictions = testData.map { point =>
@@ -241,7 +247,7 @@ val labelsAndPredictions = testData.map { point =>
241247}
242248val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean()
243249println("Test Mean Squared Error = " + testMSE)
244- println("Learned regression tree model:\n" + model.toDebugString)
250+ println("Learned regression forest model:\n" + model.toDebugString)
245251{% endhighlight %}
246252</div >
247253
@@ -256,12 +262,12 @@ import org.apache.spark.api.java.JavaSparkContext;
256262import org.apache.spark.api.java.function.Function;
257263import org.apache.spark.api.java.function.PairFunction;
258264import org.apache.spark.mllib.regression.LabeledPoint;
259- import org.apache.spark.mllib.tree.DecisionTree ;
260- import org.apache.spark.mllib.tree.model.DecisionTreeModel ;
265+ import org.apache.spark.mllib.tree.RandomForest ;
266+ import org.apache.spark.mllib.tree.model.RandomForestModel ;
261267import org.apache.spark.mllib.util.MLUtils;
262268import org.apache.spark.SparkConf;
263269
264- SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree ");
270+ SparkConf sparkConf = new SparkConf().setAppName("JavaRandomForest ");
265271JavaSparkContext sc = new JavaSparkContext(sparkConf);
266272
267273// Load and parse the data file.
@@ -276,11 +282,11 @@ JavaRDD<LabeledPoint> testData = splits[1];
276282// Empty categoricalFeaturesInfo indicates all features are continuous.
277283Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
278284String impurity = "variance";
279- Integer maxDepth = 5 ;
285+ Integer maxDepth = 4 ;
280286Integer maxBins = 32;
281287
282- // Train a DecisionTree model.
283- final DecisionTreeModel model = DecisionTree .trainRegressor(trainingData,
288+ // Train a RandomForest model.
289+ final RandomForestModel model = RandomForest .trainRegressor(trainingData,
284290 categoricalFeaturesInfo, impurity, maxDepth, maxBins);
285291
286292// Evaluate model on test instances and compute test error
@@ -305,38 +311,36 @@ Double testMSE =
305311 }
306312 }) / data.count();
307313System.out.println("Test Mean Squared Error: " + testMSE);
308- System.out.println("Learned regression tree model:\n" + model.toDebugString());
314+ System.out.println("Learned regression forest model:\n" + model.toDebugString());
309315{% endhighlight %}
310316</div >
311317
312318<div data-lang =" python " >
313319{% highlight python %}
314- from pyspark.mllib.regression import LabeledPoint
315- from pyspark.mllib.tree import DecisionTree
320+ from pyspark.mllib.tree import RandomForest
316321from pyspark.mllib.util import MLUtils
317322
318323# Load and parse the data file into an RDD of LabeledPoint.
319324data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt')
320325# Split the data into training and test sets (30% held out for testing)
321326(trainingData, testData) = data.randomSplit([ 0.7, 0.3] )
322327
323- # Train a DecisionTree model.
328+ # Train a RandomForest model.
324329# Empty categoricalFeaturesInfo indicates all features are continuous.
325- model = DecisionTree.trainRegressor(trainingData, categoricalFeaturesInfo={},
326- impurity='variance', maxDepth=5, maxBins=32)
330+ # Note: Use larger numTrees in practice.
331+ # Setting featureSubsetStrategy="auto" lets the algorithm choose.
332+ model = RandomForest.trainRegressor(trainingData, categoricalFeaturesInfo={},
333+ numTrees=3, featureSubsetStrategy="auto",
334+ impurity='variance', maxDepth=4, maxBins=32)
327335
328336# Evaluate model on test instances and compute test error
329337predictions = model.predict(testData.map(lambda x: x.features))
330338labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
331339testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() / float(testData.count())
332340print('Test Mean Squared Error = ' + str(testMSE))
333- print('Learned regression tree model:')
341+ print('Learned regression forest model:')
334342print(model.toDebugString())
335343{% endhighlight %}
336-
337- Note: When making predictions for a dataset, it is more efficient to do batch prediction rather
338- than separately calling ` predict ` on each data point. This is because the Python code makes calls
339- to an underlying ` DecisionTree ` model in Scala.
340344</div >
341345
342346</div >
0 commit comments