@@ -49,15 +49,15 @@ class NaiveBayesModel private[mllib] (
4949 val modelType : String )
5050 extends ClassificationModel with Serializable with Saveable {
5151
52- def this (labels : Array [Double ], pi : Array [Double ], theta : Array [Array [Double ]]) =
52+ private [mllib] def this (labels : Array [Double ], pi : Array [Double ], theta : Array [Array [Double ]]) =
5353 this (labels, pi, theta, NaiveBayes .Multinomial .toString)
5454
5555 private val brzPi = new BDV [Double ](pi)
5656 private val brzTheta = new BDM (theta(0 ).length, theta.length, theta.flatten).t
5757
58- // Bernoulli scoring requires log(condprob) if 1 log(1-condprob) if 0
59- // this precomputes log(1.0 - exp(theta)) and its sum for linear algebra application
60- // of this condition in predict function
58+ // Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0.
59+ // This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
60+ // application of this condition ( in predict function).
6161 private val (brzNegTheta, brzNegThetaSum) = NaiveBayes .ModelType .fromString(modelType) match {
6262 case NaiveBayes .Multinomial => (None , None )
6363 case NaiveBayes .Bernoulli =>
@@ -186,8 +186,6 @@ class NaiveBayes private (
186186 private var lambda : Double ,
187187 private var modelType : NaiveBayes .ModelType ) extends Serializable with Logging {
188188
189- def this (lambda : Double ) = this (lambda, NaiveBayes .Multinomial )
190-
191189 def this () = this (1.0 , NaiveBayes .Multinomial )
192190
193191 /** Set the smoothing parameter. Default: 1.0. */
@@ -202,6 +200,7 @@ class NaiveBayes private (
202200 this
203201 }
204202
203+ def getModelType (): NaiveBayes .ModelType = this .modelType
205204
206205 /**
207206 * Run the algorithm with the configured parameters on an input RDD of LabeledPoint entries.
@@ -301,10 +300,9 @@ object NaiveBayes {
301300 * @param lambda The smoothing parameter
302301 */
303302 def train (input : RDD [LabeledPoint ], lambda : Double ): NaiveBayesModel = {
304- new NaiveBayes (lambda).run(input)
303+ new NaiveBayes (lambda, NaiveBayes . Multinomial ).run(input)
305304 }
306305
307-
308306 /**
309307 * Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
310308 *
@@ -327,11 +325,7 @@ object NaiveBayes {
327325 new NaiveBayes (lambda, MODELTYPE .fromString(modelType)).run(input)
328326 }
329327
330-
331- /**
332- * Model types supported in Naive Bayes:
333- * multinomial and Bernoulli currently supported
334- */
328+ /** Provides static methods for using ModelType. */
335329 sealed abstract class ModelType
336330
337331 object MODELTYPE {
@@ -348,10 +342,12 @@ object NaiveBayes {
348342
349343 final val ModelType = MODELTYPE
350344
345+ /** Constant for specifying ModelType parameter: multinomial model */
351346 final val Multinomial : ModelType = new ModelType {
352347 override def toString : String = ModelType .MULTINOMIAL_STRING
353348 }
354349
350+ /** Constant for specifying ModelType parameter: bernoulli model */
355351 final val Bernoulli : ModelType = new ModelType {
356352 override def toString : String = ModelType .BERNOULLI_STRING
357353 }
0 commit comments