@@ -27,7 +27,7 @@ import org.json4s.JsonDSL._
2727import  org .json4s .jackson .JsonMethods ._ 
2828
2929import  org .apache .spark .{Logging , SparkContext , SparkException }
30- import  org .apache .spark .mllib .linalg .{BLAS , DenseVector , SparseVector , Vector }
30+ import  org .apache .spark .mllib .linalg .{BLAS , DenseMatrix ,  DenseVector , SparseVector , Vector ,  Vectors }
3131import  org .apache .spark .mllib .regression .LabeledPoint 
3232import  org .apache .spark .mllib .util .{Loader , Saveable }
3333import  org .apache .spark .rdd .RDD 
@@ -50,6 +50,9 @@ class NaiveBayesModel private[mllib] (
5050    val  modelType :  String )
5151  extends  ClassificationModel  with  Serializable  with  Saveable  {
5252
53+   val  piVector  =  Vectors .dense(pi).asInstanceOf [DenseVector ]
54+   val  thetaMatrix  =  new  DenseMatrix (labels.size, theta(0 ).size, theta.flatten, true )
55+ 
5356  private [mllib] def  this (labels : Array [Double ], pi : Array [Double ], theta : Array [Array [Double ]]) = 
5457    this (labels, pi, theta, " Multinomial"  )
5558
@@ -60,17 +63,18 @@ class NaiveBayesModel private[mllib] (
6063      theta : JIterable [JIterable [Double ]]) = 
6164    this (labels.asScala.toArray, pi.asScala.toArray, theta.asScala.toArray.map(_.asScala.toArray))
6265
63-   private  val  brzPi  =  new  BDV [Double ](pi)
64-   private  val  brzTheta  =  new  BDM (theta(0 ).length, theta.length, theta.flatten).t
65- 
6666  //  Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0.
67-   //  This precomputes log(1.0 - exp(theta)) and its sum   which are used for the   linear algebra
67+   //  This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
6868  //  application of this condition (in predict function).
69-   private  val  (brzNegTheta , brzNegThetaSum) =  modelType match  {
69+   private  val  (thetaMinusnegTheta , brzNegThetaSum) =  modelType match  {
7070    case  " Multinomial"   =>  (None , None )
7171    case  " Bernoulli"   => 
72-       val  negTheta  =  brzLog((brzExp(brzTheta.copy) :*=  (- 1.0 )) :+=  1.0 ) //  log(1.0 - exp(x))
73-       (Option (negTheta), Option (brzSum(negTheta, Axis ._1)))
72+       val  negTheta  =  thetaMatrix.map(value =>  math.log(1.0  -  math.exp(value)))
73+       val  ones  =  Vectors .dense(Array .fill(thetaMatrix.numCols){1.0 }).asInstanceOf [DenseVector ]
74+       val  thetaMinusnegTheta  =  thetaMatrix.map { value => 
75+         value -  math.log(1.0  -  math.exp(value))
76+       }
77+       (Option (thetaMinusnegTheta), Option (negTheta.multiply(ones)))
7478    case  _ => 
7579      //  This should never happen.
7680      throw  new  UnknownError (s " NaiveBayesModel was created with an unknown ModelType:  $modelType" )
@@ -85,17 +89,22 @@ class NaiveBayesModel private[mllib] (
8589  }
8690
8791  override  def  predict (testData : Vector ):  Double  =  {
88-     val  brzData  =  testData.toBreeze
8992    modelType match  {
9093      case  " Multinomial"   => 
91-         labels(brzArgmax(brzPi +  brzTheta *  brzData))
94+         val  prob  =  thetaMatrix.multiply(testData.toDense)
95+         BLAS .axpy(1.0 , piVector, prob)
96+         labels(prob.argmax)
9297      case  " Bernoulli"   => 
93-         if  (! brzData.forall(v =>  v ==  0.0  ||  v ==  1.0 )) {
94-           throw  new  SparkException (
95-             s " Bernoulli Naive Bayes requires 0 or 1 feature values but found  $testData. " )
98+         testData.foreachActive { (index, value) => 
99+           if  (value !=  0.0  &&  value !=  1.0 ) {
100+             throw  new  SparkException (
101+               s " Bernoulli Naive Bayes requires 0 or 1 feature values but found  $testData. " )
102+           }
96103        }
97-         labels(brzArgmax(brzPi + 
98-           (brzTheta -  brzNegTheta.get) *  brzData +  brzNegThetaSum.get))
104+         val  prob  =  thetaMinusnegTheta.get.multiply(testData.toDense)
105+         BLAS .axpy(1.0 , piVector, prob)
106+         BLAS .axpy(1.0 , brzNegThetaSum.get, prob)
107+         labels(prob.argmax)
99108      case  _ => 
100109        //  This should never happen.
101110        throw  new  UnknownError (s " NaiveBayesModel was created with an unknown ModelType:  $modelType" )
0 commit comments