1818package org .apache .spark .mllib .regression ;
1919
2020import java .io .Serializable ;
21- import java .util .ArrayList ;
22- import java .util .Arrays ;
2321import java .util .List ;
2422
25- import org .apache .spark .api .java .JavaDoubleRDD ;
2623import scala .Tuple3 ;
2724
25+ import com .google .common .collect .Lists ;
2826import org .junit .After ;
2927import org .junit .Assert ;
3028import org .junit .Before ;
3129import org .junit .Test ;
3230
31+ import org .apache .spark .api .java .JavaDoubleRDD ;
3332import org .apache .spark .api .java .JavaRDD ;
3433import org .apache .spark .api .java .JavaSparkContext ;
3534
3635public class JavaIsotonicRegressionSuite implements Serializable {
3736 private transient JavaSparkContext sc ;
3837
3938 private List <Tuple3 <Double , Double , Double >> generateIsotonicInput (double [] labels ) {
40- List <Tuple3 <Double , Double , Double >> input = new ArrayList <> ();
39+ List <Tuple3 <Double , Double , Double >> input = Lists . newArrayList ();
4140
42- for (int i = 1 ; i <= labels .length ; i ++) {
43- input .add (new Tuple3 (labels [i -1 ], (double )i , 1d ));
41+ for (int i = 1 ; i <= labels .length ; i ++) {
42+ input .add (new Tuple3 < Double , Double , Double > (labels [i -1 ], (double ) i , 1d ));
4443 }
4544
4645 return input ;
4746 }
4847
49- private double difference (List <Tuple3 <Double , Double , Double >> expected , IsotonicRegressionModel model ) {
50- double diff = 0 ;
51-
52- for (int i = 0 ; i < model .predictions ().length ; i ++) {
53- Tuple3 <Double , Double , Double > exp = expected .get (i );
54- diff += Math .abs (model .predict (exp ._2 ()) - exp ._1 ());
55- }
56-
57- return diff ;
58- }
59-
6048 private IsotonicRegressionModel runIsotonicRegression (double [] labels ) {
6149 JavaRDD <Tuple3 <Double , Double , Double >> trainRDD =
62- sc .parallelize (generateIsotonicInput (labels )).cache ();
50+ sc .parallelize (generateIsotonicInput (labels ), 2 ).cache ();
6351
6452 return new IsotonicRegression ().run (trainRDD );
6553 }
@@ -80,20 +68,16 @@ public void testIsotonicRegressionJavaRDD() {
8068 IsotonicRegressionModel model =
8169 runIsotonicRegression (new double []{1 , 2 , 3 , 3 , 1 , 6 , 7 , 8 , 11 , 9 , 10 , 12 });
8270
83- List <Tuple3 <Double , Double , Double >> expected =
84- generateIsotonicInput (new double [] {1 , 2 , 7d /3 , 7d /3 , 7d /3 , 6 , 7 , 8 , 10 , 10 , 10 , 12 });
85-
86- Assert .assertTrue (difference (expected , model ) == 0 );
71+ Assert .assertArrayEquals (
72+ new double [] {1 , 2 , 7d /3 , 7d /3 , 6 , 7 , 8 , 10 , 10 , 12 }, model .predictions (), 1e-14 );
8773 }
8874
8975 @ Test
9076 public void testIsotonicRegressionPredictionsJavaRDD () {
9177 IsotonicRegressionModel model =
9278 runIsotonicRegression (new double []{1 , 2 , 3 , 3 , 1 , 6 , 7 , 8 , 11 , 9 , 10 , 12 });
9379
94- JavaDoubleRDD testRDD =
95- sc .parallelizeDoubles (Arrays .asList (new Double [] {0.0 , 1.0 , 9.5 , 12.0 , 13.0 }));
96-
80+ JavaDoubleRDD testRDD = sc .parallelizeDoubles (Lists .newArrayList (0.0 , 1.0 , 9.5 , 12.0 , 13.0 ));
9781 List <Double > predictions = model .predict (testRDD ).collect ();
9882
9983 Assert .assertTrue (predictions .get (0 ) == 1d );
@@ -102,4 +86,4 @@ public void testIsotonicRegressionPredictionsJavaRDD() {
10286 Assert .assertTrue (predictions .get (3 ) == 12d );
10387 Assert .assertTrue (predictions .get (4 ) == 12d );
10488 }
105- }
89+ }
0 commit comments