@@ -46,19 +46,19 @@ class NaiveBayesModel private[mllib] (
4646    val  labels :  Array [Double ],
4747    val  pi :  Array [Double ],
4848    val  theta :  Array [Array [Double ]],
49-     val  modelType :  NaiveBayes . ModelType )
49+     val  modelType :  String )
5050  extends  ClassificationModel  with  Serializable  with  Saveable  {
5151
5252  def  this (labels : Array [Double ], pi : Array [Double ], theta : Array [Array [Double ]]) = 
53-     this (labels, pi, theta, NaiveBayes .Multinomial )
53+     this (labels, pi, theta, NaiveBayes .Multinomial .toString )
5454
5555  private  val  brzPi  =  new  BDV [Double ](pi)
5656  private  val  brzTheta  =  new  BDM (theta(0 ).length, theta.length, theta.flatten).t
5757
5858  //  Bernoulli scoring requires log(condprob) if 1 log(1-condprob) if 0
5959  //  this precomputes log(1.0 - exp(theta)) and its sum for linear algebra application
6060  //  of this condition in predict function
61-   private  val  (brzNegTheta, brzNegThetaSum) =  modelType match  {
61+   private  val  (brzNegTheta, brzNegThetaSum) =  NaiveBayes . ModelType .fromString( modelType)  match  {
6262    case  NaiveBayes .Multinomial  =>  (None , None )
6363    case  NaiveBayes .Bernoulli  => 
6464      val  negTheta  =  brzLog((brzExp(brzTheta.copy) :*=  (- 1.0 )) :+=  1.0 ) //  log(1.0 - exp(x))
@@ -74,7 +74,7 @@ class NaiveBayesModel private[mllib] (
7474  }
7575
7676  override  def  predict (testData : Vector ):  Double  =  {
77-     modelType match  {
77+     NaiveBayes . ModelType .fromString( modelType)  match  {
7878      case  NaiveBayes .Multinomial  => 
7979        labels (brzArgmax (brzPi +  brzTheta *  testData.toBreeze) )
8080      case  NaiveBayes .Bernoulli  => 
@@ -84,7 +84,7 @@ class NaiveBayesModel private[mllib] (
8484  }
8585
8686  override  def  save (sc : SparkContext , path : String ):  Unit  =  {
87-     val  data  =  NaiveBayesModel .SaveLoadV1_0 .Data (labels, pi, theta, modelType.toString )
87+     val  data  =  NaiveBayesModel .SaveLoadV1_0 .Data (labels, pi, theta, modelType)
8888    NaiveBayesModel .SaveLoadV1_0 .save(sc, path, data)
8989  }
9090
@@ -137,15 +137,15 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
137137      val  labels  =  data.getAs[Seq [Double ]](0 ).toArray
138138      val  pi  =  data.getAs[Seq [Double ]](1 ).toArray
139139      val  theta  =  data.getAs[Seq [Seq [Double ]]](2 ).map(_.toArray).toArray
140-       val  modelType  =  NaiveBayes .ModelType .fromString(data.getString(3 ))
140+       val  modelType  =  NaiveBayes .ModelType .fromString(data.getString(3 )).toString 
141141      new  NaiveBayesModel (labels, pi, theta, modelType)
142142    }
143143  }
144144
145145  override  def  load (sc : SparkContext , path : String ):  NaiveBayesModel  =  {
146-     def  getModelType (metadata : JValue ):  NaiveBayes . ModelType  =  {
146+     def  getModelType (metadata : JValue ):  String  =  {
147147      implicit  val  formats  =  DefaultFormats 
148-       NaiveBayes .ModelType .fromString((metadata \  " modelType"  ).extract[String ])
148+       NaiveBayes .ModelType .fromString((metadata \  " modelType"  ).extract[String ]).toString 
149149    }
150150    val  (loadedClassName, version, metadata) =  loadMetadata(sc, path)
151151    val  classNameV1_0  =  SaveLoadV1_0 .thisClassName
@@ -265,7 +265,7 @@ class NaiveBayes private (
265265      i +=  1 
266266    }
267267
268-     new  NaiveBayesModel (labels, pi, theta, modelType)
268+     new  NaiveBayesModel (labels, pi, theta, modelType.toString )
269269  }
270270}
271271
0 commit comments