Skip to content

Commit 418ba1b

Browse files
committed
Added save, load to mllib.classification.LogisticRegressionModel, plus test suite
1 parent eccb9fb commit 418ba1b

File tree

3 files changed

+216
-26
lines changed

3 files changed

+216
-26
lines changed

mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,16 @@
1717

1818
package org.apache.spark.mllib.classification
1919

20+
import org.apache.spark.SparkContext
2021
import org.apache.spark.annotation.Experimental
2122
import org.apache.spark.mllib.linalg.BLAS.dot
2223
import org.apache.spark.mllib.linalg.{DenseVector, Vector}
2324
import org.apache.spark.mllib.optimization._
2425
import org.apache.spark.mllib.regression._
2526
import org.apache.spark.mllib.util.{DataValidators, MLUtils}
27+
import org.apache.spark.mllib.util.{Importable, DataValidators, Exportable}
2628
import org.apache.spark.rdd.RDD
29+
import org.apache.spark.sql.{Row, SQLContext, SchemaRDD}
2730

2831
/**
2932
* Classification model trained using Multinomial/Binary Logistic Regression.
@@ -42,7 +45,8 @@ class LogisticRegressionModel (
4245
override val intercept: Double,
4346
val numFeatures: Int,
4447
val numClasses: Int)
45-
extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable {
48+
extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable
49+
with Exportable {
4650

4751
def this(weights: Vector, intercept: Double) = this(weights, intercept, weights.size, 2)
4852

@@ -60,6 +64,13 @@ class LogisticRegressionModel (
6064
this
6165
}
6266

67+
/**
68+
* :: Experimental ::
69+
* Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions.
70+
*/
71+
@Experimental
72+
def getThreshold: Option[Double] = threshold
73+
6374
/**
6475
* :: Experimental ::
6576
* Clears the threshold so that `predict` will output raw prediction scores.
@@ -126,6 +137,65 @@ class LogisticRegressionModel (
126137
bestClass.toDouble
127138
}
128139
}
140+
141+
override def save(sc: SparkContext, path: String): Unit = {
142+
val sqlContext = new SQLContext(sc)
143+
import sqlContext._
144+
// TODO: Do we need to use a SELECT statement to make the column ordering deterministic?
145+
// Create JSON metadata.
146+
val metadata =
147+
LogisticRegressionModel.Metadata(clazz = this.getClass.getName, version = Exportable.version)
148+
val metadataRDD: SchemaRDD = sc.parallelize(Seq(metadata))
149+
metadataRDD.toJSON.saveAsTextFile(path + "/metadata")
150+
// Create Parquet data.
151+
val data = LogisticRegressionModel.Data(weights, intercept, threshold)
152+
val dataRDD: SchemaRDD = sc.parallelize(Seq(data))
153+
dataRDD.saveAsParquetFile(path + "/data")
154+
}
155+
}
156+
157+
object LogisticRegressionModel extends Importable[LogisticRegressionModel] {
158+
159+
override def load(sc: SparkContext, path: String): LogisticRegressionModel = {
160+
val sqlContext = new SQLContext(sc)
161+
import sqlContext._
162+
163+
// Load JSON metadata.
164+
val metadataRDD = sqlContext.jsonFile(path + "/metadata")
165+
val metadataArray = metadataRDD.select("clazz".attr, "version".attr).take(1)
166+
assert(metadataArray.size == 1,
167+
s"Unable to load LogisticRegressionModel metadata from: ${path + "/metadata"}")
168+
metadataArray(0) match {
169+
case Row(clazz: String, version: String) =>
170+
assert(clazz == classOf[LogisticRegressionModel].getName, s"LogisticRegressionModel.load" +
171+
s" was given model file with metadata specifying a different model class: $clazz")
172+
assert(version == Importable.version, // only 1 version exists currently
173+
s"LogisticRegressionModel.load did not recognize model format version: $version")
174+
}
175+
176+
// Load Parquet data.
177+
val dataRDD = sqlContext.parquetFile(path + "/data")
178+
val dataArray = dataRDD.select("weights".attr, "intercept".attr, "threshold".attr).take(1)
179+
assert(dataArray.size == 1,
180+
s"Unable to load LogisticRegressionModel data from: ${path + "/data"}")
181+
val data = dataArray(0)
182+
assert(data.size == 3, s"Unable to load LogisticRegressionModel data from: ${path + "/data"}")
183+
val lr = data match {
184+
case Row(weights: Vector, intercept: Double, _) =>
185+
new LogisticRegressionModel(weights, intercept)
186+
}
187+
if (data.isNullAt(2)) {
188+
lr.clearThreshold()
189+
} else {
190+
lr.setThreshold(data.getDouble(2))
191+
}
192+
lr
193+
}
194+
195+
private case class Metadata(clazz: String, version: String)
196+
197+
private case class Data(weights: Vector, intercept: Double, threshold: Option[Double])
198+
129199
}
130200

131201
/**
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.util
19+
20+
import org.apache.spark.SparkContext
21+
import org.apache.spark.annotation.DeveloperApi
22+
23+
24+
/**
25+
* :: DeveloperApi ::
26+
*
27+
* Trait for models and transformers which may be saved as files.
28+
* This should be inherited by the class which implements model instances.
29+
*/
30+
@DeveloperApi
31+
trait Exportable {
32+
33+
/**
34+
* Save this model to the given path.
35+
*
36+
* This saves:
37+
* - human-readable (JSON) model metadata to path/metadata/
38+
* - Parquet formatted data to path/data/
39+
*
40+
* The model may be loaded using [[Importable.load]].
41+
*
42+
* @param sc Spark context used to save model data.
43+
* @param path Path specifying the directory in which to save this model.
44+
* This directory and any intermediate directory will be created if needed.
45+
*/
46+
def save(sc: SparkContext, path: String): Unit
47+
48+
}
49+
50+
object Exportable {
51+
52+
/** Current version of model import/export format. */
53+
val version: String = "1.0"
54+
55+
}
56+
57+
/**
58+
* :: DeveloperApi ::
59+
*
60+
* Trait for models and transformers which may be loaded from files.
61+
* This should be inherited by an object paired with the model class.
62+
*/
63+
@DeveloperApi
64+
trait Importable[Model <: Exportable] {
65+
66+
/**
67+
* Load a model from the given path.
68+
*
69+
* The model should have been saved by [[Exportable.save]].
70+
*
71+
* @param sc Spark context used for loading model files.
72+
* @param path Path specifying the directory to which the model was saved.
73+
* @return Model instance
74+
*/
75+
def load(sc: SparkContext, path: String): Model
76+
77+
}
78+
79+
object Importable {
80+
81+
/** Current version of model import/export format. */
82+
val version: String = Exportable.version
83+
84+
}

mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala

Lines changed: 61 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
package org.apache.spark.mllib.classification
1919

2020
import scala.util.control.Breaks._
21+
import org.apache.spark.util.Utils
22+
2123
import scala.util.Random
2224
import scala.collection.JavaConversions._
2325

@@ -407,16 +409,16 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M
407409
*
408410
* First of all, using the following scala code to save the data into `path`.
409411
*
410-
* testRDD.map(x => x.label+ ", " + x.features(0) + ", " + x.features(1) + ", " +
411-
* x.features(2) + ", " + x.features(3)).saveAsTextFile("path")
412+
* testRDD.map(x => x.label+ ", " + x.features(0) + ", " + x.features(1) + ", " +
413+
* x.features(2) + ", " + x.features(3)).saveAsTextFile("path")
412414
*
413415
* Using the following R code to load the data and train the model using glmnet package.
414416
*
415-
* library("glmnet")
416-
* data <- read.csv("path", header=FALSE)
417-
* label = factor(data$V1)
418-
* features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
419-
* weights = coef(glmnet(features,label, family="multinomial", alpha = 0, lambda = 0))
417+
* library("glmnet")
418+
* data <- read.csv("path", header=FALSE)
419+
* label = factor(data$V1)
420+
* features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
421+
* weights = coef(glmnet(features,label, family="multinomial", alpha = 0, lambda = 0))
420422
*
421423
* The model weights of mutinomial logstic regression in R have `K` set of linear predictors
422424
* for `K` classes classification problem; however, only `K-1` set is required if the first
@@ -425,25 +427,25 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M
425427
* weights. The mathematical discussion and proof can be found here:
426428
* http://en.wikipedia.org/wiki/Multinomial_logistic_regression
427429
*
428-
* weights1 = weights$`1` - weights$`0`
429-
* weights2 = weights$`2` - weights$`0`
430+
* weights1 = weights$`1` - weights$`0`
431+
* weights2 = weights$`2` - weights$`0`
430432
*
431-
* > weights1
432-
* 5 x 1 sparse Matrix of class "dgCMatrix"
433-
* s0
434-
* 2.6228269
435-
* data.V2 -0.5837166
436-
* data.V3 0.9285260
437-
* data.V4 -0.3783612
438-
* data.V5 -0.8123411
439-
* > weights2
440-
* 5 x 1 sparse Matrix of class "dgCMatrix"
441-
* s0
442-
* 4.11197445
443-
* data.V2 -0.16918650
444-
* data.V3 -0.81104784
445-
* data.V4 -0.06463799
446-
* data.V5 -0.29198337
433+
* > weights1
434+
* 5 x 1 sparse Matrix of class "dgCMatrix"
435+
* s0
436+
* 2.6228269
437+
* data.V2 -0.5837166
438+
* data.V3 0.9285260
439+
* data.V4 -0.3783612
440+
* data.V5 -0.8123411
441+
* > weights2
442+
* 5 x 1 sparse Matrix of class "dgCMatrix"
443+
* s0
444+
* 4.11197445
445+
* data.V2 -0.16918650
446+
* data.V3 -0.81104784
447+
* data.V4 -0.06463799
448+
* data.V5 -0.29198337
447449
*/
448450

