Skip to content

Commit 7b33b4e

Browse files
committed
[SPARK-1406] Added a PMMLExportable interface
Restructured code in a new package mllib.pmml Supported models implements the new PMMLExportable interface: LogisticRegression, SVM, KMeansModel, LinearRegression, RidgeRegression, Lasso
1 parent d559ec5 commit 7b33b4e

13 files changed

+110
-168
lines changed

mllib/src/main/scala/org/apache/spark/mllib/export/ModelExport.scala

Lines changed: 0 additions & 22 deletions
This file was deleted.

mllib/src/main/scala/org/apache/spark/mllib/export/ModelExportType.scala

Lines changed: 0 additions & 30 deletions
This file was deleted.

mllib/src/main/scala/org/apache/spark/mllib/export/ModelExporter.scala

Lines changed: 0 additions & 45 deletions
This file was deleted.
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.pmml
19+
20+
import java.io.File
21+
import java.io.OutputStream
22+
import java.io.StringWriter
23+
import javax.xml.transform.stream.StreamResult
24+
import org.jpmml.model.JAXBUtil
25+
import org.apache.spark.mllib.pmml.export.PMMLModelExport
26+
import org.apache.spark.mllib.pmml.export.PMMLModelExportFactory
27+
28+
/**
29+
* Export model to the PMML format
30+
* Predictive Model Markup Language (PMML) in an XML-based file format
31+
* developed by the Data Mining Group (www.dmg.org).
32+
*/
33+
trait PMMLExportable {
34+
35+
/**
36+
* Export the model to the stream result in PMML format
37+
*/
38+
private def toPMML(streamResult: StreamResult): Unit = {
39+
val pmmlModelExport = PMMLModelExportFactory.createPMMLModelExport(this)
40+
JAXBUtil.marshalPMML(pmmlModelExport.getPmml(), streamResult)
41+
}
42+
43+
/**
44+
* Export the model to a local File in PMML format
45+
*/
46+
def toPMML(localPath: String): Unit = {
47+
toPMML(new StreamResult(new File(localPath)))
48+
}
49+
50+
/**
51+
* Export the model to the Outputtream in PMML format
52+
*/
53+
def toPMML(outputStream: OutputStream): Unit = {
54+
toPMML(new StreamResult(outputStream))
55+
}
56+
57+
/**
58+
* Export the model to a String in PMML format
59+
*/
60+
def toPMML(): String = {
61+
var writer = new StringWriter();
62+
toPMML(new StreamResult(writer))
63+
return writer.toString();
64+
}
65+
66+
}
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
* limitations under the License.
1616
*/
1717

18-
package org.apache.spark.mllib.export.pmml
18+
package org.apache.spark.mllib.pmml.export
1919

2020
import org.dmg.pmml.DataDictionary
2121
import org.dmg.pmml.DataField
@@ -29,7 +29,6 @@ import org.dmg.pmml.NumericPredictor
2929
import org.dmg.pmml.OpType
3030
import org.dmg.pmml.RegressionModel
3131
import org.dmg.pmml.RegressionTable
32-
3332
import org.apache.spark.mllib.regression.GeneralizedLinearModel
3433

3534
/**

mllib/src/main/scala/org/apache/spark/mllib/export/pmml/KMeansPMMLModelExport.scala renamed to mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
* limitations under the License.
1616
*/
1717

18-
package org.apache.spark.mllib.export.pmml
18+
package org.apache.spark.mllib.pmml.export
1919

2020
import org.dmg.pmml.Array.Type
2121
import org.dmg.pmml.Cluster
@@ -35,7 +35,6 @@ import org.dmg.pmml.MiningFunctionType
3535
import org.dmg.pmml.MiningSchema
3636
import org.dmg.pmml.OpType
3737
import org.dmg.pmml.SquaredEuclidean
38-
3938
import org.apache.spark.mllib.clustering.KMeansModel
4039

4140
/**
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
* limitations under the License.
1616
*/
1717

18-
package org.apache.spark.mllib.export.pmml
18+
package org.apache.spark.mllib.pmml.export
1919

