@@ -101,6 +101,18 @@ def transform(self, dataset, params={}):
101101 raise NotImplementedError ()
102102
103103
104+ @inherit_doc
105+ class Model (Transformer ):
106+ """
107+ Abstract class for models fitted by :py:class:`Estimator`s.
108+ """
109+
110+ ___metaclass__ = ABCMeta
111+
112+ def __init__ (self ):
113+ super (Model , self ).__init__ ()
114+
115+
104116@inherit_doc
105117class Pipeline (Estimator ):
106118 """
@@ -169,7 +181,7 @@ def fit(self, dataset, params={}):
169181
170182
171183@inherit_doc
172- class PipelineModel (Transformer ):
184+ class PipelineModel (Model ):
173185 """
174186 Represents a compiled pipeline with transformers and fitted models.
175187 """
@@ -204,9 +216,9 @@ def _java_class(self):
204216 """
205217 raise NotImplementedError
206218
207- def _create_java_obj (self ):
219+ def _java_obj (self ):
208220 """
209- Creates a new Java object and returns its reference .
221+ Returns or creates a Java object.
210222 """
211223 java_obj = _jvm ()
212224 for name in self ._java_class .split ("." ):
@@ -231,6 +243,13 @@ def _empty_java_param_map(self):
231243 """
232244 return _jvm ().org .apache .spark .ml .param .ParamMap ()
233245
246+ def _create_java_param_map (self , params , java_obj ):
247+ paramMap = self ._empty_java_param_map ()
248+ for param , value in params .items ():
249+ if param .parent is self :
250+ paramMap .put (java_obj .getParam (param .name ), value )
251+ return paramMap
252+
234253
235254@inherit_doc
236255class JavaEstimator (Estimator , JavaWrapper ):
@@ -259,7 +278,7 @@ def _fit_java(self, dataset, params={}):
259278 :param params: additional params (overwriting embedded values)
260279 :return: fitted Java model
261280 """
262- java_obj = self ._create_java_obj ()
281+ java_obj = self ._java_obj ()
263282 self ._transfer_params_to_java (params , java_obj )
264283 return java_obj .fit (dataset ._jschema_rdd , self ._empty_java_param_map ())
265284
@@ -281,7 +300,24 @@ def __init__(self):
281300 super (JavaTransformer , self ).__init__ ()
282301
283302 def transform (self , dataset , params = {}):
284- java_obj = self ._create_java_obj ()
285- self ._transfer_params_to_java (params , java_obj )
286- return SchemaRDD (java_obj .transform (dataset ._jschema_rdd , self ._empty_java_param_map ()),
303+ java_obj = self ._java_obj ()
304+ self ._transfer_params_to_java ({}, java_obj )
305+ java_param_map = self ._create_java_param_map (params , java_obj )
306+ return SchemaRDD (java_obj .transform (dataset ._jschema_rdd , java_param_map ),
287307 dataset .sql_ctx )
308+
309+
310+ @inherit_doc
311+ class JavaModel (JavaTransformer ):
312+ """
313+ Base class for :py:class:`Model`s that wrap Java/Scala
314+ implementations.
315+ """
316+
317+ __metaclass__ = ABCMeta
318+
319+ def __init__ (self ):
320+ super (JavaTransformer , self ).__init__ ()
321+
322+ def _java_obj (self ):
323+ return self ._java_model
0 commit comments