File tree Expand file tree Collapse file tree 2 files changed +6
-6
lines changed
main/scala/org/apache/spark/mllib/classification
test/scala/org/apache/spark/mllib/classification Expand file tree Collapse file tree 2 files changed +6
-6
lines changed Original file line number Diff line number Diff line change @@ -310,21 +310,21 @@ object NaiveBayes {
310310 *
311311 * The model type can be set to either Multinomial NB ([[http://tinyurl.com/lsdw6p ]])
312312 * or Bernoulli NB ([[http://tinyurl.com/p7c96j6 ]]). The Multinomial NB can handle
313- * discrete count data and can be called by setting the model type to "Multinomial ".
313+ * discrete count data and can be called by setting the model type to "multinomial ".
314314 * For example, it can be used with word counts or TF_IDF vectors of documents.
315315 * The Bernoulli model fits presence or absence (0-1) counts. By making every vector a
316- * 0-1 vector and setting the model type to "Bernoulli ", the fits and predicts as
316+ * 0-1 vector and setting the model type to "bernoulli ", the fits and predicts as
317317 * Bernoulli NB.
318318 *
319319 * @param input RDD of `(label, array of features)` pairs. Every vector should be a frequency
320320 * vector or a count vector.
321321 * @param lambda The smoothing parameter
322322 *
323323 * @param modelType The type of NB model to fit from the enumeration NaiveBayesModels, can be
324- * Multinomial or Bernoulli
324+ * multinomial or bernoulli
325325 */
326326 def train (input : RDD [LabeledPoint ], lambda : Double , modelType : String ): NaiveBayesModel = {
327- new NaiveBayes (lambda, Multinomial ).run(input)
327+ new NaiveBayes (lambda, MODELTYPE .fromString(modelType) ).run(input)
328328 }
329329
330330
Original file line number Diff line number Diff line change @@ -124,7 +124,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
124124 val testRDD = sc.parallelize(testData, 2 )
125125 testRDD.cache()
126126
127- val model = NaiveBayes .train(testRDD, 1.0 , " Multinomial " )
127+ val model = NaiveBayes .train(testRDD, 1.0 , " multinomial " )
128128 validateModelFit(pi, theta, model)
129129
130130 val validationData = NaiveBayesSuite .generateNaiveBayesInput(
@@ -161,7 +161,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
161161 val testRDD = sc.parallelize(testData, 2 )
162162 testRDD.cache()
163163
164- val model = NaiveBayes .train(testRDD, 1.0 , " Bernoulli " )
164+ val model = NaiveBayes .train(testRDD, 1.0 , " bernoulli " )
165165 validateModelFit(pi, theta, model)
166166
167167 val validationData = NaiveBayesSuite .generateNaiveBayesInput(
You can’t perform that action at this time.
0 commit comments