Skip to content

Commit 1496852

Browse files
committed
Added save/load for NaiveBayes
1 parent 8d46386 commit 1496852

File tree

3 files changed

+38
-34
lines changed

3 files changed

+38
-34
lines changed

mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ class LogisticRegressionModel (
147147
clazz = this.getClass.getName, version = Exportable.latestVersion)
148148
val metadataRDD: DataFrame = sc.parallelize(Seq(metadata))
149149
metadataRDD.toJSON.saveAsTextFile(path + "/metadata")
150+
150151
// Create Parquet data.
151152
val data = LogisticRegressionModel.Data(weights, intercept, threshold)
152153
val dataRDD: DataFrame = sc.parallelize(Seq(data))

mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
package org.apache.spark.mllib.classification
1919

2020
import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum}
21+
import org.apache.spark.sql.catalyst.ScalaReflection
22+
import org.apache.spark.sql.types.{ArrayType, DataType, DoubleType, StructField, StructType}
2123

2224
import org.apache.spark.{SparkContext, SparkException, Logging}
2325
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector}
@@ -79,6 +81,7 @@ class NaiveBayesModel private[mllib] (
7981
clazz = this.getClass.getName, version = Exportable.latestVersion)
8082
val metadataRDD: DataFrame = sc.parallelize(Seq(metadata))
8183
metadataRDD.toJSON.saveAsTextFile(path + "/metadata")
84+
8285
// Create Parquet data.
8386
val data = NaiveBayesModel.Data(labels, pi, theta)
8487
val dataRDD: DataFrame = sc.parallelize(Seq(data))
@@ -117,11 +120,12 @@ object NaiveBayesModel extends Importable[NaiveBayesModel] {
117120
assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${path + "/data"}")
118121
val data = dataArray(0)
119122
assert(data.size == 3, s"Unable to load NaiveBayesModel data from: ${path + "/data"}")
120-
val nb = data match {
121-
case Row(labels: Seq[Double], pi: Seq[Double], theta: Seq[Seq[Double]]) =>
122-
new NaiveBayesModel(labels.toArray, pi.toArray, theta.map(_.toArray).toArray)
123-
}
124-
nb
123+
// Check schema explicitly since erasure makes it hard to use match-case for checking.
124+
Importable.checkSchema[Data](dataRDD.schema)
125+
val labels = data.getAs[Seq[Double]](0).toArray
126+
val pi = data.getAs[Seq[Double]](1).toArray
127+
val theta = data.getAs[Seq[Seq[Double]]](2).map(_.toArray).toArray
128+
new NaiveBayesModel(labels, pi, theta)
125129
}
126130
}
127131

mllib/src/main/scala/org/apache/spark/mllib/util/modelImportExport.scala

Lines changed: 28 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,13 @@
1717

1818
package org.apache.spark.mllib.util
1919

20+
import scala.reflect.runtime.universe.TypeTag
21+
2022
import org.apache.spark.SparkContext
2123
import org.apache.spark.annotation.DeveloperApi
24+
import org.apache.spark.sql.catalyst.ScalaReflection
25+
import org.apache.spark.sql.types.{DataType, StructType, StructField}
26+
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
2227

2328
/**
2429
* :: DeveloperApi ::
@@ -46,11 +51,7 @@ trait Exportable {
4651

4752
}
4853

49-
/**
50-
* :: DeveloperApi ::
51-
*/
52-
@DeveloperApi
53-
object Exportable {
54+
private[mllib] object Exportable {
5455

5556
/** Current version of model import/export format. */
5657
val latestVersion: String = "1.0"
@@ -79,34 +80,32 @@ trait Importable[Model <: Exportable] {
7980

8081
}
8182

82-
/*
83-
/**
84-
* :: DeveloperApi ::
85-
*
86-
* Trait for models and transformers which may be saved as files.
87-
* This should be inherited by the class which implements model instances.
88-
*
89-
* This specializes [[Exportable]] for local models which can be stored on a single machine.
90-
* This provides helper functionality, but developers can choose to use [[Exportable]] instead,
91-
* even for local models.
92-
*/
93-
@DeveloperApi
94-
trait LocalExportable {
83+
private[mllib] object Importable {
9584

9685
/**
97-
* Save this model to the given path.
98-
*
99-
* This saves:
100-
* - human-readable (JSON) model metadata to path/metadata/
101-
* - Parquet formatted data to path/data/
86+
* Check the schema of loaded model data.
10287
*
103-
* The model may be loaded using [[Importable.load]].
88+
* This checks every field in the expected schema to make sure that a field with the same
89+
* name and DataType appears in the loaded schema. Note that this does NOT check metadata
90+
* or containsNull.
10491
*
105-
* @param sc Spark context used to save model data.
106-
* @param path Path specifying the directory in which to save this model.
107-
* This directory and any intermediate directory will be created if needed.
92+
* @param loadedSchema Schema for model data loaded from file.
93+
* @tparam Data Expected data type from which an expected schema can be derived.
10894
*/
109-
def save(sc: SparkContext, path: String): Unit
95+
def checkSchema[Data: TypeTag](loadedSchema: StructType): Unit = {
96+
// Check schema explicitly since erasure makes it hard to use match-case for checking.
97+
val expectedFields: Array[StructField] =
98+
ScalaReflection.schemaFor[Data].dataType.asInstanceOf[StructType].fields
99+
val loadedFields: Map[String, DataType] =
100+
loadedSchema.map(field => field.name -> field.dataType).toMap
101+
expectedFields.foreach { field =>
102+
assert(loadedFields.contains(field.name), s"Unable to parse model data." +
103+
s" Expected field with name ${field.name} was missing in loaded schema:" +
104+
s" ${loadedFields.mkString(", ")}")
105+
assert(loadedFields(field.name) == field.dataType,
106+
s"Unable to parse model data. Expected field $field but found field" +
107+
s" with different type: ${loadedFields(field.name)}")
108+
}
109+
}
110110

111111
}
112-
*/

0 commit comments

Comments
 (0)