Skip to content

Commit 66ce18c

Browse files
committed
some cleanups before sending to Xiangrui
1 parent 7431272 commit 66ce18c

File tree

2 files changed

+12
-14
lines changed

2 files changed

+12
-14
lines changed

python/pyspark/ml/param/__init__.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -243,12 +243,3 @@ def _copyValues(self, to, extra={}):
243243
if paramMap.has_key(p) and to.hasParam(p.name):
244244
to._set((p.name, paramMap[p]))
245245
return to
246-
247-
@staticmethod
248-
def _copyParamMap(paramMap, to):
249-
"""
250-
Create a copy of the given ParamMap, but with parameter
251-
:param paramMap:
252-
:param to:
253-
:return:
254-
"""

python/pyspark/ml/tuning.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -232,14 +232,21 @@ def fit(self, dataset, params={}):
232232
def copy(self, extra={}):
233233
"""
234234
Creates a copy of this instance with a randomly generated uid
235-
and some extra params. This copies the underlying estimator, creates a deep copy of the embedded paramMap, and
236-
copies the embedded and extra parameters over.
235+
and some extra params. This copies the underlying estimator,
236+
evaluator, and estimatorParamMap, creates a deep copy of the
237+
embedded paramMap, and copies the embedded and extra parameters
238+
over.
237239
:param extra: Extra parameters to copy to the new instance
238240
:return: Copy of this instance
239241
"""
240-
paramMap = self.extractParamMap(extra)
241-
stages = map(lambda stage: stage.copy(extra), paramMap[self.stages])
242-
return CrossValidator().setStages(stages)
242+
newCV = Params.copy(self, extra)
243+
if self.isSet(self.estimator):
244+
newCV.setEstimator(self.getEstimator().copy(extra))
245+
if self.isSet(self.estimatorParamMaps):
246+
newCV.setEstimatorParamMaps(self.getEstimatorParamMaps().MAGIC_COPY_TO_BE_IMPLEMENTED(extra)) # TODO
247+
if self.isSet(self.evaluator):
248+
newCV.setEvaluator(self.getEvaluator().copy(extra))
249+
return newCV
243250

244251

245252
class CrossValidatorModel(Model):

0 commit comments

Comments
 (0)