@@ -35,8 +35,6 @@ import org.apache.spark.mllib.util.{Loader, Saveable}
3535import org .apache .spark .rdd .RDD
3636import org .apache .spark .sql .{DataFrame , SQLContext }
3737
38- import NaiveBayes .ModelType .{Bernoulli , Multinomial }
39-
4038
4139/**
4240 * Model for Naive Bayes Classifiers.
@@ -45,18 +43,17 @@ import NaiveBayes.ModelType.{Bernoulli, Multinomial}
4543 * @param pi log of class priors, whose dimension is C, number of labels
4644 * @param theta log of class conditional probabilities, whose dimension is C-by-D,
4745 * where D is number of features
48- * @param modelType The type of NB model to fit from the enumeration NaiveBayesModels, can be
49- * Multinomial or Bernoulli
46+ * @param modelType The type of NB model to fit can be "Multinomial" or "Bernoulli"
5047 */
5148class NaiveBayesModel private [mllib] (
5249 val labels : Array [Double ],
5350 val pi : Array [Double ],
5451 val theta : Array [Array [Double ]],
55- val modelType : NaiveBayes . ModelType )
52+ val modelType : String )
5653 extends ClassificationModel with Serializable with Saveable {
5754
5855 private [mllib] def this (labels : Array [Double ], pi : Array [Double ], theta : Array [Array [Double ]]) =
59- this (labels, pi, theta, Multinomial )
56+ this (labels, pi, theta, " Multinomial" )
6057
6158 /** A Java-friendly constructor that takes three Iterable parameters. */
6259 private [mllib] def this (
@@ -72,8 +69,8 @@ class NaiveBayesModel private[mllib] (
7269 // This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
7370 // application of this condition (in predict function).
7471 private val (brzNegTheta, brzNegThetaSum) = modelType match {
75- case Multinomial => (None , None )
76- case Bernoulli =>
72+ case " Multinomial" => (None , None )
73+ case " Bernoulli" =>
7774 val negTheta = brzLog((brzExp(brzTheta.copy) :*= (- 1.0 )) :+= 1.0 ) // log(1.0 - exp(x))
7875 (Option (negTheta), Option (brzSum(negTheta, Axis ._1)))
7976 case _ =>
@@ -91,9 +88,9 @@ class NaiveBayesModel private[mllib] (
9188
9289 override def predict (testData : Vector ): Double = {
9390 modelType match {
94- case Multinomial =>
91+ case " Multinomial" =>
9592 labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) )
96- case Bernoulli =>
93+ case " Bernoulli" =>
9794 labels (brzArgmax (brzPi +
9895 (brzTheta - brzNegTheta.get) * testData.toBreeze + brzNegThetaSum.get))
9996 case _ =>
@@ -103,7 +100,7 @@ class NaiveBayesModel private[mllib] (
103100 }
104101
105102 override def save (sc : SparkContext , path : String ): Unit = {
106- val data = NaiveBayesModel .SaveLoadV2_0 .Data (labels, pi, theta, modelType.toString )
103+ val data = NaiveBayesModel .SaveLoadV2_0 .Data (labels, pi, theta, modelType)
107104 NaiveBayesModel .SaveLoadV2_0 .save(sc, path, data)
108105 }
109106
@@ -155,7 +152,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
155152 val labels = data.getAs[Seq [Double ]](0 ).toArray
156153 val pi = data.getAs[Seq [Double ]](1 ).toArray
157154 val theta = data.getAs[Seq [Seq [Double ]]](2 ).map(_.toArray).toArray
158- val modelType = NaiveBayes . ModelType .fromString( data.getString(3 ) )
155+ val modelType = data.getString(3 )
159156 new NaiveBayesModel (labels, pi, theta, modelType)
160157 }
161158
@@ -248,11 +245,11 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
248245
249246class NaiveBayes private (
250247 private var lambda : Double ,
251- private var modelType : NaiveBayes . ModelType ) extends Serializable with Logging {
248+ private var modelType : String ) extends Serializable with Logging {
252249
253- def this (lambda : Double ) = this (lambda, Multinomial )
250+ def this (lambda : Double ) = this (lambda, " Multinomial" )
254251
255- def this () = this (1.0 , Multinomial )
252+ def this () = this (1.0 , " Multinomial" )
256253
257254 /** Set the smoothing parameter. Default: 1.0. */
258255 def setLambda (lambda : Double ): NaiveBayes = {
@@ -264,26 +261,21 @@ class NaiveBayes private (
264261 def getLambda : Double = lambda
265262
266263 /**
267- * Set the model type using a string (case-insensitive).
268- * Supported options: "multinomial" and "bernoulli".
269- * (default: multinomial)
270- */
271- def setModelType (modelType : String ): NaiveBayes = {
272- setModelType(NaiveBayes .ModelType .fromString(modelType))
273- }
274-
275- /**
276- * Set the model type.
277- * Supported options: [[NaiveBayes.ModelType.Bernoulli ]], [[NaiveBayes.ModelType.Multinomial ]]
264+ * Set the model type using a string (case-sensitive).
265+ * Supported options: "Multinomial" and "Bernoulli".
278266 * (default: Multinomial)
279267 */
280- def setModelType (modelType : NaiveBayes .ModelType ): NaiveBayes = {
281- this .modelType = modelType
282- this
268+ def setModelType (modelType: String ): NaiveBayes = {
269+ if (NaiveBayes .supportedModelTypes.contains(modelType)) {
270+ this .modelType = modelType
271+ this
272+ } else {
273+ throw new UnknownError (s " NaiveBayesModel does not support ModelType: $modelType" )
274+ }
283275 }
284276
285277 /** Get the model type. */
286- def getModelType : NaiveBayes . ModelType = this .modelType
278+ def getModelType : String = this .modelType
287279
288280 /**
289281 * Run the algorithm with the configured parameters on an input RDD of LabeledPoint entries.
@@ -336,8 +328,8 @@ class NaiveBayes private (
336328 labels(i) = label
337329 pi(i) = math.log(n + lambda) - piLogDenom
338330 val thetaLogDenom = modelType match {
339- case Multinomial => math.log(brzSum(sumTermFreqs) + numFeatures * lambda)
340- case Bernoulli => math.log(n + 2.0 * lambda)
331+ case " Multinomial" => math.log(brzSum(sumTermFreqs) + numFeatures * lambda)
332+ case " Bernoulli" => math.log(n + 2.0 * lambda)
341333 case _ =>
342334 // This should never happen.
343335 throw new UnknownError (s " NaiveBayes was created with an unknown ModelType: $modelType" )
@@ -358,6 +350,10 @@ class NaiveBayes private (
358350 * Top-level methods for calling naive Bayes.
359351 */
360352object NaiveBayes {
353+
354+ /* Set of modelTypes that NaiveBayes supports */
355+ private [mllib] val supportedModelTypes = Set (" Multinomial" , " Bernoulli" )
356+
361357 /**
362358 * Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
363359 *
@@ -386,7 +382,7 @@ object NaiveBayes {
386382 * @param lambda The smoothing parameter
387383 */
388384 def train (input : RDD [LabeledPoint ], lambda : Double ): NaiveBayesModel = {
389- new NaiveBayes (lambda, NaiveBayes . ModelType . Multinomial ).run(input)
385+ new NaiveBayes (lambda, " Multinomial" ).run(input)
390386 }
391387
392388 /**
@@ -408,42 +404,11 @@ object NaiveBayes {
408404 * multinomial or bernoulli
409405 */
410406 def train (input : RDD [LabeledPoint ], lambda : Double , modelType : String ): NaiveBayesModel = {
411- new NaiveBayes (lambda, ModelType .fromString(modelType)).run(input)
412- }
413-
414- /** Provides static methods for using ModelType. */
415- sealed abstract class ModelType extends Serializable
416-
417- object ModelType extends Serializable {
418-
419- /**
420- * Get the model type from a string.
421- * @param modelType Supported: "multinomial" or "bernoulli" (case-insensitive)
422- */
423- def fromString (modelType : String ): ModelType = modelType.toLowerCase match {
424- case " multinomial" => Multinomial
425- case " bernoulli" => Bernoulli
426- case _ =>
427- throw new IllegalArgumentException (
428- s " NaiveBayes.ModelType.fromString did not recognize string: $modelType" )
429- }
430-
431- final val Multinomial : ModelType = {
432- case object Multinomial extends ModelType with Serializable {
433- override def toString : String = " multinomial"
434- }
435- Multinomial
436- }
437-
438- final val Bernoulli : ModelType = {
439- case object Bernoulli extends ModelType with Serializable {
440- override def toString : String = " bernoulli"
441- }
442- Bernoulli
407+ if (supportedModelTypes.contains(modelType)) {
408+ new NaiveBayes (lambda, modelType).run(input)
409+ } else {
410+ throw new UnknownError (s " NaiveBayes was created with an unknown ModelType: $modelType" )
443411 }
444412 }
445413
446- /** Java-friendly accessor for supported ModelType options */
447- final val modelTypes = ModelType
448-
449414}
0 commit comments