Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import scala.reflect.ClassTag

import net.razorvine.pickle._

import org.apache.spark.SparkContext
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.mllib.classification._
Expand Down Expand Up @@ -641,6 +642,8 @@ private[python] class PythonMLLibAPI extends Serializable {
def getVectors: JMap[String, JList[Float]] = {
model.getVectors.map({case (k, v) => (k, v.toList.asJava)}).asJava
}

def save(sc: SparkContext, path: String): Unit = model.save(sc, path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary to do? Isn't the wrapper written in JavaSaveable sufficient?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see that _java_model is not Word2VecModel but a wrapper to it. Sorry about this comment.

}

/**
Expand Down
21 changes: 20 additions & 1 deletion python/pyspark/mllib/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from pyspark.mllib.linalg import (
Vector, Vectors, DenseVector, SparseVector, _convert_to_vector)
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.util import JavaLoader, JavaSaveable

__all__ = ['Normalizer', 'StandardScalerModel', 'StandardScaler',
'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel',
Expand Down Expand Up @@ -416,7 +417,7 @@ def fit(self, dataset):
return IDFModel(jmodel)


class Word2VecModel(JavaVectorTransformer):
class Word2VecModel(JavaVectorTransformer, JavaSaveable, JavaLoader):
"""
class for Word2Vec model
"""
Expand Down Expand Up @@ -455,6 +456,12 @@ def getVectors(self):
"""
return self.call("getVectors")

@classmethod
def load(cls, sc, path):
jmodel = sc._jvm.org.apache.spark.mllib.feature \
.Word2VecModel.load(sc._jsc.sc(), path)
return Word2VecModel(jmodel)


@ignore_unicode_prefix
class Word2Vec(object):
Expand Down Expand Up @@ -488,6 +495,18 @@ class Word2Vec(object):
>>> syms = model.findSynonyms(vec, 2)
>>> [s[0] for s in syms]
[u'b', u'c']

>>> import os, tempfile
>>> path = tempfile.mkdtemp()
>>> model.save(sc, path)
>>> sameModel = Word2VecModel.load(sc, path)
>>> model.transform("a") == sameModel.transform("a")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has been advised by @mengxr to keep doctests and testing for correctness separately. (#6499 (comment))
It might be better to just write sameModel.transform("a") and write a small test to test model and sameModel give same results using getVectors maybe?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I meant was this test is not for correctness. I wanted to check whether a saved model works or not.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, but I think it would be better to just write
sameModel.transform(a) separately and the output.

However, I have no objections either !

True
>>> from shutil import rmtree
>>> try:
... rmtree(path)
... except OSError:
... pass
"""
def __init__(self):
"""
Expand Down