449451
val weightsR = Vectors.dense(Array(
@@ -459,7 +461,41 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M
459461
// very steep curve in logistic function so that when we draw samples from distribution, it's
460462
// very easy to assign to another labels. However, this prediction result is consistent to R.
461463
validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData, 0.47)
464+
}
465+
466+
test("model export/import") {
467+
val nPoints = 20
468+
val A = 2.0
469+
val B = -1.5
462470

471+
val testData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 42)
472+
val testRDD = sc.parallelize(testData, 2)
473+
testRDD.cache()
474+
475+
val lr = new LogisticRegressionWithLBFGS().setIntercept(true)
476+
lr.optimizer.setNumIterations(1)
477+
val model = lr.run(testRDD)
478+
model.clearThreshold()
479+
assert(model.getThreshold.isEmpty)
480+
481+
val tempDir = Utils.createTempDir()
482+
val path = tempDir.toURI.toString
483+
484+
// Save model
485+
model.save(sc, path)
486+
val sameModel = LogisticRegressionModel.load(sc, path)
487+
assert(model.weights == sameModel.weights)
488+
assert(model.intercept == sameModel.intercept)
489+
assert(sameModel.getThreshold.isEmpty)
490+
Utils.deleteRecursively(tempDir)
491+
492+
// Save model with threshold
493+
model.setThreshold(0.7)
494+
model.save(sc, path)
495+
val sameModel2 = LogisticRegressionModel.load(sc, path)
496+
assert(model.getThreshold.get == sameModel2.getThreshold.get)
497+
498+
Utils.deleteRecursively(tempDir)
463499
}
464500

465501
}

0 commit comments

Comments
 (0)