2020
import org.dmg.pmml.DataDictionary
2121
import org.dmg.pmml.DataField
@@ -29,8 +29,8 @@ import org.dmg.pmml.NumericPredictor
2929
import org.dmg.pmml.OpType
3030
import org.dmg.pmml.RegressionModel
3131
import org.dmg.pmml.RegressionTable
32-
import org.apache.spark.mllib.classification.LogisticRegressionModel
3332
import org.dmg.pmml.RegressionNormalizationMethodType
33+
import org.apache.spark.mllib.classification.LogisticRegressionModel
3434

3535
/**
3636
* PMML Model Export for LogisticRegressionModel class

mllib/src/main/scala/org/apache/spark/mllib/export/pmml/PMMLModelExport.scala renamed to mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,17 @@
1515
* limitations under the License.
1616
*/
1717

18-
package org.apache.spark.mllib.export.pmml
18+
package org.apache.spark.mllib.pmml.export
1919

2020
import java.text.SimpleDateFormat
2121
import java.util.Date
22-
2322
import scala.beans.BeanProperty
24-
2523
import org.dmg.pmml.Application
2624
import org.dmg.pmml.Header
2725
import org.dmg.pmml.PMML
2826
import org.dmg.pmml.Timestamp
2927

30-
import org.apache.spark.mllib.export.ModelExport
31-
32-
private[mllib] trait PMMLModelExport extends ModelExport{
28+
private[mllib] trait PMMLModelExport {
3329

3430
/**
3531
* Holder of the exported model in PMML format

mllib/src/main/scala/org/apache/spark/mllib/export/ModelExportFactory.scala renamed to mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,30 +15,23 @@
1515
* limitations under the License.
1616
*/
1717

18-
package org.apache.spark.mllib.export
18+
package org.apache.spark.mllib.pmml.export
1919

2020
import org.apache.spark.mllib.classification.LogisticRegressionModel
2121
import org.apache.spark.mllib.classification.SVMModel
2222
import org.apache.spark.mllib.clustering.KMeansModel
23-
import org.apache.spark.mllib.export.ModelExportType.ModelExportType
24-
import org.apache.spark.mllib.export.ModelExportType.PMML
25-
import org.apache.spark.mllib.export.pmml.GeneralizedLinearPMMLModelExport
26-
import org.apache.spark.mllib.export.pmml.KMeansPMMLModelExport
27-
import org.apache.spark.mllib.export.pmml.LogisticRegressionPMMLModelExport
2823
import org.apache.spark.mllib.regression.LassoModel
2924
import org.apache.spark.mllib.regression.LinearRegressionModel
3025
import org.apache.spark.mllib.regression.RidgeRegressionModel
3126

32-
private[mllib] object ModelExportFactory {
27+
private[mllib] object PMMLModelExportFactory {
3328

3429
/**
35-
* Factory object to help creating the necessary ModelExport implementation
36-
* taking as input the ModelExportType (for example PMML)
37-
* and the machine learning model (for example KMeansModel).
30+
* Factory object to help creating the necessary PMMLModelExport implementation
31+
* taking as input the machine learning model (for example KMeansModel).
3832
*/
39-
def createModelExport(model: Any, exportType: ModelExportType): ModelExport = {
40-
return exportType match{
41-
case PMML => model match{
33+
def createPMMLModelExport(model: Any): PMMLModelExport = {
34+
return model match{
4235
case kmeans: KMeansModel =>
4336
new KMeansPMMLModelExport(kmeans)
4437
case linearRegression: LinearRegressionModel =>
@@ -54,10 +47,8 @@ private[mllib] object ModelExportFactory {
5447
case logisticRegression: LogisticRegressionModel =>
5548
new LogisticRegressionPMMLModelExport(logisticRegression, "logistic regression")
5649
case _ =>
57-
throw new IllegalArgumentException("Export not supported for model: " + model.getClass)
58-
}
59-
case _ => throw new IllegalArgumentException("Export type not supported:" + exportType)
60-
}
50+
throw new IllegalArgumentException("PMML Export not supported for model: " + model.getClass)
51+
}
6152
}
6253

6354
}
Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,11 @@
1515
* limitations under the License.
1616
*/
1717

18-
package org.apache.spark.mllib.export.pmml
18+
package org.apache.spark.mllib.pmml.export
1919

2020
import org.dmg.pmml.RegressionModel
2121
import org.scalatest.FunSuite
22-
2322
import org.apache.spark.mllib.classification.SVMModel
24-
import org.apache.spark.mllib.export.ModelExportFactory
25-
import org.apache.spark.mllib.export.ModelExportType
2623
import org.apache.spark.mllib.regression.LassoModel
2724
import org.apache.spark.mllib.regression.LinearRegressionModel
2825
import org.apache.spark.mllib.regression.RidgeRegressionModel
@@ -41,7 +38,7 @@ class GeneralizedLinearPMMLModelExportSuite extends FunSuite{
4138
val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label);
4239

4340
//act by exporting the model to the PMML format
44-
val linearModelExport = ModelExportFactory.createModelExport(linearRegressionModel, ModelExportType.PMML)
41+
val linearModelExport = PMMLModelExportFactory.createPMMLModelExport(linearRegressionModel)
4542
//assert that the PMML format is as expected
4643
assert(linearModelExport.isInstanceOf[PMMLModelExport])
4744
var pmml = linearModelExport.asInstanceOf[PMMLModelExport].getPmml()
@@ -54,7 +51,7 @@ class GeneralizedLinearPMMLModelExportSuite extends FunSuite{
5451
.getRegressionTables().get(0).getNumericPredictors().size() === linearRegressionModel.weights.size)
5552

5653
//act
57-
val ridgeModelExport = ModelExportFactory.createModelExport(ridgeRegressionModel, ModelExportType.PMML)
54+
val ridgeModelExport = PMMLModelExportFactory.createPMMLModelExport(ridgeRegressionModel)
5855
//assert that the PMML format is as expected
5956
assert(ridgeModelExport.isInstanceOf[PMMLModelExport])
6057
pmml = ridgeModelExport.asInstanceOf[PMMLModelExport].getPmml()
@@ -67,7 +64,7 @@ class GeneralizedLinearPMMLModelExportSuite extends FunSuite{
6764
.getRegressionTables().get(0).getNumericPredictors().size() === ridgeRegressionModel.weights.size)
6865

6966
//act
70-
val lassoModelExport = ModelExportFactory.createModelExport(lassoModel, ModelExportType.PMML)
67+
val lassoModelExport = PMMLModelExportFactory.createPMMLModelExport(lassoModel)
7168
//assert that the PMML format is as expected
7269
assert(lassoModelExport.isInstanceOf[PMMLModelExport])
7370
pmml = lassoModelExport.asInstanceOf[PMMLModelExport].getPmml()
@@ -80,7 +77,7 @@ class GeneralizedLinearPMMLModelExportSuite extends FunSuite{
8077
.getRegressionTables().get(0).getNumericPredictors().size() === lassoModel.weights.size)
8178

8279
//act
83-
val svmModelExport = ModelExportFactory.createModelExport(svmModel, ModelExportType.PMML)
80+
val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel)
8481
//assert that the PMML format is as expected
8582
assert(svmModelExport.isInstanceOf[PMMLModelExport])
8683
pmml = svmModelExport.asInstanceOf[PMMLModelExport].getPmml()
@@ -93,10 +90,10 @@ class GeneralizedLinearPMMLModelExportSuite extends FunSuite{
9390
.getRegressionTables().get(0).getNumericPredictors().size() === svmModel.weights.size)
9491

9592
//manual checking
96-
//ModelExporter.toPMML(linearRegressionModel,"/tmp/linearregression.xml")
97-
//ModelExporter.toPMML(ridgeRegressionModel,"/tmp/ridgeregression.xml")
98-
//ModelExporter.toPMML(lassoModel,"/tmp/lassoregression.xml")
99-
//ModelExporter.toPMML(svmModel,"/tmp/linearsvm.xml")
93+
//linearRegressionModel.toPMML("/tmp/linearregression.xml")
94+
//ridgeRegressionModel.toPMML("/tmp/ridgeregression.xml")
95+
//lassoModel.toPMML("/tmp/lassoregression.xml")
96+
//svmModel.toPMML("/tmp/linearsvm.xml")
10097

10198
}
10299

0 commit comments

Comments
 (0)