1818package org .apache .spark .mllib .regression
1919
2020import org .apache .spark .mllib .linalg .Vector
21- import org .apache .spark .mllib .regression .MonotonicityConstraint .MonotonicityConstraint ._
2221import org .apache .spark .rdd .RDD
2322
24- /**
25- * Monotonicity constrains for monotone regression
26- * Isotonic (increasing)
27- * Antitonic (decreasing)
28- */
29- object MonotonicityConstraint {
30-
31- object MonotonicityConstraint {
32-
33- sealed trait MonotonicityConstraint {
34- private [regression] def holds (
35- current : WeightedLabeledPoint ,
36- next : WeightedLabeledPoint ): Boolean
37- }
38-
39- /**
40- * Isotonic monotonicity constraint. Increasing sequence
41- */
42- case object Isotonic extends MonotonicityConstraint {
43- override def holds (current : WeightedLabeledPoint , next : WeightedLabeledPoint ): Boolean = {
44- current.label <= next.label
45- }
46- }
47-
48- /**
49- * Antitonic monotonicity constrain. Decreasing sequence
50- */
51- case object Antitonic extends MonotonicityConstraint {
52- override def holds (current : WeightedLabeledPoint , next : WeightedLabeledPoint ): Boolean = {
53- current.label >= next.label
54- }
55- }
56- }
57-
58- val Isotonic = MonotonicityConstraint .Isotonic
59- val Antitonic = MonotonicityConstraint .Antitonic
60- }
61-
6223/**
6324 * Regression model for Isotonic regression
6425 *
6526 * @param predictions Weights computed for every feature.
66- * @param monotonicityConstraint specifies if the sequence is increasing or decreasing
27+ * @param isotonic isotonic ( increasing) or antitonic ( decreasing) sequence
6728 */
6829class IsotonicRegressionModel (
6930 val predictions : Seq [(Double , Double , Double )],
70- val monotonicityConstraint : MonotonicityConstraint )
31+ val isotonic : Boolean )
7132 extends RegressionModel {
7233
7334 override def predict (testData : RDD [Vector ]): RDD [Double ] =
@@ -91,23 +52,23 @@ trait IsotonicRegressionAlgorithm
9152 *
9253 * @param predictions labels estimated using isotonic regression algorithm.
9354 * Used for predictions on new data points.
94- * @param monotonicityConstraint isotonic or antitonic
55+ * @param isotonic isotonic (increasing) or antitonic (decreasing) sequence
9556 * @return isotonic regression model
9657 */
9758 protected def createModel (
9859 predictions : Seq [(Double , Double , Double )],
99- monotonicityConstraint : MonotonicityConstraint ): IsotonicRegressionModel
60+ isotonic : Boolean ): IsotonicRegressionModel
10061
10162 /**
10263 * Run algorithm to obtain isotonic regression model
10364 *
10465 * @param input data
105- * @param monotonicityConstraint ascending or descenting
66+ * @param isotonic isotonic (increasing) or antitonic (decreasing) sequence
10667 * @return isotonic regression model
10768 */
10869 def run (
10970 input : RDD [(Double , Double , Double )],
110- monotonicityConstraint : MonotonicityConstraint ): IsotonicRegressionModel
71+ isotonic : Boolean ): IsotonicRegressionModel
11172}
11273
11374/**
@@ -118,16 +79,16 @@ class PoolAdjacentViolators private [mllib]
11879
11980 override def run (
12081 input : RDD [(Double , Double , Double )],
121- monotonicityConstraint : MonotonicityConstraint ): IsotonicRegressionModel = {
82+ isotonic : Boolean ): IsotonicRegressionModel = {
12283 createModel(
123- parallelPoolAdjacentViolators(input, monotonicityConstraint ),
124- monotonicityConstraint )
84+ parallelPoolAdjacentViolators(input, isotonic ),
85+ isotonic )
12586 }
12687
12788 override protected def createModel (
12889 predictions : Seq [(Double , Double , Double )],
129- monotonicityConstraint : MonotonicityConstraint ): IsotonicRegressionModel = {
130- new IsotonicRegressionModel (predictions, monotonicityConstraint )
90+ isotonic : Boolean ): IsotonicRegressionModel = {
91+ new IsotonicRegressionModel (predictions, isotonic )
13192 }
13293
13394 /**
@@ -138,32 +99,38 @@ class PoolAdjacentViolators private [mllib]
13899 * Method in situ mutates input array
139100 *
140101 * @param in input data
141- * @param monotonicityConstraint asc or desc
102+ * @param isotonic asc or desc
142103 * @return result
143104 */
144105 private def poolAdjacentViolators (
145- in : Array [WeightedLabeledPoint ],
146- monotonicityConstraint : MonotonicityConstraint ): Array [WeightedLabeledPoint ] = {
106+ in : Array [( Double , Double , Double ) ],
107+ isotonic : Boolean ): Array [( Double , Double , Double ) ] = {
147108
148109 // Pools sub array within given bounds assigning weighted average value to all elements
149- def pool (in : Array [WeightedLabeledPoint ], start : Int , end : Int ): Unit = {
110+ def pool (in : Array [( Double , Double , Double ) ], start : Int , end : Int ): Unit = {
150111 val poolSubArray = in.slice(start, end + 1 )
151112
152- val weightedSum = poolSubArray.map(lp => lp.label * lp.weight ).sum
153- val weight = poolSubArray.map(_.weight ).sum
113+ val weightedSum = poolSubArray.map(lp => lp._1 * lp._3 ).sum
114+ val weight = poolSubArray.map(_._3 ).sum
154115
155116 for (i <- start to end) {
156- in(i) = WeightedLabeledPoint (weightedSum / weight, in(i).features , in(i).weight )
117+ in(i) = (weightedSum / weight, in(i)._2 , in(i)._3 )
157118 }
158119 }
159120
160121 var i = 0
161122
123+ val monotonicityConstrainter : (Double , Double ) => Boolean = (x, y) => if (isotonic) {
124+ x <= y
125+ } else {
126+ x >= y
127+ }
128+
162129 while (i < in.length) {
163130 var j = i
164131
165132 // Find monotonicity violating sequence, if any
166- while (j < in.length - 1 && ! monotonicityConstraint.holds (in(j), in(j + 1 ))) {
133+ while (j < in.length - 1 && ! monotonicityConstrainter (in(j)._1 , in(j + 1 )._1 )) {
167134 j = j + 1
168135 }
169136
@@ -173,7 +140,7 @@ class PoolAdjacentViolators private [mllib]
173140 } else {
174141 // Otherwise pool the violating sequence
175142 // And check if pooling caused monotonicity violation in previously processed points
176- while (i >= 0 && ! monotonicityConstraint.holds (in(i), in(i + 1 ))) {
143+ while (i >= 0 && ! monotonicityConstrainter (in(i)._1 , in(i + 1 )._1 )) {
177144 pool(in, i, j)
178145 i = i - 1
179146 }
@@ -190,19 +157,19 @@ class PoolAdjacentViolators private [mllib]
190157 * Calls Pool adjacent violators on each partition and then again on the result
191158 *
192159 * @param testData input
193- * @param monotonicityConstraint asc or desc
160+ * @param isotonic isotonic (increasing) or antitonic (decreasing) sequence
194161 * @return result
195162 */
196163 private def parallelPoolAdjacentViolators (
197164 testData : RDD [(Double , Double , Double )],
198- monotonicityConstraint : MonotonicityConstraint ): Seq [(Double , Double , Double )] = {
165+ isotonic : Boolean ): Seq [(Double , Double , Double )] = {
199166
200167 poolAdjacentViolators(
201168 testData
202169 .sortBy(_._2)
203170 .cache()
204- .mapPartitions(it => poolAdjacentViolators(it.toArray, monotonicityConstraint ).toIterator)
205- .collect(), monotonicityConstraint )
171+ .mapPartitions(it => poolAdjacentViolators(it.toArray, isotonic ).toIterator)
172+ .collect(), isotonic )
206173 }
207174}
208175
@@ -221,11 +188,11 @@ object IsotonicRegression {
221188 * Each point describes a row of the data
222189 * matrix A as well as the corresponding right hand side label y
223190 * and weight as number of measurements
224- * @param monotonicityConstraint Isotonic (increasing) or Antitonic (decreasing) sequence
191+ * @param isotonic isotonic (increasing) or antitonic (decreasing) sequence
225192 */
226193 def train (
227194 input : RDD [(Double , Double , Double )],
228- monotonicityConstraint : MonotonicityConstraint = Isotonic ): IsotonicRegressionModel = {
229- new PoolAdjacentViolators ().run(input, monotonicityConstraint )
195+ isotonic : Boolean = true ): IsotonicRegressionModel = {
196+ new PoolAdjacentViolators ().run(input, isotonic )
230197 }
231198}
0 commit comments