@@ -35,6 +35,8 @@ import org.apache.spark.mllib.util.{Loader, Saveable}
3535import org .apache .spark .rdd .RDD
3636import org .apache .spark .sql .{DataFrame , SQLContext }
3737
38+ import NaiveBayes .ModelType .{Bernoulli , Multinomial }
39+
3840
3941/**
4042 * Model for Naive Bayes Classifiers.
@@ -54,7 +56,7 @@ class NaiveBayesModel private[mllib] (
5456 extends ClassificationModel with Serializable with Saveable {
5557
5658 private [mllib] def this (labels : Array [Double ], pi : Array [Double ], theta : Array [Array [Double ]]) =
57- this (labels, pi, theta, NaiveBayes . Multinomial )
59+ this (labels, pi, theta, Multinomial )
5860
5961 /** A Java-friendly constructor that takes three Iterable parameters. */
6062 private [mllib] def this (
@@ -70,10 +72,13 @@ class NaiveBayesModel private[mllib] (
7072 // This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
7173 // application of this condition (in predict function).
7274 private val (brzNegTheta, brzNegThetaSum) = modelType match {
73- case NaiveBayes . Multinomial => (None , None )
74- case NaiveBayes . Bernoulli =>
75+ case Multinomial => (None , None )
76+ case Bernoulli =>
7577 val negTheta = brzLog((brzExp(brzTheta.copy) :*= (- 1.0 )) :+= 1.0 ) // log(1.0 - exp(x))
7678 (Option (negTheta), Option (brzSum(negTheta, Axis ._1)))
79+ case _ =>
80+ // This should never happen.
81+ throw new UnknownError (s " NaiveBayesModel was created with an unknown ModelType: $modelType" )
7782 }
7883
7984 override def predict (testData : RDD [Vector ]): RDD [Double ] = {
@@ -86,29 +91,32 @@ class NaiveBayesModel private[mllib] (
8691
8792 override def predict (testData : Vector ): Double = {
8893 modelType match {
89- case NaiveBayes . Multinomial =>
94+ case Multinomial =>
9095 labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) )
91- case NaiveBayes . Bernoulli =>
96+ case Bernoulli =>
9297 labels (brzArgmax (brzPi +
9398 (brzTheta - brzNegTheta.get) * testData.toBreeze + brzNegThetaSum.get))
99+ case _ =>
100+ // This should never happen.
101+ throw new UnknownError (s " NaiveBayesModel was created with an unknown ModelType: $modelType" )
94102 }
95103 }
96104
97105 override def save (sc : SparkContext , path : String ): Unit = {
98- val data = NaiveBayesModel .SaveLoadV1_0 .Data (labels, pi, theta, modelType.toString)
99- NaiveBayesModel .SaveLoadV1_0 .save(sc, path, data)
106+ val data = NaiveBayesModel .SaveLoadV2_0 .Data (labels, pi, theta, modelType.toString)
107+ NaiveBayesModel .SaveLoadV2_0 .save(sc, path, data)
100108 }
101109
102- override protected def formatVersion : String = " 1 .0"
110+ override protected def formatVersion : String = " 2 .0"
103111}
104112
105113object NaiveBayesModel extends Loader [NaiveBayesModel ] {
106114
107115 import org .apache .spark .mllib .util .Loader ._
108116
109- private object SaveLoadV1_0 {
117+ private [mllib] object SaveLoadV2_0 {
110118
111- def thisFormatVersion : String = " 1 .0"
119+ def thisFormatVersion : String = " 2 .0"
112120
113121 /** Hard-code class name string in case it changes in the future */
114122 def thisClassName : String = " org.apache.spark.mllib.classification.NaiveBayesModel"
@@ -127,8 +135,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
127135 // Create JSON metadata.
128136 val metadata = compact(render(
129137 (" class" -> thisClassName) ~ (" version" -> thisFormatVersion) ~
130- (" numFeatures" -> data.theta(0 ).length) ~ (" numClasses" -> data.pi.length) ~
131- (" modelType" -> data.modelType)))
138+ (" numFeatures" -> data.theta(0 ).length) ~ (" numClasses" -> data.pi.length)))
132139 sc.parallelize(Seq (metadata), 1 ).saveAsTextFile(metadataPath(path))
133140
134141 // Create Parquet data.
@@ -151,36 +158,82 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
151158 val modelType = NaiveBayes .ModelType .fromString(data.getString(3 ))
152159 new NaiveBayesModel (labels, pi, theta, modelType)
153160 }
161+
154162 }
155163
156- override def load (sc : SparkContext , path : String ): NaiveBayesModel = {
157- def getModelType (metadata : JValue ): NaiveBayes .ModelType = {
158- implicit val formats = DefaultFormats
159- NaiveBayes .ModelType .fromString((metadata \ " modelType" ).extract[String ])
164+ private [mllib] object SaveLoadV1_0 {
165+
166+ def thisFormatVersion : String = " 1.0"
167+
168+ /** Hard-code class name string in case it changes in the future */
169+ def thisClassName : String = " org.apache.spark.mllib.classification.NaiveBayesModel"
170+
171+ /** Model data for model import/export */
172+ case class Data (
173+ labels : Array [Double ],
174+ pi : Array [Double ],
175+ theta : Array [Array [Double ]])
176+
177+ def save (sc : SparkContext , path : String , data : Data ): Unit = {
178+ val sqlContext = new SQLContext (sc)
179+ import sqlContext .implicits ._
180+
181+ // Create JSON metadata.
182+ val metadata = compact(render(
183+ (" class" -> thisClassName) ~ (" version" -> thisFormatVersion) ~
184+ (" numFeatures" -> data.theta(0 ).length) ~ (" numClasses" -> data.pi.length)))
185+ sc.parallelize(Seq (metadata), 1 ).saveAsTextFile(metadataPath(path))
186+
187+ // Create Parquet data.
188+ val dataRDD : DataFrame = sc.parallelize(Seq (data), 1 ).toDF()
189+ dataRDD.saveAsParquetFile(dataPath(path))
190+ }
191+
192+ def load (sc : SparkContext , path : String ): NaiveBayesModel = {
193+ val sqlContext = new SQLContext (sc)
194+ // Load Parquet data.
195+ val dataRDD = sqlContext.parquetFile(dataPath(path))
196+ // Check schema explicitly since erasure makes it hard to use match-case for checking.
197+ checkSchema[Data ](dataRDD.schema)
198+ val dataArray = dataRDD.select(" labels" , " pi" , " theta" ).take(1 )
199+ assert(dataArray.size == 1 , s " Unable to load NaiveBayesModel data from: ${dataPath(path)}" )
200+ val data = dataArray(0 )
201+ val labels = data.getAs[Seq [Double ]](0 ).toArray
202+ val pi = data.getAs[Seq [Double ]](1 ).toArray
203+ val theta = data.getAs[Seq [Seq [Double ]]](2 ).map(_.toArray).toArray
204+ new NaiveBayesModel (labels, pi, theta)
160205 }
206+ }
207+
208+ override def load (sc : SparkContext , path : String ): NaiveBayesModel = {
161209 val (loadedClassName, version, metadata) = loadMetadata(sc, path)
162210 val classNameV1_0 = SaveLoadV1_0 .thisClassName
163- (loadedClassName, version) match {
211+ val classNameV2_0 = SaveLoadV2_0 .thisClassName
212+ val (model, numFeatures, numClasses) = (loadedClassName, version) match {
164213 case (className, " 1.0" ) if className == classNameV1_0 =>
165214 val (numFeatures, numClasses) = ClassificationModel .getNumFeaturesClasses(metadata)
166215 val model = SaveLoadV1_0 .load(sc, path)
167- assert(model.pi.size == numClasses,
168- s " NaiveBayesModel.load expected $numClasses classes, " +
169- s " but class priors vector pi had ${model.pi.size} elements " )
170- assert(model.theta.size == numClasses,
171- s " NaiveBayesModel.load expected $numClasses classes, " +
172- s " but class conditionals array theta had ${model.theta.size} elements " )
173- assert(model.theta.forall(_.size == numFeatures),
174- s " NaiveBayesModel.load expected $numFeatures features, " +
175- s " but class conditionals array theta had elements of size: " +
176- s " ${model.theta.map(_.size).mkString(" ," )}" )
177- assert(model.modelType == getModelType(metadata))
178- model
216+ (model, numFeatures, numClasses)
217+ case (className, " 2.0" ) if className == classNameV2_0 =>
218+ val (numFeatures, numClasses) = ClassificationModel .getNumFeaturesClasses(metadata)
219+ val model = SaveLoadV2_0 .load(sc, path)
220+ (model, numFeatures, numClasses)
179221 case _ => throw new Exception (
180222 s " NaiveBayesModel.load did not recognize model with (className, format version): " +
181223 s " ( $loadedClassName, $version). Supported: \n " +
182224 s " ( $classNameV1_0, 1.0) " )
183225 }
226+ assert(model.pi.size == numClasses,
227+ s " NaiveBayesModel.load expected $numClasses classes, " +
228+ s " but class priors vector pi had ${model.pi.size} elements " )
229+ assert(model.theta.size == numClasses,
230+ s " NaiveBayesModel.load expected $numClasses classes, " +
231+ s " but class conditionals array theta had ${model.theta.size} elements " )
232+ assert(model.theta.forall(_.size == numFeatures),
233+ s " NaiveBayesModel.load expected $numFeatures features, " +
234+ s " but class conditionals array theta had elements of size: " +
235+ s " ${model.theta.map(_.size).mkString(" ," )}" )
236+ model
184237 }
185238}
186239
@@ -197,9 +250,9 @@ class NaiveBayes private (
197250 private var lambda : Double ,
198251 private var modelType : NaiveBayes .ModelType ) extends Serializable with Logging {
199252
200- def this (lambda : Double ) = this (lambda, NaiveBayes . Multinomial )
253+ def this (lambda : Double ) = this (lambda, Multinomial )
201254
202- def this () = this (1.0 , NaiveBayes . Multinomial )
255+ def this () = this (1.0 , Multinomial )
203256
204257 /** Set the smoothing parameter. Default: 1.0. */
205258 def setLambda (lambda : Double ): NaiveBayes = {
@@ -210,9 +263,22 @@ class NaiveBayes private (
210263 /** Get the smoothing parameter. */
211264 def getLambda : Double = lambda
212265
213- /** Set the model type. Default: Multinomial. */
214- def setModelType (model : NaiveBayes .ModelType ): NaiveBayes = {
215- this .modelType = model
266+ /**
267+ * Set the model type using a string (case-insensitive).
268+ * Supported options: "multinomial" and "bernoulli".
269+ * (default: multinomial)
270+ */
271+ def setModelType (modelType : String ): NaiveBayes = {
272+ setModelType(NaiveBayes .ModelType .fromString(modelType))
273+ }
274+
275+ /**
276+ * Set the model type.
277+ * Supported options: [[NaiveBayes.ModelType.Bernoulli ]], [[NaiveBayes.ModelType.Multinomial ]]
278+ * (default: Multinomial)
279+ */
280+ def setModelType (modelType : NaiveBayes .ModelType ): NaiveBayes = {
281+ this .modelType = modelType
216282 this
217283 }
218284
@@ -270,8 +336,11 @@ class NaiveBayes private (
270336 labels(i) = label
271337 pi(i) = math.log(n + lambda) - piLogDenom
272338 val thetaLogDenom = modelType match {
273- case NaiveBayes .Multinomial => math.log(brzSum(sumTermFreqs) + numFeatures * lambda)
274- case NaiveBayes .Bernoulli => math.log(n + 2.0 * lambda)
339+ case Multinomial => math.log(brzSum(sumTermFreqs) + numFeatures * lambda)
340+ case Bernoulli => math.log(n + 2.0 * lambda)
341+ case _ =>
342+ // This should never happen.
343+ throw new UnknownError (s " NaiveBayes was created with an unknown ModelType: $modelType" )
275344 }
276345 var j = 0
277346 while (j < numFeatures) {
@@ -317,7 +386,7 @@ object NaiveBayes {
317386 * @param lambda The smoothing parameter
318387 */
319388 def train (input : RDD [LabeledPoint ], lambda : Double ): NaiveBayesModel = {
320- new NaiveBayes (lambda, NaiveBayes .Multinomial ).run(input)
389+ new NaiveBayes (lambda, NaiveBayes .ModelType . Multinomial ).run(input)
321390 }
322391
323392 /**
@@ -339,35 +408,42 @@ object NaiveBayes {
339408 * multinomial or bernoulli
340409 */
341410 def train (input : RDD [LabeledPoint ], lambda : Double , modelType : String ): NaiveBayesModel = {
342- new NaiveBayes (lambda, MODELTYPE .fromString(modelType)).run(input)
411+ new NaiveBayes (lambda, ModelType .fromString(modelType)).run(input)
343412 }
344413
345414 /** Provides static methods for using ModelType. */
346415 sealed abstract class ModelType extends Serializable
347416
348- object MODELTYPE extends Serializable {
349- final val MULTINOMIAL_STRING = " multinomial"
350- final val BERNOULLI_STRING = " bernoulli"
417+ object ModelType extends Serializable {
351418
352- def fromString (modelType : String ): ModelType = modelType match {
353- case MULTINOMIAL_STRING => Multinomial
354- case BERNOULLI_STRING => Bernoulli
419+ /**
420+ * Get the model type from a string.
421+ * @param modelType Supported: "multinomial" or "bernoulli" (case-insensitive)
422+ */
423+ def fromString (modelType : String ): ModelType = modelType.toLowerCase match {
424+ case " multinomial" => Multinomial
425+ case " bernoulli" => Bernoulli
355426 case _ =>
356- throw new IllegalArgumentException (s " Cannot recognize NaiveBayes ModelType: $modelType" )
427+ throw new IllegalArgumentException (
428+ s " NaiveBayes.ModelType.fromString did not recognize string: $modelType" )
357429 }
358- }
359430
360- final val ModelType = MODELTYPE
431+ final val Multinomial : ModelType = {
432+ case object Multinomial extends ModelType with Serializable {
433+ override def toString : String = " multinomial"
434+ }
435+ Multinomial
436+ }
361437
362- /** Constant for specifying ModelType parameter: multinomial model */
363- final val Multinomial : ModelType = new ModelType {
364- override def toString : String = ModelType .MULTINOMIAL_STRING
438+ final val Bernoulli : ModelType = {
439+ case object Bernoulli extends ModelType with Serializable {
440+ override def toString : String = " bernoulli"
441+ }
442+ Bernoulli
443+ }
365444 }
366445
367- /** Constant for specifying ModelType parameter: bernoulli model */
368- final val Bernoulli : ModelType = new ModelType {
369- override def toString : String = ModelType .BERNOULLI_STRING
370- }
446+ /** Java-friendly accessor for supported ModelType options */
447+ final val modelTypes = ModelType
371448
372449}
373-
0 commit comments