-
Couldn't load subscription status.
- Fork 28.9k
[SPARK-7104][MLlib] Support model save/load in Python's Word2Vec #6821
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -36,6 +36,7 @@ | |
| from pyspark.mllib.linalg import ( | ||
| Vector, Vectors, DenseVector, SparseVector, _convert_to_vector) | ||
| from pyspark.mllib.regression import LabeledPoint | ||
| from pyspark.mllib.util import JavaLoader, JavaSaveable | ||
|
|
||
| __all__ = ['Normalizer', 'StandardScalerModel', 'StandardScaler', | ||
| 'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel', | ||
|
|
@@ -416,7 +417,7 @@ def fit(self, dataset): | |
| return IDFModel(jmodel) | ||
|
|
||
|
|
||
| class Word2VecModel(JavaVectorTransformer): | ||
| class Word2VecModel(JavaVectorTransformer, JavaSaveable, JavaLoader): | ||
| """ | ||
| class for Word2Vec model | ||
| """ | ||
|
|
@@ -455,6 +456,12 @@ def getVectors(self): | |
| """ | ||
| return self.call("getVectors") | ||
|
|
||
| @classmethod | ||
| def load(cls, sc, path): | ||
| jmodel = sc._jvm.org.apache.spark.mllib.feature \ | ||
| .Word2VecModel.load(sc._jsc.sc(), path) | ||
| return Word2VecModel(jmodel) | ||
|
|
||
|
|
||
| @ignore_unicode_prefix | ||
| class Word2Vec(object): | ||
|
|
@@ -488,6 +495,18 @@ class Word2Vec(object): | |
| >>> syms = model.findSynonyms(vec, 2) | ||
| >>> [s[0] for s in syms] | ||
| [u'b', u'c'] | ||
|
|
||
| >>> import os, tempfile | ||
| >>> path = tempfile.mkdtemp() | ||
| >>> model.save(sc, path) | ||
| >>> sameModel = Word2VecModel.load(sc, path) | ||
| >>> model.transform("a") == sameModel.transform("a") | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It has been advised by @mengxr to keep doctests and testing for correctness separately. (#6499 (comment)) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What I meant was this test is not for correctness. I wanted to check whether a saved model works or not. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, but I think it would be better to just write However, I have no objections either ! |
||
| True | ||
| >>> from shutil import rmtree | ||
| >>> try: | ||
| ... rmtree(path) | ||
| ... except OSError: | ||
| ... pass | ||
| """ | ||
| def __init__(self): | ||
| """ | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this necessary to do? Isn't the wrapper written in
JavaSaveablesufficient?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I see that
_java_modelis notWord2VecModelbut a wrapper to it. Sorry about this comment.