| 
18 | 18 | package org.apache.spark.mllib.classification  | 
19 | 19 | 
 
  | 
20 | 20 | import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum, Axis}  | 
21 |  | -import org.apache.spark.mllib.classification.NaiveBayesModels.NaiveBayesModels  | 
 | 21 | +import breeze.numerics.{exp => brzExp, log => brzLog}  | 
22 | 22 | 
 
  | 
23 | 23 | import org.apache.spark.{SparkException, Logging}  | 
24 | 24 | import org.apache.spark.SparkContext._  | 
25 | 25 | import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector}  | 
26 | 26 | import org.apache.spark.mllib.regression.LabeledPoint  | 
 | 27 | +import org.apache.spark.mllib.classification.NaiveBayesModels.NaiveBayesModels  | 
27 | 28 | import org.apache.spark.rdd.RDD  | 
28 | 29 | 
 
  | 
29 | 30 | 
 
  | 
@@ -52,29 +53,14 @@ class NaiveBayesModel private[mllib] (  | 
52 | 53 |     val theta: Array[Array[Double]],  | 
53 | 54 |     val model: NaiveBayesModels) extends ClassificationModel with Serializable {  | 
54 | 55 | 
 
  | 
55 |  | -  def populateMatrix(arrayIn: Array[Array[Double]],  | 
56 |  | -                     matrixIn: BDM[Double],  | 
57 |  | -                     transformation: (Double) => Double = (x) => x) = {  | 
58 |  | -    var i = 0  | 
59 |  | -    while (i < arrayIn.length) {  | 
60 |  | -      var j = 0  | 
61 |  | -      while (j < arrayIn(i).length) {  | 
62 |  | -        matrixIn(i, j) = transformation(theta(i)(j))  | 
63 |  | -        j += 1  | 
64 |  | -      }  | 
65 |  | -      i += 1  | 
66 |  | -    }  | 
67 |  | -  }  | 
68 |  | - | 
69 | 56 |   private val brzPi = new BDV[Double](pi)  | 
70 |  | -  private val brzTheta = new BDM[Double](theta.length, theta(0).length)  | 
71 |  | -  populateMatrix(theta, brzTheta)  | 
 | 57 | +  private val brzTheta = new BDM(theta(0).length, theta.length, theta.flatten).t  | 
72 | 58 | 
 
  | 
73 | 59 |   private val brzNegTheta: Option[BDM[Double]] = model match {  | 
74 | 60 |     case NaiveBayesModels.Multinomial => None  | 
75 | 61 |     case NaiveBayesModels.Bernoulli =>  | 
76 |  | -      val negTheta = new BDM[Double](theta.length, theta(0).length)  | 
77 |  | -      populateMatrix(theta, negTheta, (x) => math.log(1.0 - math.exp(x)))  | 
 | 62 | +      val negTheta = brzLog((brzExp(brzTheta.copy) :*= (-1.0)) :+= 1.0)  | 
 | 63 | +      //((x) => math.log(1.0 - math.exp(x))  | 
78 | 64 |       Option(negTheta)  | 
79 | 65 |   }  | 
80 | 66 | 
 
  | 
@@ -244,7 +230,7 @@ object NaiveBayes {  | 
244 | 230 |    * @param model The type of NB model to fit from the enumeration NaiveBayesModels, can be  | 
245 | 231 |    *              Multinomial or Bernoulli  | 
246 | 232 |    */  | 
247 |  | -  def train(input: RDD[LabeledPoint], lambda: Double, model: NaiveBayesModels): NaiveBayesModel = {  | 
248 |  | -    new NaiveBayes(lambda, model).run(input)  | 
 | 233 | +  def train(input: RDD[LabeledPoint], lambda: Double, model: String): NaiveBayesModel = {  | 
 | 234 | +    new NaiveBayes(lambda,  NaiveBayesModels.withName(model)).run(input)  | 
249 | 235 |   }  | 
250 | 236 | }  | 
0 commit comments