@@ -27,11 +27,12 @@ import org.apache.spark.SparkContext._
2727import org .jblas ._
2828import org .apache .spark .rdd .RDD
2929import org .apache .spark .mllib .regression .LabeledPoint
30- import org .apache .spark .mllib .tree .impurity .{Entropy , Gini }
30+ import org .apache .spark .mllib .tree .impurity .{Entropy , Gini , Variance }
3131import org .apache .spark .mllib .tree .model .Filter
3232import org .apache .spark .mllib .tree .configuration .Strategy
3333import org .apache .spark .mllib .tree .configuration .Algo ._
3434import scala .collection .mutable
35+ import org .apache .spark .mllib .tree .configuration .FeatureType ._
3536
3637class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
3738
@@ -56,7 +57,6 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
5657 assert(bins.length== 2 )
5758 assert(splits(0 ).length== 99 )
5859 assert(bins(0 ).length== 100 )
59- // println(splits(1)(98))
6060 }
6161
6262 test(" split and bin calculation for categorical variables" ){
@@ -69,13 +69,71 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
6969 assert(bins.length== 2 )
7070 assert(splits(0 ).length== 99 )
7171 assert(bins(0 ).length== 100 )
72- println(splits(0 )(0 ))
73- println(splits(0 )(1 ))
74- println(bins(0 )(0 ))
75- println(splits(1 )(0 ))
76- println(splits(1 )(1 ))
77- println(bins(1 )(0 ))
78- // TODO: Add asserts
72+
73+ // Checking splits
74+
75+ assert(splits(0 )(0 ).feature == 0 )
76+ assert(splits(0 )(0 ).threshold == Double .MinValue )
77+ assert(splits(0 )(0 ).featureType == Categorical )
78+ assert(splits(0 )(0 ).categories.length == 1 )
79+ assert(splits(0 )(0 ).categories.contains(1.0 ))
80+
81+
82+ assert(splits(0 )(1 ).feature == 0 )
83+ assert(splits(0 )(1 ).threshold == Double .MinValue )
84+ assert(splits(0 )(1 ).featureType == Categorical )
85+ assert(splits(0 )(1 ).categories.length == 2 )
86+ assert(splits(0 )(1 ).categories.contains(1.0 ))
87+ assert(splits(0 )(1 ).categories.contains(0.0 ))
88+
89+ assert(splits(0 )(2 ) == null )
90+
91+ assert(splits(1 )(0 ).feature == 1 )
92+ assert(splits(1 )(0 ).threshold == Double .MinValue )
93+ assert(splits(1 )(0 ).featureType == Categorical )
94+ assert(splits(1 )(0 ).categories.length == 1 )
95+ assert(splits(1 )(0 ).categories.contains(0.0 ))
96+
97+
98+ assert(splits(1 )(1 ).feature == 1 )
99+ assert(splits(1 )(1 ).threshold == Double .MinValue )
100+ assert(splits(1 )(1 ).featureType == Categorical )
101+ assert(splits(1 )(1 ).categories.length == 2 )
102+ assert(splits(1 )(1 ).categories.contains(1.0 ))
103+ assert(splits(1 )(1 ).categories.contains(0.0 ))
104+
105+ assert(splits(1 )(2 ) == null )
106+
107+
108+ // Checks bins
109+
110+ assert(bins(0 )(0 ).category == 1.0 )
111+ assert(bins(0 )(0 ).lowSplit.categories.length == 0 )
112+ assert(bins(0 )(0 ).highSplit.categories.length == 1 )
113+ assert(bins(0 )(0 ).highSplit.categories.contains(1.0 ))
114+
115+ assert(bins(0 )(1 ).category == 0.0 )
116+ assert(bins(0 )(1 ).lowSplit.categories.length == 1 )
117+ assert(bins(0 )(1 ).lowSplit.categories.contains(1.0 ))
118+ assert(bins(0 )(1 ).highSplit.categories.length == 2 )
119+ assert(bins(0 )(1 ).highSplit.categories.contains(1.0 ))
120+ assert(bins(0 )(1 ).highSplit.categories.contains(0.0 ))
121+
122+ assert(bins(0 )(2 ).category == Double .MaxValue )
123+
124+ assert(bins(1 )(0 ).category == 0.0 )
125+ assert(bins(1 )(0 ).lowSplit.categories.length == 0 )
126+ assert(bins(1 )(0 ).highSplit.categories.length == 1 )
127+ assert(bins(1 )(0 ).highSplit.categories.contains(0.0 ))
128+
129+ assert(bins(1 )(1 ).category == 1.0 )
130+ assert(bins(1 )(1 ).lowSplit.categories.length == 1 )
131+ assert(bins(1 )(1 ).lowSplit.categories.contains(0.0 ))
132+ assert(bins(1 )(1 ).highSplit.categories.length == 2 )
133+ assert(bins(1 )(1 ).highSplit.categories.contains(0.0 ))
134+ assert(bins(1 )(1 ).highSplit.categories.contains(1.0 ))
135+
136+ assert(bins(1 )(2 ).category == Double .MaxValue )
79137
80138 }
81139
@@ -85,29 +143,106 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
85143 val rdd = sc.parallelize(arr)
86144 val strategy = new Strategy (Classification ,Gini ,3 ,100 ,categoricalFeaturesInfo = Map (0 -> 3 , 1 -> 3 ))
87145 val (splits, bins) = DecisionTree .findSplitsBins(rdd,strategy)
88- assert(splits.length== 2 )
89- assert(bins.length== 2 )
90- assert(splits(0 ).length== 99 )
91- assert(bins(0 ).length== 100 )
92- println(splits(0 )(0 ))
93- println(splits(0 )(1 ))
94- println(splits(0 )(2 ))
95- println(bins(0 )(0 ))
96- println(bins(0 )(1 ))
97- println(bins(0 )(2 ))
98- println(splits(1 )(0 ))
99- println(splits(1 )(1 ))
100- println(splits(1 )(2 ))
101- println(bins(1 )(0 ))
102- println(bins(1 )(1 ))
103- println(bins(0 )(2 ))
104- println(bins(0 )(3 ))
105- // TODO: Add asserts
106146
107- }
147+ // Checking splits
148+
149+ assert(splits(0 )(0 ).feature == 0 )
150+ assert(splits(0 )(0 ).threshold == Double .MinValue )
151+ assert(splits(0 )(0 ).featureType == Categorical )
152+ assert(splits(0 )(0 ).categories.length == 1 )
153+ assert(splits(0 )(0 ).categories.contains(1.0 ))
154+
155+ assert(splits(0 )(1 ).feature == 0 )
156+ assert(splits(0 )(1 ).threshold == Double .MinValue )
157+ assert(splits(0 )(1 ).featureType == Categorical )
158+ assert(splits(0 )(1 ).categories.length == 2 )
159+ assert(splits(0 )(1 ).categories.contains(1.0 ))
160+ assert(splits(0 )(1 ).categories.contains(0.0 ))
161+
162+ assert(splits(0 )(2 ).feature == 0 )
163+ assert(splits(0 )(2 ).threshold == Double .MinValue )
164+ assert(splits(0 )(2 ).featureType == Categorical )
165+ assert(splits(0 )(2 ).categories.length == 3 )
166+ assert(splits(0 )(2 ).categories.contains(1.0 ))
167+ assert(splits(0 )(2 ).categories.contains(0.0 ))
168+ assert(splits(0 )(2 ).categories.contains(2.0 ))
169+
170+ assert(splits(0 )(3 ) == null )
171+
172+ assert(splits(1 )(0 ).feature == 1 )
173+ assert(splits(1 )(0 ).threshold == Double .MinValue )
174+ assert(splits(1 )(0 ).featureType == Categorical )
175+ assert(splits(1 )(0 ).categories.length == 1 )
176+ assert(splits(1 )(0 ).categories.contains(0.0 ))
177+
178+ assert(splits(1 )(1 ).feature == 1 )
179+ assert(splits(1 )(1 ).threshold == Double .MinValue )
180+ assert(splits(1 )(1 ).featureType == Categorical )
181+ assert(splits(1 )(1 ).categories.length == 2 )
182+ assert(splits(1 )(1 ).categories.contains(1.0 ))
183+ assert(splits(1 )(1 ).categories.contains(0.0 ))
184+
185+ assert(splits(1 )(2 ).feature == 1 )
186+ assert(splits(1 )(2 ).threshold == Double .MinValue )
187+ assert(splits(1 )(2 ).featureType == Categorical )
188+ assert(splits(1 )(2 ).categories.length == 3 )
189+ assert(splits(1 )(2 ).categories.contains(1.0 ))
190+ assert(splits(1 )(2 ).categories.contains(0.0 ))
191+ assert(splits(1 )(2 ).categories.contains(2.0 ))
192+
193+ assert(splits(1 )(3 ) == null )
194+
195+
196+ // Checks bins
197+
198+ assert(bins(0 )(0 ).category == 1.0 )
199+ assert(bins(0 )(0 ).lowSplit.categories.length == 0 )
200+ assert(bins(0 )(0 ).highSplit.categories.length == 1 )
201+ assert(bins(0 )(0 ).highSplit.categories.contains(1.0 ))
202+
203+ assert(bins(0 )(1 ).category == 0.0 )
204+ assert(bins(0 )(1 ).lowSplit.categories.length == 1 )
205+ assert(bins(0 )(1 ).lowSplit.categories.contains(1.0 ))
206+ assert(bins(0 )(1 ).highSplit.categories.length == 2 )
207+ assert(bins(0 )(1 ).highSplit.categories.contains(1.0 ))
208+ assert(bins(0 )(1 ).highSplit.categories.contains(0.0 ))
209+
210+ assert(bins(0 )(2 ).category == 2.0 )
211+ assert(bins(0 )(2 ).lowSplit.categories.length == 2 )
212+ assert(bins(0 )(2 ).lowSplit.categories.contains(1.0 ))
213+ assert(bins(0 )(2 ).lowSplit.categories.contains(0.0 ))
214+ assert(bins(0 )(2 ).highSplit.categories.length == 3 )
215+ assert(bins(0 )(2 ).highSplit.categories.contains(1.0 ))
216+ assert(bins(0 )(2 ).highSplit.categories.contains(0.0 ))
217+ assert(bins(0 )(2 ).highSplit.categories.contains(2.0 ))
218+
219+ assert(bins(0 )(3 ).category == Double .MaxValue )
220+
221+ assert(bins(1 )(0 ).category == 0.0 )
222+ assert(bins(1 )(0 ).lowSplit.categories.length == 0 )
223+ assert(bins(1 )(0 ).highSplit.categories.length == 1 )
224+ assert(bins(1 )(0 ).highSplit.categories.contains(0.0 ))
225+
226+ assert(bins(1 )(1 ).category == 1.0 )
227+ assert(bins(1 )(1 ).lowSplit.categories.length == 1 )
228+ assert(bins(1 )(1 ).lowSplit.categories.contains(0.0 ))
229+ assert(bins(1 )(1 ).highSplit.categories.length == 2 )
230+ assert(bins(1 )(1 ).highSplit.categories.contains(0.0 ))
231+ assert(bins(1 )(1 ).highSplit.categories.contains(1.0 ))
232+
233+ assert(bins(1 )(2 ).category == 2.0 )
234+ assert(bins(1 )(2 ).lowSplit.categories.length == 2 )
235+ assert(bins(1 )(2 ).lowSplit.categories.contains(0.0 ))
236+ assert(bins(1 )(2 ).lowSplit.categories.contains(1.0 ))
237+ assert(bins(1 )(2 ).highSplit.categories.length == 3 )
238+ assert(bins(1 )(2 ).highSplit.categories.contains(0.0 ))
239+ assert(bins(1 )(2 ).highSplit.categories.contains(1.0 ))
240+ assert(bins(1 )(2 ).highSplit.categories.contains(2.0 ))
241+
242+ assert(bins(1 )(3 ).category == Double .MaxValue )
108243
109- // TODO: Test max feature value > num bins
110244
245+ }
111246
112247 test(" classification stump with all categorical variables" ){
113248 val arr = DecisionTreeSuite .generateCategoricalDataPoints()
@@ -117,22 +252,41 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
117252 val (splits, bins) = DecisionTree .findSplitsBins(rdd,strategy)
118253 strategy.numBins = 100
119254 val bestSplits = DecisionTree .findBestSplits(rdd,new Array (7 ),strategy,0 ,Array [List [Filter ]](),splits,bins)
120- println(bestSplits(0 )._1)
121- println(bestSplits(0 )._2)
122- // TODO: Add asserts
255+
256+ val split = bestSplits(0 )._1
257+ assert(split.categories.length == 1 )
258+ assert(split.categories.contains(1.0 ))
259+ assert(split.featureType == Categorical )
260+ assert(split.threshold == Double .MinValue )
261+
262+ val stats = bestSplits(0 )._2
263+ assert(stats.gain > 0 )
264+ assert(stats.predict > 0.4 )
265+ assert(stats.predict < 0.5 )
266+ assert(stats.impurity > 0.2 )
267+
123268 }
124269
125270 test(" regression stump with all categorical variables" ){
126271 val arr = DecisionTreeSuite .generateCategoricalDataPoints()
127272 assert(arr.length == 1000 )
128273 val rdd = sc.parallelize(arr)
129- val strategy = new Strategy (Classification , Gini ,3 ,100 ,categoricalFeaturesInfo = Map (0 -> 3 , 1 -> 3 ))
274+ val strategy = new Strategy (Regression , Variance ,3 ,100 ,categoricalFeaturesInfo = Map (0 -> 3 , 1 -> 3 ))
130275 val (splits, bins) = DecisionTree .findSplitsBins(rdd,strategy)
131276 strategy.numBins = 100
132277 val bestSplits = DecisionTree .findBestSplits(rdd,new Array (7 ),strategy,0 ,Array [List [Filter ]](),splits,bins)
133- println(bestSplits(0 )._1)
134- println(bestSplits(0 )._2)
135- // TODO: Add asserts
278+
279+ val split = bestSplits(0 )._1
280+ assert(split.categories.length == 1 )
281+ assert(split.categories.contains(1.0 ))
282+ assert(split.featureType == Categorical )
283+ assert(split.threshold == Double .MinValue )
284+
285+ val stats = bestSplits(0 )._2
286+ assert(stats.gain > 0 )
287+ assert(stats.predict > 0.4 )
288+ assert(stats.predict < 0.5 )
289+ assert(stats.impurity > 0.2 )
136290 }
137291
138292
@@ -157,7 +311,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
157311 assert(0 == bestSplits(0 )._2.gain)
158312 assert(0 == bestSplits(0 )._2.leftImpurity)
159313 assert(0 == bestSplits(0 )._2.rightImpurity)
160- println(bestSplits( 0 )._2.predict)
314+
161315 }
162316
163317 test(" stump with fixed label 1 for Gini" ){
0 commit comments